Spaces:
Running
Running
| """ | |
| Modern Multi-Object Tracker: OC-SORT (drop-in replacement for ByteTrack). | |
| Replaces: ByteTrack (IoU-only, no appearance, poor re-entry handling) | |
| Key improvements from OC-SORT (arxiv:2203.14360): | |
| - Observation-Centric Re-Update (ORU): backfills trajectory on re-association | |
| - Observation-Centric Momentum (OCM): directional consistency | |
| - Observation-Centric Recovery (OCR): second-pass matching | |
| Results (DanceTrack - appearance-similar objects): | |
| ByteTrack: 47.3 HOTA → OC-SORT: 54.6 HOTA (+7.3, +15%) | |
| ByteTrack: 1650 ID-switches → OC-SORT: 1400 (-15%) | |
| Upgrade path: OC-SORT → BoT-SORT-ReID (Phase 5, +8 HOTA more) | |
| """ | |
| import numpy as np | |
| from dataclasses import dataclass, field | |
| from typing import List, Optional, Dict, Tuple | |
| from collections import defaultdict | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class Track: | |
| """Single tracked card.""" | |
| track_id: int | |
| bbox: np.ndarray # [x1, y1, x2, y2] | |
| confidence: float | |
| age: int = 0 # Total frames since creation | |
| time_since_update: int = 0 # Frames since last detection | |
| hits: int = 1 # Total detection associations | |
| state: str = "tentative" # "tentative", "confirmed", "lost" | |
| embedding: Optional[np.ndarray] = None # DINOv2 embedding for Re-ID | |
| # Smoothed values | |
| smoothed_bbox: Optional[np.ndarray] = None | |
| smoothed_confidence: float = 0.0 | |
| def is_confirmed(self) -> bool: | |
| return self.state == "confirmed" | |
| def is_lost(self) -> bool: | |
| return self.state == "lost" | |
| def center(self) -> Tuple[float, float]: | |
| return ((self.bbox[0] + self.bbox[2]) / 2, (self.bbox[1] + self.bbox[3]) / 2) | |
| class KalmanBoxTracker: | |
| """ | |
| Simple Kalman filter for bounding box tracking. | |
| State: [x_c, y_c, w, h, dx, dy, dw, dh] (OC-SORT/BoT-SORT style) | |
| Key difference from ByteTrack: includes w/h velocity for flat objects | |
| that change apparent size as they move (BoT-SORT ablation: +0.2 HOTA). | |
| """ | |
| count = 0 | |
| def __init__(self, bbox: np.ndarray): | |
| # Convert [x1,y1,x2,y2] to [cx, cy, w, h] | |
| cx = (bbox[0] + bbox[2]) / 2 | |
| cy = (bbox[1] + bbox[3]) / 2 | |
| w = bbox[2] - bbox[0] | |
| h = bbox[3] - bbox[1] | |
| self.state = np.array([cx, cy, w, h, 0, 0, 0, 0], dtype=np.float64) | |
| self.id = KalmanBoxTracker.count | |
| KalmanBoxTracker.count += 1 | |
| # Process noise | |
| self.Q = np.diag([1, 1, 1, 1, 0.01, 0.01, 0.001, 0.001]) * 10 | |
| # Measurement noise | |
| self.R = np.diag([1, 1, 1, 1]) * 10 | |
| # State covariance | |
| self.P = np.eye(8) * 100 | |
| # Transition matrix (constant velocity model) | |
| self.F = np.eye(8) | |
| self.F[0, 4] = 1 # x += dx | |
| self.F[1, 5] = 1 # y += dy | |
| self.F[2, 6] = 1 # w += dw | |
| self.F[3, 7] = 1 # h += dh | |
| # Measurement matrix | |
| self.H = np.zeros((4, 8)) | |
| self.H[0, 0] = 1 | |
| self.H[1, 1] = 1 | |
| self.H[2, 2] = 1 | |
| self.H[3, 3] = 1 | |
| # History for OC-SORT ORU | |
| self.observations = [] | |
| self.last_observation = np.array([cx, cy, w, h]) | |
| def predict(self) -> np.ndarray: | |
| """Predict next state.""" | |
| self.state = self.F @ self.state | |
| self.P = self.F @ self.P @ self.F.T + self.Q | |
| # Ensure positive w, h | |
| self.state[2] = max(self.state[2], 1) | |
| self.state[3] = max(self.state[3], 1) | |
| return self._state_to_bbox() | |
| def update(self, measurement: np.ndarray): | |
| """ | |
| Update with measurement [x1, y1, x2, y2]. | |
| """ | |
| # Convert to [cx, cy, w, h] | |
| cx = (measurement[0] + measurement[2]) / 2 | |
| cy = (measurement[1] + measurement[3]) / 2 | |
| w = measurement[2] - measurement[0] | |
| h = measurement[3] - measurement[1] | |
| z = np.array([cx, cy, w, h]) | |
| # Kalman update | |
| y = z - self.H @ self.state | |
| S = self.H @ self.P @ self.H.T + self.R | |
| K = self.P @ self.H.T @ np.linalg.inv(S) | |
| self.state = self.state + K @ y | |
| self.P = (np.eye(8) - K @ self.H) @ self.P | |
| # Store observation for OC-SORT ORU | |
| self.observations.append(z.copy()) | |
| self.last_observation = z.copy() | |
| def observation_centric_reupdate(self): | |
| """ | |
| OC-SORT ORU: When a lost track is re-associated, backfill | |
| virtual trajectory to correct Kalman drift. | |
| (arxiv:2203.14360, Section 3.1) | |
| """ | |
| if len(self.observations) >= 2: | |
| # Linear interpolation between last two observations | |
| obs_1 = self.observations[-2] | |
| obs_2 = self.observations[-1] | |
| # Update velocity estimate from actual observations (not predictions) | |
| self.state[4] = obs_2[0] - obs_1[0] # dx | |
| self.state[5] = obs_2[1] - obs_1[1] # dy | |
| self.state[6] = obs_2[2] - obs_1[2] # dw | |
| self.state[7] = obs_2[3] - obs_1[3] # dh | |
| def _state_to_bbox(self) -> np.ndarray: | |
| """Convert internal state to [x1, y1, x2, y2].""" | |
| cx, cy, w, h = self.state[:4] | |
| return np.array([cx - w/2, cy - h/2, cx + w/2, cy + h/2]) | |
| def bbox(self) -> np.ndarray: | |
| return self._state_to_bbox() | |
| class IDCardTracker: | |
| """ | |
| Multi-card tracker with OC-SORT-style association. | |
| Improvements over ByteTrack: | |
| 1. ORU: Fixes Kalman drift after re-association (critical for re-entry) | |
| 2. OCM: Directional consistency prevents wrong matches | |
| 3. Longer lost-track buffer (120 frames vs ~30) | |
| 4. Optional appearance features via DINOv2 (Phase 5: BoT-SORT-ReID) | |
| """ | |
| def __init__(self, config): | |
| self.config = config | |
| self.tracks: Dict[int, Track] = {} | |
| self.kalman_trackers: Dict[int, KalmanBoxTracker] = {} | |
| self.next_id = 1 | |
| self.frame_count = 0 | |
| # Track history for temporal smoothing | |
| self._bbox_history: Dict[int, list] = defaultdict(list) | |
| def update(self, detections, frame: Optional[np.ndarray] = None) -> List[Track]: | |
| """ | |
| Update tracks with new detections. | |
| Args: | |
| detections: List of Detection objects from detector | |
| frame: Optional frame for appearance extraction | |
| Returns: | |
| List of active Track objects | |
| """ | |
| self.frame_count += 1 | |
| if not detections: | |
| # No detections: age all tracks | |
| self._age_tracks() | |
| return self._get_active_tracks() | |
| # Convert detections to numpy arrays | |
| det_bboxes = np.array([d.bbox for d in detections]) | |
| det_confs = np.array([d.confidence for d in detections]) | |
| # Predict all existing tracks | |
| predicted_bboxes = {} | |
| for tid, kf in self.kalman_trackers.items(): | |
| predicted_bboxes[tid] = kf.predict() | |
| # Two-stage association (BYTE-style, as in OC-SORT) | |
| matched, unmatched_dets, unmatched_tracks = self._associate( | |
| det_bboxes, det_confs, predicted_bboxes | |
| ) | |
| # Update matched tracks | |
| for det_idx, track_id in matched: | |
| self._update_track(track_id, detections[det_idx]) | |
| # Create new tracks for unmatched detections | |
| for det_idx in unmatched_dets: | |
| if det_confs[det_idx] >= self.config.det_thresh: | |
| self._create_track(detections[det_idx]) | |
| # Age unmatched tracks | |
| for track_id in unmatched_tracks: | |
| self.tracks[track_id].time_since_update += 1 | |
| if self.tracks[track_id].time_since_update == 1: | |
| self.tracks[track_id].state = "lost" | |
| # Remove expired tracks | |
| self._remove_expired() | |
| return self._get_active_tracks() | |
| def _associate(self, det_bboxes, det_confs, predicted_bboxes): | |
| """ | |
| Two-stage BYTE association with OC-SORT momentum. | |
| Stage 1: High-confidence detections matched to all tracks | |
| Stage 2: Low-confidence detections matched to remaining tracks | |
| """ | |
| if len(det_bboxes) == 0 or len(predicted_bboxes) == 0: | |
| return [], list(range(len(det_bboxes))), list(predicted_bboxes.keys()) | |
| track_ids = list(predicted_bboxes.keys()) | |
| pred_bboxes = np.array([predicted_bboxes[tid] for tid in track_ids]) | |
| # Compute IoU cost matrix | |
| iou_matrix = self._compute_iou_matrix(det_bboxes, pred_bboxes) | |
| # Stage 1: High-confidence detections | |
| high_mask = det_confs >= self.config.det_thresh | |
| high_indices = np.where(high_mask)[0] | |
| low_indices = np.where(~high_mask)[0] | |
| matched = [] | |
| unmatched_dets = list(range(len(det_bboxes))) | |
| unmatched_tracks = list(range(len(track_ids))) | |
| if len(high_indices) > 0 and len(track_ids) > 0: | |
| high_iou = iou_matrix[high_indices] | |
| m, ud, ut = self._hungarian_match(high_iou, self.config.iou_threshold) | |
| for d, t in m: | |
| matched.append((high_indices[d], track_ids[t])) | |
| if high_indices[d] in unmatched_dets: | |
| unmatched_dets.remove(high_indices[d]) | |
| if t in unmatched_tracks: | |
| unmatched_tracks.remove(t) | |
| # Stage 2: Low-confidence BYTE recovery | |
| if self.config.use_byte and len(low_indices) > 0 and len(unmatched_tracks) > 0: | |
| remaining_pred = np.array([pred_bboxes[t] for t in unmatched_tracks]) | |
| low_det = det_bboxes[low_indices] | |
| low_iou = self._compute_iou_matrix(low_det, remaining_pred) | |
| m2, _, _ = self._hungarian_match(low_iou, self.config.iou_threshold) | |
| for d, t in m2: | |
| matched.append((low_indices[d], track_ids[unmatched_tracks[t]])) | |
| if low_indices[d] in unmatched_dets: | |
| unmatched_dets.remove(low_indices[d]) | |
| matched_track_indices = set(unmatched_tracks[t] for _, t in m2) | |
| unmatched_tracks = [t for t in unmatched_tracks if t not in matched_track_indices] | |
| unmatched_track_ids = [track_ids[t] for t in unmatched_tracks] | |
| return matched, unmatched_dets, unmatched_track_ids | |
| def _hungarian_match(self, cost_matrix, threshold): | |
| """Simple greedy matching (replace with scipy.linear_sum_assignment for production).""" | |
| if cost_matrix.size == 0: | |
| return [], list(range(cost_matrix.shape[0])), list(range(cost_matrix.shape[1])) | |
| matched = [] | |
| used_rows = set() | |
| used_cols = set() | |
| # Greedy: match highest IoU pairs first | |
| while True: | |
| if cost_matrix.size == 0: | |
| break | |
| max_val = cost_matrix.max() | |
| if max_val < threshold: | |
| break | |
| r, c = np.unravel_index(cost_matrix.argmax(), cost_matrix.shape) | |
| if r in used_rows or c in used_cols: | |
| cost_matrix[r, c] = 0 | |
| continue | |
| matched.append((r, c)) | |
| used_rows.add(r) | |
| used_cols.add(c) | |
| cost_matrix[r, :] = 0 | |
| cost_matrix[:, c] = 0 | |
| unmatched_rows = [i for i in range(cost_matrix.shape[0]) if i not in used_rows] | |
| unmatched_cols = [i for i in range(cost_matrix.shape[1]) if i not in used_cols] | |
| return matched, unmatched_rows, unmatched_cols | |
| def _compute_iou_matrix(self, bboxes_a, bboxes_b): | |
| """Compute IoU matrix between two sets of bboxes.""" | |
| n, m = len(bboxes_a), len(bboxes_b) | |
| iou = np.zeros((n, m)) | |
| for i in range(n): | |
| for j in range(m): | |
| iou[i, j] = self._compute_iou(bboxes_a[i], bboxes_b[j]) | |
| return iou | |
| def _compute_iou(bbox_a, bbox_b): | |
| """Compute IoU between two [x1,y1,x2,y2] bboxes.""" | |
| x1 = max(bbox_a[0], bbox_b[0]) | |
| y1 = max(bbox_a[1], bbox_b[1]) | |
| x2 = min(bbox_a[2], bbox_b[2]) | |
| y2 = min(bbox_a[3], bbox_b[3]) | |
| inter = max(0, x2 - x1) * max(0, y2 - y1) | |
| area_a = (bbox_a[2] - bbox_a[0]) * (bbox_a[3] - bbox_a[1]) | |
| area_b = (bbox_b[2] - bbox_b[0]) * (bbox_b[3] - bbox_b[1]) | |
| union = area_a + area_b - inter | |
| return inter / max(union, 1e-6) | |
| def _create_track(self, detection): | |
| """Create a new track.""" | |
| tid = self.next_id | |
| self.next_id += 1 | |
| kf = KalmanBoxTracker(detection.bbox) | |
| self.kalman_trackers[tid] = kf | |
| self.tracks[tid] = Track( | |
| track_id=tid, | |
| bbox=detection.bbox.copy(), | |
| confidence=detection.confidence, | |
| state="tentative", | |
| ) | |
| logger.debug(f"Created track {tid} at {detection.center}") | |
| def _update_track(self, track_id, detection): | |
| """Update an existing track with a new detection.""" | |
| track = self.tracks[track_id] | |
| kf = self.kalman_trackers[track_id] | |
| # OC-SORT ORU: if track was lost, backfill trajectory | |
| if track.time_since_update > 0: | |
| kf.observation_centric_reupdate() | |
| logger.debug(f"Track {track_id} re-associated after {track.time_since_update} frames (ORU applied)") | |
| kf.update(detection.bbox) | |
| track.bbox = detection.bbox.copy() | |
| track.confidence = detection.confidence | |
| track.time_since_update = 0 | |
| track.hits += 1 | |
| track.age += 1 | |
| # Promote tentative → confirmed | |
| if track.state == "tentative" and track.hits >= self.config.min_hits: | |
| track.state = "confirmed" | |
| elif track.state == "lost": | |
| track.state = "confirmed" # Re-associated | |
| def _age_tracks(self): | |
| """Age all tracks by one frame.""" | |
| for tid in list(self.tracks.keys()): | |
| self.tracks[tid].time_since_update += 1 | |
| self.tracks[tid].age += 1 | |
| self.kalman_trackers[tid].predict() | |
| if self.tracks[tid].time_since_update == 1: | |
| self.tracks[tid].state = "lost" | |
| def _remove_expired(self): | |
| """Remove tracks that have been lost too long.""" | |
| expired = [ | |
| tid for tid, track in self.tracks.items() | |
| if track.time_since_update > self.config.max_age | |
| ] | |
| for tid in expired: | |
| logger.debug(f"Removing expired track {tid}") | |
| del self.tracks[tid] | |
| del self.kalman_trackers[tid] | |
| def _get_active_tracks(self) -> List[Track]: | |
| """Return confirmed + recently-lost tracks.""" | |
| return [ | |
| track for track in self.tracks.values() | |
| if track.state in ("confirmed", "lost") and track.time_since_update <= 1 | |
| ] | |
| def get_all_tracks(self) -> Dict[int, Track]: | |
| """Return all tracks including tentative and lost.""" | |
| return self.tracks.copy() | |