eeeeeeeeeeeeee3's picture
Upload src/tracker.py with huggingface_hub
5bc2f1f verified
"""
ByteTrack integration for multi-object tracking
Provides temporal consistency for ball and player tracking
"""
import numpy as np
from typing import List, Dict, Tuple, Optional
import torch
try:
from byte_tracker import BYTETracker
BYTETRACK_AVAILABLE = True
except ImportError:
BYTETRACK_AVAILABLE = False
print("Warning: byte-track not installed. Install with: pip install byte-track")
class ByteTrackerWrapper:
"""
Wrapper for ByteTrack multi-object tracking
"""
def __init__(self, frame_rate: int = 30, track_thresh: float = 0.5,
track_buffer: int = 30, match_thresh: float = 0.8,
min_box_area: float = 10.0):
"""
Initialize ByteTracker
Args:
frame_rate: Video frame rate
track_thresh: Detection confidence threshold
track_buffer: Buffer for track persistence
match_thresh: IoU threshold for matching
min_box_area: Minimum box area to track
"""
if not BYTETRACK_AVAILABLE:
raise ImportError("byte-track not installed. Install with: pip install byte-track")
self.tracker = BYTETracker(
frame_rate=frame_rate,
track_thresh=track_thresh,
track_buffer=track_buffer,
match_thresh=match_thresh,
min_box_area=min_box_area
)
self.frame_id = 0
def update(self, detections: Dict, image_shape: Tuple[int, int]) -> List[Dict]:
"""
Update tracker with new detections
Args:
detections: Dictionary with 'boxes', 'scores', 'labels' (tensors)
image_shape: (height, width) of image
Returns:
List of tracked objects, each with 'track_id', 'bbox', 'score', 'class_id'
"""
if not BYTETRACK_AVAILABLE:
return []
# Convert detections to ByteTrack format
boxes = detections['boxes'].cpu().numpy() if isinstance(detections['boxes'], torch.Tensor) else detections['boxes']
scores = detections['scores'].cpu().numpy() if isinstance(detections['scores'], torch.Tensor) else detections['scores']
labels = detections['labels'].cpu().numpy() if isinstance(detections['labels'], torch.Tensor) else detections['labels']
# Convert boxes from [x_min, y_min, x_max, y_max] to [x_center, y_center, w, h]
boxes_center = np.zeros_like(boxes)
boxes_center[:, 0] = (boxes[:, 0] + boxes[:, 2]) / 2 # x_center
boxes_center[:, 1] = (boxes[:, 1] + boxes[:, 3]) / 2 # y_center
boxes_center[:, 2] = boxes[:, 2] - boxes[:, 0] # width
boxes_center[:, 3] = boxes[:, 3] - boxes[:, 1] # height
# Prepare detections for ByteTrack: [x_center, y_center, w, h, score, class_id]
detections_array = np.zeros((len(boxes), 6))
detections_array[:, :4] = boxes_center
detections_array[:, 4] = scores
detections_array[:, 5] = labels
# Update tracker
tracked_objects = self.tracker.update(detections_array, image_shape)
# Convert back to our format
results = []
for obj in tracked_objects:
track_id = int(obj.track_id)
bbox_center = obj.tlbr # ByteTrack returns [x_min, y_min, x_max, y_max]
score = float(obj.score)
class_id = int(obj.cls)
results.append({
'track_id': track_id,
'bbox': bbox_center, # [x_min, y_min, x_max, y_max]
'score': score,
'class_id': class_id
})
self.frame_id += 1
return results
def filter_short_tracks(self, tracked_objects: List[Dict], min_frames: int = 3) -> List[Dict]:
"""
Filter out tracks that exist for less than min_frames
Args:
tracked_objects: List of tracked objects
min_frames: Minimum frames for a track to be valid
Returns:
Filtered list of tracked objects
"""
# This is a simplified version - full implementation would track frame counts
# For now, return all tracks (ByteTrack handles this internally)
return tracked_objects