Zhen Ye commited on
Commit
7e7e04e
·
1 Parent(s): e7dfc36

Refactor ByteTracker: add min_hits confirmation, split thresholds, associate unconfirmed tracks

Browse files
Files changed (1) hide show
  1. utils/tracker.py +62 -24
utils/tracker.py CHANGED
@@ -223,17 +223,18 @@ class STrack:
223
  # wait, input is xyxy usually in our pipeline
224
  # ByteTrack usually uses tlwh internally.
225
  # Let's standardize to input xyxy.
226
-
227
  self._tlwh = np.asarray(self._tlwh_from_xyxy(tlwh), dtype=np.float32)
228
  self.is_activated = False
229
  self.track_id = 0
230
  self.state = 1 # 1: New, 2: Tracked, 3: Lost, 4: Removed
231
-
232
  self.score = score
233
  self.label = label
234
  self.start_frame = 0
235
  self.frame_id = 0
236
  self.time_since_update = 0
 
237
 
238
  # Multi-frame history
239
  self.history = []
@@ -282,15 +283,16 @@ class STrack:
282
  return ret
283
 
284
  def activate(self, kalman_filter, frame_id):
285
- """Start a new track."""
286
  self.kalman_filter = kalman_filter
287
  self.track_id = self.next_id()
288
  self.mean, self.covariance = self.kalman_filter.initiate(self.tlbr) # Initiate needs xyxy
289
-
290
  self.state = 2 # Tracked
291
  self.frame_id = frame_id
292
  self.start_frame = frame_id
293
- self.is_activated = True
 
294
 
295
  def re_activate(self, new_track, frame_id, new_id=False):
296
  """Reactivate a lost track with a new detection."""
@@ -305,17 +307,19 @@ class STrack:
305
  if new_id:
306
  self.track_id = self.next_id()
307
 
308
- def update(self, new_track, frame_id):
309
  """Update a tracked object with a new detection."""
310
  self.frame_id = frame_id
311
  self.time_since_update = 0
312
  self.score = new_track.score
313
-
 
314
  self.mean, self.covariance = self.kalman_filter.update(
315
  self.mean, self.covariance, self._xyah_from_xyxy(new_track.tlbr)
316
  )
317
  self.state = 2 # Tracked
318
- self.is_activated = True
 
319
 
320
  def predict(self):
321
  """Propagate tracking state distribution one time step forward."""
@@ -355,13 +359,18 @@ class STrack:
355
 
356
 
357
  class ByteTracker:
358
- def __init__(self, track_thresh=0.3, track_buffer=60, match_thresh=0.8, frame_rate=30):
 
 
359
  STrack.reset_count()
360
- self.track_thresh = track_thresh
 
 
361
  self.track_buffer = track_buffer
362
  self.match_thresh = match_thresh
 
363
  self.frame_id = 0
364
-
365
  self.tracked_stracks = [] # Type: List[STrack]
366
  self.lost_stracks = [] # Type: List[STrack]
367
  self.removed_stracks = [] # Type: List[STrack]
@@ -401,14 +410,14 @@ class ByteTracker:
401
  # We wrap original dict in STrack
402
 
403
  for d in detections_list:
404
- bbox = d['bbox']
405
  score = d['score']
406
- label = d['label']
407
-
408
- t = STrack(bbox, score, label)
409
- t.original_data = d # Link back
410
-
411
- if score >= self.track_thresh:
 
412
  detections.append(t)
413
  else:
414
  detections_second.append(t)
@@ -435,12 +444,12 @@ class ByteTracker:
435
  track = strack_pool[itracked]
436
  det = detections[idet]
437
  if track.state == 2:
438
- track.update(det, self.frame_id)
439
  activated_stracks.append(track)
440
  else:
441
  track.re_activate(det, self.frame_id, new_id=False)
442
  refind_stracks.append(track)
443
-
444
  # Persist data
445
  self._sync_data(track, det)
446
 
@@ -454,12 +463,12 @@ class ByteTracker:
454
  track = r_tracked_stracks[itracked]
455
  det = detections_second[idet]
456
  if track.state == 2:
457
- track.update(det, self.frame_id)
458
  activated_stracks.append(track)
459
  else:
460
  track.re_activate(det, self.frame_id, new_id=False)
461
  refind_stracks.append(track)
462
-
463
  self._sync_data(track, det)
464
 
465
  for it in u_track:
@@ -468,12 +477,41 @@ class ByteTracker:
468
  track.state = 3 # Lost
469
  lost_stracks.append(track)
470
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
471
  # 4. Init new tracks from unmatched high score detections
472
  # Note: Unmatched low score detections are ignored (noise)
473
  unmatched_dets = [detections[i] for i in u_detection]
474
  for track in unmatched_dets:
475
- if track.score < self.track_thresh:
476
- continue
477
 
478
  track.activate(self.kalman_filter, self.frame_id)
479
  activated_stracks.append(track)
 
223
  # wait, input is xyxy usually in our pipeline
224
  # ByteTrack usually uses tlwh internally.
225
  # Let's standardize to input xyxy.
226
+
227
  self._tlwh = np.asarray(self._tlwh_from_xyxy(tlwh), dtype=np.float32)
228
  self.is_activated = False
229
  self.track_id = 0
230
  self.state = 1 # 1: New, 2: Tracked, 3: Lost, 4: Removed
231
+
232
  self.score = score
233
  self.label = label
234
  self.start_frame = 0
235
  self.frame_id = 0
236
  self.time_since_update = 0
237
+ self.hits = 0
238
 
239
  # Multi-frame history
240
  self.history = []
 
283
  return ret
284
 
285
  def activate(self, kalman_filter, frame_id):
286
+ """Start a new track (tentative until min_hits reached)."""
287
  self.kalman_filter = kalman_filter
288
  self.track_id = self.next_id()
289
  self.mean, self.covariance = self.kalman_filter.initiate(self.tlbr) # Initiate needs xyxy
290
+
291
  self.state = 2 # Tracked
292
  self.frame_id = frame_id
293
  self.start_frame = frame_id
294
+ self.is_activated = False # Tentative until min_hits reached
295
+ self.hits = 1
296
 
297
  def re_activate(self, new_track, frame_id, new_id=False):
298
  """Reactivate a lost track with a new detection."""
 
307
  if new_id:
308
  self.track_id = self.next_id()
309
 
310
+ def update(self, new_track, frame_id, min_hits=3):
311
  """Update a tracked object with a new detection."""
312
  self.frame_id = frame_id
313
  self.time_since_update = 0
314
  self.score = new_track.score
315
+ self.hits += 1
316
+
317
  self.mean, self.covariance = self.kalman_filter.update(
318
  self.mean, self.covariance, self._xyah_from_xyxy(new_track.tlbr)
319
  )
320
  self.state = 2 # Tracked
321
+ if self.hits >= min_hits:
322
+ self.is_activated = True # Confirmed after N consecutive hits
323
 
324
  def predict(self):
325
  """Propagate tracking state distribution one time step forward."""
 
359
 
360
 
361
  class ByteTracker:
362
+ def __init__(self, track_high_thresh=0.3, track_low_thresh=0.1,
363
+ new_track_thresh=0.4, track_buffer=60, match_thresh=0.8,
364
+ frame_rate=30, min_hits=3):
365
  STrack.reset_count()
366
+ self.track_high_thresh = track_high_thresh
367
+ self.track_low_thresh = track_low_thresh
368
+ self.new_track_thresh = new_track_thresh
369
  self.track_buffer = track_buffer
370
  self.match_thresh = match_thresh
371
+ self.min_hits = min_hits
372
  self.frame_id = 0
373
+
374
  self.tracked_stracks = [] # Type: List[STrack]
375
  self.lost_stracks = [] # Type: List[STrack]
376
  self.removed_stracks = [] # Type: List[STrack]
 
410
  # We wrap original dict in STrack
411
 
412
  for d in detections_list:
 
413
  score = d['score']
414
+ if score < self.track_low_thresh:
415
+ continue # Background noise — discard entirely
416
+
417
+ t = STrack(d['bbox'], score, d['label'])
418
+ t.original_data = d # Link back
419
+
420
+ if score >= self.track_high_thresh:
421
  detections.append(t)
422
  else:
423
  detections_second.append(t)
 
444
  track = strack_pool[itracked]
445
  det = detections[idet]
446
  if track.state == 2:
447
+ track.update(det, self.frame_id, min_hits=self.min_hits)
448
  activated_stracks.append(track)
449
  else:
450
  track.re_activate(det, self.frame_id, new_id=False)
451
  refind_stracks.append(track)
452
+
453
  # Persist data
454
  self._sync_data(track, det)
455
 
 
463
  track = r_tracked_stracks[itracked]
464
  det = detections_second[idet]
465
  if track.state == 2:
466
+ track.update(det, self.frame_id, min_hits=self.min_hits)
467
  activated_stracks.append(track)
468
  else:
469
  track.re_activate(det, self.frame_id, new_id=False)
470
  refind_stracks.append(track)
471
+
472
  self._sync_data(track, det)
473
 
474
  for it in u_track:
 
477
  track.state = 3 # Lost
478
  lost_stracks.append(track)
479
 
480
+ # 3.5 Associate unconfirmed tracks with remaining unmatched detections
481
+ if unconfirmed and u_detection:
482
+ remaining_dets = [detections[i] for i in u_detection]
483
+ dists = iou_distance(unconfirmed, remaining_dets)
484
+ matches_unc, u_unconfirmed, u_det_remaining = linear_assignment(dists, thresh=0.7)
485
+
486
+ for itracked, idet in matches_unc:
487
+ track = unconfirmed[itracked]
488
+ det = remaining_dets[idet]
489
+ track.update(det, self.frame_id, min_hits=self.min_hits)
490
+ activated_stracks.append(track)
491
+ self._sync_data(track, det)
492
+
493
+ # Update u_detection to only contain indices not matched to unconfirmed
494
+ matched_det_indices = set(u_detection[idet] for _, idet in matches_unc) if len(matches_unc) > 0 else set()
495
+ u_detection = [i for i in u_detection if i not in matched_det_indices]
496
+
497
+ # Unconfirmed tracks that didn't match → remove (too noisy to keep)
498
+ for it in u_unconfirmed:
499
+ track = unconfirmed[it]
500
+ track.state = 4 # Removed
501
+ removed_stracks.append(track)
502
+
503
+ elif unconfirmed:
504
+ # No detections left to match — remove all unconfirmed
505
+ for track in unconfirmed:
506
+ track.state = 4 # Removed
507
+ removed_stracks.append(track)
508
+
509
  # 4. Init new tracks from unmatched high score detections
510
  # Note: Unmatched low score detections are ignored (noise)
511
  unmatched_dets = [detections[i] for i in u_detection]
512
  for track in unmatched_dets:
513
+ if track.score < self.new_track_thresh:
514
+ continue # Not confident enough to start a new track
515
 
516
  track.activate(self.kalman_filter, self.frame_id)
517
  activated_stracks.append(track)