File size: 6,620 Bytes
357ae2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
#!/usr/bin/env python3
"""
Automated timing benchmark for the local ONNX SaT variants (CPU).

For every model under onnx_models/ it measures, over N repeats with warmup:
  - session load time (one-off)
  - model-only inference time (onnxruntime .run)
  - end-to-end time (tokenize + run + word-safe mask + Viterbi max_length chunking)
on short / medium / long inputs, and reports mean ± std and throughput (chars/s).

Results print as a table and are written to onnx_models/benchmark.json.

Usage:
  python scripts/benchmark_onnx.py [--repeats 20] [--threads 1] [--max-ms 0]

  --threads  onnxruntime intra-op threads (1 = closest to a phone core).
  --max-ms   if > 0, exit non-zero when any end-to-end mean exceeds this many ms
             (turns the benchmark into a pass/fail latency regression test).
"""
import argparse
import json
import statistics
import sys
import time
from pathlib import Path

# reuse the segmentation pipeline + onnxruntime bootstrap from run_segmentation
sys.path.insert(0, str(Path(__file__).resolve().parent))
import run_segmentation as rs  # noqa: E402  (its import preloads libstdc++ for onnxruntime)

import numpy as np  # noqa: E402
import onnxruntime as ort  # noqa: E402

# reuse the lightweight helpers from run_segmentation (avoids importing the
# heavy wtpsplit/torch stack, which rs deliberately stubs out)
constrained_segmentation = rs.constrained_segmentation
create_prior_function = rs.create_prior_function
compute_keep_ids = rs.compute_keep_ids
load_tokenizer = rs.load_tokenizer

SHORT = "This is a test. 这是一个测试。It works well!"
MEDIUM = ("Breaking News: Scientists at CERN have announced a groundbreaking discovery "
          "that could revolutionize our understanding of particle physics. The team observed "
          "unexpected behavior in proton collisions at energies never before achieved.")
LONG = MEDIUM * 4  # ~930 chars, exercises the windowing-sized path
INPUTS = {"short": SHORT, "medium": MEDIUM, "long": LONG}


def _stats(ts_ms):
    mean = statistics.fmean(ts_ms)
    std = statistics.pstdev(ts_ms) if len(ts_ms) > 1 else 0.0
    return mean, std


def make_remap(tokenizer):
    return rs.get_remap(tokenizer)  # cached to disk


def segment_full(session, tokenizer, text, remap, unk_new, max_length=80, min_length=40):
    """Full pipeline used for end-to-end timing (mirrors run_segmentation)."""
    probs = rs.boundary_probs(session, tokenizer, text, remap, unk_new)
    probs = rs.word_safe_mask(probs, text)
    prior = create_prior_function("gaussian",
                                  {"target_length": 70, "spread": 12, "max_length": max_length})
    idx = constrained_segmentation(probs, prior, min_length=min_length,
                                   max_length=max_length, algorithm="viterbi")
    return len(idx) + 1


def time_block(fn, repeats):
    # warmup
    for _ in range(2):
        fn()
    ts = []
    for _ in range(repeats):
        t0 = time.perf_counter()
        fn()
        ts.append((time.perf_counter() - t0) * 1000.0)  # ms
    return ts


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--repeats", type=int, default=20)
    ap.add_argument("--threads", type=int, default=1,
                    help="onnxruntime intra-op threads (1 ~ single phone core)")
    ap.add_argument("--max-ms", type=float, default=0.0,
                    help="fail if any end-to-end mean exceeds this (0 = report only)")
    args = ap.parse_args()

    models = rs.find_models(rs.MODELS_DIR)
    if not models:
        sys.exit(f"No ONNX models under {rs.MODELS_DIR}. Run build_and_test_onnx.py first.")

    tokenizer = load_tokenizer()
    so = ort.SessionOptions()
    so.intra_op_num_threads = args.threads
    so.inter_op_num_threads = 1

    print(f"onnxruntime {ort.__version__} | threads={args.threads} | repeats={args.repeats}\n")
    header = (f"{'model':28s}{'load':>8s}{'infer-S':>9s}{'infer-L':>9s}"
              f"{'e2e-S':>8s}{'e2e-M':>8s}{'e2e-L':>8s}{'long c/s':>10s}")
    print(header)
    print("-" * len(header))

    results, failures = [], []
    for name, path in models.items():
        remap = unk_new = None
        if "en_zh" in name:
            remap, unk_new = make_remap(tokenizer)

        t0 = time.perf_counter()
        sess = ort.InferenceSession(str(path), sess_options=so,
                                    providers=["CPUExecutionProvider"])
        load_ms = (time.perf_counter() - t0) * 1000.0

        # model-only inference: pre-tokenize so we time just .run
        feeds = {}
        for key, text in INPUTS.items():
            ids_list, _, _ = tokenizer.encode(text)
            ids = np.array([ids_list], dtype=np.int64)
            mask = np.ones_like(ids)
            if remap is not None:
                ids = remap[ids]
                ids[ids == -1] = unk_new
            feeds[key] = {"input_ids": ids, "attention_mask": mask}

        infer_s, _ = _stats(time_block(lambda: sess.run(["logits"], feeds["short"]), args.repeats))
        infer_l, _ = _stats(time_block(lambda: sess.run(["logits"], feeds["long"]), args.repeats))

        e2e = {}
        for key, text in INPUTS.items():
            ts = time_block(lambda t=text: segment_full(sess, tokenizer, t, remap, unk_new),
                            args.repeats)
            e2e[key] = _stats(ts)[0]

        cps = len(LONG) / (e2e["long"] / 1000.0)  # chars/sec end-to-end on long input
        print(f"{name:28s}{load_ms:7.0f}m{infer_s:8.1f}m{infer_l:8.1f}m"
              f"{e2e['short']:7.1f}m{e2e['medium']:7.1f}m{e2e['long']:7.1f}m{cps:10.0f}")

        rec = {"model": name, "size_mb": round(path.stat().st_size / 1e6, 1),
               "load_ms": round(load_ms, 1), "infer_short_ms": round(infer_s, 2),
               "infer_long_ms": round(infer_l, 2),
               "e2e_short_ms": round(e2e["short"], 2), "e2e_medium_ms": round(e2e["medium"], 2),
               "e2e_long_ms": round(e2e["long"], 2), "long_chars_per_s": round(cps)}
        results.append(rec)
        if args.max_ms and e2e["long"] > args.max_ms:
            failures.append((name, e2e["long"]))

    print("\n(load = session init, one-off; infer = onnxruntime.run only; "
          "e2e = tokenize+run+mask+Viterbi. S/M/L = short/medium/long.)")
    out = rs.MODELS_DIR / "benchmark.json"
    json.dump(results, open(out, "w"), indent=2)
    print(f"Wrote {out}")

    if failures:
        print("\nLATENCY FAIL (end-to-end long > "
              f"{args.max_ms} ms):")
        for n, ms in failures:
            print(f"  {n}: {ms:.1f} ms")
        sys.exit(1)


if __name__ == "__main__":
    main()