File size: 15,005 Bytes
357ae2c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 | #!/usr/bin/env python3
"""
Pick an ONNX SaT model and segment text with it (local CPU test).
Examples:
# interactive: choose a model from a menu, then type/paste text
python scripts/run_segmentation.py
# one-shot
python scripts/run_segmentation.py --model sat-1l-sm-en_zh-int8 \
--text "Your text here. 这是中文。" --max-length 80
# from a file, tighter Chinese-style budget
python scripts/run_segmentation.py -m sat-3l-sm-en_zh-int8 -f test.txt \
--max-length 40 --min-length 15
Notes:
- "*-en_zh-*" models use a pruned vocab; the id-remap is recomputed on the fly
(deterministic from the tokenizer), so no extra files are needed.
- onnxruntime needs the conda libstdc++ on this box; the script auto-preloads it
and re-execs once if needed.
"""
import argparse
import math
import os
import re
import string
import sys
from pathlib import Path
# --- bootstrap: onnxruntime needs conda's libstdc++ preloaded on this machine ---
def _ensure_onnxruntime():
import contextlib
import io
# Probe quietly: a failed import dumps a long numpy/GLIBCXX message to stderr.
try:
with contextlib.redirect_stderr(io.StringIO()):
import onnxruntime # noqa
return
except Exception:
prefix = os.environ.get("CONDA_PREFIX") or sys.prefix
lib = Path(prefix) / "lib" / "libstdc++.so.6"
if lib.exists() and os.environ.get("_ORT_PRELOADED") != "1":
os.environ["LD_PRELOAD"] = f"{lib}:{os.environ.get('LD_PRELOAD','')}".strip(":")
os.environ["_ORT_PRELOADED"] = "1"
os.execv(sys.executable, [sys.executable] + sys.argv)
raise
_ensure_onnxruntime()
import importlib.util # noqa: E402
import types # noqa: E402
import numpy as np # noqa: E402
import onnxruntime as ort # noqa: E402
NEWLINE_INDEX = 0
ROOT = Path(__file__).resolve().parent.parent
MODELS_DIR = ROOT / "onnx_models"
# --- load the two tiny pure-numpy helper modules WITHOUT importing the heavy
# wtpsplit package (which pulls torch/onnx/skops and costs ~5s on startup).
# constraints.py references wtpsplit.utils.indices_to_sentences but
# constrained_segmentation() never calls it, so we stub that one symbol. ---
def _load_light(path, name):
if "wtpsplit" not in sys.modules:
pkg = types.ModuleType("wtpsplit"); pkg.__path__ = []
utils = types.ModuleType("wtpsplit.utils"); utils.__path__ = []
utils.indices_to_sentences = lambda *a, **k: None # unused here
sys.modules["wtpsplit"] = pkg
sys.modules["wtpsplit.utils"] = utils
spec = importlib.util.spec_from_file_location(name, path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
return mod
_WT_UTILS = ROOT / "wtpsplit" / "utils"
constrained_segmentation = _load_light(_WT_UTILS / "constraints.py",
"onnxseg_constraints").constrained_segmentation
create_prior_function = _load_light(_WT_UTILS / "priors.py",
"onnxseg_priors").create_prior_function
def get_token_spans(offsets_mapping, tokens, special_tokens):
valid = np.array([i for i, t in enumerate(tokens)
if i < len(offsets_mapping) and t not in special_tokens])
return valid, np.array(offsets_mapping)[valid]
def token_to_char_probs(text, tokens, token_logits, special_tokens, offsets_mapping):
char_probs = np.full((len(text), token_logits.shape[1]), -np.inf)
vi, vo = get_token_spans(offsets_mapping, tokens, special_tokens)
char_probs[vo[:, 1] - 1] = token_logits[vi]
return char_probs
_TOK_CACHE = Path(__file__).resolve().parent / ".xlmr_tokenizer" / "tokenizer.json"
class FastTok:
"""Thin wrapper over the `tokenizers` Rust lib (loads in ~0.4s vs ~4.3s for
transformers + AutoTokenizer). Exposes only what this script needs."""
def __init__(self, tok):
self._t = tok
self.special_tokens = {"<s>", "</s>", "<pad>", "<unk>", "<mask>"}
self.unk_token_id = tok.token_to_id("<unk>")
self.all_special_ids = [tok.token_to_id(s) for s in self.special_tokens
if tok.token_to_id(s) is not None]
def encode(self, text):
e = self._t.encode(text) # XLM-R template adds <s> ... </s>
return e.ids, e.offsets, e.tokens
def get_vocab(self):
return self._t.get_vocab()
def load_tokenizer():
"""Return a FastTok. Builds the fast tokenizer.json cache via transformers
only once (first ever run); afterwards loads via the `tokenizers` lib alone,
so transformers/torch are never imported."""
from tokenizers import Tokenizer
if not _TOK_CACHE.exists():
from transformers import AutoTokenizer # lazy: only on first build
AutoTokenizer.from_pretrained("xlm-roberta-base").save_pretrained(
str(_TOK_CACHE.parent))
return FastTok(Tokenizer.from_file(str(_TOK_CACHE)))
def compute_keep_ids(tokenizer):
"""EN+ZH keep-set: ASCII or CJK tokens, plus specials (pure-stdlib, fast)."""
keep = set(tokenizer.all_special_ids)
for tok, idx in tokenizer.get_vocab().items():
s = tok.replace("▁", " ") # SP underscore -> space
if all(ord(c) < 128 for c in s) or any(_is_cjk(c) for c in s):
keep.add(idx)
return sorted(keep)
def get_remap(tokenizer):
"""old->new id map for EN+ZH pruning, cached to disk (.npy)."""
cache = MODELS_DIR / "remap_en_zh.npy"
if cache.exists():
remap = np.load(cache)
else:
keep = compute_keep_ids(tokenizer)
remap = np.full(250002, -1, dtype=np.int64)
for new_id, old_id in enumerate(keep):
remap[old_id] = new_id
MODELS_DIR.mkdir(parents=True, exist_ok=True)
np.save(cache, remap)
return remap, int(remap[tokenizer.unk_token_id])
def find_models(root: Path):
"""Return {display_name: onnx_path} for every .onnx under onnx_models/."""
out = {}
for p in sorted(root.rglob("*.onnx")):
variant = p.parent.name # e.g. sat-1l-sm-en_zh
quant = "int8" if ".int8." in p.name else "fp32"
out[f"{variant}-{quant}"] = p
return out
def choose_model(models: dict):
names = list(models)
print("\nAvailable ONNX models:")
for i, n in enumerate(names, 1):
mb = models[n].stat().st_size / 1e6
print(f" {i:2d}) {n:30s} {mb:7.1f} MB")
while True:
sel = input("\nSelect model [number or name]: ").strip()
if sel.isdigit() and 1 <= int(sel) <= len(names):
return names[int(sel) - 1]
if sel in models:
return sel
print(" invalid choice, try again")
def get_text(args):
if args.text:
return args.text
if args.file:
return Path(args.file).read_text(encoding="utf-8")
print("\nEnter/paste text, then Ctrl-D (Ctrl-Z on Windows) to finish:")
data = sys.stdin.read()
return data if data.strip() else (
"Breaking News: Scientists announced a discovery. 这是一个测试。It works well!")
CJK_RANGES = [(0x4E00, 0x9FFF), (0x3400, 0x4DBF), (0xF900, 0xFAFF),
(0x3000, 0x303F), (0xFF00, 0xFFEF)]
def _is_cjk(ch):
cp = ord(ch)
return any(a <= cp <= b for a, b in CJK_RANGES)
# Punctuation that marks a prosodic pause, by strength (used as break-priority
# floors when a long sentence must be split below max_length). Sentence-ending
# punctuation is intentionally NOT floored here -- the model already predicts
# those boundaries well, and overriding it would create false breaks after
# abbreviations like "A.I.".
CLAUSE_PUNCT = set(",;:)]}—–" # , ; : ) ] } em/en-dash
",、;:" # CJK , 、 ; :
"”’") # closing “ ” ’
CJK_SENT_PUNCT = set("。!?…") # 。 ! ? …
# Words that introduce a clause/phrase: breaking *before* one of these sounds
# more natural than a random word gap when a long span has no punctuation.
CONNECTORS = {
"and", "but", "or", "nor", "yet", "so", "for",
"which", "that", "who", "whom", "whose", "where", "when", "while",
"because", "although", "though", "since", "if", "unless", "until",
"after", "before", "as", "than", "whether",
}
FLOOR_CLAUSE = 0.25 # comma / semicolon / colon -> strongly preferred
FLOOR_CONNECTOR = 0.05 # break before "and/which/that..." in a comma-free span
FLOOR_HANZI = 5e-3 # between two Chinese chars (no spaces in zh)
FLOOR_SPACE = 1e-4 # plain word gap -> last-resort break
FORBID = 1e-9 # mid-word -> effectively never
def _connector_break_positions(text):
"""Indices i (break after char i) that sit right before a connector word."""
pos = set()
for m in re.finditer(r"\s+(\S+)", text):
word = m.group(1).strip(string.punctuation).lower()
if word in CONNECTORS and m.start() - 1 >= 0:
pos.add(m.start() - 1) # last char of the preceding word
return pos
def pause_aware_mask(probs, text):
"""Bias forced breaks toward natural prosodic pauses so TTS doesn't pause
mid-phrase. probs[i] = boundary prob *after* char i (between i and i+1).
Model-predicted sentence boundaries (high prob) are preserved as-is and keep
dominating. For everything else we raise a floor by pause strength:
clause punctuation (, ; : 、 , …) > connector word (and/which/that) >
plain word gap,
and mid-word positions are driven to ~0 so words/abbreviations are never cut.
The result: long sentences break at the nearest comma/clause in range, then
before a connecting word, and only at a bare space as a last resort.
"""
p = probs.copy()
n = len(text)
connectors = _connector_break_positions(text)
for i in range(n - 1): # never break before end-of-text marker
ch, nxt = text[i], text[i + 1]
ends_token = nxt.isspace() or _is_cjk(nxt)
if ch in CLAUSE_PUNCT and ends_token:
p[i] = max(p[i], FLOOR_CLAUSE)
elif ch in CJK_SENT_PUNCT: # zh sentence end
p[i] = max(p[i], 0.9)
elif i in connectors: # break before connector
p[i] = max(p[i], FLOOR_CONNECTOR)
elif nxt.isspace() or ch.isspace(): # plain word boundary
p[i] = max(p[i], FLOOR_SPACE)
elif _is_cjk(ch) and _is_cjk(nxt): # between hanzi
p[i] = max(p[i], FLOOR_HANZI)
else: # mid-word/abbreviation
p[i] = min(p[i], FORBID)
return p
# kept as an alias so existing imports (benchmark) keep working
word_safe_mask = pause_aware_mask
def boundary_probs(session, tokenizer, text, remap, unk_new):
ids_list, offsets, tokens = tokenizer.encode(text)
ids = np.array([ids_list], dtype=np.int64)
mask = np.ones_like(ids)
if remap is not None:
ids = remap[ids]
ids[ids == -1] = unk_new
logits = session.run(["logits"], {"input_ids": ids, "attention_mask": mask})[0]
char_logits = token_to_char_probs(text, tokens, logits[0],
tokenizer.special_tokens, offsets)
return 1.0 / (1.0 + np.exp(-char_logits[:, NEWLINE_INDEX]))
def main():
ap = argparse.ArgumentParser(description="Segment text with a local ONNX SaT model")
ap.add_argument("-m", "--model", help="model name (see menu if omitted)")
ap.add_argument("-t", "--text", help="text to segment")
ap.add_argument("-f", "--file", help="read text from this file")
ap.add_argument("--max-length", type=int, default=80, help="target max chars per chunk")
ap.add_argument("--min-length", type=int, default=40, help="min chars per chunk")
ap.add_argument("--overflow", type=int, default=0,
help="chars a chunk may exceed --max-length to reach a comma/"
"clause/sentence pause (soft cap; 0 = hard cap)")
ap.add_argument("--prior", default="gaussian",
choices=["uniform", "gaussian", "clipped_polynomial"])
ap.add_argument("--target", type=int, default=70, help="gaussian target length")
ap.add_argument("--spread", type=int, default=12, help="gaussian spread")
ap.add_argument("--algorithm", default="viterbi", choices=["viterbi", "greedy"])
ap.add_argument("--allow-midword", action="store_true",
help="permit breaks inside words/abbreviations (off by default)")
args = ap.parse_args()
models = find_models(MODELS_DIR)
if not models:
sys.exit(f"No ONNX models found under {MODELS_DIR}. Run build_and_test_onnx.py first.")
name = args.model or choose_model(models)
if name not in models:
sys.exit(f"Unknown model '{name}'. Choices: {', '.join(models)}")
path = models[name]
tokenizer = load_tokenizer()
remap = unk_new = None
if "en_zh" in name:
remap, unk_new = get_remap(tokenizer)
session = ort.InferenceSession(str(path), providers=["CPUExecutionProvider"])
text = get_text(args)
probs = boundary_probs(session, tokenizer, text, remap, unk_new)
if not args.allow_midword:
probs = word_safe_mask(probs, text)
# Hard ceiling for the DP. With --overflow, allow chunks past --max-length up
# to this ceiling; a decay tail past --max-length keeps plain spaces from
# exploiting the slack while still letting a strong pause (comma/sentence)
# pull the break into the overflow zone.
hard_max = args.max_length + max(0, args.overflow)
prior_kwargs = {"max_length": hard_max}
if args.prior != "uniform":
prior_kwargs.update(target_length=args.target, spread=args.spread)
base_prior = create_prior_function(args.prior, prior_kwargs)
if args.overflow > 0:
soft, decay = args.max_length, float(args.overflow)
prior = lambda L: base_prior(L) * ( # noqa: E731
1.0 if L <= soft else math.exp(-((L - soft) / decay) ** 2))
else:
prior = base_prior
idx = constrained_segmentation(probs, prior, min_length=args.min_length,
max_length=hard_max, algorithm=args.algorithm)
cuts = [0] + list(idx) + [len(text)]
chunks = [text[cuts[i]:cuts[i + 1]] for i in range(len(cuts) - 1)]
print(f"\nModel: {name} ({path.stat().st_size/1e6:.1f} MB)")
print(f"Config: max={args.max_length} overflow={args.overflow} "
f"min={args.min_length} prior={args.prior} algo={args.algorithm}")
print(f"Input: {len(text)} chars -> {len(chunks)} chunks\n")
for c in chunks:
n = len(c)
flag = "!" if n > hard_max else ("+" if n > args.max_length else " ")
print(f" {flag}[{n:3d}] {c.strip()[:90]}")
assert "".join(chunks) == text, "TEXT NOT PRESERVED"
print("\n ✓ text preserved (chunks rejoin to original)")
if __name__ == "__main__":
main()
|