import cv2 import numpy as np import torch from mediapipe.python.solutions import (drawing_styles, drawing_utils, holistic, pose) from torchvision.transforms.v2 import Compose, UniformTemporalSubsample def draw_skeleton_on_image( image: np.ndarray, detection_results, resize_to: tuple[int, int] = None, ) -> np.ndarray: """ Draw skeleton on the image. Parameters ---------- image : np.ndarray Image to draw skeleton on. detection_results Detection results. resize_to : tuple[int, int], optional Resize the image to the specified size. Returns ------- np.ndarray Annotated image with skeleton. """ annotated_image = np.copy(image) # Draw pose connections drawing_utils.draw_landmarks( annotated_image, detection_results.pose_landmarks, holistic.POSE_CONNECTIONS, landmark_drawing_spec=drawing_styles.get_default_pose_landmarks_style(), ) # Draw left hand connections drawing_utils.draw_landmarks( annotated_image, detection_results.left_hand_landmarks, holistic.HAND_CONNECTIONS, drawing_utils.DrawingSpec(color=(121, 22, 76), thickness=2, circle_radius=4), drawing_utils.DrawingSpec(color=(121, 44, 250), thickness=2, circle_radius=2), ) # Draw right hand connections drawing_utils.draw_landmarks( annotated_image, detection_results.right_hand_landmarks, holistic.HAND_CONNECTIONS, drawing_utils.DrawingSpec(color=(245, 117, 66), thickness=2, circle_radius=4), drawing_utils.DrawingSpec(color=(245, 66, 230), thickness=2, circle_radius=2), ) if resize_to is not None: annotated_image = cv2.resize( annotated_image, resize_to, interpolation=cv2.INTER_AREA, ) return annotated_image def are_hands_down(pose_landmarks: list) -> bool: """ Check if the hand is down. Parameters ---------- hand_landmarks : list Hand landmarks. Returns ------- bool True if the hand is down, False otherwise. """ if pose_landmarks is None: return True landmarks = pose_landmarks.landmark left_elbow = [ landmarks[pose.PoseLandmark.LEFT_ELBOW.value].x, landmarks[pose.PoseLandmark.LEFT_ELBOW.value].y, landmarks[pose.PoseLandmark.LEFT_SHOULDER.value].visibility, ] left_wrist = [ landmarks[pose.PoseLandmark.LEFT_WRIST.value].x, landmarks[pose.PoseLandmark.LEFT_WRIST.value].y, landmarks[pose.PoseLandmark.LEFT_SHOULDER.value].visibility, ] right_elbow = [ landmarks[pose.PoseLandmark.RIGHT_ELBOW.value].x, landmarks[pose.PoseLandmark.RIGHT_ELBOW.value].y, landmarks[pose.PoseLandmark.RIGHT_SHOULDER.value].visibility, ] right_wrist = [ landmarks[pose.PoseLandmark.RIGHT_WRIST.value].x, landmarks[pose.PoseLandmark.RIGHT_WRIST.value].y, landmarks[pose.PoseLandmark.RIGHT_SHOULDER.value].visibility, ] is_visible = all( [left_elbow[2] > 0, left_wrist[2] > 0, right_elbow[2] > 0, right_wrist[2] > 0] ) return is_visible and left_wrist[1] > left_elbow[1] and right_wrist[1] > right_elbow[1] def get_predictions( inputs: dict, model, k: int = 3, ) -> list: if inputs is None: return [] with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits # Get top-3 predictions topk_scores, topk_indices = torch.topk(logits, k, dim=1) topk_scores = torch.nn.functional.softmax(topk_scores, dim=1).squeeze().detach().numpy() topk_indices = topk_indices.squeeze().detach().numpy() return [ { 'label': model.config.id2label[topk_indices[i]], 'score': topk_scores[i], } for i in range(k) ] def preprocess( model_num_frames: int, keypoints_detector, source: str, model_input_height: int, model_input_width: int, device: str, transform: Compose, ) -> dict: skeleton_video = [] did_sample_start = False cap = cv2.VideoCapture(source) while cap.isOpened(): ret, frame = cap.read() if not ret: break frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # Detect keypoints. detection_results = keypoints_detector.process(frame) skeleton_frame = draw_skeleton_on_image( image=np.zeros((1080, 1080, 3), dtype=np.uint8), detection_results=detection_results, resize_to=(model_input_height, model_input_width), ) # (height, width, channels) -> (channels, height, width) skeleton_frame = transform(torch.tensor(skeleton_frame).permute(2, 0, 1)) # Extract sign video. if not are_hands_down(detection_results.pose_landmarks): if not did_sample_start: did_sample_start = True elif did_sample_start: break if did_sample_start: skeleton_video.append(skeleton_frame) cap.release() if len(skeleton_video) < model_num_frames: return None skeleton_video = torch.stack(skeleton_video) skeleton_video = UniformTemporalSubsample(model_num_frames)(skeleton_video) inputs = { 'pixel_values': skeleton_video.unsqueeze(0), } inputs = {k: v.to(device) for k, v in inputs.items()} return inputs