Par-ity_Project / app /models /pose_estimator.py
chenemii's picture
Add head height change metric to replace hip turn metric
36d65da
"""
Pose estimation module for golf swing analysis
"""
import cv2
import numpy as np
import mediapipe as mp
import math
from tqdm import tqdm
# Keep only essential imports for pose detection
class PoseEstimator:
def __init__(self):
self.mp_pose = mp.solutions.pose
self.pose = self.mp_pose.Pose(static_image_mode=False,
model_complexity=2, # Improved hip/foot stability
enable_segmentation=False,
min_detection_confidence=0.5,
min_tracking_confidence=0.5)
def process_frame(self, frame):
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
results = self.pose.process(frame_rgb)
h, w, _ = frame.shape
if results.pose_landmarks:
keypoints = []
world_landmarks = []
total_visibility = 0
# Extract 2D keypoints
for landmark in results.pose_landmarks.landmark:
x, y = int(landmark.x * w), int(landmark.y * h)
visibility = landmark.visibility
keypoints.append([x, y, visibility])
total_visibility += visibility
# Extract 3D world landmarks if available
if results.pose_world_landmarks:
for landmark in results.pose_world_landmarks.landmark:
# World landmarks are in meters relative to hip center
world_landmarks.append([landmark.x, landmark.y, landmark.z])
else:
world_landmarks = None
# Check if this is a reasonable pose detection
avg_visibility = total_visibility / len(results.pose_landmarks.landmark)
if avg_visibility > 0.3: # Only return if average visibility is decent
return keypoints, world_landmarks
else:
# Poor quality detection, return fallback
fallback_2d = [[w//2, h//2, 0.05] for _ in range(33)]
fallback_3d = [[0.0, 0.0, 0.0] for _ in range(33)] if world_landmarks else None
return fallback_2d, fallback_3d
else:
# No pose detected, return very low confidence markers
fallback_2d = [[w//2, h//2, 0.05] for _ in range(33)]
return fallback_2d, None
def close(self):
self.pose.close()
def analyze_pose(frames):
"""
Analyze pose in video frames
Args:
frames (list): List of video frames
Returns:
tuple: (pose_data, world_landmarks) where:
- pose_data: Dictionary mapping frame indices to 2D pose keypoints
- world_landmarks: Dictionary mapping frame indices to 3D world landmarks
"""
pose_estimator = PoseEstimator()
pose_data = {}
world_landmarks = {}
for i, frame in enumerate(tqdm(frames, desc="Analyzing pose")):
keypoints, world_lm = pose_estimator.process_frame(frame)
# Always store keypoints (never None due to fallback in process_frame)
pose_data[i] = keypoints
if world_lm is not None:
world_landmarks[i] = world_lm
pose_estimator.close()
return pose_data, world_landmarks
# Deprecated - joint angle calculations removed (not part of 5 core metrics)
def calculate_joint_angles(keypoints):
"""
Deprecated function - joint angles are not part of the 5 core metrics.
Returns empty dict to maintain backward compatibility.
Args:
keypoints: List of [x, y, visibility] for each landmark or None
Returns:
Empty dictionary (joint angles deprecated)
"""
# Return empty dict since joint angles are not part of the 5 core metrics
return {}