import torch import numpy as np from pathlib import Path from collections import OrderedDict, namedtuple from boxmot.utils import logger as LOGGER from boxmot.appearance.backends.base_backend import BaseModelBackend class TensorRTBackend(BaseModelBackend): def __init__(self, weights, device, half): self.is_trt10 = False super().__init__(weights, device, half) self.nhwc = False self.half = half self.device = device self.weights = weights self.fp16 = False # Will be updated in load_model self.load_model(self.weights) def load_model(self, w): LOGGER.info(f"Loading {w} for TensorRT inference...") self.checker.check_packages(("nvidia-tensorrt",)) try: import tensorrt as trt # TensorRT library except ImportError: raise ImportError("Please install tensorrt to use this backend.") if self.device.type == "cpu": if torch.cuda.is_available(): self.device = torch.device("cuda:0") else: raise ValueError("CUDA device not available for TensorRT inference.") Binding = namedtuple("Binding", ("name", "dtype", "shape", "data", "ptr")) logger = trt.Logger(trt.Logger.INFO) # Deserialize the engine with open(w, "rb") as f, trt.Runtime(logger) as runtime: self.model_ = runtime.deserialize_cuda_engine(f.read()) # Execution context self.context = self.model_.create_execution_context() self.bindings = OrderedDict() self.is_trt10 = not hasattr(self.model_, "num_bindings") num = range(self.model_.num_io_tensors) if self.is_trt10 else range(self.model_.num_bindings) # Parse bindings for index in num: if self.is_trt10: name = self.model_.get_tensor_name(index) dtype = trt.nptype(self.model_.get_tensor_dtype(name)) is_input = self.model_.get_tensor_mode(name) == trt.TensorIOMode.INPUT if is_input and -1 in tuple(self.model_.get_tensor_shape(name)): self.context.set_input_shape(name, tuple(self.model_.get_tensor_profile_shape(name, 0)[1])) if is_input and dtype == np.float16: self.fp16 = True shape = tuple(self.context.get_tensor_shape(name)) else: name = self.model_.get_binding_name(index) dtype = trt.nptype(self.model_.get_binding_dtype(index)) is_input = self.model_.binding_is_input(index) # Handle dynamic shapes if is_input and -1 in self.model_.get_binding_shape(index): profile_index = 0 min_shape, opt_shape, max_shape = self.model_.get_profile_shape(profile_index, index) self.context.set_binding_shape(index, opt_shape) if is_input and dtype == np.float16: self.fp16 = True shape = tuple(self.context.get_binding_shape(index)) data = torch.from_numpy(np.empty(shape, dtype=dtype)).to(self.device) self.bindings[name] = Binding(name, dtype, shape, data, int(data.data_ptr())) self.binding_addrs = OrderedDict((n, d.ptr) for n, d in self.bindings.items()) def forward(self, im_batch): temp_im_batch = im_batch.clone() batch_array = [] inp_batch = im_batch.shape[0] out_batch = self.bindings["output"].shape[0] resultant_features = [] # Divide batch to sub batches while inp_batch > out_batch: batch_array.append(temp_im_batch[:out_batch]) temp_im_batch = temp_im_batch[out_batch:] inp_batch = temp_im_batch.shape[0] if temp_im_batch.shape[0] > 0: batch_array.append(temp_im_batch) for temp_batch in batch_array: # Adjust for dynamic shapes if temp_batch.shape != self.bindings["images"].shape: if self.is_trt10: self.context.set_input_shape("images", temp_batch.shape) self.bindings["images"] = self.bindings["images"]._replace(shape=temp_batch.shape) self.bindings["output"].data.resize_(tuple(self.context.get_tensor_shape("output"))) else: i_in = self.model_.get_binding_index("images") i_out = self.model_.get_binding_index("output") self.context.set_binding_shape(i_in, temp_batch.shape) self.bindings["images"] = self.bindings["images"]._replace(shape=temp_batch.shape) output_shape = tuple(self.context.get_binding_shape(i_out)) self.bindings["output"].data.resize_(output_shape) s = self.bindings["images"].shape assert temp_batch.shape == s, f"Input size {temp_batch.shape} does not match model size {s}" self.binding_addrs["images"] = int(temp_batch.data_ptr()) # Execute inference self.context.execute_v2(list(self.binding_addrs.values())) features = self.bindings["output"].data resultant_features.append(features.clone()) if len(resultant_features)== 1: return resultant_features[0] else: rslt_features = torch.cat(resultant_features,dim=0) rslt_features= rslt_features[:im_batch.shape[0]] return rslt_features