wtpsplit-kit / scripts /benchmark_onnx.py
krmanik's picture
Upload folder using huggingface_hub
357ae2c verified
#!/usr/bin/env python3
"""
Automated timing benchmark for the local ONNX SaT variants (CPU).
For every model under onnx_models/ it measures, over N repeats with warmup:
- session load time (one-off)
- model-only inference time (onnxruntime .run)
- end-to-end time (tokenize + run + word-safe mask + Viterbi max_length chunking)
on short / medium / long inputs, and reports mean ± std and throughput (chars/s).
Results print as a table and are written to onnx_models/benchmark.json.
Usage:
python scripts/benchmark_onnx.py [--repeats 20] [--threads 1] [--max-ms 0]
--threads onnxruntime intra-op threads (1 = closest to a phone core).
--max-ms if > 0, exit non-zero when any end-to-end mean exceeds this many ms
(turns the benchmark into a pass/fail latency regression test).
"""
import argparse
import json
import statistics
import sys
import time
from pathlib import Path
# reuse the segmentation pipeline + onnxruntime bootstrap from run_segmentation
sys.path.insert(0, str(Path(__file__).resolve().parent))
import run_segmentation as rs # noqa: E402 (its import preloads libstdc++ for onnxruntime)
import numpy as np # noqa: E402
import onnxruntime as ort # noqa: E402
# reuse the lightweight helpers from run_segmentation (avoids importing the
# heavy wtpsplit/torch stack, which rs deliberately stubs out)
constrained_segmentation = rs.constrained_segmentation
create_prior_function = rs.create_prior_function
compute_keep_ids = rs.compute_keep_ids
load_tokenizer = rs.load_tokenizer
SHORT = "This is a test. 这是一个测试。It works well!"
MEDIUM = ("Breaking News: Scientists at CERN have announced a groundbreaking discovery "
"that could revolutionize our understanding of particle physics. The team observed "
"unexpected behavior in proton collisions at energies never before achieved.")
LONG = MEDIUM * 4 # ~930 chars, exercises the windowing-sized path
INPUTS = {"short": SHORT, "medium": MEDIUM, "long": LONG}
def _stats(ts_ms):
mean = statistics.fmean(ts_ms)
std = statistics.pstdev(ts_ms) if len(ts_ms) > 1 else 0.0
return mean, std
def make_remap(tokenizer):
return rs.get_remap(tokenizer) # cached to disk
def segment_full(session, tokenizer, text, remap, unk_new, max_length=80, min_length=40):
"""Full pipeline used for end-to-end timing (mirrors run_segmentation)."""
probs = rs.boundary_probs(session, tokenizer, text, remap, unk_new)
probs = rs.word_safe_mask(probs, text)
prior = create_prior_function("gaussian",
{"target_length": 70, "spread": 12, "max_length": max_length})
idx = constrained_segmentation(probs, prior, min_length=min_length,
max_length=max_length, algorithm="viterbi")
return len(idx) + 1
def time_block(fn, repeats):
# warmup
for _ in range(2):
fn()
ts = []
for _ in range(repeats):
t0 = time.perf_counter()
fn()
ts.append((time.perf_counter() - t0) * 1000.0) # ms
return ts
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--repeats", type=int, default=20)
ap.add_argument("--threads", type=int, default=1,
help="onnxruntime intra-op threads (1 ~ single phone core)")
ap.add_argument("--max-ms", type=float, default=0.0,
help="fail if any end-to-end mean exceeds this (0 = report only)")
args = ap.parse_args()
models = rs.find_models(rs.MODELS_DIR)
if not models:
sys.exit(f"No ONNX models under {rs.MODELS_DIR}. Run build_and_test_onnx.py first.")
tokenizer = load_tokenizer()
so = ort.SessionOptions()
so.intra_op_num_threads = args.threads
so.inter_op_num_threads = 1
print(f"onnxruntime {ort.__version__} | threads={args.threads} | repeats={args.repeats}\n")
header = (f"{'model':28s}{'load':>8s}{'infer-S':>9s}{'infer-L':>9s}"
f"{'e2e-S':>8s}{'e2e-M':>8s}{'e2e-L':>8s}{'long c/s':>10s}")
print(header)
print("-" * len(header))
results, failures = [], []
for name, path in models.items():
remap = unk_new = None
if "en_zh" in name:
remap, unk_new = make_remap(tokenizer)
t0 = time.perf_counter()
sess = ort.InferenceSession(str(path), sess_options=so,
providers=["CPUExecutionProvider"])
load_ms = (time.perf_counter() - t0) * 1000.0
# model-only inference: pre-tokenize so we time just .run
feeds = {}
for key, text in INPUTS.items():
ids_list, _, _ = tokenizer.encode(text)
ids = np.array([ids_list], dtype=np.int64)
mask = np.ones_like(ids)
if remap is not None:
ids = remap[ids]
ids[ids == -1] = unk_new
feeds[key] = {"input_ids": ids, "attention_mask": mask}
infer_s, _ = _stats(time_block(lambda: sess.run(["logits"], feeds["short"]), args.repeats))
infer_l, _ = _stats(time_block(lambda: sess.run(["logits"], feeds["long"]), args.repeats))
e2e = {}
for key, text in INPUTS.items():
ts = time_block(lambda t=text: segment_full(sess, tokenizer, t, remap, unk_new),
args.repeats)
e2e[key] = _stats(ts)[0]
cps = len(LONG) / (e2e["long"] / 1000.0) # chars/sec end-to-end on long input
print(f"{name:28s}{load_ms:7.0f}m{infer_s:8.1f}m{infer_l:8.1f}m"
f"{e2e['short']:7.1f}m{e2e['medium']:7.1f}m{e2e['long']:7.1f}m{cps:10.0f}")
rec = {"model": name, "size_mb": round(path.stat().st_size / 1e6, 1),
"load_ms": round(load_ms, 1), "infer_short_ms": round(infer_s, 2),
"infer_long_ms": round(infer_l, 2),
"e2e_short_ms": round(e2e["short"], 2), "e2e_medium_ms": round(e2e["medium"], 2),
"e2e_long_ms": round(e2e["long"], 2), "long_chars_per_s": round(cps)}
results.append(rec)
if args.max_ms and e2e["long"] > args.max_ms:
failures.append((name, e2e["long"]))
print("\n(load = session init, one-off; infer = onnxruntime.run only; "
"e2e = tokenize+run+mask+Viterbi. S/M/L = short/medium/long.)")
out = rs.MODELS_DIR / "benchmark.json"
json.dump(results, open(out, "w"), indent=2)
print(f"Wrote {out}")
if failures:
print("\nLATENCY FAIL (end-to-end long > "
f"{args.max_ms} ms):")
for n, ms in failures:
print(f" {n}: {ms:.1f} ms")
sys.exit(1)
if __name__ == "__main__":
main()