File size: 6,620 Bytes
357ae2c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 | #!/usr/bin/env python3
"""
Automated timing benchmark for the local ONNX SaT variants (CPU).
For every model under onnx_models/ it measures, over N repeats with warmup:
- session load time (one-off)
- model-only inference time (onnxruntime .run)
- end-to-end time (tokenize + run + word-safe mask + Viterbi max_length chunking)
on short / medium / long inputs, and reports mean ± std and throughput (chars/s).
Results print as a table and are written to onnx_models/benchmark.json.
Usage:
python scripts/benchmark_onnx.py [--repeats 20] [--threads 1] [--max-ms 0]
--threads onnxruntime intra-op threads (1 = closest to a phone core).
--max-ms if > 0, exit non-zero when any end-to-end mean exceeds this many ms
(turns the benchmark into a pass/fail latency regression test).
"""
import argparse
import json
import statistics
import sys
import time
from pathlib import Path
# reuse the segmentation pipeline + onnxruntime bootstrap from run_segmentation
sys.path.insert(0, str(Path(__file__).resolve().parent))
import run_segmentation as rs # noqa: E402 (its import preloads libstdc++ for onnxruntime)
import numpy as np # noqa: E402
import onnxruntime as ort # noqa: E402
# reuse the lightweight helpers from run_segmentation (avoids importing the
# heavy wtpsplit/torch stack, which rs deliberately stubs out)
constrained_segmentation = rs.constrained_segmentation
create_prior_function = rs.create_prior_function
compute_keep_ids = rs.compute_keep_ids
load_tokenizer = rs.load_tokenizer
SHORT = "This is a test. 这是一个测试。It works well!"
MEDIUM = ("Breaking News: Scientists at CERN have announced a groundbreaking discovery "
"that could revolutionize our understanding of particle physics. The team observed "
"unexpected behavior in proton collisions at energies never before achieved.")
LONG = MEDIUM * 4 # ~930 chars, exercises the windowing-sized path
INPUTS = {"short": SHORT, "medium": MEDIUM, "long": LONG}
def _stats(ts_ms):
mean = statistics.fmean(ts_ms)
std = statistics.pstdev(ts_ms) if len(ts_ms) > 1 else 0.0
return mean, std
def make_remap(tokenizer):
return rs.get_remap(tokenizer) # cached to disk
def segment_full(session, tokenizer, text, remap, unk_new, max_length=80, min_length=40):
"""Full pipeline used for end-to-end timing (mirrors run_segmentation)."""
probs = rs.boundary_probs(session, tokenizer, text, remap, unk_new)
probs = rs.word_safe_mask(probs, text)
prior = create_prior_function("gaussian",
{"target_length": 70, "spread": 12, "max_length": max_length})
idx = constrained_segmentation(probs, prior, min_length=min_length,
max_length=max_length, algorithm="viterbi")
return len(idx) + 1
def time_block(fn, repeats):
# warmup
for _ in range(2):
fn()
ts = []
for _ in range(repeats):
t0 = time.perf_counter()
fn()
ts.append((time.perf_counter() - t0) * 1000.0) # ms
return ts
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--repeats", type=int, default=20)
ap.add_argument("--threads", type=int, default=1,
help="onnxruntime intra-op threads (1 ~ single phone core)")
ap.add_argument("--max-ms", type=float, default=0.0,
help="fail if any end-to-end mean exceeds this (0 = report only)")
args = ap.parse_args()
models = rs.find_models(rs.MODELS_DIR)
if not models:
sys.exit(f"No ONNX models under {rs.MODELS_DIR}. Run build_and_test_onnx.py first.")
tokenizer = load_tokenizer()
so = ort.SessionOptions()
so.intra_op_num_threads = args.threads
so.inter_op_num_threads = 1
print(f"onnxruntime {ort.__version__} | threads={args.threads} | repeats={args.repeats}\n")
header = (f"{'model':28s}{'load':>8s}{'infer-S':>9s}{'infer-L':>9s}"
f"{'e2e-S':>8s}{'e2e-M':>8s}{'e2e-L':>8s}{'long c/s':>10s}")
print(header)
print("-" * len(header))
results, failures = [], []
for name, path in models.items():
remap = unk_new = None
if "en_zh" in name:
remap, unk_new = make_remap(tokenizer)
t0 = time.perf_counter()
sess = ort.InferenceSession(str(path), sess_options=so,
providers=["CPUExecutionProvider"])
load_ms = (time.perf_counter() - t0) * 1000.0
# model-only inference: pre-tokenize so we time just .run
feeds = {}
for key, text in INPUTS.items():
ids_list, _, _ = tokenizer.encode(text)
ids = np.array([ids_list], dtype=np.int64)
mask = np.ones_like(ids)
if remap is not None:
ids = remap[ids]
ids[ids == -1] = unk_new
feeds[key] = {"input_ids": ids, "attention_mask": mask}
infer_s, _ = _stats(time_block(lambda: sess.run(["logits"], feeds["short"]), args.repeats))
infer_l, _ = _stats(time_block(lambda: sess.run(["logits"], feeds["long"]), args.repeats))
e2e = {}
for key, text in INPUTS.items():
ts = time_block(lambda t=text: segment_full(sess, tokenizer, t, remap, unk_new),
args.repeats)
e2e[key] = _stats(ts)[0]
cps = len(LONG) / (e2e["long"] / 1000.0) # chars/sec end-to-end on long input
print(f"{name:28s}{load_ms:7.0f}m{infer_s:8.1f}m{infer_l:8.1f}m"
f"{e2e['short']:7.1f}m{e2e['medium']:7.1f}m{e2e['long']:7.1f}m{cps:10.0f}")
rec = {"model": name, "size_mb": round(path.stat().st_size / 1e6, 1),
"load_ms": round(load_ms, 1), "infer_short_ms": round(infer_s, 2),
"infer_long_ms": round(infer_l, 2),
"e2e_short_ms": round(e2e["short"], 2), "e2e_medium_ms": round(e2e["medium"], 2),
"e2e_long_ms": round(e2e["long"], 2), "long_chars_per_s": round(cps)}
results.append(rec)
if args.max_ms and e2e["long"] > args.max_ms:
failures.append((name, e2e["long"]))
print("\n(load = session init, one-off; infer = onnxruntime.run only; "
"e2e = tokenize+run+mask+Viterbi. S/M/L = short/medium/long.)")
out = rs.MODELS_DIR / "benchmark.json"
json.dump(results, open(out, "w"), indent=2)
print(f"Wrote {out}")
if failures:
print("\nLATENCY FAIL (end-to-end long > "
f"{args.max_ms} ms):")
for n, ms in failures:
print(f" {n}: {ms:.1f} ms")
sys.exit(1)
if __name__ == "__main__":
main()
|