Spaces:
Paused
Paused
| """ | |
| Streamlit web UI for Golf Swing Analysis | |
| """ | |
| import os | |
| import sys | |
| import tempfile | |
| import streamlit as st | |
| from dotenv import load_dotenv | |
| import base64 | |
| from pathlib import Path | |
| import shutil | |
| import cv2 | |
| from PIL import Image | |
| from datetime import datetime | |
| # Load environment variables | |
| load_dotenv() | |
| # Add the app directory to the path | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| # ===== FORCE MODULE RELOAD FOR UPDATED METRICS ===== | |
| # This ensures the latest front_facing_metrics.py fixes are loaded | |
| import importlib | |
| modules_to_reload = [ | |
| 'models.front_facing_metrics', | |
| 'models.metrics_calculator', | |
| 'models.pose_estimator', | |
| 'models.swing_analyzer', | |
| 'models.llm_analyzer' | |
| ] | |
| for module in modules_to_reload: | |
| if module in sys.modules: | |
| importlib.reload(sys.modules[module]) | |
| print(f"🔄 Reloaded {module}") | |
| # Enable debug mode for front-facing metrics | |
| try: | |
| import models.front_facing_metrics as ffm | |
| ffm.set_debug(True) | |
| print(f"✅ Front-facing metrics version: {ffm.METRICS_VERSION}") | |
| print(f"✅ Debug enabled: {ffm.VERBOSE}") | |
| except Exception as e: | |
| print(f"⚠️ Could not enable debug mode: {e}") | |
| # Clear Streamlit caches to ensure fresh data | |
| if hasattr(st, 'cache_data'): | |
| st.cache_data.clear() | |
| if hasattr(st, 'cache_resource'): | |
| st.cache_resource.clear() | |
| print("🚀 Module reloads complete - running with latest fixes!") | |
| # ===== END MODULE RELOAD SECTION ===== | |
| # Import modules (will use reloaded versions from above) | |
| from utils.video_downloader import download_youtube_video, download_pro_reference, cleanup_video_file, cleanup_downloads_directory | |
| from utils.video_processor import process_video | |
| from models.pose_estimator import analyze_pose | |
| from models.swing_analyzer import segment_swing_pose_based, analyze_trajectory | |
| from models.llm_analyzer import generate_swing_analysis, create_llm_prompt, prepare_data_for_llm, check_llm_services, parse_and_format_analysis, display_formatted_analysis, compute_core_metrics | |
| from utils.visualizer import create_annotated_video | |
| from utils.comparison import create_key_frame_comparison, extract_key_swing_frames | |
| # Import RAG functionality | |
| print("=== RAG Import Debug Information ===") | |
| print(f"Current working directory: {os.getcwd()}") | |
| print(f"Python path: {sys.path}") | |
| print(f"Files in current directory: {os.listdir('.')}") | |
| # Check if we're in the app directory or project root | |
| if os.path.exists("golf_swing_rag.py"): | |
| print("✓ Found golf_swing_rag.py in current directory") | |
| elif os.path.exists("app/golf_swing_rag.py"): | |
| print("✓ Found golf_swing_rag.py in app/ subdirectory") | |
| else: | |
| print("✗ golf_swing_rag.py not found in current directory or app/ subdirectory") | |
| print(f"Looking for: golf_swing_rag.py") | |
| if os.path.exists("app"): | |
| print(f"Files in app directory: {os.listdir('app')}") | |
| try: | |
| print("Attempting to import golf_swing_rag...") | |
| from golf_swing_rag import GolfSwingRAG | |
| print("✓ Successfully imported GolfSwingRAG from golf_swing_rag") | |
| RAG_AVAILABLE = True | |
| except ImportError as e: | |
| print(f"✗ ImportError: {e}") | |
| print("Trying alternative import methods...") | |
| # Try importing from app directory explicitly | |
| try: | |
| print("Trying: from app.golf_swing_rag import GolfSwingRAG") | |
| from app.golf_swing_rag import GolfSwingRAG | |
| print("✓ Successfully imported from app.golf_swing_rag") | |
| RAG_AVAILABLE = True | |
| except ImportError as e2: | |
| print(f"✗ App import failed: {e2}") | |
| # Try adding current directory to path and importing | |
| try: | |
| print("Adding current directory to sys.path and trying again...") | |
| current_dir = os.path.dirname(os.path.abspath(__file__)) | |
| if current_dir not in sys.path: | |
| sys.path.insert(0, current_dir) | |
| print(f"Added to path: {current_dir}") | |
| from golf_swing_rag import GolfSwingRAG | |
| print("✓ Successfully imported after adding current dir to path") | |
| RAG_AVAILABLE = True | |
| except ImportError as e3: | |
| print(f"✗ Final import attempt failed: {e3}") | |
| RAG_AVAILABLE = False | |
| st.error(f"RAG functionality not available. Import errors: {e}, {e2}, {e3}") | |
| if RAG_AVAILABLE: | |
| print("✓ RAG system is available!") | |
| else: | |
| print("✗ RAG system is NOT available") | |
| st.warning("RAG functionality not available. Please ensure golf_swing_rag.py is in the app directory.") | |
| print("=== End RAG Import Debug ===") | |
| print("") | |
| # Set page config | |
| st.set_page_config(page_title="Par-ity Project🏌️♀️", | |
| page_icon="🏌️♀️", | |
| layout="wide", | |
| initial_sidebar_state="collapsed") | |
| # Custom CSS for RAG interface | |
| st.markdown(""" | |
| <style> | |
| .chat-message { | |
| padding: 1rem; | |
| border-radius: 10px; | |
| margin: 1rem 0; | |
| } | |
| .user-message { | |
| background-color: #e3f2fd; | |
| border-left: 4px solid #2196f3; | |
| } | |
| .assistant-message { | |
| background-color: #f1f8e9; | |
| border-left: 4px solid #4caf50; | |
| } | |
| .rag-header { | |
| color: #2E8B57; | |
| font-size: 1.5rem; | |
| font-weight: bold; | |
| margin-bottom: 1rem; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| def load_rag_system(): | |
| """Load and initialize the RAG system (cached for performance) with enhanced error handling""" | |
| if not RAG_AVAILABLE: | |
| st.warning("RAG system not available - missing dependencies") | |
| return None | |
| try: | |
| print("=== RAG System Loading Debug ===") | |
| with st.spinner("Loading golf swing knowledge base..."): | |
| print("Creating GolfSwingRAG instance...") | |
| rag = GolfSwingRAG() | |
| print("✓ GolfSwingRAG instance created successfully") | |
| print("Loading and processing data...") | |
| rag.load_and_process_data() | |
| print("✓ Data loaded and processed successfully") | |
| print("Creating embeddings (this may take a moment)...") | |
| try: | |
| rag.create_embeddings() | |
| if hasattr(rag, 'index') and rag.index is not None: | |
| print("✓ Embeddings created successfully with FAISS") | |
| st.success("🎯 RAG system loaded with semantic search capabilities") | |
| else: | |
| print("⚠ Embeddings creation had issues, but fallback search available") | |
| st.warning("⚠️ RAG system loaded with basic search (semantic search unavailable)") | |
| except Exception as embedding_error: | |
| print(f"⚠ Embedding creation failed: {embedding_error}") | |
| print("RAG will use fallback search methods") | |
| st.warning("⚠️ RAG system loaded with basic search only") | |
| print("✓ RAG system initialization completed!") | |
| print("=== End RAG System Loading Debug ===") | |
| return rag | |
| except Exception as e: | |
| print(f"✗ Critical error loading RAG system: {str(e)}") | |
| print(f"Error type: {type(e).__name__}") | |
| import traceback | |
| print(f"Full traceback: {traceback.format_exc()}") | |
| st.error(f"❌ RAG system failed to load: {str(e)}") | |
| return None | |
| def clamp_for_display(value, min_val, max_val): | |
| """Clamp a value for display purposes only""" | |
| if value == 'n/a' or value is None: | |
| return value | |
| try: | |
| float_val = float(value) | |
| return max(min_val, min(max_val, float_val)) | |
| except (ValueError, TypeError): | |
| return value | |
| def format_metric_value(metric_data, unit=""): | |
| """Format metric value with status indication""" | |
| if not isinstance(metric_data, dict): | |
| return 'n/a' | |
| value = metric_data.get('value') | |
| status = metric_data.get('status', 'n/a') | |
| if value is None or status == 'n/a': | |
| return 'n/a' | |
| elif status == 'ok': | |
| return f"{value}{unit}" | |
| else: | |
| return f"{value} ({status}){unit}" | |
| def get_back_tilt_grading(value, confidence, camera_roll=0): | |
| """Grade back tilt at setup""" | |
| if value is None: | |
| return {'value': None, 'display_value': 'n/a', 'badge': '⚪', 'status': 'No data available'} | |
| # Note: Not hiding based on camera roll - display all measurements | |
| # Determine grading | |
| if 28 <= value <= 38: | |
| badge = "🟢" | |
| label = "On-plane posture" | |
| elif (24 <= value < 28) or (38 < value <= 42): | |
| badge = "🟠" | |
| label = "Slightly out of range" | |
| elif (20 <= value < 24) or (42 < value <= 48): | |
| badge = "🟠" | |
| label = "Off (adjust)" | |
| else: | |
| badge = "🔴" | |
| label = "Likely problematic" | |
| return { | |
| 'value': value, # Raw value for display logic | |
| 'display_value': f"{value:.1f}°", | |
| 'badge': badge, | |
| 'label': label, | |
| 'confidence': confidence, | |
| 'status': label, # Use label as status for legacy compatibility | |
| } | |
| def get_knee_flexion_grading(value, confidence): | |
| """Grade knee flexion at setup""" | |
| if value is None: | |
| return {'display_value': 'n/a', 'badge': '⚪', 'status': 'No data available'} | |
| # Determine grading | |
| if 15 <= value <= 30: | |
| badge = "🟢" | |
| label = "Athletic" | |
| elif (12 <= value < 15): | |
| badge = "🟠" | |
| label = "Slightly too stiff" | |
| elif (30 < value <= 35): | |
| badge = "🟠" | |
| label = "Slightly too bent" | |
| elif (8 <= value < 12): | |
| badge = "🟠" | |
| label = "Too stiff" | |
| elif (35 < value <= 40): | |
| badge = "🟠" | |
| label = "Too bent" | |
| else: | |
| badge = "🔴" | |
| label = "Likely problematic" | |
| return { | |
| 'display_value': f"{value:.1f}°", | |
| 'badge': badge, | |
| 'label': label, | |
| 'confidence': confidence | |
| } | |
| def get_shoulder_tilt_swing_plane_grading(value, confidence): | |
| """Grade shoulder tilt/swing plane at top - professional = 36°, 30 handicap = 29°""" | |
| if value is None: | |
| return {'value': None, 'display_value': 'n/a', 'badge': '⚪', 'status': 'No data available'} | |
| # Professional = 36°, 30 handicap = 29° | |
| if 30 <= value <= 40: | |
| badge = '🟢' | |
| label = 'Excellent plane' | |
| elif 25 <= value < 30 or 40 < value <= 45: | |
| badge = '🟠' | |
| label = 'Good plane' | |
| else: | |
| badge = '🔴' | |
| label = 'Needs adjustment' | |
| return { | |
| 'value': value, # Raw value for display logic | |
| 'display_value': f'{value:.1f}°', | |
| 'badge': badge, | |
| 'label': label, | |
| 'confidence': confidence, | |
| 'status': label # Use label as status for legacy compatibility | |
| } | |
| def get_head_drop_grading(value, confidence): | |
| """Grade head movement at top based on percentage of torso length | |
| Convention: positive = moved DOWN (drop), negative = moved UP (rise) | |
| Grades by absolute movement magnitude, shows direction in display. | |
| """ | |
| if value is None: | |
| return {'value': None, 'display_value': 'n/a', 'badge': '⚪', 'status': 'No data available'} | |
| # Convert string to float if needed | |
| try: | |
| if isinstance(value, str): | |
| # Remove any non-numeric characters except decimal point and minus sign | |
| import re | |
| clean_value = re.sub(r'[^\d.-]', '', value) | |
| if clean_value: | |
| value = float(clean_value) | |
| else: | |
| return {'value': None, 'display_value': 'n/a', 'badge': '⚪', 'status': 'Invalid value'} | |
| else: | |
| value = float(value) | |
| except (ValueError, TypeError): | |
| return {'value': None, 'display_value': 'n/a', 'badge': '⚪', 'status': 'Invalid value'} | |
| # Determine direction and absolute magnitude | |
| direction = "drop" if value >= 0 else "rise" | |
| abs_movement = abs(value) | |
| # Grade by absolute movement magnitude (suggested rubric) | |
| if abs_movement <= 3: | |
| badge = '🟢' | |
| grade = 'Excellent head stability' | |
| elif abs_movement <= 6: | |
| badge = '🟠' | |
| grade = 'Good/typical' | |
| elif abs_movement <= 10: | |
| badge = '⚠️' | |
| grade = 'Borderline (work on stability)' | |
| else: # abs_movement > 10 | |
| badge = '🔴' | |
| grade = 'Excessive movement' | |
| return { | |
| 'value': value, # Raw value for display logic | |
| 'display_value': f'{abs_movement:.1f}% {direction}', | |
| 'badge': badge, | |
| 'label': grade, | |
| 'confidence': confidence, | |
| 'status': grade # Use grade as status for legacy compatibility | |
| } | |
| def get_torso_sidebend_impact_grading(value, confidence): | |
| """Grade torso side-bend at impact - professional range ~10-20° (trail-side bend positive)""" | |
| if value is None: | |
| return {'display_value': 'n/a', 'badge': '⚪', 'status': 'No data available'} | |
| # Use absolute value for grading, but keep sign for display | |
| abs_value = abs(value) | |
| # Professional range ~10-20°, typical amateur ~6-24° | |
| if 10 <= abs_value <= 20: | |
| badge = '🟢' | |
| label = 'Excellent side-bend' | |
| elif 6 <= abs_value < 10 or 20 < abs_value <= 24: | |
| badge = '🟠' | |
| label = 'Good side-bend' | |
| else: | |
| badge = '🔴' | |
| label = 'Needs work' | |
| return { | |
| 'display_value': f'{value:.1f}°', | |
| 'badge': badge, | |
| 'label': label, | |
| 'confidence': confidence | |
| } | |
| def get_hip_shoulder_separation_impact_grading(value, confidence): | |
| """Grade hip-shoulder separation at impact - typical range 10-45 degrees""" | |
| if value is None: | |
| return {'display_value': 'n/a', 'badge': '⚪', 'status': 'No data available'} | |
| # Professional golfers typically show 20-35° separation, recreational 10-25° | |
| if 20 <= value <= 35: | |
| badge = '🟢' | |
| label = 'Excellent separation' | |
| elif 15 <= value < 20 or 35 < value <= 45: | |
| badge = '🟠' | |
| label = 'Good separation' | |
| else: | |
| badge = '🔴' | |
| label = 'Needs improvement' | |
| return { | |
| 'display_value': f'{value:.1f}°', | |
| 'badge': badge, | |
| 'label': label, | |
| 'confidence': confidence | |
| } | |
| def get_hip_sway_grading(value, confidence, position="top"): | |
| """Grade hip sway - professional = 3.9" towards target, 30 handicap = 2.5" towards target""" | |
| if value is None: | |
| return {'display_value': 'n/a', 'badge': '⚪', 'status': 'No data available'} | |
| # Professional = 3.9" towards target, 30 handicap = 2.5" towards target | |
| if position == "top": | |
| if 3.0 <= value <= 5.0: | |
| badge = '🟢' | |
| label = 'Excellent sway' | |
| elif 2.0 <= value < 3.0 or 5.0 < value <= 6.0: | |
| badge = '🟠' | |
| label = 'Good sway' | |
| else: | |
| badge = '🔴' | |
| label = 'Needs improvement' | |
| else: # impact | |
| if 2.0 <= value <= 4.0: | |
| badge = '🟢' | |
| label = 'Excellent sway' | |
| elif 1.0 <= value < 2.0 or 4.0 < value <= 5.0: | |
| badge = '🟠' | |
| label = 'Good sway' | |
| else: | |
| badge = '🔴' | |
| label = 'Needs improvement' | |
| return { | |
| 'display_value': f'{value:.1f}"', | |
| 'badge': badge, | |
| 'label': label, | |
| 'confidence': confidence | |
| } | |
| def get_wrist_hinge_grading(value, confidence, position="top"): | |
| """Grade wrist hinge angle""" | |
| if value is None: | |
| return {'display_value': 'n/a', 'badge': '⚪', 'status': 'No data available'} | |
| # Professional ranges: Top ~90-120°, Impact ~15-35° | |
| if position == "top": | |
| if 85 <= value <= 125: | |
| badge = '🟢' | |
| label = 'Excellent hinge' | |
| elif 70 <= value < 85 or 125 < value <= 140: | |
| badge = '🟠' | |
| label = 'Good hinge' | |
| else: | |
| badge = '🔴' | |
| label = 'Needs improvement' | |
| else: # impact | |
| if 15 <= value <= 40: | |
| badge = '🟢' | |
| label = 'Excellent release' | |
| elif 10 <= value < 15 or 40 < value <= 50: | |
| badge = '🟠' | |
| label = 'Good release' | |
| else: | |
| badge = '🔴' | |
| label = 'Needs improvement' | |
| return { | |
| 'display_value': f"{value:.1f}°", | |
| 'badge': badge, | |
| 'label': label, | |
| 'confidence': confidence, | |
| } | |
| def display_new_grading_scheme(core_metrics): | |
| """Display the swing analysis with badges and confidence indicators""" | |
| # Check if we have front-facing metrics | |
| # Front-facing metrics have unique keys that DTL doesn't have | |
| has_front_facing_metrics = any(key in core_metrics for key in [ | |
| 'torso_side_bend_deg', 'shoulder_tilt_impact_deg', 'hip_sway_top_inches', 'wrist_hinge_top_deg', 'hip_shoulder_separation_impact_deg' | |
| ]) | |
| if has_front_facing_metrics: | |
| st.subheader("Swing Analysis") | |
| else: | |
| st.subheader("Down-the-Line Swing Analysis") | |
| # Extract raw values and calculate confidence (simplified for now) | |
| # New DTL metrics | |
| shoulder_tilt_swing_plane_data = core_metrics.get("shoulder_tilt_swing_plane_top_deg", {}) | |
| back_tilt_data = core_metrics.get("back_tilt_deg", {}) | |
| knee_flexion_data = core_metrics.get("knee_flexion_deg", {}) | |
| head_drop_data = core_metrics.get("head_drop_top_pct", {}) | |
| hip_depth_data = core_metrics.get("hip_depth_early_extension", {}) | |
| # Calculate confidence with QC penalties as per feedback | |
| def get_confidence(data, metric_type='general'): | |
| if data.get('value') is None: | |
| return 0.0 | |
| # Start with base confidence of 90% | |
| confidence = 90.0 | |
| status = data.get('status', 'n/a') | |
| # Apply QC penalties as specified in feedback | |
| # conf = base (90) − occlusion(0–25) − scale/crop change(0–10) − phase underseg(0–15) − sign unknown(0–5) | |
| # Occlusion penalties (0-25 points) | |
| if 'club_not_visible' in status or 'Unavailable (club occluded)' in status: | |
| confidence -= 25 # Maximum occlusion penalty | |
| elif 'poor tracking' in status or 'no detection' in status: | |
| confidence -= 20 | |
| elif 'approximate' in status: | |
| confidence -= 15 | |
| elif 'low confidence' in status: | |
| confidence -= 10 | |
| # Scale/crop change penalties (0-10 points) | |
| if 'scale_drift' in status or 'unstable (scale)' in status: | |
| confidence -= 10 | |
| elif 'QC fail' in status: | |
| confidence -= 8 | |
| # Phase undersegmentation penalties (0-15 points) | |
| if 'timing_unreliable' in status or 'phase underseg' in status: | |
| confidence -= 15 | |
| elif 'unreliable' in status: | |
| confidence -= 10 | |
| # Sign unknown/uncertain penalties (0-5 points) | |
| if 'uncertain' in status or 'extreme value' in status: | |
| confidence -= 5 | |
| elif 'outside tour range' in status: | |
| confidence -= 3 | |
| # Metric-specific adjustments | |
| if metric_type == 'shaft_angle': | |
| if 'target_line_error' in status: | |
| return 0.0 # Complete failure | |
| elif 'club occluded' in status: | |
| confidence = 0.0 # Never show if club not visible | |
| elif metric_type == 'head_sway': | |
| if 'tracking_failed' in status: | |
| return 0.0 # Complete failure | |
| elif metric_type == 'wrist_pattern': | |
| # Check for club visibility issues | |
| if 'insufficient_data' in status: | |
| confidence -= 20 | |
| # Cap at 75% if club tip visibility uncertain | |
| confidence = min(confidence, 75.0) | |
| elif metric_type == 'hip_rotation': | |
| # DTL-only limitation - cap at 60% | |
| confidence = min(confidence, 60.0) | |
| elif metric_type == 'shoulder_turn': | |
| # DTL-only limitation - cap at 65% | |
| confidence = min(confidence, 65.0) | |
| # Apply DTL-only caps: confidence ≤70% by default for DTL-limited metrics | |
| if 'DTL' in status or 'approx' in status: | |
| confidence = min(confidence, 70.0) | |
| # If primitive is n/a, dependent metrics ≤75% | |
| if data.get('value') is None: | |
| confidence = min(confidence, 75.0) | |
| # Ensure confidence stays within valid range | |
| confidence = max(0.0, min(90.0, confidence)) | |
| return confidence / 100.0 # Convert to 0-1 scale | |
| # Process each metric | |
| metrics_to_display = [] | |
| # Old metrics removed - now using new 5-metric system | |
| # DTL Metrics - All 5 new metrics | |
| if not has_front_facing_metrics: | |
| # 1. Shoulder Tilt/Swing Plane @ Top | |
| shoulder_tilt_swing_plane_value = shoulder_tilt_swing_plane_data.get('value') | |
| if shoulder_tilt_swing_plane_value is not None: | |
| confidence = get_confidence(shoulder_tilt_swing_plane_data, 'shoulder_tilt_swing_plane') | |
| grading = get_shoulder_tilt_swing_plane_grading(shoulder_tilt_swing_plane_value, confidence) | |
| if grading: | |
| metrics_to_display.append(("Shoulder Tilt/Swing Plane @ Top", grading)) | |
| # 3. Back Tilt @ Setup | |
| tilt_value = back_tilt_data.get('value') | |
| if tilt_value is not None: | |
| tilt_confidence = get_confidence(back_tilt_data, 'back_tilt') | |
| tilt_grading = get_back_tilt_grading(tilt_value, tilt_confidence) | |
| if tilt_grading: | |
| metrics_to_display.append(("Back Tilt @ Setup", tilt_grading)) | |
| # 4. Knee Flexion @ Setup | |
| knee_value = knee_flexion_data.get('value') | |
| if knee_value is not None: | |
| # Apply knee flexion correction if needed (handle legacy data) | |
| if knee_value > 90: | |
| knee_value = 180.0 - knee_value | |
| knee_confidence = get_confidence(knee_flexion_data, 'knee_flexion') | |
| knee_grading = get_knee_flexion_grading(knee_value, knee_confidence) | |
| if knee_grading: | |
| metrics_to_display.append(("Knee Flexion", knee_grading)) | |
| # 5. Head Movement @ Top (New DTL metric) | |
| head_drop_value = head_drop_data.get('value') | |
| if head_drop_value is not None: | |
| head_drop_confidence = get_confidence(head_drop_data, 'head_drop') | |
| head_drop_grading = get_head_drop_grading(head_drop_value, head_drop_confidence) | |
| if head_drop_grading: | |
| metrics_to_display.append(("Head Movement @ Top", head_drop_grading)) | |
| # Removed: Wrist Pattern, Kinematic Sequence (DTL-approx), and Shoulder Turn Quality metrics | |
| # 8. Hip Depth / Early Extension (New DTL metric) | |
| hip_depth_value = hip_depth_data.get('value') | |
| if hip_depth_value is not None: | |
| hip_confidence = get_confidence(hip_depth_data, 'hip_depth') | |
| # Pass the detailed_data (full dictionary) instead of just the value | |
| detailed_data = hip_depth_data.get('detailed_data', {}) | |
| hip_grading = get_hip_depth_grading(detailed_data, hip_confidence) | |
| if hip_grading: | |
| metrics_to_display.append(("Hip Depth / Early Extension", hip_grading)) | |
| else: | |
| # Debug: Check why hip grading is None | |
| st.caption(f"Debug: Hip depth value: {hip_depth_value}, confidence: {hip_confidence}") | |
| elif hip_depth_data.get('error'): | |
| # Display error message | |
| error_data = hip_depth_data.get('detailed_data', {}) | |
| error_msg = error_data.get('error', 'Unknown error') | |
| error_grading = { | |
| 'display_value': f'Error: {error_msg}', | |
| 'badge': '🔴', | |
| 'label': 'Calculation failed', | |
| 'confidence': 0.0, | |
| } | |
| metrics_to_display.append(("Hip Depth / Early Extension", error_grading)) | |
| else: | |
| pass # Hip depth calculation succeeded but no special handling needed | |
| # Additional old metrics removed - focusing on new 4-metric system | |
| # Front-facing metrics (only displayed when available) - 4 required metrics | |
| if has_front_facing_metrics: | |
| # Front-facing metrics are now always calculated | |
| pass | |
| # Torso Side-Bend at Impact (replaces shoulder tilt) | |
| torso_sidebend_data = core_metrics.get("torso_side_bend_deg", {}) | |
| if torso_sidebend_data.get('value') is not None: | |
| confidence = 0.9 # High confidence for front-facing measurements | |
| grading = get_torso_sidebend_impact_grading(torso_sidebend_data['value'], confidence) | |
| if grading: | |
| metrics_to_display.append(("Torso Side-Bend @ Impact", grading)) | |
| else: | |
| # Fallback to old metric name for compatibility | |
| shoulder_tilt_impact_data = core_metrics.get("shoulder_tilt_impact_deg", {}) | |
| if shoulder_tilt_impact_data.get('value') is not None: | |
| confidence = 0.9 # High confidence for front-facing measurements | |
| grading = get_torso_sidebend_impact_grading(shoulder_tilt_impact_data['value'], confidence) | |
| if grading: | |
| metrics_to_display.append(("Torso Side-Bend @ Impact", grading)) | |
| # Hip-Shoulder Separation at Impact | |
| hip_shoulder_sep_data = core_metrics.get("hip_shoulder_separation_impact_deg", {}) | |
| if hip_shoulder_sep_data.get('value') is not None: | |
| confidence = hip_shoulder_sep_data.get('confidence', 0.8) | |
| grading = get_hip_shoulder_separation_impact_grading(hip_shoulder_sep_data['value'], confidence) | |
| if grading: | |
| metrics_to_display.append(("Hip-Shoulder Separation @ Impact", grading)) | |
| # Hip Sway at Top | |
| hip_sway_top_data = core_metrics.get("hip_sway_top_inches", {}) | |
| if hip_sway_top_data.get('value') is not None: | |
| confidence = 0.8 | |
| grading = get_hip_sway_grading(hip_sway_top_data['value'], confidence, "top") | |
| if grading: | |
| metrics_to_display.append(("Hip Sway @ Top", grading)) | |
| # Wrist Hinge at Top | |
| wrist_hinge_top_data = core_metrics.get("wrist_hinge_top_deg", {}) | |
| if wrist_hinge_top_data.get('value') is not None: | |
| confidence = 0.8 # Lower confidence as it's estimated from pose | |
| grading = get_wrist_hinge_grading(wrist_hinge_top_data['value'], confidence, "top") | |
| if grading: | |
| metrics_to_display.append(("Wrist Hinge @ Top", grading)) | |
| # Display each metric | |
| for metric_name, grading in metrics_to_display: | |
| display_metric_card(metric_name, grading) | |
| def display_metric_card(metric_name, grading): | |
| """Display a single metric as a clean text bubble using Streamlit components""" | |
| # Create a container for each metric | |
| with st.container(): | |
| # Use Streamlit's built-in styling | |
| st.markdown("---") | |
| # Metric header | |
| st.subheader(metric_name) | |
| # Add definition | |
| definition = get_metric_definition(metric_name) | |
| if definition: | |
| st.caption(definition) | |
| # Result line with badge and label | |
| result_line = f"**{grading['display_value']}** — {grading['badge']} {grading.get('label', '')}" | |
| st.markdown(result_line) | |
| # Confidence display removed per user request | |
| # Detailed evaluation text | |
| evaluation = get_metric_evaluation(metric_name, grading) | |
| st.write(evaluation) | |
| # Tips display removed per user request | |
| # Add spacing | |
| st.write("") | |
| def get_metric_definition(metric_name): | |
| """Get a short definition for each metric""" | |
| definitions = { | |
| "Shoulder Tilt/Swing Plane @ Top": "Measures shoulder swing plane angle at the top of backswing.", | |
| "Back Tilt @ Setup": "Measures spine angle from vertical at address position.", | |
| "Knee Flexion": "Measures knee bend angle at address position.", | |
| "Head Movement @ Top": "Measures head movement during backswing as percentage of torso length.", | |
| "Hip Depth / Early Extension": "Tracks loss of hip flexion through impact.", | |
| "Torso Side-Bend @ Impact": "Measures torso side-bend angle at ball contact.", | |
| "Shoulder Tilt @ Impact": "Measures shoulder angle at ball contact.", # Deprecated | |
| "Hip-Shoulder Separation @ Impact": "Measures hip rotation relative to shoulders at ball contact.", | |
| "Hip Sway @ Top": "Measures lateral hip movement at the top of backswing.", | |
| "Wrist Hinge @ Top": "Measures wrist hinge angle at the top of backswing." | |
| } | |
| return definitions.get(metric_name, "") | |
| def get_metric_evaluation(metric_name, grading): | |
| """Generate detailed evaluation text for each metric""" | |
| value = grading.get('display_value', 'Unknown') | |
| badge = grading.get('badge', '') | |
| # DTL METRICS (5 current metrics) | |
| if metric_name == "Shoulder Tilt/Swing Plane @ Top": | |
| if "🟢" in badge: | |
| return f"Your shoulder tilt/swing plane of **{value}** at the top shows excellent position. This optimal swing plane angle promotes powerful, on-plane delivery and consistent ball striking. Professional golfers typically maintain 36° while 30-handicappers average 29°. Your measurement indicates proper shoulder turn and swing plane control." | |
| elif "🟠" in badge: | |
| return f"Your shoulder tilt/swing plane of **{value}** at the top is good with room for improvement. This measurement affects your swing plane consistency and power generation. Refining your shoulder turn and spine angle can enhance ball striking and distance control." | |
| else: | |
| return f"Your shoulder tilt/swing plane of **{value}** at the top needs attention. This metric is crucial for swing plane consistency and power generation. Work on proper shoulder rotation and maintaining spine angle throughout the backswing for better results." | |
| elif metric_name == "Back Tilt @ Setup": | |
| if "🟢" in badge: | |
| return f"Your back tilt of **{value}** shows excellent posture setup. This forward spine angle is crucial for creating the proper swing plane and generating power through impact. Good back tilt promotes consistent contact, optimal launch conditions, and prevents early extension during the downswing." | |
| elif "🟠" in badge: | |
| return f"Your back tilt of **{value}** is acceptable but could be optimized. Back tilt affects your swing plane and ability to rotate properly. Slight adjustments to your setup posture could improve consistency, distance, and ball striking quality." | |
| else: | |
| return f"Your back tilt of **{value}** needs attention for optimal performance. Proper spine angle at setup is fundamental for swing mechanics, power generation, and consistent ball contact. Poor back tilt can lead to swing compensations and inconsistent results." | |
| elif metric_name == "Knee Flexion": | |
| if "🟢" in badge: | |
| return f"Your knee flexion of **{value}** demonstrates an athletic setup position. This optimal knee bend provides stability throughout the swing while allowing proper weight transfer and rotation. Good knee flexion supports powerful, balanced swings and consistent ball striking." | |
| elif "🟠" in badge: | |
| return f"Your knee flexion of **{value}** is workable but could be refined. Knee bend affects your balance, power transfer, and ability to maintain posture during the swing. Minor adjustments could enhance your stability and swing efficiency." | |
| else: | |
| return f"Your knee flexion of **{value}** may be limiting your swing potential. Proper knee bend is essential for balance, power generation, and maintaining spine angle. Too little or too much knee flexion can cause balance issues and inconsistent contact." | |
| elif metric_name == "Head Movement @ Top": | |
| if "🟢" in badge: | |
| return f"Your head drop of **{value}** shows excellent head stability during the backswing. This minimal movement indicates proper head control and balance, which promotes consistent contact and accuracy. Good head stability is fundamental for reliable ball striking." | |
| elif "🟠" in badge: | |
| return f"Your head drop of **{value}** shows moderate movement during the backswing. Some head movement can affect balance and consistency. Working on head stability and maintaining your spine angle can improve ball striking and accuracy." | |
| else: | |
| return f"Your head drop of **{value}** indicates excessive head movement during the backswing. Too much head drop disrupts balance and swing center, leading to inconsistent contact. Focus on keeping your head stable and maintaining spine angle for better results." | |
| elif metric_name == "Hip Depth / Early Extension": | |
| if "🟢" in badge: | |
| return f"Your hip depth of **{value}** shows excellent posture maintenance. This indicates you're maintaining proper spine angle and avoiding early extension through impact. Good hip depth promotes solid contact, power transfer, and consistent ball striking patterns." | |
| elif "🟠" in badge: | |
| return f"Your hip depth of **{value}** is acceptable but could be improved. This measurement indicates some early extension tendencies. Working on maintaining spine angle and hip position through impact can enhance consistency and power transfer." | |
| else: | |
| return f"Your hip depth of **{value}** indicates early extension issues. This movement pattern reduces power transfer and can cause inconsistent contact. Focus on maintaining spine angle and proper hip position throughout the downswing and impact." | |
| # FRONT-FACING METRICS (4 current metrics) | |
| elif metric_name == "Torso Side-Bend @ Impact": | |
| # Interpret the sign | |
| side_description = "trail-side" if float(value.replace('°', '')) > 0 else "lead-side" | |
| abs_value = abs(float(value.replace('°', ''))) | |
| if "🟢" in badge: | |
| return f"Your torso side-bend of **{value}** at impact shows excellent position. You're bending **{abs_value:.1f}°** toward the **{side_description}** which indicates ideal impact dynamics and power transfer. This optimal torso angle promotes solid contact, optimal ball flight, and consistent distance control." | |
| elif "🟠" in badge: | |
| return f"Your torso side-bend of **{value}** at impact is acceptable with room for improvement. You're bending **{abs_value:.1f}°** toward the **{side_description}**. Torso position at impact affects power transfer and ball flight characteristics. Refining your impact position can enhance consistency and distance." | |
| else: | |
| return f"Your torso side-bend of **{value}** at impact needs attention. You're bending **{abs_value:.1f}°** toward the **{side_description}**. Proper torso angle at impact is crucial for power transfer and ball flight control. Work on impact position for better contact and consistency." | |
| elif metric_name == "Shoulder Tilt @ Impact": # Legacy support | |
| if "🟢" in badge: | |
| return f"Your shoulder tilt of **{value}** at impact shows excellent position. This proper shoulder angle indicates ideal impact dynamics and power transfer. Good shoulder tilt at impact promotes solid contact, optimal ball flight, and consistent distance control." | |
| elif "🟠" in badge: | |
| return f"Your shoulder tilt of **{value}** at impact is acceptable with room for improvement. Shoulder position at impact affects power transfer and ball flight characteristics. Refining your impact position can enhance consistency and distance." | |
| else: | |
| return f"Your shoulder tilt of **{value}** at impact needs attention. Proper shoulder angle at impact is crucial for power transfer and ball flight control. Work on impact position for better contact and consistency." | |
| elif metric_name == "Hip-Shoulder Separation @ Impact": | |
| if "🟢" in badge: | |
| return f"Your hip-shoulder separation of **{value}** at impact shows excellent body sequencing. This proper rotation sequence indicates ideal power transfer with the hips leading the shoulders through impact. Good separation promotes solid contact and optimal ball flight." | |
| elif "🟠" in badge: | |
| return f"Your hip-shoulder separation of **{value}** at impact is acceptable with room for improvement. Hip-shoulder sequencing affects power transfer and swing efficiency. Refining your rotation sequence can enhance consistency and distance." | |
| else: | |
| return f"Your hip-shoulder separation of **{value}** at impact needs attention. Proper sequencing with hips leading shoulders is crucial for power transfer and ball flight control. Work on rotation timing for better contact and consistency." | |
| elif metric_name == "Hip Sway @ Top": | |
| if "🟢" in badge: | |
| return f"Your hip sway of **{value}** at the top shows excellent stability. Minimal lateral movement maintains proper balance and swing center, promoting consistent contact and accuracy. This stable foundation supports powerful, controlled swings." | |
| elif "🟠" in badge: | |
| return f"Your hip sway of **{value}** at the top shows moderate movement. Some lateral sway can affect balance and consistency. Working on stability and weight transfer can improve ball striking and accuracy." | |
| else: | |
| return f"Your hip sway of **{value}** at the top indicates excessive lateral movement. Too much sway disrupts balance and swing center, leading to inconsistent contact. Focus on stability and proper weight transfer for better results." | |
| elif metric_name == "Wrist Hinge @ Top": | |
| if "🟢" in badge: | |
| return f"Your wrist hinge of **{value}** at the top shows excellent set. This proper wrist angle stores energy effectively and sets up lag for powerful release through impact. Good wrist hinge contributes significantly to clubhead speed and distance." | |
| elif "🟠" in badge: | |
| return f"Your wrist hinge of **{value}** at the top is adequate but could be optimized. Better wrist action can enhance lag, power generation, and strike consistency. Work on wrist mobility and proper hinge timing." | |
| else: | |
| return f"Your wrist hinge of **{value}** at the top needs improvement. Proper wrist set is crucial for creating lag and power. Limited wrist hinge reduces potential clubhead speed and distance. Focus on wrist mobility and hinge mechanics." | |
| else: | |
| # Default evaluation for any other metrics | |
| return f"Your **{value}** measurement provides insight into your swing mechanics. This metric affects various aspects of your performance including power generation, accuracy, and consistency. Continue working on this fundamental for improved golf performance." | |
| def display_swing_phase_breakdown(swing_phases): | |
| """Display the swing phase breakdown table""" | |
| st.subheader("Swing Phase Breakdown") | |
| # Create phase data | |
| phase_data = [] | |
| for phase_name, phase_info in swing_phases.items(): | |
| phase_data.append([ | |
| phase_name.title().replace('_', ' '), | |
| phase_info.get('frame_count', 0), | |
| f"{phase_info.get('duration_ms', 0):.0f} ms" | |
| ]) | |
| # Display as table | |
| import pandas as pd | |
| df = pd.DataFrame(phase_data, columns=["Phase", "Frames", "Duration"]) | |
| # Style the table | |
| styled_df = df.style.set_properties(**{ | |
| 'background-color': '#f8f9fa', | |
| 'color': '#0B3B0B', | |
| 'border': '1px solid #dee2e6' | |
| }).set_table_styles([ | |
| {'selector': 'th', 'props': [('background-color', '#e9ecef'), ('color', '#0B3B0B'), ('font-weight', 'bold')]}, | |
| {'selector': 'td', 'props': [('text-align', 'center')]}, | |
| {'selector': 'th:first-child', 'props': [('text-align', 'left')]}, | |
| {'selector': 'td:first-child', 'props': [('text-align', 'left'), ('font-weight', 'bold')]} | |
| ]) | |
| st.dataframe(styled_df, use_container_width=True, hide_index=True) | |
| def display_rag_sources(sources): | |
| """Display source information in an organized way""" | |
| if not sources: | |
| return | |
| st.subheader("📚 Sources") | |
| for i, source in enumerate(sources[:3]): # Show top 3 sources | |
| with st.expander(f"Source {i+1}: {source['metadata']['title'][:60]}..."): | |
| st.write(f"**Similarity Score:** {source['similarity_score']:.3f}") | |
| st.write(f"**Source:** {source['metadata']['source']}") | |
| if source['metadata']['url']: | |
| st.write(f"**URL:** [Link]({source['metadata']['url']})") | |
| st.write("**Content:**") | |
| st.write(source['chunk'][:500] + "..." if len(source['chunk']) > 500 else source['chunk']) | |
| def render_rag_interface(): | |
| """Render the RAG chatbot interface""" | |
| # Removed header and description | |
| # Initialize RAG system | |
| if 'rag_system' not in st.session_state and RAG_AVAILABLE: | |
| st.session_state.rag_system = load_rag_system() | |
| # Initialize chat history if not exists | |
| if 'rag_chat_history' not in st.session_state: | |
| st.session_state.rag_chat_history = [] | |
| if not RAG_AVAILABLE or st.session_state.get('rag_system') is None: | |
| st.error("RAG system is not available. Please check the setup.") | |
| return | |
| # Check if we have video analysis data to enhance responses | |
| user_swing_context = "" | |
| if st.session_state.get('video_analyzed') and 'analysis_data' in st.session_state: | |
| stored_data = st.session_state.analysis_data | |
| # Use the structured analysis_data instead of just the prompt | |
| if 'analysis_data' in stored_data: | |
| structured_analysis = stored_data['analysis_data'] | |
| core_metrics = structured_analysis.get('core_metrics', {}) | |
| # Format the simplified data for better RAG context | |
| user_swing_context = f""" | |
| USER'S SWING ANALYSIS: | |
| === SWING TIMING & PHASES === | |
| Swing Phases: | |
| - Setup: {structured_analysis.get('swing_phases', {}).get('setup', {}).get('frame_count', 0)} frames | |
| - Backswing: {structured_analysis.get('swing_phases', {}).get('backswing', {}).get('frame_count', 0)} frames | |
| - Downswing: {structured_analysis.get('swing_phases', {}).get('downswing', {}).get('frame_count', 0)} frames | |
| - Impact: {structured_analysis.get('swing_phases', {}).get('impact', {}).get('frame_count', 0)} frames | |
| - Follow-through: {structured_analysis.get('swing_phases', {}).get('follow_through', {}).get('frame_count', 0)} frames | |
| Timing Metrics: | |
| - Total Swing Time: {structured_analysis.get('timing_metrics', {}).get('total_swing_time_ms', 'N/A')} ms | |
| === DTL METRICS === | |
| - Shoulder Tilt/Swing Plane @ Top: {format_metric_value(core_metrics.get('shoulder_tilt_swing_plane_top_deg', {}), '°')} | |
| - Back Tilt @ Setup: {format_metric_value(core_metrics.get('back_tilt_deg', {}), '°')} | |
| - Knee Flexion @ Setup: {format_metric_value(core_metrics.get('knee_flexion_deg', {}), '°')} | |
| - Head Movement @ Top: {format_metric_value(core_metrics.get('head_drop_top_pct', {}), '%')} | |
| - Hip Depth / Early Extension: {format_metric_value(core_metrics.get('hip_depth_early_extension', {}), '%')} | |
| === DTL-LIMITED METRICS (Approximate) === | |
| - Shoulder Turn Quality: {format_metric_value(core_metrics.get('shoulder_turn_quality', {}))} | |
| === FRONT-FACING METRICS === | |
| - Torso Side-Bend @ Impact: {format_metric_value(core_metrics.get('torso_side_bend_deg', {}), '°')} | |
| - Hip-Shoulder Separation @ Impact: {format_metric_value(core_metrics.get('hip_shoulder_separation_impact_deg', {}), '°')} | |
| - Hip Sway @ Top: {format_metric_value(core_metrics.get('hip_sway_top_inches', {}), '"')} | |
| - Wrist Hinge @ Top: {format_metric_value(core_metrics.get('wrist_hinge_top_deg', {}), '°')} | |
| """ | |
| # Removed success message | |
| elif 'prompt' in stored_data: | |
| # Fallback to prompt if structured data not available | |
| user_swing_context = f"\n\nUSER'S SWING ANALYSIS:\n{stored_data['prompt']}" | |
| # Removed success message | |
| # Create columns for layout | |
| col1, col2 = st.columns([2, 1]) | |
| with col1: | |
| # Removed subheader | |
| # Question input (with proper label) | |
| question = st.text_area( | |
| "Question", # Proper label for accessibility | |
| height=100, | |
| placeholder="Ask about your golf swing technique...", | |
| label_visibility="collapsed" # Hide the label visually while keeping it for accessibility | |
| ) | |
| # Removed settings section - using smart defaults instead | |
| col_submit, col_clear = st.columns([1, 1]) | |
| with col_submit: | |
| submit_button = st.button("🎯 Get Answer", type="primary", use_container_width=True) | |
| with col_clear: | |
| if st.button("🗑️ Clear Chat History", use_container_width=True): | |
| st.session_state.rag_chat_history = [] | |
| # Don't call st.rerun() here to avoid disappearing interface | |
| st.success("Chat history cleared!") | |
| # Process question | |
| if submit_button and question.strip(): | |
| with st.spinner("Analyzing your question and searching the knowledge base..."): | |
| try: | |
| # Enhanced query method that includes user's swing context | |
| # Use smart default for number of sources (3-5 depending on context) | |
| num_sources = 5 if user_swing_context else 3 # More sources when we have swing analysis | |
| result = query_with_user_context( | |
| st.session_state.rag_system, | |
| question, | |
| user_swing_context, | |
| top_k=num_sources | |
| ) | |
| # Add to chat history | |
| st.session_state.rag_chat_history.append({ | |
| 'question': question, | |
| 'response': result['response'], | |
| 'sources': result['sources'], | |
| 'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S"), | |
| 'used_swing_context': bool(user_swing_context) | |
| }) | |
| st.success("Answer generated successfully!") | |
| except Exception as e: | |
| st.error(f"An error occurred: {str(e)}") | |
| # Display chat history (simplified) | |
| if st.session_state.rag_chat_history: | |
| for i, chat in enumerate(reversed(st.session_state.rag_chat_history)): | |
| # Removed question numbers, timestamps, and personalization indicators | |
| # Question | |
| st.markdown(f'<div class="chat-message user-message"><strong>🤔 Your Question:</strong><br>{chat["question"]}</div>', | |
| unsafe_allow_html=True) | |
| # Response | |
| st.markdown(f'<div class="chat-message assistant-message"><strong>⛳ Expert Answer:</strong><br>{chat["response"]}</div>', | |
| unsafe_allow_html=True) | |
| # Removed sources display | |
| st.divider() | |
| with col2: | |
| # Removed all the About section, Tips, Personalized Questions, and metrics | |
| pass | |
| def query_with_user_context(rag_system, question, user_swing_context, top_k=5): | |
| """Enhanced query method that includes user's swing analysis context""" | |
| # Search for relevant chunks | |
| relevant_chunks = rag_system.search_similar_chunks(question, top_k) | |
| # Generate response with enhanced context | |
| response = generate_enhanced_response(rag_system, question, relevant_chunks, user_swing_context) | |
| print(f"Response: {response}") | |
| return { | |
| 'response': response, | |
| 'sources': relevant_chunks, | |
| 'query': question, | |
| 'timestamp': datetime.now().isoformat() | |
| } | |
| def generate_enhanced_response(rag_system, query, context_chunks, user_swing_context=""): | |
| """Generate response using OpenAI API with user's swing analysis as the main system prompt""" | |
| if not rag_system.openai_client: | |
| print("No OpenAI client found") | |
| return generate_enhanced_fallback_response(query, context_chunks, user_swing_context) | |
| # Prepare context from knowledge base | |
| knowledge_context = "\n\n".join([f"Reference Material from '{chunk['metadata']['title']}':\n{chunk['chunk']}" | |
| for chunk in context_chunks]) | |
| # Use the user's swing analysis as the primary system prompt if available | |
| print(f"User swing context: {user_swing_context}") | |
| if user_swing_context: | |
| # Extract the actual analysis content (remove the header) | |
| analysis_content = user_swing_context.replace("USER'S SWING ANALYSIS:\n", "").strip() | |
| system_prompt = f"""{analysis_content} | |
| You are a golf swing technique expert assistant analyzing this specific player's swing. | |
| INSTRUCTIONS: | |
| 1. Always answer golf technique questions using the reference materials below | |
| 2. For swing motion biomechanics questions (head movement, hip rotation, weight transfer, etc.), also reference specific measurements from the player's swing analysis above when relevant | |
| 3. For setup/stance questions, answer from the reference materials without needing to reference swing motion data | |
| 4. Provide clear, actionable advice based on proven golf instruction | |
| 5. IMPORTANT: Keep responses to 4 sentences or less - be concise and focused | |
| Reference Materials from Golf Instruction Database: | |
| {knowledge_context}""" | |
| user_prompt = f"""Based on the golf instruction reference materials provided, please answer this question about golf swing technique: | |
| {query} | |
| Please provide a helpful, concise response (4 sentences or less) that addresses the specific question while drawing from the relevant information in the context. If the question relates to swing motion biomechanics and you have specific measurements from my swing analysis above, include those details for personalized advice.""" | |
| else: | |
| # Fallback to general system prompt if no swing analysis available | |
| system_prompt = f"""You are a golf swing technique expert assistant. You help golfers improve their swing by providing detailed, accurate advice based on professional golf instruction content. | |
| Instructions: | |
| - Answer questions about golf swing technique, mechanics, common problems, and solutions | |
| - Provide specific, actionable advice when possible | |
| - Reference relevant technical concepts when appropriate | |
| - Be encouraging and supportive | |
| - Synthesize information from multiple sources rather than just quoting them | |
| - IMPORTANT: Keep responses to 4 sentences or less - be concise and focused | |
| Reference Materials from Golf Instruction Database: | |
| {knowledge_context}""" | |
| user_prompt = f"""Based on the golf instruction reference materials provided, please answer this question about golf swing technique: | |
| {query} | |
| Please provide a helpful, concise response (4 sentences or less) that synthesizes the relevant information into clear, actionable guidance.""" | |
| print(f"System prompt: {system_prompt}") | |
| print(f"User prompt: {user_prompt}") | |
| try: | |
| response = rag_system.openai_client.chat.completions.create( | |
| model="gpt-4o-mini", | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt} | |
| ], | |
| max_tokens=400, | |
| temperature=0.7 | |
| ) | |
| return response.choices[0].message.content | |
| except Exception as e: | |
| print(f"OpenAI API error: {e}") | |
| return generate_enhanced_fallback_response(query, context_chunks, user_swing_context) | |
| def generate_enhanced_fallback_response(query, context_chunks, user_swing_context=""): | |
| """Generate an enhanced fallback response when OpenAI API is not available""" | |
| if not context_chunks: | |
| return "I couldn't find specific information about that topic in the golf swing database. Could you try rephrasing your question or being more specific?" | |
| # Extract relevant information from chunks | |
| best_chunk = context_chunks[0] | |
| chunk_content = best_chunk['chunk'] | |
| source_title = best_chunk['metadata']['title'] | |
| response_parts = [] | |
| # Check if question is about swing motion biomechanics vs setup/grip/equipment | |
| question_lower = query.lower() | |
| # Define topics that are NOT about swing motion biomechanics | |
| non_biomechanics_topics = [ | |
| 'grip', 'hold', 'grip pressure', 'grip size', 'grip style', | |
| 'setup', 'stance', 'address', 'alignment', 'posture at address', | |
| 'equipment', 'club', 'ball', 'tee', 'glove', | |
| 'course management', 'strategy', 'mental', 'psychology', | |
| 'warm up', 'practice', 'routine', 'pre-shot' | |
| ] | |
| # Check if question is about non-biomechanics topics | |
| is_non_biomechanics = any(topic in question_lower for topic in non_biomechanics_topics) | |
| # Part 1: Only check for relevant measurements if question is about swing motion biomechanics | |
| found_relevant_measurement = False | |
| if user_swing_context and not is_non_biomechanics: | |
| analysis_content = user_swing_context.replace("USER'S SWING ANALYSIS:\n", "").strip() | |
| analysis_lower = analysis_content.lower() | |
| # Only do specific keyword matching for biomechanics-related questions | |
| if "wrist" in question_lower and "hinge" in question_lower: | |
| # Look for wrist hinge measurements (only if asking about wrist hinge specifically) | |
| lines = analysis_content.split('\n') | |
| for line in lines: | |
| if 'wrist hinge' in line.lower() and ('°' in line or '%' in line): | |
| import re | |
| wrist_match = re.search(r'wrist hinge[:\s]*(\d+\.?\d*°)', line.lower()) | |
| if wrist_match: | |
| response_parts.append(f"I notice that your wrist hinge is {wrist_match.group(1)} during your swing.") | |
| found_relevant_measurement = True | |
| break | |
| elif "head" in question_lower and ("movement" in question_lower or "moving" in question_lower or "steady" in question_lower): | |
| # Look for head movement measurements (only if asking about head movement) | |
| lines = analysis_content.split('\n') | |
| for line in lines: | |
| if 'head movement' in line.lower() and ('in' in line or 'inches' in line): | |
| import re | |
| lateral_match = re.search(r'head movement \(lateral\)[:\s]*(\d+\.?\d*)\s*in', line.lower()) | |
| vertical_match = re.search(r'head movement \(vertical\)[:\s]*(\d+\.?\d*)\s*in', line.lower()) | |
| if lateral_match or vertical_match: | |
| lateral_val = lateral_match.group(1) if lateral_match else "N/A" | |
| vertical_val = vertical_match.group(1) if vertical_match else "N/A" | |
| response_parts.append(f"I notice that your head movement is {lateral_val} inches laterally and {vertical_val} inches vertically during your swing.") | |
| found_relevant_measurement = True | |
| break | |
| # Hip rotation question handling removed per user request | |
| elif "weight" in question_lower and ("transfer" in question_lower or "shift" in question_lower): | |
| # Look for weight transfer measurements (only if asking about weight transfer/shift) | |
| lines = analysis_content.split('\n') | |
| for line in lines: | |
| if ('weight transfer' in line.lower() or 'weight shift' in line.lower()) and '%' in line: | |
| import re | |
| weight_match = re.search(r'weight (?:transfer|shift)[:\s]*(\d+\.?\d*%)', line.lower()) | |
| if weight_match: | |
| response_parts.append(f"I notice that your weight transfer is {weight_match.group(1)} during the downswing.") | |
| found_relevant_measurement = True | |
| break | |
| elif "shoulder" in question_lower and ("rotation" in question_lower or "turn" in question_lower): | |
| # Look for shoulder measurements (only if asking about shoulder rotation/turn) | |
| lines = analysis_content.split('\n') | |
| for line in lines: | |
| if 'shoulder rotation' in line.lower() and '°' in line: | |
| import re | |
| shoulder_match = re.search(r'shoulder rotation[:\s]*(\d+\.?\d*°)', line.lower()) | |
| if shoulder_match: | |
| response_parts.append(f"I notice that your shoulder rotation is {shoulder_match.group(1)} during your swing.") | |
| found_relevant_measurement = True | |
| break | |
| # Part 2: Expert recommendation (synthesized from source - keep concise) | |
| sentences = chunk_content.split('. ') | |
| meaningful_sentences = [s.strip() for s in sentences if len(s.strip()) > 20][:2] | |
| expert_advice = '. '.join(meaningful_sentences[:2]) + '.' | |
| response_parts.append(f"Based on {source_title}, {expert_advice}") | |
| # Part 3: Improvement recommendation (only connect to swing analysis if relevant) | |
| if user_swing_context and found_relevant_measurement and not is_non_biomechanics: | |
| # Only provide swing-analysis-specific advice if we found relevant measurements | |
| response_parts.append("Focus on implementing this expert advice to address your specific swing characteristics.") | |
| else: | |
| # For non-biomechanics questions or when no relevant measurements found | |
| response_parts.append("Focus on implementing this expert advice.") | |
| # Combine all parts with space separation to keep it concise | |
| final_response = " ".join(response_parts) | |
| return final_response | |
| # Define functions | |
| def validate_youtube_url(url): | |
| """Validate if the URL is a YouTube URL""" | |
| return "youtube.com" in url or "youtu.be" in url | |
| def process_uploaded_video(uploaded_file): | |
| """Process an uploaded video file""" | |
| # Create downloads directory if it doesn't exist | |
| os.makedirs("downloads", exist_ok=True) | |
| # Save uploaded file to the downloads directory | |
| file_path = os.path.join("downloads", uploaded_file.name) | |
| with open(file_path, "wb") as f: | |
| f.write(uploaded_file.getvalue()) | |
| return file_path | |
| def display_video(video_path, width=300): | |
| """Display a video with download option""" | |
| # Read video bytes | |
| with open(video_path, "rb") as file: | |
| video_bytes = file.read() | |
| # Create a container with custom width | |
| video_container = st.container() | |
| # Apply CSS to control the width and ensure it's centered | |
| video_container.markdown(f""" | |
| <style> | |
| .element-container:has(video) {{ | |
| max-width: {width}px; | |
| margin: 0 auto; | |
| }} | |
| video {{ | |
| width: 100% !important; | |
| height: auto !important; | |
| }} | |
| </style> | |
| """, | |
| unsafe_allow_html=True) | |
| # Display video using st.video with bytes | |
| with video_container: | |
| st.video(video_bytes) | |
| # Show download button | |
| st.download_button(label="Download Video", | |
| data=video_bytes, | |
| file_name=os.path.basename(video_path), | |
| mime="video/mp4") | |
| # Main app | |
| def main(): | |
| """Main Streamlit application with 5-step flow""" | |
| # Custom CSS for Par-ity branding | |
| st.markdown(""" | |
| <style> | |
| /* Set background color for entire app */ | |
| .stApp { | |
| background-color: #fffdfa; | |
| } | |
| /* Ensure main content area also has the background */ | |
| .main .block-container { | |
| background-color: #fffdfa; | |
| } | |
| /* Par-ity Project Styling */ | |
| .main-header { | |
| text-align: center; | |
| color: #0B3B0B; | |
| font-family: 'Georgia', serif; | |
| font-weight: bold; | |
| } | |
| /* Fix text color visibility for mobile and all devices - EXCEPT buttons */ | |
| .stMarkdown, .stMarkdown p, .stMarkdown h1, .stMarkdown h2, .stMarkdown h3, .stMarkdown h4 { | |
| color: #0B3B0B !important; | |
| } | |
| /* Ensure all text elements have proper contrast - EXCEPT buttons */ | |
| .element-container, .stMarkdown div, p, span, h1, h2, h3, h4, h5, h6 { | |
| color: #0B3B0B !important; | |
| } | |
| /* Override text color for buttons to ensure proper contrast */ | |
| .stButton > button, .stButton > button * { | |
| color: #fffdfa !important; | |
| background-color: #0B3B0B !important; | |
| } | |
| /* Mobile-specific text styling */ | |
| @media (max-width: 768px) { | |
| .stMarkdown, .stMarkdown p, .stMarkdown h1, .stMarkdown h2, .stMarkdown h3, .stMarkdown h4 { | |
| color: #0B3B0B !important; | |
| font-weight: 500 !important; | |
| } | |
| /* Stronger contrast for mobile */ | |
| .element-container, .stMarkdown div, p, span, h1, h2, h3, h4, h5, h6 { | |
| color: #0B3B0B !important; | |
| text-shadow: none !important; | |
| } | |
| /* Ensure button text stays visible on mobile */ | |
| .stButton > button, .stButton > button * { | |
| color: #fffdfa !important; | |
| background-color: #0B3B0B !important; | |
| font-weight: bold !important; | |
| } | |
| } | |
| /* Dark mode override for mobile browsers */ | |
| @media (prefers-color-scheme: dark) { | |
| .stMarkdown, .stMarkdown p, .stMarkdown h1, .stMarkdown h2, .stMarkdown h3, .stMarkdown h4, | |
| .element-container, .stMarkdown div, p, span, h1, h2, h3, h4, h5, h6 { | |
| color: #0B3B0B !important; | |
| background-color: transparent !important; | |
| } | |
| /* Ensure button text stays beige/white even in dark mode */ | |
| .stButton > button, .stButton > button * { | |
| color: #fffdfa !important; | |
| background-color: #0B3B0B !important; | |
| } | |
| } | |
| .stButton > button { | |
| background-color: #0B3B0B !important; | |
| color: #fffdfa !important; | |
| border-radius: 25px; | |
| border: none; | |
| padding: 12px 28px; | |
| font-weight: bold !important; | |
| font-size: 16px; | |
| transition: all 0.3s ease; | |
| } | |
| .stButton > button:hover { | |
| background-color: #4CAF50 !important; | |
| color: #fffdfa !important; | |
| transform: translateY(-1px); | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # Logo at the top | |
| try: | |
| # Try multiple possible paths for the logo file | |
| logo_paths = [ | |
| "3in par-ity project horizontal logo.png", # New 3-inch logo | |
| "par-ity project horizontal logo.png", # Fallback to original | |
| "app/3in par-ity project horizontal logo.png", # Original path for local development | |
| "app/par-ity project horizontal logo.png", # Fallback original path | |
| "./3in par-ity project horizontal logo.png" # Explicit current directory | |
| ] | |
| logo_loaded = False | |
| for logo_path in logo_paths: | |
| if os.path.exists(logo_path): | |
| col1, col2, col3 = st.columns([1, 2, 1]) | |
| with col2: | |
| st.image(logo_path) | |
| logo_loaded = True | |
| break | |
| if not logo_loaded: | |
| st.markdown('<div class="main-header"><h1>⛳ Par-ity Project</h1></div>', unsafe_allow_html=True) | |
| except Exception as e: | |
| st.markdown('<div class="main-header"><h1>⛳ Par-ity Project</h1></div>', unsafe_allow_html=True) | |
| # Initialize session state for step-based flow | |
| if 'current_step' not in st.session_state: | |
| st.session_state.current_step = 1 | |
| if 'video_analyzed' not in st.session_state: | |
| st.session_state.video_analyzed = False | |
| if 'analysis_data' not in st.session_state: | |
| st.session_state.analysis_data = { | |
| 'video_path': None, | |
| 'frames': None, | |
| 'detections': None, | |
| 'pose_data': None, | |
| 'swing_phases': None, | |
| 'trajectory_data': None, | |
| 'sample_rate': None | |
| } | |
| if 'show_sidebar' not in st.session_state: | |
| st.session_state.show_sidebar = False | |
| # Add session cleanup | |
| if 'session_initialized' not in st.session_state: | |
| cleanup_result = cleanup_downloads_directory(keep_annotated=True) | |
| if cleanup_result.get('files_removed', 0) > 0: | |
| st.success(f"🗑️ Cleaned up {cleanup_result['files_removed']} old files ({cleanup_result['space_freed_mb']} MB freed)") | |
| st.session_state.session_initialized = True | |
| # Set automatic defaults | |
| llm_services = check_llm_services() | |
| any_service_available = llm_services['ollama']['available'] or llm_services['openai']['available'] | |
| enable_gpt = any_service_available | |
| sample_rate = 1 | |
| # Simple sidebar navigation (appears after Step 3) | |
| if st.session_state.current_step >= 3: | |
| with st.sidebar: | |
| st.markdown("### Navigation") | |
| st.markdown("---") | |
| if st.button("🎯 See Feedback", key="nav_improvements", use_container_width=True): | |
| st.session_state.current_step = 4 | |
| st.rerun() | |
| if st.button("💬 Ask Questions", key="nav_chatbot", use_container_width=True): | |
| # Ensure we preserve the analysis state when navigating | |
| if st.session_state.get('video_analyzed', False): | |
| st.session_state.video_analyzed = True # Explicitly preserve this | |
| st.session_state.current_step = 5 | |
| st.rerun() | |
| if st.button("🔄 Start Over", key="nav_start_over", use_container_width=True): | |
| # Reset all session state | |
| for key in list(st.session_state.keys()): | |
| if key != 'session_initialized': | |
| del st.session_state[key] | |
| st.session_state.current_step = 1 | |
| st.rerun() | |
| # Step-based content rendering | |
| current_step = st.session_state.current_step | |
| # Safeguard: If user tries to access step 4 or 5 without analysis, | |
| # but they have analysis data, let them proceed | |
| if current_step >= 4 and not st.session_state.get('video_analyzed', False): | |
| if 'analysis_data' in st.session_state and st.session_state.analysis_data.get('video_path'): | |
| st.session_state.video_analyzed = True # Restore the flag if we have data | |
| if current_step == 1: | |
| render_step_1() | |
| elif current_step == 2: | |
| render_step_2() | |
| elif current_step == 3: | |
| render_step_3() | |
| elif current_step == 4: | |
| render_step_4() | |
| elif current_step == 5: | |
| render_step_5() | |
| st.markdown("---") | |
| # Footer with website link | |
| st.markdown( | |
| '<div style="text-align: center; margin-top: 1px; margin-bottom: 1px;">' | |
| '<a href="https://par-ityproject.org" target="_blank" style="color: #0B3B0B; text-decoration: none; font-size: 14px;">par-ityproject.org</a>' | |
| '</div>', | |
| unsafe_allow_html=True | |
| ) | |
| def render_step_1(): | |
| """Step 1: Upload Swing Video""" | |
| st.markdown('<h2 style="color: #0B3B0B; font-family: Georgia, serif;">Step 1: Upload Your Video</h2>', unsafe_allow_html=True) | |
| # Camera view selection - DTL only, set as default | |
| st.markdown("### 📹 Camera View", unsafe_allow_html=True) | |
| camera_view = st.radio( | |
| "Camera angle (automatically set to Down the Line):", | |
| options=["Down the Line (DTL)"], # Only DTL option, Front Facing disabled | |
| index=0, # Default to DTL | |
| key="camera_view", | |
| horizontal=True, | |
| help="DTL: Camera positioned behind/in front of golfer along target line. (Front Facing view temporarily disabled)" | |
| ) | |
| # Club selection (always show) | |
| st.markdown("### ⛳ Club Type") | |
| club_type = st.radio( | |
| "Select club type for accurate grading:", | |
| options=["iron", "driver"], | |
| index=0, # Default to iron | |
| key="club_type_selector", | |
| horizontal=True, | |
| help="Driver allows more hip rotation than irons. This affects grading thresholds for optimal swing metrics." | |
| ) | |
| # Store in session state | |
| st.session_state.club_type = club_type | |
| st.markdown("**Choose your input method below.**") | |
| st.markdown("💡 **Tips:**") | |
| st.markdown("- Aim for a video of 5 seconds or less") | |
| st.markdown("- Select the correct camera view above for accurate analysis") | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.markdown("#### YouTube URL") | |
| st.markdown('<style>div[data-testid="stTextInput"] > label {display: none;}</style>', unsafe_allow_html=True) | |
| youtube_url = st.text_input("", key="youtube_input", label_visibility="collapsed") | |
| with col2: | |
| st.markdown("#### Upload Video File") | |
| st.markdown('<style>div[data-testid="stFileUploader"] > label {display: none;}</style>', unsafe_allow_html=True) | |
| uploaded_file = st.file_uploader("", type=["mp4", "mov", "avi"], key="video_upload", label_visibility="collapsed") | |
| # Analyze button | |
| if st.button("🏌️ Start Analysis", key="start_analysis", use_container_width=True): | |
| # Validate camera view selection | |
| if not camera_view or camera_view == "": | |
| st.error("⚠️ Please select a camera view before starting analysis.") | |
| return | |
| video_path = None | |
| if uploaded_file is not None: | |
| with st.spinner("Processing uploaded video..."): | |
| try: | |
| video_path = process_uploaded_video(uploaded_file) | |
| st.success("Video uploaded successfully!") | |
| except Exception as e: | |
| st.error(f"Error processing video: {str(e)}") | |
| return | |
| elif youtube_url: | |
| if validate_youtube_url(youtube_url): | |
| with st.spinner("Downloading video..."): | |
| try: | |
| video_path = download_youtube_video(youtube_url) | |
| st.success("Video downloaded successfully!") | |
| except Exception as e: | |
| st.error(f"Error downloading video: {str(e)}") | |
| return | |
| else: | |
| st.error("Please enter a valid YouTube URL") | |
| return | |
| else: | |
| st.error("Please provide either a YouTube URL or upload a video file.") | |
| return | |
| if video_path: | |
| # Store the camera view selection | |
| st.session_state.analysis_data['video_path'] = video_path | |
| st.session_state.analysis_data['camera_view'] = camera_view | |
| st.session_state.current_step = 2 | |
| st.rerun() | |
| def render_step_2(): | |
| """Step 2: Analyzing Video and Pose""" | |
| st.markdown('<h2 style="color: #0B3B0B; font-family: Georgia, serif;">Step 2: Analyzing Video and Pose</h2>', unsafe_allow_html=True) | |
| video_path = st.session_state.analysis_data.get('video_path') | |
| if video_path and not st.session_state.video_analyzed: | |
| try: | |
| # Process the video | |
| with st.spinner("Processing video and detecting objects..."): | |
| frames, detections = process_video(video_path, sample_rate=1) | |
| st.success("✅ Video processing complete!") | |
| with st.spinner("Analyzing golfer's pose..."): | |
| pose_data, world_landmarks = analyze_pose(frames) | |
| st.success("✅ Pose analysis complete!") | |
| with st.spinner("Segmenting swing phases..."): | |
| # Get frame shape for relative threshold calculations | |
| frame_shape = frames[0].shape if frames else None | |
| swing_phases = segment_swing_pose_based(pose_data, detections, sample_rate=1, frame_shape=frame_shape, fps=30.0) | |
| st.success("✅ Swing segmentation complete!") | |
| with st.spinner("Analyzing trajectory and speed..."): | |
| trajectory_data = analyze_trajectory(frames, detections, swing_phases, sample_rate=1) | |
| st.success("✅ Trajectory analysis complete!") | |
| # Get camera view from session state | |
| camera_view = st.session_state.analysis_data.get('camera_view', 'Down the Line (DTL)') | |
| is_front_facing = camera_view == "Front Facing (Face-On)" | |
| # Prepare data for LLM | |
| analysis_data = prepare_data_for_llm(pose_data, swing_phases, trajectory_data, fps=30.0, frame_shape=frame_shape, is_front_facing=is_front_facing, frames=frames) | |
| prompt = create_llm_prompt(analysis_data) | |
| # Store analysis data | |
| st.session_state.analysis_data.update({ | |
| 'frames': frames, | |
| 'detections': detections, | |
| 'pose_data': pose_data, | |
| 'world_landmarks': world_landmarks, | |
| 'swing_phases': swing_phases, | |
| 'trajectory_data': trajectory_data, | |
| 'sample_rate': 1, | |
| 'analysis_data': analysis_data, | |
| 'prompt': prompt | |
| }) | |
| st.session_state.video_analyzed = True | |
| st.session_state.current_step = 3 | |
| st.rerun() | |
| except Exception as e: | |
| st.error(f"❌ **Analysis Failed**\n\n{str(e)}\n\nPlease try again.") | |
| # Back button | |
| if st.button("← Back to Upload", key="back_to_upload"): | |
| st.session_state.current_step = 1 | |
| st.rerun() | |
| else: | |
| # Analysis already completed, move to results | |
| st.session_state.current_step = 3 | |
| st.rerun() | |
| def render_step_3(): | |
| """Step 3: Choose Your Options""" | |
| st.markdown('<h2 style="color: #0B3B0B; font-family: Georgia, serif;">Step 3: Choose Your Next Step</h2>', unsafe_allow_html=True) | |
| if st.session_state.video_analyzed: | |
| data = st.session_state.analysis_data | |
| video_path = data.get('video_path', '') | |
| st.markdown(f""" | |
| ## ✅ Analysis Complete! | |
| **Video processed successfully:** {os.path.basename(video_path) if video_path else 'Unknown'} | |
| **What's Next?** Choose how you'd like to get your swing feedback: | |
| """) | |
| # Display video if available (smaller) | |
| if video_path and os.path.exists(video_path): | |
| with st.expander("📹 View Your Swing Video", expanded=False): | |
| display_video(video_path, width=300) | |
| # Main action buttons - larger and more prominent | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.markdown(""" | |
| <div style="text-align: center; padding: 20px; border: 2px solid #4CAF50; border-radius: 15px; margin: 10px 0;"> | |
| <h3 style="color: #0B3B0B;">🎯 See Feedback</h3> | |
| <p>Get personalized swing analysis with specific tips for improvement</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| if st.button("🎯 Get Improvements", key="get_improvements", use_container_width=True): | |
| st.session_state.current_step = 4 | |
| st.rerun() | |
| with col2: | |
| st.markdown(""" | |
| <div style="text-align: center; padding: 20px; border: 2px solid #4CAF50; border-radius: 15px; margin: 10px 0;"> | |
| <h3 style="color: #0B3B0B;">💬 Ask Questions</h3> | |
| <p>Ask specific questions about golf swing technique from our knowledge base</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| if st.button("💬 Ask Questions", key="ask_questions", use_container_width=True): | |
| st.session_state.current_step = 5 | |
| st.rerun() | |
| else: | |
| st.error("No analysis data available. Please start over.") | |
| if st.button("🔄 Start Over", key="restart_analysis"): | |
| st.session_state.current_step = 1 | |
| st.rerun() | |
| def render_step_4(): | |
| """Step 4: Swing Analysis with New Grading Scheme""" | |
| st.markdown('<h2 style="color: #0B3B0B; font-family: Georgia, serif;">Step 4: Swing Analysis</h2>', unsafe_allow_html=True) | |
| if st.session_state.video_analyzed: | |
| data = st.session_state.analysis_data | |
| pose_data = data['pose_data'] | |
| swing_phases = data['swing_phases'] | |
| # Get the analysis data that contains core metrics | |
| analysis_data = data.get('analysis_data', {}) | |
| core_metrics = analysis_data.get('core_metrics', {}) | |
| # Get camera view to determine which metrics to calculate | |
| camera_view = st.session_state.analysis_data.get('camera_view', 'Down the Line (DTL)') | |
| is_front_facing = camera_view == "Front Facing (Face-On)" | |
| # Get club selection from session state (set in Step 1) | |
| club_selection = st.session_state.get('club_type', 'iron') | |
| # Force recomputation with new validation logic | |
| # This ensures the latest fixes are applied | |
| core_metrics = compute_core_metrics(pose_data, swing_phases, is_front_facing=is_front_facing, frames=data.get('frames'), world_landmarks=data.get('world_landmarks'), club=club_selection) | |
| # Update the cached analysis data with new metrics | |
| if 'analysis_data' in data: | |
| data['analysis_data']['core_metrics'] = core_metrics | |
| # Generate and display LLM analysis automatically at the top | |
| if 'llm_analysis' not in st.session_state: | |
| # Generate LLM analysis automatically on first load | |
| with st.spinner("Generating AI swing analysis..."): | |
| try: | |
| # Prepare data for LLM analysis | |
| analysis_data = data.get('analysis_data', {}) | |
| if not analysis_data: | |
| # Fallback: prepare data if not already done | |
| analysis_data = prepare_data_for_llm(pose_data, swing_phases, data.get('trajectory_data', {}), fps=30.0, frame_shape=None) | |
| # Generate LLM analysis | |
| raw_analysis = generate_swing_analysis(pose_data, swing_phases, data.get('trajectory_data', {})) | |
| if raw_analysis and not raw_analysis.startswith("Error:"): | |
| # Parse and format the analysis | |
| formatted_analysis = parse_and_format_analysis(raw_analysis) | |
| # Store the analysis in session state for future reference | |
| st.session_state.llm_analysis = { | |
| 'raw': raw_analysis, | |
| 'formatted': formatted_analysis | |
| } | |
| else: | |
| # Store error state | |
| st.session_state.llm_analysis = { | |
| 'error': raw_analysis or "Failed to generate analysis" | |
| } | |
| except Exception as e: | |
| # Store error state | |
| st.session_state.llm_analysis = { | |
| 'error': f"Error generating AI analysis: {str(e)}" | |
| } | |
| # Display the overall summary at the very top if analysis was successful | |
| if 'llm_analysis' in st.session_state and 'formatted' in st.session_state.llm_analysis: | |
| formatted_analysis = st.session_state.llm_analysis['formatted'] | |
| overall_summary = formatted_analysis.get('overall_summary', '') | |
| if overall_summary: | |
| st.markdown(f""" | |
| <div style='background-color: #e8f4fd; padding: 20px; border-radius: 10px; margin-bottom: 20px; border-left: 5px solid #1f77b4;'> | |
| <h3 style='color: #1f77b4; margin-top: 0; margin-bottom: 10px;'>🎯 Overall Swing Assessment</h3> | |
| <p style='margin: 0; line-height: 1.6; font-size: 16px; color: #2c3e50;'>{overall_summary}</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| elif 'llm_analysis' in st.session_state and 'error' in st.session_state.llm_analysis: | |
| st.warning(f"AI Analysis unavailable: {st.session_state.llm_analysis['error']}") | |
| # Display the new grading scheme below the summary | |
| display_new_grading_scheme(core_metrics) | |
| # Developer tools (small and subtle) | |
| st.write("") | |
| st.write("") | |
| # Small developer buttons in bottom right corner | |
| dev_col1, dev_col2, dev_col3 = st.columns([4, 1, 1]) | |
| with dev_col2: | |
| if st.button("📊 Debug", key="debug_calibration", help="View calibration data"): | |
| with st.expander("Calibration Data", expanded=True): | |
| st.caption("Raw values for qualitative assessments:") | |
| # Shaft angle raw value | |
| shoulder_swing_plane_raw = core_metrics.get("shoulder_tilt_swing_plane_top_deg", {}).get('value') | |
| shoulder_swing_plane_status = core_metrics.get("shoulder_tilt_swing_plane_top_deg", {}).get('status', 'n/a') | |
| st.caption(f"Shoulder Tilt/Swing Plane @ Top: {shoulder_swing_plane_raw}° ({shoulder_swing_plane_status})") | |
| # Torso side-bend raw value | |
| torso_sidebend_raw = core_metrics.get("torso_side_bend_deg", {}).get('value') | |
| torso_sidebend_status = core_metrics.get("torso_side_bend_deg", {}).get('status', 'n/a') | |
| if torso_sidebend_raw is not None: | |
| st.caption(f"Torso Side-Bend @ Impact: {torso_sidebend_raw}° ({torso_sidebend_status})") | |
| else: | |
| # Fallback to old metric for compatibility | |
| shoulder_tilt_raw = core_metrics.get("shoulder_tilt_impact_deg", {}).get('value') | |
| shoulder_tilt_status = core_metrics.get("shoulder_tilt_impact_deg", {}).get('status', 'n/a') | |
| st.caption(f"Torso Side-Bend @ Impact: {shoulder_tilt_raw}° ({shoulder_tilt_status})") | |
| # Hip depth raw value | |
| hip_depth_raw = core_metrics.get("hip_depth_early_extension", {}).get('value') | |
| hip_depth_status = core_metrics.get("hip_depth_early_extension", {}).get('status', 'n/a') | |
| st.caption(f"Hip Depth / Early Extension: {hip_depth_raw} ({hip_depth_status})") | |
| if has_front_facing_metrics: | |
| st.caption("Front-facing metrics detected") | |
| else: | |
| st.caption("DTL metrics detected") | |
| with dev_col3: | |
| if st.button("🔍 Prompt", key="show_prompt_btn", help="View LLM prompt"): | |
| if 'prompt' in st.session_state.analysis_data: | |
| with st.expander("LLM Prompt", expanded=True): | |
| st.code(st.session_state.analysis_data['prompt'], language="text") | |
| else: | |
| st.error("No prompt data available.") | |
| else: | |
| st.error("No analysis data available. Please analyze a video first.") | |
| def render_step_5(): | |
| """Step 5: Ask the Golf Expert""" | |
| st.markdown('<h2 style="color: #0B3B0B; font-family: Georgia, serif;">Step 5: Ask the Golf Expert</h2>', unsafe_allow_html=True) | |
| st.markdown("💬 **Ready to answer your swing questions!**") | |
| # Check if we have any analysis data, but don't block access if we don't | |
| has_analysis = st.session_state.get('video_analyzed', False) and 'analysis_data' in st.session_state | |
| if not has_analysis: | |
| st.info("💡 For personalized answers about your swing, complete the video analysis first. You can still ask general golf questions!") | |
| if RAG_AVAILABLE: | |
| render_rag_interface() | |
| else: | |
| st.error("❌ **RAG System**: Not available due to missing dependencies") | |
| st.info("The Golf Expert chatbot requires additional dependencies that are not currently available.") | |
| if __name__ == "__main__": | |
| main() | |