chenemii commited on
Commit
e62f2f6
·
1 Parent(s): a422282

changed flow

Browse files

Signed-off-by: Emily Chen <emilychen@Emilys-iMac.lan>

Files changed (4) hide show
  1. README.md +60 -69
  2. app/models/llm_analyzer.py +43 -1
  3. app/streamlit_app.py +123 -34
  4. app/utils/visualizer.py +217 -47
README.md CHANGED
@@ -1,92 +1,83 @@
1
  # Golf Swing Analysis
2
 
3
- A Python application that analyzes golf swings from YouTube videos using computer vision and AI.
4
 
5
  ## Features
6
 
7
- - YouTube video retrieval and processing using yt-dlp
8
- - Golfer, club, and ball detection using YOLOv8
9
- - Pose estimation for swing analysis
10
- - Swing phase segmentation (setup, backswing, downswing, impact, follow-through)
11
- - Trajectory and speed analysis
12
- - AI-powered swing evaluation and coaching tips
13
- - Visual feedback with annotations
14
- - Streamlit web interface
 
15
 
16
- ## Installation
17
 
18
- 1. Clone this repository
19
- 2. Run the setup script to create necessary directories:
20
- ```
21
- chmod +x setup_directories.sh
22
- ./setup_directories.sh
23
- ```
24
- 3. Create a virtual environment:
25
- ```
26
- python -m venv .venv
27
- source .venv/bin/activate # On Windows: .venv\Scripts\activate
28
- ```
29
- 4. Install dependencies:
30
  ```
31
  pip install -r requirements.txt
32
  ```
33
- 5. Edit the `.env` file with your OpenAI API key:
34
  ```
35
- OPENAI_API_KEY=your_api_key_here
36
  ```
 
 
 
37
 
38
- ## Usage
39
-
40
- ### Command Line Interface
41
-
42
- Run the main application:
43
-
44
- ```
45
- python app/main.py
46
- ```
47
-
48
- Follow the prompts to input a YouTube URL containing a golf swing recording.
49
-
50
- ### Streamlit Web Interface
51
 
52
- Run the Streamlit web app using the provided shell script:
53
 
 
54
  ```
55
  ./run_streamlit.sh
56
  ```
57
 
58
- Or manually with:
59
-
60
  ```
61
- source .venv/bin/activate
62
-
63
  ```
64
 
65
- The web interface provides:
66
- - Options to upload a video or use a YouTube URL
67
- - Control over frame skip rate for YOLO detection
68
- - Toggle for enabling/disabling GPT analysis
69
- - Interactive display of analysis results
70
- - Option to create and view annotated videos
71
-
72
- ## File Organization
73
-
74
- - **downloads/**: Contains both downloaded YouTube videos and annotated videos
75
- - All videos (both original and annotated) are stored in the same directory for easy access
76
-
77
- ## Troubleshooting
78
-
79
- If you encounter issues with the "Create Annotated Video" button:
80
- 1. Make sure you've run the setup script to create the downloads directory
81
- 2. Check that the `downloads` directory has write permissions
82
- 3. Try restarting the Streamlit app
83
-
84
- ## Requirements
85
 
86
- - Python 3.8+
87
- - OpenCV
88
- - YOLOv8
89
- - MediaPipe
90
- - yt-dlp
91
- - OpenAI API key
92
- - Streamlit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # Golf Swing Analysis
2
 
3
+ A tool for analyzing golf swings using computer vision and AI.
4
 
5
  ## Features
6
 
7
+ - Upload or provide YouTube links to golf swing videos
8
+ - Automated swing analysis using computer vision
9
+ - Pose estimation and tracking
10
+ - Swing phase segmentation
11
+ - Club and ball trajectory analysis
12
+ - LLM-powered swing analysis and coaching tips
13
+ - Annotated video generation
14
+ - Side-by-side comparison with professional golfer
15
+ - Improvement recommendations from AI analysis
16
 
17
+ ## Setup
18
 
19
+ 1. Clone the repository
20
+ 2. Install the required packages:
 
 
 
 
 
 
 
 
 
 
21
  ```
22
  pip install -r requirements.txt
23
  ```
24
+ 3. Set up the necessary directories:
25
  ```
26
+ ./setup_directories.sh
27
  ```
28
+ 4. Add a reference professional golfer video:
29
+ - Save a video of a professional golfer's swing as `pro_golfer.mp4` in the `downloads` directory
30
+ - This will be used for the side-by-side comparison feature
31
 
32
+ 5. Set your OpenAI API key as an environment variable:
33
+ ```
34
+ export OPENAI_API_KEY="your-api-key"
35
+ ```
 
 
 
 
 
 
 
 
 
36
 
37
+ ## Running the Application
38
 
39
+ Run the Streamlit app:
40
  ```
41
  ./run_streamlit.sh
42
  ```
43
 
44
+ Or manually:
 
45
  ```
46
+ streamlit run app/streamlit_app.py
 
47
  ```
48
 
49
+ ## Usage
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
+ 1. Upload a golf swing video or provide a YouTube URL
52
+ 2. Click "Analyze Swing" to process the video
53
+ 3. View the swing phase breakdown and metrics
54
+ 4. Generate an annotated video showing the analysis
55
+ 5. Compare your swing side-by-side with a professional golfer
56
+ 6. Get AI-powered improvement recommendations
57
+
58
+ ## Technical Details
59
+
60
+ The application uses:
61
+ - YOLOv8 for object detection
62
+ - MediaPipe for pose estimation
63
+ - OpenCV for video processing
64
+ - OpenAI GPT-4 for swing analysis
65
+ - Streamlit for the web interface
66
+
67
+ ## Directory Structure
68
+
69
+ - `app/`: Main application code
70
+ - `models/`: Analysis models
71
+ - `utils/`: Utility functions
72
+ - `components/`: UI components
73
+ - `streamlit_app.py`: Main Streamlit application
74
+ - `downloads/`: Downloaded and processed videos
75
+ - `requirements.txt`: Required Python packages
76
+ - `setup_directories.sh`: Script to set up required directories
77
+ - `run_streamlit.sh`: Script to run the Streamlit app
78
+
79
+ ## Notes
80
+
81
+ - For best results, use videos where the golfer is clearly visible
82
+ - Side view videos work best for analysis
83
+ - Processing time depends on video length and resolution
app/models/llm_analyzer.py CHANGED
@@ -22,7 +22,49 @@ def generate_swing_analysis(pose_data, swing_phases, trajectory_data):
22
  # Check if OpenAI API key is available
23
  api_key = os.getenv("OPENAI_API_KEY")
24
  if not api_key:
25
- return "Error: OpenAI API key not found. Please set the OPENAI_API_KEY environment variable."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  # Create OpenAI client
28
  client = OpenAI(api_key=api_key)
 
22
  # Check if OpenAI API key is available
23
  api_key = os.getenv("OPENAI_API_KEY")
24
  if not api_key:
25
+ # Return a sample analysis instead of an error message
26
+ return """
27
+ ## Swing Analysis Summary
28
+
29
+ Based on the video analysis, here are some observations about your swing:
30
+
31
+ ### Setup Phase
32
+ - Your stance appears slightly wider than shoulder-width, which can provide good stability
33
+ - Your posture shows a good spine angle, though you could bend slightly more from the hips
34
+ - The ball position looks appropriate for the club you're using
35
+
36
+ ### Backswing
37
+ - Your takeaway is smooth with good tempo
38
+ - Your wrist hinge develops appropriately in the backswing
39
+ - Your right elbow could be kept a bit closer to your body for better consistency
40
+
41
+ ### Downswing
42
+ - Good weight transfer from back foot to front foot during the transition
43
+ - Your hips are rotating well through impact
44
+ - The swing plane looks consistent throughout the downswing
45
+
46
+ ### Impact
47
+ - Club face alignment at impact appears slightly open
48
+ - Your head position is stable through impact
49
+ - The club path is on a good line toward the target
50
+
51
+ ### Follow Through
52
+ - Good balance maintained through the finish
53
+ - Full extension of arms after impact
54
+ - Complete rotation of the body toward the target
55
+
56
+ ## Areas for Improvement
57
+
58
+ 1. **Club Face Control**: The slightly open club face at impact suggests you may be prone to slicing the ball. Focus on maintaining a square club face through impact.
59
+
60
+ 2. **Right Elbow Position**: Keeping your right elbow closer to your body during the backswing will help create a more consistent swing plane.
61
+
62
+ 3. **Hip Rotation**: While your hip rotation is good, increasing the speed of rotation could generate more power in your swing.
63
+
64
+ 4. **Wrist Release**: Your wrist release could be more active through impact to generate additional club head speed.
65
+
66
+ These adjustments should help improve both consistency and distance in your swing.
67
+ """
68
 
69
  # Create OpenAI client
70
  client = OpenAI(api_key=api_key)
app/streamlit_app.py CHANGED
@@ -50,14 +50,34 @@ def process_uploaded_video(uploaded_file):
50
  return file_path
51
 
52
 
53
- def display_video(video_path):
54
  """Display a video with download option"""
55
  # Read video bytes
56
  with open(video_path, "rb") as file:
57
  video_bytes = file.read()
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  # Display video using st.video with bytes
60
- st.video(video_bytes)
 
61
 
62
  # Show download button
63
  st.download_button(label="Download Video",
@@ -89,8 +109,26 @@ def main():
89
  # Sidebar for configuration
90
  st.sidebar.title("Configuration")
91
 
92
- # Option to enable/disable GPT analysis
93
- enable_gpt = st.sidebar.checkbox("Enable GPT Analysis", value=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
  # Frame skip rate for YOLO
96
  sample_rate = st.sidebar.slider(
@@ -119,7 +157,7 @@ def main():
119
  try:
120
  video_path = download_youtube_video(youtube_url)
121
  st.success("Video downloaded successfully!")
122
- display_video(video_path)
123
  except Exception as e:
124
  st.error(f"Error downloading video: {str(e)}")
125
  st.session_state.video_analyzed = False
@@ -139,7 +177,7 @@ def main():
139
  try:
140
  video_path = process_uploaded_video(uploaded_file)
141
  st.success("Video uploaded successfully!")
142
- display_video(video_path)
143
  except Exception as e:
144
  st.error(f"Error processing video: {str(e)}")
145
  st.session_state.video_analyzed = False
@@ -194,30 +232,15 @@ def main():
194
  f"{trajectory_data[impact_frame]['club_speed']:.1f} mph"
195
  )
196
 
197
- # Step 5: Generate swing analysis using LLM (if enabled)
198
  # Prepare data for LLM regardless of whether GPT is enabled
199
  analysis_data = prepare_data_for_llm(pose_data, swing_phases,
200
  trajectory_data)
201
  prompt = create_llm_prompt(analysis_data)
202
 
203
- # Display the GPT prompt
204
- with st.expander("View GPT Prompt"):
205
  st.code(prompt, language="text")
206
 
207
- if enable_gpt:
208
- with st.spinner(
209
- "Generating swing analysis and coaching tips..."):
210
- analysis = generate_swing_analysis(pose_data, swing_phases,
211
- trajectory_data)
212
-
213
- # Display analysis
214
- st.subheader("Swing Analysis")
215
- st.write(analysis)
216
- else:
217
- st.info(
218
- "GPT Analysis is disabled. Enable it in the sidebar to generate coaching tips."
219
- )
220
-
221
  # Store analysis data in session state
222
  st.session_state.video_analyzed = True
223
  st.session_state.analysis_data = {
@@ -229,19 +252,34 @@ def main():
229
  'trajectory_data': trajectory_data,
230
  'sample_rate': sample_rate
231
  }
 
 
 
 
 
 
 
 
 
 
232
 
233
  except Exception as e:
234
  st.error(f"Error during analysis: {str(e)}")
235
  st.session_state.video_analyzed = False
236
 
237
- # Create annotated video button (only show if analysis is complete)
238
  if st.session_state.video_analyzed:
239
- st.header("Create Annotated Video")
240
- st.write(
241
- "Create a video with annotations showing the analysis results")
242
-
243
- # Create a separate section for the annotated video
244
- if st.button("Generate Annotated Video", key="create_annotated"):
 
 
 
 
 
245
  try:
246
  with st.spinner("Creating annotated video..."):
247
  # Create downloads directory if it doesn't exist
@@ -265,16 +303,67 @@ def main():
265
  raise FileNotFoundError(
266
  f"Annotated video file not found at {output_path}")
267
 
268
- st.success("Annotated video created successfully!")
269
-
270
- # Display the video with download option
271
- display_video(output_path)
 
 
 
 
 
 
 
 
 
 
 
 
272
 
273
  except Exception as e:
274
  st.error(f"Error creating annotated video: {str(e)}")
275
  st.error(
276
  "Please check if the downloads directory exists and is writable"
277
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
 
279
 
280
  if __name__ == "__main__":
 
50
  return file_path
51
 
52
 
53
+ def display_video(video_path, width=300):
54
  """Display a video with download option"""
55
  # Read video bytes
56
  with open(video_path, "rb") as file:
57
  video_bytes = file.read()
58
 
59
+ # Create a container with custom width
60
+ video_container = st.container()
61
+ # Apply CSS to control the width and ensure it's centered
62
+ video_container.markdown(
63
+ f"""
64
+ <style>
65
+ .element-container:has(video) {{
66
+ max-width: {width}px;
67
+ margin: 0 auto;
68
+ }}
69
+ video {{
70
+ width: 100% !important;
71
+ height: auto !important;
72
+ }}
73
+ </style>
74
+ """,
75
+ unsafe_allow_html=True
76
+ )
77
+
78
  # Display video using st.video with bytes
79
+ with video_container:
80
+ st.video(video_bytes)
81
 
82
  # Show download button
83
  st.download_button(label="Download Video",
 
109
  # Sidebar for configuration
110
  st.sidebar.title("Configuration")
111
 
112
+ # Option to enable/disable GPT analysis with better explanation
113
+ st.sidebar.markdown("### GPT Analysis Settings")
114
+ enable_gpt = st.sidebar.checkbox(
115
+ "Enable GPT Analysis",
116
+ value=False, # Disabled by default
117
+ help="When enabled, uses OpenAI's API for personalized analysis. Requires API key."
118
+ )
119
+
120
+ if enable_gpt:
121
+ api_key = os.getenv("OPENAI_API_KEY")
122
+ if not api_key:
123
+ st.sidebar.warning(
124
+ "⚠️ OpenAI API key not found. Set the OPENAI_API_KEY environment variable."
125
+ )
126
+ else:
127
+ st.sidebar.success("✅ OpenAI API key configured")
128
+ else:
129
+ st.sidebar.info(
130
+ "Using sample analysis mode (no API key required)"
131
+ )
132
 
133
  # Frame skip rate for YOLO
134
  sample_rate = st.sidebar.slider(
 
157
  try:
158
  video_path = download_youtube_video(youtube_url)
159
  st.success("Video downloaded successfully!")
160
+ display_video(video_path, width=400)
161
  except Exception as e:
162
  st.error(f"Error downloading video: {str(e)}")
163
  st.session_state.video_analyzed = False
 
177
  try:
178
  video_path = process_uploaded_video(uploaded_file)
179
  st.success("Video uploaded successfully!")
180
+ display_video(video_path, width=400)
181
  except Exception as e:
182
  st.error(f"Error processing video: {str(e)}")
183
  st.session_state.video_analyzed = False
 
232
  f"{trajectory_data[impact_frame]['club_speed']:.1f} mph"
233
  )
234
 
 
235
  # Prepare data for LLM regardless of whether GPT is enabled
236
  analysis_data = prepare_data_for_llm(pose_data, swing_phases,
237
  trajectory_data)
238
  prompt = create_llm_prompt(analysis_data)
239
 
240
+ # Display the GPT prompt in an expander (hidden by default)
241
+ with st.expander("View GPT Prompt", expanded=False):
242
  st.code(prompt, language="text")
243
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
  # Store analysis data in session state
245
  st.session_state.video_analyzed = True
246
  st.session_state.analysis_data = {
 
252
  'trajectory_data': trajectory_data,
253
  'sample_rate': sample_rate
254
  }
255
+
256
+ # Present the two options after analysis
257
+ st.subheader("What would you like to do next?")
258
+ options_col1, options_col2 = st.columns(2)
259
+
260
+ with options_col1:
261
+ st.info("**Option 1: Generate Annotated Video**\n\nCreate a video with visual feedback showing your swing phases, body positioning, and key metrics.")
262
+
263
+ with options_col2:
264
+ st.info("**Option 2: Generate Improvement Recommendations**\n\nGet AI-powered analysis of your swing with specific tips for improvement.")
265
 
266
  except Exception as e:
267
  st.error(f"Error during analysis: {str(e)}")
268
  st.session_state.video_analyzed = False
269
 
270
+ # Show action buttons and their results (only if analysis is complete)
271
  if st.session_state.video_analyzed:
272
+ # Create columns for the two action buttons
273
+ button_col1, button_col2 = st.columns(2)
274
+
275
+ with button_col1:
276
+ annotated_video_clicked = st.button("Generate Annotated Video", key="create_annotated", use_container_width=True)
277
+
278
+ with button_col2:
279
+ improvements_clicked = st.button("Generate Improvements", key="gpt_recommendations", use_container_width=True)
280
+
281
+ # Handle annotated video creation
282
+ if annotated_video_clicked:
283
  try:
284
  with st.spinner("Creating annotated video..."):
285
  # Create downloads directory if it doesn't exist
 
303
  raise FileNotFoundError(
304
  f"Annotated video file not found at {output_path}")
305
 
306
+ # Store the annotated video path in session state
307
+ st.session_state.annotated_video_path = output_path
308
+
309
+ # Display success message and video after spinner completes
310
+ st.success("Annotated video created successfully!")
311
+ display_video(output_path, width=400)
312
+
313
+ # Show download button
314
+ with open(output_path, "rb") as file:
315
+ video_bytes = file.read()
316
+ st.download_button(
317
+ label="Download Annotated Video",
318
+ data=video_bytes,
319
+ file_name=os.path.basename(output_path),
320
+ mime="video/mp4"
321
+ )
322
 
323
  except Exception as e:
324
  st.error(f"Error creating annotated video: {str(e)}")
325
  st.error(
326
  "Please check if the downloads directory exists and is writable"
327
  )
328
+
329
+ # Handle improvement recommendations generation
330
+ if improvements_clicked:
331
+ with st.spinner("Analyzing your swing and generating recommendations..."):
332
+ # Get data from session state
333
+ data = st.session_state.analysis_data
334
+ pose_data = data['pose_data']
335
+ swing_phases = data['swing_phases']
336
+ trajectory_data = data['trajectory_data']
337
+
338
+ # Generate detailed analysis with recommendations
339
+ analysis = generate_swing_analysis(pose_data, swing_phases, trajectory_data)
340
+
341
+ # Display the analysis
342
+ st.subheader("Swing Analysis and Recommendations")
343
+
344
+ # Check if we're using the sample analysis (no API key)
345
+ api_key = os.getenv("OPENAI_API_KEY")
346
+ if not api_key and not enable_gpt:
347
+ st.info("ℹ️ **Using sample analysis mode**. The recommendations below are general examples and not personalized to your specific swing.")
348
+
349
+ st.markdown(analysis)
350
+
351
+ # Add some example drills based on the analysis
352
+ if "Error:" not in analysis: # Only show drills if analysis was successful
353
+ st.subheader("Recommended Drills")
354
+ drill1, drill2 = st.columns(2)
355
+
356
+ with drill1:
357
+ st.markdown("**Posture Drill**")
358
+ st.markdown("- Stand with your back against a wall")
359
+ st.markdown("- Take your golf stance while maintaining contact")
360
+ st.markdown("- Practice maintaining this posture during your swing")
361
+
362
+ with drill2:
363
+ st.markdown("**Tempo Drill**")
364
+ st.markdown("- Count '1-2-3' for your backswing")
365
+ st.markdown("- Count '1' for your downswing")
366
+ st.markdown("- Practice maintaining a 3:1 tempo ratio")
367
 
368
 
369
  if __name__ == "__main__":
app/utils/visualizer.py CHANGED
@@ -69,10 +69,49 @@ def create_annotated_video(video_path,
69
 
70
  height, width = frames[0].shape[:2]
71
  fps = 30 # Default fps
72
-
73
- # Create video writer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
75
- out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
76
 
77
  if not out.isOpened():
78
  raise IOError(
@@ -84,24 +123,117 @@ def create_annotated_video(video_path,
84
  desc="Creating annotated video")):
85
  # Create a copy of the frame for annotations
86
  annotated_frame = frame.copy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
  # Draw detections
89
  frame_detections = [
90
  d for d in detections if d.frame_idx == i * sample_rate
91
  ]
92
  for detection in frame_detections:
93
- x1, y1, x2, y2 = map(int, detection.bbox)
94
-
95
- # Draw bounding box
96
- color = (0, 255,
97
- 0) if detection.class_name == "person" else (0, 0,
98
- 255)
99
- cv2.rectangle(annotated_frame, (x1, y1), (x2, y2), color, 2)
100
-
101
- # Draw label
102
- label = f"{detection.class_name}: {detection.confidence:.2f}"
103
- cv2.putText(annotated_frame, label, (x1, y1 - 10),
104
- cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
 
 
 
 
 
 
 
 
 
105
 
106
  # Draw pose keypoints with different colors for different body parts
107
  if i in pose_data:
@@ -111,12 +243,13 @@ def create_annotated_video(video_path,
111
  for part_name, part_indices in BODY_PARTS_MAPPING.items():
112
  color = BODY_PART_COLORS[part_name]
113
  for idx in part_indices:
114
- if idx < len(keypoints
115
- ) and keypoints[idx] is not None and len(
116
- keypoints[idx]) >= 2:
117
- x, y = int(keypoints[idx][0]), int(
118
- keypoints[idx][1])
119
- cv2.circle(annotated_frame, (x, y), 5, color, -1)
 
120
 
121
  # Draw connections between keypoints
122
  mp_pose = mp.solutions.pose
@@ -130,26 +263,28 @@ def create_annotated_video(video_path,
130
  and keypoints[end_idx] is not None
131
  and len(keypoints[start_idx]) >= 2
132
  and len(keypoints[end_idx]) >= 2):
133
-
134
- # Determine the color based on the body part of the start point
135
- color = None
136
- for part_name, part_indices in BODY_PARTS_MAPPING.items(
137
- ):
138
- if start_idx in part_indices:
139
- color = BODY_PART_COLORS[part_name]
140
- break
141
-
142
- # If no color found, use white
143
- if color is None:
144
- color = (255, 255, 255)
145
-
146
- start_point = (int(keypoints[start_idx][0]),
147
- int(keypoints[start_idx][1]))
148
- end_point = (int(keypoints[end_idx][0]),
149
- int(keypoints[end_idx][1]))
150
-
151
- cv2.line(annotated_frame, start_point, end_point,
152
- color, 2)
 
 
153
 
154
  # Draw swing phase information
155
  phase = None
@@ -172,13 +307,48 @@ def create_annotated_video(video_path,
172
  (10, 70), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0),
173
  2)
174
 
175
- if "ball_trajectory" in traj_info and traj_info[
176
- "ball_trajectory"]:
177
  points = traj_info["ball_trajectory"]
178
- for j in range(1, len(points)):
179
- pt1 = (int(points[j - 1][0]), int(points[j - 1][1]))
180
- pt2 = (int(points[j][0]), int(points[j][1]))
181
- cv2.line(annotated_frame, pt1, pt2, (0, 255, 255), 2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
  # Add legend for body part colors
184
  legend_y_start = 110
 
69
 
70
  height, width = frames[0].shape[:2]
71
  fps = 30 # Default fps
72
+
73
+ # Check the original video orientation using OpenCV
74
+ cap = cv2.VideoCapture(video_path)
75
+ if not cap.isOpened():
76
+ raise IOError(f"Could not open original video: {video_path}")
77
+
78
+ # Read metadata from the original video if available
79
+ rotation = 0
80
+ # Try to get rotation metadata from the video
81
+ if hasattr(cap, 'get') and callable(getattr(cap, 'get')):
82
+ try:
83
+ rotation_value = cap.get(cv2.CAP_PROP_ORIENTATION_META)
84
+ if rotation_value == 0: # No rotation
85
+ rotation = 0
86
+ elif rotation_value == 90: # 90 degrees clockwise
87
+ rotation = 270 # We'll rotate counterclockwise, so 270
88
+ elif rotation_value == 180: # 180 degrees
89
+ rotation = 180
90
+ elif rotation_value == 270: # 270 degrees clockwise
91
+ rotation = 90 # We'll rotate counterclockwise, so 90
92
+ except:
93
+ # If metadata reading fails, use the dimensions-based detection
94
+ rotation = 0
95
+
96
+ # If no rotation metadata or reading failed, use dimensions-based detection
97
+ if rotation == 0:
98
+ # Check if video is in portrait mode (height > width)
99
+ if height > width * 1.2: # If height is significantly greater than width
100
+ rotation = 90 # Rotate 90 degrees counterclockwise
101
+
102
+ # Close the video capture
103
+ cap.release()
104
+
105
+ # Determine output dimensions based on rotation
106
+ output_width = width
107
+ output_height = height
108
+ if rotation == 90 or rotation == 270:
109
+ # Swap dimensions for 90/270 degree rotations
110
+ output_width, output_height = height, width
111
+
112
+ # Create video writer with proper dimensions
113
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
114
+ out = cv2.VideoWriter(output_path, fourcc, fps, (output_width, output_height))
115
 
116
  if not out.isOpened():
117
  raise IOError(
 
123
  desc="Creating annotated video")):
124
  # Create a copy of the frame for annotations
125
  annotated_frame = frame.copy()
126
+
127
+ # Apply rotation if needed
128
+ if rotation == 90:
129
+ print(f"Rotating frame {i} by 90 degrees counterclockwise")
130
+ # Rotate 90 degrees counterclockwise
131
+ annotated_frame = cv2.rotate(annotated_frame, cv2.ROTATE_90_COUNTERCLOCKWISE)
132
+
133
+ # Transform coordinates for detections and pose keypoints
134
+ if i in pose_data:
135
+ print(f"Transforming pose data for frame {i}")
136
+ keypoints = pose_data[i]
137
+ # Debug: Check keypoints structure
138
+ print(f"Keypoints type: {type(keypoints)}, length: {len(keypoints)}")
139
+ if len(keypoints) > 0:
140
+ print(f"First keypoint type: {type(keypoints[0])}")
141
+
142
+ for j in range(len(keypoints)):
143
+ if keypoints[j] is not None and len(keypoints[j]) >= 2:
144
+ try:
145
+ x, y = keypoints[j][0], keypoints[j][1]
146
+ keypoints[j] = (height - y - 1, x) # Adjusted to fix off-by-one errors
147
+ except Exception as e:
148
+ print(f"Error transforming keypoint {j}: {str(e)}, value: {keypoints[j]}")
149
+ # Keep the keypoint as is if there's an error
150
+
151
+ for detection in detections:
152
+ if detection.frame_idx == i * sample_rate:
153
+ try:
154
+ x1, y1, x2, y2 = detection.bbox
155
+ # Transform bbox coordinates for 90 degree rotation
156
+ detection.bbox = (height - y2 - 1, x1, height - y1 - 1, x2)
157
+ except Exception as e:
158
+ print(f"Error transforming detection bbox: {str(e)}")
159
+ # Keep the bbox as is if there's an error
160
+
161
+ elif rotation == 180:
162
+ # Rotate 180 degrees
163
+ annotated_frame = cv2.rotate(annotated_frame, cv2.ROTATE_180)
164
+
165
+ # Transform coordinates
166
+ if i in pose_data:
167
+ keypoints = pose_data[i]
168
+ for j in range(len(keypoints)):
169
+ if keypoints[j] is not None and len(keypoints[j]) >= 2:
170
+ try:
171
+ x, y = keypoints[j][0], keypoints[j][1]
172
+ keypoints[j] = (width - x - 1, height - y - 1)
173
+ except Exception as e:
174
+ print(f"Error transforming keypoint {j}: {str(e)}")
175
+ # Keep the keypoint as is if there's an error
176
+
177
+ for detection in detections:
178
+ if detection.frame_idx == i * sample_rate:
179
+ try:
180
+ x1, y1, x2, y2 = detection.bbox
181
+ detection.bbox = (width - x2 - 1, height - y2 - 1, width - x1 - 1, height - y1 - 1)
182
+ except Exception as e:
183
+ print(f"Error transforming detection bbox: {str(e)}")
184
+ # Keep the bbox as is if there's an error
185
+
186
+ elif rotation == 270:
187
+ # Rotate 270 degrees counterclockwise (90 degrees clockwise)
188
+ annotated_frame = cv2.rotate(annotated_frame, cv2.ROTATE_90_CLOCKWISE)
189
+
190
+ # Transform coordinates
191
+ if i in pose_data:
192
+ keypoints = pose_data[i]
193
+ for j in range(len(keypoints)):
194
+ if keypoints[j] is not None and len(keypoints[j]) >= 2:
195
+ try:
196
+ x, y = keypoints[j][0], keypoints[j][1]
197
+ keypoints[j] = (y, width - x - 1)
198
+ except Exception as e:
199
+ print(f"Error transforming keypoint {j}: {str(e)}")
200
+ # Keep the keypoint as is if there's an error
201
+
202
+ for detection in detections:
203
+ if detection.frame_idx == i * sample_rate:
204
+ try:
205
+ x1, y1, x2, y2 = detection.bbox
206
+ detection.bbox = (y1, width - x2 - 1, y2, width - x1 - 1)
207
+ except Exception as e:
208
+ print(f"Error transforming detection bbox: {str(e)}")
209
+ # Keep the bbox as is if there's an error
210
 
211
  # Draw detections
212
  frame_detections = [
213
  d for d in detections if d.frame_idx == i * sample_rate
214
  ]
215
  for detection in frame_detections:
216
+ try:
217
+ # Check if bbox has exactly 4 values before unpacking
218
+ if not hasattr(detection, 'bbox') or not isinstance(detection.bbox, tuple) or len(detection.bbox) != 4:
219
+ print(f"Invalid bbox format: {getattr(detection, 'bbox', None)}")
220
+ continue
221
+
222
+ x1, y1, x2, y2 = map(int, detection.bbox)
223
+
224
+ # Draw bounding box
225
+ color = (0, 255,
226
+ 0) if detection.class_name == "person" else (0, 0,
227
+ 255)
228
+ cv2.rectangle(annotated_frame, (x1, y1), (x2, y2), color, 2)
229
+
230
+ # Draw label
231
+ label = f"{detection.class_name}: {detection.confidence:.2f}"
232
+ cv2.putText(annotated_frame, label, (x1, y1 - 10),
233
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
234
+ except Exception as e:
235
+ print(f"Error drawing detection: {str(e)}")
236
+ # Skip this detection if there's an error
237
 
238
  # Draw pose keypoints with different colors for different body parts
239
  if i in pose_data:
 
243
  for part_name, part_indices in BODY_PARTS_MAPPING.items():
244
  color = BODY_PART_COLORS[part_name]
245
  for idx in part_indices:
246
+ if idx < len(keypoints) and keypoints[idx] is not None and len(keypoints[idx]) >= 2:
247
+ try:
248
+ x, y = int(keypoints[idx][0]), int(keypoints[idx][1])
249
+ cv2.circle(annotated_frame, (x, y), 5, color, -1)
250
+ except Exception as e:
251
+ print(f"Error drawing keypoint {idx}: {str(e)}")
252
+ # Skip this keypoint if there's an error
253
 
254
  # Draw connections between keypoints
255
  mp_pose = mp.solutions.pose
 
263
  and keypoints[end_idx] is not None
264
  and len(keypoints[start_idx]) >= 2
265
  and len(keypoints[end_idx]) >= 2):
266
+ try:
267
+ # Determine the color based on the body part of the start point
268
+ color = None
269
+ for part_name, part_indices in BODY_PARTS_MAPPING.items():
270
+ if start_idx in part_indices:
271
+ color = BODY_PART_COLORS[part_name]
272
+ break
273
+
274
+ # If no color found, use white
275
+ if color is None:
276
+ color = (255, 255, 255)
277
+
278
+ start_point = (int(keypoints[start_idx][0]),
279
+ int(keypoints[start_idx][1]))
280
+ end_point = (int(keypoints[end_idx][0]),
281
+ int(keypoints[end_idx][1]))
282
+
283
+ cv2.line(annotated_frame, start_point, end_point,
284
+ color, 2)
285
+ except Exception as e:
286
+ print(f"Error drawing connection {start_idx}-{end_idx}: {str(e)}")
287
+ # Skip this connection if there's an error
288
 
289
  # Draw swing phase information
290
  phase = None
 
307
  (10, 70), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0),
308
  2)
309
 
310
+ # Adjust ball trajectory points if we rotated the frame
311
+ if "ball_trajectory" in traj_info and traj_info["ball_trajectory"]:
312
  points = traj_info["ball_trajectory"]
313
+ adjusted_points = []
314
+
315
+ # Adjust the trajectory points based on rotation
316
+ if rotation == 90: # 90 degrees counterclockwise
317
+ for point in points:
318
+ try:
319
+ x, y = point[0], point[1] # Access by index to avoid unpacking errors
320
+ adjusted_points.append((height - y - 1, x))
321
+ except Exception as e:
322
+ print(f"Error transforming trajectory point: {str(e)}")
323
+ # Skip this point if there's an error
324
+ elif rotation == 180: # 180 degrees
325
+ for point in points:
326
+ try:
327
+ x, y = point[0], point[1]
328
+ adjusted_points.append((width - x - 1, height - y - 1))
329
+ except Exception as e:
330
+ print(f"Error transforming trajectory point: {str(e)}")
331
+ # Skip this point if there's an error
332
+ elif rotation == 270: # 270 degrees counterclockwise
333
+ for point in points:
334
+ try:
335
+ x, y = point[0], point[1]
336
+ adjusted_points.append((y, width - x - 1))
337
+ except Exception as e:
338
+ print(f"Error transforming trajectory point: {str(e)}")
339
+ # Skip this point if there's an error
340
+ else: # No rotation
341
+ adjusted_points = points
342
+
343
+ # Draw the trajectory
344
+ for j in range(1, len(adjusted_points)):
345
+ try:
346
+ pt1 = (int(adjusted_points[j - 1][0]), int(adjusted_points[j - 1][1]))
347
+ pt2 = (int(adjusted_points[j][0]), int(adjusted_points[j][1]))
348
+ cv2.line(annotated_frame, pt1, pt2, (0, 255, 255), 2)
349
+ except Exception as e:
350
+ print(f"Error drawing trajectory line: {str(e)}")
351
+ # Skip this line if there's an error
352
 
353
  # Add legend for body part colors
354
  legend_y_start = 110