import shutil import logging from time import time import numpy as np import pandas as pd import cv2 from traceback import format_exc from argparse import Namespace from transformers import Pipeline from simple_parsing import ArgumentParser import mediapipe as mp from mediapipe.python.solutions.pose import PoseLandmark from mediapipe.python.solutions.hands import HandLandmark from mediapipe.python.solutions.drawing_utils import DrawingSpec from visualization import draw_text_on_image from configs import ModelConfig, InferenceConfig from utils import config_logger, POSE_BASED_MODELS from data import Arm, get_sample_timestamp, ok_to_get_frame from tools import load_pipeline, Predictions SPOTER_POSE_LANDMARKS = [ PoseLandmark.NOSE, PoseLandmark.LEFT_EYE, PoseLandmark.RIGHT_EYE, PoseLandmark.RIGHT_SHOULDER, PoseLandmark.LEFT_SHOULDER, PoseLandmark.RIGHT_ELBOW, PoseLandmark.LEFT_ELBOW, PoseLandmark.RIGHT_WRIST, PoseLandmark.LEFT_WRIST ] SPOTER_HAND_LANDMARKS = [ HandLandmark.WRIST, HandLandmark.INDEX_FINGER_TIP, HandLandmark.INDEX_FINGER_DIP, HandLandmark.INDEX_FINGER_PIP, HandLandmark.INDEX_FINGER_MCP, HandLandmark.MIDDLE_FINGER_TIP, HandLandmark.MIDDLE_FINGER_DIP, HandLandmark.MIDDLE_FINGER_PIP, HandLandmark.MIDDLE_FINGER_MCP, HandLandmark.RING_FINGER_TIP, HandLandmark.RING_FINGER_DIP, HandLandmark.RING_FINGER_PIP, HandLandmark.RING_FINGER_MCP, HandLandmark.PINKY_TIP, HandLandmark.PINKY_DIP, HandLandmark.PINKY_PIP, HandLandmark.PINKY_MCP, HandLandmark.THUMB_TIP, HandLandmark.THUMB_IP, HandLandmark.THUMB_MCP, HandLandmark.THUMB_CMC, ] def get_args() -> Namespace: parser = ArgumentParser( description="Train a model on VSL", add_config_path_arg=True, ) parser.add_arguments(ModelConfig, "model") parser.add_arguments(InferenceConfig, "inference") return parser.parse_args() def inference(model_config, inference_config: InferenceConfig, pipeline: Pipeline) -> None: # Load video source = str(inference_config.source) if inference_config.source.is_file() else 0 cap = cv2.VideoCapture(source) if inference_config.output_dir is not None: writer = cv2.VideoWriter( str(inference_config.output_dir / "output.mp4"), cv2.VideoWriter_fourcc(*"mp4v"), cap.get(cv2.CAP_PROP_FPS), (int(cap.get(3)), int(cap.get(4))), ) # Init Mediapipe mp_holistic = mp.solutions.holistic mp_drawing = mp.solutions.drawing_utils mp_drawing_styles = mp.solutions.drawing_styles custom_pose_style = mp_drawing_styles.get_default_pose_landmarks_style() custom_right_hand_style = mp_drawing_styles.get_default_hand_landmarks_style() custom_left_hand_style = mp_drawing_styles.get_default_hand_landmarks_style() custom_pose_connections = list(mp_holistic.POSE_CONNECTIONS) custom_hand_connections = list(mp_holistic.HAND_CONNECTIONS) if inference_config.show_skeleton: # if model_config.arch == 'spoter': pose_landmarks = SPOTER_POSE_LANDMARKS hand_landmarks = SPOTER_HAND_LANDMARKS for landmark in PoseLandmark: if landmark in pose_landmarks: custom_pose_style[landmark] = DrawingSpec(color=(0,255,0), thickness=2, circle_radius=2) else: custom_pose_style[landmark] = DrawingSpec(color=(0,0,0), thickness=0, circle_radius=0) for connection_tuple in custom_pose_connections: if landmark.value in connection_tuple: custom_pose_connections.remove(connection_tuple) for landmark in HandLandmark: if landmark in hand_landmarks: custom_right_hand_style[landmark] = DrawingSpec(color=(0,0,255), thickness=2, circle_radius=2) custom_left_hand_style[landmark] = DrawingSpec(color=(255,0,0), thickness=2, circle_radius=2) else: custom_right_hand_style[HandLandmark[landmark.name]] = DrawingSpec(color=(0,0,0), thickness=0, circle_radius=0) custom_left_hand_style[HandLandmark[landmark.name]] = DrawingSpec(color=(0,0,0), thickness=0, circle_radius=0) for connection_tuple in custom_hand_connections: if landmark.value in connection_tuple: custom_hand_connections.remove(connection_tuple) # Init variables right_arm = Arm("right", inference_config.visibility) left_arm = Arm("left", inference_config.visibility) data = [] results = None predictions = Predictions() with mp_holistic.Holistic(min_detection_confidence=0.9, min_tracking_confidence=0.5) as holistic: while cap.isOpened(): success, frame = cap.read() if not success: break # Recolor image to RGB, because mp processes on RGB image frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frame.flags.writeable = False # Make detections detection_results = holistic.process(frame) # Recolor image back to BGR, because cv2 processes on BGR image frame.flags.writeable = True frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) # Extract landmarks try: landmarks = detection_results.pose_landmarks.landmark except Exception: continue left_arm.set_pose(landmarks) right_arm.set_pose(landmarks) # Check if arms are up or down left_arm_ok_to_get_frame = ok_to_get_frame( arm=left_arm, angle_threshold=inference_config.angle_threshold, min_num_up_frames=inference_config.min_num_up_frames, min_num_down_frames=inference_config.min_num_down_frames, current_time=cap.get(cv2.CAP_PROP_POS_MSEC), delay=inference_config.delay, ) right_arm_ok_to_get_frame = ok_to_get_frame( arm=right_arm, angle_threshold=inference_config.angle_threshold, min_num_up_frames=inference_config.min_num_up_frames, min_num_down_frames=inference_config.min_num_down_frames, current_time=cap.get(cv2.CAP_PROP_POS_MSEC), delay=inference_config.delay, ) if left_arm_ok_to_get_frame or right_arm_ok_to_get_frame: # logging.info("Frame added to the list") predictions = Predictions() data.append(detection_results if inference_config.use_pose_model else frame) # Calculate the start and end time of sign start_time, end_time = get_sample_timestamp(left_arm, right_arm) # Convert from miliseconds to seconds start_time /= 1_000 end_time /= 1_000 # logging.info(f"start_time: {start_time} - end_time: {end_time}") # logging.info(f"\tLeft arm: {left_arm.start_time} - {left_arm.end_time} - {left_arm.is_up}") # logging.info(f"\tRight arm: {right_arm.start_time} - {right_arm.end_time} - {right_arm.is_up}") if start_time != 0 and end_time != 0: # Render waiting screen if inference_config.visualize: wait_frame = draw_text_on_image( np.zeros_like(frame), text="Please wait for the prediction...", position=(20, 20), color=(255, 255, 255), font_size=20, ) cv2.imshow("Video Visualization", wait_frame) if cv2.waitKey(1) & 0xFF == ord('q'): break start_inference_time = time() predictions = Predictions(predictions=pipeline(np.array(data))) predictions.inference_time = time() - start_inference_time predictions.start_time = start_time predictions.end_time = end_time logging.info(str(predictions)) results = predictions.merge_results(results) # Reset variables start_time = 0 end_time = 0 left_arm.reset_state() right_arm.reset_state() data = [] # Render detections frame = left_arm.visualize(frame, (20, 10), "Left arm angle") frame = right_arm.visualize(frame, (20, 40), "Right arm angle") frame = predictions.visualize(frame, (20, 70)) if inference_config.show_skeleton: mp_drawing.draw_landmarks( frame, detection_results.pose_landmarks, connections = custom_pose_connections, # passing the modified connections list landmark_drawing_spec=custom_pose_style) # and drawing style mp_drawing.draw_landmarks( frame, detection_results.right_hand_landmarks, connections = custom_hand_connections, # passing the modified connections list landmark_drawing_spec=custom_right_hand_style) # and drawing style mp_drawing.draw_landmarks( frame, detection_results.left_hand_landmarks, connections = custom_hand_connections, # passing the modified connections list landmark_drawing_spec=custom_left_hand_style) # and drawing style if inference_config.output_dir is not None: writer.write(frame) if inference_config.visualize: cv2.imshow("Video Visualization", frame) if cv2.waitKey(1) & 0xFF == ord('q'): break cap.release() cv2.destroyAllWindows() if inference_config.output_dir is not None: writer.release() logging.info(f"Video is recorded and saved to {inference_config.output_dir / 'output.avi'}") pd.DataFrame(results).to_csv(inference_config.output_dir / "results.csv", index=False) logging.info(f"Results saved to {inference_config.output_dir / 'results.csv'}") def main(args: Namespace) -> None: model_config = args.model logging.info(model_config) inference_config = args.inference logging.info(inference_config) if model_config.arch in POSE_BASED_MODELS: inference_config.use_pose_model = True else: inference_config.use_pose_model = False pipeline = load_pipeline(model_config, inference_config) logging.info("Pipeline loaded") inference(model_config, inference_config, pipeline) logging.info("Inference completed") if __name__ == "__main__": try: args = get_args() config_logger(args.inference.output_dir / "inference.log") logging.info(f"Config file loaded from {args.config_path[0]}") shutil.copy(args.config_path[0], args.inference.output_dir / "inference.yaml") logging.info(f"Config file saved to {args.inference.output_dir}") main(args=args) except Exception: print(format_exc())