|
|
""" |
|
|
ByteTrack integration for multi-object tracking |
|
|
Provides temporal consistency for ball and player tracking |
|
|
""" |
|
|
import numpy as np |
|
|
from typing import List, Dict, Tuple, Optional |
|
|
import torch |
|
|
|
|
|
|
|
|
try: |
|
|
from byte_tracker import BYTETracker |
|
|
BYTETRACK_AVAILABLE = True |
|
|
except ImportError: |
|
|
BYTETRACK_AVAILABLE = False |
|
|
print("Warning: byte-track not installed. Install with: pip install byte-track") |
|
|
|
|
|
|
|
|
class ByteTrackerWrapper: |
|
|
""" |
|
|
Wrapper for ByteTrack multi-object tracking |
|
|
""" |
|
|
def __init__(self, frame_rate: int = 30, track_thresh: float = 0.5, |
|
|
track_buffer: int = 30, match_thresh: float = 0.8, |
|
|
min_box_area: float = 10.0): |
|
|
""" |
|
|
Initialize ByteTracker |
|
|
|
|
|
Args: |
|
|
frame_rate: Video frame rate |
|
|
track_thresh: Detection confidence threshold |
|
|
track_buffer: Buffer for track persistence |
|
|
match_thresh: IoU threshold for matching |
|
|
min_box_area: Minimum box area to track |
|
|
""" |
|
|
if not BYTETRACK_AVAILABLE: |
|
|
raise ImportError("byte-track not installed. Install with: pip install byte-track") |
|
|
|
|
|
self.tracker = BYTETracker( |
|
|
frame_rate=frame_rate, |
|
|
track_thresh=track_thresh, |
|
|
track_buffer=track_buffer, |
|
|
match_thresh=match_thresh, |
|
|
min_box_area=min_box_area |
|
|
) |
|
|
self.frame_id = 0 |
|
|
|
|
|
def update(self, detections: Dict, image_shape: Tuple[int, int]) -> List[Dict]: |
|
|
""" |
|
|
Update tracker with new detections |
|
|
|
|
|
Args: |
|
|
detections: Dictionary with 'boxes', 'scores', 'labels' (tensors) |
|
|
image_shape: (height, width) of image |
|
|
|
|
|
Returns: |
|
|
List of tracked objects, each with 'track_id', 'bbox', 'score', 'class_id' |
|
|
""" |
|
|
if not BYTETRACK_AVAILABLE: |
|
|
return [] |
|
|
|
|
|
|
|
|
boxes = detections['boxes'].cpu().numpy() if isinstance(detections['boxes'], torch.Tensor) else detections['boxes'] |
|
|
scores = detections['scores'].cpu().numpy() if isinstance(detections['scores'], torch.Tensor) else detections['scores'] |
|
|
labels = detections['labels'].cpu().numpy() if isinstance(detections['labels'], torch.Tensor) else detections['labels'] |
|
|
|
|
|
|
|
|
boxes_center = np.zeros_like(boxes) |
|
|
boxes_center[:, 0] = (boxes[:, 0] + boxes[:, 2]) / 2 |
|
|
boxes_center[:, 1] = (boxes[:, 1] + boxes[:, 3]) / 2 |
|
|
boxes_center[:, 2] = boxes[:, 2] - boxes[:, 0] |
|
|
boxes_center[:, 3] = boxes[:, 3] - boxes[:, 1] |
|
|
|
|
|
|
|
|
detections_array = np.zeros((len(boxes), 6)) |
|
|
detections_array[:, :4] = boxes_center |
|
|
detections_array[:, 4] = scores |
|
|
detections_array[:, 5] = labels |
|
|
|
|
|
|
|
|
tracked_objects = self.tracker.update(detections_array, image_shape) |
|
|
|
|
|
|
|
|
results = [] |
|
|
for obj in tracked_objects: |
|
|
track_id = int(obj.track_id) |
|
|
bbox_center = obj.tlbr |
|
|
score = float(obj.score) |
|
|
class_id = int(obj.cls) |
|
|
|
|
|
results.append({ |
|
|
'track_id': track_id, |
|
|
'bbox': bbox_center, |
|
|
'score': score, |
|
|
'class_id': class_id |
|
|
}) |
|
|
|
|
|
self.frame_id += 1 |
|
|
return results |
|
|
|
|
|
def filter_short_tracks(self, tracked_objects: List[Dict], min_frames: int = 3) -> List[Dict]: |
|
|
""" |
|
|
Filter out tracks that exist for less than min_frames |
|
|
|
|
|
Args: |
|
|
tracked_objects: List of tracked objects |
|
|
min_frames: Minimum frames for a track to be valid |
|
|
|
|
|
Returns: |
|
|
Filtered list of tracked objects |
|
|
""" |
|
|
|
|
|
|
|
|
return tracked_objects |
|
|
|