import numpy as np from filterpy.kalman import KalmanFilter from scipy.optimize import linear_sum_assignment import logging logger = logging.getLogger(__name__) class BYTETracker: def __init__(self, track_thresh=0.3, track_buffer=90, match_thresh=0.5, frame_rate=30): self.track_thresh = track_thresh self.track_buffer = track_buffer self.match_thresh = match_thresh self.frame_rate = frame_rate self.next_id = 1 self.tracks = {} self.kalman_filters = {} self.worker_history = {} self.last_positions = {} self.recently_removed = {} logger.info("BYTETracker initialized with track_thresh=%.2f, track_buffer=%d, match_thresh=%.2f", track_thresh, track_buffer, match_thresh) def _init_kalman(self): kf = KalmanFilter(dim_x=4, dim_z=2) kf.F = np.array([[1, 0, 1, 0], [0, 1, 0, 1], [0, 0, 1, 0], [0, 0, 0, 1]]) kf.H = np.array([[1, 0, 0, 0], [0, 1, 0, 0]]) kf.P *= 1000. kf.R = np.array([[10, 0], [0, 10]]) kf.Q = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]) return kf def update(self, dets, scores, cls, current_time): tracks = [] # Prune stale tracks for track_id in list(self.tracks.keys()): if current_time - self.tracks[track_id]['last_seen'] > self.track_buffer / self.frame_rate: self.recently_removed[track_id] = { 'bbox': self.tracks[track_id]['bbox'], 'last_seen': current_time, 'last_position': self.last_positions.get(track_id, [0, 0]) } del self.tracks[track_id] del self.kalman_filters[track_id] if track_id in self.worker_history: del self.worker_history[track_id] if track_id in self.last_positions: del self.last_positions[track_id] # Predict with Kalman filter for track_id, kf in self.kalman_filters.items(): kf.predict() self.tracks[track_id]['bbox'][:2] = kf.x[:2].flatten() # Match detections to tracks if len(dets) == 0: return tracks cost_matrix = np.zeros((len(dets), len(self.tracks))) track_ids = list(self.tracks.keys()) for i, det in enumerate(dets): for j, track_id in enumerate(track_ids): cost_matrix[i, j] = 1 - self._calculate_iou(det, self.tracks[track_id]['bbox']) row_indices, col_indices = linear_sum_assignment(cost_matrix) matched = set() for i, j in zip(row_indices, col_indices): if cost_matrix[i, j] < 1 - self.match_thresh: track_id = track_ids[j] self.tracks[track_id].update({ 'bbox': dets[i], 'score': scores[i], 'cls': cls[i], 'last_seen': current_time }) self.kalman_filters[track_id].update(dets[i][:2]) self.worker_history[track_id].append(dets[i][:2]) self.last_positions[track_id] = dets[i][:2] tracks.append({'id': track_id, 'bbox': dets[i], 'score': scores[i], 'cls': cls[i]}) matched.add(i) # Handle unmatched detections for i, (det, score, cl) in enumerate(zip(dets, scores, cls)): if i in matched or score < self.track_thresh: continue reidentified = False for track_id, info in list(self.recently_removed.items()): if self._is_same_worker(det[:2], info['last_position']): self.tracks[track_id] = { 'bbox': det, 'score': score, 'cls': cl, 'last_seen': current_time } self.kalman_filters[track_id] = self._init_kalman() self.kalman_filters[track_id].x[:2] = det[:2].reshape(2, 1) self.worker_history[track_id] = [det[:2]] self.last_positions[track_id] = det[:2] tracks.append({'id': track_id, 'bbox': det, 'score': score, 'cls': cl}) del self.recently_removed[track_id] reidentified = True break if not reidentified: self.tracks[self.next_id] = { 'bbox': det, 'score': score, 'cls': cl, 'last_seen': current_time } self.kalman_filters[self.next_id] = self._init_kalman() self.kalman_filters[self.next_id].x[:2] = det[:2].reshape(2, 1) self.worker_history[self.next_id] = [det[:2]] self.last_positions[self.next_id] = det[:2] tracks.append({'id': self.next_id, 'bbox': det, 'score': score, 'cls': cl}) self.next_id += 1 return tracks def _calculate_iou(self, box1, box2): x1, y1, w1, h1 = box1 x2, y2, w2, h2 = box2 x_left = max(x1 - w1/2, x2 - w2/2) y_top = max(y1 - h1/2, y2 - h2/2) x_right = min(x1 + w1/2, x2 + w2/2) y_bottom = min(y1 + h1/2, y2 + h2/2) if x_right < x_left or y_bottom < y_top: return 0.0 intersection = (x_right - x_left) * (y_bottom - y_top) union = w1 * h1 + w2 * h2 - intersection return intersection / union def _is_same_worker(self, pos1, pos2, threshold=150): return np.sqrt(np.sum((np.array(pos1) - np.array(pos2))**2)) < threshold