Spaces:
Runtime error
Runtime error
| # inference.py | |
| import shutil | |
| import logging | |
| from time import time | |
| import numpy as np | |
| import pandas as pd | |
| import cv2 | |
| from traceback import format_exc | |
| from argparse import Namespace | |
| from pydantic import BaseModel | |
| import mediapipe as mp | |
| from mediapipe.python.solutions.pose import PoseLandmark | |
| from mediapipe.python.solutions.hands import HandLandmark | |
| from mediapipe.python.solutions.drawing_utils import DrawingSpec | |
| from visualization import draw_text_on_image | |
| from configs import ModelConfig, InferenceConfig | |
| from utils import config_logger, POSE_BASED_MODELS | |
| from data import Arm, get_sample_timestamp, ok_to_get_frame | |
| from tools.models import load_pipeline, get_predictions, Predictions | |
| # Define id2gloss mapping | |
| # Bạn cần thay thế bản đồ này với bản đồ thực tế của bạn | |
| id2gloss = { | |
| "0": "hello", | |
| "1": "thanks", | |
| "2": "yes", | |
| # Thêm các ánh xạ cần thiết | |
| } | |
| SPOTER_POSE_LANDMARKS = [ | |
| PoseLandmark.NOSE, | |
| PoseLandmark.LEFT_EYE, | |
| PoseLandmark.RIGHT_EYE, | |
| PoseLandmark.RIGHT_SHOULDER, | |
| PoseLandmark.LEFT_SHOULDER, | |
| PoseLandmark.RIGHT_ELBOW, | |
| PoseLandmark.LEFT_ELBOW, | |
| PoseLandmark.RIGHT_WRIST, | |
| PoseLandmark.LEFT_WRIST | |
| ] | |
| SPOTER_HAND_LANDMARKS = [ | |
| HandLandmark.WRIST, | |
| HandLandmark.INDEX_FINGER_TIP, HandLandmark.INDEX_FINGER_DIP, HandLandmark.INDEX_FINGER_PIP, HandLandmark.INDEX_FINGER_MCP, | |
| HandLandmark.MIDDLE_FINGER_TIP, HandLandmark.MIDDLE_FINGER_DIP, HandLandmark.MIDDLE_FINGER_PIP, HandLandmark.MIDDLE_FINGER_MCP, | |
| HandLandmark.RING_FINGER_TIP, HandLandmark.RING_FINGER_DIP, HandLandmark.RING_FINGER_PIP, HandLandmark.RING_FINGER_MCP, | |
| HandLandmark.PINKY_TIP, HandLandmark.PINKY_DIP, HandLandmark.PINKY_PIP, HandLandmark.PINKY_MCP, | |
| HandLandmark.THUMB_TIP, HandLandmark.THUMB_IP, HandLandmark.THUMB_MCP, HandLandmark.THUMB_CMC, | |
| ] | |
| def get_args() -> Namespace: | |
| parser = ArgumentParser( | |
| description="Train a model on VSL", | |
| add_config_path_arg=True, | |
| ) | |
| parser.add_arguments(ModelConfig, "model") | |
| parser.add_arguments(InferenceConfig, "inference") | |
| return parser.parse_args() | |
| def inference(model_config, inference_config: InferenceConfig, session: ort.InferenceSession) -> dict: | |
| # Load video | |
| source = str(inference_config.source) if Path(inference_config.source).is_file() else 0 | |
| cap = cv2.VideoCapture(source) | |
| if inference_config.output_dir is not None: | |
| writer = cv2.VideoWriter( | |
| str(Path(inference_config.output_dir) / "output.mp4"), | |
| cv2.VideoWriter_fourcc(*"mp4v"), | |
| cap.get(cv2.CAP_PROP_FPS), | |
| (int(cap.get(3)), int(cap.get(4))), | |
| ) | |
| # Init Mediapipe | |
| mp_holistic = mp.solutions.holistic | |
| mp_drawing = mp.solutions.drawing_utils | |
| mp_drawing_styles = mp.solutions.drawing_styles | |
| custom_pose_style = mp_drawing_styles.get_default_pose_landmarks_style() | |
| custom_right_hand_style = mp_drawing_styles.get_default_hand_landmarks_style() | |
| custom_left_hand_style = mp_drawing_styles.get_default_hand_landmarks_style() | |
| custom_pose_connections = list(mp_holistic.POSE_CONNECTIONS) | |
| custom_hand_connections = list(mp_holistic.HAND_CONNECTIONS) | |
| if inference_config.show_skeleton: | |
| # Định dạng đặc biệt cho 'spoter' | |
| pose_landmarks = SPOTER_POSE_LANDMARKS | |
| hand_landmarks = SPOTER_HAND_LANDMARKS | |
| for landmark in PoseLandmark: | |
| if landmark in pose_landmarks: | |
| custom_pose_style[landmark] = DrawingSpec(color=(0, 255, 0), thickness=2, circle_radius=2) | |
| else: | |
| custom_pose_style[landmark] = DrawingSpec(color=(0, 0, 0), thickness=0, circle_radius=0) | |
| # Loại bỏ các kết nối liên quan | |
| custom_pose_connections = [conn for conn in custom_pose_connections if landmark.value not in conn] | |
| for landmark in HandLandmark: | |
| if landmark in hand_landmarks: | |
| custom_right_hand_style[landmark] = DrawingSpec(color=(0, 0, 255), thickness=2, circle_radius=2) | |
| custom_left_hand_style[landmark] = DrawingSpec(color=(255, 0, 0), thickness=2, circle_radius=2) | |
| else: | |
| # Loại bỏ các kết nối liên quan | |
| custom_hand_connections = [conn for conn in custom_hand_connections if landmark.value not in conn] | |
| custom_right_hand_style[landmark] = DrawingSpec(color=(0, 0, 0), thickness=0, circle_radius=0) | |
| custom_left_hand_style[landmark] = DrawingSpec(color=(0, 0, 0), thickness=0, circle_radius=0) | |
| # Init variables | |
| right_arm = Arm("right", inference_config.visibility) | |
| left_arm = Arm("left", inference_config.visibility) | |
| data = [] | |
| results = None | |
| predictions = Predictions() | |
| with mp_holistic.Holistic(min_detection_confidence=0.9, min_tracking_confidence=0.5) as holistic: | |
| while cap.isOpened(): | |
| success, frame = cap.read() | |
| if not success: | |
| break | |
| # Recolor image to RGB, because mp processes on RGB image | |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| frame.flags.writeable = False | |
| # Make detections | |
| detection_results = holistic.process(frame) | |
| # Recolor image back to BGR, because cv2 processes on BGR image | |
| frame.flags.writeable = True | |
| frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) | |
| # Extract landmarks | |
| try: | |
| landmarks = detection_results.pose_landmarks.landmark | |
| except Exception: | |
| continue | |
| left_arm.set_pose(landmarks) | |
| right_arm.set_pose(landmarks) | |
| # Check if arms are up or down | |
| left_arm_ok_to_get_frame = ok_to_get_frame( | |
| arm=left_arm, | |
| angle_threshold=inference_config.angle_threshold, | |
| min_num_up_frames=inference_config.min_num_up_frames, | |
| min_num_down_frames=inference_config.min_num_down_frames, | |
| current_time=cap.get(cv2.CAP_PROP_POS_MSEC), | |
| delay=inference_config.delay, | |
| ) | |
| right_arm_ok_to_get_frame = ok_to_get_frame( | |
| arm=right_arm, | |
| angle_threshold=inference_config.angle_threshold, | |
| min_num_up_frames=inference_config.min_num_up_frames, | |
| min_num_down_frames=inference_config.min_num_down_frames, | |
| current_time=cap.get(cv2.CAP_PROP_POS_MSEC), | |
| delay=inference_config.delay, | |
| ) | |
| if left_arm_ok_to_get_frame or right_arm_ok_to_get_frame: | |
| # logging.info("Frame added to the list") | |
| predictions = Predictions() | |
| data.append(frame) # Chỉ sử dụng frame vì bạn đang dùng .onnx | |
| # Calculate the start and end time of sign | |
| start_time, end_time = get_sample_timestamp(left_arm, right_arm) | |
| # Convert from milliseconds to seconds | |
| start_time /= 1_000 | |
| end_time /= 1_000 | |
| if start_time != 0 and end_time != 0: | |
| # Render waiting screen | |
| if inference_config.visualize: | |
| wait_frame = draw_text_on_image( | |
| np.zeros_like(frame), | |
| text="Please wait for the prediction...", | |
| position=(20, 20), | |
| color=(255, 255, 255), | |
| font_size=20, | |
| ) | |
| cv2.imshow("Video Visualization", wait_frame) | |
| if cv2.waitKey(1) & 0xFF == ord('q'): | |
| break | |
| start_inference_time = time() | |
| # Chuyển data thành np.ndarray phù hợp với mô hình ONNX | |
| data_np = np.stack(data, axis=0) # Giả sử mô hình nhận dạng theo batch | |
| predictions = get_predictions(data_np, session, id2gloss=id2gloss, k=inference_config.top_k) | |
| predictions.inference_time = time() - start_inference_time | |
| predictions.start_time = start_time | |
| predictions.end_time = end_time | |
| logging.info(str(predictions)) | |
| results = predictions.merge_results(results) | |
| # Reset variables | |
| start_time = 0 | |
| end_time = 0 | |
| left_arm.reset_state() | |
| right_arm.reset_state() | |
| data = [] | |
| # Render detections | |
| frame = left_arm.visualize(frame, (20, 10), "Left arm angle") | |
| frame = right_arm.visualize(frame, (20, 40), "Right arm angle") | |
| frame = predictions.visualize(frame, (20, 70)) | |
| if inference_config.show_skeleton: | |
| mp_drawing.draw_landmarks( | |
| frame, | |
| detection_results.pose_landmarks, | |
| connections=custom_pose_connections, | |
| landmark_drawing_spec=custom_pose_style | |
| ) | |
| mp_drawing.draw_landmarks( | |
| frame, | |
| detection_results.right_hand_landmarks, | |
| connections=custom_hand_connections, | |
| landmark_drawing_spec=custom_right_hand_style | |
| ) | |
| mp_drawing.draw_landmarks( | |
| frame, | |
| detection_results.left_hand_landmarks, | |
| connections=custom_hand_connections, | |
| landmark_drawing_spec=custom_left_hand_style | |
| ) | |
| if inference_config.output_dir is not None: | |
| writer.write(frame) | |
| if inference_config.visualize: | |
| cv2.imshow("Video Visualization", frame) | |
| if cv2.waitKey(1) & 0xFF == ord('q'): | |
| break | |
| cap.release() | |
| cv2.destroyAllWindows() | |
| if inference_config.output_dir is not None: | |
| writer.release() | |
| logging.info(f"Video is recorded and saved to {Path(inference_config.output_dir) / 'output.mp4'}") | |
| pd.DataFrame(results).to_csv(Path(inference_config.output_dir) / "results.csv", index=False) | |
| logging.info(f"Results saved to {Path(inference_config.output_dir) / 'results.csv'}") | |
| return {"results": results} | |