Tracking_HUB / scripts /tracker_factory.py
Panagiota Moraiti
Initial upload
a31bf96
import supervision as sv
from supervision import ByteTrack
from trackers import SORTTracker, DeepSORTTracker
def get_tracker(tracker_name, track_activation_threshold = 0.25, lost_track_buffer = 30,
frame_rate = 30.0, minimum_consecutive_frames = 3, minimum_iou_threshold = 0.3):
"""
Factory method to return the correct tracking algorithm based on name.
Args:
tracker_name (str): Name of the tracker ('bytetrack', 'sort', 'deepsort').
track_activation_threshold (float): Min confidence to activate a track.
lost_track_buffer (int): Number of frames to keep a lost track.
frame_rate (float): Frame rate of the video.
minimum_consecutive_frames (int): Minimum frames an object must persist to be tracked.
minimum_iou_threshold (float): IOU threshold for SORT/DeepSORT.
Returns:
BaseInference: A tracker instance.
Raises:
ValueError: If the tracker_name is unsupported.
"""
if tracker_name == 'bytetrack':
return sv.ByteTrack(
track_activation_threshold=track_activation_threshold,
lost_track_buffer=lost_track_buffer,
minimum_matching_threshold=minimum_iou_threshold,
minimum_consecutive_frames=minimum_consecutive_frames,
frame_rate=frame_rate
)
elif tracker_name == 'sort':
return SORTTracker(
track_activation_threshold=track_activation_threshold,
lost_track_buffer=lost_track_buffer,
frame_rate=frame_rate,
minimum_consecutive_frames=minimum_consecutive_frames,
minimum_iou_threshold=minimum_iou_threshold
)
elif tracker_name == 'deepsort':
return DeepSORTTracker(
track_activation_threshold=track_activation_threshold,
lost_track_buffer=lost_track_buffer,
frame_rate=frame_rate,
minimum_consecutive_frames=minimum_consecutive_frames,
minimum_iou_threshold=minimum_iou_threshold
)
else:
raise ValueError(f"Unsupported tracker: {tracker_name}")