Zhen Ye commited on
Commit
6896025
·
1 Parent(s): 85ec659

feat: replace SimpleTracker with ByteTrack

Browse files
Files changed (2) hide show
  1. inference.py +4 -111
  2. utils/tracker.py +663 -0
inference.py CHANGED
@@ -271,115 +271,7 @@ def _build_detection_records(
271
  return detections
272
 
273
 
274
- class SimpleTracker:
275
- def __init__(self, max_age: int = 30, iou_thresh: float = 0.3):
276
- self.tracks = {} # id -> {bbox, label, history, missed_frames, filter}
277
- self.next_id = 1
278
- self.max_age = max_age
279
- self.iou_thresh = iou_thresh
280
-
281
- def update(self, detections: List[Dict[str, Any]]):
282
- # detection: {bbox: [x1,y1,x2,y2], label, score}
283
-
284
- # 1. Predict new locations (simple constant velocity or just last pos)
285
- # For simple IOU tracker, prediction is just previous position.
286
-
287
- # 2. Match
288
- active_tracks = [t for t in self.tracks.values() if t['missed_frames'] < self.max_age]
289
-
290
- matched_track_indices = set()
291
- matched_det_indices = set()
292
-
293
- # Greedy matching by IOU
294
- # O(N*M) but N,M are small
295
- matches = [] # (track_id, det_idx, iou)
296
-
297
- for t_id, track in self.tracks.items():
298
- if track['missed_frames'] >= self.max_age: continue
299
-
300
- t_box = track['bbox']
301
- for d_idx, det in enumerate(detections):
302
- if d_idx in matched_det_indices: continue
303
- d_box = det['bbox']
304
-
305
- # Check label consistency (optional, but good for stability)
306
- if track['label'] != det['label']: continue
307
-
308
- iou = self._calculate_iou(t_box, d_box)
309
- if iou > self.iou_thresh:
310
- matches.append((t_id, d_idx, iou))
311
-
312
- # Sort by IOU desc
313
- matches.sort(key=lambda x: x[2], reverse=True)
314
-
315
- used_tracks = set()
316
- used_dets = set()
317
-
318
- for t_id, d_idx, iou in matches:
319
- if t_id in used_tracks or d_idx in used_dets: continue
320
-
321
- # Update Track
322
- track = self.tracks[t_id]
323
- track['bbox'] = detections[d_idx]['bbox']
324
- track['score'] = detections[d_idx]['score']
325
- track['missed_frames'] = 0
326
- track['history'].append(track['bbox'])
327
- if len(track['history']) > 30: track['history'].pop(0)
328
-
329
- # Persist GPT attributes from track to detection (propagate forward)
330
- for key in ['gpt_distance_m', 'gpt_direction', 'gpt_description']:
331
- if key in track:
332
- detections[d_idx][key] = track[key]
333
-
334
- # Persist GPT attributes from detection to track (update from source)
335
- for key in ['gpt_distance_m', 'gpt_direction', 'gpt_description']:
336
- if key in detections[d_idx]:
337
- track[key] = detections[d_idx][key]
338
-
339
- detections[d_idx]['track_id'] = f"T{str(t_id).zfill(2)}"
340
-
341
- # Attach speed/direction state (to be computed by SpeedEstimator)
342
- detections[d_idx]['history'] = track['history']
343
-
344
- used_tracks.add(t_id)
345
- used_dets.add(d_idx)
346
-
347
- # 3. Create new tracks
348
- for d_idx, det in enumerate(detections):
349
- if d_idx not in used_dets:
350
- t_id = self.next_id
351
- self.next_id += 1
352
- self.tracks[t_id] = {
353
- 'bbox': det['bbox'],
354
- 'label': det['label'],
355
- 'score': det['score'],
356
- 'missed_frames': 0,
357
- 'history': [det['bbox']]
358
- }
359
- # Initialize GPT attributes if present
360
- for key in ['gpt_distance_m', 'gpt_direction', 'gpt_description']:
361
- if key in det:
362
- self.tracks[t_id][key] = det[key]
363
-
364
- det['track_id'] = f"T{str(t_id).zfill(2)}"
365
- det['history'] = [det['bbox']]
366
-
367
- # 4. Age out
368
- for t_id in list(self.tracks.keys()):
369
- if t_id not in used_tracks:
370
- self.tracks[t_id]['missed_frames'] += 1
371
- if self.tracks[t_id]['missed_frames'] > self.max_age:
372
- del self.tracks[t_id]
373
-
374
- def _calculate_iou(self, boxA, boxB):
375
- xA = max(boxA[0], boxB[0])
376
- yA = max(boxA[1], boxB[1])
377
- xB = min(boxA[2], boxB[2])
378
- yB = min(boxA[3], boxB[3])
379
- interArea = max(0, xB - xA) * max(0, yB - yA)
380
- boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
381
- boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
382
- return interArea / float(boxAArea + boxBArea - interArea + 1e-6)
383
 
384
 
385
  class SpeedEstimator:
@@ -1186,7 +1078,7 @@ def run_inference(
1186
  buffer = {}
1187
 
1188
  # Initialize Tracker & Speed Estimator
1189
- tracker = SimpleTracker()
1190
  speed_est = SpeedEstimator(fps=fps)
1191
 
1192
  try:
@@ -1259,7 +1151,8 @@ def run_inference(
1259
 
1260
  # --- SEQUENTIAL TRACKING ---
1261
  # Update tracker with current frame detections
1262
- tracker.update(dets)
 
1263
  speed_est.estimate(dets)
1264
 
1265
  # --- RENDER BOXES & OVERLAYS ---
 
271
  return detections
272
 
273
 
274
+ from utils.tracker import ByteTracker
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
 
276
 
277
  class SpeedEstimator:
 
1078
  buffer = {}
1079
 
1080
  # Initialize Tracker & Speed Estimator
1081
+ tracker = ByteTracker(frame_rate=fps)
1082
  speed_est = SpeedEstimator(fps=fps)
1083
 
1084
  try:
 
1151
 
1152
  # --- SEQUENTIAL TRACKING ---
1153
  # Update tracker with current frame detections
1154
+ # ByteTracker returns the list of ACTIVE tracks with IDs
1155
+ dets = tracker.update(dets)
1156
  speed_est.estimate(dets)
1157
 
1158
  # --- RENDER BOXES & OVERLAYS ---
utils/tracker.py ADDED
@@ -0,0 +1,663 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from scipy.optimize import linear_sum_assignment
3
+ import scipy.linalg
4
+
5
+
6
+ class KalmanFilter:
7
+ """
8
+ A simple Kalman Filter for tracking bounding boxes in image space.
9
+ The 8-dimensional state space is (x, y, a, h, vx, vy, va, vh), where
10
+ x, y is the center position, a is the aspect ratio, and h is the height.
11
+ """
12
+
13
+ def __init__(self):
14
+ ndim, dt = 4, 1.0
15
+
16
+ # Create Kalman filter model matrices.
17
+ self._motion_mat = np.eye(2 * ndim, 2 * ndim)
18
+ for i in range(ndim):
19
+ self._motion_mat[i, ndim + i] = dt
20
+ self._update_mat = np.eye(ndim, 2 * ndim)
21
+
22
+ # Motion and observation uncertainty are chosen relative to the current
23
+ # state estimate. These weights control the amount of uncertainty in
24
+ # the model. This is a bit heuristic.
25
+ self._std_weight_position = 1.0 / 20
26
+ self._std_weight_velocity = 1.0 / 160
27
+
28
+ def initiate(self, measurement):
29
+ """Create track from unassociated measurement.
30
+
31
+ Parameters
32
+ ----------
33
+ measurement : dbo
34
+ Bounding box coordinates (x1, y1, x2, y2) with confidence score.
35
+
36
+ Returns
37
+ -------
38
+ (mean, covariance)
39
+ Returns the mean vector (8 dimensional) and covariance matrix (8x8)
40
+ of the new track.
41
+ """
42
+ mean_pos = self._xyah_from_xyxy(measurement)
43
+ mean = np.r_[mean_pos, np.zeros_like(mean_pos)]
44
+
45
+ std = [
46
+ 2 * self._std_weight_position * mean_pos[3],
47
+ 2 * self._std_weight_position * mean_pos[3],
48
+ 1e-2,
49
+ 2 * self._std_weight_position * mean_pos[3],
50
+ 10 * self._std_weight_velocity * mean_pos[3],
51
+ 10 * self._std_weight_velocity * mean_pos[3],
52
+ 1e-5,
53
+ 10 * self._std_weight_velocity * mean_pos[3],
54
+ ]
55
+ covariance = np.diag(np.square(std))
56
+ return mean, covariance
57
+
58
+ def predict(self, mean, covariance):
59
+ """Run Kalman filter prediction step.
60
+
61
+ Parameters
62
+ ----------
63
+ mean : ndarray
64
+ The 8 dimensional mean vector of the object state at the previous
65
+ time step.
66
+ covariance : ndarray
67
+ The 8x8 dimensional covariance matrix of the object state at the
68
+ previous time step.
69
+
70
+ Returns
71
+ -------
72
+ (mean, covariance)
73
+ Returns the mean vector and covariance matrix of the predicted
74
+ state.
75
+ """
76
+ std_pos = [
77
+ self._std_weight_position * mean[3],
78
+ self._std_weight_position * mean[3],
79
+ 1e-2,
80
+ self._std_weight_position * mean[3],
81
+ ]
82
+
83
+ std_vel = [
84
+ self._std_weight_velocity * mean[3],
85
+ self._std_weight_velocity * mean[3],
86
+ 1e-5,
87
+ self._std_weight_velocity * mean[3],
88
+ ]
89
+
90
+ motion_cov = np.diag(np.square(np.r_[std_pos, std_vel]))
91
+ mean = np.dot(self._motion_mat, mean)
92
+ covariance = (
93
+ np.linalg.multi_dot((self._motion_mat, covariance, self._motion_mat.T))
94
+ + motion_cov
95
+ )
96
+ return mean, covariance
97
+
98
+ def project(self, mean, covariance):
99
+ """Project state distribution to measurement space.
100
+
101
+ Parameters
102
+ ----------
103
+ mean : ndarray
104
+ The state's mean vector (8 dimensional).
105
+ covariance : ndarray
106
+ The state's covariance matrix (8x8 dimensional).
107
+
108
+ Returns
109
+ -------
110
+ (mean, covariance)
111
+ Returns the projected mean and covariance matrix of the given state
112
+ estimate.
113
+ """
114
+ std = [
115
+ self._std_weight_position * mean[3],
116
+ self._std_weight_position * mean[3],
117
+ 1e-1,
118
+ self._std_weight_position * mean[3],
119
+ ]
120
+
121
+ innovation_cov = np.diag(np.square(std))
122
+ mean = np.dot(self._update_mat, mean)
123
+ covariance = np.linalg.multi_dot((self._update_mat, covariance, self._update_mat.T))
124
+ return mean, covariance + innovation_cov
125
+
126
+ def update(self, mean, covariance, measurement):
127
+ """Run Kalman filter correction step.
128
+
129
+ Parameters
130
+ ----------
131
+ mean : ndarray
132
+ The predicted state's mean vector (8 dimensional).
133
+ covariance : ndarray
134
+ The state's covariance matrix (8x8 dimensional).
135
+ measurement : ndarray
136
+ The 4 dimensional measurement vector (x, y, a, h), where (x, y)
137
+ is the center position, a the aspect ratio, and h the height.
138
+
139
+ Returns
140
+ -------
141
+ (mean, covariance)
142
+ Returns the measurement-corrected state distribution.
143
+ """
144
+ projected_mean, projected_cov = self.project(mean, covariance)
145
+ chol_factor, lower = scipy.linalg.cho_factor(
146
+ projected_cov, lower=True, check_finite=False
147
+ )
148
+ kalman_gain = scipy.linalg.cho_solve(
149
+ (chol_factor, lower),
150
+ np.dot(covariance, self._update_mat.T).T,
151
+ check_finite=False,
152
+ ).T
153
+ innovation = measurement - projected_mean
154
+ new_mean = mean + np.dot(innovation, kalman_gain.T)
155
+ new_covariance = covariance - np.linalg.multi_dot(
156
+ (kalman_gain, projected_cov, kalman_gain.T)
157
+ )
158
+ return new_mean, new_covariance
159
+
160
+ def gating_distance(self, mean, covariance, measurements, only_position=False, metric="mahalanobis"):
161
+ """Compute gating distance between state distribution and measurements."""
162
+ mean, covariance = self.project(mean, covariance)
163
+ if only_position:
164
+ mean, covariance = mean[:2], covariance[:2, :2]
165
+ measurements = measurements[:, :2]
166
+
167
+ d = measurements - mean
168
+ if metric == "gaussian":
169
+ return np.sum(d * d, axis=1)
170
+ elif metric == "mahalanobis":
171
+ cholesky_factor = np.linalg.cholesky(covariance)
172
+ z = scipy.linalg.solve_triangular(
173
+ cholesky_factor, d.T, lower=True, check_finite=False, overwrite_b=True
174
+ )
175
+ squared_maha = np.sum(z * z, axis=0)
176
+ return squared_maha
177
+ else:
178
+ raise ValueError("invalid distance metric")
179
+
180
+ def _xyah_from_xyxy(self, xyxy):
181
+ """Convert bounding box to format `(center x, center y, aspect ratio,
182
+ height)`, where the aspect ratio is `width / height`.
183
+ """
184
+ bbox = np.asarray(xyxy).copy()
185
+ cx = (bbox[0] + bbox[2]) / 2.0
186
+ cy = (bbox[1] + bbox[3]) / 2.0
187
+ w = bbox[2] - bbox[0]
188
+ h = bbox[3] - bbox[1]
189
+
190
+ ret = np.zeros(4, dtype=bbox.dtype)
191
+ ret[0] = cx
192
+ ret[1] = cy
193
+ ret[2] = w / h
194
+ ret[3] = h
195
+ return ret
196
+
197
+
198
+ class STrack:
199
+ """
200
+ Single object track. Wrapper around KalmanFilter state.
201
+ """
202
+
203
+ def __init__(self, tlwh, score, label):
204
+ # wait, input is xyxy usually in our pipeline
205
+ # ByteTrack usually uses tlwh internally.
206
+ # Let's standardize to input xyxy.
207
+
208
+ self._tlwh = np.asarray(self._tlwh_from_xyxy(tlwh), dtype=np.float32)
209
+ self.is_activated = False
210
+ self.track_id = 0
211
+ self.state = 1 # 1: New, 2: Tracked, 3: Lost, 4: Removed
212
+
213
+ self.score = score
214
+ self.label = label
215
+ self.start_frame = 0
216
+ self.frame_id = 0
217
+ self.time_since_update = 0
218
+
219
+ # Multi-frame history
220
+ self.history = []
221
+
222
+ # Kalman Filter
223
+ self.kalman_filter = None
224
+ self.mean = None
225
+ self.covariance = None
226
+
227
+ # GPT attributes (persistent)
228
+ self.gpt_data = {}
229
+
230
+ def _tlwh_from_xyxy(self, xyxy):
231
+ """Convert xyxy to tlwh."""
232
+ w = xyxy[2] - xyxy[0]
233
+ h = xyxy[3] - xyxy[1]
234
+ return [xyxy[0], xyxy[1], w, h]
235
+
236
+ def _xyxy_from_tlwh(self, tlwh):
237
+ """Convert tlwh to xyxy."""
238
+ x1 = tlwh[0]
239
+ y1 = tlwh[1]
240
+ x2 = x1 + tlwh[2]
241
+ y2 = y1 + tlwh[3]
242
+ return [x1, y1, x2, y2]
243
+
244
+ @property
245
+ def tlwh(self):
246
+ """Get current position in bounding box format `(top left x, top left y,
247
+ width, height)`.
248
+ """
249
+ if self.mean is None:
250
+ return self._tlwh.copy()
251
+ ret = self.mean[:4].copy()
252
+ ret[2] *= ret[3]
253
+ ret[:2] -= ret[2:] / 2
254
+ return ret
255
+
256
+ @property
257
+ def tlbr(self):
258
+ """Get current position in bounding box format `(min x, min y, max x,
259
+ max y)`.
260
+ """
261
+ ret = self.tlwh.copy()
262
+ ret[2:] += ret[:2]
263
+ return ret
264
+
265
+ def activate(self, kalman_filter, frame_id):
266
+ """Start a new track."""
267
+ self.kalman_filter = kalman_filter
268
+ self.track_id = self.next_id()
269
+ self.mean, self.covariance = self.kalman_filter.initiate(self.tlbr) # Initiate needs xyxy
270
+
271
+ self.state = 2 # Tracked
272
+ self.frame_id = frame_id
273
+ self.start_frame = frame_id
274
+ self.is_activated = True
275
+
276
+ def re_activate(self, new_track, frame_id, new_id=False):
277
+ """Reactivate a lost track with a new detection."""
278
+ self.mean, self.covariance = self.kalman_filter.update(
279
+ self.mean, self.covariance, self._xyah_from_xyxy(new_track.tlbr)
280
+ )
281
+ self.time_since_update = 0
282
+ self.state = 2 # Tracked
283
+ self.frame_id = frame_id
284
+ self.score = new_track.score
285
+
286
+ if new_id:
287
+ self.track_id = self.next_id()
288
+
289
+ def update(self, new_track, frame_id):
290
+ """Update a tracked object with a new detection."""
291
+ self.frame_id = frame_id
292
+ self.time_since_update = 0
293
+ self.score = new_track.score
294
+
295
+ self.mean, self.covariance = self.kalman_filter.update(
296
+ self.mean, self.covariance, self._xyah_from_xyxy(new_track.tlbr)
297
+ )
298
+ self.state = 2 # Tracked
299
+ self.is_activated = True
300
+
301
+ def predict(self):
302
+ """Propagate tracking state distribution one time step forward."""
303
+ if self.mean is None: return
304
+ if self.state != 2: # Only predict if tracked? ByteTrack predicts always?
305
+ # Standard implementation predicts for all active/lost tracks
306
+ pass
307
+ self.mean, self.covariance = self.kalman_filter.predict(self.mean, self.covariance)
308
+
309
+ def _xyah_from_xyxy(self, xyxy):
310
+ """Internal helper for measurement conversion."""
311
+ bbox = np.asarray(xyxy).copy()
312
+ cx = (bbox[0] + bbox[2]) / 2.0
313
+ cy = (bbox[1] + bbox[3]) / 2.0
314
+ w = bbox[2] - bbox[0]
315
+ h = bbox[3] - bbox[1]
316
+
317
+ ret = np.zeros(4, dtype=bbox.dtype)
318
+ ret[0] = cx
319
+ ret[1] = cy
320
+ ret[2] = w / h
321
+ ret[3] = h
322
+ return ret
323
+
324
+ @staticmethod
325
+ def next_id():
326
+ # Global counter
327
+ if not hasattr(STrack, "_count"):
328
+ STrack._count = 0
329
+ STrack._count += 1
330
+ return STrack._count
331
+
332
+
333
+ class ByteTracker:
334
+ def __init__(self, track_thresh=0.5, track_buffer=30, match_thresh=0.8, frame_rate=30):
335
+ self.track_thresh = track_thresh
336
+ self.track_buffer = track_buffer
337
+ self.match_thresh = match_thresh
338
+ self.frame_id = 0
339
+
340
+ self.tracked_stracks = [] # Type: List[STrack]
341
+ self.lost_stracks = [] # Type: List[STrack]
342
+ self.removed_stracks = [] # Type: List[STrack]
343
+
344
+ self.kalman_filter = KalmanFilter()
345
+
346
+ def update(self, detections_list):
347
+ """
348
+ Update the tracker with a list of detections.
349
+
350
+ Args:
351
+ detections_list: List of dicts, each having:
352
+ - bbox: [x1, y1, x2, y2]
353
+ - score: float
354
+ - label: str
355
+ - (optional) other keys preserved
356
+
357
+ Returns:
358
+ List of dicts with 'track_id' added/updated.
359
+ """
360
+ self.frame_id += 1
361
+
362
+ # 0. STrack Conversion using generic interface
363
+ activated_stracks = []
364
+ refind_stracks = []
365
+ lost_stracks = []
366
+ removed_stracks = []
367
+
368
+ scores = [d['score'] for d in detections_list]
369
+ bboxes = [d['bbox'] for d in detections_list]
370
+
371
+ # Split into high and low confidence
372
+ detections = []
373
+ detections_second = []
374
+
375
+ # Need to keep mapping to original dict to populate results later
376
+ # We wrap original dict in STrack
377
+
378
+ for d in detections_list:
379
+ bbox = d['bbox']
380
+ score = d['score']
381
+ label = d['label']
382
+
383
+ t = STrack(bbox, score, label)
384
+ t.original_data = d # Link back
385
+
386
+ if score >= self.track_thresh:
387
+ detections.append(t)
388
+ else:
389
+ detections_second.append(t)
390
+
391
+ # 1. Prediction
392
+ unconfirmed = []
393
+ tracked_stracks = [] # Type: List[STrack]
394
+ for track in self.tracked_stracks:
395
+ if not track.is_activated:
396
+ unconfirmed.append(track)
397
+ else:
398
+ tracked_stracks.append(track)
399
+
400
+ strack_pool = join_stracks(tracked_stracks, self.lost_stracks)
401
+ # Predict the current location with KF
402
+ STrack.multi_predict(strack_pool, self.kalman_filter)
403
+
404
+ # 2. First association (High score)
405
+ dists = iou_distance(strack_pool, detections)
406
+ dists = fuse_score(dists, detections) # Optional? ByteTrack uses it
407
+ matches, u_track, u_detection = linear_assignment(dists, thresh=self.match_thresh)
408
+
409
+ for itracked, idet in matches:
410
+ track = strack_pool[itracked]
411
+ det = detections[idet]
412
+ if track.state == 2:
413
+ track.update(det, self.frame_id)
414
+ activated_stracks.append(track)
415
+ else:
416
+ track.re_activate(det, self.frame_id, new_id=False)
417
+ refind_stracks.append(track)
418
+
419
+ # Persist data
420
+ self._sync_data(track, det)
421
+
422
+ # 3. Second association (Low score)
423
+ # Match unmatched tracks to low score detections
424
+ r_tracked_stracks = [strack_pool[i] for i in u_track if strack_pool[i].state == 2]
425
+ dists = iou_distance(r_tracked_stracks, detections_second)
426
+ matches, u_track, u_detection_second = linear_assignment(dists, thresh=0.5)
427
+
428
+ for itracked, idet in matches:
429
+ track = r_tracked_stracks[itracked]
430
+ det = detections_second[idet]
431
+ if track.state == 2:
432
+ track.update(det, self.frame_id)
433
+ activated_stracks.append(track)
434
+ else:
435
+ track.re_activate(det, self.frame_id, new_id=False)
436
+ refind_stracks.append(track)
437
+
438
+ self._sync_data(track, det)
439
+
440
+ for it in u_track:
441
+ track = r_tracked_stracks[it]
442
+ if not track.state == 3: # If not already lost
443
+ track.state = 3 # Lost
444
+ lost_stracks.append(track)
445
+
446
+ # 4. Init new tracks from unmatched high score detections
447
+ # Note: Unmatched low score detections are ignored (noise)
448
+ detections = [detections[i] for i in u_detection]
449
+ for inew in u_detection:
450
+ track = detections[inew]
451
+ if track.score < self.track_thresh:
452
+ continue
453
+
454
+ track.activate(self.kalman_filter, self.frame_id)
455
+ activated_stracks.append(track)
456
+ self._sync_data(track, track) # Sync self
457
+
458
+ # 5. Update state
459
+ self.tracked_stracks = [t for t in self.tracked_stracks if t.state == 2]
460
+ self.tracked_stracks = join_stracks(self.tracked_stracks, activated_stracks)
461
+ self.tracked_stracks = join_stracks(self.tracked_stracks, refind_stracks)
462
+ self.lost_stracks = sub_stracks(self.lost_stracks, self.tracked_stracks)
463
+ self.lost_stracks.extend(lost_stracks)
464
+ self.lost_stracks = sub_stracks(self.lost_stracks, self.removed_stracks)
465
+ self.removed_stracks.extend(removed_stracks)
466
+ self.tracked_stracks, self.lost_stracks = remove_duplicate_stracks(self.tracked_stracks, self.lost_stracks)
467
+
468
+ # 6. Age out lost tracks
469
+ for track in self.lost_stracks:
470
+ if self.frame_id - track.frame_id > self.track_buffer:
471
+ self.removed_stracks.append(track)
472
+ self.lost_stracks = [t for t in self.lost_stracks if self.frame_id - t.frame_id <= self.track_buffer]
473
+
474
+ # 7. Final Output Construction
475
+ # We need to update the original dictionaries in detections_list IN PLACE,
476
+ # or return a new list. The logic in inference.py expects us to modify detections dicts
477
+ # or we might want to return the tracked ones.
478
+ # But wait, we iterate `detections_list` at start.
479
+ # We want to return ONLY the currently tracked/active objects?
480
+ # Usually inference pipeline draws ALL detections, but standard tracking ONLY output active tracks.
481
+ # If we only output active tracks, we might suppress valid high-confidence detections that just started?
482
+ # No, activated_stracks includes new ones.
483
+
484
+ # Let's collect all active tracks
485
+ output_stracks = [t for t in self.tracked_stracks if t.is_activated]
486
+
487
+ results = []
488
+ for track in output_stracks:
489
+ # Reconstruct dictionary
490
+ # Get latest bbox from Kalman State for smoothness, or original?
491
+ # Usually we use the detection box if matched, or predicted if lost (but logic above separates them).
492
+ # If matched, we have updated KF.
493
+
494
+ d_out = track.original_data.copy() if hasattr(track, 'original_data') else {}
495
+
496
+ # Update bbox to tracked bbox? Or keep raw?
497
+ # Keeping raw is safer for simple visualizer, but tracked bbox is smoother.
498
+ # Let's use tracked bbox (tlbr).
499
+ tracked_bbox = track.tlbr
500
+ d_out['bbox'] = [float(x) for x in tracked_bbox]
501
+ d_out['track_id'] = f"T{str(track.track_id).zfill(2)}"
502
+
503
+ # Restore GPT data if track has it and current detection didn't
504
+ for k, v in track.gpt_data.items():
505
+ if k not in d_out:
506
+ d_out[k] = v
507
+
508
+ # Update history
509
+ if 'history' not in track.gpt_data:
510
+ track.gpt_data['history'] = []
511
+ track.gpt_data['history'].append(d_out['bbox'])
512
+ if len(track.gpt_data['history']) > 30:
513
+ track.gpt_data['history'].pop(0)
514
+ d_out['history'] = track.gpt_data['history']
515
+
516
+ results.append(d_out)
517
+
518
+ return results
519
+
520
+ def _sync_data(self, track, det_source):
521
+ """Propagate attributes like GPT data between track and detection."""
522
+ # 1. From Source to Track (Update)
523
+ source_data = det_source.original_data if hasattr(det_source, 'original_data') else {}
524
+ for k in ['gpt_distance_m', 'gpt_direction', 'gpt_description']:
525
+ if k in source_data:
526
+ track.gpt_data[k] = source_data[k]
527
+
528
+ # 2. From Track to Source (Forward fill logic handled in output construction)
529
+
530
+
531
+ # --- Helper Functions ---
532
+
533
+ def linear_assignment(cost_matrix, thresh):
534
+ """Linear assignment with threshold using scipy."""
535
+ if cost_matrix.size == 0:
536
+ return np.empty((0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(range(cost_matrix.shape[1]))
537
+
538
+ matches, unmatched_a, unmatched_b = [], [], []
539
+
540
+ # Scipy linear_sum_assignment finds min cost
541
+ row_ind, col_ind = linear_sum_assignment(cost_matrix)
542
+
543
+ for r, c in zip(row_ind, col_ind):
544
+ if cost_matrix[r, c] <= thresh:
545
+ matches.append((r, c))
546
+ else:
547
+ unmatched_a.append(r)
548
+ unmatched_b.append(c)
549
+
550
+ # Add accumulation of indices that weren't selected
551
+ # (scipy returns perfect matching for square, but partial for rectangular)
552
+ # Actually scipy matches rows to cols. Any row not in row_ind is unmatched?
553
+ # No, row_ind covers all rows if N < M.
554
+
555
+ if cost_matrix.shape[0] > cost_matrix.shape[1]: # More rows than cols
556
+ unmatched_a += list(set(range(cost_matrix.shape[0])) - set(row_ind))
557
+ elif cost_matrix.shape[0] < cost_matrix.shape[1]: # More cols than rows
558
+ unmatched_b += list(set(range(cost_matrix.shape[1])) - set(col_ind))
559
+
560
+ # Also filter out threshold failures
561
+ for r, c in zip(row_ind, col_ind):
562
+ if cost_matrix[r, c] > thresh:
563
+ if r not in unmatched_a: unmatched_a.append(r)
564
+ if c not in unmatched_b: unmatched_b.append(c)
565
+
566
+ # Clean up
567
+ matches = np.array(matches) if len(matches) > 0 else np.empty((0, 2), dtype=int)
568
+ return matches, unmatched_a, unmatched_b
569
+
570
+
571
+ def iou_distance(atracks, btracks):
572
+ """Compute IOU cost matrix between tracks and detections."""
573
+ if (len(atracks) == 0 and len(btracks) == 0) or len(atracks) == 0 or len(btracks) == 0:
574
+ return np.zeros((len(atracks), len(btracks)), dtype=float)
575
+
576
+ atlbrs = [track.tlbr for track in atracks]
577
+ btlbrs = [track.tlbr for track in btracks]
578
+
579
+ _ious = bbox_ious(np.array(atlbrs), np.array(btlbrs))
580
+ cost_matrix = 1 - _ious
581
+ return cost_matrix
582
+
583
+ def bbox_ious(boxes1, boxes2):
584
+ """IOU matrix."""
585
+ b1_x1, b1_y1, b1_x2, b1_y2 = boxes1[:, 0], boxes1[:, 1], boxes1[:, 2], boxes1[:, 3]
586
+ b2_x1, b2_y1, b2_x2, b2_y2 = boxes2[:, 0], boxes2[:, 1], boxes2[:, 2], boxes2[:, 3]
587
+
588
+ inter_rect_x1 = np.maximum(b1_x1[:, None], b2_x1)
589
+ inter_rect_y1 = np.maximum(b1_y1[:, None], b2_y1)
590
+ inter_rect_x2 = np.minimum(b1_x2[:, None], b2_x2)
591
+ inter_rect_y2 = np.minimum(b1_y2[:, None], b2_y2)
592
+
593
+ inter_area = np.maximum(inter_rect_x2 - inter_rect_x1, 0) * np.maximum(inter_rect_y2 - inter_rect_y1, 0)
594
+
595
+ b1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1)
596
+ b2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1)
597
+
598
+ iou = inter_area / (b1_area[:, None] + b2_area - inter_area + 1e-6)
599
+ return iou
600
+
601
+
602
+ def fuse_score(cost_matrix, detections):
603
+ """Refine cost matrix with detection scores."""
604
+ if cost_matrix.size == 0: return cost_matrix
605
+ iou_sim = 1 - cost_matrix
606
+ det_scores = np.array([d.score for d in detections])
607
+ det_scores = np.expand_dims(det_scores, axis=0).repeat(cost_matrix.shape[0], axis=0)
608
+ fuse_sim = iou_sim * det_scores
609
+ fuse_cost = 1 - fuse_sim
610
+ return fuse_cost
611
+
612
+
613
+ # STrack collection helpers
614
+
615
+ def join_stracks(tlist_a, tlist_b):
616
+ exists = {}
617
+ res = []
618
+ for t in tlist_a:
619
+ exists[t.track_id] = 1
620
+ res.append(t)
621
+ for t in tlist_b:
622
+ tid = t.track_id
623
+ if not exists.get(tid, 0):
624
+ exists[tid] = 1
625
+ res.append(t)
626
+ return res
627
+
628
+ def sub_stracks(tlist_a, tlist_b):
629
+ stracks = {}
630
+ for t in tlist_a:
631
+ stracks[t.track_id] = t
632
+ for t in tlist_b:
633
+ tid = t.track_id
634
+ if stracks.get(tid, 0):
635
+ del stracks[tid]
636
+ return list(stracks.values())
637
+
638
+ def remove_duplicate_stracks(stracksa, stracksb):
639
+ pdist = iou_distance(stracksa, stracksb)
640
+ pairs = np.where(pdist < 0.15)
641
+ dupa, dupb = list(pairs[0]), list(pairs[1])
642
+ for a, b in zip(dupa, dupb):
643
+ time_a = stracksa[a].frame_id - stracksa[a].start_frame
644
+ time_b = stracksb[b].frame_id - stracksb[b].start_frame
645
+ if time_a > time_b:
646
+ dupb.append(b) # Bug in orig ByteTrack? It assumes removing from list.
647
+ # We mark for removal.
648
+ else:
649
+ dupa.append(a)
650
+
651
+ res_a = [t for i, t in enumerate(stracksa) if not i in dupa]
652
+ res_b = [t for i, t in enumerate(stracksb) if not i in dupb]
653
+ return res_a, res_b
654
+
655
+
656
+ # Monkey patch for multi_predict since STrack is not in a module
657
+ def multi_predict(stracks, kalman_filter):
658
+ for t in stracks:
659
+ if t.state != 2:
660
+ t.mean[7] = 0 # reset velocity h if lost
661
+ t.mean, t.covariance = kalman_filter.predict(t.mean, t.covariance)
662
+
663
+ STrack.multi_predict = static_method_multi_predict = multi_predict