File size: 5,694 Bytes
4a5016f
 
 
95ad9d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4a5016f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
"""
Simple swing segmentation with downswing re-gating
"""

def _dt_and_fps(frame_timestamps_ms, frames: int, total_ms: float):
    """Calculate time delta and FPS from frame data"""
    if frame_timestamps_ms and len(frame_timestamps_ms) >= 2:
        dt = (frame_timestamps_ms[-1] - frame_timestamps_ms[0]) / max(len(frame_timestamps_ms) - 1, 1) / 1000.0
    else:
        dt = (total_ms / 1000.0) / max(frames, 1)
    return dt, 1.0 / max(dt, 1e-6)


def detect_arm_velocity_zero_crossing(pose_data, frames):
    """Simple fallback for top detection"""
    if not frames:
        return 0
    return frames[len(frames)//3]  # Simple fallback

def segment_swing(pose_data, detections=None, sample_rate=1, frame_shape=None, 
                 frame_timestamps_ms=None, total_ms=None, fps=30.0):
    """
    Simplified swing segmentation with downswing re-gating logic.
    
    Args:
        pose_data: Dictionary mapping frame indices to pose keypoints
        detections: Object detections (unused in current implementation)
        sample_rate: Frame sampling rate
        frame_shape: Frame shape (unused in current implementation)
        frame_timestamps_ms: List of frame timestamps in milliseconds
        total_ms: Total video duration in milliseconds
        fps: Video frame rate (fallback if timestamps not available)
        
    Returns:
        dict: Dictionary with swing phases and timing_unreliable flag
    """
    frames = [i for i in sorted(pose_data) if pose_data[i] is not None]
    out = {"setup":[], "backswing":[], "downswing":[], "impact":[], "follow_through":[], "timing_unreliable": False}
    if not frames: 
        return out
    
    # Improved segmentation using biomechanical markers
    total_frames = len(frames)
    
    # Setup phase - first 12.5% (this is fairly reliable)
    setup_end_idx = max(1, total_frames // 8)
    setup_end = frames[setup_end_idx]
    
    # Detect top of swing using arm velocity zero crossing (more reliable than time-based)
    backswing_frames_for_analysis = [f for f in frames if f > setup_end]
    top = detect_arm_velocity_zero_crossing(pose_data, backswing_frames_for_analysis)
    
    # Robust impact detection using clubhead velocity zero-crossing
    # Look for the frame where clubhead velocity changes from downward to upward
    impact_candidates = []
    if pose_data:
        # Get frames after top for impact analysis
        frames_after_top = [f for f in frames if f > top and f in pose_data]

        if len(frames_after_top) >= 5:
            clubhead_positions = []
            valid_frames = []

            for frame_idx in frames_after_top:
                kp = pose_data[frame_idx]
                if kp and len(kp) > 15:
                    # Use wrist as proxy for clubhead (lead arm wrist for right-handed)
                    wrist = kp[15][:2]  # Right wrist
                    if kp[15][2] > 0.5:  # Good visibility
                        clubhead_positions.append(wrist[1])  # Y-coordinate (vertical)
                        valid_frames.append(frame_idx)

            if len(clubhead_positions) >= 5:
                # Calculate vertical velocity (downward = positive, upward = negative)
                velocities = []
                for i in range(1, len(clubhead_positions)):
                    vel = clubhead_positions[i] - clubhead_positions[i-1]
                    velocities.append(vel)

                # Find zero crossing: velocity changes from positive to negative
                for i in range(1, len(velocities)):
                    if velocities[i-1] > 0 and velocities[i] <= 0:
                        # Found impact - clubhead starts moving upward
                        impact_candidates.append(valid_frames[i])
                        break

    # Fallback: if no velocity zero-crossing found, use timing-based estimate
    if not impact_candidates:
        top_idx_in_total = frames.index(top) if top in frames else total_frames // 3
        remaining_frames_after_top = total_frames - top_idx_in_total
        # Tour pro downswing: ~8-10 frames at 30fps (25-30% of total swing)
        expected_downswing_frames = max(8, int(total_frames * 0.25))
        impact_idx = min(total_frames - 1, top_idx_in_total + expected_downswing_frames)
        imp = frames[impact_idx]
    else:
        imp = impact_candidates[0]
    
    # Assign frames to phases
    for f in frames:
        if f <= setup_end: 
            out["setup"].append(f)
        elif f <= top: 
            out["backswing"].append(f)
        elif f < imp: 
            out["downswing"].append(f)
        elif f == imp: 
            out["impact"].append(f)
        else: 
            out["follow_through"].append(f)
    
    # Get timing information for downswing re-gating
    if total_ms is None:
        total_ms = total_frames * (1000.0 / fps)  # fallback estimate
    dt, actual_fps = _dt_and_fps(frame_timestamps_ms, total_frames, total_ms)
    
    # Downswing re-gating logic
    downswing_frames = out["downswing"]
    if downswing_frames:
        downswing_duration_frames = len(downswing_frames)
        
        # Scale expected range by fps (~30 fps baseline)
        fps_scale = actual_fps / 30.0
        min_expected = max(1, int(6 * fps_scale))
        max_expected = int(15 * fps_scale)
        
        # Check if downswing is outside expected range
        if downswing_duration_frames < min_expected or downswing_duration_frames > max_expected:
            # Mark timing as unreliable when downswing duration is outside expected range
            # (Angular velocity re-gating removed for 5 core metrics simplification)
            out["timing_unreliable"] = True
    
    return out