Emily Chen commited on
Commit
a422282
·
0 Parent(s):

initial commit

Browse files

Signed-off-by: Emily Chen <emilychen@Emilys-iMac.lan>

README.md ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Golf Swing Analysis
2
+
3
+ A Python application that analyzes golf swings from YouTube videos using computer vision and AI.
4
+
5
+ ## Features
6
+
7
+ - YouTube video retrieval and processing using yt-dlp
8
+ - Golfer, club, and ball detection using YOLOv8
9
+ - Pose estimation for swing analysis
10
+ - Swing phase segmentation (setup, backswing, downswing, impact, follow-through)
11
+ - Trajectory and speed analysis
12
+ - AI-powered swing evaluation and coaching tips
13
+ - Visual feedback with annotations
14
+ - Streamlit web interface
15
+
16
+ ## Installation
17
+
18
+ 1. Clone this repository
19
+ 2. Run the setup script to create necessary directories:
20
+ ```
21
+ chmod +x setup_directories.sh
22
+ ./setup_directories.sh
23
+ ```
24
+ 3. Create a virtual environment:
25
+ ```
26
+ python -m venv .venv
27
+ source .venv/bin/activate # On Windows: .venv\Scripts\activate
28
+ ```
29
+ 4. Install dependencies:
30
+ ```
31
+ pip install -r requirements.txt
32
+ ```
33
+ 5. Edit the `.env` file with your OpenAI API key:
34
+ ```
35
+ OPENAI_API_KEY=your_api_key_here
36
+ ```
37
+
38
+ ## Usage
39
+
40
+ ### Command Line Interface
41
+
42
+ Run the main application:
43
+
44
+ ```
45
+ python app/main.py
46
+ ```
47
+
48
+ Follow the prompts to input a YouTube URL containing a golf swing recording.
49
+
50
+ ### Streamlit Web Interface
51
+
52
+ Run the Streamlit web app using the provided shell script:
53
+
54
+ ```
55
+ ./run_streamlit.sh
56
+ ```
57
+
58
+ Or manually with:
59
+
60
+ ```
61
+ source .venv/bin/activate
62
+
63
+ ```
64
+
65
+ The web interface provides:
66
+ - Options to upload a video or use a YouTube URL
67
+ - Control over frame skip rate for YOLO detection
68
+ - Toggle for enabling/disabling GPT analysis
69
+ - Interactive display of analysis results
70
+ - Option to create and view annotated videos
71
+
72
+ ## File Organization
73
+
74
+ - **downloads/**: Contains both downloaded YouTube videos and annotated videos
75
+ - All videos (both original and annotated) are stored in the same directory for easy access
76
+
77
+ ## Troubleshooting
78
+
79
+ If you encounter issues with the "Create Annotated Video" button:
80
+ 1. Make sure you've run the setup script to create the downloads directory
81
+ 2. Check that the `downloads` directory has write permissions
82
+ 3. Try restarting the Streamlit app
83
+
84
+ ## Requirements
85
+
86
+ - Python 3.8+
87
+ - OpenCV
88
+ - YOLOv8
89
+ - MediaPipe
90
+ - yt-dlp
91
+ - OpenAI API key
92
+ - Streamlit
app/__init__.py ADDED
File without changes
app/components/__init__.py ADDED
File without changes
app/main.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Golf Swing Analysis - Main Application
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ from dotenv import load_dotenv
9
+
10
+ # Load environment variables
11
+ load_dotenv()
12
+
13
+ # Add the app directory to the path
14
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
15
+
16
+ from app.utils.video_downloader import download_youtube_video
17
+ from app.utils.video_processor import process_video
18
+ from app.models.pose_estimator import analyze_pose
19
+ from app.models.swing_analyzer import segment_swing, analyze_trajectory
20
+ from app.models.llm_analyzer import generate_swing_analysis
21
+ from app.utils.visualizer import create_annotated_video
22
+
23
+
24
+ def main():
25
+ """Main application function"""
26
+ print("\n===== Golf Swing Analysis =====\n")
27
+
28
+ # Step 1: Get YouTube URL from user
29
+ youtube_url = input("Enter YouTube URL of golf swing: ")
30
+
31
+ # Step 2: Configure analysis options
32
+ enable_gpt = input(
33
+ "\nEnable GPT analysis? (y/n, default: y): ").lower() != 'n'
34
+
35
+ sample_rate_input = input(
36
+ "\nFrame skip rate for YOLO (1-10, default: 5): ")
37
+ sample_rate = 5 # Default value
38
+ if sample_rate_input.isdigit():
39
+ sample_rate = max(1, min(10, int(sample_rate_input)))
40
+
41
+ try:
42
+ # Step 3: Download the video
43
+ print("\nDownloading video...")
44
+ video_path = download_youtube_video(youtube_url)
45
+ print(f"Video downloaded to: {video_path}")
46
+
47
+ # Step 4: Process video and detect golfer, club, and ball
48
+ print("\nProcessing video and detecting objects...")
49
+ frames, detections = process_video(video_path, sample_rate=sample_rate)
50
+
51
+ # Step 5: Analyze pose throughout the swing
52
+ print("\nAnalyzing golfer's pose...")
53
+ pose_data = analyze_pose(frames)
54
+
55
+ # Step 6: Segment swing into phases
56
+ print("\nSegmenting swing phases...")
57
+ swing_phases = segment_swing(pose_data,
58
+ detections,
59
+ sample_rate=sample_rate)
60
+
61
+ # Step 7: Analyze trajectory and speed
62
+ print("\nAnalyzing trajectory and speed...")
63
+ trajectory_data = analyze_trajectory(frames,
64
+ detections,
65
+ swing_phases,
66
+ sample_rate=sample_rate)
67
+
68
+ # Step 8: Generate swing analysis using LLM (if enabled)
69
+ if enable_gpt:
70
+ print("\nGenerating swing analysis and coaching tips...")
71
+ analysis = generate_swing_analysis(pose_data, swing_phases,
72
+ trajectory_data)
73
+
74
+ # Display results
75
+ print("\n===== Swing Analysis Results =====\n")
76
+ print(analysis)
77
+ else:
78
+ print("\nGPT analysis disabled. Skipping swing evaluation.")
79
+
80
+ # Step 9: Create annotated video (optional)
81
+ create_video = input(
82
+ "\nCreate annotated video? (y/n): ").lower() == 'y'
83
+ if create_video:
84
+ print("\nCreating annotated video...")
85
+ output_path = create_annotated_video(video_path,
86
+ frames,
87
+ detections,
88
+ pose_data,
89
+ swing_phases,
90
+ trajectory_data,
91
+ sample_rate=sample_rate)
92
+ print(f"Annotated video saved to: {output_path}")
93
+
94
+ print("\nAnalysis complete!")
95
+
96
+ except Exception as e:
97
+ print(f"\nError: {str(e)}")
98
+ return
99
+
100
+
101
+ if __name__ == "__main__":
102
+ main()
app/models/__init__.py ADDED
File without changes
app/models/llm_analyzer.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLM-based golf swing analysis module
3
+ """
4
+
5
+ import os
6
+ import json
7
+ from openai import OpenAI
8
+
9
+
10
+ def generate_swing_analysis(pose_data, swing_phases, trajectory_data):
11
+ """
12
+ Generate swing analysis and coaching tips using LLM
13
+
14
+ Args:
15
+ pose_data (dict): Dictionary mapping frame indices to pose keypoints
16
+ swing_phases (dict): Dictionary mapping phase names to lists of frame indices
17
+ trajectory_data (dict): Dictionary mapping frame indices to trajectory data
18
+
19
+ Returns:
20
+ str: Detailed swing analysis and coaching tips
21
+ """
22
+ # Check if OpenAI API key is available
23
+ api_key = os.getenv("OPENAI_API_KEY")
24
+ if not api_key:
25
+ return "Error: OpenAI API key not found. Please set the OPENAI_API_KEY environment variable."
26
+
27
+ # Create OpenAI client
28
+ client = OpenAI(api_key=api_key)
29
+
30
+ # Prepare data for LLM
31
+ analysis_data = prepare_data_for_llm(pose_data, swing_phases,
32
+ trajectory_data)
33
+
34
+ # Generate prompt for LLM
35
+ prompt = create_llm_prompt(analysis_data)
36
+
37
+ try:
38
+ # Call OpenAI API
39
+ response = client.chat.completions.create(
40
+ model="gpt-4",
41
+ messages=[{
42
+ "role":
43
+ "system",
44
+ "content":
45
+ "You are a professional golf coach with expertise in analyzing golf swings. Provide detailed, actionable feedback based on the swing data provided."
46
+ }, {
47
+ "role": "user",
48
+ "content": prompt
49
+ }],
50
+ temperature=0.7,
51
+ max_tokens=1000)
52
+
53
+ # Extract and return analysis
54
+ analysis = response.choices[0].message.content
55
+ return analysis
56
+
57
+ except Exception as e:
58
+ return f"Error generating swing analysis: {str(e)}"
59
+
60
+
61
+ def prepare_data_for_llm(pose_data, swing_phases, trajectory_data):
62
+ """
63
+ Prepare swing data for LLM analysis
64
+
65
+ Args:
66
+ pose_data (dict): Dictionary mapping frame indices to pose keypoints
67
+ swing_phases (dict): Dictionary mapping phase names to lists of frame indices
68
+ trajectory_data (dict): Dictionary mapping frame indices to trajectory data
69
+
70
+ Returns:
71
+ dict: Processed data for LLM analysis
72
+ """
73
+ analysis_data = {"swing_phases": {}, "joint_angles": {}, "trajectory": {}}
74
+
75
+ # Process swing phases
76
+ for phase, frames in swing_phases.items():
77
+ if frames:
78
+ # Get a representative frame for each phase
79
+ mid_frame = frames[len(frames) // 2]
80
+
81
+ # Get joint angles for the representative frame
82
+ if mid_frame in pose_data:
83
+ keypoints = pose_data[mid_frame]
84
+
85
+ # Calculate key metrics for each phase
86
+ analysis_data["swing_phases"][phase] = {
87
+ "frame_index": mid_frame,
88
+ "duration_frames": len(frames)
89
+ }
90
+
91
+ # Process trajectory data
92
+ impact_frames = swing_phases.get("impact", [])
93
+ if impact_frames:
94
+ impact_frame = impact_frames[len(impact_frames) // 2]
95
+ if impact_frame in trajectory_data:
96
+ impact_data = trajectory_data[impact_frame]
97
+ if "club_speed" in impact_data and impact_data["club_speed"]:
98
+ analysis_data["trajectory"]["club_speed_mph"] = impact_data[
99
+ "club_speed"]
100
+
101
+ # Add additional metrics that would be calculated in a real implementation
102
+ # These are placeholder values for demonstration
103
+ analysis_data["metrics"] = {
104
+ "tempo_ratio": 3.0, # Backswing to downswing time ratio
105
+ "swing_plane_consistency": 0.85, # 0-1 scale
106
+ "weight_shift": 0.7, # 0-1 scale
107
+ "hip_rotation": 45, # degrees
108
+ "shoulder_rotation": 90, # degrees
109
+ "wrist_hinge": 80, # degrees
110
+ "posture_score": 0.8 # 0-1 scale
111
+ }
112
+
113
+ return analysis_data
114
+
115
+
116
+ def create_llm_prompt(analysis_data):
117
+ """
118
+ Create a prompt for the LLM based on swing analysis data
119
+
120
+ Args:
121
+ analysis_data (dict): Processed swing analysis data
122
+
123
+ Returns:
124
+ str: Prompt for LLM
125
+ """
126
+ prompt = """
127
+ I've analyzed a golf swing and extracted the following data:
128
+
129
+ ## Swing Phases
130
+ """
131
+
132
+ # Add swing phases information
133
+ for phase, data in analysis_data["swing_phases"].items():
134
+ prompt += f"- {phase.capitalize()}: Frame {data['frame_index']}, Duration: {data['duration_frames']} frames\n"
135
+
136
+ # Add trajectory information
137
+ prompt += "\n## Trajectory Data\n"
138
+ if "trajectory" in analysis_data and "club_speed_mph" in analysis_data[
139
+ "trajectory"]:
140
+ prompt += f"- Club Speed: {analysis_data['trajectory']['club_speed_mph']:.1f} mph\n"
141
+
142
+ # Add metrics
143
+ prompt += "\n## Swing Metrics\n"
144
+ for metric, value in analysis_data["metrics"].items():
145
+ # Format metric name for readability
146
+ metric_name = metric.replace("_", " ").title()
147
+
148
+ # Format value based on type
149
+ if isinstance(value, float):
150
+ if 0 <= value <= 1:
151
+ # Format as percentage for 0-1 scale metrics
152
+ formatted_value = f"{value * 100:.0f}%"
153
+ else:
154
+ # Format as decimal for other floats
155
+ formatted_value = f"{value:.1f}"
156
+ else:
157
+ # Use as is for integers and other types
158
+ formatted_value = str(value)
159
+
160
+ prompt += f"- {metric_name}: {formatted_value}\n"
161
+
162
+ prompt += """
163
+ Based on this data, please provide:
164
+ 1. A detailed analysis of the golf swing
165
+ 2. Key strengths and weaknesses
166
+ 3. Specific recommendations for improvement
167
+ 4. Drills or exercises that could help address the identified issues
168
+
169
+ Please be specific and actionable in your feedback.
170
+ """
171
+
172
+ return prompt
app/models/pose_estimator.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pose estimation module for golf swing analysis
3
+ """
4
+
5
+ import cv2
6
+ import numpy as np
7
+ import mediapipe as mp
8
+ from tqdm import tqdm
9
+
10
+
11
+ class PoseEstimator:
12
+ """MediaPipe-based pose estimator for golf swing analysis"""
13
+
14
+ def __init__(self):
15
+ """Initialize the pose estimator"""
16
+ self.mp_pose = mp.solutions.pose
17
+ self.pose = self.mp_pose.Pose(static_image_mode=False,
18
+ model_complexity=2,
19
+ enable_segmentation=False,
20
+ min_detection_confidence=0.5,
21
+ min_tracking_confidence=0.5)
22
+
23
+ def process_frame(self, frame):
24
+ """
25
+ Process a single frame and extract pose landmarks
26
+
27
+ Args:
28
+ frame (numpy.ndarray): Input frame
29
+
30
+ Returns:
31
+ list: List of keypoints [x, y, visibility] or None if not detected
32
+ """
33
+ # Convert BGR to RGB
34
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
35
+
36
+ # Process the frame
37
+ results = self.pose.process(frame_rgb)
38
+
39
+ if not results.pose_landmarks:
40
+ return None
41
+
42
+ # Extract keypoints
43
+ keypoints = []
44
+ for landmark in results.pose_landmarks.landmark:
45
+ # Convert normalized coordinates to pixel coordinates
46
+ h, w, _ = frame.shape
47
+ x, y = int(landmark.x * w), int(landmark.y * h)
48
+ visibility = landmark.visibility
49
+ keypoints.append([x, y, visibility])
50
+
51
+ return keypoints
52
+
53
+ def close(self):
54
+ """Release resources"""
55
+ self.pose.close()
56
+
57
+
58
+ def analyze_pose(frames):
59
+ """
60
+ Analyze pose in video frames
61
+
62
+ Args:
63
+ frames (list): List of video frames
64
+
65
+ Returns:
66
+ dict: Dictionary mapping frame indices to pose keypoints
67
+ """
68
+ pose_estimator = PoseEstimator()
69
+ pose_data = {}
70
+
71
+ for i, frame in enumerate(tqdm(frames, desc="Analyzing pose")):
72
+ keypoints = pose_estimator.process_frame(frame)
73
+ if keypoints:
74
+ pose_data[i] = keypoints
75
+
76
+ pose_estimator.close()
77
+
78
+ return pose_data
79
+
80
+
81
+ def calculate_joint_angles(keypoints):
82
+ """
83
+ Calculate joint angles from pose keypoints
84
+
85
+ Args:
86
+ keypoints (list): List of keypoints [x, y, visibility]
87
+
88
+ Returns:
89
+ dict: Dictionary of joint angles in degrees
90
+ """
91
+ # Define joint connections for angle calculation
92
+ joint_connections = {
93
+ "right_shoulder": [
94
+ mp.solutions.pose.PoseLandmark.RIGHT_ELBOW.value,
95
+ mp.solutions.pose.PoseLandmark.RIGHT_SHOULDER.value,
96
+ mp.solutions.pose.PoseLandmark.RIGHT_HIP.value
97
+ ],
98
+ "left_shoulder": [
99
+ mp.solutions.pose.PoseLandmark.LEFT_ELBOW.value,
100
+ mp.solutions.pose.PoseLandmark.LEFT_SHOULDER.value,
101
+ mp.solutions.pose.PoseLandmark.LEFT_HIP.value
102
+ ],
103
+ "right_elbow": [
104
+ mp.solutions.pose.PoseLandmark.RIGHT_WRIST.value,
105
+ mp.solutions.pose.PoseLandmark.RIGHT_ELBOW.value,
106
+ mp.solutions.pose.PoseLandmark.RIGHT_SHOULDER.value
107
+ ],
108
+ "left_elbow": [
109
+ mp.solutions.pose.PoseLandmark.LEFT_WRIST.value,
110
+ mp.solutions.pose.PoseLandmark.LEFT_ELBOW.value,
111
+ mp.solutions.pose.PoseLandmark.LEFT_SHOULDER.value
112
+ ],
113
+ "right_hip": [
114
+ mp.solutions.pose.PoseLandmark.RIGHT_KNEE.value,
115
+ mp.solutions.pose.PoseLandmark.RIGHT_HIP.value,
116
+ mp.solutions.pose.PoseLandmark.RIGHT_SHOULDER.value
117
+ ],
118
+ "left_hip": [
119
+ mp.solutions.pose.PoseLandmark.LEFT_KNEE.value,
120
+ mp.solutions.pose.PoseLandmark.LEFT_HIP.value,
121
+ mp.solutions.pose.PoseLandmark.LEFT_SHOULDER.value
122
+ ],
123
+ "right_knee": [
124
+ mp.solutions.pose.PoseLandmark.RIGHT_ANKLE.value,
125
+ mp.solutions.pose.PoseLandmark.RIGHT_KNEE.value,
126
+ mp.solutions.pose.PoseLandmark.RIGHT_HIP.value
127
+ ],
128
+ "left_knee": [
129
+ mp.solutions.pose.PoseLandmark.LEFT_ANKLE.value,
130
+ mp.solutions.pose.PoseLandmark.LEFT_KNEE.value,
131
+ mp.solutions.pose.PoseLandmark.LEFT_HIP.value
132
+ ]
133
+ }
134
+
135
+ angles = {}
136
+
137
+ for joint_name, landmarks in joint_connections.items():
138
+ # Get the three points that form the angle
139
+ if all(landmarks[i] < len(keypoints) for i in range(3)):
140
+ p1 = np.array(keypoints[landmarks[0]][:2])
141
+ p2 = np.array(keypoints[landmarks[1]][:2])
142
+ p3 = np.array(keypoints[landmarks[2]][:2])
143
+
144
+ # Calculate vectors
145
+ v1 = p1 - p2
146
+ v2 = p3 - p2
147
+
148
+ # Calculate angle
149
+ cosine_angle = np.dot(
150
+ v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2))
151
+ angle = np.arccos(np.clip(cosine_angle, -1.0, 1.0))
152
+ angle_degrees = np.degrees(angle)
153
+
154
+ angles[joint_name] = angle_degrees
155
+
156
+ return angles
app/models/swing_analyzer.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Swing analysis module for golf swing segmentation and trajectory analysis
3
+ """
4
+
5
+ import numpy as np
6
+ import cv2
7
+ from app.models.pose_estimator import calculate_joint_angles
8
+
9
+
10
+ def segment_swing(pose_data, detections, sample_rate=5):
11
+ """
12
+ Segment the golf swing into key phases
13
+
14
+ Args:
15
+ pose_data (dict): Dictionary mapping frame indices to pose keypoints
16
+ detections (list): List of Detection objects
17
+ sample_rate (int): The frame sampling rate used during processing
18
+
19
+ Returns:
20
+ dict: Dictionary mapping phase names to lists of frame indices
21
+ """
22
+ # Initialize swing phases
23
+ swing_phases = {
24
+ "setup": [],
25
+ "backswing": [],
26
+ "downswing": [],
27
+ "impact": [],
28
+ "follow_through": []
29
+ }
30
+
31
+ # Get frame indices with pose data
32
+ frame_indices = sorted(pose_data.keys())
33
+
34
+ if not frame_indices:
35
+ return swing_phases
36
+
37
+ # Calculate joint angles for each frame
38
+ angles_by_frame = {}
39
+ for idx in frame_indices:
40
+ keypoints = pose_data[idx]
41
+ angles = calculate_joint_angles(keypoints)
42
+ angles_by_frame[idx] = angles
43
+
44
+ # Analyze shoulder rotation to identify swing phases
45
+ # This is a simplified approach - a more sophisticated algorithm would be needed for production
46
+
47
+ # Find the frame with the maximum right shoulder angle (top of backswing)
48
+ max_shoulder_angle = -1
49
+ top_backswing_frame = frame_indices[0]
50
+
51
+ for idx in frame_indices:
52
+ angles = angles_by_frame[idx]
53
+ if "right_shoulder" in angles and angles[
54
+ "right_shoulder"] > max_shoulder_angle:
55
+ max_shoulder_angle = angles["right_shoulder"]
56
+ top_backswing_frame = idx
57
+
58
+ # Find impact frame (when club meets ball)
59
+ # In a real implementation, this would use club and ball detection
60
+ impact_frame = None
61
+ person_positions = {}
62
+
63
+ # Extract person positions from detections
64
+ for detection in detections:
65
+ if detection.class_name == "person":
66
+ frame_idx = detection.frame_idx // sample_rate # Convert to processed frame index
67
+ if frame_idx in frame_indices:
68
+ person_positions[frame_idx] = detection.bbox
69
+
70
+ # Find the frame with the most forward position (impact)
71
+ if person_positions:
72
+ min_x = float('inf')
73
+ for idx, bbox in person_positions.items():
74
+ if idx > top_backswing_frame and bbox[0] < min_x:
75
+ min_x = bbox[0]
76
+ impact_frame = idx
77
+
78
+ # If impact frame not found, estimate it as 2/3 between top of backswing and end
79
+ if impact_frame is None:
80
+ impact_frame = frame_indices[0] + int(
81
+ (frame_indices[-1] - top_backswing_frame) * 2 / 3)
82
+
83
+ # Assign frames to phases
84
+ for idx in frame_indices:
85
+ if idx < frame_indices[len(frame_indices) // 5]:
86
+ # First 20% of frames are setup
87
+ swing_phases["setup"].append(idx)
88
+ elif idx < top_backswing_frame:
89
+ # Frames before top of backswing are backswing
90
+ swing_phases["backswing"].append(idx)
91
+ elif idx < impact_frame:
92
+ # Frames between top of backswing and impact are downswing
93
+ swing_phases["downswing"].append(idx)
94
+ elif idx < impact_frame + 5:
95
+ # Frames around impact
96
+ swing_phases["impact"].append(idx)
97
+ else:
98
+ # Remaining frames are follow-through
99
+ swing_phases["follow_through"].append(idx)
100
+
101
+ return swing_phases
102
+
103
+
104
+ def analyze_trajectory(frames, detections, swing_phases, sample_rate=5):
105
+ """
106
+ Analyze club and ball trajectory and speed
107
+
108
+ Args:
109
+ frames (list): List of video frames
110
+ detections (list): List of Detection objects
111
+ swing_phases (dict): Dictionary mapping phase names to lists of frame indices
112
+ sample_rate (int): The frame sampling rate used during processing
113
+
114
+ Returns:
115
+ dict: Dictionary mapping frame indices to trajectory data
116
+ """
117
+ trajectory_data = {}
118
+
119
+ # Extract ball detections
120
+ ball_detections = [d for d in detections if d.class_name == "sports ball"]
121
+
122
+ # Get impact frame index
123
+ impact_frames = swing_phases.get("impact", [])
124
+ if not impact_frames:
125
+ return trajectory_data
126
+
127
+ impact_frame_idx = impact_frames[len(impact_frames) // 2]
128
+
129
+ # Track ball trajectory after impact
130
+ ball_trajectory = []
131
+ ball_positions = {}
132
+
133
+ for detection in ball_detections:
134
+ frame_idx = detection.frame_idx // sample_rate # Convert to processed frame index
135
+ if frame_idx >= impact_frame_idx:
136
+ # Calculate ball center
137
+ x1, y1, x2, y2 = detection.bbox
138
+ center_x = (x1 + x2) / 2
139
+ center_y = (y1 + y2) / 2
140
+ ball_positions[frame_idx] = (center_x, center_y)
141
+
142
+ # Sort ball positions by frame index
143
+ sorted_frames = sorted(ball_positions.keys())
144
+ for idx in sorted_frames:
145
+ ball_trajectory.append(ball_positions[idx])
146
+
147
+ # Estimate club speed at impact
148
+ # In a real implementation, this would use more sophisticated tracking
149
+ club_speed = None
150
+ if len(swing_phases.get("downswing", [])) >= 2:
151
+ # Simplified club speed calculation
152
+ # In reality, this would require tracking the club head specifically
153
+ downswing_frames = swing_phases["downswing"]
154
+ time_diff = (downswing_frames[-1] -
155
+ downswing_frames[0]) / 30 # Assuming 30 fps
156
+ if time_diff > 0:
157
+ # Simplified speed calculation (just an example)
158
+ club_speed = 100 * (1 / time_diff) # Arbitrary scaling
159
+
160
+ # Populate trajectory data
161
+ for idx in sorted(swing_phases.keys()):
162
+ frames_in_phase = swing_phases[idx]
163
+ for frame_idx in frames_in_phase:
164
+ trajectory_data[frame_idx] = {
165
+ "phase":
166
+ idx,
167
+ "club_speed":
168
+ club_speed if idx == "impact" else None,
169
+ "ball_trajectory":
170
+ ball_trajectory
171
+ if idx == "impact" or idx == "follow_through" else None
172
+ }
173
+
174
+ return trajectory_data
app/streamlit_app.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Streamlit web UI for Golf Swing Analysis
3
+ """
4
+
5
+ import os
6
+ import sys
7
+ import tempfile
8
+ import streamlit as st
9
+ from dotenv import load_dotenv
10
+ import base64
11
+ from pathlib import Path
12
+ import shutil
13
+
14
+ # Load environment variables
15
+ load_dotenv()
16
+
17
+ # Add the app directory to the path
18
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
19
+
20
+ from app.utils.video_downloader import download_youtube_video
21
+ from app.utils.video_processor import process_video
22
+ from app.models.pose_estimator import analyze_pose
23
+ from app.models.swing_analyzer import segment_swing, analyze_trajectory
24
+ from app.models.llm_analyzer import generate_swing_analysis, create_llm_prompt, prepare_data_for_llm
25
+ from app.utils.visualizer import create_annotated_video
26
+
27
+ # Set page config
28
+ st.set_page_config(page_title="Golf Swing Analysis",
29
+ page_icon="🏌️",
30
+ layout="wide",
31
+ initial_sidebar_state="expanded")
32
+
33
+
34
+ # Define functions
35
+ def validate_youtube_url(url):
36
+ """Validate if the URL is a YouTube URL"""
37
+ return "youtube.com" in url or "youtu.be" in url
38
+
39
+
40
+ def process_uploaded_video(uploaded_file):
41
+ """Process an uploaded video file"""
42
+ # Create downloads directory if it doesn't exist
43
+ os.makedirs("downloads", exist_ok=True)
44
+
45
+ # Save uploaded file to the downloads directory
46
+ file_path = os.path.join("downloads", uploaded_file.name)
47
+ with open(file_path, "wb") as f:
48
+ f.write(uploaded_file.getvalue())
49
+
50
+ return file_path
51
+
52
+
53
+ def display_video(video_path):
54
+ """Display a video with download option"""
55
+ # Read video bytes
56
+ with open(video_path, "rb") as file:
57
+ video_bytes = file.read()
58
+
59
+ # Display video using st.video with bytes
60
+ st.video(video_bytes)
61
+
62
+ # Show download button
63
+ st.download_button(label="Download Video",
64
+ data=video_bytes,
65
+ file_name=os.path.basename(video_path),
66
+ mime="video/mp4")
67
+
68
+
69
+ # Main app
70
+ def main():
71
+ """Main Streamlit application"""
72
+ st.title("🏌️ Golf Swing Analysis")
73
+ st.write("Analyze your golf swing using computer vision and AI")
74
+
75
+ # Initialize session state for storing analysis results
76
+ if 'video_analyzed' not in st.session_state:
77
+ st.session_state.video_analyzed = False
78
+ if 'analysis_data' not in st.session_state:
79
+ st.session_state.analysis_data = {
80
+ 'video_path': None,
81
+ 'frames': None,
82
+ 'detections': None,
83
+ 'pose_data': None,
84
+ 'swing_phases': None,
85
+ 'trajectory_data': None,
86
+ 'sample_rate': None
87
+ }
88
+
89
+ # Sidebar for configuration
90
+ st.sidebar.title("Configuration")
91
+
92
+ # Option to enable/disable GPT analysis
93
+ enable_gpt = st.sidebar.checkbox("Enable GPT Analysis", value=True)
94
+
95
+ # Frame skip rate for YOLO
96
+ sample_rate = st.sidebar.slider(
97
+ "Frame Skip Rate (YOLO)",
98
+ min_value=1,
99
+ max_value=10,
100
+ value=5,
101
+ help=
102
+ "Process every Nth frame. Higher values = faster but less accurate.")
103
+
104
+ # Video input options
105
+ st.header("Video Input")
106
+ input_option = st.radio("Choose input method:",
107
+ ["YouTube URL", "Upload Video"])
108
+
109
+ video_path = None
110
+ analyze_clicked = False
111
+
112
+ if input_option == "YouTube URL":
113
+ youtube_url = st.text_input("Enter YouTube URL of golf swing:")
114
+
115
+ analyze_clicked = st.button("Analyze Swing", key="analyze_youtube")
116
+ if youtube_url and analyze_clicked:
117
+ if validate_youtube_url(youtube_url):
118
+ with st.spinner("Downloading video..."):
119
+ try:
120
+ video_path = download_youtube_video(youtube_url)
121
+ st.success("Video downloaded successfully!")
122
+ display_video(video_path)
123
+ except Exception as e:
124
+ st.error(f"Error downloading video: {str(e)}")
125
+ st.session_state.video_analyzed = False
126
+ return
127
+ else:
128
+ st.error("Please enter a valid YouTube URL")
129
+ st.session_state.video_analyzed = False
130
+ return
131
+
132
+ else: # Upload Video
133
+ uploaded_file = st.file_uploader("Upload a golf swing video",
134
+ type=["mp4", "mov", "avi"])
135
+
136
+ analyze_clicked = st.button("Analyze Swing", key="analyze_upload")
137
+ if uploaded_file and analyze_clicked:
138
+ with st.spinner("Processing uploaded video..."):
139
+ try:
140
+ video_path = process_uploaded_video(uploaded_file)
141
+ st.success("Video uploaded successfully!")
142
+ display_video(video_path)
143
+ except Exception as e:
144
+ st.error(f"Error processing video: {str(e)}")
145
+ st.session_state.video_analyzed = False
146
+ return
147
+
148
+ # Process video if available and analyze button was clicked
149
+ if video_path and analyze_clicked:
150
+ try:
151
+ # Step 1: Process video and detect objects
152
+ with st.spinner("Processing video and detecting objects..."):
153
+ frames, detections = process_video(video_path,
154
+ sample_rate=sample_rate)
155
+ st.success(f"Processed {len(frames)} frames")
156
+
157
+ # Step 2: Analyze pose
158
+ with st.spinner("Analyzing golfer's pose..."):
159
+ pose_data = analyze_pose(frames)
160
+ st.success("Pose analysis complete")
161
+
162
+ # Step 3: Segment swing into phases
163
+ with st.spinner("Segmenting swing phases..."):
164
+ swing_phases = segment_swing(pose_data,
165
+ detections,
166
+ sample_rate=sample_rate)
167
+
168
+ # Display swing phases
169
+ st.subheader("Swing Phases")
170
+ phase_cols = st.columns(5)
171
+ for i, (phase,
172
+ frames_in_phase) in enumerate(swing_phases.items()):
173
+ with phase_cols[i]:
174
+ st.metric(label=phase.capitalize(),
175
+ value=f"{len(frames_in_phase)} frames")
176
+
177
+ # Step 4: Analyze trajectory and speed
178
+ with st.spinner("Analyzing trajectory and speed..."):
179
+ trajectory_data = analyze_trajectory(frames,
180
+ detections,
181
+ swing_phases,
182
+ sample_rate=sample_rate)
183
+
184
+ # Display club speed if available
185
+ impact_frames = swing_phases.get("impact", [])
186
+ if impact_frames:
187
+ impact_frame = impact_frames[len(impact_frames) // 2]
188
+ if impact_frame in trajectory_data and trajectory_data[
189
+ impact_frame].get("club_speed"):
190
+ st.subheader("Club Speed")
191
+ st.metric(
192
+ label="Estimated Club Speed",
193
+ value=
194
+ f"{trajectory_data[impact_frame]['club_speed']:.1f} mph"
195
+ )
196
+
197
+ # Step 5: Generate swing analysis using LLM (if enabled)
198
+ # Prepare data for LLM regardless of whether GPT is enabled
199
+ analysis_data = prepare_data_for_llm(pose_data, swing_phases,
200
+ trajectory_data)
201
+ prompt = create_llm_prompt(analysis_data)
202
+
203
+ # Display the GPT prompt
204
+ with st.expander("View GPT Prompt"):
205
+ st.code(prompt, language="text")
206
+
207
+ if enable_gpt:
208
+ with st.spinner(
209
+ "Generating swing analysis and coaching tips..."):
210
+ analysis = generate_swing_analysis(pose_data, swing_phases,
211
+ trajectory_data)
212
+
213
+ # Display analysis
214
+ st.subheader("Swing Analysis")
215
+ st.write(analysis)
216
+ else:
217
+ st.info(
218
+ "GPT Analysis is disabled. Enable it in the sidebar to generate coaching tips."
219
+ )
220
+
221
+ # Store analysis data in session state
222
+ st.session_state.video_analyzed = True
223
+ st.session_state.analysis_data = {
224
+ 'video_path': video_path,
225
+ 'frames': frames,
226
+ 'detections': detections,
227
+ 'pose_data': pose_data,
228
+ 'swing_phases': swing_phases,
229
+ 'trajectory_data': trajectory_data,
230
+ 'sample_rate': sample_rate
231
+ }
232
+
233
+ except Exception as e:
234
+ st.error(f"Error during analysis: {str(e)}")
235
+ st.session_state.video_analyzed = False
236
+
237
+ # Create annotated video button (only show if analysis is complete)
238
+ if st.session_state.video_analyzed:
239
+ st.header("Create Annotated Video")
240
+ st.write(
241
+ "Create a video with annotations showing the analysis results")
242
+
243
+ # Create a separate section for the annotated video
244
+ if st.button("Generate Annotated Video", key="create_annotated"):
245
+ try:
246
+ with st.spinner("Creating annotated video..."):
247
+ # Create downloads directory if it doesn't exist
248
+ os.makedirs("downloads", exist_ok=True)
249
+
250
+ # Get data from session state
251
+ data = st.session_state.analysis_data
252
+
253
+ # Create the annotated video
254
+ output_path = create_annotated_video(
255
+ data['video_path'],
256
+ data['frames'],
257
+ data['detections'],
258
+ data['pose_data'],
259
+ data['swing_phases'],
260
+ data['trajectory_data'],
261
+ sample_rate=data['sample_rate'])
262
+
263
+ # Verify the file exists
264
+ if not os.path.exists(output_path):
265
+ raise FileNotFoundError(
266
+ f"Annotated video file not found at {output_path}")
267
+
268
+ st.success("Annotated video created successfully!")
269
+
270
+ # Display the video with download option
271
+ display_video(output_path)
272
+
273
+ except Exception as e:
274
+ st.error(f"Error creating annotated video: {str(e)}")
275
+ st.error(
276
+ "Please check if the downloads directory exists and is writable"
277
+ )
278
+
279
+
280
+ if __name__ == "__main__":
281
+ main()
app/utils/__init__.py ADDED
File without changes
app/utils/video_downloader.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ YouTube video downloader module using yt-dlp
3
+ """
4
+
5
+ import os
6
+ import yt_dlp
7
+
8
+
9
+ def download_youtube_video(url, output_dir="downloads"):
10
+ """
11
+ Download a YouTube video from the provided URL using yt-dlp
12
+
13
+ Args:
14
+ url (str): YouTube video URL
15
+ output_dir (str): Directory to save the downloaded video
16
+
17
+ Returns:
18
+ str: Path to the downloaded video file
19
+
20
+ Raises:
21
+ ValueError: If the URL is invalid or video is unavailable
22
+ """
23
+ # Create output directory if it doesn't exist
24
+ os.makedirs(output_dir, exist_ok=True)
25
+
26
+ # Set output template for the downloaded file
27
+ output_template = os.path.join(output_dir, "%(title)s.%(ext)s")
28
+
29
+ # Configure yt-dlp options
30
+ ydl_opts = {
31
+ 'format': 'best[ext=mp4]/best', # Prefer mp4 format
32
+ 'outtmpl': output_template,
33
+ 'noplaylist': True,
34
+ 'quiet': False,
35
+ 'no_warnings': False,
36
+ 'ignoreerrors': False,
37
+ }
38
+
39
+ try:
40
+ # Create yt-dlp object and download the video
41
+ with yt_dlp.YoutubeDL(ydl_opts) as ydl:
42
+ info = ydl.extract_info(url, download=True)
43
+
44
+ # Get the downloaded file path
45
+ if 'entries' in info:
46
+ # Playlist (should not happen with noplaylist=True)
47
+ raise ValueError("Playlists are not supported")
48
+
49
+ # Get video title and extension
50
+ title = info.get('title', 'video')
51
+ ext = info.get('ext', 'mp4')
52
+
53
+ # Construct the file path
54
+ video_path = os.path.join(output_dir, f"{title}.{ext}")
55
+
56
+ # Check if file exists
57
+ if not os.path.exists(video_path):
58
+ # Try with sanitized filename
59
+ sanitized_title = ''.join(c for c in title
60
+ if c.isalnum() or c in ' ._-')
61
+ video_path = os.path.join(output_dir,
62
+ f"{sanitized_title}.{ext}")
63
+
64
+ if not os.path.exists(video_path):
65
+ # If still not found, look for any mp4 file in the directory
66
+ mp4_files = [
67
+ f for f in os.listdir(output_dir) if f.endswith('.mp4')
68
+ ]
69
+ if mp4_files:
70
+ video_path = os.path.join(output_dir, mp4_files[0])
71
+ else:
72
+ raise ValueError("Downloaded file not found")
73
+
74
+ return video_path
75
+
76
+ except yt_dlp.utils.DownloadError as e:
77
+ raise ValueError(f"Error downloading video: {str(e)}")
78
+ except Exception as e:
79
+ raise ValueError(f"Error: {str(e)}")
app/utils/video_processor.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Video processing and object detection module
3
+ """
4
+
5
+ import cv2
6
+ import numpy as np
7
+ from tqdm import tqdm
8
+ from ultralytics import YOLO
9
+
10
+
11
+ class Detection:
12
+ """Class to store detection results"""
13
+
14
+ def __init__(self, frame_idx, class_id, class_name, bbox, confidence):
15
+ self.frame_idx = frame_idx
16
+ self.class_id = class_id
17
+ self.class_name = class_name
18
+ self.bbox = bbox # [x1, y1, x2, y2]
19
+ self.confidence = confidence
20
+
21
+
22
+ def process_video(video_path, sample_rate=5):
23
+ """
24
+ Process video and detect golfer, club, and ball
25
+
26
+ Args:
27
+ video_path (str): Path to the video file
28
+ sample_rate (int): Process every nth frame
29
+
30
+ Returns:
31
+ tuple: (frames, detections)
32
+ - frames: List of processed frames
33
+ - detections: List of Detection objects
34
+ """
35
+ # Load YOLOv8 model
36
+ model = YOLO("yolov8n.pt")
37
+
38
+ # Custom class names for golf-specific objects
39
+ class_names = model.names
40
+
41
+ # Open video file
42
+ cap = cv2.VideoCapture(video_path)
43
+ if not cap.isOpened():
44
+ raise ValueError("Error opening video file")
45
+
46
+ # Get video properties
47
+ frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
48
+ fps = cap.get(cv2.CAP_PROP_FPS)
49
+
50
+ frames = []
51
+ detections = []
52
+
53
+ # Process frames
54
+ for frame_idx in tqdm(range(0, frame_count, sample_rate),
55
+ desc="Processing frames"):
56
+ # Set frame position
57
+ cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
58
+
59
+ # Read frame
60
+ ret, frame = cap.read()
61
+ if not ret:
62
+ break
63
+
64
+ # Store original frame
65
+ frames.append(frame)
66
+
67
+ # Run YOLOv8 detection
68
+ results = model(frame)
69
+
70
+ # Process detection results
71
+ for result in results:
72
+ boxes = result.boxes
73
+ for box in boxes:
74
+ # Get detection information
75
+ class_id = int(box.cls.item())
76
+ class_name = class_names[class_id]
77
+
78
+ # Filter for relevant objects (person, sports ball)
79
+ if class_name in ["person", "sports ball"]:
80
+ bbox = box.xyxy[0].tolist() # [x1, y1, x2, y2]
81
+ confidence = box.conf.item()
82
+
83
+ # Create Detection object
84
+ detection = Detection(frame_idx, class_id, class_name,
85
+ bbox, confidence)
86
+ detections.append(detection)
87
+
88
+ # Release video capture
89
+ cap.release()
90
+
91
+ return frames, detections
app/utils/visualizer.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Visualization module for creating annotated videos
3
+ """
4
+
5
+ import os
6
+ import cv2
7
+ import numpy as np
8
+ from tqdm import tqdm
9
+ import logging
10
+ import mediapipe as mp
11
+
12
+ # Define body part groups and their colors
13
+ BODY_PART_COLORS = {
14
+ "head": (255, 0, 0), # Blue
15
+ "torso": (0, 255, 0), # Green
16
+ "arms": (255, 165, 0), # Orange
17
+ "hands": (255, 0, 255), # Magenta
18
+ "legs": (0, 255, 255), # Cyan
19
+ "feet": (255, 255, 0) # Yellow
20
+ }
21
+
22
+ # Define which landmarks belong to which body part groups
23
+ BODY_PARTS_MAPPING = {
24
+ "head": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], # Nose, eyes, ears, mouth
25
+ "torso": [11, 12, 23, 24], # Shoulders and hips
26
+ "arms": [11, 12, 13, 14], # Shoulders and elbows
27
+ "hands": [15, 16, 17, 18, 19, 20, 21,
28
+ 22], # Wrists, pinkies, indices, thumbs
29
+ "legs": [23, 24, 25, 26], # Hips and knees
30
+ "feet": [27, 28, 29, 30, 31, 32] # Ankles, heels, foot indices
31
+ }
32
+
33
+
34
+ def create_annotated_video(video_path,
35
+ frames,
36
+ detections,
37
+ pose_data,
38
+ swing_phases,
39
+ trajectory_data,
40
+ output_dir="downloads",
41
+ sample_rate=5):
42
+ """
43
+ Create an annotated video with swing analysis visualizations
44
+
45
+ Args:
46
+ video_path (str): Path to the original video
47
+ frames (list): List of video frames
48
+ detections (list): List of Detection objects
49
+ pose_data (dict): Pose estimation data
50
+ swing_phases (dict): Swing phase segmentation data
51
+ trajectory_data (dict): Trajectory and speed analysis data
52
+ output_dir (str): Directory to save the output video
53
+ sample_rate (int): The frame sampling rate used during processing
54
+
55
+ Returns:
56
+ str: Path to the annotated video
57
+ """
58
+ try:
59
+ # Create output directory if it doesn't exist
60
+ os.makedirs(output_dir, exist_ok=True)
61
+
62
+ # Get original video filename without extension
63
+ video_name = os.path.splitext(os.path.basename(video_path))[0]
64
+ output_path = os.path.join(output_dir, f"{video_name}_annotated.mp4")
65
+
66
+ # Get video properties
67
+ if not frames or len(frames) == 0:
68
+ raise ValueError("No frames provided for annotation")
69
+
70
+ height, width = frames[0].shape[:2]
71
+ fps = 30 # Default fps
72
+
73
+ # Create video writer
74
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
75
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
76
+
77
+ if not out.isOpened():
78
+ raise IOError(
79
+ f"Failed to create video writer for {output_path}. Check directory permissions."
80
+ )
81
+
82
+ # Process each frame
83
+ for i, frame in enumerate(tqdm(frames,
84
+ desc="Creating annotated video")):
85
+ # Create a copy of the frame for annotations
86
+ annotated_frame = frame.copy()
87
+
88
+ # Draw detections
89
+ frame_detections = [
90
+ d for d in detections if d.frame_idx == i * sample_rate
91
+ ]
92
+ for detection in frame_detections:
93
+ x1, y1, x2, y2 = map(int, detection.bbox)
94
+
95
+ # Draw bounding box
96
+ color = (0, 255,
97
+ 0) if detection.class_name == "person" else (0, 0,
98
+ 255)
99
+ cv2.rectangle(annotated_frame, (x1, y1), (x2, y2), color, 2)
100
+
101
+ # Draw label
102
+ label = f"{detection.class_name}: {detection.confidence:.2f}"
103
+ cv2.putText(annotated_frame, label, (x1, y1 - 10),
104
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
105
+
106
+ # Draw pose keypoints with different colors for different body parts
107
+ if i in pose_data:
108
+ keypoints = pose_data[i]
109
+
110
+ # Draw each keypoint with its corresponding body part color
111
+ for part_name, part_indices in BODY_PARTS_MAPPING.items():
112
+ color = BODY_PART_COLORS[part_name]
113
+ for idx in part_indices:
114
+ if idx < len(keypoints
115
+ ) and keypoints[idx] is not None and len(
116
+ keypoints[idx]) >= 2:
117
+ x, y = int(keypoints[idx][0]), int(
118
+ keypoints[idx][1])
119
+ cv2.circle(annotated_frame, (x, y), 5, color, -1)
120
+
121
+ # Draw connections between keypoints
122
+ mp_pose = mp.solutions.pose
123
+ connections = mp_pose.POSE_CONNECTIONS
124
+
125
+ for connection in connections:
126
+ start_idx, end_idx = connection
127
+
128
+ if (start_idx < len(keypoints) and end_idx < len(keypoints)
129
+ and keypoints[start_idx] is not None
130
+ and keypoints[end_idx] is not None
131
+ and len(keypoints[start_idx]) >= 2
132
+ and len(keypoints[end_idx]) >= 2):
133
+
134
+ # Determine the color based on the body part of the start point
135
+ color = None
136
+ for part_name, part_indices in BODY_PARTS_MAPPING.items(
137
+ ):
138
+ if start_idx in part_indices:
139
+ color = BODY_PART_COLORS[part_name]
140
+ break
141
+
142
+ # If no color found, use white
143
+ if color is None:
144
+ color = (255, 255, 255)
145
+
146
+ start_point = (int(keypoints[start_idx][0]),
147
+ int(keypoints[start_idx][1]))
148
+ end_point = (int(keypoints[end_idx][0]),
149
+ int(keypoints[end_idx][1]))
150
+
151
+ cv2.line(annotated_frame, start_point, end_point,
152
+ color, 2)
153
+
154
+ # Draw swing phase information
155
+ phase = None
156
+ for phase_name, phase_frames in swing_phases.items():
157
+ if i in phase_frames:
158
+ phase = phase_name
159
+ break
160
+
161
+ if phase:
162
+ cv2.putText(annotated_frame, f"Phase: {phase}", (10, 30),
163
+ cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
164
+
165
+ # Draw trajectory information if available
166
+ if i in trajectory_data:
167
+ traj_info = trajectory_data[i]
168
+ if "club_speed" in traj_info and traj_info["club_speed"]:
169
+ cv2.putText(
170
+ annotated_frame,
171
+ f"Club Speed: {traj_info['club_speed']:.1f} mph",
172
+ (10, 70), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0),
173
+ 2)
174
+
175
+ if "ball_trajectory" in traj_info and traj_info[
176
+ "ball_trajectory"]:
177
+ points = traj_info["ball_trajectory"]
178
+ for j in range(1, len(points)):
179
+ pt1 = (int(points[j - 1][0]), int(points[j - 1][1]))
180
+ pt2 = (int(points[j][0]), int(points[j][1]))
181
+ cv2.line(annotated_frame, pt1, pt2, (0, 255, 255), 2)
182
+
183
+ # Add legend for body part colors
184
+ legend_y_start = 110
185
+ legend_y_spacing = 30
186
+ legend_x = 10
187
+ legend_box_size = 20
188
+
189
+ # Draw legend title
190
+ cv2.putText(annotated_frame, "Body Parts Legend:",
191
+ (legend_x, legend_y_start - 10),
192
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
193
+
194
+ # Draw color boxes and labels for each body part
195
+ for idx, (part_name, color) in enumerate(BODY_PART_COLORS.items()):
196
+ y_pos = legend_y_start + idx * legend_y_spacing
197
+
198
+ # Draw color box
199
+ cv2.rectangle(annotated_frame,
200
+ (legend_x, y_pos - legend_box_size + 5),
201
+ (legend_x + legend_box_size, y_pos + 5), color,
202
+ -1)
203
+
204
+ # Draw part name
205
+ cv2.putText(annotated_frame, part_name.capitalize(),
206
+ (legend_x + legend_box_size + 10, y_pos + 5),
207
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
208
+
209
+ # Write the annotated frame to the output video
210
+ out.write(annotated_frame)
211
+
212
+ # Release video writer
213
+ out.release()
214
+
215
+ # Verify the file was created
216
+ if not os.path.exists(output_path) or os.path.getsize(
217
+ output_path) == 0:
218
+ raise IOError(f"Failed to create video file at {output_path}")
219
+
220
+ print(f"Annotated video saved to: {output_path}")
221
+ return output_path
222
+
223
+ except Exception as e:
224
+ print(f"Error creating annotated video: {str(e)}")
225
+ raise
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ opencv-python==4.8.1.78
2
+ yt-dlp==2025.2.19
3
+ ultralytics==8.1.0
4
+ mediapipe==0.10.13
5
+ numpy==1.26.4
6
+ matplotlib==3.8.0
7
+ torch==2.2.0
8
+ torchvision==0.17.0
9
+ openai==1.6.0
10
+ python-dotenv==1.0.0
11
+ tqdm==4.66.1
12
+ streamlit==1.30.0
run_streamlit.sh ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Activate the virtual environment
4
+ source .venv/bin/activate
5
+
6
+ # Run the Streamlit app
7
+ streamlit run app/streamlit_app.py
8
+
9
+ # Deactivate the virtual environment when done
10
+ deactivate
setup_directories.sh ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Create downloads directory for the application
4
+ mkdir -p downloads
5
+
6
+ # Set permissions
7
+ chmod 755 downloads
8
+
9
+ echo "Directory created and permissions set:"
10
+ echo "- downloads: for storing downloaded YouTube videos and annotated videos"
11
+
12
+ # Create .env file if it doesn't exist
13
+ if [ ! -f .env ]; then
14
+ echo "Creating .env file template..."
15
+ echo "OPENAI_API_KEY=your_api_key_here" > .env
16
+ echo ".env file created. Please edit it to add your OpenAI API key."
17
+ fi
18
+
19
+ echo "Setup complete!"