# inference.py import shutil import logging from time import time import numpy as np import pandas as pd import cv2 from traceback import format_exc from argparse import Namespace from pydantic import BaseModel import mediapipe as mp from mediapipe.python.solutions.pose import PoseLandmark from mediapipe.python.solutions.hands import HandLandmark from mediapipe.python.solutions.drawing_utils import DrawingSpec from visualization import draw_text_on_image from configs import ModelConfig, InferenceConfig from utils import config_logger, POSE_BASED_MODELS from data import Arm, get_sample_timestamp, ok_to_get_frame from tools.models import load_pipeline, get_predictions, Predictions # Define id2gloss mapping # Bạn cần thay thế bản đồ này với bản đồ thực tế của bạn id2gloss = { "0": "hello", "1": "thanks", "2": "yes", # Thêm các ánh xạ cần thiết } SPOTER_POSE_LANDMARKS = [ PoseLandmark.NOSE, PoseLandmark.LEFT_EYE, PoseLandmark.RIGHT_EYE, PoseLandmark.RIGHT_SHOULDER, PoseLandmark.LEFT_SHOULDER, PoseLandmark.RIGHT_ELBOW, PoseLandmark.LEFT_ELBOW, PoseLandmark.RIGHT_WRIST, PoseLandmark.LEFT_WRIST ] SPOTER_HAND_LANDMARKS = [ HandLandmark.WRIST, HandLandmark.INDEX_FINGER_TIP, HandLandmark.INDEX_FINGER_DIP, HandLandmark.INDEX_FINGER_PIP, HandLandmark.INDEX_FINGER_MCP, HandLandmark.MIDDLE_FINGER_TIP, HandLandmark.MIDDLE_FINGER_DIP, HandLandmark.MIDDLE_FINGER_PIP, HandLandmark.MIDDLE_FINGER_MCP, HandLandmark.RING_FINGER_TIP, HandLandmark.RING_FINGER_DIP, HandLandmark.RING_FINGER_PIP, HandLandmark.RING_FINGER_MCP, HandLandmark.PINKY_TIP, HandLandmark.PINKY_DIP, HandLandmark.PINKY_PIP, HandLandmark.PINKY_MCP, HandLandmark.THUMB_TIP, HandLandmark.THUMB_IP, HandLandmark.THUMB_MCP, HandLandmark.THUMB_CMC, ] def get_args() -> Namespace: parser = ArgumentParser( description="Train a model on VSL", add_config_path_arg=True, ) parser.add_arguments(ModelConfig, "model") parser.add_arguments(InferenceConfig, "inference") return parser.parse_args() def inference(model_config, inference_config: InferenceConfig, session: ort.InferenceSession) -> dict: # Load video source = str(inference_config.source) if Path(inference_config.source).is_file() else 0 cap = cv2.VideoCapture(source) if inference_config.output_dir is not None: writer = cv2.VideoWriter( str(Path(inference_config.output_dir) / "output.mp4"), cv2.VideoWriter_fourcc(*"mp4v"), cap.get(cv2.CAP_PROP_FPS), (int(cap.get(3)), int(cap.get(4))), ) # Init Mediapipe mp_holistic = mp.solutions.holistic mp_drawing = mp.solutions.drawing_utils mp_drawing_styles = mp.solutions.drawing_styles custom_pose_style = mp_drawing_styles.get_default_pose_landmarks_style() custom_right_hand_style = mp_drawing_styles.get_default_hand_landmarks_style() custom_left_hand_style = mp_drawing_styles.get_default_hand_landmarks_style() custom_pose_connections = list(mp_holistic.POSE_CONNECTIONS) custom_hand_connections = list(mp_holistic.HAND_CONNECTIONS) if inference_config.show_skeleton: # Định dạng đặc biệt cho 'spoter' pose_landmarks = SPOTER_POSE_LANDMARKS hand_landmarks = SPOTER_HAND_LANDMARKS for landmark in PoseLandmark: if landmark in pose_landmarks: custom_pose_style[landmark] = DrawingSpec(color=(0, 255, 0), thickness=2, circle_radius=2) else: custom_pose_style[landmark] = DrawingSpec(color=(0, 0, 0), thickness=0, circle_radius=0) # Loại bỏ các kết nối liên quan custom_pose_connections = [conn for conn in custom_pose_connections if landmark.value not in conn] for landmark in HandLandmark: if landmark in hand_landmarks: custom_right_hand_style[landmark] = DrawingSpec(color=(0, 0, 255), thickness=2, circle_radius=2) custom_left_hand_style[landmark] = DrawingSpec(color=(255, 0, 0), thickness=2, circle_radius=2) else: # Loại bỏ các kết nối liên quan custom_hand_connections = [conn for conn in custom_hand_connections if landmark.value not in conn] custom_right_hand_style[landmark] = DrawingSpec(color=(0, 0, 0), thickness=0, circle_radius=0) custom_left_hand_style[landmark] = DrawingSpec(color=(0, 0, 0), thickness=0, circle_radius=0) # Init variables right_arm = Arm("right", inference_config.visibility) left_arm = Arm("left", inference_config.visibility) data = [] results = None predictions = Predictions() with mp_holistic.Holistic(min_detection_confidence=0.9, min_tracking_confidence=0.5) as holistic: while cap.isOpened(): success, frame = cap.read() if not success: break # Recolor image to RGB, because mp processes on RGB image frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frame.flags.writeable = False # Make detections detection_results = holistic.process(frame) # Recolor image back to BGR, because cv2 processes on BGR image frame.flags.writeable = True frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) # Extract landmarks try: landmarks = detection_results.pose_landmarks.landmark except Exception: continue left_arm.set_pose(landmarks) right_arm.set_pose(landmarks) # Check if arms are up or down left_arm_ok_to_get_frame = ok_to_get_frame( arm=left_arm, angle_threshold=inference_config.angle_threshold, min_num_up_frames=inference_config.min_num_up_frames, min_num_down_frames=inference_config.min_num_down_frames, current_time=cap.get(cv2.CAP_PROP_POS_MSEC), delay=inference_config.delay, ) right_arm_ok_to_get_frame = ok_to_get_frame( arm=right_arm, angle_threshold=inference_config.angle_threshold, min_num_up_frames=inference_config.min_num_up_frames, min_num_down_frames=inference_config.min_num_down_frames, current_time=cap.get(cv2.CAP_PROP_POS_MSEC), delay=inference_config.delay, ) if left_arm_ok_to_get_frame or right_arm_ok_to_get_frame: # logging.info("Frame added to the list") predictions = Predictions() data.append(frame) # Chỉ sử dụng frame vì bạn đang dùng .onnx # Calculate the start and end time of sign start_time, end_time = get_sample_timestamp(left_arm, right_arm) # Convert from milliseconds to seconds start_time /= 1_000 end_time /= 1_000 if start_time != 0 and end_time != 0: # Render waiting screen if inference_config.visualize: wait_frame = draw_text_on_image( np.zeros_like(frame), text="Please wait for the prediction...", position=(20, 20), color=(255, 255, 255), font_size=20, ) cv2.imshow("Video Visualization", wait_frame) if cv2.waitKey(1) & 0xFF == ord('q'): break start_inference_time = time() # Chuyển data thành np.ndarray phù hợp với mô hình ONNX data_np = np.stack(data, axis=0) # Giả sử mô hình nhận dạng theo batch predictions = get_predictions(data_np, session, id2gloss=id2gloss, k=inference_config.top_k) predictions.inference_time = time() - start_inference_time predictions.start_time = start_time predictions.end_time = end_time logging.info(str(predictions)) results = predictions.merge_results(results) # Reset variables start_time = 0 end_time = 0 left_arm.reset_state() right_arm.reset_state() data = [] # Render detections frame = left_arm.visualize(frame, (20, 10), "Left arm angle") frame = right_arm.visualize(frame, (20, 40), "Right arm angle") frame = predictions.visualize(frame, (20, 70)) if inference_config.show_skeleton: mp_drawing.draw_landmarks( frame, detection_results.pose_landmarks, connections=custom_pose_connections, landmark_drawing_spec=custom_pose_style ) mp_drawing.draw_landmarks( frame, detection_results.right_hand_landmarks, connections=custom_hand_connections, landmark_drawing_spec=custom_right_hand_style ) mp_drawing.draw_landmarks( frame, detection_results.left_hand_landmarks, connections=custom_hand_connections, landmark_drawing_spec=custom_left_hand_style ) if inference_config.output_dir is not None: writer.write(frame) if inference_config.visualize: cv2.imshow("Video Visualization", frame) if cv2.waitKey(1) & 0xFF == ord('q'): break cap.release() cv2.destroyAllWindows() if inference_config.output_dir is not None: writer.release() logging.info(f"Video is recorded and saved to {Path(inference_config.output_dir) / 'output.mp4'}") pd.DataFrame(results).to_csv(Path(inference_config.output_dir) / "results.csv", index=False) logging.info(f"Results saved to {Path(inference_config.output_dir) / 'results.csv'}") return {"results": results}