Spaces:
Paused
Paused
File size: 3,851 Bytes
a422282 4a5016f a422282 4a5016f a422282 4a5016f a422282 0723bf0 4a5016f 36d65da 4a5016f 36d65da 0723bf0 4a5016f 36d65da 4a5016f 36d65da 4a5016f 36d65da 0723bf0 4a5016f 36d65da a422282 36d65da a422282 36d65da a422282 36d65da 4a5016f 36d65da a422282 36d65da 4a5016f a422282 4a5016f a422282 4a5016f a422282 4a5016f a422282 4a5016f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
"""
Pose estimation module for golf swing analysis
"""
import cv2
import numpy as np
import mediapipe as mp
import math
from tqdm import tqdm
# Keep only essential imports for pose detection
class PoseEstimator:
def __init__(self):
self.mp_pose = mp.solutions.pose
self.pose = self.mp_pose.Pose(static_image_mode=False,
model_complexity=2, # Improved hip/foot stability
enable_segmentation=False,
min_detection_confidence=0.5,
min_tracking_confidence=0.5)
def process_frame(self, frame):
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
results = self.pose.process(frame_rgb)
h, w, _ = frame.shape
if results.pose_landmarks:
keypoints = []
world_landmarks = []
total_visibility = 0
# Extract 2D keypoints
for landmark in results.pose_landmarks.landmark:
x, y = int(landmark.x * w), int(landmark.y * h)
visibility = landmark.visibility
keypoints.append([x, y, visibility])
total_visibility += visibility
# Extract 3D world landmarks if available
if results.pose_world_landmarks:
for landmark in results.pose_world_landmarks.landmark:
# World landmarks are in meters relative to hip center
world_landmarks.append([landmark.x, landmark.y, landmark.z])
else:
world_landmarks = None
# Check if this is a reasonable pose detection
avg_visibility = total_visibility / len(results.pose_landmarks.landmark)
if avg_visibility > 0.3: # Only return if average visibility is decent
return keypoints, world_landmarks
else:
# Poor quality detection, return fallback
fallback_2d = [[w//2, h//2, 0.05] for _ in range(33)]
fallback_3d = [[0.0, 0.0, 0.0] for _ in range(33)] if world_landmarks else None
return fallback_2d, fallback_3d
else:
# No pose detected, return very low confidence markers
fallback_2d = [[w//2, h//2, 0.05] for _ in range(33)]
return fallback_2d, None
def close(self):
self.pose.close()
def analyze_pose(frames):
"""
Analyze pose in video frames
Args:
frames (list): List of video frames
Returns:
tuple: (pose_data, world_landmarks) where:
- pose_data: Dictionary mapping frame indices to 2D pose keypoints
- world_landmarks: Dictionary mapping frame indices to 3D world landmarks
"""
pose_estimator = PoseEstimator()
pose_data = {}
world_landmarks = {}
for i, frame in enumerate(tqdm(frames, desc="Analyzing pose")):
keypoints, world_lm = pose_estimator.process_frame(frame)
# Always store keypoints (never None due to fallback in process_frame)
pose_data[i] = keypoints
if world_lm is not None:
world_landmarks[i] = world_lm
pose_estimator.close()
return pose_data, world_landmarks
# Deprecated - joint angle calculations removed (not part of 5 core metrics)
def calculate_joint_angles(keypoints):
"""
Deprecated function - joint angles are not part of the 5 core metrics.
Returns empty dict to maintain backward compatibility.
Args:
keypoints: List of [x, y, visibility] for each landmark or None
Returns:
Empty dictionary (joint angles deprecated)
"""
# Return empty dict since joint angles are not part of the 5 core metrics
return {} |