tensorrt-polygraphy-serialized-plugin-ace-poc / trt_serialized_plugin_marker_probe.py
noshkas's picture
Add TensorRT Polygraphy serialized plugin ACE PoC
5c17b14 verified
#!/usr/bin/env python3
import argparse
import json
import os
import shutil
import subprocess
import sys
import textwrap
from pathlib import Path
import tensorrt as trt
C_SOURCE = r"""
#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#include <unistd.h>
__attribute__((constructor))
static void marker_constructor(void) {
const char *path = getenv("TRT_PLUGIN_MARKER");
if (!path || !path[0]) {
return;
}
FILE *f = fopen(path, "a");
if (!f) {
return;
}
time_t now = time(NULL);
fprintf(f, "marker_constructor pid=%ld time=%ld\n", (long)getpid(), (long)now);
fclose(f);
}
extern "C" int trt_marker_export(void) {
return 1337;
}
"""
def run(cmd, **kwargs):
proc = subprocess.run(cmd, text=True, capture_output=True, **kwargs)
return {
"cmd": cmd,
"returncode": proc.returncode,
"stdout_tail": proc.stdout[-4000:],
"stderr_tail": proc.stderr[-4000:],
}
def compile_marker_lib(work: Path):
src = work / "marker_payload.cpp"
lib = work / "libmarker_payload.so"
src.write_text(C_SOURCE)
result = run(["g++", "-shared", "-fPIC", "-O2", str(src), "-o", str(lib)])
if result["returncode"] != 0:
raise RuntimeError(json.dumps(result, indent=2))
return src, lib, result
def build_engine(path: Path, plugin_lib: Path):
logger = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(logger)
network = builder.create_network(
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
)
inp = network.add_input("x", trt.float32, (1, 1))
identity = network.add_identity(inp)
identity.get_output(0).name = "y"
network.mark_output(identity.get_output(0))
config = builder.create_builder_config()
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 20)
config.set_flag(trt.BuilderFlag.VERSION_COMPATIBLE)
config.plugins_to_serialize = [str(plugin_lib)]
serialized = builder.build_serialized_network(network, config)
if serialized is None:
raise RuntimeError("failed to build serialized plugin engine")
path.write_bytes(bytes(serialized))
return {
"path": str(path),
"size": path.stat().st_size,
"plugins_to_serialize": [str(plugin_lib)],
}
def try_deserialize(path: Path, marker: Path, allow_host_code: bool):
logger = trt.Logger(trt.Logger.WARNING)
runtime = trt.Runtime(logger)
if hasattr(runtime, "engine_host_code_allowed"):
runtime.engine_host_code_allowed = allow_host_code
before = marker.read_text() if marker.exists() else ""
try:
engine = runtime.deserialize_cuda_engine(path.read_bytes())
ok = engine is not None
if engine is not None:
_ = engine.num_io_tensors
exc = None
except Exception as err:
ok = False
exc = f"{type(err).__name__}: {err}"
after = marker.read_text() if marker.exists() else ""
return {
"allow_host_code": allow_host_code,
"ok": ok,
"exception": exc,
"marker_changed": after != before,
"marker_after": after,
}
def polygraphy_inspect(path: Path, marker: Path):
env = os.environ.copy()
env["TRT_PLUGIN_MARKER"] = str(marker)
before = marker.read_text() if marker.exists() else ""
proc = subprocess.run(
[
"polygraphy",
"inspect",
"model",
str(path),
"--model-type=engine",
"--show",
"attrs",
],
text=True,
capture_output=True,
timeout=90,
env=env,
)
after = marker.read_text() if marker.exists() else ""
return {
"cmd": [
"polygraphy",
"inspect",
"model",
str(path),
"--model-type=engine",
"--show",
"attrs",
],
"returncode": proc.returncode,
"stdout_tail": proc.stdout[-4000:],
"stderr_tail": proc.stderr[-4000:],
"marker_changed": after != before,
"marker_after": after,
}
def main():
parser = argparse.ArgumentParser(
description="Test whether TensorRT serialized plugin libraries execute from an engine file."
)
parser.add_argument("--out", default="results/trt_serialized_plugin_marker_probe.json")
args = parser.parse_args()
out = Path(args.out)
base = out.parent.parent
work = base / "plugin-work"
cases = base / "cases"
out.parent.mkdir(parents=True, exist_ok=True)
work.mkdir(parents=True, exist_ok=True)
cases.mkdir(parents=True, exist_ok=True)
src, lib, compile_result = compile_marker_lib(work)
marker = out.parent / "serialized_plugin_marker.txt"
if marker.exists():
marker.unlink()
engine = cases / "vc_serialized_marker_plugin.engine"
result = {
"python": sys.version,
"tensorrt_version": trt.__version__,
"compile": compile_result,
"source": str(src),
"plugin_lib": str(lib),
}
try:
result["build"] = build_engine(engine, lib)
except Exception as err:
result["build_error"] = f"{type(err).__name__}: {err}"
out.write_text(json.dumps(result, indent=2, sort_keys=True))
print(json.dumps(result, indent=2, sort_keys=True))
return 2
removed_lib = work / "libmarker_payload.removed"
shutil.move(lib, removed_lib)
result["plugin_removed_before_load"] = {
"original_exists": lib.exists(),
"moved_to": str(removed_lib),
"moved_exists": removed_lib.exists(),
}
os.environ["TRT_PLUGIN_MARKER"] = str(marker)
result["deserialize_false"] = try_deserialize(
engine, marker, allow_host_code=False
)
result["deserialize_true"] = try_deserialize(
engine, marker, allow_host_code=True
)
result["polygraphy_inspect"] = polygraphy_inspect(engine, marker)
out.write_text(json.dumps(result, indent=2, sort_keys=True))
print(json.dumps(result, indent=2, sort_keys=True))
return 0
if __name__ == "__main__":
raise SystemExit(main())