Spaces:
Runtime error
Runtime error
| from enum import Enum | |
| class TrackerType(Enum): | |
| NONE = 0 | |
| CONF_BOOST = 1 | |
| BYTETRACK = 2 | |
| def toString(val): | |
| if val == TrackerType.NONE: return "None" | |
| if val == TrackerType.CONF_BOOST: return "Confidence Boost" | |
| if val == TrackerType.BYTETRACK: return "ByteTrack" | |
| ### Configuration options | |
| WEIGHTS = 'models/v5m_896_300best.pt' | |
| # will need to configure these based on GPU hardware | |
| BATCH_SIZE = 32 | |
| CONF_THRES = 0.05 # detection | |
| NMS_IOU = 0.25 # NMS IOU | |
| MAX_AGE = 20 # time until missing fish get's new id | |
| MIN_HITS = 11 # minimum number of frames with a specific fish for it to count | |
| MIN_LENGTH = 0.3 # minimum fish length, in meters | |
| MAX_LENGTH = 0 # maximum fish length, in meters | |
| IOU_THRES = 0.01 # IOU threshold for tracking | |
| MIN_TRAVEL = 0 # Minimum distance a track has to travel | |
| DEFAULT_TRACKER = TrackerType.BYTETRACK | |
| class InferenceConfig: | |
| def __init__(self, | |
| weights=WEIGHTS, conf_thresh=CONF_THRES, nms_iou=NMS_IOU, | |
| min_hits=MIN_HITS, max_age=MAX_AGE, min_length=MIN_LENGTH, max_length=MAX_LENGTH, min_travel=MIN_TRAVEL): | |
| self.weights = weights | |
| self.conf_thresh = conf_thresh | |
| self.nms_iou = nms_iou | |
| self.min_hits = min_hits | |
| self.max_age = max_age | |
| self.min_length = min_length | |
| self.max_length = max_length | |
| self.min_travel = min_travel | |
| self.associative_tracker = DEFAULT_TRACKER | |
| self.boost_power = 2 | |
| self.boost_decay = 0.1 | |
| self.byte_low_conf = 0.1 | |
| self.byte_high_conf = 0.3 | |
| def enable_sort_track(self): | |
| self.associative_tracker = TrackerType.NONE | |
| def enable_conf_boost(self, power=2, decay=0.1): | |
| self.associative_tracker = TrackerType.CONF_BOOST | |
| self.boost_power = power | |
| self.boost_decay = decay | |
| def enable_byte_track(self, low=0.1, high=0.3): | |
| self.associative_tracker = TrackerType.BYTETRACK | |
| self.byte_low_conf = low | |
| self.byte_high_conf = high | |
| def enable_tracker_from_string(self, associativity): | |
| if associativity != "": | |
| if (associativity.startswith("boost")): | |
| conf = associativity.split(":") | |
| if len(conf) == 3: | |
| self.enable_conf_boost(power=float(conf[1]), decay=float(conf[2])) | |
| return True | |
| else: | |
| print("INVALID PARAMETERS FOR CONFIDENCE BOOST:", associativity) | |
| return False | |
| elif (associativity.startswith("bytetrack")): | |
| conf = associativity.split(":") | |
| if len(conf) == 3: | |
| self.enable_byte_track(low=float(conf[1]), high=float(conf[2])) | |
| return True | |
| else: | |
| print("INVALID PARAMETERS FOR BYTETRACK:", associativity) | |
| return False | |
| else: | |
| print("INVALID ASSOCIATIVITY TYPE:", associativity) | |
| return False | |
| else: | |
| self.enable_sort_track() | |
| return True | |
| def find_model(self, model_list): | |
| print("weights", self.weights) | |
| for model_name in model_list: | |
| print("Path", model_list[model_name], "->", model_name) | |
| if model_list[model_name] == self.weights: | |
| return model_name | |
| print("not found") | |
| return None | |
| def to_dict(self): | |
| dict = { | |
| 'weights': self.weights, | |
| 'nms_iou': self.nms_iou, | |
| 'min_hits': self.min_hits, | |
| 'max_age': self.max_age, | |
| 'min_length': self.min_length, | |
| 'min_travel': self.min_travel, | |
| } | |
| # Add tracker specific parameters | |
| if (self.associative_tracker == TrackerType.BYTETRACK): | |
| dict['tracker'] = "ByteTrack" | |
| dict['byte_low_conf'] = self.byte_low_conf | |
| dict['byte_high_conf'] = self.byte_high_conf | |
| elif (self.associative_tracker == TrackerType.CONF_BOOST): | |
| dict['tracker'] = "Confidence Boost" | |
| dict['conf_thresh'] = self.conf_thresh | |
| dict['boost_power'] = self.boost_power | |
| dict['boost_decay'] = self.boost_decay | |
| elif (self.associative_tracker == TrackerType.NONE): | |
| dict['tracker'] = "None" | |
| dict['conf_thresh'] = self.conf_thresh | |
| return dict |