#!/usr/bin/env python3 """ Pick an ONNX SaT model and segment text with it (local CPU test). Examples: # interactive: choose a model from a menu, then type/paste text python scripts/run_segmentation.py # one-shot python scripts/run_segmentation.py --model sat-1l-sm-en_zh-int8 \ --text "Your text here. 这是中文。" --max-length 80 # from a file, tighter Chinese-style budget python scripts/run_segmentation.py -m sat-3l-sm-en_zh-int8 -f test.txt \ --max-length 40 --min-length 15 Notes: - "*-en_zh-*" models use a pruned vocab; the id-remap is recomputed on the fly (deterministic from the tokenizer), so no extra files are needed. - onnxruntime needs the conda libstdc++ on this box; the script auto-preloads it and re-execs once if needed. """ import argparse import math import os import re import string import sys from pathlib import Path # --- bootstrap: onnxruntime needs conda's libstdc++ preloaded on this machine --- def _ensure_onnxruntime(): import contextlib import io # Probe quietly: a failed import dumps a long numpy/GLIBCXX message to stderr. try: with contextlib.redirect_stderr(io.StringIO()): import onnxruntime # noqa return except Exception: prefix = os.environ.get("CONDA_PREFIX") or sys.prefix lib = Path(prefix) / "lib" / "libstdc++.so.6" if lib.exists() and os.environ.get("_ORT_PRELOADED") != "1": os.environ["LD_PRELOAD"] = f"{lib}:{os.environ.get('LD_PRELOAD','')}".strip(":") os.environ["_ORT_PRELOADED"] = "1" os.execv(sys.executable, [sys.executable] + sys.argv) raise _ensure_onnxruntime() import importlib.util # noqa: E402 import types # noqa: E402 import numpy as np # noqa: E402 import onnxruntime as ort # noqa: E402 NEWLINE_INDEX = 0 ROOT = Path(__file__).resolve().parent.parent MODELS_DIR = ROOT / "onnx_models" # --- load the two tiny pure-numpy helper modules WITHOUT importing the heavy # wtpsplit package (which pulls torch/onnx/skops and costs ~5s on startup). # constraints.py references wtpsplit.utils.indices_to_sentences but # constrained_segmentation() never calls it, so we stub that one symbol. --- def _load_light(path, name): if "wtpsplit" not in sys.modules: pkg = types.ModuleType("wtpsplit"); pkg.__path__ = [] utils = types.ModuleType("wtpsplit.utils"); utils.__path__ = [] utils.indices_to_sentences = lambda *a, **k: None # unused here sys.modules["wtpsplit"] = pkg sys.modules["wtpsplit.utils"] = utils spec = importlib.util.spec_from_file_location(name, path) mod = importlib.util.module_from_spec(spec) spec.loader.exec_module(mod) return mod _WT_UTILS = ROOT / "wtpsplit" / "utils" constrained_segmentation = _load_light(_WT_UTILS / "constraints.py", "onnxseg_constraints").constrained_segmentation create_prior_function = _load_light(_WT_UTILS / "priors.py", "onnxseg_priors").create_prior_function def get_token_spans(offsets_mapping, tokens, special_tokens): valid = np.array([i for i, t in enumerate(tokens) if i < len(offsets_mapping) and t not in special_tokens]) return valid, np.array(offsets_mapping)[valid] def token_to_char_probs(text, tokens, token_logits, special_tokens, offsets_mapping): char_probs = np.full((len(text), token_logits.shape[1]), -np.inf) vi, vo = get_token_spans(offsets_mapping, tokens, special_tokens) char_probs[vo[:, 1] - 1] = token_logits[vi] return char_probs _TOK_CACHE = Path(__file__).resolve().parent / ".xlmr_tokenizer" / "tokenizer.json" class FastTok: """Thin wrapper over the `tokenizers` Rust lib (loads in ~0.4s vs ~4.3s for transformers + AutoTokenizer). Exposes only what this script needs.""" def __init__(self, tok): self._t = tok self.special_tokens = {"", "", "", "", ""} self.unk_token_id = tok.token_to_id("") self.all_special_ids = [tok.token_to_id(s) for s in self.special_tokens if tok.token_to_id(s) is not None] def encode(self, text): e = self._t.encode(text) # XLM-R template adds ... return e.ids, e.offsets, e.tokens def get_vocab(self): return self._t.get_vocab() def load_tokenizer(): """Return a FastTok. Builds the fast tokenizer.json cache via transformers only once (first ever run); afterwards loads via the `tokenizers` lib alone, so transformers/torch are never imported.""" from tokenizers import Tokenizer if not _TOK_CACHE.exists(): from transformers import AutoTokenizer # lazy: only on first build AutoTokenizer.from_pretrained("xlm-roberta-base").save_pretrained( str(_TOK_CACHE.parent)) return FastTok(Tokenizer.from_file(str(_TOK_CACHE))) def compute_keep_ids(tokenizer): """EN+ZH keep-set: ASCII or CJK tokens, plus specials (pure-stdlib, fast).""" keep = set(tokenizer.all_special_ids) for tok, idx in tokenizer.get_vocab().items(): s = tok.replace("▁", " ") # SP underscore -> space if all(ord(c) < 128 for c in s) or any(_is_cjk(c) for c in s): keep.add(idx) return sorted(keep) def get_remap(tokenizer): """old->new id map for EN+ZH pruning, cached to disk (.npy).""" cache = MODELS_DIR / "remap_en_zh.npy" if cache.exists(): remap = np.load(cache) else: keep = compute_keep_ids(tokenizer) remap = np.full(250002, -1, dtype=np.int64) for new_id, old_id in enumerate(keep): remap[old_id] = new_id MODELS_DIR.mkdir(parents=True, exist_ok=True) np.save(cache, remap) return remap, int(remap[tokenizer.unk_token_id]) def find_models(root: Path): """Return {display_name: onnx_path} for every .onnx under onnx_models/.""" out = {} for p in sorted(root.rglob("*.onnx")): variant = p.parent.name # e.g. sat-1l-sm-en_zh quant = "int8" if ".int8." in p.name else "fp32" out[f"{variant}-{quant}"] = p return out def choose_model(models: dict): names = list(models) print("\nAvailable ONNX models:") for i, n in enumerate(names, 1): mb = models[n].stat().st_size / 1e6 print(f" {i:2d}) {n:30s} {mb:7.1f} MB") while True: sel = input("\nSelect model [number or name]: ").strip() if sel.isdigit() and 1 <= int(sel) <= len(names): return names[int(sel) - 1] if sel in models: return sel print(" invalid choice, try again") def get_text(args): if args.text: return args.text if args.file: return Path(args.file).read_text(encoding="utf-8") print("\nEnter/paste text, then Ctrl-D (Ctrl-Z on Windows) to finish:") data = sys.stdin.read() return data if data.strip() else ( "Breaking News: Scientists announced a discovery. 这是一个测试。It works well!") CJK_RANGES = [(0x4E00, 0x9FFF), (0x3400, 0x4DBF), (0xF900, 0xFAFF), (0x3000, 0x303F), (0xFF00, 0xFFEF)] def _is_cjk(ch): cp = ord(ch) return any(a <= cp <= b for a, b in CJK_RANGES) # Punctuation that marks a prosodic pause, by strength (used as break-priority # floors when a long sentence must be split below max_length). Sentence-ending # punctuation is intentionally NOT floored here -- the model already predicts # those boundaries well, and overriding it would create false breaks after # abbreviations like "A.I.". CLAUSE_PUNCT = set(",;:)]}—–" # , ; : ) ] } em/en-dash ",、;:" # CJK , 、 ; : "”’") # closing “ ” ’ CJK_SENT_PUNCT = set("。!?…") # 。 ! ? … # Words that introduce a clause/phrase: breaking *before* one of these sounds # more natural than a random word gap when a long span has no punctuation. CONNECTORS = { "and", "but", "or", "nor", "yet", "so", "for", "which", "that", "who", "whom", "whose", "where", "when", "while", "because", "although", "though", "since", "if", "unless", "until", "after", "before", "as", "than", "whether", } FLOOR_CLAUSE = 0.25 # comma / semicolon / colon -> strongly preferred FLOOR_CONNECTOR = 0.05 # break before "and/which/that..." in a comma-free span FLOOR_HANZI = 5e-3 # between two Chinese chars (no spaces in zh) FLOOR_SPACE = 1e-4 # plain word gap -> last-resort break FORBID = 1e-9 # mid-word -> effectively never def _connector_break_positions(text): """Indices i (break after char i) that sit right before a connector word.""" pos = set() for m in re.finditer(r"\s+(\S+)", text): word = m.group(1).strip(string.punctuation).lower() if word in CONNECTORS and m.start() - 1 >= 0: pos.add(m.start() - 1) # last char of the preceding word return pos def pause_aware_mask(probs, text): """Bias forced breaks toward natural prosodic pauses so TTS doesn't pause mid-phrase. probs[i] = boundary prob *after* char i (between i and i+1). Model-predicted sentence boundaries (high prob) are preserved as-is and keep dominating. For everything else we raise a floor by pause strength: clause punctuation (, ; : 、 , …) > connector word (and/which/that) > plain word gap, and mid-word positions are driven to ~0 so words/abbreviations are never cut. The result: long sentences break at the nearest comma/clause in range, then before a connecting word, and only at a bare space as a last resort. """ p = probs.copy() n = len(text) connectors = _connector_break_positions(text) for i in range(n - 1): # never break before end-of-text marker ch, nxt = text[i], text[i + 1] ends_token = nxt.isspace() or _is_cjk(nxt) if ch in CLAUSE_PUNCT and ends_token: p[i] = max(p[i], FLOOR_CLAUSE) elif ch in CJK_SENT_PUNCT: # zh sentence end p[i] = max(p[i], 0.9) elif i in connectors: # break before connector p[i] = max(p[i], FLOOR_CONNECTOR) elif nxt.isspace() or ch.isspace(): # plain word boundary p[i] = max(p[i], FLOOR_SPACE) elif _is_cjk(ch) and _is_cjk(nxt): # between hanzi p[i] = max(p[i], FLOOR_HANZI) else: # mid-word/abbreviation p[i] = min(p[i], FORBID) return p # kept as an alias so existing imports (benchmark) keep working word_safe_mask = pause_aware_mask def boundary_probs(session, tokenizer, text, remap, unk_new): ids_list, offsets, tokens = tokenizer.encode(text) ids = np.array([ids_list], dtype=np.int64) mask = np.ones_like(ids) if remap is not None: ids = remap[ids] ids[ids == -1] = unk_new logits = session.run(["logits"], {"input_ids": ids, "attention_mask": mask})[0] char_logits = token_to_char_probs(text, tokens, logits[0], tokenizer.special_tokens, offsets) return 1.0 / (1.0 + np.exp(-char_logits[:, NEWLINE_INDEX])) def main(): ap = argparse.ArgumentParser(description="Segment text with a local ONNX SaT model") ap.add_argument("-m", "--model", help="model name (see menu if omitted)") ap.add_argument("-t", "--text", help="text to segment") ap.add_argument("-f", "--file", help="read text from this file") ap.add_argument("--max-length", type=int, default=80, help="target max chars per chunk") ap.add_argument("--min-length", type=int, default=40, help="min chars per chunk") ap.add_argument("--overflow", type=int, default=0, help="chars a chunk may exceed --max-length to reach a comma/" "clause/sentence pause (soft cap; 0 = hard cap)") ap.add_argument("--prior", default="gaussian", choices=["uniform", "gaussian", "clipped_polynomial"]) ap.add_argument("--target", type=int, default=70, help="gaussian target length") ap.add_argument("--spread", type=int, default=12, help="gaussian spread") ap.add_argument("--algorithm", default="viterbi", choices=["viterbi", "greedy"]) ap.add_argument("--allow-midword", action="store_true", help="permit breaks inside words/abbreviations (off by default)") args = ap.parse_args() models = find_models(MODELS_DIR) if not models: sys.exit(f"No ONNX models found under {MODELS_DIR}. Run build_and_test_onnx.py first.") name = args.model or choose_model(models) if name not in models: sys.exit(f"Unknown model '{name}'. Choices: {', '.join(models)}") path = models[name] tokenizer = load_tokenizer() remap = unk_new = None if "en_zh" in name: remap, unk_new = get_remap(tokenizer) session = ort.InferenceSession(str(path), providers=["CPUExecutionProvider"]) text = get_text(args) probs = boundary_probs(session, tokenizer, text, remap, unk_new) if not args.allow_midword: probs = word_safe_mask(probs, text) # Hard ceiling for the DP. With --overflow, allow chunks past --max-length up # to this ceiling; a decay tail past --max-length keeps plain spaces from # exploiting the slack while still letting a strong pause (comma/sentence) # pull the break into the overflow zone. hard_max = args.max_length + max(0, args.overflow) prior_kwargs = {"max_length": hard_max} if args.prior != "uniform": prior_kwargs.update(target_length=args.target, spread=args.spread) base_prior = create_prior_function(args.prior, prior_kwargs) if args.overflow > 0: soft, decay = args.max_length, float(args.overflow) prior = lambda L: base_prior(L) * ( # noqa: E731 1.0 if L <= soft else math.exp(-((L - soft) / decay) ** 2)) else: prior = base_prior idx = constrained_segmentation(probs, prior, min_length=args.min_length, max_length=hard_max, algorithm=args.algorithm) cuts = [0] + list(idx) + [len(text)] chunks = [text[cuts[i]:cuts[i + 1]] for i in range(len(cuts) - 1)] print(f"\nModel: {name} ({path.stat().st_size/1e6:.1f} MB)") print(f"Config: max={args.max_length} overflow={args.overflow} " f"min={args.min_length} prior={args.prior} algo={args.algorithm}") print(f"Input: {len(text)} chars -> {len(chunks)} chunks\n") for c in chunks: n = len(c) flag = "!" if n > hard_max else ("+" if n > args.max_length else " ") print(f" {flag}[{n:3d}] {c.strip()[:90]}") assert "".join(chunks) == text, "TEXT NOT PRESERVED" print("\n ✓ text preserved (chunks rejoin to original)") if __name__ == "__main__": main()