wtpsplit-kit / scripts /validate_pruned_vocab.py
krmanik's picture
Upload folder using huggingface_hub
357ae2c verified
#!/usr/bin/env python3
"""
Validate that EN+ZH vocab pruning does not change segmentation on target-language text.
Compares newline-boundary probabilities of the pruned model against the original
full-vocab model on English + Chinese samples. Reports max/mean probability delta
and whether the predicted sentence boundaries (threshold 0.025) are identical.
This validates the *pruning* step (lossless for in-vocab text). Quantization
fidelity (int8/palettization) must be measured on a Mac with the Core ML runtime.
"""
import sys
from pathlib import Path
import numpy as np
import torch
sys.path.insert(0, str(Path(__file__).resolve().parent))
import wtpsplit.models # noqa: F401
from transformers import AutoModelForTokenClassification, AutoTokenizer
from build_ios_coreml import compute_keep_ids, prune_embedding, remap_ids
NEWLINE_INDEX = 0 # Constants.NEWLINE_INDEX in wtpsplit
SAMPLES = [
"The quick brown fox jumps over the lazy dog. Pack my box with five dozen liquor jugs. How vexingly quick daft zebras jump!",
"Anthropic has been on the rise. The company released a powerful model. It passed its rivals on Thursday.",
"今天天气很好。我们一起去公园散步吧。下午还可以喝杯咖啡。",
"他说:“这是最重要的发现。”大家都很激动。会议在日内瓦举行。",
"This is English. 这是中文。Mixed text works too! 混合文本也可以。",
]
def boundary_probs(model, tokenizer, text, remap=None, unk_new_id=None, seq_len=256):
enc = tokenizer([text], return_tensors="pt", padding="max_length",
max_length=seq_len, truncation=True)
ids, mask = enc["input_ids"], enc["attention_mask"]
if remap is not None:
ids = remap_ids(ids, remap, unk_new_id)
with torch.no_grad():
logits = model(input_ids=ids, attention_mask=mask, return_dict=True).logits
probs = torch.sigmoid(logits[0, :, NEWLINE_INDEX]).cpu().numpy()
return probs, mask[0].cpu().numpy()
def main():
repo = "segment-any-text/sat-1l-sm"
tok = AutoTokenizer.from_pretrained("xlm-roberta-base")
orig = AutoModelForTokenClassification.from_pretrained(repo).eval()
pruned = AutoModelForTokenClassification.from_pretrained(repo).eval()
keep = compute_keep_ids(tok)
remap = prune_embedding(pruned, keep)
unk_new = int(remap[tok.unk_token_id])
print(f"Pruned vocab: {len(keep)} / {len(remap)} tokens "
f"({100*len(keep)/len(remap):.1f}%)\n")
thr = 0.025
all_ok = True
for text in SAMPLES:
po, mo = boundary_probs(orig, tok, text)
pp, _ = boundary_probs(pruned, tok, text, remap, unk_new)
valid = mo.astype(bool)
delta = np.abs(po[valid] - pp[valid])
bo = (po[valid] > thr)
bp = (pp[valid] > thr)
same = bool(np.array_equal(bo, bp))
all_ok &= same
print(f"{'OK ' if same else 'DIFF'} | max Δp={delta.max():.2e} "
f"mean Δp={delta.mean():.2e} | boundaries identical={same}")
print(f" {text[:60]!r}")
print(f"\n{'ALL BOUNDARIES IDENTICAL ✓' if all_ok else 'SOME DIFFERENCES ✗'}")
if __name__ == "__main__":
main()