BechusRantus's picture
Upload folder using huggingface_hub
7134ce7 verified
import warnings
from collections import namedtuple
from functools import partial
from pathlib import Path
from typing import List, Optional, Union
import numpy as np
import onnxruntime
try:
import tensorrt as trt
except Exception:
trt = None
import torch
warnings.filterwarnings(action='ignore', category=DeprecationWarning)
class TRTWrapper(torch.nn.Module):
dtype_mapping = {}
def __init__(self, weight: Union[str, Path],
device: Optional[torch.device]):
super().__init__()
weight = Path(weight) if isinstance(weight, str) else weight
assert weight.exists() and weight.suffix in ('.engine', '.plan')
if isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
device = torch.device(f'cuda:{device}')
self.weight = weight
self.device = device
self.stream = torch.cuda.Stream(device=device)
self.__update_mapping()
self.__init_engine()
self.__init_bindings()
def __update_mapping(self):
self.dtype_mapping.update({
trt.bool: torch.bool,
trt.int8: torch.int8,
trt.int32: torch.int32,
trt.float16: torch.float16,
trt.float32: torch.float32
})
def __init_engine(self):
logger = trt.Logger(trt.Logger.ERROR)
self.log = partial(logger.log, trt.Logger.ERROR)
trt.init_libnvinfer_plugins(logger, namespace='')
self.logger = logger
with trt.Runtime(logger) as runtime:
model = runtime.deserialize_cuda_engine(self.weight.read_bytes())
context = model.create_execution_context()
names = [model.get_binding_name(i) for i in range(model.num_bindings)]
num_inputs, num_outputs = 0, 0
for i in range(model.num_bindings):
if model.binding_is_input(i):
num_inputs += 1
else:
num_outputs += 1
self.is_dynamic = -1 in model.get_binding_shape(0)
self.model = model
self.context = context
self.input_names = names[:num_inputs]
self.output_names = names[num_inputs:]
self.num_inputs = num_inputs
self.num_outputs = num_outputs
self.num_bindings = num_inputs + num_outputs
self.bindings: List[int] = [0] * self.num_bindings
def __init_bindings(self):
Binding = namedtuple('Binding', ('name', 'dtype', 'shape'))
inputs_info = []
outputs_info = []
for i, name in enumerate(self.input_names):
assert self.model.get_binding_name(i) == name
dtype = self.dtype_mapping[self.model.get_binding_dtype(i)]
shape = tuple(self.model.get_binding_shape(i))
inputs_info.append(Binding(name, dtype, shape))
for i, name in enumerate(self.output_names):
i += self.num_inputs
assert self.model.get_binding_name(i) == name
dtype = self.dtype_mapping[self.model.get_binding_dtype(i)]
shape = tuple(self.model.get_binding_shape(i))
outputs_info.append(Binding(name, dtype, shape))
self.inputs_info = inputs_info
self.outputs_info = outputs_info
if not self.is_dynamic:
self.output_tensor = [
torch.empty(o.shape, dtype=o.dtype, device=self.device)
for o in outputs_info
]
def forward(self, *inputs):
assert len(inputs) == self.num_inputs
contiguous_inputs: List[torch.Tensor] = [
i.contiguous() for i in inputs
]
for i in range(self.num_inputs):
self.bindings[i] = contiguous_inputs[i].data_ptr()
if self.is_dynamic:
self.context.set_binding_shape(
i, tuple(contiguous_inputs[i].shape))
# create output tensors
outputs: List[torch.Tensor] = []
for i in range(self.num_outputs):
j = i + self.num_inputs
if self.is_dynamic:
shape = tuple(self.context.get_binding_shape(j))
output = torch.empty(
size=shape,
dtype=self.output_dtypes[i],
device=self.device)
else:
output = self.output_tensor[i]
outputs.append(output)
self.bindings[j] = output.data_ptr()
self.context.execute_async_v2(self.bindings, self.stream.cuda_stream)
self.stream.synchronize()
return tuple(outputs)
class ORTWrapper(torch.nn.Module):
def __init__(self, weight: Union[str, Path],
device: Optional[torch.device]):
super().__init__()
weight = Path(weight) if isinstance(weight, str) else weight
assert weight.exists() and weight.suffix == '.onnx'
if isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
device = torch.device(f'cuda:{device}')
self.weight = weight
self.device = device
self.__init_session()
self.__init_bindings()
def __init_session(self):
providers = ['CPUExecutionProvider']
if 'cuda' in self.device.type:
providers.insert(0, 'CUDAExecutionProvider')
session = onnxruntime.InferenceSession(
str(self.weight), providers=providers)
self.session = session
def __init_bindings(self):
Binding = namedtuple('Binding', ('name', 'dtype', 'shape'))
inputs_info = []
outputs_info = []
self.is_dynamic = False
for i, tensor in enumerate(self.session.get_inputs()):
if any(not isinstance(i, int) for i in tensor.shape):
self.is_dynamic = True
inputs_info.append(
Binding(tensor.name, tensor.type, tuple(tensor.shape)))
for i, tensor in enumerate(self.session.get_outputs()):
outputs_info.append(
Binding(tensor.name, tensor.type, tuple(tensor.shape)))
self.inputs_info = inputs_info
self.outputs_info = outputs_info
self.num_inputs = len(inputs_info)
def forward(self, *inputs):
assert len(inputs) == self.num_inputs
contiguous_inputs: List[np.ndarray] = [
i.contiguous().cpu().numpy() for i in inputs
]
if not self.is_dynamic:
# make sure input shape is right for static input shape
for i in range(self.num_inputs):
assert contiguous_inputs[i].shape == self.inputs_info[i].shape
outputs = self.session.run([o.name for o in self.outputs_info], {
j.name: contiguous_inputs[i]
for i, j in enumerate(self.inputs_info)
})
return tuple(torch.from_numpy(o).to(self.device) for o in outputs)