Spaces:
Paused
Paused
| """ | |
| Pose estimation module for golf swing analysis | |
| """ | |
| import cv2 | |
| import numpy as np | |
| import mediapipe as mp | |
| import math | |
| from tqdm import tqdm | |
| # Keep only essential imports for pose detection | |
| class PoseEstimator: | |
| def __init__(self): | |
| self.mp_pose = mp.solutions.pose | |
| self.pose = self.mp_pose.Pose(static_image_mode=False, | |
| model_complexity=2, # Improved hip/foot stability | |
| enable_segmentation=False, | |
| min_detection_confidence=0.5, | |
| min_tracking_confidence=0.5) | |
| def process_frame(self, frame): | |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| results = self.pose.process(frame_rgb) | |
| h, w, _ = frame.shape | |
| if results.pose_landmarks: | |
| keypoints = [] | |
| world_landmarks = [] | |
| total_visibility = 0 | |
| # Extract 2D keypoints | |
| for landmark in results.pose_landmarks.landmark: | |
| x, y = int(landmark.x * w), int(landmark.y * h) | |
| visibility = landmark.visibility | |
| keypoints.append([x, y, visibility]) | |
| total_visibility += visibility | |
| # Extract 3D world landmarks if available | |
| if results.pose_world_landmarks: | |
| for landmark in results.pose_world_landmarks.landmark: | |
| # World landmarks are in meters relative to hip center | |
| world_landmarks.append([landmark.x, landmark.y, landmark.z]) | |
| else: | |
| world_landmarks = None | |
| # Check if this is a reasonable pose detection | |
| avg_visibility = total_visibility / len(results.pose_landmarks.landmark) | |
| if avg_visibility > 0.3: # Only return if average visibility is decent | |
| return keypoints, world_landmarks | |
| else: | |
| # Poor quality detection, return fallback | |
| fallback_2d = [[w//2, h//2, 0.05] for _ in range(33)] | |
| fallback_3d = [[0.0, 0.0, 0.0] for _ in range(33)] if world_landmarks else None | |
| return fallback_2d, fallback_3d | |
| else: | |
| # No pose detected, return very low confidence markers | |
| fallback_2d = [[w//2, h//2, 0.05] for _ in range(33)] | |
| return fallback_2d, None | |
| def close(self): | |
| self.pose.close() | |
| def analyze_pose(frames): | |
| """ | |
| Analyze pose in video frames | |
| Args: | |
| frames (list): List of video frames | |
| Returns: | |
| tuple: (pose_data, world_landmarks) where: | |
| - pose_data: Dictionary mapping frame indices to 2D pose keypoints | |
| - world_landmarks: Dictionary mapping frame indices to 3D world landmarks | |
| """ | |
| pose_estimator = PoseEstimator() | |
| pose_data = {} | |
| world_landmarks = {} | |
| for i, frame in enumerate(tqdm(frames, desc="Analyzing pose")): | |
| keypoints, world_lm = pose_estimator.process_frame(frame) | |
| # Always store keypoints (never None due to fallback in process_frame) | |
| pose_data[i] = keypoints | |
| if world_lm is not None: | |
| world_landmarks[i] = world_lm | |
| pose_estimator.close() | |
| return pose_data, world_landmarks | |
| # Deprecated - joint angle calculations removed (not part of 5 core metrics) | |
| def calculate_joint_angles(keypoints): | |
| """ | |
| Deprecated function - joint angles are not part of the 5 core metrics. | |
| Returns empty dict to maintain backward compatibility. | |
| Args: | |
| keypoints: List of [x, y, visibility] for each landmark or None | |
| Returns: | |
| Empty dictionary (joint angles deprecated) | |
| """ | |
| # Return empty dict since joint angles are not part of the 5 core metrics | |
| return {} |