#!/usr/bin/env python3 import argparse import json import os import shutil import subprocess import sys import textwrap from pathlib import Path import tensorrt as trt C_SOURCE = r""" #include #include #include #include __attribute__((constructor)) static void marker_constructor(void) { const char *path = getenv("TRT_PLUGIN_MARKER"); if (!path || !path[0]) { return; } FILE *f = fopen(path, "a"); if (!f) { return; } time_t now = time(NULL); fprintf(f, "marker_constructor pid=%ld time=%ld\n", (long)getpid(), (long)now); fclose(f); } extern "C" int trt_marker_export(void) { return 1337; } """ def run(cmd, **kwargs): proc = subprocess.run(cmd, text=True, capture_output=True, **kwargs) return { "cmd": cmd, "returncode": proc.returncode, "stdout_tail": proc.stdout[-4000:], "stderr_tail": proc.stderr[-4000:], } def compile_marker_lib(work: Path): src = work / "marker_payload.cpp" lib = work / "libmarker_payload.so" src.write_text(C_SOURCE) result = run(["g++", "-shared", "-fPIC", "-O2", str(src), "-o", str(lib)]) if result["returncode"] != 0: raise RuntimeError(json.dumps(result, indent=2)) return src, lib, result def build_engine(path: Path, plugin_lib: Path): logger = trt.Logger(trt.Logger.WARNING) builder = trt.Builder(logger) network = builder.create_network( 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) ) inp = network.add_input("x", trt.float32, (1, 1)) identity = network.add_identity(inp) identity.get_output(0).name = "y" network.mark_output(identity.get_output(0)) config = builder.create_builder_config() config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 20) config.set_flag(trt.BuilderFlag.VERSION_COMPATIBLE) config.plugins_to_serialize = [str(plugin_lib)] serialized = builder.build_serialized_network(network, config) if serialized is None: raise RuntimeError("failed to build serialized plugin engine") path.write_bytes(bytes(serialized)) return { "path": str(path), "size": path.stat().st_size, "plugins_to_serialize": [str(plugin_lib)], } def try_deserialize(path: Path, marker: Path, allow_host_code: bool): logger = trt.Logger(trt.Logger.WARNING) runtime = trt.Runtime(logger) if hasattr(runtime, "engine_host_code_allowed"): runtime.engine_host_code_allowed = allow_host_code before = marker.read_text() if marker.exists() else "" try: engine = runtime.deserialize_cuda_engine(path.read_bytes()) ok = engine is not None if engine is not None: _ = engine.num_io_tensors exc = None except Exception as err: ok = False exc = f"{type(err).__name__}: {err}" after = marker.read_text() if marker.exists() else "" return { "allow_host_code": allow_host_code, "ok": ok, "exception": exc, "marker_changed": after != before, "marker_after": after, } def polygraphy_inspect(path: Path, marker: Path): env = os.environ.copy() env["TRT_PLUGIN_MARKER"] = str(marker) before = marker.read_text() if marker.exists() else "" proc = subprocess.run( [ "polygraphy", "inspect", "model", str(path), "--model-type=engine", "--show", "attrs", ], text=True, capture_output=True, timeout=90, env=env, ) after = marker.read_text() if marker.exists() else "" return { "cmd": [ "polygraphy", "inspect", "model", str(path), "--model-type=engine", "--show", "attrs", ], "returncode": proc.returncode, "stdout_tail": proc.stdout[-4000:], "stderr_tail": proc.stderr[-4000:], "marker_changed": after != before, "marker_after": after, } def main(): parser = argparse.ArgumentParser( description="Test whether TensorRT serialized plugin libraries execute from an engine file." ) parser.add_argument("--out", default="results/trt_serialized_plugin_marker_probe.json") args = parser.parse_args() out = Path(args.out) base = out.parent.parent work = base / "plugin-work" cases = base / "cases" out.parent.mkdir(parents=True, exist_ok=True) work.mkdir(parents=True, exist_ok=True) cases.mkdir(parents=True, exist_ok=True) src, lib, compile_result = compile_marker_lib(work) marker = out.parent / "serialized_plugin_marker.txt" if marker.exists(): marker.unlink() engine = cases / "vc_serialized_marker_plugin.engine" result = { "python": sys.version, "tensorrt_version": trt.__version__, "compile": compile_result, "source": str(src), "plugin_lib": str(lib), } try: result["build"] = build_engine(engine, lib) except Exception as err: result["build_error"] = f"{type(err).__name__}: {err}" out.write_text(json.dumps(result, indent=2, sort_keys=True)) print(json.dumps(result, indent=2, sort_keys=True)) return 2 removed_lib = work / "libmarker_payload.removed" shutil.move(lib, removed_lib) result["plugin_removed_before_load"] = { "original_exists": lib.exists(), "moved_to": str(removed_lib), "moved_exists": removed_lib.exists(), } os.environ["TRT_PLUGIN_MARKER"] = str(marker) result["deserialize_false"] = try_deserialize( engine, marker, allow_host_code=False ) result["deserialize_true"] = try_deserialize( engine, marker, allow_host_code=True ) result["polygraphy_inspect"] = polygraphy_inspect(engine, marker) out.write_text(json.dumps(result, indent=2, sort_keys=True)) print(json.dumps(result, indent=2, sort_keys=True)) return 0 if __name__ == "__main__": raise SystemExit(main())