File size: 15,005 Bytes
357ae2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
#!/usr/bin/env python3
"""
Pick an ONNX SaT model and segment text with it (local CPU test).

Examples:
    # interactive: choose a model from a menu, then type/paste text
    python scripts/run_segmentation.py

    # one-shot
    python scripts/run_segmentation.py --model sat-1l-sm-en_zh-int8 \
        --text "Your text here. 这是中文。" --max-length 80

    # from a file, tighter Chinese-style budget
    python scripts/run_segmentation.py -m sat-3l-sm-en_zh-int8 -f test.txt \
        --max-length 40 --min-length 15

Notes:
- "*-en_zh-*" models use a pruned vocab; the id-remap is recomputed on the fly
  (deterministic from the tokenizer), so no extra files are needed.
- onnxruntime needs the conda libstdc++ on this box; the script auto-preloads it
  and re-execs once if needed.
"""
import argparse
import math
import os
import re
import string
import sys
from pathlib import Path

# --- bootstrap: onnxruntime needs conda's libstdc++ preloaded on this machine ---
def _ensure_onnxruntime():
    import contextlib
    import io
    # Probe quietly: a failed import dumps a long numpy/GLIBCXX message to stderr.
    try:
        with contextlib.redirect_stderr(io.StringIO()):
            import onnxruntime  # noqa
        return
    except Exception:
        prefix = os.environ.get("CONDA_PREFIX") or sys.prefix
        lib = Path(prefix) / "lib" / "libstdc++.so.6"
        if lib.exists() and os.environ.get("_ORT_PRELOADED") != "1":
            os.environ["LD_PRELOAD"] = f"{lib}:{os.environ.get('LD_PRELOAD','')}".strip(":")
            os.environ["_ORT_PRELOADED"] = "1"
            os.execv(sys.executable, [sys.executable] + sys.argv)
        raise


_ensure_onnxruntime()

import importlib.util  # noqa: E402
import types  # noqa: E402

import numpy as np  # noqa: E402
import onnxruntime as ort  # noqa: E402

NEWLINE_INDEX = 0
ROOT = Path(__file__).resolve().parent.parent
MODELS_DIR = ROOT / "onnx_models"


# --- load the two tiny pure-numpy helper modules WITHOUT importing the heavy
#     wtpsplit package (which pulls torch/onnx/skops and costs ~5s on startup).
#     constraints.py references wtpsplit.utils.indices_to_sentences but
#     constrained_segmentation() never calls it, so we stub that one symbol. ---
def _load_light(path, name):
    if "wtpsplit" not in sys.modules:
        pkg = types.ModuleType("wtpsplit"); pkg.__path__ = []
        utils = types.ModuleType("wtpsplit.utils"); utils.__path__ = []
        utils.indices_to_sentences = lambda *a, **k: None  # unused here
        sys.modules["wtpsplit"] = pkg
        sys.modules["wtpsplit.utils"] = utils
    spec = importlib.util.spec_from_file_location(name, path)
    mod = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(mod)
    return mod


_WT_UTILS = ROOT / "wtpsplit" / "utils"
constrained_segmentation = _load_light(_WT_UTILS / "constraints.py",
                                       "onnxseg_constraints").constrained_segmentation
create_prior_function = _load_light(_WT_UTILS / "priors.py",
                                    "onnxseg_priors").create_prior_function

def get_token_spans(offsets_mapping, tokens, special_tokens):
    valid = np.array([i for i, t in enumerate(tokens)
                      if i < len(offsets_mapping) and t not in special_tokens])
    return valid, np.array(offsets_mapping)[valid]


def token_to_char_probs(text, tokens, token_logits, special_tokens, offsets_mapping):
    char_probs = np.full((len(text), token_logits.shape[1]), -np.inf)
    vi, vo = get_token_spans(offsets_mapping, tokens, special_tokens)
    char_probs[vo[:, 1] - 1] = token_logits[vi]
    return char_probs


_TOK_CACHE = Path(__file__).resolve().parent / ".xlmr_tokenizer" / "tokenizer.json"


class FastTok:
    """Thin wrapper over the `tokenizers` Rust lib (loads in ~0.4s vs ~4.3s for
    transformers + AutoTokenizer). Exposes only what this script needs."""

    def __init__(self, tok):
        self._t = tok
        self.special_tokens = {"<s>", "</s>", "<pad>", "<unk>", "<mask>"}
        self.unk_token_id = tok.token_to_id("<unk>")
        self.all_special_ids = [tok.token_to_id(s) for s in self.special_tokens
                                if tok.token_to_id(s) is not None]

    def encode(self, text):
        e = self._t.encode(text)  # XLM-R template adds <s> ... </s>
        return e.ids, e.offsets, e.tokens

    def get_vocab(self):
        return self._t.get_vocab()


def load_tokenizer():
    """Return a FastTok. Builds the fast tokenizer.json cache via transformers
    only once (first ever run); afterwards loads via the `tokenizers` lib alone,
    so transformers/torch are never imported."""
    from tokenizers import Tokenizer
    if not _TOK_CACHE.exists():
        from transformers import AutoTokenizer  # lazy: only on first build
        AutoTokenizer.from_pretrained("xlm-roberta-base").save_pretrained(
            str(_TOK_CACHE.parent))
    return FastTok(Tokenizer.from_file(str(_TOK_CACHE)))


def compute_keep_ids(tokenizer):
    """EN+ZH keep-set: ASCII or CJK tokens, plus specials (pure-stdlib, fast)."""
    keep = set(tokenizer.all_special_ids)
    for tok, idx in tokenizer.get_vocab().items():
        s = tok.replace("▁", " ")  # SP underscore -> space
        if all(ord(c) < 128 for c in s) or any(_is_cjk(c) for c in s):
            keep.add(idx)
    return sorted(keep)


def get_remap(tokenizer):
    """old->new id map for EN+ZH pruning, cached to disk (.npy)."""
    cache = MODELS_DIR / "remap_en_zh.npy"
    if cache.exists():
        remap = np.load(cache)
    else:
        keep = compute_keep_ids(tokenizer)
        remap = np.full(250002, -1, dtype=np.int64)
        for new_id, old_id in enumerate(keep):
            remap[old_id] = new_id
        MODELS_DIR.mkdir(parents=True, exist_ok=True)
        np.save(cache, remap)
    return remap, int(remap[tokenizer.unk_token_id])


def find_models(root: Path):
    """Return {display_name: onnx_path} for every .onnx under onnx_models/."""
    out = {}
    for p in sorted(root.rglob("*.onnx")):
        variant = p.parent.name                      # e.g. sat-1l-sm-en_zh
        quant = "int8" if ".int8." in p.name else "fp32"
        out[f"{variant}-{quant}"] = p
    return out


def choose_model(models: dict):
    names = list(models)
    print("\nAvailable ONNX models:")
    for i, n in enumerate(names, 1):
        mb = models[n].stat().st_size / 1e6
        print(f"  {i:2d}) {n:30s} {mb:7.1f} MB")
    while True:
        sel = input("\nSelect model [number or name]: ").strip()
        if sel.isdigit() and 1 <= int(sel) <= len(names):
            return names[int(sel) - 1]
        if sel in models:
            return sel
        print("  invalid choice, try again")


def get_text(args):
    if args.text:
        return args.text
    if args.file:
        return Path(args.file).read_text(encoding="utf-8")
    print("\nEnter/paste text, then Ctrl-D (Ctrl-Z on Windows) to finish:")
    data = sys.stdin.read()
    return data if data.strip() else (
        "Breaking News: Scientists announced a discovery. 这是一个测试。It works well!")


CJK_RANGES = [(0x4E00, 0x9FFF), (0x3400, 0x4DBF), (0xF900, 0xFAFF),
              (0x3000, 0x303F), (0xFF00, 0xFFEF)]


def _is_cjk(ch):
    cp = ord(ch)
    return any(a <= cp <= b for a, b in CJK_RANGES)


# Punctuation that marks a prosodic pause, by strength (used as break-priority
# floors when a long sentence must be split below max_length). Sentence-ending
# punctuation is intentionally NOT floored here -- the model already predicts
# those boundaries well, and overriding it would create false breaks after
# abbreviations like "A.I.".
CLAUSE_PUNCT = set(",;:)]}—–"           # , ; : ) ] } em/en-dash
                   ",、;:"      # CJK , 、 ; :
                   "”’")                 # closing “ ” ’
CJK_SENT_PUNCT = set("。!?…")   # 。 ! ? …

# Words that introduce a clause/phrase: breaking *before* one of these sounds
# more natural than a random word gap when a long span has no punctuation.
CONNECTORS = {
    "and", "but", "or", "nor", "yet", "so", "for",
    "which", "that", "who", "whom", "whose", "where", "when", "while",
    "because", "although", "though", "since", "if", "unless", "until",
    "after", "before", "as", "than", "whether",
}

FLOOR_CLAUSE = 0.25       # comma / semicolon / colon -> strongly preferred
FLOOR_CONNECTOR = 0.05    # break before "and/which/that..." in a comma-free span
FLOOR_HANZI = 5e-3        # between two Chinese chars (no spaces in zh)
FLOOR_SPACE = 1e-4        # plain word gap -> last-resort break
FORBID = 1e-9             # mid-word -> effectively never


def _connector_break_positions(text):
    """Indices i (break after char i) that sit right before a connector word."""
    pos = set()
    for m in re.finditer(r"\s+(\S+)", text):
        word = m.group(1).strip(string.punctuation).lower()
        if word in CONNECTORS and m.start() - 1 >= 0:
            pos.add(m.start() - 1)  # last char of the preceding word
    return pos


def pause_aware_mask(probs, text):
    """Bias forced breaks toward natural prosodic pauses so TTS doesn't pause
    mid-phrase. probs[i] = boundary prob *after* char i (between i and i+1).

    Model-predicted sentence boundaries (high prob) are preserved as-is and keep
    dominating. For everything else we raise a floor by pause strength:
      clause punctuation (, ; : 、 , …) > connector word (and/which/that) >
      plain word gap,
    and mid-word positions are driven to ~0 so words/abbreviations are never cut.
    The result: long sentences break at the nearest comma/clause in range, then
    before a connecting word, and only at a bare space as a last resort.
    """
    p = probs.copy()
    n = len(text)
    connectors = _connector_break_positions(text)
    for i in range(n - 1):  # never break before end-of-text marker
        ch, nxt = text[i], text[i + 1]
        ends_token = nxt.isspace() or _is_cjk(nxt)
        if ch in CLAUSE_PUNCT and ends_token:
            p[i] = max(p[i], FLOOR_CLAUSE)
        elif ch in CJK_SENT_PUNCT:                       # zh sentence end
            p[i] = max(p[i], 0.9)
        elif i in connectors:                            # break before connector
            p[i] = max(p[i], FLOOR_CONNECTOR)
        elif nxt.isspace() or ch.isspace():              # plain word boundary
            p[i] = max(p[i], FLOOR_SPACE)
        elif _is_cjk(ch) and _is_cjk(nxt):               # between hanzi
            p[i] = max(p[i], FLOOR_HANZI)
        else:                                            # mid-word/abbreviation
            p[i] = min(p[i], FORBID)
    return p


# kept as an alias so existing imports (benchmark) keep working
word_safe_mask = pause_aware_mask


def boundary_probs(session, tokenizer, text, remap, unk_new):
    ids_list, offsets, tokens = tokenizer.encode(text)
    ids = np.array([ids_list], dtype=np.int64)
    mask = np.ones_like(ids)
    if remap is not None:
        ids = remap[ids]
        ids[ids == -1] = unk_new
    logits = session.run(["logits"], {"input_ids": ids, "attention_mask": mask})[0]
    char_logits = token_to_char_probs(text, tokens, logits[0],
                                      tokenizer.special_tokens, offsets)
    return 1.0 / (1.0 + np.exp(-char_logits[:, NEWLINE_INDEX]))


def main():
    ap = argparse.ArgumentParser(description="Segment text with a local ONNX SaT model")
    ap.add_argument("-m", "--model", help="model name (see menu if omitted)")
    ap.add_argument("-t", "--text", help="text to segment")
    ap.add_argument("-f", "--file", help="read text from this file")
    ap.add_argument("--max-length", type=int, default=80, help="target max chars per chunk")
    ap.add_argument("--min-length", type=int, default=40, help="min chars per chunk")
    ap.add_argument("--overflow", type=int, default=0,
                    help="chars a chunk may exceed --max-length to reach a comma/"
                         "clause/sentence pause (soft cap; 0 = hard cap)")
    ap.add_argument("--prior", default="gaussian",
                    choices=["uniform", "gaussian", "clipped_polynomial"])
    ap.add_argument("--target", type=int, default=70, help="gaussian target length")
    ap.add_argument("--spread", type=int, default=12, help="gaussian spread")
    ap.add_argument("--algorithm", default="viterbi", choices=["viterbi", "greedy"])
    ap.add_argument("--allow-midword", action="store_true",
                    help="permit breaks inside words/abbreviations (off by default)")
    args = ap.parse_args()

    models = find_models(MODELS_DIR)
    if not models:
        sys.exit(f"No ONNX models found under {MODELS_DIR}. Run build_and_test_onnx.py first.")

    name = args.model or choose_model(models)
    if name not in models:
        sys.exit(f"Unknown model '{name}'. Choices: {', '.join(models)}")
    path = models[name]

    tokenizer = load_tokenizer()
    remap = unk_new = None
    if "en_zh" in name:
        remap, unk_new = get_remap(tokenizer)

    session = ort.InferenceSession(str(path), providers=["CPUExecutionProvider"])
    text = get_text(args)

    probs = boundary_probs(session, tokenizer, text, remap, unk_new)
    if not args.allow_midword:
        probs = word_safe_mask(probs, text)

    # Hard ceiling for the DP. With --overflow, allow chunks past --max-length up
    # to this ceiling; a decay tail past --max-length keeps plain spaces from
    # exploiting the slack while still letting a strong pause (comma/sentence)
    # pull the break into the overflow zone.
    hard_max = args.max_length + max(0, args.overflow)
    prior_kwargs = {"max_length": hard_max}
    if args.prior != "uniform":
        prior_kwargs.update(target_length=args.target, spread=args.spread)
    base_prior = create_prior_function(args.prior, prior_kwargs)
    if args.overflow > 0:
        soft, decay = args.max_length, float(args.overflow)
        prior = lambda L: base_prior(L) * (  # noqa: E731
            1.0 if L <= soft else math.exp(-((L - soft) / decay) ** 2))
    else:
        prior = base_prior

    idx = constrained_segmentation(probs, prior, min_length=args.min_length,
                                   max_length=hard_max, algorithm=args.algorithm)
    cuts = [0] + list(idx) + [len(text)]
    chunks = [text[cuts[i]:cuts[i + 1]] for i in range(len(cuts) - 1)]

    print(f"\nModel: {name}  ({path.stat().st_size/1e6:.1f} MB)")
    print(f"Config: max={args.max_length} overflow={args.overflow} "
          f"min={args.min_length} prior={args.prior} algo={args.algorithm}")
    print(f"Input: {len(text)} chars -> {len(chunks)} chunks\n")
    for c in chunks:
        n = len(c)
        flag = "!" if n > hard_max else ("+" if n > args.max_length else " ")
        print(f"  {flag}[{n:3d}] {c.strip()[:90]}")
    assert "".join(chunks) == text, "TEXT NOT PRESERVED"
    print("\n  ✓ text preserved (chunks rejoin to original)")


if __name__ == "__main__":
    main()