chenemii commited on
Commit
3301758
·
1 Parent(s): c3f5155

frame analysis

Browse files
Files changed (2) hide show
  1. app/streamlit_app.py +91 -116
  2. app/utils/comparison.py +293 -147
app/streamlit_app.py CHANGED
@@ -10,6 +10,8 @@ from dotenv import load_dotenv
10
  import base64
11
  from pathlib import Path
12
  import shutil
 
 
13
 
14
  # Load environment variables
15
  load_dotenv()
@@ -23,7 +25,7 @@ from app.models.pose_estimator import analyze_pose
23
  from app.models.swing_analyzer import segment_swing, analyze_trajectory
24
  from app.models.llm_analyzer import generate_swing_analysis, create_llm_prompt, prepare_data_for_llm, check_llm_services
25
  from app.utils.visualizer import create_annotated_video
26
- from app.utils.comparison import create_key_frame_comparison
27
 
28
  # Set page config
29
  st.set_page_config(page_title="Par-ity Project: Golf Swing Analysis 🏌️‍♀️",
@@ -302,10 +304,9 @@ def main():
302
  )
303
 
304
  with options_col3:
305
- if enable_pro_comparison and st.session_state.pro_reference_path:
306
- st.info(
307
- "**Option 3: Compare With Pro**\n\nSee side-by-side comparisons of 3 key swing positions with a professional golfer, including improvement tips for each phase."
308
- )
309
 
310
  except Exception as e:
311
  st.error(f"Error during analysis: {str(e)}")
@@ -346,11 +347,7 @@ def main():
346
  language="text")
347
 
348
  # Create columns for the action buttons
349
- if enable_pro_comparison and st.session_state.pro_reference_path:
350
- button_col1, button_col2, button_col3 = st.columns(3)
351
- else:
352
- button_col1, button_col2 = st.columns(2)
353
- button_col3 = None
354
 
355
  with button_col1:
356
  annotated_video_clicked = st.button("Generate Annotated Video",
@@ -362,14 +359,10 @@ def main():
362
  key="gpt_recommendations",
363
  use_container_width=True)
364
 
365
- # Add pro comparison button if enabled
366
- if enable_pro_comparison and st.session_state.pro_reference_path and button_col3:
367
- with button_col3:
368
- comparison_clicked = st.button("Compare Key Positions",
369
- key="pro_comparison",
370
- use_container_width=True)
371
- else:
372
- comparison_clicked = False
373
 
374
  # Handle annotated video creation
375
  if annotated_video_clicked:
@@ -477,111 +470,93 @@ def main():
477
  st.markdown("- Count '1' for your downswing")
478
  st.markdown("- Practice maintaining a 3:1 tempo ratio")
479
 
480
- # Handle pro comparison video creation
481
- if comparison_clicked and st.session_state.pro_reference_path:
482
  try:
483
- with st.spinner("Creating key frame comparison..."):
484
- # Get data from session state
485
  user_video_path = st.session_state.analysis_data['video_path']
486
  user_swing_phases = st.session_state.analysis_data['swing_phases']
487
-
488
- # Create the key frame comparison using static pro reference images
489
- # Don't pass pro_video_path to ensure it uses the static images
490
- comparison_data = create_key_frame_comparison(
491
- user_video_path,
492
- user_swing_phases=user_swing_phases,
493
- use_pro_images=True
494
- )
495
-
496
- # Store the comparison data in session state
497
- st.session_state.comparison_data = comparison_data
498
-
499
- # Display success message
500
- st.success("Key frame comparison created successfully!")
501
- st.subheader("Swing Analysis: Key Position Comparison")
502
-
503
- # Display each comparison with comments
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
504
  phases = ['setup', 'backswing', 'impact']
505
-
506
  for phase in phases:
507
- if phase in comparison_data:
508
- data = comparison_data[phase]
509
-
510
- # Display the comparison image
511
- st.subheader(f"{data['title']}")
512
-
513
- # Display the image
514
- if os.path.exists(data['image_path']):
515
- st.image(data['image_path'], use_column_width=True)
516
 
517
- # Create download button for the image
518
- with open(data['image_path'], "rb") as file:
519
- image_bytes = file.read()
520
- st.download_button(
521
- label=f"Download {data['title']} Comparison",
522
- data=image_bytes,
523
- file_name=os.path.basename(data['image_path']),
524
- mime="image/jpeg",
525
- key=f"download_{phase}"
526
- )
527
-
528
- # Display improvement comments
529
- comments = data['comments']
530
-
531
- col1, col2 = st.columns(2)
532
-
533
- with col1:
534
- st.markdown("**🏆 Professional Analysis:**")
535
- for analysis in comments['pro_analysis']:
536
- st.markdown(f"• {analysis}")
537
-
538
- with col2:
539
- st.markdown("**🔄 User vs Professional Comparison:**")
540
- for comparison in comments['comparison']:
541
- st.markdown(f"• {comparison}")
542
-
543
- st.markdown("---") # Add separator between phases
544
-
545
- # Add general guidance
546
- with st.expander("How to Use This Analysis", expanded=False):
547
- st.markdown("""
548
- ### How to Interpret These Comparisons
549
-
550
- Each comparison shows your swing position (left) next to a professional golfer's position (right) at three critical moments:
551
-
552
- 1. **Starting Position**: Your setup and address position
553
- 2. **Top of Backswing**: The highest point of your backswing
554
- 3. **Impact with Ball**: The moment of contact with the ball
555
-
556
- **Tips for Improvement:**
557
- - Compare your body positioning, posture, and club position to the pro
558
- - Focus on one aspect at a time (e.g., posture, then weight distribution)
559
- - Practice the positions slowly without a ball first
560
- - Use a mirror or video recording to check your positions
561
- - Work with a golf instructor for personalized feedback
562
-
563
- **Remember:** Every golfer is different, so focus on the fundamental principles rather than trying to copy every detail exactly.
564
- """)
565
-
566
  except Exception as e:
567
- st.error(f"Error creating key frame comparison: {str(e)}")
568
-
569
- # Add some guidance for interpreting the comparison
570
- with st.expander("How to use this comparison", expanded=True):
571
- st.markdown("""
572
- ### How to Interpret This Comparison
573
-
574
- This side-by-side comparison allows you to see how your swing compares to a professional golfer's swing frame by frame. Look for these key differences:
575
-
576
- 1. **Posture and Setup**: Compare your stance, grip, and alignment at address
577
- 2. **Backswing Rotation**: Note how much shoulder and hip rotation occurs
578
- 3. **Top of Swing Position**: Observe club position and body alignment
579
- 4. **Downswing Sequence**: Watch how the pro initiates the downswing
580
- 5. **Impact Position**: Compare body positioning at impact
581
- 6. **Follow-through**: Note how weight transfers and body rotates after impact
582
-
583
- Try pausing the video at key positions to analyze differences in detail.
584
- """)
585
 
586
 
587
  if __name__ == "__main__":
 
10
  import base64
11
  from pathlib import Path
12
  import shutil
13
+ import cv2
14
+ from PIL import Image
15
 
16
  # Load environment variables
17
  load_dotenv()
 
25
  from app.models.swing_analyzer import segment_swing, analyze_trajectory
26
  from app.models.llm_analyzer import generate_swing_analysis, create_llm_prompt, prepare_data_for_llm, check_llm_services
27
  from app.utils.visualizer import create_annotated_video
28
+ from app.utils.comparison import create_key_frame_comparison, extract_key_swing_frames
29
 
30
  # Set page config
31
  st.set_page_config(page_title="Par-ity Project: Golf Swing Analysis 🏌️‍♀️",
 
304
  )
305
 
306
  with options_col3:
307
+ st.info(
308
+ "**Option 3: Key Frame Analysis**\n\nExtract and review your setup, top of backswing, and impact frames with helpful comments for each phase."
309
+ )
 
310
 
311
  except Exception as e:
312
  st.error(f"Error during analysis: {str(e)}")
 
347
  language="text")
348
 
349
  # Create columns for the action buttons
350
+ button_col1, button_col2, button_col3 = st.columns(3)
 
 
 
 
351
 
352
  with button_col1:
353
  annotated_video_clicked = st.button("Generate Annotated Video",
 
359
  key="gpt_recommendations",
360
  use_container_width=True)
361
 
362
+ with button_col3:
363
+ keyframe_analysis_clicked = st.button("Key Frame Analysis",
364
+ key="keyframe_analysis",
365
+ use_container_width=True)
 
 
 
 
366
 
367
  # Handle annotated video creation
368
  if annotated_video_clicked:
 
470
  st.markdown("- Count '1' for your downswing")
471
  st.markdown("- Practice maintaining a 3:1 tempo ratio")
472
 
473
+ # Handle key frame analysis (new tab/option)
474
+ if keyframe_analysis_clicked:
475
  try:
476
+ with st.spinner("Extracting key frames from your swing..."):
 
477
  user_video_path = st.session_state.analysis_data['video_path']
478
  user_swing_phases = st.session_state.analysis_data['swing_phases']
479
+ key_frames = extract_key_swing_frames(user_video_path, user_swing_phases)
480
+
481
+ st.success("Key frame analysis complete!")
482
+ st.subheader("Key Frame Analysis: Your Swing's Critical Positions")
483
+
484
+ # Define helpful comments for each phase
485
+ phase_comments = {
486
+ 'setup': [
487
+ "Balanced stance with feet shoulder-width apart.",
488
+ "Even weight distribution on both feet.",
489
+ "Neutral grip with hands in proper position.",
490
+ "Athletic posture with slight forward bend.",
491
+ "Ball positioned correctly for club selection."
492
+ ],
493
+ 'backswing': [
494
+ "Full shoulder rotation with stable lower body.",
495
+ "Club on proper swing plane at top.",
496
+ "Consistent spine angle throughout.",
497
+ "Minimal weight shift to right side."
498
+ ],
499
+ 'impact': [
500
+ "Weight shifted to front foot (70-80%).",
501
+ "Hands ahead of ball at impact.",
502
+ "Square club face to target line.",
503
+ "Head behind ball with steady position.",
504
+ "Hips and shoulders aligned to target."
505
+ ]
506
+ }
507
+ phase_titles = {
508
+ 'setup': 'Starting Position',
509
+ 'backswing': 'Top of Backswing',
510
+ 'impact': 'Impact with Ball'
511
+ }
512
  phases = ['setup', 'backswing', 'impact']
 
513
  for phase in phases:
514
+ st.subheader(f"{phase_titles[phase]}")
515
+ img_col, comment_col = st.columns([1, 1])
516
+ with img_col:
517
+ if key_frames.get(phase) is not None:
518
+ frame = key_frames[phase]
 
 
 
 
519
 
520
+ # Verify frame is in color before conversion
521
+ if len(frame.shape) == 3 and frame.shape[2] == 3:
522
+ try:
523
+ # Save frame to temp file for display
524
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.jpg')
525
+
526
+ # Convert BGR (OpenCV) to RGB (PIL) format
527
+ rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
528
+
529
+ # Debug: Log frame dimensions after extraction and color conversion
530
+ height, width = rgb_frame.shape[:2]
531
+ print(f"Frame dimensions for {phase}: {width}x{height}")
532
+
533
+ pil_img = Image.fromarray(rgb_frame)
534
+ pil_img.save(temp_file.name, format="JPEG", quality=95)
535
+
536
+ # Display the image
537
+ st.image(temp_file.name, use_container_width=True)
538
+
539
+ # Clean up temp file
540
+ try:
541
+ os.unlink(temp_file.name)
542
+ except:
543
+ pass # Ignore cleanup errors
544
+
545
+ except Exception as e:
546
+ st.error(f"Error displaying {phase} frame: {str(e)}")
547
+ st.warning("Frame could not be displayed properly.")
548
+ else:
549
+ st.warning(f"Frame for {phase} is not in color format. Shape: {frame.shape}")
550
+ else:
551
+ st.warning("Frame not found.")
552
+ with comment_col:
553
+ st.markdown("**Comments:**")
554
+ for comment in phase_comments[phase]:
555
+ st.markdown(f"- {comment}")
556
+ st.markdown("---")
 
 
 
 
 
 
 
 
 
 
 
 
557
  except Exception as e:
558
+ st.error(f"Error during key frame analysis: {str(e)}")
559
+ st.info("Please ensure your video is in a supported format and try again.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
560
 
561
 
562
  if __name__ == "__main__":
app/utils/comparison.py CHANGED
@@ -1,11 +1,65 @@
1
  """
2
  Comparison module for frame-by-frame analysis between user and pro swings
 
 
 
3
  """
4
 
5
  import os
6
  import cv2
7
  import numpy as np
8
  from tqdm import tqdm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
 
11
  def extract_frames(video_path, max_frames=100):
@@ -24,11 +78,22 @@ def extract_frames(video_path, max_frames=100):
24
  if not os.path.exists(video_path):
25
  raise ValueError(f"Video file not found: {video_path}")
26
 
 
27
  cap = cv2.VideoCapture(video_path)
28
 
29
  if not cap.isOpened():
30
  raise ValueError(f"Could not open video: {video_path}")
31
 
 
 
 
 
 
 
 
 
 
 
32
  # Get total frame count
33
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
34
 
@@ -43,7 +108,10 @@ def extract_frames(video_path, max_frames=100):
43
  break
44
 
45
  if current_frame % step == 0:
46
- frames.append(frame)
 
 
 
47
 
48
  current_frame += 1
49
 
@@ -59,56 +127,126 @@ def extract_key_swing_frames(video_path, swing_phases=None):
59
  2. Top of backswing
60
  3. Impact with ball
61
 
62
- Args:
63
- video_path (str): Path to the video file
64
- swing_phases (dict): Optional swing phase data for precise frame selection
65
-
66
- Returns:
67
- dict: Dictionary with keys 'setup', 'backswing', 'impact'
68
- and frame images as values
69
  """
70
  if not os.path.exists(video_path):
71
  raise ValueError(f"Video file not found: {video_path}")
72
-
 
 
 
73
  cap = cv2.VideoCapture(video_path)
74
 
75
  if not cap.isOpened():
76
  raise ValueError(f"Could not open video: {video_path}")
77
 
78
- # Get total frame count
79
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
80
-
81
- key_frames = {}
82
-
83
- if swing_phases:
84
- # Use provided swing phase data for precise frame selection
85
- frame_indices = {
86
- 'setup': swing_phases.get('setup', [0])[0] if swing_phases.get('setup') else 0,
87
- 'backswing': swing_phases.get('backswing', [total_frames//3])[-1] if swing_phases.get('backswing') else total_frames//3,
88
- 'impact': swing_phases.get('impact', [total_frames//2])[len(swing_phases.get('impact', [total_frames//2]))//2] if swing_phases.get('impact') else total_frames//2
89
- }
90
- else:
91
- # Use estimated frame positions for 3 frames
92
- frame_indices = {
93
- 'setup': 0, # First frame
94
- 'backswing': total_frames // 3, # 33% through
95
- 'impact': int(total_frames * 0.6) # 60% through
96
- }
97
-
98
- # Extract the specific frames
99
- for phase, frame_idx in frame_indices.items():
100
- cap.set(cv2.CAP_PROP_POS_FRAMES, min(frame_idx, total_frames - 1))
101
- ret, frame = cap.read()
102
- if ret:
103
- # Keep original orientation - no rotation
104
- key_frames[phase] = frame
 
 
105
  else:
106
- # If frame extraction fails, use a black frame
107
- key_frames[phase] = np.zeros((480, 640, 3), dtype=np.uint8)
108
-
109
- cap.release()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
- return key_frames
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
 
114
  def generate_improvement_comments(phase):
@@ -206,6 +344,8 @@ def load_pro_reference_images(pro_images_dir="pro_reference"):
206
  if os.path.exists(image_path):
207
  image = cv2.imread(image_path)
208
  if image is not None:
 
 
209
  pro_frames[phase] = image
210
  else:
211
  # Create a placeholder if image can't be loaded
@@ -217,20 +357,65 @@ def load_pro_reference_images(pro_images_dir="pro_reference"):
217
  return pro_frames
218
 
219
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  def create_key_frame_comparison(user_video_path, pro_video_path=None, user_swing_phases=None, pro_swing_phases=None, output_dir="downloads", use_pro_images=True):
221
  """
222
- Create a comparison of 3 key frames between user and pro golfer swings
 
 
 
 
223
 
224
  Args:
225
  user_video_path (str): Path to the user's golf swing video
226
  pro_video_path (str): Path to the professional golfer's swing video (optional if use_pro_images=True)
227
  user_swing_phases (dict): Optional swing phase data for user video
228
  pro_swing_phases (dict): Optional swing phase data for pro video
229
- output_dir (str): Directory to save the comparison images
230
  use_pro_images (bool): Whether to use provided pro reference images instead of video
231
 
232
  Returns:
233
- dict: Dictionary with phase names as keys and image paths as values
 
234
  """
235
  # Extract key frames from user video
236
  user_frames = extract_key_swing_frames(user_video_path, user_swing_phases)
@@ -254,30 +439,40 @@ def create_key_frame_comparison(user_video_path, pro_video_path=None, user_swing
254
  user_frame = user_frames.get(phase, np.zeros((480, 640, 3), dtype=np.uint8))
255
  pro_frame = pro_frames.get(phase, np.zeros((480, 640, 3), dtype=np.uint8))
256
 
257
- # Resize frames to consistent size while maintaining portrait orientation
258
- target_height = 400
259
- user_frame = resize_frame_maintain_aspect(user_frame, target_height)
260
- pro_frame = resize_frame_maintain_aspect(pro_frame, target_height)
261
 
262
- # Create side-by-side comparison
263
- comparison_image = create_side_by_side_image(user_frame, pro_frame, phase_titles[i])
264
-
265
- # Save the comparison image with absolute path
266
  video_name = os.path.splitext(os.path.basename(user_video_path))[0]
267
- output_path = os.path.join(output_dir, f"{video_name}_{phase}_comparison.jpg")
 
268
 
269
- # Ensure the image is saved successfully
270
- success = cv2.imwrite(output_path, comparison_image)
271
- if not success:
272
- print(f"Warning: Failed to save image to {output_path}")
273
- else:
274
- print(f"Successfully saved comparison image: {output_path}")
 
 
 
 
 
 
 
 
 
 
 
 
275
 
276
  # Get improvement comments
277
  comments = generate_improvement_comments(phase)
278
 
279
  comparison_data[phase] = {
280
- 'image_path': output_path,
 
281
  'title': phase_titles[i],
282
  'comments': comments
283
  }
@@ -285,69 +480,6 @@ def create_key_frame_comparison(user_video_path, pro_video_path=None, user_swing
285
  return comparison_data
286
 
287
 
288
- def resize_frame_maintain_aspect(frame, target_height):
289
- """
290
- Resize frame to target height while maintaining aspect ratio
291
-
292
- Args:
293
- frame (numpy.ndarray): Input frame
294
- target_height (int): Target height
295
-
296
- Returns:
297
- numpy.ndarray: Resized frame
298
- """
299
- h, w = frame.shape[:2]
300
- target_width = int(w * (target_height / h))
301
- return cv2.resize(frame, (target_width, target_height))
302
-
303
-
304
- def create_side_by_side_image(user_frame, pro_frame, title):
305
- """
306
- Create a side-by-side comparison image
307
-
308
- Args:
309
- user_frame (numpy.ndarray): User's swing frame
310
- pro_frame (numpy.ndarray): Pro's swing frame
311
- title (str): Title for the comparison
312
-
313
- Returns:
314
- numpy.ndarray: Combined comparison image
315
- """
316
- # Get dimensions
317
- user_h, user_w = user_frame.shape[:2]
318
- pro_h, pro_w = pro_frame.shape[:2]
319
-
320
- # Create padding and title space
321
- padding = 20
322
- title_height = 60
323
- max_height = max(user_h, pro_h)
324
- total_width = user_w + pro_w + padding
325
- total_height = max_height + title_height
326
-
327
- # Create blank canvas
328
- canvas = np.ones((total_height, total_width, 3), dtype=np.uint8) * 255
329
-
330
- # Add title
331
- font = cv2.FONT_HERSHEY_SIMPLEX
332
- title_size = cv2.getTextSize(title, font, 1.2, 2)[0]
333
- title_x = (total_width - title_size[0]) // 2
334
- cv2.putText(canvas, title, (title_x, 40), font, 1.2, (0, 0, 0), 2)
335
-
336
- # Add user frame
337
- y_offset = title_height + (max_height - user_h) // 2
338
- canvas[y_offset:y_offset + user_h, 0:user_w] = user_frame
339
-
340
- # Add pro frame
341
- y_offset = title_height + (max_height - pro_h) // 2
342
- canvas[y_offset:y_offset + pro_h, user_w + padding:user_w + padding + pro_w] = pro_frame
343
-
344
- # Draw vertical separator line
345
- line_x = user_w + padding // 2
346
- cv2.line(canvas, (line_x, title_height), (line_x, total_height), (200, 200, 200), 2)
347
-
348
- return canvas
349
-
350
-
351
  def normalize_frames(frames, target_height=480):
352
  """
353
  Normalize frames to a consistent size while maintaining aspect ratio
@@ -362,14 +494,8 @@ def normalize_frames(frames, target_height=480):
362
  normalized_frames = []
363
 
364
  for frame in frames:
365
- # Get current dimensions
366
- h, w = frame.shape[:2]
367
-
368
- # Calculate new width to maintain aspect ratio
369
- target_width = int(w * (target_height / h))
370
-
371
- # Resize the frame
372
- resized = cv2.resize(frame, (target_width, target_height))
373
  normalized_frames.append(resized)
374
 
375
  return normalized_frames
@@ -391,25 +517,43 @@ def create_side_by_side_comparison(user_frames, pro_frames, output_path, fps=30)
391
  if not user_frames or not pro_frames:
392
  raise ValueError("Both user and pro frames must be provided")
393
 
394
- # Normalize frames to same height
395
- user_normalized = normalize_frames(user_frames)
396
- pro_normalized = normalize_frames(pro_frames)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
397
 
398
  # Ensure we have the same number of frames by duplicating the last frame if needed
399
- max_frames = max(len(user_normalized), len(pro_normalized))
400
 
401
- while len(user_normalized) < max_frames:
402
- user_normalized.append(user_normalized[-1])
 
403
 
404
- while len(pro_normalized) < max_frames:
405
- pro_normalized.append(pro_normalized[-1])
406
 
407
  # Create output directory if it doesn't exist
408
  os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True)
409
 
410
- # Get dimensions for the combined frame
411
- user_h, user_w = user_normalized[0].shape[:2]
412
- pro_h, pro_w = pro_normalized[0].shape[:2]
413
 
414
  # Create a combined frame with padding
415
  padding = 20 # Pixels between the two videos
@@ -424,7 +568,7 @@ def create_side_by_side_comparison(user_frames, pro_frames, output_path, fps=30)
424
  raise IOError(f"Failed to create video writer for {output_path}")
425
 
426
  # Create the combined video
427
- for i in tqdm(range(min(len(user_normalized), len(pro_normalized))), desc="Creating comparison video"):
428
  # Create a blank canvas
429
  combined = np.ones((combined_height, combined_width, 3), dtype=np.uint8) * 255
430
 
@@ -434,14 +578,16 @@ def create_side_by_side_comparison(user_frames, pro_frames, output_path, fps=30)
434
  cv2.putText(combined, "Pro Swing", (user_w + padding + pro_w//2 - 60, 30), font, 1, (0, 0, 0), 2)
435
 
436
  # Add frame number
437
- cv2.putText(combined, f"Frame: {i+1}/{min(len(user_normalized), len(pro_normalized))}",
438
  (10, combined_height - 10), font, 0.5, (0, 0, 0), 1)
439
 
440
- # Paste user frame
441
- combined[0:user_h, 0:user_w] = user_normalized[i]
 
442
 
443
  # Paste pro frame
444
- combined[0:pro_h, user_w+padding:user_w+padding+pro_w] = pro_normalized[i]
 
445
 
446
  # Draw vertical line between frames
447
  cv2.line(combined, (user_w + padding//2, 0), (user_w + padding//2, combined_height), (0, 0, 0), 2)
 
1
  """
2
  Comparison module for frame-by-frame analysis between user and pro swings
3
+
4
+ CRITICAL NOTE: This module preserves the original sizes and orientations of both user and professional videos.
5
+ Frames are saved as separate image files at their original resolutions without any resizing, rotation, or distortion.
6
  """
7
 
8
  import os
9
  import cv2
10
  import numpy as np
11
  from tqdm import tqdm
12
+ from PIL import Image
13
+
14
+
15
+ def ensure_color_frame(frame):
16
+ """
17
+ Ensure frame is in color format (3 channels)
18
+
19
+ Args:
20
+ frame (numpy.ndarray): Input frame
21
+
22
+ Returns:
23
+ numpy.ndarray: Color frame with 3 channels
24
+ """
25
+ if frame is None:
26
+ return np.zeros((480, 640, 3), dtype=np.uint8)
27
+
28
+ # If frame is grayscale (2D), convert to color (3D)
29
+ if len(frame.shape) == 2:
30
+ frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2BGR)
31
+ elif len(frame.shape) == 3 and frame.shape[2] == 1:
32
+ frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2BGR)
33
+ elif len(frame.shape) == 3 and frame.shape[2] == 4:
34
+ # Convert RGBA to BGR
35
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2BGR)
36
+
37
+ return frame
38
+
39
+
40
+ def resize_frame_proportionally(frame, target_height):
41
+ """
42
+ Resize frame proportionally to target height while maintaining aspect ratio
43
+
44
+ Args:
45
+ frame (numpy.ndarray): Input frame
46
+ target_height (int): Target height
47
+
48
+ Returns:
49
+ numpy.ndarray: Resized frame
50
+ """
51
+ # Ensure frame is in color format
52
+ frame = ensure_color_frame(frame)
53
+
54
+ h, w = frame.shape[:2]
55
+ if h == 0:
56
+ return np.zeros((target_height, target_height, 3), dtype=np.uint8)
57
+
58
+ # Calculate new width to maintain aspect ratio
59
+ target_width = int(w * (target_height / h))
60
+
61
+ # Resize the frame
62
+ return cv2.resize(frame, (target_width, target_height))
63
 
64
 
65
  def extract_frames(video_path, max_frames=100):
 
78
  if not os.path.exists(video_path):
79
  raise ValueError(f"Video file not found: {video_path}")
80
 
81
+ # Use standard OpenCV VideoCapture with explicit settings to prevent any rotation
82
  cap = cv2.VideoCapture(video_path)
83
 
84
  if not cap.isOpened():
85
  raise ValueError(f"Could not open video: {video_path}")
86
 
87
+ # CRITICAL: Explicitly disable ALL automatic transformations
88
+ # This prevents OpenCV from applying any rotation based on metadata
89
+ try:
90
+ cap.set(cv2.CAP_PROP_ORIENTATION_AUTO, 0) # Disable auto-orientation
91
+ cap.set(cv2.CAP_PROP_ORIENTATION_META, 0) # Ignore orientation metadata
92
+ cap.set(cv2.CAP_PROP_CONVERT_RGB, 0) # Keep BGR format
93
+ except:
94
+ # If properties are not supported, continue without them
95
+ pass
96
+
97
  # Get total frame count
98
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
99
 
 
108
  break
109
 
110
  if current_frame % step == 0:
111
+ # Store frame exactly as read from video - no transformations at all
112
+ # Only verify it's a valid color frame before storing
113
+ if frame is not None and len(frame.shape) == 3:
114
+ frames.append(frame.copy())
115
 
116
  current_frame += 1
117
 
 
127
  2. Top of backswing
128
  3. Impact with ball
129
 
130
+ Simplified version that uses basic OpenCV and handles rotation properly.
 
 
 
 
 
 
131
  """
132
  if not os.path.exists(video_path):
133
  raise ValueError(f"Video file not found: {video_path}")
134
+
135
+ print(f"Extracting key frames from: {video_path}")
136
+
137
+ # Use basic OpenCV VideoCapture
138
  cap = cv2.VideoCapture(video_path)
139
 
140
  if not cap.isOpened():
141
  raise ValueError(f"Could not open video: {video_path}")
142
 
143
+ try:
144
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
145
+ if total_frames <= 0:
146
+ raise ValueError(f"Invalid video: no frames found in {video_path}")
147
+
148
+ print(f"Total frames in video: {total_frames}")
149
+
150
+ # Check for rotation metadata
151
+ rotation_angle = 0
152
+ try:
153
+ # Try to get orientation metadata if available
154
+ orientation = cap.get(cv2.CAP_PROP_ORIENTATION_META)
155
+ if orientation == 90:
156
+ rotation_angle = 270 # Rotate counterclockwise
157
+ elif orientation == 180:
158
+ rotation_angle = 180
159
+ elif orientation == 270:
160
+ rotation_angle = 90 # Rotate counterclockwise
161
+ print(f"Video orientation metadata: {orientation}, applying rotation: {rotation_angle}")
162
+ except:
163
+ print("No orientation metadata available")
164
+
165
+ key_frames = {}
166
+
167
+ # Determine frame indices
168
+ if swing_phases:
169
+ setup_idx = 0 # Always start from beginning
170
+ backswing_idx = swing_phases.get('backswing', [total_frames//3])[-1] if swing_phases.get('backswing') else total_frames//3
171
+ impact_idx = swing_phases.get('impact', [total_frames//2])[len(swing_phases.get('impact', [total_frames//2]))//2] if swing_phases.get('impact') else total_frames//2
172
  else:
173
+ setup_idx = 0
174
+ backswing_idx = total_frames // 3
175
+ impact_idx = int(total_frames * 0.6)
176
+
177
+ print(f"Frame indices - Setup: {setup_idx}, Backswing: {backswing_idx}, Impact: {impact_idx}")
178
+
179
+ # Extract frames for each phase
180
+ phases = [
181
+ ('setup', setup_idx),
182
+ ('backswing', backswing_idx),
183
+ ('impact', impact_idx)
184
+ ]
185
+
186
+ for phase_name, frame_idx in phases:
187
+ frame = _extract_single_frame(cap, frame_idx, total_frames, rotation_angle, phase_name)
188
+ if frame is not None:
189
+ key_frames[phase_name] = frame
190
+ print(f"Successfully extracted {phase_name} frame")
191
+ else:
192
+ print(f"Failed to extract {phase_name} frame")
193
+
194
+ return key_frames
195
+
196
+ except Exception as e:
197
+ raise ValueError(f"Error extracting frames from {video_path}: {str(e)}")
198
+ finally:
199
+ cap.release()
200
+
201
+
202
+ def _extract_single_frame(cap, target_idx, total_frames, rotation_angle, phase_name):
203
+ """
204
+ Extract a single frame from video with validation and rotation correction
205
+ """
206
+ # Try the target frame first
207
+ for attempt_idx in [target_idx, target_idx + 1, target_idx - 1, target_idx + 2, target_idx - 2]:
208
+ if attempt_idx < 0 or attempt_idx >= total_frames:
209
+ continue
210
+
211
+ cap.set(cv2.CAP_PROP_POS_FRAMES, attempt_idx)
212
+ ret, frame = cap.read()
213
+
214
+ if not ret or frame is None:
215
+ print(f"Failed to read frame at index {attempt_idx} for {phase_name}")
216
+ continue
217
+
218
+ # Validate frame has 3 channels (color)
219
+ if len(frame.shape) != 3 or frame.shape[2] != 3:
220
+ print(f"Frame at index {attempt_idx} for {phase_name} is not in color format: {frame.shape}")
221
+ continue
222
+
223
+ print(f"Successfully read frame at index {attempt_idx} for {phase_name}, shape: {frame.shape}")
224
+
225
+ # Apply rotation correction if needed
226
+ if rotation_angle != 0:
227
+ print(f"Before rotation: {frame.shape}")
228
+ frame = _apply_rotation(frame, rotation_angle)
229
+ print(f"After {rotation_angle}° rotation: {frame.shape}")
230
+ print(f"Applied {rotation_angle}° rotation to {phase_name} frame")
231
+
232
+ return frame.copy()
233
 
234
+ print(f"Could not extract valid frame for {phase_name} after trying multiple indices")
235
+ return None
236
+
237
+
238
+ def _apply_rotation(frame, rotation_angle):
239
+ """
240
+ Apply rotation to a frame based on angle
241
+ """
242
+ if rotation_angle == 90:
243
+ return cv2.rotate(frame, cv2.ROTATE_90_COUNTERCLOCKWISE)
244
+ elif rotation_angle == 180:
245
+ return cv2.rotate(frame, cv2.ROTATE_180)
246
+ elif rotation_angle == 270:
247
+ return cv2.rotate(frame, cv2.ROTATE_90_CLOCKWISE)
248
+ else:
249
+ return frame
250
 
251
 
252
  def generate_improvement_comments(phase):
 
344
  if os.path.exists(image_path):
345
  image = cv2.imread(image_path)
346
  if image is not None:
347
+ # Ensure the image is in color format
348
+ image = ensure_color_frame(image)
349
  pro_frames[phase] = image
350
  else:
351
  # Create a placeholder if image can't be loaded
 
357
  return pro_frames
358
 
359
 
360
+ def save_frame_with_orientation(frame, output_path):
361
+ """
362
+ Save a frame using PIL after converting from BGR to RGB.
363
+ Ensures proper color handling and orientation.
364
+
365
+ Args:
366
+ frame (numpy.ndarray): Frame in BGR format (OpenCV)
367
+ output_path (str): Path to save the image
368
+ """
369
+ try:
370
+ if frame is None or frame.size == 0:
371
+ # Save a black image if frame is invalid
372
+ black = np.zeros((480, 640, 3), dtype=np.uint8)
373
+ img = Image.fromarray(black)
374
+ img.save(output_path, format="JPEG", quality=95)
375
+ return
376
+
377
+ # Verify frame is in color (3 channels)
378
+ if len(frame.shape) != 3 or frame.shape[2] != 3:
379
+ raise ValueError(f"Frame is not in color format. Shape: {frame.shape}")
380
+
381
+ # Convert BGR (OpenCV) to RGB (PIL)
382
+ rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
383
+
384
+ # Create PIL image and save with high quality
385
+ img = Image.fromarray(rgb_frame)
386
+ img.save(output_path, format="JPEG", quality=95)
387
+
388
+ except Exception as e:
389
+ print(f"Warning: Error saving frame to {output_path}: {str(e)}")
390
+ # Create a fallback black image
391
+ try:
392
+ black = np.zeros((480, 640, 3), dtype=np.uint8)
393
+ img = Image.fromarray(black)
394
+ img.save(output_path, format="JPEG", quality=95)
395
+ except Exception as fallback_error:
396
+ print(f"Error: Could not save fallback image: {str(fallback_error)}")
397
+ raise
398
+
399
+
400
  def create_key_frame_comparison(user_video_path, pro_video_path=None, user_swing_phases=None, pro_swing_phases=None, output_dir="downloads", use_pro_images=True):
401
  """
402
+ Create separate images for 3 key frames from user and pro golfer swings
403
+
404
+ IMPORTANT: This function preserves the original sizes of both user and professional frames.
405
+ No resizing, rotation, or distortion is applied to either frame. Each frame is saved
406
+ as a separate image file at its original resolution.
407
 
408
  Args:
409
  user_video_path (str): Path to the user's golf swing video
410
  pro_video_path (str): Path to the professional golfer's swing video (optional if use_pro_images=True)
411
  user_swing_phases (dict): Optional swing phase data for user video
412
  pro_swing_phases (dict): Optional swing phase data for pro video
413
+ output_dir (str): Directory to save the separate images
414
  use_pro_images (bool): Whether to use provided pro reference images instead of video
415
 
416
  Returns:
417
+ dict: Dictionary with phase names as keys and dictionaries containing
418
+ 'user_image_path', 'pro_image_path', 'title', and 'comments' as values
419
  """
420
  # Extract key frames from user video
421
  user_frames = extract_key_swing_frames(user_video_path, user_swing_phases)
 
439
  user_frame = user_frames.get(phase, np.zeros((480, 640, 3), dtype=np.uint8))
440
  pro_frame = pro_frames.get(phase, np.zeros((480, 640, 3), dtype=np.uint8))
441
 
442
+ # CRITICAL: Keep user frame EXACTLY as extracted - no processing at all
443
+ # Only ensure pro frame is in color format since it comes from reference images
444
+ pro_frame = ensure_color_frame(pro_frame)
 
445
 
446
+ # Save user frame with original size using PIL to ensure correct orientation and color
 
 
 
447
  video_name = os.path.splitext(os.path.basename(user_video_path))[0]
448
+ user_output_path = os.path.join(output_dir, f"{video_name}_{phase}_user.jpg")
449
+ pro_output_path = os.path.join(output_dir, f"{video_name}_{phase}_pro.jpg")
450
 
451
+ # Save user image using PIL (handles BGR->RGB and orientation)
452
+ try:
453
+ save_frame_with_orientation(user_frame, user_output_path)
454
+ user_success = True
455
+ except Exception as e:
456
+ print(f"Warning: Failed to save user image to {user_output_path}: {e}")
457
+ user_success = False
458
+ # Save pro image using OpenCV (as before)
459
+ pro_success = cv2.imwrite(pro_output_path, pro_frame)
460
+
461
+ if user_success:
462
+ print(f"Successfully saved user image: {user_output_path}")
463
+ if not user_success:
464
+ print(f"Warning: Failed to save user image to {user_output_path}")
465
+ if pro_success:
466
+ print(f"Successfully saved pro image: {pro_output_path}")
467
+ if not pro_success:
468
+ print(f"Warning: Failed to save pro image to {pro_output_path}")
469
 
470
  # Get improvement comments
471
  comments = generate_improvement_comments(phase)
472
 
473
  comparison_data[phase] = {
474
+ 'user_image_path': user_output_path,
475
+ 'pro_image_path': pro_output_path,
476
  'title': phase_titles[i],
477
  'comments': comments
478
  }
 
480
  return comparison_data
481
 
482
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
483
  def normalize_frames(frames, target_height=480):
484
  """
485
  Normalize frames to a consistent size while maintaining aspect ratio
 
494
  normalized_frames = []
495
 
496
  for frame in frames:
497
+ # Use the color-safe resize function
498
+ resized = resize_frame_proportionally(frame, target_height)
 
 
 
 
 
 
499
  normalized_frames.append(resized)
500
 
501
  return normalized_frames
 
517
  if not user_frames or not pro_frames:
518
  raise ValueError("Both user and pro frames must be provided")
519
 
520
+ # Ensure all frames are in color format
521
+ user_frames = [ensure_color_frame(frame) for frame in user_frames]
522
+ pro_frames = [ensure_color_frame(frame) for frame in pro_frames]
523
+
524
+ # Get dimensions from first frames
525
+ user_h, user_w = user_frames[0].shape[:2]
526
+ pro_h, pro_w = pro_frames[0].shape[:2]
527
+
528
+ # Choose target height (smaller of the two, capped at 720p)
529
+ target_height = min(user_h, pro_h, 720)
530
+
531
+ # Resize both user and pro frames proportionally to the same height
532
+ user_resized = []
533
+ for frame in user_frames:
534
+ resized = resize_frame_proportionally(frame, target_height)
535
+ user_resized.append(resized)
536
+
537
+ pro_resized = []
538
+ for frame in pro_frames:
539
+ resized = resize_frame_proportionally(frame, target_height)
540
+ pro_resized.append(resized)
541
 
542
  # Ensure we have the same number of frames by duplicating the last frame if needed
543
+ max_frames = max(len(user_resized), len(pro_resized))
544
 
545
+ user_aligned = user_resized.copy()
546
+ while len(user_aligned) < max_frames:
547
+ user_aligned.append(user_aligned[-1])
548
 
549
+ while len(pro_resized) < max_frames:
550
+ pro_resized.append(pro_resized[-1])
551
 
552
  # Create output directory if it doesn't exist
553
  os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True)
554
 
555
+ # Get dimensions for the combined frame using original user frame dimensions
556
+ pro_h, pro_w = pro_resized[0].shape[:2]
 
557
 
558
  # Create a combined frame with padding
559
  padding = 20 # Pixels between the two videos
 
568
  raise IOError(f"Failed to create video writer for {output_path}")
569
 
570
  # Create the combined video
571
+ for i in tqdm(range(min(len(user_aligned), len(pro_resized))), desc="Creating comparison video"):
572
  # Create a blank canvas
573
  combined = np.ones((combined_height, combined_width, 3), dtype=np.uint8) * 255
574
 
 
578
  cv2.putText(combined, "Pro Swing", (user_w + padding + pro_w//2 - 60, 30), font, 1, (0, 0, 0), 2)
579
 
580
  # Add frame number
581
+ cv2.putText(combined, f"Frame: {i+1}/{min(len(user_aligned), len(pro_resized))}",
582
  (10, combined_height - 10), font, 0.5, (0, 0, 0), 1)
583
 
584
+ # Paste user frame at original size and orientation
585
+ y_offset_user = (combined_height - user_h) // 2
586
+ combined[y_offset_user:y_offset_user + user_h, 0:user_w] = user_aligned[i]
587
 
588
  # Paste pro frame
589
+ y_offset_pro = (combined_height - pro_h) // 2
590
+ combined[y_offset_pro:y_offset_pro + pro_h, user_w + padding:user_w + padding + pro_w] = pro_resized[i]
591
 
592
  # Draw vertical line between frames
593
  cv2.line(combined, (user_w + padding//2, 0), (user_w + padding//2, combined_height), (0, 0, 0), 2)