Spaces:
Sleeping
Sleeping
| import torch | |
| import numpy as np | |
| from pathlib import Path | |
| from collections import OrderedDict, namedtuple | |
| from boxmot.utils import logger as LOGGER | |
| from boxmot.appearance.backends.base_backend import BaseModelBackend | |
| class TensorRTBackend(BaseModelBackend): | |
| def __init__(self, weights, device, half): | |
| self.is_trt10 = False | |
| super().__init__(weights, device, half) | |
| self.nhwc = False | |
| self.half = half | |
| self.device = device | |
| self.weights = weights | |
| self.fp16 = False # Will be updated in load_model | |
| self.load_model(self.weights) | |
| def load_model(self, w): | |
| LOGGER.info(f"Loading {w} for TensorRT inference...") | |
| self.checker.check_packages(("nvidia-tensorrt",)) | |
| try: | |
| import tensorrt as trt # TensorRT library | |
| except ImportError: | |
| raise ImportError("Please install tensorrt to use this backend.") | |
| if self.device.type == "cpu": | |
| if torch.cuda.is_available(): | |
| self.device = torch.device("cuda:0") | |
| else: | |
| raise ValueError("CUDA device not available for TensorRT inference.") | |
| Binding = namedtuple("Binding", ("name", "dtype", "shape", "data", "ptr")) | |
| logger = trt.Logger(trt.Logger.INFO) | |
| # Deserialize the engine | |
| with open(w, "rb") as f, trt.Runtime(logger) as runtime: | |
| self.model_ = runtime.deserialize_cuda_engine(f.read()) | |
| # Execution context | |
| self.context = self.model_.create_execution_context() | |
| self.bindings = OrderedDict() | |
| self.is_trt10 = not hasattr(self.model_, "num_bindings") | |
| num = range(self.model_.num_io_tensors) if self.is_trt10 else range(self.model_.num_bindings) | |
| # Parse bindings | |
| for index in num: | |
| if self.is_trt10: | |
| name = self.model_.get_tensor_name(index) | |
| dtype = trt.nptype(self.model_.get_tensor_dtype(name)) | |
| is_input = self.model_.get_tensor_mode(name) == trt.TensorIOMode.INPUT | |
| if is_input and -1 in tuple(self.model_.get_tensor_shape(name)): | |
| self.context.set_input_shape(name, tuple(self.model_.get_tensor_profile_shape(name, 0)[1])) | |
| if is_input and dtype == np.float16: | |
| self.fp16 = True | |
| shape = tuple(self.context.get_tensor_shape(name)) | |
| else: | |
| name = self.model_.get_binding_name(index) | |
| dtype = trt.nptype(self.model_.get_binding_dtype(index)) | |
| is_input = self.model_.binding_is_input(index) | |
| # Handle dynamic shapes | |
| if is_input and -1 in self.model_.get_binding_shape(index): | |
| profile_index = 0 | |
| min_shape, opt_shape, max_shape = self.model_.get_profile_shape(profile_index, index) | |
| self.context.set_binding_shape(index, opt_shape) | |
| if is_input and dtype == np.float16: | |
| self.fp16 = True | |
| shape = tuple(self.context.get_binding_shape(index)) | |
| data = torch.from_numpy(np.empty(shape, dtype=dtype)).to(self.device) | |
| self.bindings[name] = Binding(name, dtype, shape, data, int(data.data_ptr())) | |
| self.binding_addrs = OrderedDict((n, d.ptr) for n, d in self.bindings.items()) | |
| def forward(self, im_batch): | |
| temp_im_batch = im_batch.clone() | |
| batch_array = [] | |
| inp_batch = im_batch.shape[0] | |
| out_batch = self.bindings["output"].shape[0] | |
| resultant_features = [] | |
| # Divide batch to sub batches | |
| while inp_batch > out_batch: | |
| batch_array.append(temp_im_batch[:out_batch]) | |
| temp_im_batch = temp_im_batch[out_batch:] | |
| inp_batch = temp_im_batch.shape[0] | |
| if temp_im_batch.shape[0] > 0: | |
| batch_array.append(temp_im_batch) | |
| for temp_batch in batch_array: | |
| # Adjust for dynamic shapes | |
| if temp_batch.shape != self.bindings["images"].shape: | |
| if self.is_trt10: | |
| self.context.set_input_shape("images", temp_batch.shape) | |
| self.bindings["images"] = self.bindings["images"]._replace(shape=temp_batch.shape) | |
| self.bindings["output"].data.resize_(tuple(self.context.get_tensor_shape("output"))) | |
| else: | |
| i_in = self.model_.get_binding_index("images") | |
| i_out = self.model_.get_binding_index("output") | |
| self.context.set_binding_shape(i_in, temp_batch.shape) | |
| self.bindings["images"] = self.bindings["images"]._replace(shape=temp_batch.shape) | |
| output_shape = tuple(self.context.get_binding_shape(i_out)) | |
| self.bindings["output"].data.resize_(output_shape) | |
| s = self.bindings["images"].shape | |
| assert temp_batch.shape == s, f"Input size {temp_batch.shape} does not match model size {s}" | |
| self.binding_addrs["images"] = int(temp_batch.data_ptr()) | |
| # Execute inference | |
| self.context.execute_v2(list(self.binding_addrs.values())) | |
| features = self.bindings["output"].data | |
| resultant_features.append(features.clone()) | |
| if len(resultant_features)== 1: | |
| return resultant_features[0] | |
| else: | |
| rslt_features = torch.cat(resultant_features,dim=0) | |
| rslt_features= rslt_features[:im_batch.shape[0]] | |
| return rslt_features | |