thienphuc12339 commited on
Commit
0bfec51
·
verified ·
1 Parent(s): 74050d9

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +48 -75
inference.py CHANGED
@@ -1,4 +1,5 @@
1
- #inference.py
 
2
  import shutil
3
  import logging
4
  from time import time
@@ -8,8 +9,7 @@ import pandas as pd
8
  import cv2
9
  from traceback import format_exc
10
  from argparse import Namespace
11
- from transformers import Pipeline
12
- from simple_parsing import ArgumentParser
13
  import mediapipe as mp
14
  from mediapipe.python.solutions.pose import PoseLandmark
15
  from mediapipe.python.solutions.hands import HandLandmark
@@ -19,8 +19,16 @@ from visualization import draw_text_on_image
19
  from configs import ModelConfig, InferenceConfig
20
  from utils import config_logger, POSE_BASED_MODELS
21
  from data import Arm, get_sample_timestamp, ok_to_get_frame
22
- from tools import load_pipeline, Predictions
23
 
 
 
 
 
 
 
 
 
24
 
25
  SPOTER_POSE_LANDMARKS = [
26
  PoseLandmark.NOSE,
@@ -31,7 +39,8 @@ SPOTER_POSE_LANDMARKS = [
31
  PoseLandmark.RIGHT_ELBOW,
32
  PoseLandmark.LEFT_ELBOW,
33
  PoseLandmark.RIGHT_WRIST,
34
- PoseLandmark.LEFT_WRIST ]
 
35
 
36
  SPOTER_HAND_LANDMARKS = [
37
  HandLandmark.WRIST,
@@ -42,6 +51,7 @@ SPOTER_HAND_LANDMARKS = [
42
  HandLandmark.THUMB_TIP, HandLandmark.THUMB_IP, HandLandmark.THUMB_MCP, HandLandmark.THUMB_CMC,
43
  ]
44
 
 
45
  def get_args() -> Namespace:
46
  parser = ArgumentParser(
47
  description="Train a model on VSL",
@@ -52,13 +62,13 @@ def get_args() -> Namespace:
52
  return parser.parse_args()
53
 
54
 
55
- def inference(model_config, inference_config: InferenceConfig, pipeline: Pipeline) -> None:
56
  # Load video
57
- source = str(inference_config.source) if inference_config.source.is_file() else 0
58
  cap = cv2.VideoCapture(source)
59
  if inference_config.output_dir is not None:
60
  writer = cv2.VideoWriter(
61
- str(inference_config.output_dir / "output.mp4"),
62
  cv2.VideoWriter_fourcc(*"mp4v"),
63
  cap.get(cv2.CAP_PROP_FPS),
64
  (int(cap.get(3)), int(cap.get(4))),
@@ -69,7 +79,6 @@ def inference(model_config, inference_config: InferenceConfig, pipeline: Pipelin
69
  mp_drawing = mp.solutions.drawing_utils
70
  mp_drawing_styles = mp.solutions.drawing_styles
71
 
72
-
73
  custom_pose_style = mp_drawing_styles.get_default_pose_landmarks_style()
74
  custom_right_hand_style = mp_drawing_styles.get_default_hand_landmarks_style()
75
  custom_left_hand_style = mp_drawing_styles.get_default_hand_landmarks_style()
@@ -77,30 +86,27 @@ def inference(model_config, inference_config: InferenceConfig, pipeline: Pipelin
77
  custom_hand_connections = list(mp_holistic.HAND_CONNECTIONS)
78
 
79
  if inference_config.show_skeleton:
80
- # if model_config.arch == 'spoter':
81
  pose_landmarks = SPOTER_POSE_LANDMARKS
82
  hand_landmarks = SPOTER_HAND_LANDMARKS
83
 
84
  for landmark in PoseLandmark:
85
  if landmark in pose_landmarks:
86
- custom_pose_style[landmark] = DrawingSpec(color=(0,255,0), thickness=2, circle_radius=2)
87
  else:
88
- custom_pose_style[landmark] = DrawingSpec(color=(0,0,0), thickness=0, circle_radius=0)
89
- for connection_tuple in custom_pose_connections:
90
- if landmark.value in connection_tuple:
91
- custom_pose_connections.remove(connection_tuple)
92
 
93
  for landmark in HandLandmark:
94
  if landmark in hand_landmarks:
95
- custom_right_hand_style[landmark] = DrawingSpec(color=(0,0,255), thickness=2, circle_radius=2)
96
- custom_left_hand_style[landmark] = DrawingSpec(color=(255,0,0), thickness=2, circle_radius=2)
97
  else:
98
- custom_right_hand_style[HandLandmark[landmark.name]] = DrawingSpec(color=(0,0,0), thickness=0, circle_radius=0)
99
- custom_left_hand_style[HandLandmark[landmark.name]] = DrawingSpec(color=(0,0,0), thickness=0, circle_radius=0)
100
- for connection_tuple in custom_hand_connections:
101
- if landmark.value in connection_tuple:
102
- custom_hand_connections.remove(connection_tuple)
103
-
104
 
105
  # Init variables
106
  right_arm = Arm("right", inference_config.visibility)
@@ -155,19 +161,15 @@ def inference(model_config, inference_config: InferenceConfig, pipeline: Pipelin
155
  if left_arm_ok_to_get_frame or right_arm_ok_to_get_frame:
156
  # logging.info("Frame added to the list")
157
  predictions = Predictions()
158
- data.append(detection_results if inference_config.use_pose_model else frame)
159
 
160
  # Calculate the start and end time of sign
161
  start_time, end_time = get_sample_timestamp(left_arm, right_arm)
162
 
163
- # Convert from miliseconds to seconds
164
  start_time /= 1_000
165
  end_time /= 1_000
166
 
167
- # logging.info(f"start_time: {start_time} - end_time: {end_time}")
168
- # logging.info(f"\tLeft arm: {left_arm.start_time} - {left_arm.end_time} - {left_arm.is_up}")
169
- # logging.info(f"\tRight arm: {right_arm.start_time} - {right_arm.end_time} - {right_arm.is_up}")
170
-
171
  if start_time != 0 and end_time != 0:
172
  # Render waiting screen
173
  if inference_config.visualize:
@@ -183,7 +185,9 @@ def inference(model_config, inference_config: InferenceConfig, pipeline: Pipelin
183
  break
184
 
185
  start_inference_time = time()
186
- predictions = Predictions(predictions=pipeline(np.array(data)))
 
 
187
  predictions.inference_time = time() - start_inference_time
188
 
189
  predictions.start_time = start_time
@@ -206,20 +210,23 @@ def inference(model_config, inference_config: InferenceConfig, pipeline: Pipelin
206
  mp_drawing.draw_landmarks(
207
  frame,
208
  detection_results.pose_landmarks,
209
- connections = custom_pose_connections, # passing the modified connections list
210
- landmark_drawing_spec=custom_pose_style) # and drawing style
211
-
 
212
  mp_drawing.draw_landmarks(
213
  frame,
214
  detection_results.right_hand_landmarks,
215
- connections = custom_hand_connections, # passing the modified connections list
216
- landmark_drawing_spec=custom_right_hand_style) # and drawing style
 
217
 
218
  mp_drawing.draw_landmarks(
219
  frame,
220
  detection_results.left_hand_landmarks,
221
- connections = custom_hand_connections, # passing the modified connections list
222
- landmark_drawing_spec=custom_left_hand_style) # and drawing style
 
223
 
224
  if inference_config.output_dir is not None:
225
  writer.write(frame)
@@ -234,42 +241,8 @@ def inference(model_config, inference_config: InferenceConfig, pipeline: Pipelin
234
 
235
  if inference_config.output_dir is not None:
236
  writer.release()
237
- logging.info(f"Video is recorded and saved to {inference_config.output_dir / 'output.avi'}")
238
- pd.DataFrame(results).to_csv(inference_config.output_dir / "results.csv", index=False)
239
- logging.info(f"Results saved to {inference_config.output_dir / 'results.csv'}")
240
-
241
-
242
- # inference.py
243
-
244
- def main(args: Namespace) -> None:
245
- model_config = args.model
246
- logging.info(model_config)
247
- inference_config = args.inference
248
- logging.info(inference_config)
249
-
250
- if model_config.arch in POSE_BASED_MODELS:
251
- inference_config.use_pose_model = True
252
- else:
253
- inference_config.use_pose_model = False
254
-
255
- pipeline_or_session = load_pipeline(model_config, inference_config)
256
- logging.info("Pipeline loaded")
257
-
258
- inference(model_config, inference_config, pipeline_or_session)
259
- logging.info("Inference completed")
260
-
261
-
262
-
263
- if __name__ == "__main__":
264
- try:
265
- args = get_args()
266
-
267
- config_logger(args.inference.output_dir / "inference.log")
268
- logging.info(f"Config file loaded from {args.config_path[0]}")
269
-
270
- shutil.copy(args.config_path[0], args.inference.output_dir / "inference.yaml")
271
- logging.info(f"Config file saved to {args.inference.output_dir}")
272
 
273
- main(args=args)
274
- except Exception:
275
- print(format_exc())
 
1
+ # inference.py
2
+
3
  import shutil
4
  import logging
5
  from time import time
 
9
  import cv2
10
  from traceback import format_exc
11
  from argparse import Namespace
12
+ from pydantic import BaseModel
 
13
  import mediapipe as mp
14
  from mediapipe.python.solutions.pose import PoseLandmark
15
  from mediapipe.python.solutions.hands import HandLandmark
 
19
  from configs import ModelConfig, InferenceConfig
20
  from utils import config_logger, POSE_BASED_MODELS
21
  from data import Arm, get_sample_timestamp, ok_to_get_frame
22
+ from tools.models import load_pipeline, get_predictions, Predictions
23
 
24
+ # Define id2gloss mapping
25
+ # Bạn cần thay thế bản đồ này với bản đồ thực tế của bạn
26
+ id2gloss = {
27
+ "0": "hello",
28
+ "1": "thanks",
29
+ "2": "yes",
30
+ # Thêm các ánh xạ cần thiết
31
+ }
32
 
33
  SPOTER_POSE_LANDMARKS = [
34
  PoseLandmark.NOSE,
 
39
  PoseLandmark.RIGHT_ELBOW,
40
  PoseLandmark.LEFT_ELBOW,
41
  PoseLandmark.RIGHT_WRIST,
42
+ PoseLandmark.LEFT_WRIST
43
+ ]
44
 
45
  SPOTER_HAND_LANDMARKS = [
46
  HandLandmark.WRIST,
 
51
  HandLandmark.THUMB_TIP, HandLandmark.THUMB_IP, HandLandmark.THUMB_MCP, HandLandmark.THUMB_CMC,
52
  ]
53
 
54
+
55
  def get_args() -> Namespace:
56
  parser = ArgumentParser(
57
  description="Train a model on VSL",
 
62
  return parser.parse_args()
63
 
64
 
65
+ def inference(model_config, inference_config: InferenceConfig, session: ort.InferenceSession) -> dict:
66
  # Load video
67
+ source = str(inference_config.source) if Path(inference_config.source).is_file() else 0
68
  cap = cv2.VideoCapture(source)
69
  if inference_config.output_dir is not None:
70
  writer = cv2.VideoWriter(
71
+ str(Path(inference_config.output_dir) / "output.mp4"),
72
  cv2.VideoWriter_fourcc(*"mp4v"),
73
  cap.get(cv2.CAP_PROP_FPS),
74
  (int(cap.get(3)), int(cap.get(4))),
 
79
  mp_drawing = mp.solutions.drawing_utils
80
  mp_drawing_styles = mp.solutions.drawing_styles
81
 
 
82
  custom_pose_style = mp_drawing_styles.get_default_pose_landmarks_style()
83
  custom_right_hand_style = mp_drawing_styles.get_default_hand_landmarks_style()
84
  custom_left_hand_style = mp_drawing_styles.get_default_hand_landmarks_style()
 
86
  custom_hand_connections = list(mp_holistic.HAND_CONNECTIONS)
87
 
88
  if inference_config.show_skeleton:
89
+ # Định dạng đặc biệt cho 'spoter'
90
  pose_landmarks = SPOTER_POSE_LANDMARKS
91
  hand_landmarks = SPOTER_HAND_LANDMARKS
92
 
93
  for landmark in PoseLandmark:
94
  if landmark in pose_landmarks:
95
+ custom_pose_style[landmark] = DrawingSpec(color=(0, 255, 0), thickness=2, circle_radius=2)
96
  else:
97
+ custom_pose_style[landmark] = DrawingSpec(color=(0, 0, 0), thickness=0, circle_radius=0)
98
+ # Loại bỏ các kết nối liên quan
99
+ custom_pose_connections = [conn for conn in custom_pose_connections if landmark.value not in conn]
 
100
 
101
  for landmark in HandLandmark:
102
  if landmark in hand_landmarks:
103
+ custom_right_hand_style[landmark] = DrawingSpec(color=(0, 0, 255), thickness=2, circle_radius=2)
104
+ custom_left_hand_style[landmark] = DrawingSpec(color=(255, 0, 0), thickness=2, circle_radius=2)
105
  else:
106
+ # Loại bỏ các kết nối liên quan
107
+ custom_hand_connections = [conn for conn in custom_hand_connections if landmark.value not in conn]
108
+ custom_right_hand_style[landmark] = DrawingSpec(color=(0, 0, 0), thickness=0, circle_radius=0)
109
+ custom_left_hand_style[landmark] = DrawingSpec(color=(0, 0, 0), thickness=0, circle_radius=0)
 
 
110
 
111
  # Init variables
112
  right_arm = Arm("right", inference_config.visibility)
 
161
  if left_arm_ok_to_get_frame or right_arm_ok_to_get_frame:
162
  # logging.info("Frame added to the list")
163
  predictions = Predictions()
164
+ data.append(frame) # Chỉ sử dụng frame vì bạn đang dùng .onnx
165
 
166
  # Calculate the start and end time of sign
167
  start_time, end_time = get_sample_timestamp(left_arm, right_arm)
168
 
169
+ # Convert from milliseconds to seconds
170
  start_time /= 1_000
171
  end_time /= 1_000
172
 
 
 
 
 
173
  if start_time != 0 and end_time != 0:
174
  # Render waiting screen
175
  if inference_config.visualize:
 
185
  break
186
 
187
  start_inference_time = time()
188
+ # Chuyển data thành np.ndarray phù hợp với mô hình ONNX
189
+ data_np = np.stack(data, axis=0) # Giả sử mô hình nhận dạng theo batch
190
+ predictions = get_predictions(data_np, session, id2gloss=id2gloss, k=inference_config.top_k)
191
  predictions.inference_time = time() - start_inference_time
192
 
193
  predictions.start_time = start_time
 
210
  mp_drawing.draw_landmarks(
211
  frame,
212
  detection_results.pose_landmarks,
213
+ connections=custom_pose_connections,
214
+ landmark_drawing_spec=custom_pose_style
215
+ )
216
+
217
  mp_drawing.draw_landmarks(
218
  frame,
219
  detection_results.right_hand_landmarks,
220
+ connections=custom_hand_connections,
221
+ landmark_drawing_spec=custom_right_hand_style
222
+ )
223
 
224
  mp_drawing.draw_landmarks(
225
  frame,
226
  detection_results.left_hand_landmarks,
227
+ connections=custom_hand_connections,
228
+ landmark_drawing_spec=custom_left_hand_style
229
+ )
230
 
231
  if inference_config.output_dir is not None:
232
  writer.write(frame)
 
241
 
242
  if inference_config.output_dir is not None:
243
  writer.release()
244
+ logging.info(f"Video is recorded and saved to {Path(inference_config.output_dir) / 'output.mp4'}")
245
+ pd.DataFrame(results).to_csv(Path(inference_config.output_dir) / "results.csv", index=False)
246
+ logging.info(f"Results saved to {Path(inference_config.output_dir) / 'results.csv'}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
 
248
+ return {"results": results}