#!/usr/bin/env python3 """ Validate that EN+ZH vocab pruning does not change segmentation on target-language text. Compares newline-boundary probabilities of the pruned model against the original full-vocab model on English + Chinese samples. Reports max/mean probability delta and whether the predicted sentence boundaries (threshold 0.025) are identical. This validates the *pruning* step (lossless for in-vocab text). Quantization fidelity (int8/palettization) must be measured on a Mac with the Core ML runtime. """ import sys from pathlib import Path import numpy as np import torch sys.path.insert(0, str(Path(__file__).resolve().parent)) import wtpsplit.models # noqa: F401 from transformers import AutoModelForTokenClassification, AutoTokenizer from build_ios_coreml import compute_keep_ids, prune_embedding, remap_ids NEWLINE_INDEX = 0 # Constants.NEWLINE_INDEX in wtpsplit SAMPLES = [ "The quick brown fox jumps over the lazy dog. Pack my box with five dozen liquor jugs. How vexingly quick daft zebras jump!", "Anthropic has been on the rise. The company released a powerful model. It passed its rivals on Thursday.", "今天天气很好。我们一起去公园散步吧。下午还可以喝杯咖啡。", "他说:“这是最重要的发现。”大家都很激动。会议在日内瓦举行。", "This is English. 这是中文。Mixed text works too! 混合文本也可以。", ] def boundary_probs(model, tokenizer, text, remap=None, unk_new_id=None, seq_len=256): enc = tokenizer([text], return_tensors="pt", padding="max_length", max_length=seq_len, truncation=True) ids, mask = enc["input_ids"], enc["attention_mask"] if remap is not None: ids = remap_ids(ids, remap, unk_new_id) with torch.no_grad(): logits = model(input_ids=ids, attention_mask=mask, return_dict=True).logits probs = torch.sigmoid(logits[0, :, NEWLINE_INDEX]).cpu().numpy() return probs, mask[0].cpu().numpy() def main(): repo = "segment-any-text/sat-1l-sm" tok = AutoTokenizer.from_pretrained("xlm-roberta-base") orig = AutoModelForTokenClassification.from_pretrained(repo).eval() pruned = AutoModelForTokenClassification.from_pretrained(repo).eval() keep = compute_keep_ids(tok) remap = prune_embedding(pruned, keep) unk_new = int(remap[tok.unk_token_id]) print(f"Pruned vocab: {len(keep)} / {len(remap)} tokens " f"({100*len(keep)/len(remap):.1f}%)\n") thr = 0.025 all_ok = True for text in SAMPLES: po, mo = boundary_probs(orig, tok, text) pp, _ = boundary_probs(pruned, tok, text, remap, unk_new) valid = mo.astype(bool) delta = np.abs(po[valid] - pp[valid]) bo = (po[valid] > thr) bp = (pp[valid] > thr) same = bool(np.array_equal(bo, bp)) all_ok &= same print(f"{'OK ' if same else 'DIFF'} | max Δp={delta.max():.2e} " f"mean Δp={delta.mean():.2e} | boundaries identical={same}") print(f" {text[:60]!r}") print(f"\n{'ALL BOUNDARIES IDENTICAL ✓' if all_ok else 'SOME DIFFERENCES ✗'}") if __name__ == "__main__": main()