File size: 5,734 Bytes
7a7e825
 
 
6251904
 
 
7a7e825
 
 
 
 
 
 
 
 
 
 
 
 
6251904
 
7a7e825
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6251904
 
 
7a7e825
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import numpy as np
from filterpy.kalman import KalmanFilter
from scipy.optimize import linear_sum_assignment
import logging

logger = logging.getLogger(__name__)

class BYTETracker:
    def __init__(self, track_thresh=0.3, track_buffer=90, match_thresh=0.5, frame_rate=30):
        self.track_thresh = track_thresh
        self.track_buffer = track_buffer
        self.match_thresh = match_thresh
        self.frame_rate = frame_rate
        self.next_id = 1
        self.tracks = {}
        self.kalman_filters = {}
        self.worker_history = {}
        self.last_positions = {}
        self.recently_removed = {}
        logger.info("BYTETracker initialized with track_thresh=%.2f, track_buffer=%d, match_thresh=%.2f", 
                    track_thresh, track_buffer, match_thresh)

    def _init_kalman(self):
        kf = KalmanFilter(dim_x=4, dim_z=2)
        kf.F = np.array([[1, 0, 1, 0], [0, 1, 0, 1], [0, 0, 1, 0], [0, 0, 0, 1]])
        kf.H = np.array([[1, 0, 0, 0], [0, 1, 0, 0]])
        kf.P *= 1000.
        kf.R = np.array([[10, 0], [0, 10]])
        kf.Q = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])
        return kf

    def update(self, dets, scores, cls, current_time):
        tracks = []
        # Prune stale tracks
        for track_id in list(self.tracks.keys()):
            if current_time - self.tracks[track_id]['last_seen'] > self.track_buffer / self.frame_rate:
                self.recently_removed[track_id] = {
                    'bbox': self.tracks[track_id]['bbox'],
                    'last_seen': current_time,
                    'last_position': self.last_positions.get(track_id, [0, 0])
                }
                del self.tracks[track_id]
                del self.kalman_filters[track_id]
                if track_id in self.worker_history:
                    del self.worker_history[track_id]
                if track_id in self.last_positions:
                    del self.last_positions[track_id]

        # Predict with Kalman filter
        for track_id, kf in self.kalman_filters.items():
            kf.predict()
            self.tracks[track_id]['bbox'][:2] = kf.x[:2].flatten()

        # Match detections to tracks
        if len(dets) == 0:
            return tracks

        cost_matrix = np.zeros((len(dets), len(self.tracks)))
        track_ids = list(self.tracks.keys())
        for i, det in enumerate(dets):
            for j, track_id in enumerate(track_ids):
                cost_matrix[i, j] = 1 - self._calculate_iou(det, self.tracks[track_id]['bbox'])

        row_indices, col_indices = linear_sum_assignment(cost_matrix)
        matched = set()
        for i, j in zip(row_indices, col_indices):
            if cost_matrix[i, j] < 1 - self.match_thresh:
                track_id = track_ids[j]
                self.tracks[track_id].update({
                    'bbox': dets[i],
                    'score': scores[i],
                    'cls': cls[i],
                    'last_seen': current_time
                })
                self.kalman_filters[track_id].update(dets[i][:2])
                self.worker_history[track_id].append(dets[i][:2])
                self.last_positions[track_id] = dets[i][:2]
                tracks.append({'id': track_id, 'bbox': dets[i], 'score': scores[i], 'cls': cls[i]})
                matched.add(i)

        # Handle unmatched detections
        for i, (det, score, cl) in enumerate(zip(dets, scores, cls)):
            if i in matched or score < self.track_thresh:
                continue
            reidentified = False
            for track_id, info in list(self.recently_removed.items()):
                if self._is_same_worker(det[:2], info['last_position']):
                    self.tracks[track_id] = {
                        'bbox': det,
                        'score': score,
                        'cls': cl,
                        'last_seen': current_time
                    }
                    self.kalman_filters[track_id] = self._init_kalman()
                    self.kalman_filters[track_id].x[:2] = det[:2].reshape(2, 1)
                    self.worker_history[track_id] = [det[:2]]
                    self.last_positions[track_id] = det[:2]
                    tracks.append({'id': track_id, 'bbox': det, 'score': score, 'cls': cl})
                    del self.recently_removed[track_id]
                    reidentified = True
                    break
            if not reidentified:
                self.tracks[self.next_id] = {
                    'bbox': det,
                    'score': score,
                    'cls': cl,
                    'last_seen': current_time
                }
                self.kalman_filters[self.next_id] = self._init_kalman()
                self.kalman_filters[self.next_id].x[:2] = det[:2].reshape(2, 1)
                self.worker_history[self.next_id] = [det[:2]]
                self.last_positions[self.next_id] = det[:2]
                tracks.append({'id': self.next_id, 'bbox': det, 'score': score, 'cls': cl})
                self.next_id += 1

        return tracks

    def _calculate_iou(self, box1, box2):
        x1, y1, w1, h1 = box1
        x2, y2, w2, h2 = box2
        x_left = max(x1 - w1/2, x2 - w2/2)
        y_top = max(y1 - h1/2, y2 - h2/2)
        x_right = min(x1 + w1/2, x2 + w2/2)
        y_bottom = min(y1 + h1/2, y2 + h2/2)
        if x_right < x_left or y_bottom < y_top:
            return 0.0
        intersection = (x_right - x_left) * (y_bottom - y_top)
        union = w1 * h1 + w2 * h2 - intersection
        return intersection / union

    def _is_same_worker(self, pos1, pos2, threshold=150):
        return np.sqrt(np.sum((np.array(pos1) - np.array(pos2))**2)) < threshold