| |
| |
| |
| #include "WeightsContext.hpp" |
| #include "onnx/onnx-ml.pb.h" |
| #include <NvInfer.h> |
| #include <fstream> |
| #include <iostream> |
|
|
| struct Logger : public nvinfer1::ILogger |
| { |
| void log(Severity s, char const* msg) noexcept override |
| { |
| if (s <= Severity::kWARNING) std::cerr << "[trt] " << msg << "\n"; |
| } |
| }; |
|
|
| int main(int argc, char** argv) |
| { |
| GOOGLE_PROTOBUF_VERIFY_VERSION; |
| std::string onnxPath = (argc > 1) ? argv[1] : "malicious.onnx"; |
|
|
| |
| onnx::ModelProto model; |
| { |
| std::ifstream in(onnxPath, std::ios::binary); |
| if (!in || !model.ParseFromIstream(&in)) { std::cerr << "failed to read " << onnxPath << "\n"; return 1; } |
| } |
| if (model.graph().initializer_size() == 0) { std::cerr << "no initializer\n"; return 1; } |
| onnx::TensorProto const& t = model.graph().initializer(0); |
|
|
| Logger logger; |
| onnx2trt::WeightsContext ctx(&logger); |
| ctx.setOnnxFileLocation(onnxPath); |
|
|
| onnx2trt::ShapedWeights w{}; |
| std::cerr << "[*] loading " << onnxPath << " -> WeightsContext::convertOnnxWeights (real onnx-tensorrt path)\n"; |
| std::cerr << "[*] initializer '" << t.name() << "' external_data offset is attacker-controlled and unchecked...\n"; |
| bool ok = ctx.convertOnnxWeights(t, &w); |
| std::cerr << "[*] returned " << ok << " WITHOUT crashing -> offset was bounded (no bug)\n"; |
| return 0; |
| } |
|
|