Spaces:
Paused
Paused
| """ | |
| ShortSmith v2 - Object Tracker Module | |
| Multi-object tracking using ByteTrack for: | |
| - Maintaining person identity across frames | |
| - Handling occlusions and reappearances | |
| - Tracking specific individuals through video | |
| ByteTrack uses two-stage association for robust tracking. | |
| """ | |
| from pathlib import Path | |
| from typing import List, Optional, Dict, Tuple, Union | |
| from dataclasses import dataclass, field | |
| import numpy as np | |
| from utils.logger import get_logger, LogTimer | |
| from utils.helpers import InferenceError | |
| from config import get_config | |
| logger = get_logger("models.tracker") | |
| class TrackedObject: | |
| """Represents a tracked object across frames.""" | |
| track_id: int # Unique track identifier | |
| bbox: Tuple[int, int, int, int] # Current bounding box (x1, y1, x2, y2) | |
| confidence: float # Detection confidence | |
| class_id: int = 0 # Object class (0 = person) | |
| frame_id: int = 0 # Current frame number | |
| # Track history | |
| history: List[Tuple[int, int, int, int]] = field(default_factory=list) | |
| age: int = 0 # Frames since first detection | |
| hits: int = 0 # Number of detections | |
| time_since_update: int = 0 # Frames since last detection | |
| def center(self) -> Tuple[int, int]: | |
| x1, y1, x2, y2 = self.bbox | |
| return ((x1 + x2) // 2, (y1 + y2) // 2) | |
| def area(self) -> int: | |
| x1, y1, x2, y2 = self.bbox | |
| return (x2 - x1) * (y2 - y1) | |
| def is_confirmed(self) -> bool: | |
| """Track is confirmed after multiple detections.""" | |
| return self.hits >= 3 | |
| class TrackingResult: | |
| """Result of tracking for a single frame.""" | |
| frame_id: int | |
| tracks: List[TrackedObject] | |
| lost_tracks: List[int] # Track IDs lost this frame | |
| new_tracks: List[int] # New track IDs this frame | |
| class ObjectTracker: | |
| """ | |
| Multi-object tracker using ByteTrack algorithm. | |
| ByteTrack features: | |
| - Two-stage association (high-confidence first, then low-confidence) | |
| - Handles occlusions by keeping lost tracks | |
| - Re-identifies objects after temporary disappearance | |
| """ | |
| def __init__( | |
| self, | |
| track_thresh: float = 0.5, | |
| track_buffer: int = 30, | |
| match_thresh: float = 0.8, | |
| ): | |
| """ | |
| Initialize tracker. | |
| Args: | |
| track_thresh: Detection confidence threshold for new tracks | |
| track_buffer: Frames to keep lost tracks | |
| match_thresh: IoU threshold for matching | |
| """ | |
| self.track_thresh = track_thresh | |
| self.track_buffer = track_buffer | |
| self.match_thresh = match_thresh | |
| self._tracks: Dict[int, TrackedObject] = {} | |
| self._lost_tracks: Dict[int, TrackedObject] = {} | |
| self._next_id = 1 | |
| self._frame_id = 0 | |
| logger.info( | |
| f"ObjectTracker initialized (thresh={track_thresh}, " | |
| f"buffer={track_buffer}, match={match_thresh})" | |
| ) | |
| def update( | |
| self, | |
| detections: List[Tuple[Tuple[int, int, int, int], float]], | |
| ) -> TrackingResult: | |
| """ | |
| Update tracker with new detections. | |
| Args: | |
| detections: List of (bbox, confidence) tuples | |
| Returns: | |
| TrackingResult with current tracks | |
| """ | |
| self._frame_id += 1 | |
| if not detections: | |
| # No detections - age all tracks | |
| return self._handle_no_detections() | |
| # Separate high and low confidence detections | |
| high_conf = [(bbox, conf) for bbox, conf in detections if conf >= self.track_thresh] | |
| low_conf = [(bbox, conf) for bbox, conf in detections if conf < self.track_thresh] | |
| # First association: match high-confidence detections to active tracks | |
| matched, unmatched_tracks, unmatched_dets = self._associate( | |
| list(self._tracks.values()), | |
| high_conf, | |
| self.match_thresh, | |
| ) | |
| # Update matched tracks | |
| for track_id, det_idx in matched: | |
| bbox, conf = high_conf[det_idx] | |
| self._update_track(track_id, bbox, conf) | |
| # Second association: match low-confidence to remaining tracks | |
| if low_conf and unmatched_tracks: | |
| remaining_tracks = [self._tracks[tid] for tid in unmatched_tracks] | |
| matched2, unmatched_tracks, _ = self._associate( | |
| remaining_tracks, | |
| low_conf, | |
| self.match_thresh * 0.9, # Lower threshold | |
| ) | |
| for track_id, det_idx in matched2: | |
| bbox, conf = low_conf[det_idx] | |
| self._update_track(track_id, bbox, conf) | |
| # Handle unmatched tracks | |
| lost_this_frame = [] | |
| for track_id in unmatched_tracks: | |
| track = self._tracks[track_id] | |
| track.time_since_update += 1 | |
| if track.time_since_update > self.track_buffer: | |
| # Remove track | |
| del self._tracks[track_id] | |
| lost_this_frame.append(track_id) | |
| else: | |
| # Move to lost tracks | |
| self._lost_tracks[track_id] = self._tracks.pop(track_id) | |
| # Try to recover lost tracks with unmatched detections | |
| recovered = self._recover_lost_tracks( | |
| [(high_conf[i] if i < len(high_conf) else low_conf[i - len(high_conf)]) | |
| for i in unmatched_dets] | |
| ) | |
| # Create new tracks for remaining detections | |
| new_tracks = [] | |
| for i in unmatched_dets: | |
| if i not in recovered: | |
| det = high_conf[i] if i < len(high_conf) else low_conf[i - len(high_conf)] | |
| bbox, conf = det | |
| track_id = self._create_track(bbox, conf) | |
| new_tracks.append(track_id) | |
| return TrackingResult( | |
| frame_id=self._frame_id, | |
| tracks=list(self._tracks.values()), | |
| lost_tracks=lost_this_frame, | |
| new_tracks=new_tracks, | |
| ) | |
| def _associate( | |
| self, | |
| tracks: List[TrackedObject], | |
| detections: List[Tuple[Tuple[int, int, int, int], float]], | |
| thresh: float, | |
| ) -> Tuple[List[Tuple[int, int]], List[int], List[int]]: | |
| """ | |
| Associate detections to tracks using IoU. | |
| Returns: | |
| (matched_pairs, unmatched_track_ids, unmatched_detection_indices) | |
| """ | |
| if not tracks or not detections: | |
| return [], [t.track_id for t in tracks], list(range(len(detections))) | |
| # Compute IoU matrix | |
| iou_matrix = np.zeros((len(tracks), len(detections))) | |
| for i, track in enumerate(tracks): | |
| for j, (det_bbox, _) in enumerate(detections): | |
| iou_matrix[i, j] = self._compute_iou(track.bbox, det_bbox) | |
| # Greedy matching | |
| matched = [] | |
| unmatched_tracks = set(t.track_id for t in tracks) | |
| unmatched_dets = set(range(len(detections))) | |
| while True: | |
| # Find best match | |
| if iou_matrix.size == 0: | |
| break | |
| max_iou = np.max(iou_matrix) | |
| if max_iou < thresh: | |
| break | |
| max_idx = np.unravel_index(np.argmax(iou_matrix), iou_matrix.shape) | |
| track_idx, det_idx = max_idx | |
| track_id = tracks[track_idx].track_id | |
| matched.append((track_id, det_idx)) | |
| unmatched_tracks.discard(track_id) | |
| unmatched_dets.discard(det_idx) | |
| # Remove matched row and column | |
| iou_matrix[track_idx, :] = -1 | |
| iou_matrix[:, det_idx] = -1 | |
| return matched, list(unmatched_tracks), list(unmatched_dets) | |
| def _compute_iou( | |
| self, | |
| bbox1: Tuple[int, int, int, int], | |
| bbox2: Tuple[int, int, int, int], | |
| ) -> float: | |
| """Compute IoU between two bounding boxes.""" | |
| x1_1, y1_1, x2_1, y2_1 = bbox1 | |
| x1_2, y1_2, x2_2, y2_2 = bbox2 | |
| # Intersection | |
| xi1 = max(x1_1, x1_2) | |
| yi1 = max(y1_1, y1_2) | |
| xi2 = min(x2_1, x2_2) | |
| yi2 = min(y2_1, y2_2) | |
| if xi2 <= xi1 or yi2 <= yi1: | |
| return 0.0 | |
| intersection = (xi2 - xi1) * (yi2 - yi1) | |
| # Union | |
| area1 = (x2_1 - x1_1) * (y2_1 - y1_1) | |
| area2 = (x2_2 - x1_2) * (y2_2 - y1_2) | |
| union = area1 + area2 - intersection | |
| return intersection / union if union > 0 else 0.0 | |
| def _update_track( | |
| self, | |
| track_id: int, | |
| bbox: Tuple[int, int, int, int], | |
| confidence: float, | |
| ) -> None: | |
| """Update an existing track.""" | |
| track = self._tracks.get(track_id) or self._lost_tracks.get(track_id) | |
| if track is None: | |
| return | |
| # Move from lost to active if needed | |
| if track_id in self._lost_tracks: | |
| self._tracks[track_id] = self._lost_tracks.pop(track_id) | |
| track = self._tracks[track_id] | |
| track.history.append(track.bbox) | |
| track.bbox = bbox | |
| track.confidence = confidence | |
| track.frame_id = self._frame_id | |
| track.hits += 1 | |
| track.time_since_update = 0 | |
| def _create_track( | |
| self, | |
| bbox: Tuple[int, int, int, int], | |
| confidence: float, | |
| ) -> int: | |
| """Create a new track.""" | |
| track_id = self._next_id | |
| self._next_id += 1 | |
| track = TrackedObject( | |
| track_id=track_id, | |
| bbox=bbox, | |
| confidence=confidence, | |
| frame_id=self._frame_id, | |
| age=1, | |
| hits=1, | |
| ) | |
| self._tracks[track_id] = track | |
| logger.debug(f"Created new track {track_id}") | |
| return track_id | |
| def _recover_lost_tracks( | |
| self, | |
| detections: List[Tuple[Tuple[int, int, int, int], float]], | |
| ) -> set: | |
| """Try to recover lost tracks with unmatched detections.""" | |
| recovered = set() | |
| if not self._lost_tracks or not detections: | |
| return recovered | |
| for det_idx, (bbox, conf) in enumerate(detections): | |
| best_iou = 0 | |
| best_track_id = None | |
| for track_id, track in self._lost_tracks.items(): | |
| iou = self._compute_iou(track.bbox, bbox) | |
| if iou > best_iou and iou > self.match_thresh * 0.7: | |
| best_iou = iou | |
| best_track_id = track_id | |
| if best_track_id is not None: | |
| self._update_track(best_track_id, bbox, conf) | |
| recovered.add(det_idx) | |
| logger.debug(f"Recovered track {best_track_id}") | |
| return recovered | |
| def _handle_no_detections(self) -> TrackingResult: | |
| """Handle frame with no detections.""" | |
| lost_this_frame = [] | |
| for track_id in list(self._tracks.keys()): | |
| track = self._tracks[track_id] | |
| track.time_since_update += 1 | |
| if track.time_since_update > self.track_buffer: | |
| del self._tracks[track_id] | |
| lost_this_frame.append(track_id) | |
| else: | |
| self._lost_tracks[track_id] = self._tracks.pop(track_id) | |
| return TrackingResult( | |
| frame_id=self._frame_id, | |
| tracks=list(self._tracks.values()), | |
| lost_tracks=lost_this_frame, | |
| new_tracks=[], | |
| ) | |
| def get_track(self, track_id: int) -> Optional[TrackedObject]: | |
| """Get a specific track by ID.""" | |
| return self._tracks.get(track_id) or self._lost_tracks.get(track_id) | |
| def get_active_tracks(self) -> List[TrackedObject]: | |
| """Get all active tracks.""" | |
| return list(self._tracks.values()) | |
| def get_confirmed_tracks(self) -> List[TrackedObject]: | |
| """Get only confirmed tracks (multiple detections).""" | |
| return [t for t in self._tracks.values() if t.is_confirmed] | |
| def reset(self) -> None: | |
| """Reset tracker state.""" | |
| self._tracks.clear() | |
| self._lost_tracks.clear() | |
| self._frame_id = 0 | |
| logger.info("Tracker reset") | |
| def get_track_for_target( | |
| self, | |
| target_bbox: Tuple[int, int, int, int], | |
| threshold: float = 0.5, | |
| ) -> Optional[int]: | |
| """ | |
| Find track that best matches a target bounding box. | |
| Args: | |
| target_bbox: Target bounding box to match | |
| threshold: Minimum IoU for match | |
| Returns: | |
| Track ID if found, None otherwise | |
| """ | |
| best_iou = 0 | |
| best_track = None | |
| for track in self._tracks.values(): | |
| iou = self._compute_iou(track.bbox, target_bbox) | |
| if iou > best_iou and iou > threshold: | |
| best_iou = iou | |
| best_track = track.track_id | |
| return best_track | |
| # Export public interface | |
| __all__ = ["ObjectTracker", "TrackedObject", "TrackingResult"] | |