3d_model / ylff /utils /tensorrt_export.py
Azan
Clean deployment build (Squashed)
7a87926
"""
TensorRT export utilities for optimized inference.
TensorRT provides 5-10x speedup over standard PyTorch inference for production deployment.
Requires: NVIDIA GPU, TensorRT SDK, and ONNX model.
"""
import logging
from pathlib import Path
from typing import Dict, List, Optional
import numpy as np
logger = logging.getLogger(__name__)
# Try to import TensorRT
try:
import pycuda.driver as cuda
import tensorrt as trt
TENSORRT_AVAILABLE = True
except ImportError:
TENSORRT_AVAILABLE = False
logger.warning("TensorRT not available. Install with: " "pip install nvidia-tensorrt pycuda")
def check_tensorrt_available() -> bool:
"""Check if TensorRT is available."""
return TENSORRT_AVAILABLE
def build_tensorrt_engine(
onnx_path: Path,
engine_path: Path,
precision: str = "fp16",
max_batch_size: int = 1,
max_workspace_size: int = 1 << 30, # 1GB
min_timing_iterations: int = 1,
avg_timing_iterations: int = 8,
int8_calibration_cache: Optional[Path] = None,
) -> Path:
"""
Build TensorRT engine from ONNX model.
Args:
onnx_path: Path to ONNX model
precision: Precision mode: "fp32", "fp16", or "int8"
max_batch_size: Maximum batch size
max_workspace_size: Maximum workspace size in bytes
min_timing_iterations: Minimum timing iterations for optimization
avg_timing_iterations: Average timing iterations for optimization
int8_calibration_cache: Path to INT8 calibration cache (for INT8 mode)
Returns:
Path to saved TensorRT engine
"""
if not TENSORRT_AVAILABLE:
raise RuntimeError(
"TensorRT not available. Install with: pip install nvidia-tensorrt pycuda"
)
if not onnx_path.exists():
raise FileNotFoundError(f"ONNX model not found: {onnx_path}")
logger.info(f"Building TensorRT engine from {onnx_path}")
logger.info(f"Precision: {precision}, Max batch size: {max_batch_size}")
# Create TensorRT logger
trt_logger = trt.Logger(trt.Logger.WARNING)
# Create builder and network
builder = trt.Builder(trt_logger)
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, trt_logger)
# Parse ONNX model
with open(onnx_path, "rb") as model:
if not parser.parse(model.read()):
logger.error("Failed to parse ONNX model")
for error in range(parser.num_errors):
logger.error(parser.get_error(error))
raise RuntimeError("Failed to parse ONNX model")
logger.info(f"ONNX model parsed successfully. Inputs: {network.num_inputs}")
# Configure builder
config = builder.create_builder_config()
config.max_workspace_size = max_workspace_size
# Set precision
if precision == "fp16":
if builder.platform_has_fast_fp16:
config.set_flag(trt.BuilderFlag.FP16)
logger.info("FP16 precision enabled")
else:
logger.warning("FP16 not supported on this platform, using FP32")
elif precision == "int8":
if builder.platform_has_fast_int8:
config.set_flag(trt.BuilderFlag.INT8)
logger.info("INT8 precision enabled")
if int8_calibration_cache:
# Load calibration cache
with open(int8_calibration_cache, "rb") as f:
config.int8_calibration_cache = f.read()
else:
logger.warning("INT8 not supported on this platform, using FP32")
# Set optimization profile (for dynamic shapes)
profile = builder.create_optimization_profile()
for i in range(network.num_inputs):
input_tensor = network.get_input(i)
shape = input_tensor.shape
# Set min, opt, max shapes (assuming batch dimension is first)
profile.set_shape(
input_tensor.name,
(1, *shape[1:]), # min
(max_batch_size, *shape[1:]), # opt
(max_batch_size, *shape[1:]), # max
)
config.add_optimization_profile(profile)
# Build engine
logger.info("Building TensorRT engine (this may take a while)...")
engine = builder.build_engine(network, config)
if engine is None:
raise RuntimeError("Failed to build TensorRT engine")
# Save engine
engine_path.parent.mkdir(parents=True, exist_ok=True)
with open(engine_path, "wb") as f:
f.write(engine.serialize())
logger.info(f"TensorRT engine saved to {engine_path}")
logger.info(f"Engine size: {engine_path.stat().st_size / 1024 / 1024:.2f} MB")
return engine_path
def load_tensorrt_engine(engine_path: Path):
"""
Load TensorRT engine from file.
Args:
engine_path: Path to TensorRT engine file
Returns:
TensorRT engine
"""
if not TENSORRT_AVAILABLE:
raise RuntimeError("TensorRT not available")
if not engine_path.exists():
raise FileNotFoundError(f"TensorRT engine not found: {engine_path}")
logger.info(f"Loading TensorRT engine from {engine_path}")
trt_logger = trt.Logger(trt.Logger.WARNING)
runtime = trt.Runtime(trt_logger)
with open(engine_path, "rb") as f:
engine = runtime.deserialize_cuda_engine(f.read())
if engine is None:
raise RuntimeError("Failed to load TensorRT engine")
logger.info("TensorRT engine loaded successfully")
return engine
class TensorRTInference:
"""
TensorRT inference wrapper.
Provides a simple interface for running inference with TensorRT engines.
"""
def __init__(self, engine_path: Path, device: int = 0):
"""
Initialize TensorRT inference.
Args:
engine_path: Path to TensorRT engine file
device: CUDA device ID
"""
if not TENSORRT_AVAILABLE:
raise RuntimeError("TensorRT not available")
self.engine = load_tensorrt_engine(engine_path)
self.context = self.engine.create_execution_context()
self.device = device
# Allocate buffers
self.inputs = []
self.outputs = []
self.bindings = []
self.stream = cuda.Stream()
for i in range(self.engine.num_io_tensors):
name = self.engine.get_tensor_name(i)
shape = self.engine.get_tensor_shape(name)
dtype = trt.nptype(self.engine.get_tensor_dtype(name))
size = trt.volume(shape) * np.dtype(dtype).itemsize
# Allocate host and device buffers
host_mem = cuda.pagelocked_empty(size, dtype)
device_mem = cuda.mem_alloc(host_mem.nbytes)
self.bindings.append(int(device_mem))
if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
self.inputs.append({"name": name, "host": host_mem, "device": device_mem})
else:
self.outputs.append({"name": name, "host": host_mem, "device": device_mem})
logger.info(
f"TensorRT inference initialized: {len(self.inputs)} inputs, "
f"{len(self.outputs)} outputs"
)
def __call__(self, *inputs: np.ndarray) -> List[np.ndarray]:
"""
Run inference.
Args:
*inputs: Input arrays (numpy)
Returns:
List of output arrays
"""
# Copy inputs to device
for i, inp in enumerate(self.inputs):
np.copyto(inp["host"], inputs[i].ravel())
cuda.memcpy_htod_async(inp["device"], inp["host"], self.stream)
# Set input shapes
for i, inp in enumerate(self.inputs):
self.context.set_input_shape(inp["name"], inputs[i].shape)
# Run inference
self.context.execute_async_v2(bindings=self.bindings, stream_handle=self.stream.handle)
# Copy outputs from device
outputs = []
for out in self.outputs:
cuda.memcpy_dtoh_async(out["host"], out["device"], self.stream)
outputs.append(out["host"])
self.stream.synchronize()
# Reshape outputs
reshaped_outputs = []
for i, out in enumerate(self.outputs):
shape = self.context.get_tensor_shape(out["name"])
reshaped_outputs.append(outputs[i].reshape(shape))
return reshaped_outputs
def __del__(self):
"""Cleanup CUDA resources."""
if hasattr(self, "stream"):
del self.stream
def benchmark_tensorrt(
engine_path: Path,
sample_inputs: List[np.ndarray],
num_runs: int = 100,
warmup_runs: int = 10,
) -> Dict[str, float]:
"""
Benchmark TensorRT inference.
Args:
engine_path: Path to TensorRT engine
sample_inputs: Sample input arrays
num_runs: Number of benchmark runs
warmup_runs: Number of warmup runs
Returns:
Dict with benchmark results (fps, latency_ms, etc.)
"""
if not TENSORRT_AVAILABLE:
raise RuntimeError("TensorRT not available")
logger.info(f"Benchmarking TensorRT engine: {engine_path}")
inference = TensorRTInference(engine_path)
# Warmup
for _ in range(warmup_runs):
_ = inference(*sample_inputs)
# Benchmark
import time
times = []
for _ in range(num_runs):
start = time.time()
_ = inference(*sample_inputs)
times.append(time.time() - start)
avg_time = np.mean(times)
std_time = np.std(times)
fps = 1.0 / avg_time
results = {
"fps": fps,
"latency_ms": avg_time * 1000,
"latency_std_ms": std_time * 1000,
"min_latency_ms": np.min(times) * 1000,
"max_latency_ms": np.max(times) * 1000,
}
logger.info("TensorRT Benchmark Results:")
logger.info(f" FPS: {fps:.2f}")
logger.info(f" Latency: {avg_time * 1000:.2f}ms ± {std_time * 1000:.2f}ms")
logger.info(f" Min: {np.min(times) * 1000:.2f}ms, " f"Max: {np.max(times) * 1000:.2f}ms")
return results