Spaces:
Paused
Paused
frame impact
Browse files- app/models/swing_analyzer.py +120 -67
- app/streamlit_app.py +3 -2
- app/utils/comparison.py +62 -112
- app/utils/video_processor.py +11 -2
app/models/swing_analyzer.py
CHANGED
|
@@ -5,84 +5,125 @@ Swing analysis module for golf swing segmentation and trajectory analysis
|
|
| 5 |
import numpy as np
|
| 6 |
from app.models.pose_estimator import calculate_joint_angles
|
| 7 |
|
| 8 |
-
def segment_swing(pose_data, detections, sample_rate=1):
|
| 9 |
-
swing_phases = {"setup": [], "backswing": [], "downswing": [], "impact": [], "follow_through": []}
|
| 10 |
-
frame_indices = sorted(pose_data.keys())
|
| 11 |
-
if not frame_indices:
|
| 12 |
-
return swing_phases
|
| 13 |
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
for idx in frame_indices:
|
| 16 |
keypoints = pose_data[idx]
|
| 17 |
angles = calculate_joint_angles(keypoints)
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
-
# --- Dynamic Phase Segmentation ---
|
| 21 |
-
# 1. Setup: before any significant movement
|
| 22 |
-
# 2. Backswing: from end of setup to top of backswing
|
| 23 |
-
# 3. Downswing: from top of backswing to just before impact
|
| 24 |
-
# 4. Impact: frame(s) where ball first moves
|
| 25 |
-
# 5. Follow-through: after impact
|
| 26 |
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
setup_end = frame_indices[0]
|
| 29 |
-
initial_angles =
|
| 30 |
-
initial_shoulder = initial_angles.get("right_shoulder")
|
| 31 |
-
|
| 32 |
-
movement_threshold = 8 # degrees, can be tuned
|
| 33 |
for idx in frame_indices[1:]:
|
| 34 |
-
angles =
|
| 35 |
-
shoulder = angles.get("right_shoulder")
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
(wrist and initial_wrist and abs(wrist - initial_wrist) > movement_threshold):
|
| 39 |
-
setup_end = idx - 1
|
| 40 |
break
|
| 41 |
|
| 42 |
-
#
|
| 43 |
-
|
| 44 |
-
top_backswing_frame = setup_end + 1
|
| 45 |
-
for idx in frame_indices:
|
| 46 |
-
if idx <= setup_end:
|
| 47 |
-
continue
|
| 48 |
-
shoulder = angles_by_frame[idx].get("right_shoulder")
|
| 49 |
-
if shoulder and shoulder > max_shoulder_angle:
|
| 50 |
-
max_shoulder_angle = shoulder
|
| 51 |
-
top_backswing_frame = idx
|
| 52 |
|
| 53 |
-
#
|
| 54 |
-
impact_frame =
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
x1, y1, x2, y2 = detection.bbox
|
| 65 |
-
ball_x = (x1 + x2) / 2
|
| 66 |
-
ball_y = (y1 + y2) / 2
|
| 67 |
-
if prev_x is not None and prev_y is not None:
|
| 68 |
-
dx = abs(ball_x - prev_x)
|
| 69 |
-
dy = abs(ball_y - prev_y)
|
| 70 |
-
if dx > movement_threshold_px or dy > movement_threshold_px:
|
| 71 |
-
impact_frame = frame_idx
|
| 72 |
-
break
|
| 73 |
-
prev_x = ball_x
|
| 74 |
-
prev_y = ball_y
|
| 75 |
-
prev_frame = frame_idx
|
| 76 |
-
if impact_frame is None and prev_frame is not None:
|
| 77 |
-
impact_frame = prev_frame
|
| 78 |
-
if impact_frame is None:
|
| 79 |
-
impact_frame = frame_indices[-1]
|
| 80 |
-
|
| 81 |
-
# --- 4. Assign phases dynamically ---
|
| 82 |
for idx in frame_indices:
|
| 83 |
if idx <= setup_end:
|
| 84 |
swing_phases["setup"].append(idx)
|
| 85 |
-
elif idx <=
|
| 86 |
swing_phases["backswing"].append(idx)
|
| 87 |
elif idx < impact_frame:
|
| 88 |
swing_phases["downswing"].append(idx)
|
|
@@ -93,7 +134,19 @@ def segment_swing(pose_data, detections, sample_rate=1):
|
|
| 93 |
|
| 94 |
return swing_phases
|
| 95 |
|
| 96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
trajectory_data = {}
|
| 98 |
if len(frames) < 150:
|
| 99 |
sample_rate = 1
|
|
@@ -108,7 +161,7 @@ def analyze_trajectory(frames, detections, swing_phases, sample_rate=1):
|
|
| 108 |
ball_positions = {}
|
| 109 |
|
| 110 |
for detection in ball_detections:
|
| 111 |
-
frame_idx = detection.frame_idx
|
| 112 |
if frame_idx >= impact_frame_idx:
|
| 113 |
x1, y1, x2, y2 = detection.bbox
|
| 114 |
center_x = (x1 + x2) / 2
|
|
@@ -122,7 +175,7 @@ def analyze_trajectory(frames, detections, swing_phases, sample_rate=1):
|
|
| 122 |
club_speed = None
|
| 123 |
downswing_frames = swing_phases.get("downswing", [])
|
| 124 |
if len(downswing_frames) >= 2:
|
| 125 |
-
actual_frames_elapsed = (downswing_frames[-1] - downswing_frames[0])
|
| 126 |
time_diff = actual_frames_elapsed / 30
|
| 127 |
if time_diff > 0:
|
| 128 |
club_speed = 100 * (1 / time_diff)
|
|
|
|
| 5 |
import numpy as np
|
| 6 |
from app.models.pose_estimator import calculate_joint_angles
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
+
def find_top_of_backswing(pose_data):
|
| 10 |
+
"""Helper function to find the peak of backswing"""
|
| 11 |
+
frame_indices = sorted(pose_data.keys())
|
| 12 |
+
max_shoulder_angle = -1
|
| 13 |
+
top_frame = frame_indices[0]
|
| 14 |
+
|
| 15 |
for idx in frame_indices:
|
| 16 |
keypoints = pose_data[idx]
|
| 17 |
angles = calculate_joint_angles(keypoints)
|
| 18 |
+
shoulder = angles.get("right_shoulder", 0)
|
| 19 |
+
if shoulder > max_shoulder_angle:
|
| 20 |
+
max_shoulder_angle = shoulder
|
| 21 |
+
top_frame = idx
|
| 22 |
+
|
| 23 |
+
return top_frame
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
+
def detect_impact_frame(pose_data, detections, sample_rate=2):
|
| 27 |
+
"""
|
| 28 |
+
Simple impact detection: ball movement first, wrist speed fallback
|
| 29 |
+
"""
|
| 30 |
+
frame_indices = sorted(pose_data.keys())
|
| 31 |
+
if len(frame_indices) < 10:
|
| 32 |
+
return None
|
| 33 |
+
|
| 34 |
+
top_backswing = find_top_of_backswing(pose_data)
|
| 35 |
+
downswing_frames = [f for f in frame_indices if f > top_backswing]
|
| 36 |
+
|
| 37 |
+
if not downswing_frames:
|
| 38 |
+
return None
|
| 39 |
+
|
| 40 |
+
# Method 1: Ball movement (if we have ball detections)
|
| 41 |
+
if detections:
|
| 42 |
+
ball_detections = [d for d in detections if d.class_name == "sports ball"]
|
| 43 |
+
ball_positions = {}
|
| 44 |
+
|
| 45 |
+
for detection in ball_detections:
|
| 46 |
+
frame_idx = detection.frame_idx // sample_rate
|
| 47 |
+
if frame_idx > top_backswing:
|
| 48 |
+
x1, y1, x2, y2 = detection.bbox
|
| 49 |
+
center_x, center_y = (x1 + x2) / 2, (y1 + y2) / 2
|
| 50 |
+
ball_positions[frame_idx] = (center_x, center_y)
|
| 51 |
+
|
| 52 |
+
# Find first significant ball movement
|
| 53 |
+
if len(ball_positions) >= 2:
|
| 54 |
+
sorted_frames = sorted(ball_positions.keys())
|
| 55 |
+
for i in range(1, len(sorted_frames)):
|
| 56 |
+
curr_pos = ball_positions[sorted_frames[i]]
|
| 57 |
+
prev_pos = ball_positions[sorted_frames[i-1]]
|
| 58 |
+
movement = np.sqrt((curr_pos[0] - prev_pos[0])**2 + (curr_pos[1] - prev_pos[1])**2)
|
| 59 |
+
|
| 60 |
+
if movement > 15: # Significant movement threshold
|
| 61 |
+
print(f"Impact detected via ball movement at frame {sorted_frames[i]}")
|
| 62 |
+
return sorted_frames[i]
|
| 63 |
+
|
| 64 |
+
# Method 2: Wrist speed fallback (simple and reliable)
|
| 65 |
+
max_wrist_speed = 0
|
| 66 |
+
impact_frame = None
|
| 67 |
+
|
| 68 |
+
for i in range(1, len(downswing_frames)):
|
| 69 |
+
curr_frame = downswing_frames[i]
|
| 70 |
+
prev_frame = downswing_frames[i-1]
|
| 71 |
+
|
| 72 |
+
curr_angles = calculate_joint_angles(pose_data[curr_frame])
|
| 73 |
+
prev_angles = calculate_joint_angles(pose_data[prev_frame])
|
| 74 |
+
|
| 75 |
+
curr_wrist = curr_angles.get("right_wrist", 0)
|
| 76 |
+
prev_wrist = prev_angles.get("right_wrist", 0)
|
| 77 |
+
wrist_speed = abs(curr_wrist - prev_wrist)
|
| 78 |
+
|
| 79 |
+
if wrist_speed > max_wrist_speed:
|
| 80 |
+
max_wrist_speed = wrist_speed
|
| 81 |
+
impact_frame = curr_frame
|
| 82 |
+
|
| 83 |
+
print(f"Impact detected via wrist speed at frame {impact_frame}")
|
| 84 |
+
return impact_frame or downswing_frames[len(downswing_frames) // 3]
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def segment_swing_pose_based(pose_data, detections=None, sample_rate=2):
|
| 88 |
+
"""
|
| 89 |
+
Simple swing segmentation with clean impact detection
|
| 90 |
+
"""
|
| 91 |
+
swing_phases = {"setup": [], "backswing": [], "downswing": [], "impact": [], "follow_through": []}
|
| 92 |
+
frame_indices = sorted(pose_data.keys())
|
| 93 |
+
|
| 94 |
+
if not frame_indices:
|
| 95 |
+
return swing_phases
|
| 96 |
+
|
| 97 |
+
# 1. Find setup end (first significant movement)
|
| 98 |
setup_end = frame_indices[0]
|
| 99 |
+
initial_angles = calculate_joint_angles(pose_data[frame_indices[0]])
|
| 100 |
+
initial_shoulder = initial_angles.get("right_shoulder", 0)
|
| 101 |
+
|
|
|
|
| 102 |
for idx in frame_indices[1:]:
|
| 103 |
+
angles = calculate_joint_angles(pose_data[idx])
|
| 104 |
+
shoulder = angles.get("right_shoulder", 0)
|
| 105 |
+
if abs(shoulder - initial_shoulder) > 10:
|
| 106 |
+
setup_end = max(frame_indices[0], idx - 2)
|
|
|
|
|
|
|
| 107 |
break
|
| 108 |
|
| 109 |
+
# 2. Find top of backswing
|
| 110 |
+
top_backswing = find_top_of_backswing(pose_data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
|
| 112 |
+
# 3. Find impact frame
|
| 113 |
+
impact_frame = detect_impact_frame(pose_data, detections, sample_rate)
|
| 114 |
+
|
| 115 |
+
# Simple validation and fallback
|
| 116 |
+
if not impact_frame or impact_frame <= top_backswing:
|
| 117 |
+
downswing_frames = [f for f in frame_indices if f > top_backswing]
|
| 118 |
+
impact_frame = downswing_frames[len(downswing_frames) // 3] if downswing_frames else top_backswing + 1
|
| 119 |
+
|
| 120 |
+
print(f"Swing phases: Setup end={setup_end}, Top backswing={top_backswing}, Impact={impact_frame}")
|
| 121 |
+
|
| 122 |
+
# 4. Assign phases
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
for idx in frame_indices:
|
| 124 |
if idx <= setup_end:
|
| 125 |
swing_phases["setup"].append(idx)
|
| 126 |
+
elif idx <= top_backswing:
|
| 127 |
swing_phases["backswing"].append(idx)
|
| 128 |
elif idx < impact_frame:
|
| 129 |
swing_phases["downswing"].append(idx)
|
|
|
|
| 134 |
|
| 135 |
return swing_phases
|
| 136 |
|
| 137 |
+
|
| 138 |
+
# Wrapper function to maintain compatibility with existing Streamlit app
|
| 139 |
+
def segment_swing(pose_data, detections, sample_rate=2):
|
| 140 |
+
"""
|
| 141 |
+
Main swing segmentation function (wrapper for pose-based approach)
|
| 142 |
+
"""
|
| 143 |
+
return segment_swing_pose_based(pose_data, detections, sample_rate)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def analyze_trajectory(frames, detections, swing_phases, sample_rate=2):
|
| 147 |
+
"""
|
| 148 |
+
Analyze ball trajectory and calculate club speed
|
| 149 |
+
"""
|
| 150 |
trajectory_data = {}
|
| 151 |
if len(frames) < 150:
|
| 152 |
sample_rate = 1
|
|
|
|
| 161 |
ball_positions = {}
|
| 162 |
|
| 163 |
for detection in ball_detections:
|
| 164 |
+
frame_idx = detection.frame_idx // sample_rate
|
| 165 |
if frame_idx >= impact_frame_idx:
|
| 166 |
x1, y1, x2, y2 = detection.bbox
|
| 167 |
center_x = (x1 + x2) / 2
|
|
|
|
| 175 |
club_speed = None
|
| 176 |
downswing_frames = swing_phases.get("downswing", [])
|
| 177 |
if len(downswing_frames) >= 2:
|
| 178 |
+
actual_frames_elapsed = (downswing_frames[-1] - downswing_frames[0]) * sample_rate
|
| 179 |
time_diff = actual_frames_elapsed / 30
|
| 180 |
if time_diff > 0:
|
| 181 |
club_speed = 100 * (1 / time_diff)
|
app/streamlit_app.py
CHANGED
|
@@ -166,7 +166,7 @@ def main():
|
|
| 166 |
"Frame Skip Rate (YOLO)",
|
| 167 |
min_value=1,
|
| 168 |
max_value=10,
|
| 169 |
-
value=
|
| 170 |
help=
|
| 171 |
"Process every Nth frame. Higher values = faster but less accurate.")
|
| 172 |
|
|
@@ -476,7 +476,8 @@ def main():
|
|
| 476 |
with st.spinner("Extracting key frames from your swing..."):
|
| 477 |
user_video_path = st.session_state.analysis_data['video_path']
|
| 478 |
user_swing_phases = st.session_state.analysis_data['swing_phases']
|
| 479 |
-
|
|
|
|
| 480 |
|
| 481 |
st.success("Key frame analysis complete!")
|
| 482 |
st.subheader("Key Frame Analysis: Your Swing's Critical Positions")
|
|
|
|
| 166 |
"Frame Skip Rate (YOLO)",
|
| 167 |
min_value=1,
|
| 168 |
max_value=10,
|
| 169 |
+
value=2,
|
| 170 |
help=
|
| 171 |
"Process every Nth frame. Higher values = faster but less accurate.")
|
| 172 |
|
|
|
|
| 476 |
with st.spinner("Extracting key frames from your swing..."):
|
| 477 |
user_video_path = st.session_state.analysis_data['video_path']
|
| 478 |
user_swing_phases = st.session_state.analysis_data['swing_phases']
|
| 479 |
+
frames = st.session_state.analysis_data['frames']
|
| 480 |
+
key_frames = extract_key_swing_frames(user_video_path, frames, user_swing_phases)
|
| 481 |
|
| 482 |
st.success("Key frame analysis complete!")
|
| 483 |
st.subheader("Key Frame Analysis: Your Swing's Critical Positions")
|
app/utils/comparison.py
CHANGED
|
@@ -120,133 +120,83 @@ def extract_frames(video_path, max_frames=100):
|
|
| 120 |
return frames
|
| 121 |
|
| 122 |
|
| 123 |
-
def extract_key_swing_frames(video_path, swing_phases=None):
|
| 124 |
"""
|
| 125 |
-
Extract 3 key frames from a
|
| 126 |
1. First setup frame
|
| 127 |
2. Last backswing frame (top of backswing)
|
| 128 |
3. First impact frame
|
| 129 |
|
| 130 |
Args:
|
| 131 |
-
video_path (str): Path to the video file
|
|
|
|
| 132 |
swing_phases (dict): Dictionary mapping phase names to lists of frame indices
|
|
|
|
| 133 |
|
| 134 |
Returns:
|
| 135 |
dict: Dictionary mapping phase names to frames
|
| 136 |
"""
|
| 137 |
-
|
| 138 |
-
raise ValueError(f"Video file not found: {video_path}")
|
| 139 |
-
|
| 140 |
-
print(f"Extracting key frames from: {video_path}")
|
| 141 |
-
|
| 142 |
-
# Use basic OpenCV VideoCapture
|
| 143 |
-
cap = cv2.VideoCapture(video_path)
|
| 144 |
-
|
| 145 |
-
if not cap.isOpened():
|
| 146 |
-
raise ValueError(f"Could not open video: {video_path}")
|
| 147 |
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
raise ValueError(f"Invalid video: no frames found in {video_path}")
|
| 152 |
-
|
| 153 |
-
print(f"Total frames in video: {total_frames}")
|
| 154 |
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
if orientation == 90:
|
| 161 |
-
rotation_angle = 270 # Rotate counterclockwise
|
| 162 |
-
elif orientation == 180:
|
| 163 |
-
rotation_angle = 180
|
| 164 |
-
elif orientation == 270:
|
| 165 |
-
rotation_angle = 90 # Rotate counterclockwise
|
| 166 |
-
print(f"Video orientation metadata: {orientation}, applying rotation: {rotation_angle}")
|
| 167 |
-
except:
|
| 168 |
-
print("No orientation metadata available")
|
| 169 |
-
|
| 170 |
-
key_frames = {}
|
| 171 |
-
|
| 172 |
-
# Determine frame indices based on swing phases
|
| 173 |
-
if swing_phases:
|
| 174 |
-
# Get first setup frame
|
| 175 |
-
setup_frames = swing_phases.get('setup', [])
|
| 176 |
-
setup_idx = setup_frames[0] if setup_frames else 0
|
| 177 |
-
|
| 178 |
-
# Get last backswing frame (top of backswing)
|
| 179 |
-
backswing_frames = swing_phases.get('backswing', [])
|
| 180 |
-
backswing_idx = backswing_frames[-1] if backswing_frames else total_frames//3
|
| 181 |
-
|
| 182 |
-
# Get first impact frame
|
| 183 |
-
impact_frames = swing_phases.get('impact', [])
|
| 184 |
-
impact_idx = impact_frames[0] if impact_frames else total_frames//2
|
| 185 |
-
else:
|
| 186 |
-
# Fallback to default indices if no swing phases provided
|
| 187 |
-
setup_idx = 0
|
| 188 |
-
backswing_idx = total_frames // 3
|
| 189 |
-
impact_idx = int(total_frames * 0.6)
|
| 190 |
-
|
| 191 |
-
print(f"Frame indices - Setup: {setup_idx}, Backswing: {backswing_idx}, Impact: {impact_idx}")
|
| 192 |
-
|
| 193 |
-
# Extract frames for each phase
|
| 194 |
-
phases = [
|
| 195 |
-
('setup', setup_idx),
|
| 196 |
-
('backswing', backswing_idx),
|
| 197 |
-
('impact', impact_idx)
|
| 198 |
-
]
|
| 199 |
-
|
| 200 |
-
for phase_name, frame_idx in phases:
|
| 201 |
-
frame = _extract_single_frame(cap, frame_idx, total_frames, rotation_angle, phase_name)
|
| 202 |
-
if frame is not None:
|
| 203 |
-
key_frames[phase_name] = frame
|
| 204 |
-
print(f"Successfully extracted {phase_name} frame")
|
| 205 |
-
else:
|
| 206 |
-
print(f"Failed to extract {phase_name} frame")
|
| 207 |
|
| 208 |
-
|
|
|
|
|
|
|
| 209 |
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
|
|
|
|
| 215 |
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
if not ret or frame is None:
|
| 229 |
-
print(f"Failed to read frame at index {attempt_idx} for {phase_name}")
|
| 230 |
-
continue
|
| 231 |
-
|
| 232 |
-
# Validate frame has 3 channels (color)
|
| 233 |
-
if len(frame.shape) != 3 or frame.shape[2] != 3:
|
| 234 |
-
print(f"Frame at index {attempt_idx} for {phase_name} is not in color format: {frame.shape}")
|
| 235 |
-
continue
|
| 236 |
-
|
| 237 |
-
print(f"Successfully read frame at index {attempt_idx} for {phase_name}, shape: {frame.shape}")
|
| 238 |
-
|
| 239 |
-
# Apply rotation correction if needed
|
| 240 |
-
if rotation_angle != 0:
|
| 241 |
-
print(f"Before rotation: {frame.shape}")
|
| 242 |
-
frame = _apply_rotation(frame, rotation_angle)
|
| 243 |
-
print(f"After {rotation_angle}° rotation: {frame.shape}")
|
| 244 |
-
print(f"Applied {rotation_angle}° rotation to {phase_name} frame")
|
| 245 |
-
|
| 246 |
-
return frame.copy()
|
| 247 |
-
|
| 248 |
-
print(f"Could not extract valid frame for {phase_name} after trying multiple indices")
|
| 249 |
-
return None
|
| 250 |
|
| 251 |
|
| 252 |
def _apply_rotation(frame, rotation_angle):
|
|
@@ -432,13 +382,13 @@ def create_key_frame_comparison(user_video_path, pro_video_path=None, user_swing
|
|
| 432 |
'user_image_path', 'pro_image_path', 'title', and 'comments' as values
|
| 433 |
"""
|
| 434 |
# Extract key frames from user video
|
| 435 |
-
user_frames = extract_key_swing_frames(user_video_path, user_swing_phases)
|
| 436 |
|
| 437 |
# Get pro frames either from provided images or video
|
| 438 |
if use_pro_images:
|
| 439 |
pro_frames = load_pro_reference_images()
|
| 440 |
else:
|
| 441 |
-
pro_frames = extract_key_swing_frames(pro_video_path, pro_swing_phases)
|
| 442 |
|
| 443 |
# Create output directory with absolute path
|
| 444 |
output_dir = os.path.abspath(output_dir)
|
|
|
|
| 120 |
return frames
|
| 121 |
|
| 122 |
|
| 123 |
+
def extract_key_swing_frames(video_path, frames, swing_phases=None):
|
| 124 |
"""
|
| 125 |
+
Extract 3 key frames from a list of processed frames.
|
| 126 |
1. First setup frame
|
| 127 |
2. Last backswing frame (top of backswing)
|
| 128 |
3. First impact frame
|
| 129 |
|
| 130 |
Args:
|
| 131 |
+
video_path (str): Path to the original video file (used for rotation metadata).
|
| 132 |
+
frames (list): List of processed video frames.
|
| 133 |
swing_phases (dict): Dictionary mapping phase names to lists of frame indices
|
| 134 |
+
relative to the 'frames' list.
|
| 135 |
|
| 136 |
Returns:
|
| 137 |
dict: Dictionary mapping phase names to frames
|
| 138 |
"""
|
| 139 |
+
key_frames = {'setup': None, 'backswing': None, 'impact': None}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
|
| 141 |
+
if not frames:
|
| 142 |
+
print("Warning: No frames provided to extract_key_swing_frames.")
|
| 143 |
+
return key_frames
|
|
|
|
|
|
|
|
|
|
| 144 |
|
| 145 |
+
# Determine frame indices based on swing phases
|
| 146 |
+
if swing_phases:
|
| 147 |
+
# Get first setup frame
|
| 148 |
+
setup_frames = swing_phases.get('setup', [])
|
| 149 |
+
setup_idx = setup_frames[0] if setup_frames else 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
|
| 151 |
+
# Get last backswing frame (top of backswing)
|
| 152 |
+
backswing_frames = swing_phases.get('backswing', [])
|
| 153 |
+
backswing_idx = backswing_frames[-1] if backswing_frames else len(frames) // 3
|
| 154 |
|
| 155 |
+
# Get first impact frame
|
| 156 |
+
impact_frames = swing_phases.get('impact', [])
|
| 157 |
+
impact_idx = impact_frames[0] if impact_frames else len(frames) // 2
|
| 158 |
+
else:
|
| 159 |
+
# Fallback to default indices if no swing phases provided
|
| 160 |
+
setup_idx = 0
|
| 161 |
+
backswing_idx = len(frames) // 3
|
| 162 |
+
impact_idx = len(frames) // 2
|
| 163 |
+
|
| 164 |
+
print(f"Key frame indices (relative to processed frames) - Setup: {setup_idx}, Backswing: {backswing_idx}, Impact: {impact_idx}")
|
| 165 |
+
|
| 166 |
+
# Get rotation angle from the original video file
|
| 167 |
+
rotation_angle = 0
|
| 168 |
+
if os.path.exists(video_path):
|
| 169 |
+
cap = cv2.VideoCapture(video_path)
|
| 170 |
+
if cap.isOpened():
|
| 171 |
+
try:
|
| 172 |
+
orientation = int(cap.get(cv2.CAP_PROP_ORIENTATION_META))
|
| 173 |
+
if orientation == 90:
|
| 174 |
+
rotation_angle = 270 # Rotate counterclockwise
|
| 175 |
+
elif orientation == 180:
|
| 176 |
+
rotation_angle = 180
|
| 177 |
+
elif orientation == 270:
|
| 178 |
+
rotation_angle = 90 # Rotate counterclockwise
|
| 179 |
+
print(f"Video orientation metadata: {orientation}, applying rotation: {rotation_angle}")
|
| 180 |
+
except Exception as e:
|
| 181 |
+
print(f"Could not read orientation metadata: {e}")
|
| 182 |
+
finally:
|
| 183 |
+
cap.release()
|
| 184 |
+
else:
|
| 185 |
+
print(f"Warning: Video path {video_path} not found for rotation check.")
|
| 186 |
|
| 187 |
+
phase_indices = {'setup': setup_idx, 'backswing': backswing_idx, 'impact': impact_idx}
|
| 188 |
|
| 189 |
+
for phase_name, frame_idx in phase_indices.items():
|
| 190 |
+
if 0 <= frame_idx < len(frames):
|
| 191 |
+
frame = frames[frame_idx].copy()
|
| 192 |
+
if rotation_angle != 0:
|
| 193 |
+
frame = _apply_rotation(frame, rotation_angle)
|
| 194 |
+
key_frames[phase_name] = frame
|
| 195 |
+
print(f"Successfully extracted {phase_name} frame from memory.")
|
| 196 |
+
else:
|
| 197 |
+
print(f"Failed to extract {phase_name} frame: index {frame_idx} is out of bounds for {len(frames)} frames.")
|
| 198 |
+
|
| 199 |
+
return key_frames
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
|
| 201 |
|
| 202 |
def _apply_rotation(frame, rotation_angle):
|
|
|
|
| 382 |
'user_image_path', 'pro_image_path', 'title', and 'comments' as values
|
| 383 |
"""
|
| 384 |
# Extract key frames from user video
|
| 385 |
+
user_frames = extract_key_swing_frames(user_video_path, user_frames, user_swing_phases)
|
| 386 |
|
| 387 |
# Get pro frames either from provided images or video
|
| 388 |
if use_pro_images:
|
| 389 |
pro_frames = load_pro_reference_images()
|
| 390 |
else:
|
| 391 |
+
pro_frames = extract_key_swing_frames(pro_video_path, pro_frames, pro_swing_phases)
|
| 392 |
|
| 393 |
# Create output directory with absolute path
|
| 394 |
output_dir = os.path.abspath(output_dir)
|
app/utils/video_processor.py
CHANGED
|
@@ -3,6 +3,7 @@ Video processing and object detection module
|
|
| 3 |
"""
|
| 4 |
|
| 5 |
import cv2
|
|
|
|
| 6 |
import numpy as np
|
| 7 |
from tqdm import tqdm
|
| 8 |
from ultralytics import YOLO
|
|
@@ -35,9 +36,17 @@ def process_video(video_path, sample_rate=5):
|
|
| 35 |
model = YOLO("yolov8n.pt")
|
| 36 |
class_names = model.names
|
| 37 |
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
if not cap.isOpened():
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 43 |
|
|
|
|
| 3 |
"""
|
| 4 |
|
| 5 |
import cv2
|
| 6 |
+
import platform
|
| 7 |
import numpy as np
|
| 8 |
from tqdm import tqdm
|
| 9 |
from ultralytics import YOLO
|
|
|
|
| 36 |
model = YOLO("yolov8n.pt")
|
| 37 |
class_names = model.names
|
| 38 |
|
| 39 |
+
# On macOS ("Darwin"), the AVFoundation backend is often more reliable.
|
| 40 |
+
# For other systems, FFMPEG is a good choice.
|
| 41 |
+
backend = cv2.CAP_AVFOUNDATION if platform.system() == "Darwin" else cv2.CAP_FFMPEG
|
| 42 |
+
cap = cv2.VideoCapture(video_path, backend)
|
| 43 |
+
|
| 44 |
if not cap.isOpened():
|
| 45 |
+
backend_name = "AVFoundation" if platform.system() == "Darwin" else "FFMPEG"
|
| 46 |
+
print(f"Warning: Could not open video with {backend_name} backend. Trying default.")
|
| 47 |
+
cap = cv2.VideoCapture(video_path) # Fallback to default
|
| 48 |
+
if not cap.isOpened():
|
| 49 |
+
raise ValueError("Error opening video file with any available backend.")
|
| 50 |
|
| 51 |
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 52 |
|