Spaces:
Sleeping
Sleeping
File size: 2,107 Bytes
a31bf96 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 | import supervision as sv
from supervision import ByteTrack
from trackers import SORTTracker, DeepSORTTracker
def get_tracker(tracker_name, track_activation_threshold = 0.25, lost_track_buffer = 30,
frame_rate = 30.0, minimum_consecutive_frames = 3, minimum_iou_threshold = 0.3):
"""
Factory method to return the correct tracking algorithm based on name.
Args:
tracker_name (str): Name of the tracker ('bytetrack', 'sort', 'deepsort').
track_activation_threshold (float): Min confidence to activate a track.
lost_track_buffer (int): Number of frames to keep a lost track.
frame_rate (float): Frame rate of the video.
minimum_consecutive_frames (int): Minimum frames an object must persist to be tracked.
minimum_iou_threshold (float): IOU threshold for SORT/DeepSORT.
Returns:
BaseInference: A tracker instance.
Raises:
ValueError: If the tracker_name is unsupported.
"""
if tracker_name == 'bytetrack':
return sv.ByteTrack(
track_activation_threshold=track_activation_threshold,
lost_track_buffer=lost_track_buffer,
minimum_matching_threshold=minimum_iou_threshold,
minimum_consecutive_frames=minimum_consecutive_frames,
frame_rate=frame_rate
)
elif tracker_name == 'sort':
return SORTTracker(
track_activation_threshold=track_activation_threshold,
lost_track_buffer=lost_track_buffer,
frame_rate=frame_rate,
minimum_consecutive_frames=minimum_consecutive_frames,
minimum_iou_threshold=minimum_iou_threshold
)
elif tracker_name == 'deepsort':
return DeepSORTTracker(
track_activation_threshold=track_activation_threshold,
lost_track_buffer=lost_track_buffer,
frame_rate=frame_rate,
minimum_consecutive_frames=minimum_consecutive_frames,
minimum_iou_threshold=minimum_iou_threshold
)
else:
raise ValueError(f"Unsupported tracker: {tracker_name}")
|