| """ |
| ShortSmith v2 - Object Tracker Module |
| |
| Multi-object tracking using ByteTrack for: |
| - Maintaining person identity across frames |
| - Handling occlusions and reappearances |
| - Tracking specific individuals through video |
| |
| ByteTrack uses two-stage association for robust tracking. |
| """ |
|
|
| from pathlib import Path |
| from typing import List, Optional, Dict, Tuple, Union |
| from dataclasses import dataclass, field |
| import numpy as np |
|
|
| from utils.logger import get_logger, LogTimer |
| from utils.helpers import InferenceError |
| from config import get_config |
|
|
| logger = get_logger("models.tracker") |
|
|
|
|
| @dataclass |
| class TrackedObject: |
| """Represents a tracked object across frames.""" |
| track_id: int |
| bbox: Tuple[int, int, int, int] |
| confidence: float |
| class_id: int = 0 |
| frame_id: int = 0 |
|
|
| |
| history: List[Tuple[int, int, int, int]] = field(default_factory=list) |
| age: int = 0 |
| hits: int = 0 |
| time_since_update: int = 0 |
|
|
| @property |
| def center(self) -> Tuple[int, int]: |
| x1, y1, x2, y2 = self.bbox |
| return ((x1 + x2) // 2, (y1 + y2) // 2) |
|
|
| @property |
| def area(self) -> int: |
| x1, y1, x2, y2 = self.bbox |
| return (x2 - x1) * (y2 - y1) |
|
|
| @property |
| def is_confirmed(self) -> bool: |
| """Track is confirmed after multiple detections.""" |
| return self.hits >= 3 |
|
|
|
|
| @dataclass |
| class TrackingResult: |
| """Result of tracking for a single frame.""" |
| frame_id: int |
| tracks: List[TrackedObject] |
| lost_tracks: List[int] |
| new_tracks: List[int] |
|
|
|
|
| class ObjectTracker: |
| """ |
| Multi-object tracker using ByteTrack algorithm. |
| |
| ByteTrack features: |
| - Two-stage association (high-confidence first, then low-confidence) |
| - Handles occlusions by keeping lost tracks |
| - Re-identifies objects after temporary disappearance |
| """ |
|
|
| def __init__( |
| self, |
| track_thresh: float = 0.5, |
| track_buffer: int = 30, |
| match_thresh: float = 0.8, |
| ): |
| """ |
| Initialize tracker. |
| |
| Args: |
| track_thresh: Detection confidence threshold for new tracks |
| track_buffer: Frames to keep lost tracks |
| match_thresh: IoU threshold for matching |
| """ |
| self.track_thresh = track_thresh |
| self.track_buffer = track_buffer |
| self.match_thresh = match_thresh |
|
|
| self._tracks: Dict[int, TrackedObject] = {} |
| self._lost_tracks: Dict[int, TrackedObject] = {} |
| self._next_id = 1 |
| self._frame_id = 0 |
|
|
| logger.info( |
| f"ObjectTracker initialized (thresh={track_thresh}, " |
| f"buffer={track_buffer}, match={match_thresh})" |
| ) |
|
|
| def update( |
| self, |
| detections: List[Tuple[Tuple[int, int, int, int], float]], |
| ) -> TrackingResult: |
| """ |
| Update tracker with new detections. |
| |
| Args: |
| detections: List of (bbox, confidence) tuples |
| |
| Returns: |
| TrackingResult with current tracks |
| """ |
| self._frame_id += 1 |
|
|
| if not detections: |
| |
| return self._handle_no_detections() |
|
|
| |
| high_conf = [(bbox, conf) for bbox, conf in detections if conf >= self.track_thresh] |
| low_conf = [(bbox, conf) for bbox, conf in detections if conf < self.track_thresh] |
|
|
| |
| matched, unmatched_tracks, unmatched_dets = self._associate( |
| list(self._tracks.values()), |
| high_conf, |
| self.match_thresh, |
| ) |
|
|
| |
| for track_id, det_idx in matched: |
| bbox, conf = high_conf[det_idx] |
| self._update_track(track_id, bbox, conf) |
|
|
| |
| if low_conf and unmatched_tracks: |
| remaining_tracks = [self._tracks[tid] for tid in unmatched_tracks] |
| matched2, unmatched_tracks, _ = self._associate( |
| remaining_tracks, |
| low_conf, |
| self.match_thresh * 0.9, |
| ) |
|
|
| for track_id, det_idx in matched2: |
| bbox, conf = low_conf[det_idx] |
| self._update_track(track_id, bbox, conf) |
|
|
| |
| lost_this_frame = [] |
| for track_id in unmatched_tracks: |
| track = self._tracks[track_id] |
| track.time_since_update += 1 |
|
|
| if track.time_since_update > self.track_buffer: |
| |
| del self._tracks[track_id] |
| lost_this_frame.append(track_id) |
| else: |
| |
| self._lost_tracks[track_id] = self._tracks.pop(track_id) |
|
|
| |
| recovered = self._recover_lost_tracks( |
| [(high_conf[i] if i < len(high_conf) else low_conf[i - len(high_conf)]) |
| for i in unmatched_dets] |
| ) |
|
|
| |
| new_tracks = [] |
| for i in unmatched_dets: |
| if i not in recovered: |
| det = high_conf[i] if i < len(high_conf) else low_conf[i - len(high_conf)] |
| bbox, conf = det |
| track_id = self._create_track(bbox, conf) |
| new_tracks.append(track_id) |
|
|
| return TrackingResult( |
| frame_id=self._frame_id, |
| tracks=list(self._tracks.values()), |
| lost_tracks=lost_this_frame, |
| new_tracks=new_tracks, |
| ) |
|
|
| def _associate( |
| self, |
| tracks: List[TrackedObject], |
| detections: List[Tuple[Tuple[int, int, int, int], float]], |
| thresh: float, |
| ) -> Tuple[List[Tuple[int, int]], List[int], List[int]]: |
| """ |
| Associate detections to tracks using IoU. |
| |
| Returns: |
| (matched_pairs, unmatched_track_ids, unmatched_detection_indices) |
| """ |
| if not tracks or not detections: |
| return [], [t.track_id for t in tracks], list(range(len(detections))) |
|
|
| |
| iou_matrix = np.zeros((len(tracks), len(detections))) |
|
|
| for i, track in enumerate(tracks): |
| for j, (det_bbox, _) in enumerate(detections): |
| iou_matrix[i, j] = self._compute_iou(track.bbox, det_bbox) |
|
|
| |
| matched = [] |
| unmatched_tracks = set(t.track_id for t in tracks) |
| unmatched_dets = set(range(len(detections))) |
|
|
| while True: |
| |
| if iou_matrix.size == 0: |
| break |
|
|
| max_iou = np.max(iou_matrix) |
| if max_iou < thresh: |
| break |
|
|
| max_idx = np.unravel_index(np.argmax(iou_matrix), iou_matrix.shape) |
| track_idx, det_idx = max_idx |
|
|
| track_id = tracks[track_idx].track_id |
| matched.append((track_id, det_idx)) |
| unmatched_tracks.discard(track_id) |
| unmatched_dets.discard(det_idx) |
|
|
| |
| iou_matrix[track_idx, :] = -1 |
| iou_matrix[:, det_idx] = -1 |
|
|
| return matched, list(unmatched_tracks), list(unmatched_dets) |
|
|
| def _compute_iou( |
| self, |
| bbox1: Tuple[int, int, int, int], |
| bbox2: Tuple[int, int, int, int], |
| ) -> float: |
| """Compute IoU between two bounding boxes.""" |
| x1_1, y1_1, x2_1, y2_1 = bbox1 |
| x1_2, y1_2, x2_2, y2_2 = bbox2 |
|
|
| |
| xi1 = max(x1_1, x1_2) |
| yi1 = max(y1_1, y1_2) |
| xi2 = min(x2_1, x2_2) |
| yi2 = min(y2_1, y2_2) |
|
|
| if xi2 <= xi1 or yi2 <= yi1: |
| return 0.0 |
|
|
| intersection = (xi2 - xi1) * (yi2 - yi1) |
|
|
| |
| area1 = (x2_1 - x1_1) * (y2_1 - y1_1) |
| area2 = (x2_2 - x1_2) * (y2_2 - y1_2) |
| union = area1 + area2 - intersection |
|
|
| return intersection / union if union > 0 else 0.0 |
|
|
| def _update_track( |
| self, |
| track_id: int, |
| bbox: Tuple[int, int, int, int], |
| confidence: float, |
| ) -> None: |
| """Update an existing track.""" |
| track = self._tracks.get(track_id) or self._lost_tracks.get(track_id) |
|
|
| if track is None: |
| return |
|
|
| |
| if track_id in self._lost_tracks: |
| self._tracks[track_id] = self._lost_tracks.pop(track_id) |
|
|
| track = self._tracks[track_id] |
| track.history.append(track.bbox) |
| track.bbox = bbox |
| track.confidence = confidence |
| track.frame_id = self._frame_id |
| track.hits += 1 |
| track.time_since_update = 0 |
|
|
| def _create_track( |
| self, |
| bbox: Tuple[int, int, int, int], |
| confidence: float, |
| ) -> int: |
| """Create a new track.""" |
| track_id = self._next_id |
| self._next_id += 1 |
|
|
| track = TrackedObject( |
| track_id=track_id, |
| bbox=bbox, |
| confidence=confidence, |
| frame_id=self._frame_id, |
| age=1, |
| hits=1, |
| ) |
|
|
| self._tracks[track_id] = track |
| logger.debug(f"Created new track {track_id}") |
| return track_id |
|
|
| def _recover_lost_tracks( |
| self, |
| detections: List[Tuple[Tuple[int, int, int, int], float]], |
| ) -> set: |
| """Try to recover lost tracks with unmatched detections.""" |
| recovered = set() |
|
|
| if not self._lost_tracks or not detections: |
| return recovered |
|
|
| for det_idx, (bbox, conf) in enumerate(detections): |
| best_iou = 0 |
| best_track_id = None |
|
|
| for track_id, track in self._lost_tracks.items(): |
| iou = self._compute_iou(track.bbox, bbox) |
| if iou > best_iou and iou > self.match_thresh * 0.7: |
| best_iou = iou |
| best_track_id = track_id |
|
|
| if best_track_id is not None: |
| self._update_track(best_track_id, bbox, conf) |
| recovered.add(det_idx) |
| logger.debug(f"Recovered track {best_track_id}") |
|
|
| return recovered |
|
|
| def _handle_no_detections(self) -> TrackingResult: |
| """Handle frame with no detections.""" |
| lost_this_frame = [] |
|
|
| for track_id in list(self._tracks.keys()): |
| track = self._tracks[track_id] |
| track.time_since_update += 1 |
|
|
| if track.time_since_update > self.track_buffer: |
| del self._tracks[track_id] |
| lost_this_frame.append(track_id) |
| else: |
| self._lost_tracks[track_id] = self._tracks.pop(track_id) |
|
|
| return TrackingResult( |
| frame_id=self._frame_id, |
| tracks=list(self._tracks.values()), |
| lost_tracks=lost_this_frame, |
| new_tracks=[], |
| ) |
|
|
| def get_track(self, track_id: int) -> Optional[TrackedObject]: |
| """Get a specific track by ID.""" |
| return self._tracks.get(track_id) or self._lost_tracks.get(track_id) |
|
|
| def get_active_tracks(self) -> List[TrackedObject]: |
| """Get all active tracks.""" |
| return list(self._tracks.values()) |
|
|
| def get_confirmed_tracks(self) -> List[TrackedObject]: |
| """Get only confirmed tracks (multiple detections).""" |
| return [t for t in self._tracks.values() if t.is_confirmed] |
|
|
| def reset(self) -> None: |
| """Reset tracker state.""" |
| self._tracks.clear() |
| self._lost_tracks.clear() |
| self._frame_id = 0 |
| logger.info("Tracker reset") |
|
|
| def get_track_for_target( |
| self, |
| target_bbox: Tuple[int, int, int, int], |
| threshold: float = 0.5, |
| ) -> Optional[int]: |
| """ |
| Find track that best matches a target bounding box. |
| |
| Args: |
| target_bbox: Target bounding box to match |
| threshold: Minimum IoU for match |
| |
| Returns: |
| Track ID if found, None otherwise |
| """ |
| best_iou = 0 |
| best_track = None |
|
|
| for track in self._tracks.values(): |
| iou = self._compute_iou(track.bbox, target_bbox) |
| if iou > best_iou and iou > threshold: |
| best_iou = iou |
| best_track = track.track_id |
|
|
| return best_track |
|
|
|
|
| |
| __all__ = ["ObjectTracker", "TrackedObject", "TrackingResult"] |
|
|