#!/usr/bin/env python3 """ Automated timing benchmark for the local ONNX SaT variants (CPU). For every model under onnx_models/ it measures, over N repeats with warmup: - session load time (one-off) - model-only inference time (onnxruntime .run) - end-to-end time (tokenize + run + word-safe mask + Viterbi max_length chunking) on short / medium / long inputs, and reports mean ± std and throughput (chars/s). Results print as a table and are written to onnx_models/benchmark.json. Usage: python scripts/benchmark_onnx.py [--repeats 20] [--threads 1] [--max-ms 0] --threads onnxruntime intra-op threads (1 = closest to a phone core). --max-ms if > 0, exit non-zero when any end-to-end mean exceeds this many ms (turns the benchmark into a pass/fail latency regression test). """ import argparse import json import statistics import sys import time from pathlib import Path # reuse the segmentation pipeline + onnxruntime bootstrap from run_segmentation sys.path.insert(0, str(Path(__file__).resolve().parent)) import run_segmentation as rs # noqa: E402 (its import preloads libstdc++ for onnxruntime) import numpy as np # noqa: E402 import onnxruntime as ort # noqa: E402 # reuse the lightweight helpers from run_segmentation (avoids importing the # heavy wtpsplit/torch stack, which rs deliberately stubs out) constrained_segmentation = rs.constrained_segmentation create_prior_function = rs.create_prior_function compute_keep_ids = rs.compute_keep_ids load_tokenizer = rs.load_tokenizer SHORT = "This is a test. 这是一个测试。It works well!" MEDIUM = ("Breaking News: Scientists at CERN have announced a groundbreaking discovery " "that could revolutionize our understanding of particle physics. The team observed " "unexpected behavior in proton collisions at energies never before achieved.") LONG = MEDIUM * 4 # ~930 chars, exercises the windowing-sized path INPUTS = {"short": SHORT, "medium": MEDIUM, "long": LONG} def _stats(ts_ms): mean = statistics.fmean(ts_ms) std = statistics.pstdev(ts_ms) if len(ts_ms) > 1 else 0.0 return mean, std def make_remap(tokenizer): return rs.get_remap(tokenizer) # cached to disk def segment_full(session, tokenizer, text, remap, unk_new, max_length=80, min_length=40): """Full pipeline used for end-to-end timing (mirrors run_segmentation).""" probs = rs.boundary_probs(session, tokenizer, text, remap, unk_new) probs = rs.word_safe_mask(probs, text) prior = create_prior_function("gaussian", {"target_length": 70, "spread": 12, "max_length": max_length}) idx = constrained_segmentation(probs, prior, min_length=min_length, max_length=max_length, algorithm="viterbi") return len(idx) + 1 def time_block(fn, repeats): # warmup for _ in range(2): fn() ts = [] for _ in range(repeats): t0 = time.perf_counter() fn() ts.append((time.perf_counter() - t0) * 1000.0) # ms return ts def main(): ap = argparse.ArgumentParser() ap.add_argument("--repeats", type=int, default=20) ap.add_argument("--threads", type=int, default=1, help="onnxruntime intra-op threads (1 ~ single phone core)") ap.add_argument("--max-ms", type=float, default=0.0, help="fail if any end-to-end mean exceeds this (0 = report only)") args = ap.parse_args() models = rs.find_models(rs.MODELS_DIR) if not models: sys.exit(f"No ONNX models under {rs.MODELS_DIR}. Run build_and_test_onnx.py first.") tokenizer = load_tokenizer() so = ort.SessionOptions() so.intra_op_num_threads = args.threads so.inter_op_num_threads = 1 print(f"onnxruntime {ort.__version__} | threads={args.threads} | repeats={args.repeats}\n") header = (f"{'model':28s}{'load':>8s}{'infer-S':>9s}{'infer-L':>9s}" f"{'e2e-S':>8s}{'e2e-M':>8s}{'e2e-L':>8s}{'long c/s':>10s}") print(header) print("-" * len(header)) results, failures = [], [] for name, path in models.items(): remap = unk_new = None if "en_zh" in name: remap, unk_new = make_remap(tokenizer) t0 = time.perf_counter() sess = ort.InferenceSession(str(path), sess_options=so, providers=["CPUExecutionProvider"]) load_ms = (time.perf_counter() - t0) * 1000.0 # model-only inference: pre-tokenize so we time just .run feeds = {} for key, text in INPUTS.items(): ids_list, _, _ = tokenizer.encode(text) ids = np.array([ids_list], dtype=np.int64) mask = np.ones_like(ids) if remap is not None: ids = remap[ids] ids[ids == -1] = unk_new feeds[key] = {"input_ids": ids, "attention_mask": mask} infer_s, _ = _stats(time_block(lambda: sess.run(["logits"], feeds["short"]), args.repeats)) infer_l, _ = _stats(time_block(lambda: sess.run(["logits"], feeds["long"]), args.repeats)) e2e = {} for key, text in INPUTS.items(): ts = time_block(lambda t=text: segment_full(sess, tokenizer, t, remap, unk_new), args.repeats) e2e[key] = _stats(ts)[0] cps = len(LONG) / (e2e["long"] / 1000.0) # chars/sec end-to-end on long input print(f"{name:28s}{load_ms:7.0f}m{infer_s:8.1f}m{infer_l:8.1f}m" f"{e2e['short']:7.1f}m{e2e['medium']:7.1f}m{e2e['long']:7.1f}m{cps:10.0f}") rec = {"model": name, "size_mb": round(path.stat().st_size / 1e6, 1), "load_ms": round(load_ms, 1), "infer_short_ms": round(infer_s, 2), "infer_long_ms": round(infer_l, 2), "e2e_short_ms": round(e2e["short"], 2), "e2e_medium_ms": round(e2e["medium"], 2), "e2e_long_ms": round(e2e["long"], 2), "long_chars_per_s": round(cps)} results.append(rec) if args.max_ms and e2e["long"] > args.max_ms: failures.append((name, e2e["long"])) print("\n(load = session init, one-off; infer = onnxruntime.run only; " "e2e = tokenize+run+mask+Viterbi. S/M/L = short/medium/long.)") out = rs.MODELS_DIR / "benchmark.json" json.dump(results, open(out, "w"), indent=2) print(f"Wrote {out}") if failures: print("\nLATENCY FAIL (end-to-end long > " f"{args.max_ms} ms):") for n, ms in failures: print(f" {n}: {ms:.1f} ms") sys.exit(1) if __name__ == "__main__": main()