| import warnings
|
| from collections import namedtuple
|
| from functools import partial
|
| from pathlib import Path
|
| from typing import List, Optional, Union
|
|
|
| import numpy as np
|
| import onnxruntime
|
|
|
| try:
|
| import tensorrt as trt
|
| except Exception:
|
| trt = None
|
| import torch
|
|
|
| warnings.filterwarnings(action='ignore', category=DeprecationWarning)
|
|
|
|
|
| class TRTWrapper(torch.nn.Module):
|
| dtype_mapping = {}
|
|
|
| def __init__(self, weight: Union[str, Path],
|
| device: Optional[torch.device]):
|
| super().__init__()
|
| weight = Path(weight) if isinstance(weight, str) else weight
|
| assert weight.exists() and weight.suffix in ('.engine', '.plan')
|
| if isinstance(device, str):
|
| device = torch.device(device)
|
| elif isinstance(device, int):
|
| device = torch.device(f'cuda:{device}')
|
| self.weight = weight
|
| self.device = device
|
| self.stream = torch.cuda.Stream(device=device)
|
| self.__update_mapping()
|
| self.__init_engine()
|
| self.__init_bindings()
|
|
|
| def __update_mapping(self):
|
| self.dtype_mapping.update({
|
| trt.bool: torch.bool,
|
| trt.int8: torch.int8,
|
| trt.int32: torch.int32,
|
| trt.float16: torch.float16,
|
| trt.float32: torch.float32
|
| })
|
|
|
| def __init_engine(self):
|
| logger = trt.Logger(trt.Logger.ERROR)
|
| self.log = partial(logger.log, trt.Logger.ERROR)
|
| trt.init_libnvinfer_plugins(logger, namespace='')
|
| self.logger = logger
|
| with trt.Runtime(logger) as runtime:
|
| model = runtime.deserialize_cuda_engine(self.weight.read_bytes())
|
|
|
| context = model.create_execution_context()
|
|
|
| names = [model.get_binding_name(i) for i in range(model.num_bindings)]
|
|
|
| num_inputs, num_outputs = 0, 0
|
|
|
| for i in range(model.num_bindings):
|
| if model.binding_is_input(i):
|
| num_inputs += 1
|
| else:
|
| num_outputs += 1
|
|
|
| self.is_dynamic = -1 in model.get_binding_shape(0)
|
|
|
| self.model = model
|
| self.context = context
|
| self.input_names = names[:num_inputs]
|
| self.output_names = names[num_inputs:]
|
| self.num_inputs = num_inputs
|
| self.num_outputs = num_outputs
|
| self.num_bindings = num_inputs + num_outputs
|
| self.bindings: List[int] = [0] * self.num_bindings
|
|
|
| def __init_bindings(self):
|
| Binding = namedtuple('Binding', ('name', 'dtype', 'shape'))
|
| inputs_info = []
|
| outputs_info = []
|
|
|
| for i, name in enumerate(self.input_names):
|
| assert self.model.get_binding_name(i) == name
|
| dtype = self.dtype_mapping[self.model.get_binding_dtype(i)]
|
| shape = tuple(self.model.get_binding_shape(i))
|
| inputs_info.append(Binding(name, dtype, shape))
|
|
|
| for i, name in enumerate(self.output_names):
|
| i += self.num_inputs
|
| assert self.model.get_binding_name(i) == name
|
| dtype = self.dtype_mapping[self.model.get_binding_dtype(i)]
|
| shape = tuple(self.model.get_binding_shape(i))
|
| outputs_info.append(Binding(name, dtype, shape))
|
| self.inputs_info = inputs_info
|
| self.outputs_info = outputs_info
|
| if not self.is_dynamic:
|
| self.output_tensor = [
|
| torch.empty(o.shape, dtype=o.dtype, device=self.device)
|
| for o in outputs_info
|
| ]
|
|
|
| def forward(self, *inputs):
|
|
|
| assert len(inputs) == self.num_inputs
|
|
|
| contiguous_inputs: List[torch.Tensor] = [
|
| i.contiguous() for i in inputs
|
| ]
|
|
|
| for i in range(self.num_inputs):
|
| self.bindings[i] = contiguous_inputs[i].data_ptr()
|
| if self.is_dynamic:
|
| self.context.set_binding_shape(
|
| i, tuple(contiguous_inputs[i].shape))
|
|
|
|
|
| outputs: List[torch.Tensor] = []
|
|
|
| for i in range(self.num_outputs):
|
| j = i + self.num_inputs
|
| if self.is_dynamic:
|
| shape = tuple(self.context.get_binding_shape(j))
|
| output = torch.empty(
|
| size=shape,
|
| dtype=self.output_dtypes[i],
|
| device=self.device)
|
|
|
| else:
|
| output = self.output_tensor[i]
|
| outputs.append(output)
|
| self.bindings[j] = output.data_ptr()
|
|
|
| self.context.execute_async_v2(self.bindings, self.stream.cuda_stream)
|
| self.stream.synchronize()
|
|
|
| return tuple(outputs)
|
|
|
|
|
| class ORTWrapper(torch.nn.Module):
|
|
|
| def __init__(self, weight: Union[str, Path],
|
| device: Optional[torch.device]):
|
| super().__init__()
|
| weight = Path(weight) if isinstance(weight, str) else weight
|
| assert weight.exists() and weight.suffix == '.onnx'
|
|
|
| if isinstance(device, str):
|
| device = torch.device(device)
|
| elif isinstance(device, int):
|
| device = torch.device(f'cuda:{device}')
|
| self.weight = weight
|
| self.device = device
|
| self.__init_session()
|
| self.__init_bindings()
|
|
|
| def __init_session(self):
|
| providers = ['CPUExecutionProvider']
|
| if 'cuda' in self.device.type:
|
| providers.insert(0, 'CUDAExecutionProvider')
|
|
|
| session = onnxruntime.InferenceSession(
|
| str(self.weight), providers=providers)
|
| self.session = session
|
|
|
| def __init_bindings(self):
|
| Binding = namedtuple('Binding', ('name', 'dtype', 'shape'))
|
| inputs_info = []
|
| outputs_info = []
|
| self.is_dynamic = False
|
| for i, tensor in enumerate(self.session.get_inputs()):
|
| if any(not isinstance(i, int) for i in tensor.shape):
|
| self.is_dynamic = True
|
| inputs_info.append(
|
| Binding(tensor.name, tensor.type, tuple(tensor.shape)))
|
|
|
| for i, tensor in enumerate(self.session.get_outputs()):
|
| outputs_info.append(
|
| Binding(tensor.name, tensor.type, tuple(tensor.shape)))
|
| self.inputs_info = inputs_info
|
| self.outputs_info = outputs_info
|
| self.num_inputs = len(inputs_info)
|
|
|
| def forward(self, *inputs):
|
|
|
| assert len(inputs) == self.num_inputs
|
|
|
| contiguous_inputs: List[np.ndarray] = [
|
| i.contiguous().cpu().numpy() for i in inputs
|
| ]
|
|
|
| if not self.is_dynamic:
|
|
|
| for i in range(self.num_inputs):
|
| assert contiguous_inputs[i].shape == self.inputs_info[i].shape
|
|
|
| outputs = self.session.run([o.name for o in self.outputs_info], {
|
| j.name: contiguous_inputs[i]
|
| for i, j in enumerate(self.inputs_info)
|
| })
|
|
|
| return tuple(torch.from_numpy(o).to(self.device) for o in outputs)
|
|
|