Spaces:
Sleeping
Sleeping
File size: 5,174 Bytes
e1832f4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import cv2
import torch
import gdown
import numpy as np
from abc import ABC, abstractmethod
from boxmot.utils import logger as LOGGER
from boxmot.appearance.reid.registry import ReIDModelRegistry
from boxmot.utils.checks import RequirementsChecker
class BaseModelBackend:
def __init__(self, weights, device, half):
self.weights = weights[0] if isinstance(weights, list) else weights
self.device = device
self.half = half
self.model = None
self.cuda = torch.cuda.is_available() and self.device.type != "cpu"
self.download_model(self.weights)
self.model_name = ReIDModelRegistry.get_model_name(self.weights)
self.model = ReIDModelRegistry.build_model(
self.model_name,
num_classes=ReIDModelRegistry.get_nr_classes(self.weights),
pretrained=not (self.weights and self.weights.is_file()),
use_gpu=device,
)
self.checker = RequirementsChecker()
self.load_model(self.weights)
def get_crops(self, xyxys, img):
h, w = img.shape[:2]
resize_dims = (128, 256)
interpolation_method = cv2.INTER_LINEAR
mean_array = torch.tensor([0.485, 0.456, 0.406], device=self.device).view(1, 3, 1, 1)
std_array = torch.tensor([0.229, 0.224, 0.225], device=self.device).view(1, 3, 1, 1)
# Preallocate tensor for crops
num_crops = len(xyxys)
crops = torch.empty((num_crops, 3, resize_dims[1], resize_dims[0]),
dtype=torch.half if self.half else torch.float, device=self.device)
for i, box in enumerate(xyxys):
x1, y1, x2, y2 = box.round().astype('int')
x1, y1, x2, y2 = max(0, x1), max(0, y1), min(w, x2), min(h, y2)
crop = img[y1:y2, x1:x2]
# Resize and convert color in one step
crop = cv2.resize(crop, resize_dims, interpolation=interpolation_method)
crop = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)
# Convert to tensor and normalize (convert to [0, 1] by dividing by 255 in batch later)
crop = torch.from_numpy(crop).to(self.device, dtype=torch.half if self.half else torch.float)
crops[i] = torch.permute(crop, (2, 0, 1)) # Change to (C, H, W)
# Normalize the entire batch in one go
crops = crops / 255.0
# Standardize the batch
crops = (crops - mean_array) / std_array
return crops
@torch.no_grad()
def get_features(self, xyxys, img):
if xyxys.size != 0:
crops = self.get_crops(xyxys, img)
crops = self.inference_preprocess(crops)
features = self.forward(crops)
features = self.inference_postprocess(features)
else:
features = np.array([])
features = features / np.linalg.norm(features, axis=-1, keepdims=True)
return features
def warmup(self, imgsz=[(256, 128, 3)]):
# warmup model by running inference once
if self.device.type != "cpu":
im = np.random.randint(0, 255, *imgsz, dtype=np.uint8)
crops = self.get_crops(xyxys=np.array(
[[0, 0, 64, 64], [0, 0, 128, 128]]),
img=im
)
crops = self.inference_preprocess(crops)
self.forward(crops) # warmup
def to_numpy(self, x):
return x.cpu().numpy() if isinstance(x, torch.Tensor) else x
def inference_preprocess(self, x):
if self.half:
if isinstance(x, torch.Tensor):
if x.dtype != torch.float16:
x = x.half()
elif isinstance(x, np.ndarray):
if x.dtype != np.float16:
x = x.astype(np.float16)
if self.nhwc:
if isinstance(x, torch.Tensor):
x = x.permute(0, 2, 3, 1) # Convert from NCHW to NHWC
elif isinstance(x, np.ndarray):
x = np.transpose(x, (0, 2, 3, 1)) # Convert from NCHW to NHWC
return x
def inference_postprocess(self, features):
if isinstance(features, (list, tuple)):
return (
self.to_numpy(features[0]) if len(features) == 1 else [self.to_numpy(x) for x in features]
)
else:
return self.to_numpy(features)
@abstractmethod
def forward(self, im_batch):
raise NotImplementedError("This method should be implemented by subclasses.")
@abstractmethod
def load_model(self, w):
raise NotImplementedError("This method should be implemented by subclasses.")
def download_model(self, w):
if w.suffix == ".pt":
model_url = ReIDModelRegistry.get_model_url(w)
if not w.exists() and model_url is not None:
gdown.download(model_url, str(w), quiet=False)
elif not w.exists():
LOGGER.error(
f"No URL associated with the chosen StrongSORT weights ({w}). Choose between:"
)
ReIDModelRegistry.show_downloadable_models()
exit() |