File size: 3,814 Bytes
1346b10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
"""
tracker.py
ByteTrack wrapper using supervision 0.21 stable API.
Maintains trajectory history and speed estimates per track ID.
"""
from __future__ import annotations

import math
from collections import defaultdict, deque
from typing import Dict, List, Optional, Tuple

import numpy as np


# pixels-per-metre calibration (rough: 200px ≈ 10m pitch width)
DEFAULT_PPM = 20.0


class TrackState:
    """Stores trajectory + speed per track ID."""

    def __init__(self, traj_len: int = 60, ppm: float = DEFAULT_PPM) -> None:
        self.ppm      = ppm
        self.traj_len = traj_len
        self.trajs:  Dict[int, deque]          = defaultdict(lambda: deque(maxlen=traj_len))
        self.prev:   Dict[int, Tuple[int,int]] = {}
        self.speeds: Dict[int, float]          = {}

    def update(self, tracks: list[dict], fps: float) -> None:
        """
        tracks: list of {"id": int, "xyxy": [x1,y1,x2,y2]}
        """
        new_prev: Dict[int, Tuple[int,int]] = {}
        for t in tracks:
            tid = int(t["id"])
            x1, y1, x2, y2 = t["xyxy"]
            cx, cy = int((x1+x2)/2), int((y1+y2)/2)
            new_prev[tid] = (cx, cy)
            self.trajs[tid].append((cx, cy))

            if tid in self.prev and fps > 0:
                d   = math.hypot(cx - self.prev[tid][0], cy - self.prev[tid][1])
                spd = (d / self.ppm) * fps * 3.6          # km/h
                old = self.speeds.get(tid, spd)
                self.speeds[tid] = 0.7 * old + 0.3 * spd  # EMA smooth

        self.prev = new_prev

    def trajectory(self, tid: int) -> List[Tuple[int,int]]:
        return list(self.trajs[tid])

    def speed(self, tid: int) -> Optional[float]:
        return self.speeds.get(tid)

    @property
    def all_ids(self) -> List[int]:
        return list(self.trajs.keys())


class SportsTracker:
    """
    Wraps supervision 0.21 ByteTracker.
    Input/output uses plain Python dicts — no supervision objects exposed.
    """

    def __init__(
        self,
        fps: float      = 30.0,
        conf: float     = 0.30,
        iou: float      = 0.50,
        traj_len: int   = 60,
        ppm: float      = DEFAULT_PPM,
    ) -> None:
        import supervision as sv

        self.fps   = fps
        self.state = TrackState(traj_len=traj_len, ppm=ppm)

        # supervision 0.21 ByteTrack constructor
        self._tracker = sv.ByteTrack(
            track_activation_threshold=conf,
            lost_track_buffer=max(60, int(fps * 3)),
            minimum_matching_threshold=0.80,
            frame_rate=int(fps),
        )

    def update(self, detections: list[dict]) -> list[dict]:
        """
        Args:
            detections: list of {"xyxy": [x1,y1,x2,y2], "conf": float}

        Returns:
            list of {"id": int, "xyxy": [x1,y1,x2,y2], "conf": float}
        """
        import supervision as sv

        if not detections:
            return []

        xyxy  = np.array([d["xyxy"]  for d in detections], dtype=np.float32)
        confs = np.array([d["conf"]  for d in detections], dtype=np.float32)
        cids  = np.zeros(len(detections), dtype=int)

        sv_det = sv.Detections(
            xyxy=xyxy,
            confidence=confs,
            class_id=cids,
        )

        tracked = self._tracker.update_with_detections(sv_det)

        results = []
        if tracked.tracker_id is None:
            return results

        for i, tid in enumerate(tracked.tracker_id):
            if tid is None:
                continue
            results.append({
                "id":   int(tid),
                "xyxy": tracked.xyxy[i].tolist(),
                "conf": float(tracked.confidence[i]) if tracked.confidence is not None else 0.0,
            })

        self.state.update(results, self.fps)
        return results