Spaces:
Paused
Paused
| """ | |
| Visualization module for creating annotated videos | |
| """ | |
| import os | |
| import cv2 | |
| import numpy as np | |
| from tqdm import tqdm | |
| import logging | |
| import mediapipe as mp | |
| # Define body part groups and their colors | |
| BODY_PART_COLORS = { | |
| "head": (255, 0, 0), # Blue | |
| "torso": (0, 255, 0), # Green | |
| "arms": (255, 165, 0), # Orange | |
| "hands": (255, 0, 255), # Magenta | |
| "legs": (0, 255, 255), # Cyan | |
| "feet": (255, 255, 0) # Yellow | |
| } | |
| # Define which landmarks belong to which body part groups | |
| BODY_PARTS_MAPPING = { | |
| "head": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], # Nose, eyes, ears, mouth | |
| "torso": [11, 12, 23, 24], # Shoulders and hips | |
| "arms": [11, 12, 13, 14], # Shoulders and elbows | |
| "hands": [15, 16, 17, 18, 19, 20, 21, | |
| 22], # Wrists, pinkies, indices, thumbs | |
| "legs": [23, 24, 25, 26], # Hips and knees | |
| "feet": [27, 28, 29, 30, 31, 32] # Ankles, heels, foot indices | |
| } | |
| def create_annotated_video(video_path, | |
| frames, | |
| detections, | |
| pose_data, | |
| swing_phases, | |
| trajectory_data, | |
| output_dir="downloads", | |
| sample_rate=1): | |
| """ | |
| Create an annotated video with swing analysis visualizations | |
| Args: | |
| video_path (str): Path to the original video | |
| frames (list): List of video frames | |
| detections (list): List of Detection objects | |
| pose_data (dict): Pose estimation data | |
| swing_phases (dict): Swing phase segmentation data | |
| trajectory_data (dict): Trajectory and speed analysis data | |
| output_dir (str): Directory to save the output video | |
| sample_rate (int): The frame sampling rate used during processing | |
| Returns: | |
| str: Path to the annotated video | |
| """ | |
| try: | |
| # Create output directory if it doesn't exist | |
| os.makedirs(output_dir, exist_ok=True) | |
| # Check if sample rate should be adjusted for short videos | |
| if len(frames) < 150 and sample_rate > 1: | |
| sample_rate = 1 | |
| # Get original video filename without extension | |
| video_name = os.path.splitext(os.path.basename(video_path))[0] | |
| output_path = os.path.join(output_dir, f"{video_name}_annotated.mp4") | |
| # Get video properties | |
| if not frames or len(frames) == 0: | |
| raise ValueError("No frames provided for annotation") | |
| height, width = frames[0].shape[:2] | |
| fps = 30 # Default fps | |
| # Check the original video orientation using OpenCV | |
| cap = cv2.VideoCapture(video_path) | |
| if not cap.isOpened(): | |
| raise IOError(f"Could not open original video: {video_path}") | |
| # Read metadata from the original video if available | |
| rotation = 0 | |
| # Try to get rotation metadata from the video | |
| if hasattr(cap, 'get') and callable(getattr(cap, 'get')): | |
| try: | |
| rotation_value = cap.get(cv2.CAP_PROP_ORIENTATION_META) | |
| if rotation_value == 0: # No rotation | |
| rotation = 0 | |
| elif rotation_value == 90: # 90 degrees clockwise | |
| rotation = 270 # We'll rotate counterclockwise, so 270 | |
| elif rotation_value == 180: # 180 degrees | |
| rotation = 180 | |
| elif rotation_value == 270: # 270 degrees clockwise | |
| rotation = 90 # We'll rotate counterclockwise, so 90 | |
| except: | |
| # If metadata reading fails, don't apply any rotation | |
| rotation = 0 | |
| # Don't apply automatic rotation based on dimensions | |
| # Keep the video in its original orientation | |
| # Close the video capture | |
| cap.release() | |
| # Determine output dimensions based on rotation | |
| output_width = width | |
| output_height = height | |
| if rotation == 90 or rotation == 270: | |
| # Swap dimensions for 90/270 degree rotations | |
| output_width, output_height = height, width | |
| # Create video writer with proper dimensions | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| out = cv2.VideoWriter(output_path, fourcc, fps, (output_width, output_height)) | |
| if not out.isOpened(): | |
| raise IOError( | |
| f"Failed to create video writer for {output_path}. Check directory permissions." | |
| ) | |
| # Process each frame | |
| for i, frame in enumerate(tqdm(frames, | |
| desc="Creating annotated video")): | |
| # Create a copy of the frame for annotations | |
| annotated_frame = frame.copy() | |
| # Apply rotation if needed | |
| if rotation == 90: | |
| print(f"Rotating frame {i} by 90 degrees counterclockwise") | |
| # Rotate 90 degrees counterclockwise | |
| annotated_frame = cv2.rotate(annotated_frame, cv2.ROTATE_90_COUNTERCLOCKWISE) | |
| # Transform coordinates for detections and pose keypoints | |
| if i in pose_data: | |
| print(f"Transforming pose data for frame {i}") | |
| keypoints = pose_data[i] | |
| # Debug: Check keypoints structure | |
| print(f"Keypoints type: {type(keypoints)}, length: {len(keypoints)}") | |
| if len(keypoints) > 0: | |
| print(f"First keypoint type: {type(keypoints[0])}") | |
| for j in range(len(keypoints)): | |
| if keypoints[j] is not None and len(keypoints[j]) >= 2: | |
| try: | |
| x, y = keypoints[j][0], keypoints[j][1] | |
| # Fix coordinate transformation for 90-degree rotation | |
| keypoints[j] = (y, width - x - 1) | |
| except Exception as e: | |
| print(f"Error transforming keypoint {j}: {str(e)}, value: {keypoints[j]}") | |
| # Keep the keypoint as is if there's an error | |
| for detection in detections: | |
| if detection.frame_idx == i * sample_rate: | |
| try: | |
| x1, y1, x2, y2 = detection.bbox | |
| # Fix bbox coordinate transformation for 90-degree rotation | |
| # The correct transformation for 90 degrees counterclockwise is: | |
| # (y1, width - x2 - 1, y2, width - x1 - 1) | |
| detection.bbox = (y1, width - x2 - 1, y2, width - x1 - 1) | |
| except Exception as e: | |
| print(f"Error transforming detection bbox: {str(e)}") | |
| # Keep the bbox as is if there's an error | |
| elif rotation == 180: | |
| # Rotate 180 degrees | |
| annotated_frame = cv2.rotate(annotated_frame, cv2.ROTATE_180) | |
| # Transform coordinates | |
| if i in pose_data: | |
| keypoints = pose_data[i] | |
| for j in range(len(keypoints)): | |
| if keypoints[j] is not None and len(keypoints[j]) >= 2: | |
| try: | |
| x, y = keypoints[j][0], keypoints[j][1] | |
| keypoints[j] = (width - x - 1, height - y - 1) | |
| except Exception as e: | |
| print(f"Error transforming keypoint {j}: {str(e)}") | |
| # Keep the keypoint as is if there's an error | |
| for detection in detections: | |
| if detection.frame_idx == i * sample_rate: | |
| try: | |
| x1, y1, x2, y2 = detection.bbox | |
| detection.bbox = (width - x2 - 1, height - y2 - 1, width - x1 - 1, height - y1 - 1) | |
| except Exception as e: | |
| print(f"Error transforming detection bbox: {str(e)}") | |
| # Keep the bbox as is if there's an error | |
| elif rotation == 270: | |
| # Rotate 270 degrees counterclockwise (90 degrees clockwise) | |
| annotated_frame = cv2.rotate(annotated_frame, cv2.ROTATE_90_CLOCKWISE) | |
| # Transform coordinates | |
| if i in pose_data: | |
| keypoints = pose_data[i] | |
| for j in range(len(keypoints)): | |
| if keypoints[j] is not None and len(keypoints[j]) >= 2: | |
| try: | |
| x, y = keypoints[j][0], keypoints[j][1] | |
| # Fix coordinate transformation for 270-degree rotation | |
| keypoints[j] = (height - y - 1, x) | |
| except Exception as e: | |
| print(f"Error transforming keypoint {j}: {str(e)}") | |
| # Keep the keypoint as is if there's an error | |
| for detection in detections: | |
| if detection.frame_idx == i * sample_rate: | |
| try: | |
| x1, y1, x2, y2 = detection.bbox | |
| # Fix bbox coordinate transformation for 270-degree rotation | |
| # The correct transformation for 270 degrees counterclockwise is: | |
| # (height - y2 - 1, x1, height - y1 - 1, x2) | |
| detection.bbox = (height - y2 - 1, x1, height - y1 - 1, x2) | |
| except Exception as e: | |
| print(f"Error transforming detection bbox: {str(e)}") | |
| # Keep the bbox as is if there's an error | |
| # Draw detections - only show person detections, skip other objects | |
| frame_detections = [ | |
| d for d in detections if d.frame_idx == i * sample_rate and d.class_name == "person" | |
| ] | |
| for detection in frame_detections: | |
| try: | |
| # Check if bbox has exactly 4 values before unpacking | |
| if not hasattr(detection, 'bbox') or not isinstance(detection.bbox, tuple) or len(detection.bbox) != 4: | |
| print(f"Invalid bbox format: {getattr(detection, 'bbox', None)}") | |
| continue | |
| x1, y1, x2, y2 = map(int, detection.bbox) | |
| # Draw bounding box (only for person detections - green) | |
| color = (0, 255, 0) # Green for person | |
| cv2.rectangle(annotated_frame, (x1, y1), (x2, y2), color, 2) | |
| # Draw label | |
| label = f"{detection.class_name}: {detection.confidence:.2f}" | |
| cv2.putText(annotated_frame, label, (x1, y1 - 10), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2) | |
| except Exception as e: | |
| print(f"Error drawing detection: {str(e)}") | |
| # Skip this detection if there's an error | |
| # Draw pose keypoints with different colors for different body parts | |
| if i in pose_data: | |
| keypoints = pose_data[i] | |
| # Draw each keypoint with its corresponding body part color | |
| for part_name, part_indices in BODY_PARTS_MAPPING.items(): | |
| color = BODY_PART_COLORS[part_name] | |
| for idx in part_indices: | |
| if idx < len(keypoints) and keypoints[idx] is not None and len(keypoints[idx]) >= 2: | |
| try: | |
| x, y = int(keypoints[idx][0]), int(keypoints[idx][1]) | |
| cv2.circle(annotated_frame, (x, y), 5, color, -1) | |
| except Exception as e: | |
| print(f"Error drawing keypoint {idx}: {str(e)}") | |
| # Skip this keypoint if there's an error | |
| # Draw connections between keypoints | |
| mp_pose = mp.solutions.pose | |
| connections = mp_pose.POSE_CONNECTIONS | |
| for connection in connections: | |
| start_idx, end_idx = connection | |
| if (start_idx < len(keypoints) and end_idx < len(keypoints) | |
| and keypoints[start_idx] is not None | |
| and keypoints[end_idx] is not None | |
| and len(keypoints[start_idx]) >= 2 | |
| and len(keypoints[end_idx]) >= 2): | |
| try: | |
| # Determine the color based on the body part of the start point | |
| color = None | |
| for part_name, part_indices in BODY_PARTS_MAPPING.items(): | |
| if start_idx in part_indices: | |
| color = BODY_PART_COLORS[part_name] | |
| break | |
| # If no color found, use white | |
| if color is None: | |
| color = (255, 255, 255) | |
| start_point = (int(keypoints[start_idx][0]), | |
| int(keypoints[start_idx][1])) | |
| end_point = (int(keypoints[end_idx][0]), | |
| int(keypoints[end_idx][1])) | |
| cv2.line(annotated_frame, start_point, end_point, | |
| color, 2) | |
| except Exception as e: | |
| print(f"Error drawing connection {start_idx}-{end_idx}: {str(e)}") | |
| # Skip this connection if there's an error | |
| # Draw swing phase information | |
| phase = None | |
| for phase_name, phase_frames in swing_phases.items(): | |
| # Skip non-phase keys like timing_unreliable | |
| if not isinstance(phase_frames, list): | |
| continue | |
| if i in phase_frames: | |
| phase = phase_name | |
| break | |
| if phase: | |
| cv2.putText(annotated_frame, f"Phase: {phase}", (10, 30), | |
| cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2) | |
| # Draw trajectory information if available | |
| if i in trajectory_data: | |
| traj_info = trajectory_data[i] | |
| # Club speed display removed - not part of 5 core metrics | |
| # Adjust ball trajectory points if we rotated the frame | |
| if "ball_trajectory" in traj_info and traj_info["ball_trajectory"]: | |
| points = traj_info["ball_trajectory"] | |
| adjusted_points = [] | |
| # Adjust the trajectory points based on rotation | |
| if rotation == 90: # 90 degrees counterclockwise | |
| for point in points: | |
| try: | |
| x, y = point[0], point[1] # Access by index to avoid unpacking errors | |
| adjusted_points.append((height - y - 1, x)) | |
| except Exception as e: | |
| print(f"Error transforming trajectory point: {str(e)}") | |
| # Skip this point if there's an error | |
| elif rotation == 180: # 180 degrees | |
| for point in points: | |
| try: | |
| x, y = point[0], point[1] | |
| adjusted_points.append((width - x - 1, height - y - 1)) | |
| except Exception as e: | |
| print(f"Error transforming trajectory point: {str(e)}") | |
| # Skip this point if there's an error | |
| elif rotation == 270: # 270 degrees counterclockwise | |
| for point in points: | |
| try: | |
| x, y = point[0], point[1] | |
| adjusted_points.append((y, width - x - 1)) | |
| except Exception as e: | |
| print(f"Error transforming trajectory point: {str(e)}") | |
| # Skip this point if there's an error | |
| else: # No rotation | |
| adjusted_points = points | |
| # Draw the trajectory | |
| for j in range(1, len(adjusted_points)): | |
| try: | |
| pt1 = (int(adjusted_points[j - 1][0]), int(adjusted_points[j - 1][1])) | |
| pt2 = (int(adjusted_points[j][0]), int(adjusted_points[j][1])) | |
| cv2.line(annotated_frame, pt1, pt2, (0, 255, 255), 2) | |
| except Exception as e: | |
| print(f"Error drawing trajectory line: {str(e)}") | |
| # Skip this line if there's an error | |
| # Add legend for body part colors | |
| legend_y_start = 110 | |
| legend_y_spacing = 30 | |
| legend_x = 10 | |
| legend_box_size = 20 | |
| # Draw legend title | |
| cv2.putText(annotated_frame, "Body Parts Legend:", | |
| (legend_x, legend_y_start - 10), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2) | |
| # Draw color boxes and labels for each body part | |
| for idx, (part_name, color) in enumerate(BODY_PART_COLORS.items()): | |
| y_pos = legend_y_start + idx * legend_y_spacing | |
| # Draw color box | |
| cv2.rectangle(annotated_frame, | |
| (legend_x, y_pos - legend_box_size + 5), | |
| (legend_x + legend_box_size, y_pos + 5), color, | |
| -1) | |
| # Draw part name | |
| cv2.putText(annotated_frame, part_name.capitalize(), | |
| (legend_x + legend_box_size + 10, y_pos + 5), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2) | |
| # Write the annotated frame to the output video | |
| out.write(annotated_frame) | |
| # Release video writer | |
| out.release() | |
| # Verify the file was created | |
| if not os.path.exists(output_path) or os.path.getsize( | |
| output_path) == 0: | |
| raise IOError(f"Failed to create video file at {output_path}") | |
| print(f"Annotated video saved to: {output_path}") | |
| return output_path | |
| except Exception as e: | |
| print(f"Error creating annotated video: {str(e)}") | |
| raise | |