Par-ity_Project / app /models /segmentation.py
chenemii's picture
Major code cleanup: Remove unused metrics and consolidate model files
95ad9d6
"""
Simple swing segmentation with downswing re-gating
"""
def _dt_and_fps(frame_timestamps_ms, frames: int, total_ms: float):
"""Calculate time delta and FPS from frame data"""
if frame_timestamps_ms and len(frame_timestamps_ms) >= 2:
dt = (frame_timestamps_ms[-1] - frame_timestamps_ms[0]) / max(len(frame_timestamps_ms) - 1, 1) / 1000.0
else:
dt = (total_ms / 1000.0) / max(frames, 1)
return dt, 1.0 / max(dt, 1e-6)
def detect_arm_velocity_zero_crossing(pose_data, frames):
"""Simple fallback for top detection"""
if not frames:
return 0
return frames[len(frames)//3] # Simple fallback
def segment_swing(pose_data, detections=None, sample_rate=1, frame_shape=None,
frame_timestamps_ms=None, total_ms=None, fps=30.0):
"""
Simplified swing segmentation with downswing re-gating logic.
Args:
pose_data: Dictionary mapping frame indices to pose keypoints
detections: Object detections (unused in current implementation)
sample_rate: Frame sampling rate
frame_shape: Frame shape (unused in current implementation)
frame_timestamps_ms: List of frame timestamps in milliseconds
total_ms: Total video duration in milliseconds
fps: Video frame rate (fallback if timestamps not available)
Returns:
dict: Dictionary with swing phases and timing_unreliable flag
"""
frames = [i for i in sorted(pose_data) if pose_data[i] is not None]
out = {"setup":[], "backswing":[], "downswing":[], "impact":[], "follow_through":[], "timing_unreliable": False}
if not frames:
return out
# Improved segmentation using biomechanical markers
total_frames = len(frames)
# Setup phase - first 12.5% (this is fairly reliable)
setup_end_idx = max(1, total_frames // 8)
setup_end = frames[setup_end_idx]
# Detect top of swing using arm velocity zero crossing (more reliable than time-based)
backswing_frames_for_analysis = [f for f in frames if f > setup_end]
top = detect_arm_velocity_zero_crossing(pose_data, backswing_frames_for_analysis)
# Robust impact detection using clubhead velocity zero-crossing
# Look for the frame where clubhead velocity changes from downward to upward
impact_candidates = []
if pose_data:
# Get frames after top for impact analysis
frames_after_top = [f for f in frames if f > top and f in pose_data]
if len(frames_after_top) >= 5:
clubhead_positions = []
valid_frames = []
for frame_idx in frames_after_top:
kp = pose_data[frame_idx]
if kp and len(kp) > 15:
# Use wrist as proxy for clubhead (lead arm wrist for right-handed)
wrist = kp[15][:2] # Right wrist
if kp[15][2] > 0.5: # Good visibility
clubhead_positions.append(wrist[1]) # Y-coordinate (vertical)
valid_frames.append(frame_idx)
if len(clubhead_positions) >= 5:
# Calculate vertical velocity (downward = positive, upward = negative)
velocities = []
for i in range(1, len(clubhead_positions)):
vel = clubhead_positions[i] - clubhead_positions[i-1]
velocities.append(vel)
# Find zero crossing: velocity changes from positive to negative
for i in range(1, len(velocities)):
if velocities[i-1] > 0 and velocities[i] <= 0:
# Found impact - clubhead starts moving upward
impact_candidates.append(valid_frames[i])
break
# Fallback: if no velocity zero-crossing found, use timing-based estimate
if not impact_candidates:
top_idx_in_total = frames.index(top) if top in frames else total_frames // 3
remaining_frames_after_top = total_frames - top_idx_in_total
# Tour pro downswing: ~8-10 frames at 30fps (25-30% of total swing)
expected_downswing_frames = max(8, int(total_frames * 0.25))
impact_idx = min(total_frames - 1, top_idx_in_total + expected_downswing_frames)
imp = frames[impact_idx]
else:
imp = impact_candidates[0]
# Assign frames to phases
for f in frames:
if f <= setup_end:
out["setup"].append(f)
elif f <= top:
out["backswing"].append(f)
elif f < imp:
out["downswing"].append(f)
elif f == imp:
out["impact"].append(f)
else:
out["follow_through"].append(f)
# Get timing information for downswing re-gating
if total_ms is None:
total_ms = total_frames * (1000.0 / fps) # fallback estimate
dt, actual_fps = _dt_and_fps(frame_timestamps_ms, total_frames, total_ms)
# Downswing re-gating logic
downswing_frames = out["downswing"]
if downswing_frames:
downswing_duration_frames = len(downswing_frames)
# Scale expected range by fps (~30 fps baseline)
fps_scale = actual_fps / 30.0
min_expected = max(1, int(6 * fps_scale))
max_expected = int(15 * fps_scale)
# Check if downswing is outside expected range
if downswing_duration_frames < min_expected or downswing_duration_frames > max_expected:
# Mark timing as unreliable when downswing duration is outside expected range
# (Angular velocity re-gating removed for 5 core metrics simplification)
out["timing_unreliable"] = True
return out