import cv2 import torch import gdown import numpy as np from abc import ABC, abstractmethod from boxmot.utils import logger as LOGGER from boxmot.appearance.reid.registry import ReIDModelRegistry from boxmot.utils.checks import RequirementsChecker class BaseModelBackend: def __init__(self, weights, device, half): self.weights = weights[0] if isinstance(weights, list) else weights self.device = device self.half = half self.model = None self.cuda = torch.cuda.is_available() and self.device.type != "cpu" self.download_model(self.weights) self.model_name = ReIDModelRegistry.get_model_name(self.weights) self.model = ReIDModelRegistry.build_model( self.model_name, num_classes=ReIDModelRegistry.get_nr_classes(self.weights), pretrained=not (self.weights and self.weights.is_file()), use_gpu=device, ) self.checker = RequirementsChecker() self.load_model(self.weights) def get_crops(self, xyxys, img): h, w = img.shape[:2] resize_dims = (128, 256) interpolation_method = cv2.INTER_LINEAR mean_array = torch.tensor([0.485, 0.456, 0.406], device=self.device).view(1, 3, 1, 1) std_array = torch.tensor([0.229, 0.224, 0.225], device=self.device).view(1, 3, 1, 1) # Preallocate tensor for crops num_crops = len(xyxys) crops = torch.empty((num_crops, 3, resize_dims[1], resize_dims[0]), dtype=torch.half if self.half else torch.float, device=self.device) for i, box in enumerate(xyxys): x1, y1, x2, y2 = box.round().astype('int') x1, y1, x2, y2 = max(0, x1), max(0, y1), min(w, x2), min(h, y2) crop = img[y1:y2, x1:x2] # Resize and convert color in one step crop = cv2.resize(crop, resize_dims, interpolation=interpolation_method) crop = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB) # Convert to tensor and normalize (convert to [0, 1] by dividing by 255 in batch later) crop = torch.from_numpy(crop).to(self.device, dtype=torch.half if self.half else torch.float) crops[i] = torch.permute(crop, (2, 0, 1)) # Change to (C, H, W) # Normalize the entire batch in one go crops = crops / 255.0 # Standardize the batch crops = (crops - mean_array) / std_array return crops @torch.no_grad() def get_features(self, xyxys, img): if xyxys.size != 0: crops = self.get_crops(xyxys, img) crops = self.inference_preprocess(crops) features = self.forward(crops) features = self.inference_postprocess(features) else: features = np.array([]) features = features / np.linalg.norm(features, axis=-1, keepdims=True) return features def warmup(self, imgsz=[(256, 128, 3)]): # warmup model by running inference once if self.device.type != "cpu": im = np.random.randint(0, 255, *imgsz, dtype=np.uint8) crops = self.get_crops(xyxys=np.array( [[0, 0, 64, 64], [0, 0, 128, 128]]), img=im ) crops = self.inference_preprocess(crops) self.forward(crops) # warmup def to_numpy(self, x): return x.cpu().numpy() if isinstance(x, torch.Tensor) else x def inference_preprocess(self, x): if self.half: if isinstance(x, torch.Tensor): if x.dtype != torch.float16: x = x.half() elif isinstance(x, np.ndarray): if x.dtype != np.float16: x = x.astype(np.float16) if self.nhwc: if isinstance(x, torch.Tensor): x = x.permute(0, 2, 3, 1) # Convert from NCHW to NHWC elif isinstance(x, np.ndarray): x = np.transpose(x, (0, 2, 3, 1)) # Convert from NCHW to NHWC return x def inference_postprocess(self, features): if isinstance(features, (list, tuple)): return ( self.to_numpy(features[0]) if len(features) == 1 else [self.to_numpy(x) for x in features] ) else: return self.to_numpy(features) @abstractmethod def forward(self, im_batch): raise NotImplementedError("This method should be implemented by subclasses.") @abstractmethod def load_model(self, w): raise NotImplementedError("This method should be implemented by subclasses.") def download_model(self, w): if w.suffix == ".pt": model_url = ReIDModelRegistry.get_model_url(w) if not w.exists() and model_url is not None: gdown.download(model_url, str(w), quiet=False) elif not w.exists(): LOGGER.error( f"No URL associated with the chosen StrongSORT weights ({w}). Choose between:" ) ReIDModelRegistry.show_downloadable_models() exit()