""" TensorRT export utilities for optimized inference. TensorRT provides 5-10x speedup over standard PyTorch inference for production deployment. Requires: NVIDIA GPU, TensorRT SDK, and ONNX model. """ import logging from pathlib import Path from typing import Dict, List, Optional import numpy as np logger = logging.getLogger(__name__) # Try to import TensorRT try: import pycuda.driver as cuda import tensorrt as trt TENSORRT_AVAILABLE = True except ImportError: TENSORRT_AVAILABLE = False logger.warning("TensorRT not available. Install with: " "pip install nvidia-tensorrt pycuda") def check_tensorrt_available() -> bool: """Check if TensorRT is available.""" return TENSORRT_AVAILABLE def build_tensorrt_engine( onnx_path: Path, engine_path: Path, precision: str = "fp16", max_batch_size: int = 1, max_workspace_size: int = 1 << 30, # 1GB min_timing_iterations: int = 1, avg_timing_iterations: int = 8, int8_calibration_cache: Optional[Path] = None, ) -> Path: """ Build TensorRT engine from ONNX model. Args: onnx_path: Path to ONNX model precision: Precision mode: "fp32", "fp16", or "int8" max_batch_size: Maximum batch size max_workspace_size: Maximum workspace size in bytes min_timing_iterations: Minimum timing iterations for optimization avg_timing_iterations: Average timing iterations for optimization int8_calibration_cache: Path to INT8 calibration cache (for INT8 mode) Returns: Path to saved TensorRT engine """ if not TENSORRT_AVAILABLE: raise RuntimeError( "TensorRT not available. Install with: pip install nvidia-tensorrt pycuda" ) if not onnx_path.exists(): raise FileNotFoundError(f"ONNX model not found: {onnx_path}") logger.info(f"Building TensorRT engine from {onnx_path}") logger.info(f"Precision: {precision}, Max batch size: {max_batch_size}") # Create TensorRT logger trt_logger = trt.Logger(trt.Logger.WARNING) # Create builder and network builder = trt.Builder(trt_logger) network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) parser = trt.OnnxParser(network, trt_logger) # Parse ONNX model with open(onnx_path, "rb") as model: if not parser.parse(model.read()): logger.error("Failed to parse ONNX model") for error in range(parser.num_errors): logger.error(parser.get_error(error)) raise RuntimeError("Failed to parse ONNX model") logger.info(f"ONNX model parsed successfully. Inputs: {network.num_inputs}") # Configure builder config = builder.create_builder_config() config.max_workspace_size = max_workspace_size # Set precision if precision == "fp16": if builder.platform_has_fast_fp16: config.set_flag(trt.BuilderFlag.FP16) logger.info("FP16 precision enabled") else: logger.warning("FP16 not supported on this platform, using FP32") elif precision == "int8": if builder.platform_has_fast_int8: config.set_flag(trt.BuilderFlag.INT8) logger.info("INT8 precision enabled") if int8_calibration_cache: # Load calibration cache with open(int8_calibration_cache, "rb") as f: config.int8_calibration_cache = f.read() else: logger.warning("INT8 not supported on this platform, using FP32") # Set optimization profile (for dynamic shapes) profile = builder.create_optimization_profile() for i in range(network.num_inputs): input_tensor = network.get_input(i) shape = input_tensor.shape # Set min, opt, max shapes (assuming batch dimension is first) profile.set_shape( input_tensor.name, (1, *shape[1:]), # min (max_batch_size, *shape[1:]), # opt (max_batch_size, *shape[1:]), # max ) config.add_optimization_profile(profile) # Build engine logger.info("Building TensorRT engine (this may take a while)...") engine = builder.build_engine(network, config) if engine is None: raise RuntimeError("Failed to build TensorRT engine") # Save engine engine_path.parent.mkdir(parents=True, exist_ok=True) with open(engine_path, "wb") as f: f.write(engine.serialize()) logger.info(f"TensorRT engine saved to {engine_path}") logger.info(f"Engine size: {engine_path.stat().st_size / 1024 / 1024:.2f} MB") return engine_path def load_tensorrt_engine(engine_path: Path): """ Load TensorRT engine from file. Args: engine_path: Path to TensorRT engine file Returns: TensorRT engine """ if not TENSORRT_AVAILABLE: raise RuntimeError("TensorRT not available") if not engine_path.exists(): raise FileNotFoundError(f"TensorRT engine not found: {engine_path}") logger.info(f"Loading TensorRT engine from {engine_path}") trt_logger = trt.Logger(trt.Logger.WARNING) runtime = trt.Runtime(trt_logger) with open(engine_path, "rb") as f: engine = runtime.deserialize_cuda_engine(f.read()) if engine is None: raise RuntimeError("Failed to load TensorRT engine") logger.info("TensorRT engine loaded successfully") return engine class TensorRTInference: """ TensorRT inference wrapper. Provides a simple interface for running inference with TensorRT engines. """ def __init__(self, engine_path: Path, device: int = 0): """ Initialize TensorRT inference. Args: engine_path: Path to TensorRT engine file device: CUDA device ID """ if not TENSORRT_AVAILABLE: raise RuntimeError("TensorRT not available") self.engine = load_tensorrt_engine(engine_path) self.context = self.engine.create_execution_context() self.device = device # Allocate buffers self.inputs = [] self.outputs = [] self.bindings = [] self.stream = cuda.Stream() for i in range(self.engine.num_io_tensors): name = self.engine.get_tensor_name(i) shape = self.engine.get_tensor_shape(name) dtype = trt.nptype(self.engine.get_tensor_dtype(name)) size = trt.volume(shape) * np.dtype(dtype).itemsize # Allocate host and device buffers host_mem = cuda.pagelocked_empty(size, dtype) device_mem = cuda.mem_alloc(host_mem.nbytes) self.bindings.append(int(device_mem)) if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT: self.inputs.append({"name": name, "host": host_mem, "device": device_mem}) else: self.outputs.append({"name": name, "host": host_mem, "device": device_mem}) logger.info( f"TensorRT inference initialized: {len(self.inputs)} inputs, " f"{len(self.outputs)} outputs" ) def __call__(self, *inputs: np.ndarray) -> List[np.ndarray]: """ Run inference. Args: *inputs: Input arrays (numpy) Returns: List of output arrays """ # Copy inputs to device for i, inp in enumerate(self.inputs): np.copyto(inp["host"], inputs[i].ravel()) cuda.memcpy_htod_async(inp["device"], inp["host"], self.stream) # Set input shapes for i, inp in enumerate(self.inputs): self.context.set_input_shape(inp["name"], inputs[i].shape) # Run inference self.context.execute_async_v2(bindings=self.bindings, stream_handle=self.stream.handle) # Copy outputs from device outputs = [] for out in self.outputs: cuda.memcpy_dtoh_async(out["host"], out["device"], self.stream) outputs.append(out["host"]) self.stream.synchronize() # Reshape outputs reshaped_outputs = [] for i, out in enumerate(self.outputs): shape = self.context.get_tensor_shape(out["name"]) reshaped_outputs.append(outputs[i].reshape(shape)) return reshaped_outputs def __del__(self): """Cleanup CUDA resources.""" if hasattr(self, "stream"): del self.stream def benchmark_tensorrt( engine_path: Path, sample_inputs: List[np.ndarray], num_runs: int = 100, warmup_runs: int = 10, ) -> Dict[str, float]: """ Benchmark TensorRT inference. Args: engine_path: Path to TensorRT engine sample_inputs: Sample input arrays num_runs: Number of benchmark runs warmup_runs: Number of warmup runs Returns: Dict with benchmark results (fps, latency_ms, etc.) """ if not TENSORRT_AVAILABLE: raise RuntimeError("TensorRT not available") logger.info(f"Benchmarking TensorRT engine: {engine_path}") inference = TensorRTInference(engine_path) # Warmup for _ in range(warmup_runs): _ = inference(*sample_inputs) # Benchmark import time times = [] for _ in range(num_runs): start = time.time() _ = inference(*sample_inputs) times.append(time.time() - start) avg_time = np.mean(times) std_time = np.std(times) fps = 1.0 / avg_time results = { "fps": fps, "latency_ms": avg_time * 1000, "latency_std_ms": std_time * 1000, "min_latency_ms": np.min(times) * 1000, "max_latency_ms": np.max(times) * 1000, } logger.info("TensorRT Benchmark Results:") logger.info(f" FPS: {fps:.2f}") logger.info(f" Latency: {avg_time * 1000:.2f}ms ± {std_time * 1000:.2f}ms") logger.info(f" Min: {np.min(times) * 1000:.2f}ms, " f"Max: {np.max(times) * 1000:.2f}ms") return results