Spaces:
Runtime error
Runtime error
Update inference.py
Browse files- inference.py +48 -75
inference.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
-
#inference.py
|
|
|
|
| 2 |
import shutil
|
| 3 |
import logging
|
| 4 |
from time import time
|
|
@@ -8,8 +9,7 @@ import pandas as pd
|
|
| 8 |
import cv2
|
| 9 |
from traceback import format_exc
|
| 10 |
from argparse import Namespace
|
| 11 |
-
from
|
| 12 |
-
from simple_parsing import ArgumentParser
|
| 13 |
import mediapipe as mp
|
| 14 |
from mediapipe.python.solutions.pose import PoseLandmark
|
| 15 |
from mediapipe.python.solutions.hands import HandLandmark
|
|
@@ -19,8 +19,16 @@ from visualization import draw_text_on_image
|
|
| 19 |
from configs import ModelConfig, InferenceConfig
|
| 20 |
from utils import config_logger, POSE_BASED_MODELS
|
| 21 |
from data import Arm, get_sample_timestamp, ok_to_get_frame
|
| 22 |
-
from tools import load_pipeline, Predictions
|
| 23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
SPOTER_POSE_LANDMARKS = [
|
| 26 |
PoseLandmark.NOSE,
|
|
@@ -31,7 +39,8 @@ SPOTER_POSE_LANDMARKS = [
|
|
| 31 |
PoseLandmark.RIGHT_ELBOW,
|
| 32 |
PoseLandmark.LEFT_ELBOW,
|
| 33 |
PoseLandmark.RIGHT_WRIST,
|
| 34 |
-
PoseLandmark.LEFT_WRIST
|
|
|
|
| 35 |
|
| 36 |
SPOTER_HAND_LANDMARKS = [
|
| 37 |
HandLandmark.WRIST,
|
|
@@ -42,6 +51,7 @@ SPOTER_HAND_LANDMARKS = [
|
|
| 42 |
HandLandmark.THUMB_TIP, HandLandmark.THUMB_IP, HandLandmark.THUMB_MCP, HandLandmark.THUMB_CMC,
|
| 43 |
]
|
| 44 |
|
|
|
|
| 45 |
def get_args() -> Namespace:
|
| 46 |
parser = ArgumentParser(
|
| 47 |
description="Train a model on VSL",
|
|
@@ -52,13 +62,13 @@ def get_args() -> Namespace:
|
|
| 52 |
return parser.parse_args()
|
| 53 |
|
| 54 |
|
| 55 |
-
def inference(model_config, inference_config: InferenceConfig,
|
| 56 |
# Load video
|
| 57 |
-
source = str(inference_config.source) if inference_config.source.is_file() else 0
|
| 58 |
cap = cv2.VideoCapture(source)
|
| 59 |
if inference_config.output_dir is not None:
|
| 60 |
writer = cv2.VideoWriter(
|
| 61 |
-
str(inference_config.output_dir / "output.mp4"),
|
| 62 |
cv2.VideoWriter_fourcc(*"mp4v"),
|
| 63 |
cap.get(cv2.CAP_PROP_FPS),
|
| 64 |
(int(cap.get(3)), int(cap.get(4))),
|
|
@@ -69,7 +79,6 @@ def inference(model_config, inference_config: InferenceConfig, pipeline: Pipelin
|
|
| 69 |
mp_drawing = mp.solutions.drawing_utils
|
| 70 |
mp_drawing_styles = mp.solutions.drawing_styles
|
| 71 |
|
| 72 |
-
|
| 73 |
custom_pose_style = mp_drawing_styles.get_default_pose_landmarks_style()
|
| 74 |
custom_right_hand_style = mp_drawing_styles.get_default_hand_landmarks_style()
|
| 75 |
custom_left_hand_style = mp_drawing_styles.get_default_hand_landmarks_style()
|
|
@@ -77,30 +86,27 @@ def inference(model_config, inference_config: InferenceConfig, pipeline: Pipelin
|
|
| 77 |
custom_hand_connections = list(mp_holistic.HAND_CONNECTIONS)
|
| 78 |
|
| 79 |
if inference_config.show_skeleton:
|
| 80 |
-
#
|
| 81 |
pose_landmarks = SPOTER_POSE_LANDMARKS
|
| 82 |
hand_landmarks = SPOTER_HAND_LANDMARKS
|
| 83 |
|
| 84 |
for landmark in PoseLandmark:
|
| 85 |
if landmark in pose_landmarks:
|
| 86 |
-
custom_pose_style[landmark] = DrawingSpec(color=(0,255,0), thickness=2, circle_radius=2)
|
| 87 |
else:
|
| 88 |
-
custom_pose_style[landmark] = DrawingSpec(color=(0,0,0), thickness=0, circle_radius=0)
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
custom_pose_connections.remove(connection_tuple)
|
| 92 |
|
| 93 |
for landmark in HandLandmark:
|
| 94 |
if landmark in hand_landmarks:
|
| 95 |
-
custom_right_hand_style[landmark] = DrawingSpec(color=(0,0,255), thickness=2, circle_radius=2)
|
| 96 |
-
custom_left_hand_style[landmark] = DrawingSpec(color=(255,0,0), thickness=2, circle_radius=2)
|
| 97 |
else:
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
custom_hand_connections.remove(connection_tuple)
|
| 103 |
-
|
| 104 |
|
| 105 |
# Init variables
|
| 106 |
right_arm = Arm("right", inference_config.visibility)
|
|
@@ -155,19 +161,15 @@ def inference(model_config, inference_config: InferenceConfig, pipeline: Pipelin
|
|
| 155 |
if left_arm_ok_to_get_frame or right_arm_ok_to_get_frame:
|
| 156 |
# logging.info("Frame added to the list")
|
| 157 |
predictions = Predictions()
|
| 158 |
-
data.append(
|
| 159 |
|
| 160 |
# Calculate the start and end time of sign
|
| 161 |
start_time, end_time = get_sample_timestamp(left_arm, right_arm)
|
| 162 |
|
| 163 |
-
# Convert from
|
| 164 |
start_time /= 1_000
|
| 165 |
end_time /= 1_000
|
| 166 |
|
| 167 |
-
# logging.info(f"start_time: {start_time} - end_time: {end_time}")
|
| 168 |
-
# logging.info(f"\tLeft arm: {left_arm.start_time} - {left_arm.end_time} - {left_arm.is_up}")
|
| 169 |
-
# logging.info(f"\tRight arm: {right_arm.start_time} - {right_arm.end_time} - {right_arm.is_up}")
|
| 170 |
-
|
| 171 |
if start_time != 0 and end_time != 0:
|
| 172 |
# Render waiting screen
|
| 173 |
if inference_config.visualize:
|
|
@@ -183,7 +185,9 @@ def inference(model_config, inference_config: InferenceConfig, pipeline: Pipelin
|
|
| 183 |
break
|
| 184 |
|
| 185 |
start_inference_time = time()
|
| 186 |
-
|
|
|
|
|
|
|
| 187 |
predictions.inference_time = time() - start_inference_time
|
| 188 |
|
| 189 |
predictions.start_time = start_time
|
|
@@ -206,20 +210,23 @@ def inference(model_config, inference_config: InferenceConfig, pipeline: Pipelin
|
|
| 206 |
mp_drawing.draw_landmarks(
|
| 207 |
frame,
|
| 208 |
detection_results.pose_landmarks,
|
| 209 |
-
connections
|
| 210 |
-
landmark_drawing_spec=custom_pose_style
|
| 211 |
-
|
|
|
|
| 212 |
mp_drawing.draw_landmarks(
|
| 213 |
frame,
|
| 214 |
detection_results.right_hand_landmarks,
|
| 215 |
-
connections
|
| 216 |
-
landmark_drawing_spec=custom_right_hand_style
|
|
|
|
| 217 |
|
| 218 |
mp_drawing.draw_landmarks(
|
| 219 |
frame,
|
| 220 |
detection_results.left_hand_landmarks,
|
| 221 |
-
connections
|
| 222 |
-
landmark_drawing_spec=custom_left_hand_style
|
|
|
|
| 223 |
|
| 224 |
if inference_config.output_dir is not None:
|
| 225 |
writer.write(frame)
|
|
@@ -234,42 +241,8 @@ def inference(model_config, inference_config: InferenceConfig, pipeline: Pipelin
|
|
| 234 |
|
| 235 |
if inference_config.output_dir is not None:
|
| 236 |
writer.release()
|
| 237 |
-
logging.info(f"Video is recorded and saved to {inference_config.output_dir / 'output.
|
| 238 |
-
pd.DataFrame(results).to_csv(inference_config.output_dir / "results.csv", index=False)
|
| 239 |
-
logging.info(f"Results saved to {inference_config.output_dir / 'results.csv'}")
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
# inference.py
|
| 243 |
-
|
| 244 |
-
def main(args: Namespace) -> None:
|
| 245 |
-
model_config = args.model
|
| 246 |
-
logging.info(model_config)
|
| 247 |
-
inference_config = args.inference
|
| 248 |
-
logging.info(inference_config)
|
| 249 |
-
|
| 250 |
-
if model_config.arch in POSE_BASED_MODELS:
|
| 251 |
-
inference_config.use_pose_model = True
|
| 252 |
-
else:
|
| 253 |
-
inference_config.use_pose_model = False
|
| 254 |
-
|
| 255 |
-
pipeline_or_session = load_pipeline(model_config, inference_config)
|
| 256 |
-
logging.info("Pipeline loaded")
|
| 257 |
-
|
| 258 |
-
inference(model_config, inference_config, pipeline_or_session)
|
| 259 |
-
logging.info("Inference completed")
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
if __name__ == "__main__":
|
| 264 |
-
try:
|
| 265 |
-
args = get_args()
|
| 266 |
-
|
| 267 |
-
config_logger(args.inference.output_dir / "inference.log")
|
| 268 |
-
logging.info(f"Config file loaded from {args.config_path[0]}")
|
| 269 |
-
|
| 270 |
-
shutil.copy(args.config_path[0], args.inference.output_dir / "inference.yaml")
|
| 271 |
-
logging.info(f"Config file saved to {args.inference.output_dir}")
|
| 272 |
|
| 273 |
-
|
| 274 |
-
except Exception:
|
| 275 |
-
print(format_exc())
|
|
|
|
| 1 |
+
# inference.py
|
| 2 |
+
|
| 3 |
import shutil
|
| 4 |
import logging
|
| 5 |
from time import time
|
|
|
|
| 9 |
import cv2
|
| 10 |
from traceback import format_exc
|
| 11 |
from argparse import Namespace
|
| 12 |
+
from pydantic import BaseModel
|
|
|
|
| 13 |
import mediapipe as mp
|
| 14 |
from mediapipe.python.solutions.pose import PoseLandmark
|
| 15 |
from mediapipe.python.solutions.hands import HandLandmark
|
|
|
|
| 19 |
from configs import ModelConfig, InferenceConfig
|
| 20 |
from utils import config_logger, POSE_BASED_MODELS
|
| 21 |
from data import Arm, get_sample_timestamp, ok_to_get_frame
|
| 22 |
+
from tools.models import load_pipeline, get_predictions, Predictions
|
| 23 |
|
| 24 |
+
# Define id2gloss mapping
|
| 25 |
+
# Bạn cần thay thế bản đồ này với bản đồ thực tế của bạn
|
| 26 |
+
id2gloss = {
|
| 27 |
+
"0": "hello",
|
| 28 |
+
"1": "thanks",
|
| 29 |
+
"2": "yes",
|
| 30 |
+
# Thêm các ánh xạ cần thiết
|
| 31 |
+
}
|
| 32 |
|
| 33 |
SPOTER_POSE_LANDMARKS = [
|
| 34 |
PoseLandmark.NOSE,
|
|
|
|
| 39 |
PoseLandmark.RIGHT_ELBOW,
|
| 40 |
PoseLandmark.LEFT_ELBOW,
|
| 41 |
PoseLandmark.RIGHT_WRIST,
|
| 42 |
+
PoseLandmark.LEFT_WRIST
|
| 43 |
+
]
|
| 44 |
|
| 45 |
SPOTER_HAND_LANDMARKS = [
|
| 46 |
HandLandmark.WRIST,
|
|
|
|
| 51 |
HandLandmark.THUMB_TIP, HandLandmark.THUMB_IP, HandLandmark.THUMB_MCP, HandLandmark.THUMB_CMC,
|
| 52 |
]
|
| 53 |
|
| 54 |
+
|
| 55 |
def get_args() -> Namespace:
|
| 56 |
parser = ArgumentParser(
|
| 57 |
description="Train a model on VSL",
|
|
|
|
| 62 |
return parser.parse_args()
|
| 63 |
|
| 64 |
|
| 65 |
+
def inference(model_config, inference_config: InferenceConfig, session: ort.InferenceSession) -> dict:
|
| 66 |
# Load video
|
| 67 |
+
source = str(inference_config.source) if Path(inference_config.source).is_file() else 0
|
| 68 |
cap = cv2.VideoCapture(source)
|
| 69 |
if inference_config.output_dir is not None:
|
| 70 |
writer = cv2.VideoWriter(
|
| 71 |
+
str(Path(inference_config.output_dir) / "output.mp4"),
|
| 72 |
cv2.VideoWriter_fourcc(*"mp4v"),
|
| 73 |
cap.get(cv2.CAP_PROP_FPS),
|
| 74 |
(int(cap.get(3)), int(cap.get(4))),
|
|
|
|
| 79 |
mp_drawing = mp.solutions.drawing_utils
|
| 80 |
mp_drawing_styles = mp.solutions.drawing_styles
|
| 81 |
|
|
|
|
| 82 |
custom_pose_style = mp_drawing_styles.get_default_pose_landmarks_style()
|
| 83 |
custom_right_hand_style = mp_drawing_styles.get_default_hand_landmarks_style()
|
| 84 |
custom_left_hand_style = mp_drawing_styles.get_default_hand_landmarks_style()
|
|
|
|
| 86 |
custom_hand_connections = list(mp_holistic.HAND_CONNECTIONS)
|
| 87 |
|
| 88 |
if inference_config.show_skeleton:
|
| 89 |
+
# Định dạng đặc biệt cho 'spoter'
|
| 90 |
pose_landmarks = SPOTER_POSE_LANDMARKS
|
| 91 |
hand_landmarks = SPOTER_HAND_LANDMARKS
|
| 92 |
|
| 93 |
for landmark in PoseLandmark:
|
| 94 |
if landmark in pose_landmarks:
|
| 95 |
+
custom_pose_style[landmark] = DrawingSpec(color=(0, 255, 0), thickness=2, circle_radius=2)
|
| 96 |
else:
|
| 97 |
+
custom_pose_style[landmark] = DrawingSpec(color=(0, 0, 0), thickness=0, circle_radius=0)
|
| 98 |
+
# Loại bỏ các kết nối liên quan
|
| 99 |
+
custom_pose_connections = [conn for conn in custom_pose_connections if landmark.value not in conn]
|
|
|
|
| 100 |
|
| 101 |
for landmark in HandLandmark:
|
| 102 |
if landmark in hand_landmarks:
|
| 103 |
+
custom_right_hand_style[landmark] = DrawingSpec(color=(0, 0, 255), thickness=2, circle_radius=2)
|
| 104 |
+
custom_left_hand_style[landmark] = DrawingSpec(color=(255, 0, 0), thickness=2, circle_radius=2)
|
| 105 |
else:
|
| 106 |
+
# Loại bỏ các kết nối liên quan
|
| 107 |
+
custom_hand_connections = [conn for conn in custom_hand_connections if landmark.value not in conn]
|
| 108 |
+
custom_right_hand_style[landmark] = DrawingSpec(color=(0, 0, 0), thickness=0, circle_radius=0)
|
| 109 |
+
custom_left_hand_style[landmark] = DrawingSpec(color=(0, 0, 0), thickness=0, circle_radius=0)
|
|
|
|
|
|
|
| 110 |
|
| 111 |
# Init variables
|
| 112 |
right_arm = Arm("right", inference_config.visibility)
|
|
|
|
| 161 |
if left_arm_ok_to_get_frame or right_arm_ok_to_get_frame:
|
| 162 |
# logging.info("Frame added to the list")
|
| 163 |
predictions = Predictions()
|
| 164 |
+
data.append(frame) # Chỉ sử dụng frame vì bạn đang dùng .onnx
|
| 165 |
|
| 166 |
# Calculate the start and end time of sign
|
| 167 |
start_time, end_time = get_sample_timestamp(left_arm, right_arm)
|
| 168 |
|
| 169 |
+
# Convert from milliseconds to seconds
|
| 170 |
start_time /= 1_000
|
| 171 |
end_time /= 1_000
|
| 172 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
if start_time != 0 and end_time != 0:
|
| 174 |
# Render waiting screen
|
| 175 |
if inference_config.visualize:
|
|
|
|
| 185 |
break
|
| 186 |
|
| 187 |
start_inference_time = time()
|
| 188 |
+
# Chuyển data thành np.ndarray phù hợp với mô hình ONNX
|
| 189 |
+
data_np = np.stack(data, axis=0) # Giả sử mô hình nhận dạng theo batch
|
| 190 |
+
predictions = get_predictions(data_np, session, id2gloss=id2gloss, k=inference_config.top_k)
|
| 191 |
predictions.inference_time = time() - start_inference_time
|
| 192 |
|
| 193 |
predictions.start_time = start_time
|
|
|
|
| 210 |
mp_drawing.draw_landmarks(
|
| 211 |
frame,
|
| 212 |
detection_results.pose_landmarks,
|
| 213 |
+
connections=custom_pose_connections,
|
| 214 |
+
landmark_drawing_spec=custom_pose_style
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
mp_drawing.draw_landmarks(
|
| 218 |
frame,
|
| 219 |
detection_results.right_hand_landmarks,
|
| 220 |
+
connections=custom_hand_connections,
|
| 221 |
+
landmark_drawing_spec=custom_right_hand_style
|
| 222 |
+
)
|
| 223 |
|
| 224 |
mp_drawing.draw_landmarks(
|
| 225 |
frame,
|
| 226 |
detection_results.left_hand_landmarks,
|
| 227 |
+
connections=custom_hand_connections,
|
| 228 |
+
landmark_drawing_spec=custom_left_hand_style
|
| 229 |
+
)
|
| 230 |
|
| 231 |
if inference_config.output_dir is not None:
|
| 232 |
writer.write(frame)
|
|
|
|
| 241 |
|
| 242 |
if inference_config.output_dir is not None:
|
| 243 |
writer.release()
|
| 244 |
+
logging.info(f"Video is recorded and saved to {Path(inference_config.output_dir) / 'output.mp4'}")
|
| 245 |
+
pd.DataFrame(results).to_csv(Path(inference_config.output_dir) / "results.csv", index=False)
|
| 246 |
+
logging.info(f"Results saved to {Path(inference_config.output_dir) / 'results.csv'}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 247 |
|
| 248 |
+
return {"results": results}
|
|
|
|
|
|