chenemii commited on
Commit
0723bf0
·
1 Parent(s): 3301758

frame analysis

Browse files
app/models/pose_estimator.py CHANGED
@@ -7,12 +7,8 @@ import numpy as np
7
  import mediapipe as mp
8
  from tqdm import tqdm
9
 
10
-
11
  class PoseEstimator:
12
- """MediaPipe-based pose estimator for golf swing analysis"""
13
-
14
  def __init__(self):
15
- """Initialize the pose estimator"""
16
  self.mp_pose = mp.solutions.pose
17
  self.pose = self.mp_pose.Pose(static_image_mode=False,
18
  model_complexity=1,
@@ -21,40 +17,26 @@ class PoseEstimator:
21
  min_tracking_confidence=0.5)
22
 
23
  def process_frame(self, frame):
24
- """
25
- Process a single frame and extract pose landmarks
26
-
27
- Args:
28
- frame (numpy.ndarray): Input frame
29
-
30
- Returns:
31
- list: List of keypoints [x, y, visibility] or None if not detected
32
- """
33
- # Convert BGR to RGB
34
  frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
35
-
36
- # Process the frame
37
  results = self.pose.process(frame_rgb)
38
-
39
- if not results.pose_landmarks:
40
- return None
41
-
42
- # Extract keypoints
43
  keypoints = []
44
- for landmark in results.pose_landmarks.landmark:
45
- # Convert normalized coordinates to pixel coordinates
46
- h, w, _ = frame.shape
47
- x, y = int(landmark.x * w), int(landmark.y * h)
48
- visibility = landmark.visibility
49
- keypoints.append([x, y, visibility])
 
 
 
 
 
50
 
51
  return keypoints
52
 
53
  def close(self):
54
- """Release resources"""
55
  self.pose.close()
56
 
57
-
58
  def analyze_pose(frames):
59
  """
60
  Analyze pose in video frames
@@ -70,87 +52,55 @@ def analyze_pose(frames):
70
 
71
  for i, frame in enumerate(tqdm(frames, desc="Analyzing pose")):
72
  keypoints = pose_estimator.process_frame(frame)
73
- if keypoints:
74
- pose_data[i] = keypoints
75
 
76
  pose_estimator.close()
77
-
78
  return pose_data
79
 
80
-
81
  def calculate_joint_angles(keypoints):
82
  """
83
- Calculate joint angles from pose keypoints
84
 
85
  Args:
86
- keypoints (list): List of keypoints [x, y, visibility]
87
 
88
  Returns:
89
- dict: Dictionary of joint angles in degrees
90
  """
91
- # Define joint connections for angle calculation
92
- joint_connections = {
93
- "right_shoulder": [
94
- mp.solutions.pose.PoseLandmark.RIGHT_ELBOW.value,
95
- mp.solutions.pose.PoseLandmark.RIGHT_SHOULDER.value,
96
- mp.solutions.pose.PoseLandmark.RIGHT_HIP.value
97
- ],
98
- "left_shoulder": [
99
- mp.solutions.pose.PoseLandmark.LEFT_ELBOW.value,
100
- mp.solutions.pose.PoseLandmark.LEFT_SHOULDER.value,
101
- mp.solutions.pose.PoseLandmark.LEFT_HIP.value
102
- ],
103
- "right_elbow": [
104
- mp.solutions.pose.PoseLandmark.RIGHT_WRIST.value,
105
- mp.solutions.pose.PoseLandmark.RIGHT_ELBOW.value,
106
- mp.solutions.pose.PoseLandmark.RIGHT_SHOULDER.value
107
- ],
108
- "left_elbow": [
109
- mp.solutions.pose.PoseLandmark.LEFT_WRIST.value,
110
- mp.solutions.pose.PoseLandmark.LEFT_ELBOW.value,
111
- mp.solutions.pose.PoseLandmark.LEFT_SHOULDER.value
112
- ],
113
- "right_hip": [
114
- mp.solutions.pose.PoseLandmark.RIGHT_KNEE.value,
115
- mp.solutions.pose.PoseLandmark.RIGHT_HIP.value,
116
- mp.solutions.pose.PoseLandmark.RIGHT_SHOULDER.value
117
- ],
118
- "left_hip": [
119
- mp.solutions.pose.PoseLandmark.LEFT_KNEE.value,
120
- mp.solutions.pose.PoseLandmark.LEFT_HIP.value,
121
- mp.solutions.pose.PoseLandmark.LEFT_SHOULDER.value
122
- ],
123
- "right_knee": [
124
- mp.solutions.pose.PoseLandmark.RIGHT_ANKLE.value,
125
- mp.solutions.pose.PoseLandmark.RIGHT_KNEE.value,
126
- mp.solutions.pose.PoseLandmark.RIGHT_HIP.value
127
- ],
128
- "left_knee": [
129
- mp.solutions.pose.PoseLandmark.LEFT_ANKLE.value,
130
- mp.solutions.pose.PoseLandmark.LEFT_KNEE.value,
131
- mp.solutions.pose.PoseLandmark.LEFT_HIP.value
132
- ]
133
- }
134
-
135
  angles = {}
136
-
137
- for joint_name, landmarks in joint_connections.items():
138
- # Get the three points that form the angle
139
- if all(landmarks[i] < len(keypoints) for i in range(3)):
140
- p1 = np.array(keypoints[landmarks[0]][:2])
141
- p2 = np.array(keypoints[landmarks[1]][:2])
142
- p3 = np.array(keypoints[landmarks[2]][:2])
143
-
144
- # Calculate vectors
145
- v1 = p1 - p2
146
- v2 = p3 - p2
147
-
148
- # Calculate angle
149
- cosine_angle = np.dot(
150
- v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2))
151
- angle = np.arccos(np.clip(cosine_angle, -1.0, 1.0))
152
- angle_degrees = np.degrees(angle)
153
-
154
- angles[joint_name] = angle_degrees
155
-
156
- return angles
 
 
 
 
 
 
 
 
 
 
 
 
7
  import mediapipe as mp
8
  from tqdm import tqdm
9
 
 
10
  class PoseEstimator:
 
 
11
  def __init__(self):
 
12
  self.mp_pose = mp.solutions.pose
13
  self.pose = self.mp_pose.Pose(static_image_mode=False,
14
  model_complexity=1,
 
17
  min_tracking_confidence=0.5)
18
 
19
  def process_frame(self, frame):
 
 
 
 
 
 
 
 
 
 
20
  frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
 
 
21
  results = self.pose.process(frame_rgb)
 
 
 
 
 
22
  keypoints = []
23
+ h, w, _ = frame.shape
24
+
25
+ if results.pose_landmarks:
26
+ for landmark in results.pose_landmarks.landmark:
27
+ x, y = int(landmark.x * w), int(landmark.y * h)
28
+ visibility = landmark.visibility
29
+ keypoints.append([x, y, visibility])
30
+ else:
31
+ center_x, center_y = w // 2, h // 2
32
+ for _ in range(33):
33
+ keypoints.append([center_x, center_y, 0.0])
34
 
35
  return keypoints
36
 
37
  def close(self):
 
38
  self.pose.close()
39
 
 
40
  def analyze_pose(frames):
41
  """
42
  Analyze pose in video frames
 
52
 
53
  for i, frame in enumerate(tqdm(frames, desc="Analyzing pose")):
54
  keypoints = pose_estimator.process_frame(frame)
55
+ # Store all frames, even if no pose is detected
56
+ pose_data[i] = keypoints if keypoints is not None else []
57
 
58
  pose_estimator.close()
 
59
  return pose_data
60
 
 
61
  def calculate_joint_angles(keypoints):
62
  """
63
+ Calculate joint angles from keypoints.
64
 
65
  Args:
66
+ keypoints: List of [x, y, visibility] for each landmark
67
 
68
  Returns:
69
+ Dictionary of joint angles
70
  """
71
+ if not keypoints or len(keypoints) < 33: # MediaPipe Pose has 33 landmarks
72
+ return {}
73
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  angles = {}
75
+
76
+ # Right shoulder angle (landmarks 11, 13, 15)
77
+ if all(keypoints[i][2] > 0.5 for i in [11, 13, 15]):
78
+ shoulder = np.array(keypoints[11][:2])
79
+ elbow = np.array(keypoints[13][:2])
80
+ wrist = np.array(keypoints[15][:2])
81
+ v1 = shoulder - elbow
82
+ v2 = wrist - elbow
83
+ angle = np.degrees(np.arccos(np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2))))
84
+ angles["right_shoulder"] = angle
85
+
86
+ # Right elbow angle (landmarks 13, 15, 17)
87
+ if all(keypoints[i][2] > 0.5 for i in [13, 15, 17]):
88
+ upper_arm = np.array(keypoints[13][:2])
89
+ elbow = np.array(keypoints[15][:2])
90
+ wrist = np.array(keypoints[17][:2])
91
+ v1 = upper_arm - elbow
92
+ v2 = wrist - elbow
93
+ angle = np.degrees(np.arccos(np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2))))
94
+ angles["right_elbow"] = angle
95
+
96
+ # Right wrist angle (landmarks 15, 17, 19)
97
+ if all(keypoints[i][2] > 0.5 for i in [15, 17, 19]):
98
+ elbow = np.array(keypoints[15][:2])
99
+ wrist = np.array(keypoints[17][:2])
100
+ hand = np.array(keypoints[19][:2])
101
+ v1 = elbow - wrist
102
+ v2 = hand - wrist
103
+ angle = np.degrees(np.arccos(np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2))))
104
+ angles["right_wrist"] = angle
105
+
106
+ return angles
app/models/swing_analyzer.py CHANGED
@@ -3,187 +3,142 @@ Swing analysis module for golf swing segmentation and trajectory analysis
3
  """
4
 
5
  import numpy as np
6
- import cv2
7
  from app.models.pose_estimator import calculate_joint_angles
8
 
9
-
10
- def segment_swing(pose_data, detections, sample_rate=5):
11
- """
12
- Segment the golf swing into key phases
13
-
14
- Args:
15
- pose_data (dict): Dictionary mapping frame indices to pose keypoints
16
- detections (list): List of Detection objects
17
- sample_rate (int): The frame sampling rate used during processing
18
-
19
- Returns:
20
- dict: Dictionary mapping phase names to lists of frame indices
21
- """
22
- # Initialize swing phases
23
- swing_phases = {
24
- "setup": [],
25
- "backswing": [],
26
- "downswing": [],
27
- "impact": [],
28
- "follow_through": []
29
- }
30
-
31
- # Get frame indices with pose data
32
  frame_indices = sorted(pose_data.keys())
33
-
34
  if not frame_indices:
35
  return swing_phases
36
-
37
- # Auto-adjust sample rate based on number of frames
38
- # For short videos (less than 150 frames), don't skip any frames
39
- if len(frame_indices) < 150 and sample_rate > 1:
40
- # Get the max frame idx to understand video length
41
- max_frame_idx = max(frame_indices) if frame_indices else 0
42
- # For videos with less than 150 frames, use sample_rate=1
43
- if max_frame_idx < 150:
44
- sample_rate = 1
45
-
46
- # Calculate joint angles for each frame
47
  angles_by_frame = {}
48
  for idx in frame_indices:
49
  keypoints = pose_data[idx]
50
  angles = calculate_joint_angles(keypoints)
51
  angles_by_frame[idx] = angles
52
 
53
- # Analyze shoulder rotation to identify swing phases
54
- # This is a simplified approach - a more sophisticated algorithm would be needed for production
 
 
55
 
56
- # Find the frame with the maximum right shoulder angle (top of backswing)
57
- max_shoulder_angle = -1
58
- top_backswing_frame = frame_indices[0]
 
 
 
 
 
 
 
59
 
 
 
60
  for idx in frame_indices:
61
- angles = angles_by_frame[idx]
62
- if "right_shoulder" in angles and angles[
63
- "right_shoulder"] > max_shoulder_angle:
64
- max_shoulder_angle = angles["right_shoulder"]
 
65
  top_backswing_frame = idx
66
 
67
- # Find impact frame (when club meets ball)
68
- # In a real implementation, this would use club and ball detection
69
  impact_frame = None
70
- person_positions = {}
71
-
72
- # Extract person positions from detections
73
- for detection in detections:
74
- if detection.class_name == "person":
75
- frame_idx = detection.frame_idx // sample_rate # Convert to processed frame index
76
- if frame_idx in frame_indices:
77
- person_positions[frame_idx] = detection.bbox
78
-
79
- # Find the frame with the most forward position (impact)
80
- if person_positions:
81
- min_x = float('inf')
82
- for idx, bbox in person_positions.items():
83
- if idx > top_backswing_frame and bbox[0] < min_x:
84
- min_x = bbox[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  impact_frame = idx
86
 
87
- # If impact frame not found, estimate it as 2/3 between top of backswing and end
88
  if impact_frame is None:
89
- impact_frame = frame_indices[0] + int(
90
- (frame_indices[-1] - top_backswing_frame) * 2 / 3)
91
 
92
- # Assign frames to phases
93
  for idx in frame_indices:
94
- if idx < frame_indices[len(frame_indices) // 5]:
95
- # First 20% of frames are setup
96
  swing_phases["setup"].append(idx)
97
- elif idx < top_backswing_frame:
98
- # Frames before top of backswing are backswing
99
  swing_phases["backswing"].append(idx)
100
  elif idx < impact_frame:
101
- # Frames between top of backswing and impact are downswing
102
  swing_phases["downswing"].append(idx)
103
- elif idx < impact_frame + 5:
104
- # Frames around impact
105
  swing_phases["impact"].append(idx)
106
  else:
107
- # Remaining frames are follow-through
108
  swing_phases["follow_through"].append(idx)
109
 
110
  return swing_phases
111
 
112
-
113
- def analyze_trajectory(frames, detections, swing_phases, sample_rate=5):
114
- """
115
- Analyze club and ball trajectory and speed
116
-
117
- Args:
118
- frames (list): List of video frames
119
- detections (list): List of Detection objects
120
- swing_phases (dict): Dictionary mapping phase names to lists of frame indices
121
- sample_rate (int): The frame sampling rate used during processing
122
-
123
- Returns:
124
- dict: Dictionary mapping frame indices to trajectory data
125
- """
126
  trajectory_data = {}
127
-
128
- # Auto-adjust sample rate based on number of frames
129
- # For short videos (less than 150 frames), don't skip any frames
130
- if len(frames) < 150 and sample_rate > 1:
131
  sample_rate = 1
132
 
133
- # Extract ball detections
134
  ball_detections = [d for d in detections if d.class_name == "sports ball"]
135
-
136
- # Get impact frame index
137
  impact_frames = swing_phases.get("impact", [])
138
  if not impact_frames:
139
  return trajectory_data
140
 
141
  impact_frame_idx = impact_frames[len(impact_frames) // 2]
142
-
143
- # Track ball trajectory after impact
144
  ball_trajectory = []
145
  ball_positions = {}
146
 
147
  for detection in ball_detections:
148
- frame_idx = detection.frame_idx // sample_rate # Convert to processed frame index
149
  if frame_idx >= impact_frame_idx:
150
- # Calculate ball center
151
  x1, y1, x2, y2 = detection.bbox
152
  center_x = (x1 + x2) / 2
153
  center_y = (y1 + y2) / 2
154
  ball_positions[frame_idx] = (center_x, center_y)
155
 
156
- # Sort ball positions by frame index
157
  sorted_frames = sorted(ball_positions.keys())
158
  for idx in sorted_frames:
159
  ball_trajectory.append(ball_positions[idx])
160
 
161
- # Estimate club speed at impact
162
- # In a real implementation, this would use more sophisticated tracking
163
  club_speed = None
164
- if len(swing_phases.get("downswing", [])) >= 2:
165
- # Simplified club speed calculation
166
- # In reality, this would require tracking the club head specifically
167
- downswing_frames = swing_phases["downswing"]
168
- # Account for sample rate when calculating time difference
169
- actual_frames_elapsed = (downswing_frames[-1] - downswing_frames[0]) * sample_rate
170
- time_diff = actual_frames_elapsed / 30 # Assuming 30 fps
171
  if time_diff > 0:
172
- # Simplified speed calculation (just an example)
173
- club_speed = 100 * (1 / time_diff) # Arbitrary scaling
174
 
175
- # Populate trajectory data
176
- for idx in sorted(swing_phases.keys()):
177
- frames_in_phase = swing_phases[idx]
178
  for frame_idx in frames_in_phase:
179
  trajectory_data[frame_idx] = {
180
- "phase":
181
- idx,
182
- "club_speed":
183
- club_speed if idx == "impact" else None,
184
- "ball_trajectory":
185
- ball_trajectory
186
- if idx == "impact" or idx == "follow_through" else None
187
  }
188
 
189
- return trajectory_data
 
3
  """
4
 
5
  import numpy as np
 
6
  from app.models.pose_estimator import calculate_joint_angles
7
 
8
+ def segment_swing(pose_data, detections, sample_rate=1):
9
+ swing_phases = {"setup": [], "backswing": [], "downswing": [], "impact": [], "follow_through": []}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  frame_indices = sorted(pose_data.keys())
 
11
  if not frame_indices:
12
  return swing_phases
13
+
 
 
 
 
 
 
 
 
 
 
14
  angles_by_frame = {}
15
  for idx in frame_indices:
16
  keypoints = pose_data[idx]
17
  angles = calculate_joint_angles(keypoints)
18
  angles_by_frame[idx] = angles
19
 
20
+ setup_end = frame_indices[0]
21
+ initial_angles = angles_by_frame[frame_indices[0]]
22
+ initial_shoulder = initial_angles.get("right_shoulder")
23
+ initial_wrist = initial_angles.get("right_elbow")
24
 
25
+ for idx in frame_indices[1:]:
26
+ angles = angles_by_frame[idx]
27
+ shoulder = angles.get("right_shoulder")
28
+ wrist = angles.get("right_elbow")
29
+ if shoulder and initial_shoulder and abs(shoulder - initial_shoulder) > 10:
30
+ setup_end = idx
31
+ break
32
+ if wrist and initial_wrist and abs(wrist - initial_wrist) > 10:
33
+ setup_end = idx
34
+ break
35
 
36
+ max_shoulder_angle = -1
37
+ top_backswing_frame = setup_end
38
  for idx in frame_indices:
39
+ if idx < setup_end:
40
+ continue
41
+ shoulder = angles_by_frame[idx].get("right_shoulder")
42
+ if shoulder and shoulder > max_shoulder_angle:
43
+ max_shoulder_angle = shoulder
44
  top_backswing_frame = idx
45
 
46
+ # Find impact frame by looking for the point where the club head is at its lowest point
47
+ # during the downswing, before it starts rising in the follow-through
48
  impact_frame = None
49
+ min_wrist_y = float('inf')
50
+ prev_wrist_y = None
51
+ wrist_velocities = []
52
+
53
+ # First pass: collect wrist positions and calculate velocities
54
+ wrist_positions = []
55
+ for idx in frame_indices:
56
+ if idx < top_backswing_frame:
57
+ continue
58
+ keypoints = pose_data[idx]
59
+ if len(keypoints) > 16:
60
+ wrist_y = keypoints[16][1]
61
+ wrist_positions.append((idx, wrist_y))
62
+
63
+ # Calculate velocities between consecutive frames
64
+ for i in range(1, len(wrist_positions)):
65
+ idx, wrist_y = wrist_positions[i]
66
+ prev_idx, prev_y = wrist_positions[i-1]
67
+ velocity = (wrist_y - prev_y) / (idx - prev_idx)
68
+ wrist_velocities.append((idx, velocity))
69
+
70
+ # Find impact as the point where velocity changes from negative (downward) to positive (upward)
71
+ for i in range(1, len(wrist_velocities)):
72
+ idx, velocity = wrist_velocities[i]
73
+ prev_idx, prev_velocity = wrist_velocities[i-1]
74
+ if prev_velocity < 0 and velocity > 0: # Velocity changes from negative to positive
75
+ impact_frame = prev_idx
76
+ break
77
+
78
+ # If no clear impact point found, use the frame with minimum wrist Y position
79
+ if impact_frame is None:
80
+ for idx, wrist_y in wrist_positions:
81
+ if wrist_y < min_wrist_y:
82
+ min_wrist_y = wrist_y
83
  impact_frame = idx
84
 
 
85
  if impact_frame is None:
86
+ impact_frame = frame_indices[-1]
 
87
 
 
88
  for idx in frame_indices:
89
+ if idx <= setup_end:
 
90
  swing_phases["setup"].append(idx)
91
+ elif idx <= top_backswing_frame:
 
92
  swing_phases["backswing"].append(idx)
93
  elif idx < impact_frame:
 
94
  swing_phases["downswing"].append(idx)
95
+ elif idx == impact_frame:
 
96
  swing_phases["impact"].append(idx)
97
  else:
 
98
  swing_phases["follow_through"].append(idx)
99
 
100
  return swing_phases
101
 
102
+ def analyze_trajectory(frames, detections, swing_phases, sample_rate=1):
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  trajectory_data = {}
104
+ if len(frames) < 150:
 
 
 
105
  sample_rate = 1
106
 
 
107
  ball_detections = [d for d in detections if d.class_name == "sports ball"]
 
 
108
  impact_frames = swing_phases.get("impact", [])
109
  if not impact_frames:
110
  return trajectory_data
111
 
112
  impact_frame_idx = impact_frames[len(impact_frames) // 2]
 
 
113
  ball_trajectory = []
114
  ball_positions = {}
115
 
116
  for detection in ball_detections:
117
+ frame_idx = detection.frame_idx
118
  if frame_idx >= impact_frame_idx:
 
119
  x1, y1, x2, y2 = detection.bbox
120
  center_x = (x1 + x2) / 2
121
  center_y = (y1 + y2) / 2
122
  ball_positions[frame_idx] = (center_x, center_y)
123
 
 
124
  sorted_frames = sorted(ball_positions.keys())
125
  for idx in sorted_frames:
126
  ball_trajectory.append(ball_positions[idx])
127
 
 
 
128
  club_speed = None
129
+ downswing_frames = swing_phases.get("downswing", [])
130
+ if len(downswing_frames) >= 2:
131
+ actual_frames_elapsed = (downswing_frames[-1] - downswing_frames[0])
132
+ time_diff = actual_frames_elapsed / 30
 
 
 
133
  if time_diff > 0:
134
+ club_speed = 100 * (1 / time_diff)
 
135
 
136
+ for phase_name, frames_in_phase in swing_phases.items():
 
 
137
  for frame_idx in frames_in_phase:
138
  trajectory_data[frame_idx] = {
139
+ "phase": phase_name,
140
+ "club_speed": club_speed if phase_name == "impact" else None,
141
+ "ball_trajectory": ball_trajectory if phase_name in ["impact", "follow_through"] else None
 
 
 
 
142
  }
143
 
144
+ return trajectory_data
app/utils/comparison.py CHANGED
@@ -123,11 +123,16 @@ def extract_frames(video_path, max_frames=100):
123
  def extract_key_swing_frames(video_path, swing_phases=None):
124
  """
125
  Extract 3 key frames from a golf swing video:
126
- 1. Starting position (setup)
127
- 2. Top of backswing
128
- 3. Impact with ball
129
 
130
- Simplified version that uses basic OpenCV and handles rotation properly.
 
 
 
 
 
131
  """
132
  if not os.path.exists(video_path):
133
  raise ValueError(f"Video file not found: {video_path}")
@@ -164,12 +169,21 @@ def extract_key_swing_frames(video_path, swing_phases=None):
164
 
165
  key_frames = {}
166
 
167
- # Determine frame indices
168
  if swing_phases:
169
- setup_idx = 0 # Always start from beginning
170
- backswing_idx = swing_phases.get('backswing', [total_frames//3])[-1] if swing_phases.get('backswing') else total_frames//3
171
- impact_idx = swing_phases.get('impact', [total_frames//2])[len(swing_phases.get('impact', [total_frames//2]))//2] if swing_phases.get('impact') else total_frames//2
 
 
 
 
 
 
 
 
172
  else:
 
173
  setup_idx = 0
174
  backswing_idx = total_frames // 3
175
  impact_idx = int(total_frames * 0.6)
 
123
  def extract_key_swing_frames(video_path, swing_phases=None):
124
  """
125
  Extract 3 key frames from a golf swing video:
126
+ 1. First setup frame
127
+ 2. Last backswing frame (top of backswing)
128
+ 3. First impact frame
129
 
130
+ Args:
131
+ video_path (str): Path to the video file
132
+ swing_phases (dict): Dictionary mapping phase names to lists of frame indices
133
+
134
+ Returns:
135
+ dict: Dictionary mapping phase names to frames
136
  """
137
  if not os.path.exists(video_path):
138
  raise ValueError(f"Video file not found: {video_path}")
 
169
 
170
  key_frames = {}
171
 
172
+ # Determine frame indices based on swing phases
173
  if swing_phases:
174
+ # Get first setup frame
175
+ setup_frames = swing_phases.get('setup', [])
176
+ setup_idx = setup_frames[0] if setup_frames else 0
177
+
178
+ # Get last backswing frame (top of backswing)
179
+ backswing_frames = swing_phases.get('backswing', [])
180
+ backswing_idx = backswing_frames[-1] if backswing_frames else total_frames//3
181
+
182
+ # Get first impact frame
183
+ impact_frames = swing_phases.get('impact', [])
184
+ impact_idx = impact_frames[0] if impact_frames else total_frames//2
185
  else:
186
+ # Fallback to default indices if no swing phases provided
187
  setup_idx = 0
188
  backswing_idx = total_frames // 3
189
  impact_idx = int(total_frames * 0.6)
app/utils/video_processor.py CHANGED
@@ -32,40 +32,28 @@ def process_video(video_path, sample_rate=5):
32
  - frames: List of processed frames
33
  - detections: List of Detection objects
34
  """
35
- # Load YOLOv8 model
36
  model = YOLO("yolov8n.pt")
37
-
38
- # Custom class names for golf-specific objects
39
  class_names = model.names
40
 
41
- # Open video file
42
  cap = cv2.VideoCapture(video_path)
43
  if not cap.isOpened():
44
  raise ValueError("Error opening video file")
45
 
46
- # Get video properties
47
  frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
48
- fps = cap.get(cv2.CAP_PROP_FPS)
49
-
50
- # Auto-adjust sample rate based on video length
51
- # For short videos (less than 150 frames), don't skip any frames
52
- if frame_count < 150 and sample_rate > 1:
53
  print(f"Short video detected ({frame_count} frames). Processing all frames.")
54
  sample_rate = 1
55
 
56
  frames = []
57
  detections = []
58
 
59
- # Process frames
60
- for frame_idx in tqdm(range(0, frame_count, sample_rate),
61
- desc="Processing frames"):
62
- # Set frame position
63
  cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
64
-
65
- # Read frame
66
  ret, frame = cap.read()
67
  if not ret:
68
- break
 
69
 
70
  # Store original frame
71
  frames.append(frame)
@@ -77,21 +65,11 @@ def process_video(video_path, sample_rate=5):
77
  for result in results:
78
  boxes = result.boxes
79
  for box in boxes:
80
- # Get detection information
81
  class_id = int(box.cls.item())
82
  class_name = class_names[class_id]
 
 
 
83
 
84
- # Filter for relevant objects (person, sports ball)
85
- if class_name in ["person", "sports ball"]:
86
- bbox = box.xyxy[0].tolist() # [x1, y1, x2, y2]
87
- confidence = box.conf.item()
88
-
89
- # Create Detection object
90
- detection = Detection(frame_idx, class_id, class_name,
91
- bbox, confidence)
92
- detections.append(detection)
93
-
94
- # Release video capture
95
  cap.release()
96
-
97
- return frames, detections
 
32
  - frames: List of processed frames
33
  - detections: List of Detection objects
34
  """
 
35
  model = YOLO("yolov8n.pt")
 
 
36
  class_names = model.names
37
 
 
38
  cap = cv2.VideoCapture(video_path)
39
  if not cap.isOpened():
40
  raise ValueError("Error opening video file")
41
 
 
42
  frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
43
+
44
+ if frame_count < 150:
 
 
 
45
  print(f"Short video detected ({frame_count} frames). Processing all frames.")
46
  sample_rate = 1
47
 
48
  frames = []
49
  detections = []
50
 
51
+ for frame_idx in tqdm(range(0, frame_count, sample_rate), desc="Processing frames"):
 
 
 
52
  cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
 
 
53
  ret, frame = cap.read()
54
  if not ret:
55
+ print(f"Warning: Could not read frame {frame_idx}")
56
+ continue
57
 
58
  # Store original frame
59
  frames.append(frame)
 
65
  for result in results:
66
  boxes = result.boxes
67
  for box in boxes:
 
68
  class_id = int(box.cls.item())
69
  class_name = class_names[class_id]
70
+ bbox = box.xyxy[0].tolist()
71
+ confidence = box.conf.item()
72
+ detections.append(Detection(frame_idx, class_id, class_name, bbox, confidence))
73
 
 
 
 
 
 
 
 
 
 
 
 
74
  cap.release()
75
+ return frames, detections