tensorrt-qkv-runnerstate-poc / qkv_runnerstate_fixed_harness.cpp
shiyamganesh's picture
Upload 11 files
23abecf verified
#include <cstring>
#include <cstdint>
#include <iostream>
#include <stdexcept>
#include <string>
enum PluginFieldType { kUNKNOWN = 0, kINT32 = 1 };
struct PluginField { const char* name; const void* data; PluginFieldType type; int32_t length; };
struct PluginFieldCollection { int32_t nbFields; const PluginField* fields; };
template <typename T>
void deserialize_value(void const** buffer, size_t* buffer_size, T* value) {
if (*buffer_size < sizeof(T)) throw std::runtime_error("buffer too small");
std::memcpy(value, *buffer, sizeof(T));
reinterpret_cast<char const*&>(*buffer) += sizeof(T);
*buffer_size -= sizeof(T);
}
struct FakeDispatcher {
int a{}, b{}, c{}, d{}, e{};
size_t getSerializationSize() const { return sizeof(a)+sizeof(b)+sizeof(c)+sizeof(d)+sizeof(e); }
void deserialize(const void* data, size_t length) {
deserialize_value(&data, &length, &a);
deserialize_value(&data, &length, &b);
deserialize_value(&data, &length, &c);
deserialize_value(&data, &length, &d);
deserialize_value(&data, &length, &e);
}
};
void fixed_qkv_like_createPlugin_runtime(const PluginFieldCollection* fc) {
const void* runnerStateBuffer = nullptr;
int32_t runnerStateBufferLength = -1;
PluginFieldType runnerStateBufferType = kINT32;
for (int i=0; i<fc->nbFields; ++i) {
if (std::string(fc->fields[i].name) == "runnerStateBuffer") {
runnerStateBuffer = fc->fields[i].data;
runnerStateBufferLength = fc->fields[i].length;
runnerStateBufferType = fc->fields[i].type;
}
}
if (!runnerStateBuffer) throw std::runtime_error("missing runnerStateBuffer");
FakeDispatcher d;
auto expected = d.getSerializationSize();
if (runnerStateBufferType != kUNKNOWN) throw std::runtime_error("invalid runnerStateBuffer type");
if (runnerStateBufferLength != static_cast<int32_t>(expected)) {
throw std::runtime_error("invalid runnerStateBuffer length: rejected before deserialize");
}
d.deserialize(runnerStateBuffer, expected);
}
int main() {
char* tiny = new char[1];
tiny[0] = 0x41;
PluginField f{"runnerStateBuffer", tiny, kUNKNOWN, 1};
PluginFieldCollection fc{1, &f};
try {
fixed_qkv_like_createPlugin_runtime(&fc);
std::cerr << "unexpected success\n";
} catch (const std::exception& e) {
std::cerr << "Rejected safely: " << e.what() << "\n";
}
delete[] tiny;
return 0;
}