Spaces:
Build error
Build error
| import cv2 | |
| import numpy as np | |
| import torch | |
| from mediapipe.python.solutions import (drawing_styles, drawing_utils, | |
| holistic, pose) | |
| from torchvision.transforms.v2 import Compose, UniformTemporalSubsample | |
| def draw_skeleton_on_image( | |
| image: np.ndarray, | |
| detection_results, | |
| resize_to: tuple[int, int] = None, | |
| ) -> np.ndarray: | |
| """ | |
| Draw skeleton on the image. | |
| Parameters | |
| ---------- | |
| image : np.ndarray | |
| Image to draw skeleton on. | |
| detection_results | |
| Detection results. | |
| resize_to : tuple[int, int], optional | |
| Resize the image to the specified size. | |
| Returns | |
| ------- | |
| np.ndarray | |
| Annotated image with skeleton. | |
| """ | |
| annotated_image = np.copy(image) | |
| # Draw pose connections | |
| drawing_utils.draw_landmarks( | |
| annotated_image, | |
| detection_results.pose_landmarks, | |
| holistic.POSE_CONNECTIONS, | |
| landmark_drawing_spec=drawing_styles.get_default_pose_landmarks_style(), | |
| ) | |
| # Draw left hand connections | |
| drawing_utils.draw_landmarks( | |
| annotated_image, | |
| detection_results.left_hand_landmarks, | |
| holistic.HAND_CONNECTIONS, | |
| drawing_utils.DrawingSpec(color=(121, 22, 76), thickness=2, circle_radius=4), | |
| drawing_utils.DrawingSpec(color=(121, 44, 250), thickness=2, circle_radius=2), | |
| ) | |
| # Draw right hand connections | |
| drawing_utils.draw_landmarks( | |
| annotated_image, | |
| detection_results.right_hand_landmarks, | |
| holistic.HAND_CONNECTIONS, | |
| drawing_utils.DrawingSpec(color=(245, 117, 66), thickness=2, circle_radius=4), | |
| drawing_utils.DrawingSpec(color=(245, 66, 230), thickness=2, circle_radius=2), | |
| ) | |
| if resize_to is not None: | |
| annotated_image = cv2.resize( | |
| annotated_image, | |
| resize_to, | |
| interpolation=cv2.INTER_AREA, | |
| ) | |
| return annotated_image | |
| def are_hands_down(pose_landmarks: list) -> bool: | |
| """ | |
| Check if the hand is down. | |
| Parameters | |
| ---------- | |
| hand_landmarks : list | |
| Hand landmarks. | |
| Returns | |
| ------- | |
| bool | |
| True if the hand is down, False otherwise. | |
| """ | |
| if pose_landmarks is None: | |
| return True | |
| landmarks = pose_landmarks.landmark | |
| left_elbow = [ | |
| landmarks[pose.PoseLandmark.LEFT_ELBOW.value].x, | |
| landmarks[pose.PoseLandmark.LEFT_ELBOW.value].y, | |
| landmarks[pose.PoseLandmark.LEFT_SHOULDER.value].visibility, | |
| ] | |
| left_wrist = [ | |
| landmarks[pose.PoseLandmark.LEFT_WRIST.value].x, | |
| landmarks[pose.PoseLandmark.LEFT_WRIST.value].y, | |
| landmarks[pose.PoseLandmark.LEFT_SHOULDER.value].visibility, | |
| ] | |
| right_elbow = [ | |
| landmarks[pose.PoseLandmark.RIGHT_ELBOW.value].x, | |
| landmarks[pose.PoseLandmark.RIGHT_ELBOW.value].y, | |
| landmarks[pose.PoseLandmark.RIGHT_SHOULDER.value].visibility, | |
| ] | |
| right_wrist = [ | |
| landmarks[pose.PoseLandmark.RIGHT_WRIST.value].x, | |
| landmarks[pose.PoseLandmark.RIGHT_WRIST.value].y, | |
| landmarks[pose.PoseLandmark.RIGHT_SHOULDER.value].visibility, | |
| ] | |
| is_visible = all( | |
| [left_elbow[2] > 0, left_wrist[2] > 0, right_elbow[2] > 0, right_wrist[2] > 0] | |
| ) | |
| return is_visible and left_wrist[1] > left_elbow[1] and right_wrist[1] > right_elbow[1] | |
| def get_predictions( | |
| inputs: dict, | |
| model, | |
| k: int = 3, | |
| ) -> list: | |
| if inputs is None: | |
| return [] | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| # Get top-3 predictions | |
| topk_scores, topk_indices = torch.topk(logits, k, dim=1) | |
| topk_scores = torch.nn.functional.softmax(topk_scores, dim=1).squeeze().detach().numpy() | |
| topk_indices = topk_indices.squeeze().detach().numpy() | |
| return [ | |
| { | |
| 'label': model.config.id2label[topk_indices[i]], | |
| 'score': topk_scores[i], | |
| } | |
| for i in range(k) | |
| ] | |
| def preprocess( | |
| model_num_frames: int, | |
| keypoints_detector, | |
| source: str, | |
| model_input_height: int, | |
| model_input_width: int, | |
| device: str, | |
| transform: Compose, | |
| ) -> dict: | |
| skeleton_video = [] | |
| did_sample_start = False | |
| cap = cv2.VideoCapture(source) | |
| while cap.isOpened(): | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| # Detect keypoints. | |
| detection_results = keypoints_detector.process(frame) | |
| skeleton_frame = draw_skeleton_on_image( | |
| image=np.zeros((1080, 1080, 3), dtype=np.uint8), | |
| detection_results=detection_results, | |
| resize_to=(model_input_height, model_input_width), | |
| ) | |
| # (height, width, channels) -> (channels, height, width) | |
| skeleton_frame = transform(torch.tensor(skeleton_frame).permute(2, 0, 1)) | |
| # Extract sign video. | |
| if not are_hands_down(detection_results.pose_landmarks): | |
| if not did_sample_start: | |
| did_sample_start = True | |
| elif did_sample_start: | |
| break | |
| if did_sample_start: | |
| skeleton_video.append(skeleton_frame) | |
| cap.release() | |
| if len(skeleton_video) < model_num_frames: | |
| return None | |
| skeleton_video = torch.stack(skeleton_video) | |
| skeleton_video = UniformTemporalSubsample(model_num_frames)(skeleton_video) | |
| inputs = { | |
| 'pixel_values': skeleton_video.unsqueeze(0), | |
| } | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| return inputs | |