File size: 8,172 Bytes
af83d87 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 | # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TensorRT Engine wrapper for GR00T inference.
Loads serialized TRT engines, manages input/output tensor bindings, and
executes inference. Supports dynamic shapes and BF16/FP16/FP32 dtypes.
"""
import atexit
import ctypes
import os
import tensorrt as trt
import torch
def torch_type(trt_type):
"""Convert TensorRT data type to PyTorch equivalent."""
mapping = {
trt.float32: torch.float32,
trt.float16: torch.float16,
trt.bfloat16: torch.bfloat16,
trt.int8: torch.int8,
trt.int32: torch.int32,
trt.bool: torch.bool,
trt.uint8: torch.uint8,
trt.int64: torch.int64,
}
if trt_type in mapping:
return mapping[trt_type]
raise TypeError(
f"Could not resolve TensorRT datatype to an equivalent PyTorch datatype. {trt_type}"
)
class Engine(object):
"""TensorRT engine wrapper for loading and executing inference."""
def __init__(self, file, plugins=[]):
super().__init__()
self.logger = trt.Logger(trt.Logger.ERROR)
trt.init_libnvinfer_plugins(self.logger, "")
self.plugins = [ctypes.CDLL(plugin, ctypes.RTLD_GLOBAL) for plugin in plugins]
self.file = file
self.load(file)
def destroy(self):
del self.execution_context
del self.handle
atexit.register(destroy, self)
self.print()
def print(self):
"""Display engine details (inputs/outputs) on rank 0 only."""
if int(os.getenv("LOCAL_RANK", -1)) not in [0, -1]:
return
print("============= TRT Engine Detail =============")
print(f"Engine file: {self.file}")
print(f"Inputs: {len(self.in_meta)}")
for ib, item in enumerate(self.in_meta):
tensor_name, shape, dtype = item[:3]
print(f" {ib}. {tensor_name}: {'x'.join(map(str, shape))} [{dtype}]")
print(f"Outputs: {len(self.out_meta)}")
for ib, item in enumerate(self.out_meta):
tensor_name, shape, dtype = item[:3]
print(f" {ib}. {tensor_name}: {'x'.join(map(str, shape))} [{dtype}]")
print("=============================================")
def load(self, file):
"""Deserialize and load a TensorRT engine from file."""
runtime = trt.Runtime(self.logger)
with open(file, "rb") as f:
self.handle = runtime.deserialize_cuda_engine(f.read())
assert self.handle is not None, (
f"Failed to deserialize the cuda engine from file: {file}"
)
self.execution_context = self.handle.create_execution_context()
self.meta, self.in_meta, self.out_meta = [], [], []
for tensor_name in self.handle:
shape = self.handle.get_tensor_shape(tensor_name)
dtype = torch_type(self.handle.get_tensor_dtype(tensor_name))
if self.handle.get_tensor_mode(tensor_name) == trt.TensorIOMode.INPUT:
self.in_meta.append([tensor_name, shape, dtype])
else:
self.out_meta.append([tensor_name, shape, dtype])
def __call__(self, *args, **inputs):
return self.forward(*args, **inputs)
def dtype_of(self, tensor_name: str) -> torch.dtype:
"""Return the expected PyTorch dtype for a named input tensor."""
for name, _shape, dtype in self.in_meta:
if name == tensor_name:
return dtype
raise KeyError(f"Input tensor '{tensor_name}' not found in engine.")
def set_runtime_tensor_shape(self, name, shape):
"""Set runtime input shape for dynamic dimensions."""
self.execution_context.set_input_shape(name, shape)
def forward(self, *args, **kwargs):
"""Execute TRT inference with the given input tensors.
Accepts both positional and keyword arguments.
Returns a dict of output tensors by default, or a list if return_list=True.
"""
return_list = kwargs.pop("return_list", False)
reference_tensors = []
stream = torch.cuda.current_stream()
# Process positional arguments
for iarg, x in enumerate(args):
name, shape, dtype = self.in_meta[iarg]
runtime_shape = self.execution_context.get_tensor_shape(name)
assert isinstance(x, torch.Tensor), f"Unsupported tensor type: {type(x)}"
assert runtime_shape == x.shape, f"Invalid input shape: {runtime_shape} != {x.shape}"
assert dtype == x.dtype, (
f"Invalid tensor dtype, expected dtype is {dtype}, but got {x.dtype}"
)
assert x.is_cuda, f"Invalid tensor device, expected device is cuda, but got {x.device}"
x = x.cuda().contiguous()
self.execution_context.set_tensor_address(name, x.data_ptr())
reference_tensors.append(x)
# Process keyword arguments
for name, shape, dtype in self.in_meta:
if name not in kwargs:
continue
runtime_shape = self.execution_context.get_tensor_shape(name)
x = kwargs[name]
assert isinstance(x, torch.Tensor), f"Unsupported tensor[{name}] type: {type(x)}"
assert runtime_shape == x.shape, (
f"Invalid input[{name}] shape: {x.shape}, but the expected shape is: {runtime_shape}"
)
assert dtype == x.dtype, (
f"Invalid tensor[{name}] dtype, expected dtype is {dtype}, but got {x.dtype}"
)
assert x.is_cuda, (
f"Invalid tensor[{name}] device, expected device is cuda, but got {x.device}"
)
x = x.cuda().contiguous()
self.execution_context.set_tensor_address(name, x.data_ptr())
reference_tensors.append(x)
# Allocate output tensors
for item in self.out_meta:
name = item[0]
runtime_shape = self.execution_context.get_tensor_shape(name)
output_tensor = torch.zeros(
*runtime_shape, dtype=item[2], device=reference_tensors[0].device
)
self.execution_context.set_tensor_address(name, output_tensor.data_ptr())
reference_tensors.append(output_tensor)
# Execute
self.execution_context.execute_async_v3(stream.cuda_stream)
stream.synchronize()
assert len(reference_tensors) == len(self.in_meta) + len(self.out_meta), (
f"Invalid input tensors. The expected I/O tensors are "
f"{len(self.in_meta) + len(self.out_meta)}, but got {len(reference_tensors)}"
)
if return_list:
return [
reference_tensors[len(self.in_meta) + i] for i, item in enumerate(self.out_meta)
]
else:
return {
item[0]: reference_tensors[len(self.in_meta) + i]
for i, item in enumerate(self.out_meta)
}
|