chenemii commited on
Commit
4a0e3a8
·
1 Parent(s): 6035b75

compare with pro

Browse files
app/streamlit_app.py CHANGED
@@ -17,12 +17,13 @@ load_dotenv()
17
  # Add the app directory to the path
18
  sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
19
 
20
- from app.utils.video_downloader import download_youtube_video
21
  from app.utils.video_processor import process_video
22
  from app.models.pose_estimator import analyze_pose
23
  from app.models.swing_analyzer import segment_swing, analyze_trajectory
24
  from app.models.llm_analyzer import generate_swing_analysis, create_llm_prompt, prepare_data_for_llm, check_llm_services
25
  from app.utils.visualizer import create_annotated_video
 
26
 
27
  # Set page config
28
  st.set_page_config(page_title="Par-ity Project: Golf Swing Analysis 🏌️‍♀️",
@@ -102,6 +103,8 @@ def main():
102
  'trajectory_data': None,
103
  'sample_rate': None
104
  }
 
 
105
 
106
  # Sidebar for configuration
107
  st.sidebar.title("Configuration")
@@ -164,6 +167,23 @@ def main():
164
  value=5,
165
  help=
166
  "Process every Nth frame. Higher values = faster but less accurate.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
  # Video input options
169
  st.header("Video Input")
@@ -208,6 +228,18 @@ def main():
208
  st.error(f"Error processing video: {str(e)}")
209
  st.session_state.video_analyzed = False
210
  return
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
  # Process video if available and analyze button was clicked
213
  if video_path and analyze_clicked:
@@ -255,9 +287,9 @@ def main():
255
  'prompt': prompt
256
  }
257
 
258
- # Present the two options after analysis
259
  st.subheader("What would you like to do next?")
260
- options_col1, options_col2 = st.columns(2)
261
 
262
  with options_col1:
263
  st.info(
@@ -268,6 +300,12 @@ def main():
268
  st.info(
269
  "**Option 2: Generate Improvement Recommendations**\n\nGet AI-powered analysis of your swing with specific tips for improvement."
270
  )
 
 
 
 
 
 
271
 
272
  except Exception as e:
273
  st.error(f"Error during analysis: {str(e)}")
@@ -307,8 +345,12 @@ def main():
307
  st.code(st.session_state.analysis_data['prompt'],
308
  language="text")
309
 
310
- # Create columns for the two action buttons
311
- button_col1, button_col2 = st.columns(2)
 
 
 
 
312
 
313
  with button_col1:
314
  annotated_video_clicked = st.button("Generate Annotated Video",
@@ -319,6 +361,15 @@ def main():
319
  improvements_clicked = st.button("Generate Improvements",
320
  key="gpt_recommendations",
321
  use_container_width=True)
 
 
 
 
 
 
 
 
 
322
 
323
  # Handle annotated video creation
324
  if annotated_video_clicked:
@@ -425,6 +476,65 @@ def main():
425
  st.markdown("- Count '1-2-3' for your backswing")
426
  st.markdown("- Count '1' for your downswing")
427
  st.markdown("- Practice maintaining a 3:1 tempo ratio")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
428
 
429
 
430
  if __name__ == "__main__":
 
17
  # Add the app directory to the path
18
  sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
19
 
20
+ from app.utils.video_downloader import download_youtube_video, download_pro_reference
21
  from app.utils.video_processor import process_video
22
  from app.models.pose_estimator import analyze_pose
23
  from app.models.swing_analyzer import segment_swing, analyze_trajectory
24
  from app.models.llm_analyzer import generate_swing_analysis, create_llm_prompt, prepare_data_for_llm, check_llm_services
25
  from app.utils.visualizer import create_annotated_video
26
+ from app.utils.comparison import create_frame_by_frame_comparison
27
 
28
  # Set page config
29
  st.set_page_config(page_title="Par-ity Project: Golf Swing Analysis 🏌️‍♀️",
 
103
  'trajectory_data': None,
104
  'sample_rate': None
105
  }
106
+ if 'pro_reference_path' not in st.session_state:
107
+ st.session_state.pro_reference_path = None
108
 
109
  # Sidebar for configuration
110
  st.sidebar.title("Configuration")
 
167
  value=5,
168
  help=
169
  "Process every Nth frame. Higher values = faster but less accurate.")
170
+
171
+ # Pro reference toggle
172
+ enable_pro_comparison = st.sidebar.checkbox(
173
+ "Enable Pro Comparison",
174
+ value=True,
175
+ help="Compare your swing with a professional golfer reference"
176
+ )
177
+
178
+ # Pro reference URL input
179
+ if enable_pro_comparison:
180
+ pro_url = st.sidebar.text_input(
181
+ "Pro Golfer Reference URL",
182
+ value="https://www.youtube.com/shorts/geR666LWSHg",
183
+ help="YouTube URL of professional golfer swing (default provided)"
184
+ )
185
+ else:
186
+ pro_url = None
187
 
188
  # Video input options
189
  st.header("Video Input")
 
228
  st.error(f"Error processing video: {str(e)}")
229
  st.session_state.video_analyzed = False
230
  return
231
+
232
+ # Download pro reference if enabled
233
+ if enable_pro_comparison and (video_path or st.session_state.video_analyzed):
234
+ if not st.session_state.pro_reference_path:
235
+ with st.spinner("Downloading professional golfer reference..."):
236
+ try:
237
+ pro_path = download_pro_reference(pro_url)
238
+ st.session_state.pro_reference_path = pro_path
239
+ st.success("Professional reference downloaded successfully!")
240
+ except Exception as e:
241
+ st.error(f"Error downloading pro reference: {str(e)}")
242
+ st.session_state.pro_reference_path = None
243
 
244
  # Process video if available and analyze button was clicked
245
  if video_path and analyze_clicked:
 
287
  'prompt': prompt
288
  }
289
 
290
+ # Present the options after analysis
291
  st.subheader("What would you like to do next?")
292
+ options_col1, options_col2, options_col3 = st.columns(3)
293
 
294
  with options_col1:
295
  st.info(
 
300
  st.info(
301
  "**Option 2: Generate Improvement Recommendations**\n\nGet AI-powered analysis of your swing with specific tips for improvement."
302
  )
303
+
304
+ with options_col3:
305
+ if enable_pro_comparison and st.session_state.pro_reference_path:
306
+ st.info(
307
+ "**Option 3: Compare With Pro**\n\nSee a side-by-side frame-by-frame comparison with a professional golfer's swing."
308
+ )
309
 
310
  except Exception as e:
311
  st.error(f"Error during analysis: {str(e)}")
 
345
  st.code(st.session_state.analysis_data['prompt'],
346
  language="text")
347
 
348
+ # Create columns for the action buttons
349
+ if enable_pro_comparison and st.session_state.pro_reference_path:
350
+ button_col1, button_col2, button_col3 = st.columns(3)
351
+ else:
352
+ button_col1, button_col2 = st.columns(2)
353
+ button_col3 = None
354
 
355
  with button_col1:
356
  annotated_video_clicked = st.button("Generate Annotated Video",
 
361
  improvements_clicked = st.button("Generate Improvements",
362
  key="gpt_recommendations",
363
  use_container_width=True)
364
+
365
+ # Add pro comparison button if enabled
366
+ if enable_pro_comparison and st.session_state.pro_reference_path and button_col3:
367
+ with button_col3:
368
+ comparison_clicked = st.button("Compare With Pro",
369
+ key="pro_comparison",
370
+ use_container_width=True)
371
+ else:
372
+ comparison_clicked = False
373
 
374
  # Handle annotated video creation
375
  if annotated_video_clicked:
 
476
  st.markdown("- Count '1-2-3' for your backswing")
477
  st.markdown("- Count '1' for your downswing")
478
  st.markdown("- Practice maintaining a 3:1 tempo ratio")
479
+
480
+ # Handle pro comparison video creation
481
+ if comparison_clicked and st.session_state.pro_reference_path:
482
+ try:
483
+ with st.spinner("Creating frame-by-frame comparison..."):
484
+ # Get data from session state
485
+ user_video_path = st.session_state.analysis_data['video_path']
486
+ pro_video_path = st.session_state.pro_reference_path
487
+
488
+ # Create the comparison video
489
+ comparison_path = create_frame_by_frame_comparison(
490
+ user_video_path,
491
+ pro_video_path
492
+ )
493
+
494
+ # Verify the file exists
495
+ if not os.path.exists(comparison_path):
496
+ raise FileNotFoundError(
497
+ f"Comparison video file not found at {comparison_path}")
498
+
499
+ # Store the comparison video path in session state
500
+ st.session_state.comparison_video_path = comparison_path
501
+
502
+ # Display success message and video
503
+ st.success("Frame-by-frame comparison created successfully!")
504
+ st.subheader("Side-by-Side Comparison with Pro Golfer")
505
+
506
+ # Display video with larger width
507
+ display_video(comparison_path, width=800)
508
+
509
+ # Show download button
510
+ with open(comparison_path, "rb") as file:
511
+ video_bytes = file.read()
512
+ st.download_button(
513
+ label="Download Comparison Video",
514
+ data=video_bytes,
515
+ file_name=os.path.basename(comparison_path),
516
+ mime="video/mp4"
517
+ )
518
+
519
+ # Add some guidance for interpreting the comparison
520
+ with st.expander("How to use this comparison", expanded=True):
521
+ st.markdown("""
522
+ ### How to Interpret This Comparison
523
+
524
+ This side-by-side comparison allows you to see how your swing compares to a professional golfer's swing frame by frame. Look for these key differences:
525
+
526
+ 1. **Posture and Setup**: Compare your stance, grip, and alignment at address
527
+ 2. **Backswing Rotation**: Note how much shoulder and hip rotation occurs
528
+ 3. **Top of Swing Position**: Observe club position and body alignment
529
+ 4. **Downswing Sequence**: Watch how the pro initiates the downswing
530
+ 5. **Impact Position**: Compare body positioning at impact
531
+ 6. **Follow-through**: Note how weight transfers and body rotates after impact
532
+
533
+ Try pausing the video at key positions to analyze differences in detail.
534
+ """)
535
+
536
+ except Exception as e:
537
+ st.error(f"Error creating comparison video: {str(e)}")
538
 
539
 
540
  if __name__ == "__main__":
app/utils/comparison.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Comparison module for frame-by-frame analysis between user and pro swings
3
+ """
4
+
5
+ import os
6
+ import cv2
7
+ import numpy as np
8
+ from tqdm import tqdm
9
+
10
+
11
+ def extract_frames(video_path, max_frames=100):
12
+ """
13
+ Extract frames from a video
14
+
15
+ Args:
16
+ video_path (str): Path to the video file
17
+ max_frames (int): Maximum number of frames to extract
18
+
19
+ Returns:
20
+ list: List of extracted frames as numpy arrays
21
+ """
22
+ frames = []
23
+
24
+ if not os.path.exists(video_path):
25
+ raise ValueError(f"Video file not found: {video_path}")
26
+
27
+ cap = cv2.VideoCapture(video_path)
28
+
29
+ if not cap.isOpened():
30
+ raise ValueError(f"Could not open video: {video_path}")
31
+
32
+ # Get total frame count
33
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
34
+
35
+ # Calculate step to get approximately max_frames
36
+ step = max(1, total_frames // max_frames)
37
+
38
+ current_frame = 0
39
+
40
+ while True:
41
+ ret, frame = cap.read()
42
+ if not ret:
43
+ break
44
+
45
+ if current_frame % step == 0:
46
+ frames.append(frame)
47
+
48
+ current_frame += 1
49
+
50
+ cap.release()
51
+
52
+ return frames
53
+
54
+
55
+ def normalize_frames(frames, target_height=480):
56
+ """
57
+ Normalize frames to a consistent size while maintaining aspect ratio
58
+
59
+ Args:
60
+ frames (list): List of frames
61
+ target_height (int): Target height for normalized frames
62
+
63
+ Returns:
64
+ list: List of normalized frames
65
+ """
66
+ normalized_frames = []
67
+
68
+ for frame in frames:
69
+ # Get current dimensions
70
+ h, w = frame.shape[:2]
71
+
72
+ # Calculate new width to maintain aspect ratio
73
+ target_width = int(w * (target_height / h))
74
+
75
+ # Resize the frame
76
+ resized = cv2.resize(frame, (target_width, target_height))
77
+ normalized_frames.append(resized)
78
+
79
+ return normalized_frames
80
+
81
+
82
+ def create_side_by_side_comparison(user_frames, pro_frames, output_path, fps=30):
83
+ """
84
+ Create a side-by-side comparison video
85
+
86
+ Args:
87
+ user_frames (list): List of user swing frames
88
+ pro_frames (list): List of pro swing frames
89
+ output_path (str): Path to save the comparison video
90
+ fps (int): Frames per second for output video
91
+
92
+ Returns:
93
+ str: Path to the comparison video
94
+ """
95
+ if not user_frames or not pro_frames:
96
+ raise ValueError("Both user and pro frames must be provided")
97
+
98
+ # Normalize frames to same height
99
+ user_normalized = normalize_frames(user_frames)
100
+ pro_normalized = normalize_frames(pro_frames)
101
+
102
+ # Ensure we have the same number of frames by duplicating the last frame if needed
103
+ max_frames = max(len(user_normalized), len(pro_normalized))
104
+
105
+ while len(user_normalized) < max_frames:
106
+ user_normalized.append(user_normalized[-1])
107
+
108
+ while len(pro_normalized) < max_frames:
109
+ pro_normalized.append(pro_normalized[-1])
110
+
111
+ # Create output directory if it doesn't exist
112
+ os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True)
113
+
114
+ # Get dimensions for the combined frame
115
+ user_h, user_w = user_normalized[0].shape[:2]
116
+ pro_h, pro_w = pro_normalized[0].shape[:2]
117
+
118
+ # Create a combined frame with padding
119
+ padding = 20 # Pixels between the two videos
120
+ combined_width = user_w + pro_w + padding
121
+ combined_height = max(user_h, pro_h)
122
+
123
+ # Create video writer
124
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
125
+ out = cv2.VideoWriter(output_path, fourcc, fps, (combined_width, combined_height))
126
+
127
+ if not out.isOpened():
128
+ raise IOError(f"Failed to create video writer for {output_path}")
129
+
130
+ # Create the combined video
131
+ for i in tqdm(range(min(len(user_normalized), len(pro_normalized))), desc="Creating comparison video"):
132
+ # Create a blank canvas
133
+ combined = np.ones((combined_height, combined_width, 3), dtype=np.uint8) * 255
134
+
135
+ # Add title text
136
+ font = cv2.FONT_HERSHEY_SIMPLEX
137
+ cv2.putText(combined, "Your Swing", (user_w//2 - 60, 30), font, 1, (0, 0, 0), 2)
138
+ cv2.putText(combined, "Pro Swing", (user_w + padding + pro_w//2 - 60, 30), font, 1, (0, 0, 0), 2)
139
+
140
+ # Add frame number
141
+ cv2.putText(combined, f"Frame: {i+1}/{min(len(user_normalized), len(pro_normalized))}",
142
+ (10, combined_height - 10), font, 0.5, (0, 0, 0), 1)
143
+
144
+ # Paste user frame
145
+ combined[0:user_h, 0:user_w] = user_normalized[i]
146
+
147
+ # Paste pro frame
148
+ combined[0:pro_h, user_w+padding:user_w+padding+pro_w] = pro_normalized[i]
149
+
150
+ # Draw vertical line between frames
151
+ cv2.line(combined, (user_w + padding//2, 0), (user_w + padding//2, combined_height), (0, 0, 0), 2)
152
+
153
+ # Write to video
154
+ out.write(combined)
155
+
156
+ out.release()
157
+
158
+ return output_path
159
+
160
+
161
+ def align_swings(user_frames, pro_frames, method="manual"):
162
+ """
163
+ Align user and pro swings based on swing phases
164
+
165
+ Args:
166
+ user_frames (list): List of user swing frames
167
+ pro_frames (list): List of pro swing frames
168
+ method (str): Alignment method ('manual' or 'auto')
169
+
170
+ Returns:
171
+ tuple: Aligned user frames and pro frames
172
+ """
173
+ # For now, we'll use a simple frame stretching approach
174
+ # In the future, this could be enhanced with ML-based swing phase detection
175
+
176
+ # Get frame counts
177
+ user_count = len(user_frames)
178
+ pro_count = len(pro_frames)
179
+
180
+ # If almost equal, return as-is
181
+ if abs(user_count - pro_count) <= 5:
182
+ return user_frames, pro_frames
183
+
184
+ # If user has more frames, subsample
185
+ if user_count > pro_count:
186
+ indices = np.linspace(0, user_count - 1, pro_count, dtype=int)
187
+ return [user_frames[i] for i in indices], pro_frames
188
+
189
+ # If pro has more frames, subsample
190
+ indices = np.linspace(0, pro_count - 1, user_count, dtype=int)
191
+ return user_frames, [pro_frames[i] for i in indices]
192
+
193
+
194
+ def create_frame_by_frame_comparison(user_video_path, pro_video_path, output_dir="downloads"):
195
+ """
196
+ Create a frame-by-frame comparison between user and pro golfer swings
197
+
198
+ Args:
199
+ user_video_path (str): Path to the user's golf swing video
200
+ pro_video_path (str): Path to the professional golfer's swing video
201
+ output_dir (str): Directory to save the comparison video
202
+
203
+ Returns:
204
+ str: Path to the comparison video
205
+ """
206
+ # Extract frames
207
+ user_frames = extract_frames(user_video_path)
208
+ pro_frames = extract_frames(pro_video_path)
209
+
210
+ # Align swings
211
+ aligned_user, aligned_pro = align_swings(user_frames, pro_frames)
212
+
213
+ # Create output path
214
+ video_name = os.path.splitext(os.path.basename(user_video_path))[0]
215
+ output_path = os.path.join(output_dir, f"{video_name}_comparison.mp4")
216
+
217
+ # Create side-by-side comparison
218
+ return create_side_by_side_comparison(aligned_user, aligned_pro, output_path)
app/utils/video_downloader.py CHANGED
@@ -77,3 +77,61 @@ def download_youtube_video(url, output_dir="downloads"):
77
  raise ValueError(f"Error downloading video: {str(e)}")
78
  except Exception as e:
79
  raise ValueError(f"Error: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  raise ValueError(f"Error downloading video: {str(e)}")
78
  except Exception as e:
79
  raise ValueError(f"Error: {str(e)}")
80
+
81
+
82
+ def download_pro_reference(url="https://www.youtube.com/shorts/geR666LWSHg", output_dir="downloads"):
83
+ """
84
+ Download a professional golfer reference video
85
+
86
+ Args:
87
+ url (str): YouTube video URL of professional golfer (default: provided reference)
88
+ output_dir (str): Directory to save the downloaded video
89
+
90
+ Returns:
91
+ str: Path to the downloaded pro reference video file
92
+ """
93
+ try:
94
+ # Create a specific filename for the pro reference
95
+ os.makedirs(output_dir, exist_ok=True)
96
+
97
+ # Check if pro reference already exists to avoid re-downloading
98
+ pro_file_path = os.path.join(output_dir, "pro_reference.mp4")
99
+ if os.path.exists(pro_file_path):
100
+ return pro_file_path
101
+
102
+ # Set output template for the downloaded file with fixed name
103
+ output_template = os.path.join(output_dir, "pro_reference.%(ext)s")
104
+
105
+ # Configure yt-dlp options
106
+ ydl_opts = {
107
+ 'format': 'best[ext=mp4]/best', # Prefer mp4 format
108
+ 'outtmpl': output_template,
109
+ 'noplaylist': True,
110
+ 'quiet': False,
111
+ 'no_warnings': False,
112
+ 'ignoreerrors': False,
113
+ }
114
+
115
+ # Create yt-dlp object and download the video
116
+ with yt_dlp.YoutubeDL(ydl_opts) as ydl:
117
+ ydl.extract_info(url, download=True)
118
+
119
+ # Check if file exists with mp4 extension
120
+ if os.path.exists(pro_file_path):
121
+ return pro_file_path
122
+ else:
123
+ # Try other extensions
124
+ for ext in ['webm', 'mkv']:
125
+ alt_path = os.path.join(output_dir, f"pro_reference.{ext}")
126
+ if os.path.exists(alt_path):
127
+ return alt_path
128
+
129
+ # If still not found, download as normal video and rename
130
+ video_path = download_youtube_video(url, output_dir)
131
+ ext = os.path.splitext(video_path)[1]
132
+ new_path = os.path.join(output_dir, f"pro_reference{ext}")
133
+ os.rename(video_path, new_path)
134
+ return new_path
135
+
136
+ except Exception as e:
137
+ raise ValueError(f"Error downloading pro reference: {str(e)}")
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
  opencv-python-headless
2
- yt-dlp==2025.2.19
3
  ultralytics
4
  mediapipe
5
  numpy
 
1
  opencv-python-headless
2
+ yt-dlp==2025.05.22
3
  ultralytics
4
  mediapipe
5
  numpy