| |
| """ |
| Build iOS Core ML variants of SaT models. |
| |
| Produces, for each base model and vocab variant, four quantization levels: |
| fp16, int8 (linear), 6-bit palettized, 4-bit palettized. |
| |
| Vocab pruning (EN+ZH): keeps only XLM-R tokens whose surface form is pure-ASCII |
| (covers English) or contains CJK characters (covers Chinese), plus all special |
| tokens. The word-embedding matrix is sliced to the kept rows and an |
| old->new id remap table is emitted so the app can feed remapped ids. |
| |
| On Linux you can CONVERT and measure .mlpackage sizes, but not run them |
| (no Core ML runtime). Accuracy of the pruned torch model is validated here; |
| quantization fidelity must be checked on a Mac. |
| |
| Usage: |
| python scripts/build_ios_coreml.py --out ios_models \ |
| --models sat-1l-sm sat-3l-sm --seq-len 256 |
| """ |
| import argparse |
| import json |
| import shutil |
| import unicodedata |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
|
|
| import wtpsplit.models |
| from transformers import AutoModelForTokenClassification, AutoTokenizer |
|
|
| CJK_RANGES = [ |
| (0x4E00, 0x9FFF), (0x3400, 0x4DBF), (0xF900, 0xFAFF), |
| (0x3000, 0x303F), (0xFF00, 0xFFEF), |
| ] |
|
|
|
|
| def is_cjk(s: str) -> bool: |
| return any(any(a <= ord(c) <= b for a, b in CJK_RANGES) for c in s) |
|
|
|
|
| def compute_keep_ids(tokenizer) -> list[int]: |
| """Token ids to keep for English + Chinese (ASCII or CJK), plus specials.""" |
| vocab = tokenizer.get_vocab() |
| keep = set(tokenizer.all_special_ids) |
| for tok, idx in vocab.items(): |
| s = tok.replace("▁", " ") |
| if all(ord(c) < 128 for c in s) or is_cjk(s): |
| keep.add(idx) |
| return sorted(keep) |
|
|
|
|
| def prune_embedding(model, keep_ids: list[int]): |
| """Slice word_embeddings + classifier-independent; return old->new remap array.""" |
| old_emb = model.roberta.embeddings.word_embeddings |
| keep = torch.tensor(keep_ids, dtype=torch.long) |
| new_weight = old_emb.weight.data[keep].clone() |
| new_emb = torch.nn.Embedding(len(keep_ids), new_weight.shape[1], |
| padding_idx=None) |
| new_emb.weight.data = new_weight |
| model.roberta.embeddings.word_embeddings = new_emb |
| model.config.vocab_size = len(keep_ids) |
|
|
| remap = np.full(old_emb.weight.shape[0], -1, dtype=np.int64) |
| for new_id, old_id in enumerate(keep_ids): |
| remap[old_id] = new_id |
| return remap |
|
|
|
|
| class LogitsWrapper(torch.nn.Module): |
| """Expose a clean (input_ids, attention_mask) -> logits signature for tracing.""" |
|
|
| def __init__(self, model): |
| super().__init__() |
| self.model = model |
|
|
| def forward(self, input_ids, attention_mask): |
| out = self.model(input_ids=input_ids, attention_mask=attention_mask, |
| return_dict=True) |
| return out.logits |
|
|
|
|
| def remap_ids(input_ids: torch.Tensor, remap: np.ndarray, unk_new_id: int): |
| arr = input_ids.cpu().numpy() |
| out = remap[arr] |
| out[out == -1] = unk_new_id |
| return torch.tensor(out, dtype=torch.long) |
|
|
|
|
| def dir_size_mb(path: Path) -> float: |
| total = sum(f.stat().st_size for f in path.rglob("*") if f.is_file()) |
| return total / 1e6 |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--out", default="ios_models") |
| ap.add_argument("--models", nargs="+", default=["sat-1l-sm", "sat-3l-sm"]) |
| ap.add_argument("--vocabs", nargs="+", default=["full", "en_zh"]) |
| ap.add_argument("--quants", nargs="+", |
| default=["fp16", "int8", "palette6", "palette4"]) |
| ap.add_argument("--seq-len", type=int, default=256) |
| args = ap.parse_args() |
|
|
| import coremltools as ct |
| from coremltools.optimize.coreml import ( |
| OpPalettizerConfig, OpLinearQuantizerConfig, OptimizationConfig, |
| palettize_weights, linear_quantize_weights, |
| ) |
|
|
| out_root = Path(args.out) |
| out_root.mkdir(parents=True, exist_ok=True) |
| results = [] |
|
|
| for short in args.models: |
| repo = f"segment-any-text/{short}" |
| tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base") |
|
|
| for vocab in args.vocabs: |
| print(f"\n=== {short} / vocab={vocab} ===", flush=True) |
| model = AutoModelForTokenClassification.from_pretrained(repo).eval() |
|
|
| variant_dir = out_root / f"{short}-{vocab}" |
| variant_dir.mkdir(parents=True, exist_ok=True) |
|
|
| if vocab == "en_zh": |
| keep_ids = compute_keep_ids(tokenizer) |
| remap = prune_embedding(model, keep_ids) |
| unk_new_id = int(remap[tokenizer.unk_token_id]) |
| np.save(variant_dir / "old_to_new_ids.npy", remap) |
| json.dump( |
| {"kept_vocab_size": len(keep_ids), |
| "orig_vocab_size": int(len(remap)), |
| "unk_new_id": unk_new_id, |
| "pad_new_id": int(remap[tokenizer.pad_token_id])}, |
| open(variant_dir / "prune_info.json", "w"), indent=2) |
| print(f" kept {len(keep_ids)} / {len(remap)} tokens", flush=True) |
| else: |
| remap, unk_new_id = None, None |
|
|
| vocab_size = model.config.vocab_size |
| n_params = sum(p.numel() for p in model.parameters()) |
| print(f" vocab_size={vocab_size} params={n_params/1e6:.1f}M", flush=True) |
|
|
| |
| wrapper = LogitsWrapper(model).eval() |
| ids = torch.randint(0, vocab_size, (1, args.seq_len), dtype=torch.long) |
| ids[:, 0] = model.config.bos_token_id if vocab == "full" else 0 |
| mask = torch.ones((1, args.seq_len), dtype=torch.long) |
| with torch.no_grad(): |
| traced = torch.jit.trace(wrapper, (ids, mask)) |
|
|
| |
| mlmodel_fp16 = ct.convert( |
| traced, |
| inputs=[ |
| ct.TensorType(name="input_ids", shape=(1, args.seq_len), dtype=np.int32), |
| ct.TensorType(name="attention_mask", shape=(1, args.seq_len), dtype=np.int32), |
| ], |
| outputs=[ct.TensorType(name="logits")], |
| minimum_deployment_target=ct.target.iOS16, |
| compute_precision=ct.precision.FLOAT16, |
| convert_to="mlprogram", |
| ) |
|
|
| for quant in args.quants: |
| tag = f"{short}-{vocab}-{quant}" |
| try: |
| if quant == "fp16": |
| mlm = mlmodel_fp16 |
| elif quant == "int8": |
| cfg = OptimizationConfig( |
| global_config=OpLinearQuantizerConfig( |
| mode="linear_symmetric", dtype="int8", weight_threshold=512)) |
| mlm = linear_quantize_weights(mlmodel_fp16, cfg) |
| elif quant.startswith("palette"): |
| nbits = int(quant.replace("palette", "")) |
| cfg = OptimizationConfig( |
| global_config=OpPalettizerConfig( |
| nbits=nbits, mode="kmeans", weight_threshold=512)) |
| mlm = palettize_weights(mlmodel_fp16, cfg) |
| else: |
| raise ValueError(quant) |
|
|
| pkg = variant_dir / f"{tag}.mlpackage" |
| if pkg.exists(): |
| shutil.rmtree(pkg) |
| mlm.save(str(pkg)) |
| size = dir_size_mb(pkg) |
| results.append((tag, vocab_size, round(size, 1))) |
| print(f" [{quant:9s}] {size:7.1f} MB -> {pkg.name}", flush=True) |
| except Exception as e: |
| results.append((tag, vocab_size, f"FAIL: {e}")) |
| print(f" [{quant:9s}] FAILED: {e}", flush=True) |
|
|
| print("\n================ SUMMARY ================") |
| print(f"{'variant':32s}{'vocab':>8s}{'size MB':>10s}") |
| for tag, v, s in results: |
| print(f"{tag:32s}{v:8d}{str(s):>10s}") |
| json.dump([{"variant": t, "vocab": v, "size_mb": s} for t, v, s in results], |
| open(out_root / "sizes.json", "w"), indent=2) |
| print(f"\nWrote {out_root/'sizes.json'}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|