File size: 3,851 Bytes
a422282
 
 
 
 
 
 
4a5016f
a422282
 
4a5016f
 
a422282
 
 
 
4a5016f
a422282
 
 
 
 
 
 
0723bf0
 
 
4a5016f
36d65da
4a5016f
36d65da
 
0723bf0
 
 
 
4a5016f
 
36d65da
 
 
 
 
 
 
 
4a5016f
 
 
36d65da
4a5016f
 
36d65da
 
 
0723bf0
4a5016f
36d65da
 
a422282
 
 
 
 
 
 
 
 
 
 
 
36d65da
 
 
a422282
 
 
36d65da
a422282
 
36d65da
4a5016f
 
36d65da
 
a422282
 
36d65da
4a5016f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a422282
 
 
4a5016f
 
a422282
 
4a5016f
a422282
 
4a5016f
a422282
4a5016f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
"""
Pose estimation module for golf swing analysis
"""

import cv2
import numpy as np
import mediapipe as mp
import math
from tqdm import tqdm

# Keep only essential imports for pose detection

class PoseEstimator:
    def __init__(self):
        self.mp_pose = mp.solutions.pose
        self.pose = self.mp_pose.Pose(static_image_mode=False,
                                      model_complexity=2,  # Improved hip/foot stability
                                      enable_segmentation=False,
                                      min_detection_confidence=0.5,
                                      min_tracking_confidence=0.5)

    def process_frame(self, frame):
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        results = self.pose.process(frame_rgb)
        h, w, _ = frame.shape

        if results.pose_landmarks:
            keypoints = []
            world_landmarks = []
            total_visibility = 0
            
            # Extract 2D keypoints
            for landmark in results.pose_landmarks.landmark:
                x, y = int(landmark.x * w), int(landmark.y * h)
                visibility = landmark.visibility
                keypoints.append([x, y, visibility])
                total_visibility += visibility
            
            # Extract 3D world landmarks if available
            if results.pose_world_landmarks:
                for landmark in results.pose_world_landmarks.landmark:
                    # World landmarks are in meters relative to hip center
                    world_landmarks.append([landmark.x, landmark.y, landmark.z])
            else:
                world_landmarks = None
            
            # Check if this is a reasonable pose detection
            avg_visibility = total_visibility / len(results.pose_landmarks.landmark)
            if avg_visibility > 0.3:  # Only return if average visibility is decent
                return keypoints, world_landmarks
            else:
                # Poor quality detection, return fallback
                fallback_2d = [[w//2, h//2, 0.05] for _ in range(33)]
                fallback_3d = [[0.0, 0.0, 0.0] for _ in range(33)] if world_landmarks else None
                return fallback_2d, fallback_3d
        else:
            # No pose detected, return very low confidence markers
            fallback_2d = [[w//2, h//2, 0.05] for _ in range(33)]
            return fallback_2d, None

    def close(self):
        self.pose.close()

def analyze_pose(frames):
    """
    Analyze pose in video frames
    
    Args:
        frames (list): List of video frames
        
    Returns:
        tuple: (pose_data, world_landmarks) where:
            - pose_data: Dictionary mapping frame indices to 2D pose keypoints
            - world_landmarks: Dictionary mapping frame indices to 3D world landmarks
    """
    pose_estimator = PoseEstimator()
    pose_data = {}
    world_landmarks = {}

    for i, frame in enumerate(tqdm(frames, desc="Analyzing pose")):
        keypoints, world_lm = pose_estimator.process_frame(frame)
        # Always store keypoints (never None due to fallback in process_frame)
        pose_data[i] = keypoints
        if world_lm is not None:
            world_landmarks[i] = world_lm

    pose_estimator.close()
    return pose_data, world_landmarks


# Deprecated - joint angle calculations removed (not part of 5 core metrics)

























def calculate_joint_angles(keypoints):
    """
    Deprecated function - joint angles are not part of the 5 core metrics.
    Returns empty dict to maintain backward compatibility.
    
    Args:
        keypoints: List of [x, y, visibility] for each landmark or None
        
    Returns:
        Empty dictionary (joint angles deprecated)
    """
    # Return empty dict since joint angles are not part of the 5 core metrics
    return {}