Spaces:
Paused
Paused
| """ | |
| LLM-based Golf Swing Analysis | |
| This module handles LLM analysis and prompt generation for golf swing feedback. | |
| """ | |
| import os | |
| import json | |
| import requests | |
| import re | |
| from typing import Dict, Any, Optional | |
| from .metrics_calculator import ( | |
| calculate_back_tilt_degree, calculate_knee_bend_degree, | |
| calculate_shoulder_tilt_swing_plane_at_top, | |
| compute_dtl_three | |
| ) | |
| from .front_facing_metrics import compute_front_facing_metrics | |
| # Import grading functions to avoid duplication | |
| try: | |
| from ..streamlit_app import ( | |
| get_shoulder_tilt_swing_plane_grading, | |
| get_back_tilt_grading, | |
| get_knee_flexion_grading, | |
| get_head_drop_grading, | |
| get_shoulder_tilt_impact_grading, | |
| get_hip_sway_grading, | |
| get_wrist_hinge_grading, | |
| get_hip_shoulder_separation_impact_grading | |
| ) | |
| except ImportError: | |
| # Fallback if imports fail - define minimal grading functions | |
| def get_shoulder_tilt_swing_plane_grading(value, confidence): | |
| return {'value': value, 'status': 'n/a', 'badge': '⚪'} | |
| def get_back_tilt_grading(value, confidence): | |
| return {'value': value, 'status': 'n/a', 'badge': '⚪'} | |
| def get_knee_flexion_grading(value, confidence): | |
| return {'value': value, 'status': 'n/a', 'badge': '⚪'} | |
| def get_hip_depth_grading(value, confidence): | |
| return {'value': value, 'status': 'n/a', 'badge': '⚪'} | |
| def get_shoulder_tilt_impact_grading(value, confidence): | |
| return {'value': value, 'status': 'n/a', 'badge': '⚪'} | |
| def get_hip_sway_grading(value, confidence, position): | |
| return {'value': value, 'status': 'n/a', 'badge': '⚪'} | |
| def get_wrist_hinge_grading(value, confidence, position): | |
| return {'value': value, 'status': 'n/a', 'badge': '⚪'} | |
| def get_hip_shoulder_separation_impact_grading(value, confidence): | |
| return {'value': value, 'status': 'n/a', 'badge': '⚪'} | |
| def get_head_drop_grading(value, confidence): | |
| """Grade head drop metric based on percentage of torso length | |
| Args: | |
| value: percentage value (positive = head moved down) | |
| confidence: confidence score (not used in current implementation) | |
| """ | |
| if value is None: | |
| return {'value': None, 'status': 'error', 'badge': '❌'} | |
| # Address → Top: ~2–6% typical; >8% = too much "sit" | |
| if 2 <= value <= 6: | |
| return {'value': f"{value:.1f}%", 'status': 'good', 'badge': '✅'} | |
| elif 6 < value <= 8: | |
| return {'value': f"{value:.1f}%", 'status': 'caution', 'badge': '⚠️'} | |
| else: | |
| return {'value': f"{value:.1f}%", 'status': 'poor', 'badge': '❌'} | |
| def safe_fmt_deg(v): | |
| """Format angle value safely, returning 'n/a' for None or invalid values""" | |
| if v is None or (isinstance(v, str) and v.lower() in ['none', 'n/a', '']): | |
| return 'n/a' | |
| try: | |
| return f"{float(v):.1f}°" | |
| except (ValueError, TypeError): | |
| return 'n/a' | |
| def safe_fmt_percent(v): | |
| """Format percentage value safely, returning 'n/a' for None or invalid values""" | |
| if v is None or (isinstance(v, str) and v.lower() in ['none', 'n/a', '']): | |
| return 'n/a' | |
| try: | |
| return f"{float(v):.1f}%" | |
| except (ValueError, TypeError): | |
| return 'n/a' | |
| def check_llm_services(): | |
| """Check availability of LLM services""" | |
| services = { | |
| 'openai': {'available': bool(os.getenv('OPENAI_API_KEY'))}, | |
| 'ollama': {'available': False} # Will check if running | |
| } | |
| # Check if Ollama is running | |
| try: | |
| response = requests.get('http://localhost:11434/api/version', timeout=2) | |
| services['ollama']['available'] = response.status_code == 200 | |
| except: | |
| pass | |
| return services | |
| def call_openai_service(prompt, config): | |
| """Call OpenAI API for swing analysis""" | |
| api_key = os.getenv('OPENAI_API_KEY') | |
| if not api_key: | |
| return None | |
| headers = { | |
| 'Authorization': f'Bearer {api_key}', | |
| 'Content-Type': 'application/json' | |
| } | |
| payload = { | |
| 'model': config.get('model', 'gpt-4'), | |
| 'messages': [{'role': 'user', 'content': prompt}], | |
| 'max_tokens': config.get('max_tokens', 2000), | |
| 'temperature': config.get('temperature', 0.7) | |
| } | |
| try: | |
| response = requests.post( | |
| 'https://api.openai.com/v1/chat/completions', | |
| headers=headers, | |
| json=payload, | |
| timeout=30 | |
| ) | |
| response.raise_for_status() | |
| return response.json()['choices'][0]['message']['content'] | |
| except Exception as e: | |
| print(f"OpenAI API error: {e}") | |
| return None | |
| def call_ollama_service(prompt, config): | |
| """Call local Ollama service for swing analysis""" | |
| try: | |
| response = requests.post( | |
| 'http://localhost:11434/api/generate', | |
| json={ | |
| 'model': config.get('model', 'llama2'), | |
| 'prompt': prompt, | |
| 'stream': False | |
| }, | |
| timeout=60 | |
| ) | |
| response.raise_for_status() | |
| return response.json().get('response') | |
| except Exception as e: | |
| print(f"Ollama service error: {e}") | |
| return None | |
| def compute_core_metrics(pose_data, swing_phases, frame_timestamps_ms=None, total_ms=None, player_handedness='right', is_front_facing=False, frames=None, world_landmarks=None, club="iron"): | |
| """Compute core golf swing metrics | |
| Args: | |
| pose_data (dict): Dictionary mapping frame indices to pose keypoints | |
| swing_phases (dict): Dictionary mapping phase names to lists of frame indices | |
| frame_timestamps_ms (list, optional): List of frame timestamps in milliseconds | |
| total_ms (float, optional): Total video duration in milliseconds | |
| player_handedness (str): 'right' or 'left' handed player | |
| is_front_facing (bool): True for front-facing camera view | |
| club (str): Club type for grading ("driver", "iron", "wedge") | |
| Returns: | |
| dict: Core metrics with confidence scores and status | |
| """ | |
| if is_front_facing: | |
| # Use front-facing metrics | |
| return compute_front_facing_core_metrics(pose_data, swing_phases, frames, world_landmarks, club) | |
| else: | |
| # Use DTL metrics (existing implementation) | |
| return compute_dtl_core_metrics(pose_data, swing_phases, frame_timestamps_ms, total_ms, player_handedness, frames) | |
| def compute_dtl_core_metrics(pose_data, swing_phases, frame_timestamps_ms=None, total_ms=None, player_handedness='right', frames=None): | |
| """Compute DTL (Down-the-Line) golf swing metrics using new unified function""" | |
| # Extract frame dimensions from frames if available | |
| frame_w, frame_h = 1920, 1080 # Default values | |
| if frames and len(frames) > 0: | |
| if hasattr(frames[0], 'shape'): | |
| frame_h, frame_w = frames[0].shape[:2] | |
| # Use the new unified DTL metrics function | |
| try: | |
| dtl_metrics = compute_dtl_three(pose_data, swing_phases, frame_w, frame_h, frames) | |
| if dtl_metrics is None: | |
| # Fallback to individual calculations if unified function fails | |
| return compute_dtl_core_metrics_fallback(pose_data, swing_phases, frame_timestamps_ms, total_ms, player_handedness, frames) | |
| # Convert to the expected format with grading | |
| core_metrics = {} | |
| # Map the new metric names to the expected format (hip depth and hip turn removed) | |
| metric_mapping = { | |
| 'shoulder_plane_top_deg': ('shoulder_tilt_swing_plane_top_deg', get_shoulder_tilt_swing_plane_grading), | |
| 'back_tilt_setup_deg': ('back_tilt_deg', get_back_tilt_grading), | |
| 'knee_flexion_deg': ('knee_flexion_deg', get_knee_flexion_grading), | |
| 'head_drop_top_pct': ('head_drop_top_pct', get_head_drop_grading), | |
| # Hip turn mapping removed per user request | |
| } | |
| for new_key, (old_key, grading_func) in metric_mapping.items(): | |
| value = dtl_metrics.get(new_key) | |
| if value is not None: | |
| try: | |
| grading = grading_func(value, 0.8) | |
| if grading: | |
| # Hip turn debug info removed per user request | |
| core_metrics[old_key] = grading | |
| except Exception: | |
| # Fallback format if grading fails | |
| core_metrics[old_key] = { | |
| 'value': value, | |
| 'status': 'Calculated', | |
| 'badge': '✅' | |
| } | |
| # Add validation info if available | |
| validation = dtl_metrics.get('_validation') | |
| if validation: | |
| core_metrics['_validation'] = validation | |
| # Add debug info if available | |
| debug_info = dtl_metrics.get('_debug') | |
| if debug_info: | |
| core_metrics['_debug'] = debug_info | |
| return core_metrics | |
| except Exception as e: | |
| print(f"Error in compute_dtl_three: {e}") | |
| # Fallback to individual calculations | |
| return compute_dtl_core_metrics_fallback(pose_data, swing_phases, frame_timestamps_ms, total_ms, player_handedness, frames) | |
| def compute_dtl_core_metrics_fallback(pose_data, swing_phases, frame_timestamps_ms=None, total_ms=None, player_handedness='right', frames=None): | |
| """Fallback function using individual metric calculations""" | |
| # Extract frame dimensions from frames if available | |
| frame_w, frame_h = 1920, 1080 # Default values | |
| if frames and len(frames) > 0: | |
| if hasattr(frames[0], 'shape'): | |
| frame_h, frame_w = frames[0].shape[:2] | |
| # Get phase frame indices | |
| setup_frames = swing_phases.get("setup", []) | |
| backswing_frames = swing_phases.get("backswing", []) | |
| downswing_frames = swing_phases.get("downswing", []) | |
| impact_frames = swing_phases.get("impact", []) | |
| # Get key frame indices | |
| address_idx = setup_frames[0] if setup_frames else 0 | |
| top_idx = backswing_frames[-1] if backswing_frames else address_idx | |
| impact_idx = impact_frames[0] if impact_frames else top_idx | |
| # Initialize core metrics with new DTL metrics - don't set default status, let grading functions handle it | |
| core_metrics = {} | |
| # Calculate Shoulder Tilt / Swing Plane Angle at Top - professional = 36°, 30 handicap = 29° | |
| try: | |
| # Ensure we have valid indices and data | |
| if backswing_frames and top_idx in pose_data and pose_data[top_idx] is not None: | |
| shoulder_tilt_swing_plane = calculate_shoulder_tilt_swing_plane_at_top(pose_data, swing_phases, top_idx, frames) | |
| if shoulder_tilt_swing_plane is not None and isinstance(shoulder_tilt_swing_plane, (int, float)): | |
| grading = get_shoulder_tilt_swing_plane_grading(shoulder_tilt_swing_plane, 0.8) | |
| if grading and isinstance(grading, dict): | |
| core_metrics["shoulder_tilt_swing_plane_top_deg"] = grading | |
| except Exception as e: | |
| # Silently continue if calculation fails | |
| pass | |
| # Calculate Back Tilt @ Setup | |
| try: | |
| # Ensure we have valid indices and data | |
| if setup_frames and address_idx in pose_data and pose_data[address_idx] is not None: | |
| back_tilt = calculate_back_tilt_degree(pose_data, swing_phases, address_idx, frames) | |
| if back_tilt is not None and isinstance(back_tilt, (int, float)): | |
| grading = get_back_tilt_grading(back_tilt, 0.8) | |
| if grading and isinstance(grading, dict): | |
| core_metrics["back_tilt_deg"] = grading | |
| except Exception as e: | |
| # Silently continue if calculation fails | |
| pass | |
| # Calculate Knee Flexion @ Setup | |
| knee_bend_data = calculate_knee_bend_degree(pose_data, address_idx) | |
| if knee_bend_data is not None: | |
| primary_value = knee_bend_data.get('primary_value') # Average or single value | |
| if primary_value is not None: | |
| grading = get_knee_flexion_grading(primary_value, 0.8) | |
| # Add separate lead/trail values if available | |
| lead_flexion = knee_bend_data.get('lead_knee_flexion') | |
| trail_flexion = knee_bend_data.get('trail_knee_flexion') | |
| if lead_flexion is not None: | |
| grading['lead_knee_flexion'] = round(lead_flexion, 1) | |
| if trail_flexion is not None: | |
| grading['trail_knee_flexion'] = round(trail_flexion, 1) | |
| core_metrics["knee_flexion_deg"] = grading | |
| # Calculate Hip Depth / Early Extension | |
| try: | |
| hip_depth_data = calculate_hip_depth_early_extension(pose_data, swing_phases) | |
| if hip_depth_data: | |
| # Check if it's an error message | |
| if 'error' in hip_depth_data: | |
| # Store error but still try to display it | |
| core_metrics["hip_depth_early_extension"] = { | |
| 'value': None, | |
| 'detailed_data': hip_depth_data, | |
| 'status': f'Error: {hip_depth_data["error"]}', | |
| 'badge': '🔴', | |
| 'error': True | |
| } | |
| elif isinstance(hip_depth_data, dict) and 'depth_loss_pct' in hip_depth_data: | |
| # Valid data structure - proceed with grading | |
| grading = get_hip_depth_grading(hip_depth_data, 0.8) | |
| if grading: | |
| core_metrics["hip_depth_early_extension"] = grading | |
| else: | |
| # Unexpected data structure - create a fallback entry | |
| core_metrics["hip_depth_early_extension"] = { | |
| 'value': None, | |
| 'status': 'Data format error', | |
| 'badge': '🔴' | |
| } | |
| except Exception as e: | |
| # Create a fallback entry for display | |
| core_metrics["hip_depth_early_extension"] = { | |
| 'value': None, | |
| 'status': 'Calculation failed', | |
| 'badge': '🔴' | |
| } | |
| # Hip turn calculation removed per user request | |
| return core_metrics | |
| def compute_front_facing_core_metrics(pose_data, swing_phases, frames=None, world_landmarks=None, club="iron"): | |
| """Compute front-facing golf swing metrics""" | |
| # Get front-facing metrics from the new module | |
| front_metrics = compute_front_facing_metrics(pose_data, swing_phases, world_landmarks=world_landmarks, frames=frames, club=club, handedness="right") | |
| # Convert to the expected format with status and badges | |
| core_metrics = {} | |
| # Process each front-facing metric with new calibration values | |
| for metric_name, metric_data in front_metrics.items(): | |
| # Skip hip turn metric entirely | |
| if 'hip_turn_impact' in metric_name: | |
| continue | |
| # Handle other metrics with traditional processing | |
| value = metric_data.get('value') if isinstance(metric_data, dict) else metric_data | |
| if 'shoulder_tilt_impact' in metric_name: | |
| core_metrics[metric_name] = get_shoulder_tilt_impact_grading(value, 0.9) | |
| elif 'hip_shoulder_separation_impact' in metric_name: | |
| confidence = metric_data.get('confidence', 0.8) if isinstance(metric_data, dict) else 0.8 | |
| core_metrics[metric_name] = get_hip_shoulder_separation_impact_grading(value, confidence) | |
| elif 'hip_sway_top' in metric_name: | |
| core_metrics[metric_name] = get_hip_sway_grading(value, 0.8, "top") | |
| elif 'wrist_hinge_top' in metric_name: | |
| core_metrics[metric_name] = get_wrist_hinge_grading(value, 0.8, "top") | |
| else: | |
| # Default case | |
| core_metrics[metric_name] = {'value': value, 'status': 'n/a'} | |
| return core_metrics | |
| def prepare_data_for_llm(pose_data, swing_phases, trajectory_data=None, fps=30.0, frame_shape=None, frame_timestamps_ms=None, total_ms=None, is_front_facing=False, frames=None): | |
| """Prepare swing data for LLM analysis | |
| Args: | |
| pose_data (dict): Dictionary mapping frame indices to pose keypoints | |
| swing_phases (dict): Dictionary mapping phase names to lists of frame indices | |
| trajectory_data (dict, optional): Ball trajectory data | |
| fps (float): Video frame rate for timing calculations | |
| frame_shape (tuple, optional): Frame shape (H, W) | |
| frame_timestamps_ms (list, optional): List of frame timestamps in milliseconds | |
| total_ms (float, optional): Total video duration in milliseconds | |
| is_front_facing (bool): True for front-facing camera view | |
| Returns: | |
| dict: Formatted swing data for LLM | |
| """ | |
| # Compute core metrics | |
| core_metrics = compute_core_metrics( | |
| pose_data, swing_phases, frame_timestamps_ms, total_ms, | |
| player_handedness='right', is_front_facing=is_front_facing, frames=frames | |
| ) | |
| # Calculate timing data | |
| setup_frames = swing_phases.get("setup", []) | |
| backswing_frames = swing_phases.get("backswing", []) | |
| downswing_frames = swing_phases.get("downswing", []) | |
| impact_frames = swing_phases.get("impact", []) | |
| follow_through_frames = swing_phases.get("follow_through", []) | |
| total_frames = len(setup_frames) + len(backswing_frames) + len(downswing_frames) + len(impact_frames) + len(follow_through_frames) | |
| if total_ms is None: | |
| total_ms = total_frames * (1000.0 / fps) | |
| dt = total_ms / max(total_frames, 1) / 1000.0 # seconds per frame | |
| swing_data = { | |
| "swing_phases": { | |
| "setup": { | |
| "frame_count": len(setup_frames), | |
| "duration_ms": len(setup_frames) * dt * 1000.0 | |
| }, | |
| "backswing": { | |
| "frame_count": len(backswing_frames), | |
| "duration_ms": len(backswing_frames) * dt * 1000.0 | |
| }, | |
| "downswing": { | |
| "frame_count": len(downswing_frames), | |
| "duration_ms": len(downswing_frames) * dt * 1000.0 | |
| }, | |
| "impact": { | |
| "frame_count": len(impact_frames), | |
| "duration_ms": len(impact_frames) * dt * 1000.0 | |
| }, | |
| "follow_through": { | |
| "frame_count": len(follow_through_frames), | |
| "duration_ms": len(follow_through_frames) * dt * 1000.0 | |
| } | |
| }, | |
| "timing_metrics": { | |
| "total_swing_frames": total_frames, | |
| "total_swing_time_ms": total_frames * dt * 1000.0, | |
| "actual_fps": round(1.0 / dt if dt > 0 else fps, 1) | |
| }, | |
| "core_metrics": core_metrics | |
| } | |
| return swing_data | |
| def create_llm_prompt(analysis_data): | |
| """Create LLM prompt from swing analysis data""" | |
| prompt_template = """# Golf Swing Analysis | |
| ## NEW METRICS CALIBRATION | |
| Use these new professional/amateur benchmarks for scoring. These represent updated golf swing mechanics based on the latest analysis: | |
| ### **NEW FRONT-FACING METRICS:** | |
| **Shoulder Tilt @ Impact:** | |
| - Professional = 39° | |
| - 30 Handicap = 27° | |
| **Hip Sway @ Top:** | |
| - Professional = 3.9" towards target | |
| - 30 Handicap = 2.5" towards target | |
| **Wrist Hinge @ Top:** | |
| - To be measured and calibrated | |
| ### **NEW DTL METRICS:** | |
| **Shoulder Tilt/Swing Plane Angle @ Top:** | |
| - Professional = 36° (Iron avg: 34.3°, Driver avg: 30.5°) | |
| - 30 Handicap = 29° | |
| **Back Tilt (°):** | |
| - Professional = 31° (Iron avg: 30.8°, Driver avg: 32.3°) | |
| **Knee Flexion (°):** | |
| - Professional = 22° (Iron avg: 21.6°, Driver avg: 27.2°) | |
| **Head Drop/Rise @ Top (%):** | |
| - Professional = 0-5% drop (Iron: 0-7.6% range, Driver: 4.6-5.5% drop) | |
| ### **ANALYSIS INSTRUCTIONS** | |
| **GOLF SWING ANALYSIS FORMAT** | |
| Use the benchmarks above to guide your evaluation. Follow this exact format: | |
| **OVERALL_SUMMARY:** [1-2 sentences maximum providing a concise evaluation of the swing's overall quality and main strengths/areas for improvement] | |
| **PERFORMANCE_CLASSIFICATION:** [XX%] | |
| (XX = number from 10% to 100%) | |
| **Metric Evaluations** | |
| For each metric below, write exactly 3 sentences evaluating the metric: | |
| 1. First sentence: State if it's good, bad, or needs improvement compared to professional standards | |
| 2. Second sentence: Compare the specific value to professional/amateur ranges | |
| 3. Third sentence: Brief explanation of impact on swing performance | |
| **Classification Bands:** | |
| - **90–100%**: Tour-level | |
| - **80–89%**: Advanced amateur | |
| - **70–79%**: Skilled | |
| - **60–69%**: Intermediate | |
| - **50–59%**: Developing | |
| - **40–49%**: Beginner | |
| - **10–39%**: Novice | |
| **STYLE & FORMATTING RULES:** | |
| - Use these headers: OVERALL_SUMMARY, PERFORMANCE_CLASSIFICATION, Metric Evaluations | |
| - No emojis anywhere in the response | |
| - Write 1-2 sentences maximum for the overall summary | |
| - Write exactly 3 sentences for each metric evaluation | |
| - Use a positive, coaching tone throughout | |
| - Focus on biomechanics and compare actual values to the professional ranges provided | |
| """ | |
| # Format metrics for prompt | |
| core_metrics = analysis_data.get('core_metrics', {}) | |
| metrics_text = "SWING METRICS:\n" | |
| for metric_name, metric_data in core_metrics.items(): | |
| if metric_data.get('value') is not None: | |
| value = metric_data['value'] | |
| confidence = metric_data.get('confidence', 0.0) | |
| status = metric_data.get('status', 'unknown') | |
| if isinstance(value, dict): | |
| # Handle complex metrics like hip depth | |
| if 'depth_loss_pct' in value: | |
| metrics_text += f"- Hip Depth/Early Extension: {value['depth_loss_pct']}% loss (confidence: {confidence:.0%}) - {status}\n" | |
| elif 'displacement_pct' in value: | |
| direction = value.get('displacement_direction', 'unknown') | |
| metrics_text += f"- Head Displacement: {value['displacement_pct']}% {direction} (confidence: {confidence:.0%}) - {status}\n" | |
| else: | |
| metrics_text += f"- {metric_name.replace('_', ' ').title()}: {value}° (confidence: {confidence:.0%}) - {status}\n" | |
| # Combine template with data | |
| full_prompt = f"{prompt_template}\n\n{metrics_text}" | |
| return full_prompt | |
| def generate_swing_analysis(pose_data, swing_phases, trajectory_data): | |
| """Generate swing analysis using LLM""" | |
| # Prepare data for LLM | |
| analysis_data = prepare_data_for_llm(pose_data, swing_phases, trajectory_data) | |
| # Create prompt | |
| prompt = create_llm_prompt(analysis_data) | |
| # Check available services | |
| services = check_llm_services() | |
| # Configuration | |
| config = { | |
| 'model': 'gpt-4', | |
| 'max_tokens': 2000, | |
| 'temperature': 0.7 | |
| } | |
| # Try OpenAI first, then Ollama | |
| analysis = None | |
| if services.get('openai', {}).get('available'): | |
| analysis = call_openai_service(prompt, config) | |
| elif services.get('ollama', {}).get('available'): | |
| analysis = call_ollama_service(prompt, config) | |
| if analysis: | |
| return analysis | |
| else: | |
| return "LLM analysis unavailable. Please check your API configuration." | |
| def parse_and_format_analysis(raw_analysis): | |
| """Parse and format the raw LLM analysis""" | |
| if not raw_analysis: | |
| return None | |
| # Simple parsing - just clean up the text | |
| cleaned_analysis = raw_analysis.strip() | |
| # Remove any markdown formatting for cleaner display | |
| cleaned_analysis = re.sub(r'\*\*(.*?)\*\*', r'\1', cleaned_analysis) | |
| cleaned_analysis = re.sub(r'\*(.*?)\*', r'\1', cleaned_analysis) | |
| return { | |
| 'formatted_analysis': cleaned_analysis, | |
| 'raw_analysis': raw_analysis | |
| } | |
| def display_formatted_analysis(analysis_data): | |
| """Display formatted analysis (for compatibility)""" | |
| if not analysis_data: | |
| return "No analysis data available" | |
| return analysis_data.get('formatted_analysis', analysis_data.get('raw_analysis', 'No analysis available')) | |
| def test_dtl_five_metrics_fixes(pose_data, swing_phases, frames=None): | |
| """Test function to validate the DTL metric fixes | |
| This function tests the fixes for: | |
| 1. Shoulder tilt no longer returning ~90° due to yaw correction | |
| 2. Hip turn no longer returning ~90° due to width-ratio method | |
| 3. Hip depth properly comparing posterior hip positions | |
| 4. All 5 metrics being computed and returned | |
| Args: | |
| pose_data (dict): Dictionary mapping frame indices to pose keypoints | |
| swing_phases (dict): Dictionary mapping phase names to lists of frame indices | |
| frames (list, optional): Video frames for analysis | |
| Returns: | |
| dict: Test results showing before/after values and validation | |
| """ | |
| print("\n=== Testing DTL Five Metrics Fixes ===") | |
| # Test the new aggregator | |
| metrics = compute_dtl_five_metrics(pose_data, swing_phases, frames) | |
| if metrics is None: | |
| print("❌ Failed: compute_dtl_five_metrics returned None") | |
| return None | |
| print(f"✅ Success: Got {len(metrics)} metrics") | |
| # Validate each metric | |
| expected_metrics = [ | |
| "shoulder_plane_top_deg", | |
| "back_tilt_setup_deg", | |
| "knee_flexion_deg", | |
| "hip_depth_pct", | |
| # "hip_turn_impact_deg" # Removed hip turn metric | |
| ] | |
| results = { | |
| "total_metrics": len(metrics), | |
| "expected_metrics": len(expected_metrics), | |
| "metric_values": {}, | |
| "validation_results": {} | |
| } | |
| for metric_name in expected_metrics: | |
| value = metrics.get(metric_name) | |
| results["metric_values"][metric_name] = value | |
| if value is None: | |
| print(f"⚠️ {metric_name}: None (calculation failed or outside sanity bounds)") | |
| results["validation_results"][metric_name] = "failed_or_clamped" | |
| elif metric_name.endswith("_deg") and value >= 80: | |
| print(f"❌ {metric_name}: {value}° (unrealistic high value - check calculation)") | |
| results["validation_results"][metric_name] = "unrealistic_high" | |
| else: | |
| print(f"✅ {metric_name}: {value}{'%' if 'pct' in metric_name else '°'} (reasonable value)") | |
| results["validation_results"][metric_name] = "success" | |
| # Summary | |
| successful_metrics = sum(1 for v in results["validation_results"].values() if v == "success") | |
| print(f"\n📊 Summary: {successful_metrics}/{len(expected_metrics)} metrics returning reasonable values") | |
| if successful_metrics >= 4: | |
| print("🎉 Overall: GOOD - Most metrics working properly") | |
| elif successful_metrics >= 2: | |
| print("⚠️ Overall: FAIR - Some metrics still need work") | |
| else: | |
| print("❌ Overall: POOR - Major issues remain") | |
| results["overall_status"] = "good" if successful_metrics >= 4 else "fair" if successful_metrics >= 2 else "poor" | |
| return results | |
| def get_hip_depth_grading_from_value(value, confidence): | |
| """Helper function to create hip depth grading from just the percentage value""" | |
| if value is None: | |
| return None | |
| # Create a mock hip_depth_data structure for the existing grading function | |
| mock_data = { | |
| 'depth_loss_pct': value, | |
| 'confidence': confidence | |
| } | |
| # Try to import the grading function from streamlit_app | |
| try: | |
| from ..streamlit_app import get_hip_depth_grading | |
| return get_hip_depth_grading(mock_data, confidence) | |
| except ImportError: | |
| # Fallback grading logic | |
| if value <= 5: | |
| status = "Excellent - minimal early extension" | |
| badge = "🟢" | |
| elif value <= 15: | |
| status = "Good - slight early extension" | |
| badge = "🟡" | |
| else: | |
| status = "Needs work - significant early extension" | |
| badge = "🔴" | |
| return { | |
| 'value': value, | |
| 'status': status, | |
| 'badge': badge, | |
| 'detailed_data': mock_data | |
| } |