ShortSmith_v3 / models /tracker.py
chaitanya.musale
Fix models folder issues: bugs and code cleanup
15c68da
"""
ShortSmith v2 - Object Tracker Module
Multi-object tracking using ByteTrack for:
- Maintaining person identity across frames
- Handling occlusions and reappearances
- Tracking specific individuals through video
ByteTrack uses two-stage association for robust tracking.
"""
from typing import List, Optional, Dict, Tuple
from dataclasses import dataclass, field
import numpy as np
from utils.logger import get_logger
logger = get_logger("models.tracker")
@dataclass
class TrackedObject:
"""Represents a tracked object across frames."""
track_id: int # Unique track identifier
bbox: Tuple[int, int, int, int] # Current bounding box (x1, y1, x2, y2)
confidence: float # Detection confidence
class_id: int = 0 # Object class (0 = person)
frame_id: int = 0 # Current frame number
# Track history
history: List[Tuple[int, int, int, int]] = field(default_factory=list)
age: int = 0 # Frames since first detection
hits: int = 0 # Number of detections
time_since_update: int = 0 # Frames since last detection
@property
def center(self) -> Tuple[int, int]:
x1, y1, x2, y2 = self.bbox
return ((x1 + x2) // 2, (y1 + y2) // 2)
@property
def area(self) -> int:
x1, y1, x2, y2 = self.bbox
return (x2 - x1) * (y2 - y1)
@property
def is_confirmed(self) -> bool:
"""Track is confirmed after multiple detections."""
return self.hits >= 3
@dataclass
class TrackingResult:
"""Result of tracking for a single frame."""
frame_id: int
tracks: List[TrackedObject]
lost_tracks: List[int] # Track IDs lost this frame
new_tracks: List[int] # New track IDs this frame
class ObjectTracker:
"""
Multi-object tracker using ByteTrack algorithm.
ByteTrack features:
- Two-stage association (high-confidence first, then low-confidence)
- Handles occlusions by keeping lost tracks
- Re-identifies objects after temporary disappearance
"""
def __init__(
self,
track_thresh: float = 0.5,
track_buffer: int = 30,
match_thresh: float = 0.8,
):
"""
Initialize tracker.
Args:
track_thresh: Detection confidence threshold for new tracks
track_buffer: Frames to keep lost tracks
match_thresh: IoU threshold for matching
"""
self.track_thresh = track_thresh
self.track_buffer = track_buffer
self.match_thresh = match_thresh
self._tracks: Dict[int, TrackedObject] = {}
self._lost_tracks: Dict[int, TrackedObject] = {}
self._next_id = 1
self._frame_id = 0
logger.info(
f"ObjectTracker initialized (thresh={track_thresh}, "
f"buffer={track_buffer}, match={match_thresh})"
)
def update(
self,
detections: List[Tuple[Tuple[int, int, int, int], float]],
) -> TrackingResult:
"""
Update tracker with new detections.
Args:
detections: List of (bbox, confidence) tuples
Returns:
TrackingResult with current tracks
"""
self._frame_id += 1
if not detections:
# No detections - age all tracks
return self._handle_no_detections()
# Separate high and low confidence detections
high_conf = [(bbox, conf) for bbox, conf in detections if conf >= self.track_thresh]
low_conf = [(bbox, conf) for bbox, conf in detections if conf < self.track_thresh]
# First association: match high-confidence detections to active tracks
matched, unmatched_tracks, unmatched_dets = self._associate(
list(self._tracks.values()),
high_conf,
self.match_thresh,
)
# Update matched tracks
for track_id, det_idx in matched:
bbox, conf = high_conf[det_idx]
self._update_track(track_id, bbox, conf)
# Second association: match low-confidence to remaining tracks
if low_conf and unmatched_tracks:
remaining_tracks = [self._tracks[tid] for tid in unmatched_tracks]
matched2, unmatched_tracks, _ = self._associate(
remaining_tracks,
low_conf,
self.match_thresh * 0.9, # Lower threshold
)
for track_id, det_idx in matched2:
bbox, conf = low_conf[det_idx]
self._update_track(track_id, bbox, conf)
# Handle unmatched tracks
lost_this_frame = []
for track_id in unmatched_tracks:
track = self._tracks[track_id]
track.time_since_update += 1
if track.time_since_update > self.track_buffer:
# Remove track
del self._tracks[track_id]
lost_this_frame.append(track_id)
else:
# Move to lost tracks
self._lost_tracks[track_id] = self._tracks.pop(track_id)
# Try to recover lost tracks with unmatched detections
# Build list of (original_index, detection) pairs
unmatched_detections = [
(i, high_conf[i] if i < len(high_conf) else low_conf[i - len(high_conf)])
for i in unmatched_dets
]
recovered_indices = self._recover_lost_tracks(unmatched_detections)
# Create new tracks for remaining detections
new_tracks = []
for i in unmatched_dets:
if i not in recovered_indices:
det = high_conf[i] if i < len(high_conf) else low_conf[i - len(high_conf)]
bbox, conf = det
track_id = self._create_track(bbox, conf)
new_tracks.append(track_id)
return TrackingResult(
frame_id=self._frame_id,
tracks=list(self._tracks.values()),
lost_tracks=lost_this_frame,
new_tracks=new_tracks,
)
def _associate(
self,
tracks: List[TrackedObject],
detections: List[Tuple[Tuple[int, int, int, int], float]],
thresh: float,
) -> Tuple[List[Tuple[int, int]], List[int], List[int]]:
"""
Associate detections to tracks using IoU.
Returns:
(matched_pairs, unmatched_track_ids, unmatched_detection_indices)
"""
if not tracks or not detections:
return [], [t.track_id for t in tracks], list(range(len(detections)))
# Compute IoU matrix
iou_matrix = np.zeros((len(tracks), len(detections)))
for i, track in enumerate(tracks):
for j, (det_bbox, _) in enumerate(detections):
iou_matrix[i, j] = self._compute_iou(track.bbox, det_bbox)
# Greedy matching
matched = []
unmatched_tracks = set(t.track_id for t in tracks)
unmatched_dets = set(range(len(detections)))
while True:
# Find best match
if iou_matrix.size == 0:
break
max_iou = np.max(iou_matrix)
if max_iou < thresh:
break
max_idx = np.unravel_index(np.argmax(iou_matrix), iou_matrix.shape)
track_idx, det_idx = max_idx
track_id = tracks[track_idx].track_id
matched.append((track_id, det_idx))
unmatched_tracks.discard(track_id)
unmatched_dets.discard(det_idx)
# Remove matched row and column
iou_matrix[track_idx, :] = -1
iou_matrix[:, det_idx] = -1
return matched, list(unmatched_tracks), list(unmatched_dets)
def _compute_iou(
self,
bbox1: Tuple[int, int, int, int],
bbox2: Tuple[int, int, int, int],
) -> float:
"""Compute IoU between two bounding boxes."""
x1_1, y1_1, x2_1, y2_1 = bbox1
x1_2, y1_2, x2_2, y2_2 = bbox2
# Intersection
xi1 = max(x1_1, x1_2)
yi1 = max(y1_1, y1_2)
xi2 = min(x2_1, x2_2)
yi2 = min(y2_1, y2_2)
if xi2 <= xi1 or yi2 <= yi1:
return 0.0
intersection = (xi2 - xi1) * (yi2 - yi1)
# Union
area1 = (x2_1 - x1_1) * (y2_1 - y1_1)
area2 = (x2_2 - x1_2) * (y2_2 - y1_2)
union = area1 + area2 - intersection
return intersection / union if union > 0 else 0.0
def _update_track(
self,
track_id: int,
bbox: Tuple[int, int, int, int],
confidence: float,
) -> None:
"""Update an existing track."""
track = self._tracks.get(track_id) or self._lost_tracks.get(track_id)
if track is None:
return
# Move from lost to active if needed
if track_id in self._lost_tracks:
self._tracks[track_id] = self._lost_tracks.pop(track_id)
track = self._tracks[track_id]
track.history.append(track.bbox)
track.bbox = bbox
track.confidence = confidence
track.frame_id = self._frame_id
track.hits += 1
track.time_since_update = 0
def _create_track(
self,
bbox: Tuple[int, int, int, int],
confidence: float,
) -> int:
"""Create a new track."""
track_id = self._next_id
self._next_id += 1
track = TrackedObject(
track_id=track_id,
bbox=bbox,
confidence=confidence,
frame_id=self._frame_id,
age=1,
hits=1,
)
self._tracks[track_id] = track
logger.debug(f"Created new track {track_id}")
return track_id
def _recover_lost_tracks(
self,
detections: List[Tuple[int, Tuple[Tuple[int, int, int, int], float]]],
) -> set:
"""
Try to recover lost tracks with unmatched detections.
Args:
detections: List of (original_index, (bbox, confidence)) tuples
Returns:
Set of original indices that were successfully recovered
"""
recovered = set()
if not self._lost_tracks or not detections:
return recovered
for orig_idx, (bbox, conf) in detections:
best_iou = 0
best_track_id = None
for track_id, track in self._lost_tracks.items():
iou = self._compute_iou(track.bbox, bbox)
if iou > best_iou and iou > self.match_thresh * 0.7:
best_iou = iou
best_track_id = track_id
if best_track_id is not None:
self._update_track(best_track_id, bbox, conf)
recovered.add(orig_idx) # Add original index, not enumeration index
logger.debug(f"Recovered track {best_track_id}")
return recovered
def _handle_no_detections(self) -> TrackingResult:
"""Handle frame with no detections."""
lost_this_frame = []
for track_id in list(self._tracks.keys()):
track = self._tracks[track_id]
track.time_since_update += 1
if track.time_since_update > self.track_buffer:
del self._tracks[track_id]
lost_this_frame.append(track_id)
else:
self._lost_tracks[track_id] = self._tracks.pop(track_id)
return TrackingResult(
frame_id=self._frame_id,
tracks=list(self._tracks.values()),
lost_tracks=lost_this_frame,
new_tracks=[],
)
def get_track(self, track_id: int) -> Optional[TrackedObject]:
"""Get a specific track by ID."""
return self._tracks.get(track_id) or self._lost_tracks.get(track_id)
def get_active_tracks(self) -> List[TrackedObject]:
"""Get all active tracks."""
return list(self._tracks.values())
def get_confirmed_tracks(self) -> List[TrackedObject]:
"""Get only confirmed tracks (multiple detections)."""
return [t for t in self._tracks.values() if t.is_confirmed]
def reset(self) -> None:
"""Reset tracker state."""
self._tracks.clear()
self._lost_tracks.clear()
self._frame_id = 0
logger.info("Tracker reset")
def get_track_for_target(
self,
target_bbox: Tuple[int, int, int, int],
threshold: float = 0.5,
) -> Optional[int]:
"""
Find track that best matches a target bounding box.
Args:
target_bbox: Target bounding box to match
threshold: Minimum IoU for match
Returns:
Track ID if found, None otherwise
"""
best_iou = 0
best_track = None
for track in self._tracks.values():
iou = self._compute_iou(track.bbox, target_bbox)
if iou > best_iou and iou > threshold:
best_iou = iou
best_track = track.track_id
return best_track
# Export public interface
__all__ = ["ObjectTracker", "TrackedObject", "TrackingResult"]