| |
| """ |
| Pick an ONNX SaT model and segment text with it (local CPU test). |
| |
| Examples: |
| # interactive: choose a model from a menu, then type/paste text |
| python scripts/run_segmentation.py |
| |
| # one-shot |
| python scripts/run_segmentation.py --model sat-1l-sm-en_zh-int8 \ |
| --text "Your text here. 这是中文。" --max-length 80 |
| |
| # from a file, tighter Chinese-style budget |
| python scripts/run_segmentation.py -m sat-3l-sm-en_zh-int8 -f test.txt \ |
| --max-length 40 --min-length 15 |
| |
| Notes: |
| - "*-en_zh-*" models use a pruned vocab; the id-remap is recomputed on the fly |
| (deterministic from the tokenizer), so no extra files are needed. |
| - onnxruntime needs the conda libstdc++ on this box; the script auto-preloads it |
| and re-execs once if needed. |
| """ |
| import argparse |
| import math |
| import os |
| import re |
| import string |
| import sys |
| from pathlib import Path |
|
|
| |
| def _ensure_onnxruntime(): |
| import contextlib |
| import io |
| |
| try: |
| with contextlib.redirect_stderr(io.StringIO()): |
| import onnxruntime |
| return |
| except Exception: |
| prefix = os.environ.get("CONDA_PREFIX") or sys.prefix |
| lib = Path(prefix) / "lib" / "libstdc++.so.6" |
| if lib.exists() and os.environ.get("_ORT_PRELOADED") != "1": |
| os.environ["LD_PRELOAD"] = f"{lib}:{os.environ.get('LD_PRELOAD','')}".strip(":") |
| os.environ["_ORT_PRELOADED"] = "1" |
| os.execv(sys.executable, [sys.executable] + sys.argv) |
| raise |
|
|
|
|
| _ensure_onnxruntime() |
|
|
| import importlib.util |
| import types |
|
|
| import numpy as np |
| import onnxruntime as ort |
|
|
| NEWLINE_INDEX = 0 |
| ROOT = Path(__file__).resolve().parent.parent |
| MODELS_DIR = ROOT / "onnx_models" |
|
|
|
|
| |
| |
| |
| |
| def _load_light(path, name): |
| if "wtpsplit" not in sys.modules: |
| pkg = types.ModuleType("wtpsplit"); pkg.__path__ = [] |
| utils = types.ModuleType("wtpsplit.utils"); utils.__path__ = [] |
| utils.indices_to_sentences = lambda *a, **k: None |
| sys.modules["wtpsplit"] = pkg |
| sys.modules["wtpsplit.utils"] = utils |
| spec = importlib.util.spec_from_file_location(name, path) |
| mod = importlib.util.module_from_spec(spec) |
| spec.loader.exec_module(mod) |
| return mod |
|
|
|
|
| _WT_UTILS = ROOT / "wtpsplit" / "utils" |
| constrained_segmentation = _load_light(_WT_UTILS / "constraints.py", |
| "onnxseg_constraints").constrained_segmentation |
| create_prior_function = _load_light(_WT_UTILS / "priors.py", |
| "onnxseg_priors").create_prior_function |
|
|
| def get_token_spans(offsets_mapping, tokens, special_tokens): |
| valid = np.array([i for i, t in enumerate(tokens) |
| if i < len(offsets_mapping) and t not in special_tokens]) |
| return valid, np.array(offsets_mapping)[valid] |
|
|
|
|
| def token_to_char_probs(text, tokens, token_logits, special_tokens, offsets_mapping): |
| char_probs = np.full((len(text), token_logits.shape[1]), -np.inf) |
| vi, vo = get_token_spans(offsets_mapping, tokens, special_tokens) |
| char_probs[vo[:, 1] - 1] = token_logits[vi] |
| return char_probs |
|
|
|
|
| _TOK_CACHE = Path(__file__).resolve().parent / ".xlmr_tokenizer" / "tokenizer.json" |
|
|
|
|
| class FastTok: |
| """Thin wrapper over the `tokenizers` Rust lib (loads in ~0.4s vs ~4.3s for |
| transformers + AutoTokenizer). Exposes only what this script needs.""" |
|
|
| def __init__(self, tok): |
| self._t = tok |
| self.special_tokens = {"<s>", "</s>", "<pad>", "<unk>", "<mask>"} |
| self.unk_token_id = tok.token_to_id("<unk>") |
| self.all_special_ids = [tok.token_to_id(s) for s in self.special_tokens |
| if tok.token_to_id(s) is not None] |
|
|
| def encode(self, text): |
| e = self._t.encode(text) |
| return e.ids, e.offsets, e.tokens |
|
|
| def get_vocab(self): |
| return self._t.get_vocab() |
|
|
|
|
| def load_tokenizer(): |
| """Return a FastTok. Builds the fast tokenizer.json cache via transformers |
| only once (first ever run); afterwards loads via the `tokenizers` lib alone, |
| so transformers/torch are never imported.""" |
| from tokenizers import Tokenizer |
| if not _TOK_CACHE.exists(): |
| from transformers import AutoTokenizer |
| AutoTokenizer.from_pretrained("xlm-roberta-base").save_pretrained( |
| str(_TOK_CACHE.parent)) |
| return FastTok(Tokenizer.from_file(str(_TOK_CACHE))) |
|
|
|
|
| def compute_keep_ids(tokenizer): |
| """EN+ZH keep-set: ASCII or CJK tokens, plus specials (pure-stdlib, fast).""" |
| keep = set(tokenizer.all_special_ids) |
| for tok, idx in tokenizer.get_vocab().items(): |
| s = tok.replace("▁", " ") |
| if all(ord(c) < 128 for c in s) or any(_is_cjk(c) for c in s): |
| keep.add(idx) |
| return sorted(keep) |
|
|
|
|
| def get_remap(tokenizer): |
| """old->new id map for EN+ZH pruning, cached to disk (.npy).""" |
| cache = MODELS_DIR / "remap_en_zh.npy" |
| if cache.exists(): |
| remap = np.load(cache) |
| else: |
| keep = compute_keep_ids(tokenizer) |
| remap = np.full(250002, -1, dtype=np.int64) |
| for new_id, old_id in enumerate(keep): |
| remap[old_id] = new_id |
| MODELS_DIR.mkdir(parents=True, exist_ok=True) |
| np.save(cache, remap) |
| return remap, int(remap[tokenizer.unk_token_id]) |
|
|
|
|
| def find_models(root: Path): |
| """Return {display_name: onnx_path} for every .onnx under onnx_models/.""" |
| out = {} |
| for p in sorted(root.rglob("*.onnx")): |
| variant = p.parent.name |
| quant = "int8" if ".int8." in p.name else "fp32" |
| out[f"{variant}-{quant}"] = p |
| return out |
|
|
|
|
| def choose_model(models: dict): |
| names = list(models) |
| print("\nAvailable ONNX models:") |
| for i, n in enumerate(names, 1): |
| mb = models[n].stat().st_size / 1e6 |
| print(f" {i:2d}) {n:30s} {mb:7.1f} MB") |
| while True: |
| sel = input("\nSelect model [number or name]: ").strip() |
| if sel.isdigit() and 1 <= int(sel) <= len(names): |
| return names[int(sel) - 1] |
| if sel in models: |
| return sel |
| print(" invalid choice, try again") |
|
|
|
|
| def get_text(args): |
| if args.text: |
| return args.text |
| if args.file: |
| return Path(args.file).read_text(encoding="utf-8") |
| print("\nEnter/paste text, then Ctrl-D (Ctrl-Z on Windows) to finish:") |
| data = sys.stdin.read() |
| return data if data.strip() else ( |
| "Breaking News: Scientists announced a discovery. 这是一个测试。It works well!") |
|
|
|
|
| CJK_RANGES = [(0x4E00, 0x9FFF), (0x3400, 0x4DBF), (0xF900, 0xFAFF), |
| (0x3000, 0x303F), (0xFF00, 0xFFEF)] |
|
|
|
|
| def _is_cjk(ch): |
| cp = ord(ch) |
| return any(a <= cp <= b for a, b in CJK_RANGES) |
|
|
|
|
| |
| |
| |
| |
| |
| CLAUSE_PUNCT = set(",;:)]}—–" |
| ",、;:" |
| "”’") |
| CJK_SENT_PUNCT = set("。!?…") |
|
|
| |
| |
| CONNECTORS = { |
| "and", "but", "or", "nor", "yet", "so", "for", |
| "which", "that", "who", "whom", "whose", "where", "when", "while", |
| "because", "although", "though", "since", "if", "unless", "until", |
| "after", "before", "as", "than", "whether", |
| } |
|
|
| FLOOR_CLAUSE = 0.25 |
| FLOOR_CONNECTOR = 0.05 |
| FLOOR_HANZI = 5e-3 |
| FLOOR_SPACE = 1e-4 |
| FORBID = 1e-9 |
|
|
|
|
| def _connector_break_positions(text): |
| """Indices i (break after char i) that sit right before a connector word.""" |
| pos = set() |
| for m in re.finditer(r"\s+(\S+)", text): |
| word = m.group(1).strip(string.punctuation).lower() |
| if word in CONNECTORS and m.start() - 1 >= 0: |
| pos.add(m.start() - 1) |
| return pos |
|
|
|
|
| def pause_aware_mask(probs, text): |
| """Bias forced breaks toward natural prosodic pauses so TTS doesn't pause |
| mid-phrase. probs[i] = boundary prob *after* char i (between i and i+1). |
| |
| Model-predicted sentence boundaries (high prob) are preserved as-is and keep |
| dominating. For everything else we raise a floor by pause strength: |
| clause punctuation (, ; : 、 , …) > connector word (and/which/that) > |
| plain word gap, |
| and mid-word positions are driven to ~0 so words/abbreviations are never cut. |
| The result: long sentences break at the nearest comma/clause in range, then |
| before a connecting word, and only at a bare space as a last resort. |
| """ |
| p = probs.copy() |
| n = len(text) |
| connectors = _connector_break_positions(text) |
| for i in range(n - 1): |
| ch, nxt = text[i], text[i + 1] |
| ends_token = nxt.isspace() or _is_cjk(nxt) |
| if ch in CLAUSE_PUNCT and ends_token: |
| p[i] = max(p[i], FLOOR_CLAUSE) |
| elif ch in CJK_SENT_PUNCT: |
| p[i] = max(p[i], 0.9) |
| elif i in connectors: |
| p[i] = max(p[i], FLOOR_CONNECTOR) |
| elif nxt.isspace() or ch.isspace(): |
| p[i] = max(p[i], FLOOR_SPACE) |
| elif _is_cjk(ch) and _is_cjk(nxt): |
| p[i] = max(p[i], FLOOR_HANZI) |
| else: |
| p[i] = min(p[i], FORBID) |
| return p |
|
|
|
|
| |
| word_safe_mask = pause_aware_mask |
|
|
|
|
| def boundary_probs(session, tokenizer, text, remap, unk_new): |
| ids_list, offsets, tokens = tokenizer.encode(text) |
| ids = np.array([ids_list], dtype=np.int64) |
| mask = np.ones_like(ids) |
| if remap is not None: |
| ids = remap[ids] |
| ids[ids == -1] = unk_new |
| logits = session.run(["logits"], {"input_ids": ids, "attention_mask": mask})[0] |
| char_logits = token_to_char_probs(text, tokens, logits[0], |
| tokenizer.special_tokens, offsets) |
| return 1.0 / (1.0 + np.exp(-char_logits[:, NEWLINE_INDEX])) |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser(description="Segment text with a local ONNX SaT model") |
| ap.add_argument("-m", "--model", help="model name (see menu if omitted)") |
| ap.add_argument("-t", "--text", help="text to segment") |
| ap.add_argument("-f", "--file", help="read text from this file") |
| ap.add_argument("--max-length", type=int, default=80, help="target max chars per chunk") |
| ap.add_argument("--min-length", type=int, default=40, help="min chars per chunk") |
| ap.add_argument("--overflow", type=int, default=0, |
| help="chars a chunk may exceed --max-length to reach a comma/" |
| "clause/sentence pause (soft cap; 0 = hard cap)") |
| ap.add_argument("--prior", default="gaussian", |
| choices=["uniform", "gaussian", "clipped_polynomial"]) |
| ap.add_argument("--target", type=int, default=70, help="gaussian target length") |
| ap.add_argument("--spread", type=int, default=12, help="gaussian spread") |
| ap.add_argument("--algorithm", default="viterbi", choices=["viterbi", "greedy"]) |
| ap.add_argument("--allow-midword", action="store_true", |
| help="permit breaks inside words/abbreviations (off by default)") |
| args = ap.parse_args() |
|
|
| models = find_models(MODELS_DIR) |
| if not models: |
| sys.exit(f"No ONNX models found under {MODELS_DIR}. Run build_and_test_onnx.py first.") |
|
|
| name = args.model or choose_model(models) |
| if name not in models: |
| sys.exit(f"Unknown model '{name}'. Choices: {', '.join(models)}") |
| path = models[name] |
|
|
| tokenizer = load_tokenizer() |
| remap = unk_new = None |
| if "en_zh" in name: |
| remap, unk_new = get_remap(tokenizer) |
|
|
| session = ort.InferenceSession(str(path), providers=["CPUExecutionProvider"]) |
| text = get_text(args) |
|
|
| probs = boundary_probs(session, tokenizer, text, remap, unk_new) |
| if not args.allow_midword: |
| probs = word_safe_mask(probs, text) |
|
|
| |
| |
| |
| |
| hard_max = args.max_length + max(0, args.overflow) |
| prior_kwargs = {"max_length": hard_max} |
| if args.prior != "uniform": |
| prior_kwargs.update(target_length=args.target, spread=args.spread) |
| base_prior = create_prior_function(args.prior, prior_kwargs) |
| if args.overflow > 0: |
| soft, decay = args.max_length, float(args.overflow) |
| prior = lambda L: base_prior(L) * ( |
| 1.0 if L <= soft else math.exp(-((L - soft) / decay) ** 2)) |
| else: |
| prior = base_prior |
|
|
| idx = constrained_segmentation(probs, prior, min_length=args.min_length, |
| max_length=hard_max, algorithm=args.algorithm) |
| cuts = [0] + list(idx) + [len(text)] |
| chunks = [text[cuts[i]:cuts[i + 1]] for i in range(len(cuts) - 1)] |
|
|
| print(f"\nModel: {name} ({path.stat().st_size/1e6:.1f} MB)") |
| print(f"Config: max={args.max_length} overflow={args.overflow} " |
| f"min={args.min_length} prior={args.prior} algo={args.algorithm}") |
| print(f"Input: {len(text)} chars -> {len(chunks)} chunks\n") |
| for c in chunks: |
| n = len(c) |
| flag = "!" if n > hard_max else ("+" if n > args.max_length else " ") |
| print(f" {flag}[{n:3d}] {c.strip()[:90]}") |
| assert "".join(chunks) == text, "TEXT NOT PRESERVED" |
| print("\n ✓ text preserved (chunks rejoin to original)") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|