wtpsplit-kit / scripts /run_segmentation.py
krmanik's picture
Upload folder using huggingface_hub
357ae2c verified
#!/usr/bin/env python3
"""
Pick an ONNX SaT model and segment text with it (local CPU test).
Examples:
# interactive: choose a model from a menu, then type/paste text
python scripts/run_segmentation.py
# one-shot
python scripts/run_segmentation.py --model sat-1l-sm-en_zh-int8 \
--text "Your text here. 这是中文。" --max-length 80
# from a file, tighter Chinese-style budget
python scripts/run_segmentation.py -m sat-3l-sm-en_zh-int8 -f test.txt \
--max-length 40 --min-length 15
Notes:
- "*-en_zh-*" models use a pruned vocab; the id-remap is recomputed on the fly
(deterministic from the tokenizer), so no extra files are needed.
- onnxruntime needs the conda libstdc++ on this box; the script auto-preloads it
and re-execs once if needed.
"""
import argparse
import math
import os
import re
import string
import sys
from pathlib import Path
# --- bootstrap: onnxruntime needs conda's libstdc++ preloaded on this machine ---
def _ensure_onnxruntime():
import contextlib
import io
# Probe quietly: a failed import dumps a long numpy/GLIBCXX message to stderr.
try:
with contextlib.redirect_stderr(io.StringIO()):
import onnxruntime # noqa
return
except Exception:
prefix = os.environ.get("CONDA_PREFIX") or sys.prefix
lib = Path(prefix) / "lib" / "libstdc++.so.6"
if lib.exists() and os.environ.get("_ORT_PRELOADED") != "1":
os.environ["LD_PRELOAD"] = f"{lib}:{os.environ.get('LD_PRELOAD','')}".strip(":")
os.environ["_ORT_PRELOADED"] = "1"
os.execv(sys.executable, [sys.executable] + sys.argv)
raise
_ensure_onnxruntime()
import importlib.util # noqa: E402
import types # noqa: E402
import numpy as np # noqa: E402
import onnxruntime as ort # noqa: E402
NEWLINE_INDEX = 0
ROOT = Path(__file__).resolve().parent.parent
MODELS_DIR = ROOT / "onnx_models"
# --- load the two tiny pure-numpy helper modules WITHOUT importing the heavy
# wtpsplit package (which pulls torch/onnx/skops and costs ~5s on startup).
# constraints.py references wtpsplit.utils.indices_to_sentences but
# constrained_segmentation() never calls it, so we stub that one symbol. ---
def _load_light(path, name):
if "wtpsplit" not in sys.modules:
pkg = types.ModuleType("wtpsplit"); pkg.__path__ = []
utils = types.ModuleType("wtpsplit.utils"); utils.__path__ = []
utils.indices_to_sentences = lambda *a, **k: None # unused here
sys.modules["wtpsplit"] = pkg
sys.modules["wtpsplit.utils"] = utils
spec = importlib.util.spec_from_file_location(name, path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
return mod
_WT_UTILS = ROOT / "wtpsplit" / "utils"
constrained_segmentation = _load_light(_WT_UTILS / "constraints.py",
"onnxseg_constraints").constrained_segmentation
create_prior_function = _load_light(_WT_UTILS / "priors.py",
"onnxseg_priors").create_prior_function
def get_token_spans(offsets_mapping, tokens, special_tokens):
valid = np.array([i for i, t in enumerate(tokens)
if i < len(offsets_mapping) and t not in special_tokens])
return valid, np.array(offsets_mapping)[valid]
def token_to_char_probs(text, tokens, token_logits, special_tokens, offsets_mapping):
char_probs = np.full((len(text), token_logits.shape[1]), -np.inf)
vi, vo = get_token_spans(offsets_mapping, tokens, special_tokens)
char_probs[vo[:, 1] - 1] = token_logits[vi]
return char_probs
_TOK_CACHE = Path(__file__).resolve().parent / ".xlmr_tokenizer" / "tokenizer.json"
class FastTok:
"""Thin wrapper over the `tokenizers` Rust lib (loads in ~0.4s vs ~4.3s for
transformers + AutoTokenizer). Exposes only what this script needs."""
def __init__(self, tok):
self._t = tok
self.special_tokens = {"<s>", "</s>", "<pad>", "<unk>", "<mask>"}
self.unk_token_id = tok.token_to_id("<unk>")
self.all_special_ids = [tok.token_to_id(s) for s in self.special_tokens
if tok.token_to_id(s) is not None]
def encode(self, text):
e = self._t.encode(text) # XLM-R template adds <s> ... </s>
return e.ids, e.offsets, e.tokens
def get_vocab(self):
return self._t.get_vocab()
def load_tokenizer():
"""Return a FastTok. Builds the fast tokenizer.json cache via transformers
only once (first ever run); afterwards loads via the `tokenizers` lib alone,
so transformers/torch are never imported."""
from tokenizers import Tokenizer
if not _TOK_CACHE.exists():
from transformers import AutoTokenizer # lazy: only on first build
AutoTokenizer.from_pretrained("xlm-roberta-base").save_pretrained(
str(_TOK_CACHE.parent))
return FastTok(Tokenizer.from_file(str(_TOK_CACHE)))
def compute_keep_ids(tokenizer):
"""EN+ZH keep-set: ASCII or CJK tokens, plus specials (pure-stdlib, fast)."""
keep = set(tokenizer.all_special_ids)
for tok, idx in tokenizer.get_vocab().items():
s = tok.replace("▁", " ") # SP underscore -> space
if all(ord(c) < 128 for c in s) or any(_is_cjk(c) for c in s):
keep.add(idx)
return sorted(keep)
def get_remap(tokenizer):
"""old->new id map for EN+ZH pruning, cached to disk (.npy)."""
cache = MODELS_DIR / "remap_en_zh.npy"
if cache.exists():
remap = np.load(cache)
else:
keep = compute_keep_ids(tokenizer)
remap = np.full(250002, -1, dtype=np.int64)
for new_id, old_id in enumerate(keep):
remap[old_id] = new_id
MODELS_DIR.mkdir(parents=True, exist_ok=True)
np.save(cache, remap)
return remap, int(remap[tokenizer.unk_token_id])
def find_models(root: Path):
"""Return {display_name: onnx_path} for every .onnx under onnx_models/."""
out = {}
for p in sorted(root.rglob("*.onnx")):
variant = p.parent.name # e.g. sat-1l-sm-en_zh
quant = "int8" if ".int8." in p.name else "fp32"
out[f"{variant}-{quant}"] = p
return out
def choose_model(models: dict):
names = list(models)
print("\nAvailable ONNX models:")
for i, n in enumerate(names, 1):
mb = models[n].stat().st_size / 1e6
print(f" {i:2d}) {n:30s} {mb:7.1f} MB")
while True:
sel = input("\nSelect model [number or name]: ").strip()
if sel.isdigit() and 1 <= int(sel) <= len(names):
return names[int(sel) - 1]
if sel in models:
return sel
print(" invalid choice, try again")
def get_text(args):
if args.text:
return args.text
if args.file:
return Path(args.file).read_text(encoding="utf-8")
print("\nEnter/paste text, then Ctrl-D (Ctrl-Z on Windows) to finish:")
data = sys.stdin.read()
return data if data.strip() else (
"Breaking News: Scientists announced a discovery. 这是一个测试。It works well!")
CJK_RANGES = [(0x4E00, 0x9FFF), (0x3400, 0x4DBF), (0xF900, 0xFAFF),
(0x3000, 0x303F), (0xFF00, 0xFFEF)]
def _is_cjk(ch):
cp = ord(ch)
return any(a <= cp <= b for a, b in CJK_RANGES)
# Punctuation that marks a prosodic pause, by strength (used as break-priority
# floors when a long sentence must be split below max_length). Sentence-ending
# punctuation is intentionally NOT floored here -- the model already predicts
# those boundaries well, and overriding it would create false breaks after
# abbreviations like "A.I.".
CLAUSE_PUNCT = set(",;:)]}—–" # , ; : ) ] } em/en-dash
",、;:" # CJK , 、 ; :
"”’") # closing “ ” ’
CJK_SENT_PUNCT = set("。!?…") # 。 ! ? …
# Words that introduce a clause/phrase: breaking *before* one of these sounds
# more natural than a random word gap when a long span has no punctuation.
CONNECTORS = {
"and", "but", "or", "nor", "yet", "so", "for",
"which", "that", "who", "whom", "whose", "where", "when", "while",
"because", "although", "though", "since", "if", "unless", "until",
"after", "before", "as", "than", "whether",
}
FLOOR_CLAUSE = 0.25 # comma / semicolon / colon -> strongly preferred
FLOOR_CONNECTOR = 0.05 # break before "and/which/that..." in a comma-free span
FLOOR_HANZI = 5e-3 # between two Chinese chars (no spaces in zh)
FLOOR_SPACE = 1e-4 # plain word gap -> last-resort break
FORBID = 1e-9 # mid-word -> effectively never
def _connector_break_positions(text):
"""Indices i (break after char i) that sit right before a connector word."""
pos = set()
for m in re.finditer(r"\s+(\S+)", text):
word = m.group(1).strip(string.punctuation).lower()
if word in CONNECTORS and m.start() - 1 >= 0:
pos.add(m.start() - 1) # last char of the preceding word
return pos
def pause_aware_mask(probs, text):
"""Bias forced breaks toward natural prosodic pauses so TTS doesn't pause
mid-phrase. probs[i] = boundary prob *after* char i (between i and i+1).
Model-predicted sentence boundaries (high prob) are preserved as-is and keep
dominating. For everything else we raise a floor by pause strength:
clause punctuation (, ; : 、 , …) > connector word (and/which/that) >
plain word gap,
and mid-word positions are driven to ~0 so words/abbreviations are never cut.
The result: long sentences break at the nearest comma/clause in range, then
before a connecting word, and only at a bare space as a last resort.
"""
p = probs.copy()
n = len(text)
connectors = _connector_break_positions(text)
for i in range(n - 1): # never break before end-of-text marker
ch, nxt = text[i], text[i + 1]
ends_token = nxt.isspace() or _is_cjk(nxt)
if ch in CLAUSE_PUNCT and ends_token:
p[i] = max(p[i], FLOOR_CLAUSE)
elif ch in CJK_SENT_PUNCT: # zh sentence end
p[i] = max(p[i], 0.9)
elif i in connectors: # break before connector
p[i] = max(p[i], FLOOR_CONNECTOR)
elif nxt.isspace() or ch.isspace(): # plain word boundary
p[i] = max(p[i], FLOOR_SPACE)
elif _is_cjk(ch) and _is_cjk(nxt): # between hanzi
p[i] = max(p[i], FLOOR_HANZI)
else: # mid-word/abbreviation
p[i] = min(p[i], FORBID)
return p
# kept as an alias so existing imports (benchmark) keep working
word_safe_mask = pause_aware_mask
def boundary_probs(session, tokenizer, text, remap, unk_new):
ids_list, offsets, tokens = tokenizer.encode(text)
ids = np.array([ids_list], dtype=np.int64)
mask = np.ones_like(ids)
if remap is not None:
ids = remap[ids]
ids[ids == -1] = unk_new
logits = session.run(["logits"], {"input_ids": ids, "attention_mask": mask})[0]
char_logits = token_to_char_probs(text, tokens, logits[0],
tokenizer.special_tokens, offsets)
return 1.0 / (1.0 + np.exp(-char_logits[:, NEWLINE_INDEX]))
def main():
ap = argparse.ArgumentParser(description="Segment text with a local ONNX SaT model")
ap.add_argument("-m", "--model", help="model name (see menu if omitted)")
ap.add_argument("-t", "--text", help="text to segment")
ap.add_argument("-f", "--file", help="read text from this file")
ap.add_argument("--max-length", type=int, default=80, help="target max chars per chunk")
ap.add_argument("--min-length", type=int, default=40, help="min chars per chunk")
ap.add_argument("--overflow", type=int, default=0,
help="chars a chunk may exceed --max-length to reach a comma/"
"clause/sentence pause (soft cap; 0 = hard cap)")
ap.add_argument("--prior", default="gaussian",
choices=["uniform", "gaussian", "clipped_polynomial"])
ap.add_argument("--target", type=int, default=70, help="gaussian target length")
ap.add_argument("--spread", type=int, default=12, help="gaussian spread")
ap.add_argument("--algorithm", default="viterbi", choices=["viterbi", "greedy"])
ap.add_argument("--allow-midword", action="store_true",
help="permit breaks inside words/abbreviations (off by default)")
args = ap.parse_args()
models = find_models(MODELS_DIR)
if not models:
sys.exit(f"No ONNX models found under {MODELS_DIR}. Run build_and_test_onnx.py first.")
name = args.model or choose_model(models)
if name not in models:
sys.exit(f"Unknown model '{name}'. Choices: {', '.join(models)}")
path = models[name]
tokenizer = load_tokenizer()
remap = unk_new = None
if "en_zh" in name:
remap, unk_new = get_remap(tokenizer)
session = ort.InferenceSession(str(path), providers=["CPUExecutionProvider"])
text = get_text(args)
probs = boundary_probs(session, tokenizer, text, remap, unk_new)
if not args.allow_midword:
probs = word_safe_mask(probs, text)
# Hard ceiling for the DP. With --overflow, allow chunks past --max-length up
# to this ceiling; a decay tail past --max-length keeps plain spaces from
# exploiting the slack while still letting a strong pause (comma/sentence)
# pull the break into the overflow zone.
hard_max = args.max_length + max(0, args.overflow)
prior_kwargs = {"max_length": hard_max}
if args.prior != "uniform":
prior_kwargs.update(target_length=args.target, spread=args.spread)
base_prior = create_prior_function(args.prior, prior_kwargs)
if args.overflow > 0:
soft, decay = args.max_length, float(args.overflow)
prior = lambda L: base_prior(L) * ( # noqa: E731
1.0 if L <= soft else math.exp(-((L - soft) / decay) ** 2))
else:
prior = base_prior
idx = constrained_segmentation(probs, prior, min_length=args.min_length,
max_length=hard_max, algorithm=args.algorithm)
cuts = [0] + list(idx) + [len(text)]
chunks = [text[cuts[i]:cuts[i + 1]] for i in range(len(cuts) - 1)]
print(f"\nModel: {name} ({path.stat().st_size/1e6:.1f} MB)")
print(f"Config: max={args.max_length} overflow={args.overflow} "
f"min={args.min_length} prior={args.prior} algo={args.algorithm}")
print(f"Input: {len(text)} chars -> {len(chunks)} chunks\n")
for c in chunks:
n = len(c)
flag = "!" if n > hard_max else ("+" if n > args.max_length else " ")
print(f" {flag}[{n:3d}] {c.strip()[:90]}")
assert "".join(chunks) == text, "TEXT NOT PRESERVED"
print("\n ✓ text preserved (chunks rejoin to original)")
if __name__ == "__main__":
main()