File size: 2,107 Bytes
a31bf96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import supervision as sv
from supervision import ByteTrack
from trackers import SORTTracker, DeepSORTTracker


def get_tracker(tracker_name, track_activation_threshold = 0.25, lost_track_buffer = 30,
                frame_rate = 30.0, minimum_consecutive_frames = 3, minimum_iou_threshold = 0.3):
    """
    Factory method to return the correct tracking algorithm based on name.

    Args:
        tracker_name (str): Name of the tracker ('bytetrack', 'sort', 'deepsort').
        track_activation_threshold (float): Min confidence to activate a track.
        lost_track_buffer (int): Number of frames to keep a lost track.
        frame_rate (float): Frame rate of the video.
        minimum_consecutive_frames (int): Minimum frames an object must persist to be tracked.
        minimum_iou_threshold (float): IOU threshold for SORT/DeepSORT.

    Returns:
        BaseInference: A tracker instance.

    Raises:
        ValueError: If the tracker_name is unsupported.
    """
    if tracker_name == 'bytetrack':
        return sv.ByteTrack(
            track_activation_threshold=track_activation_threshold,
            lost_track_buffer=lost_track_buffer,
            minimum_matching_threshold=minimum_iou_threshold,
            minimum_consecutive_frames=minimum_consecutive_frames,
            frame_rate=frame_rate
        )
    elif tracker_name == 'sort':
        return SORTTracker(
            track_activation_threshold=track_activation_threshold,
            lost_track_buffer=lost_track_buffer,
            frame_rate=frame_rate,
            minimum_consecutive_frames=minimum_consecutive_frames,
            minimum_iou_threshold=minimum_iou_threshold
        )
    elif tracker_name == 'deepsort':
        return DeepSORTTracker(
            track_activation_threshold=track_activation_threshold,
            lost_track_buffer=lost_track_buffer,
            frame_rate=frame_rate,
            minimum_consecutive_frames=minimum_consecutive_frames,
            minimum_iou_threshold=minimum_iou_threshold
        )
    else:
        raise ValueError(f"Unsupported tracker: {tracker_name}")