| |
| import argparse |
| import json |
| import os |
| import shutil |
| import subprocess |
| import sys |
| import textwrap |
| from pathlib import Path |
|
|
| import tensorrt as trt |
|
|
|
|
| C_SOURCE = r""" |
| #include <stdio.h> |
| #include <stdlib.h> |
| #include <time.h> |
| #include <unistd.h> |
| |
| __attribute__((constructor)) |
| static void marker_constructor(void) { |
| const char *path = getenv("TRT_PLUGIN_MARKER"); |
| if (!path || !path[0]) { |
| return; |
| } |
| FILE *f = fopen(path, "a"); |
| if (!f) { |
| return; |
| } |
| time_t now = time(NULL); |
| fprintf(f, "marker_constructor pid=%ld time=%ld\n", (long)getpid(), (long)now); |
| fclose(f); |
| } |
| |
| extern "C" int trt_marker_export(void) { |
| return 1337; |
| } |
| """ |
|
|
|
|
| def run(cmd, **kwargs): |
| proc = subprocess.run(cmd, text=True, capture_output=True, **kwargs) |
| return { |
| "cmd": cmd, |
| "returncode": proc.returncode, |
| "stdout_tail": proc.stdout[-4000:], |
| "stderr_tail": proc.stderr[-4000:], |
| } |
|
|
|
|
| def compile_marker_lib(work: Path): |
| src = work / "marker_payload.cpp" |
| lib = work / "libmarker_payload.so" |
| src.write_text(C_SOURCE) |
| result = run(["g++", "-shared", "-fPIC", "-O2", str(src), "-o", str(lib)]) |
| if result["returncode"] != 0: |
| raise RuntimeError(json.dumps(result, indent=2)) |
| return src, lib, result |
|
|
|
|
| def build_engine(path: Path, plugin_lib: Path): |
| logger = trt.Logger(trt.Logger.WARNING) |
| builder = trt.Builder(logger) |
| network = builder.create_network( |
| 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) |
| ) |
| inp = network.add_input("x", trt.float32, (1, 1)) |
| identity = network.add_identity(inp) |
| identity.get_output(0).name = "y" |
| network.mark_output(identity.get_output(0)) |
|
|
| config = builder.create_builder_config() |
| config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 20) |
| config.set_flag(trt.BuilderFlag.VERSION_COMPATIBLE) |
| config.plugins_to_serialize = [str(plugin_lib)] |
|
|
| serialized = builder.build_serialized_network(network, config) |
| if serialized is None: |
| raise RuntimeError("failed to build serialized plugin engine") |
| path.write_bytes(bytes(serialized)) |
| return { |
| "path": str(path), |
| "size": path.stat().st_size, |
| "plugins_to_serialize": [str(plugin_lib)], |
| } |
|
|
|
|
| def try_deserialize(path: Path, marker: Path, allow_host_code: bool): |
| logger = trt.Logger(trt.Logger.WARNING) |
| runtime = trt.Runtime(logger) |
| if hasattr(runtime, "engine_host_code_allowed"): |
| runtime.engine_host_code_allowed = allow_host_code |
|
|
| before = marker.read_text() if marker.exists() else "" |
| try: |
| engine = runtime.deserialize_cuda_engine(path.read_bytes()) |
| ok = engine is not None |
| if engine is not None: |
| _ = engine.num_io_tensors |
| exc = None |
| except Exception as err: |
| ok = False |
| exc = f"{type(err).__name__}: {err}" |
| after = marker.read_text() if marker.exists() else "" |
| return { |
| "allow_host_code": allow_host_code, |
| "ok": ok, |
| "exception": exc, |
| "marker_changed": after != before, |
| "marker_after": after, |
| } |
|
|
|
|
| def polygraphy_inspect(path: Path, marker: Path): |
| env = os.environ.copy() |
| env["TRT_PLUGIN_MARKER"] = str(marker) |
| before = marker.read_text() if marker.exists() else "" |
| proc = subprocess.run( |
| [ |
| "polygraphy", |
| "inspect", |
| "model", |
| str(path), |
| "--model-type=engine", |
| "--show", |
| "attrs", |
| ], |
| text=True, |
| capture_output=True, |
| timeout=90, |
| env=env, |
| ) |
| after = marker.read_text() if marker.exists() else "" |
| return { |
| "cmd": [ |
| "polygraphy", |
| "inspect", |
| "model", |
| str(path), |
| "--model-type=engine", |
| "--show", |
| "attrs", |
| ], |
| "returncode": proc.returncode, |
| "stdout_tail": proc.stdout[-4000:], |
| "stderr_tail": proc.stderr[-4000:], |
| "marker_changed": after != before, |
| "marker_after": after, |
| } |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser( |
| description="Test whether TensorRT serialized plugin libraries execute from an engine file." |
| ) |
| parser.add_argument("--out", default="results/trt_serialized_plugin_marker_probe.json") |
| args = parser.parse_args() |
|
|
| out = Path(args.out) |
| base = out.parent.parent |
| work = base / "plugin-work" |
| cases = base / "cases" |
| out.parent.mkdir(parents=True, exist_ok=True) |
| work.mkdir(parents=True, exist_ok=True) |
| cases.mkdir(parents=True, exist_ok=True) |
|
|
| src, lib, compile_result = compile_marker_lib(work) |
| marker = out.parent / "serialized_plugin_marker.txt" |
| if marker.exists(): |
| marker.unlink() |
|
|
| engine = cases / "vc_serialized_marker_plugin.engine" |
| result = { |
| "python": sys.version, |
| "tensorrt_version": trt.__version__, |
| "compile": compile_result, |
| "source": str(src), |
| "plugin_lib": str(lib), |
| } |
|
|
| try: |
| result["build"] = build_engine(engine, lib) |
| except Exception as err: |
| result["build_error"] = f"{type(err).__name__}: {err}" |
| out.write_text(json.dumps(result, indent=2, sort_keys=True)) |
| print(json.dumps(result, indent=2, sort_keys=True)) |
| return 2 |
|
|
| removed_lib = work / "libmarker_payload.removed" |
| shutil.move(lib, removed_lib) |
| result["plugin_removed_before_load"] = { |
| "original_exists": lib.exists(), |
| "moved_to": str(removed_lib), |
| "moved_exists": removed_lib.exists(), |
| } |
|
|
| os.environ["TRT_PLUGIN_MARKER"] = str(marker) |
| result["deserialize_false"] = try_deserialize( |
| engine, marker, allow_host_code=False |
| ) |
| result["deserialize_true"] = try_deserialize( |
| engine, marker, allow_host_code=True |
| ) |
| result["polygraphy_inspect"] = polygraphy_inspect(engine, marker) |
|
|
| out.write_text(json.dumps(result, indent=2, sort_keys=True)) |
| print(json.dumps(result, indent=2, sort_keys=True)) |
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| raise SystemExit(main()) |
|
|