""" Pose estimation module for golf swing analysis """ import cv2 import numpy as np import mediapipe as mp import math from tqdm import tqdm # Keep only essential imports for pose detection class PoseEstimator: def __init__(self): self.mp_pose = mp.solutions.pose self.pose = self.mp_pose.Pose(static_image_mode=False, model_complexity=2, # Improved hip/foot stability enable_segmentation=False, min_detection_confidence=0.5, min_tracking_confidence=0.5) def process_frame(self, frame): frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) results = self.pose.process(frame_rgb) h, w, _ = frame.shape if results.pose_landmarks: keypoints = [] world_landmarks = [] total_visibility = 0 # Extract 2D keypoints for landmark in results.pose_landmarks.landmark: x, y = int(landmark.x * w), int(landmark.y * h) visibility = landmark.visibility keypoints.append([x, y, visibility]) total_visibility += visibility # Extract 3D world landmarks if available if results.pose_world_landmarks: for landmark in results.pose_world_landmarks.landmark: # World landmarks are in meters relative to hip center world_landmarks.append([landmark.x, landmark.y, landmark.z]) else: world_landmarks = None # Check if this is a reasonable pose detection avg_visibility = total_visibility / len(results.pose_landmarks.landmark) if avg_visibility > 0.3: # Only return if average visibility is decent return keypoints, world_landmarks else: # Poor quality detection, return fallback fallback_2d = [[w//2, h//2, 0.05] for _ in range(33)] fallback_3d = [[0.0, 0.0, 0.0] for _ in range(33)] if world_landmarks else None return fallback_2d, fallback_3d else: # No pose detected, return very low confidence markers fallback_2d = [[w//2, h//2, 0.05] for _ in range(33)] return fallback_2d, None def close(self): self.pose.close() def analyze_pose(frames): """ Analyze pose in video frames Args: frames (list): List of video frames Returns: tuple: (pose_data, world_landmarks) where: - pose_data: Dictionary mapping frame indices to 2D pose keypoints - world_landmarks: Dictionary mapping frame indices to 3D world landmarks """ pose_estimator = PoseEstimator() pose_data = {} world_landmarks = {} for i, frame in enumerate(tqdm(frames, desc="Analyzing pose")): keypoints, world_lm = pose_estimator.process_frame(frame) # Always store keypoints (never None due to fallback in process_frame) pose_data[i] = keypoints if world_lm is not None: world_landmarks[i] = world_lm pose_estimator.close() return pose_data, world_landmarks # Deprecated - joint angle calculations removed (not part of 5 core metrics) def calculate_joint_angles(keypoints): """ Deprecated function - joint angles are not part of the 5 core metrics. Returns empty dict to maintain backward compatibility. Args: keypoints: List of [x, y, visibility] for each landmark or None Returns: Empty dictionary (joint angles deprecated) """ # Return empty dict since joint angles are not part of the 5 core metrics return {}