#!/usr/bin/env python3 """ Build iOS Core ML variants of SaT models. Produces, for each base model and vocab variant, four quantization levels: fp16, int8 (linear), 6-bit palettized, 4-bit palettized. Vocab pruning (EN+ZH): keeps only XLM-R tokens whose surface form is pure-ASCII (covers English) or contains CJK characters (covers Chinese), plus all special tokens. The word-embedding matrix is sliced to the kept rows and an old->new id remap table is emitted so the app can feed remapped ids. On Linux you can CONVERT and measure .mlpackage sizes, but not run them (no Core ML runtime). Accuracy of the pruned torch model is validated here; quantization fidelity must be checked on a Mac. Usage: python scripts/build_ios_coreml.py --out ios_models \ --models sat-1l-sm sat-3l-sm --seq-len 256 """ import argparse import json import shutil import unicodedata from pathlib import Path import numpy as np import torch import wtpsplit.models # noqa: F401 (registers SubwordXLM* with Auto classes) from transformers import AutoModelForTokenClassification, AutoTokenizer CJK_RANGES = [ (0x4E00, 0x9FFF), (0x3400, 0x4DBF), (0xF900, 0xFAFF), (0x3000, 0x303F), (0xFF00, 0xFFEF), ] def is_cjk(s: str) -> bool: return any(any(a <= ord(c) <= b for a, b in CJK_RANGES) for c in s) def compute_keep_ids(tokenizer) -> list[int]: """Token ids to keep for English + Chinese (ASCII or CJK), plus specials.""" vocab = tokenizer.get_vocab() # token -> id keep = set(tokenizer.all_special_ids) for tok, idx in vocab.items(): s = tok.replace("▁", " ") # SP underscore -> space if all(ord(c) < 128 for c in s) or is_cjk(s): keep.add(idx) return sorted(keep) def prune_embedding(model, keep_ids: list[int]): """Slice word_embeddings + classifier-independent; return old->new remap array.""" old_emb = model.roberta.embeddings.word_embeddings keep = torch.tensor(keep_ids, dtype=torch.long) new_weight = old_emb.weight.data[keep].clone() new_emb = torch.nn.Embedding(len(keep_ids), new_weight.shape[1], padding_idx=None) new_emb.weight.data = new_weight model.roberta.embeddings.word_embeddings = new_emb model.config.vocab_size = len(keep_ids) remap = np.full(old_emb.weight.shape[0], -1, dtype=np.int64) for new_id, old_id in enumerate(keep_ids): remap[old_id] = new_id return remap # remap[old_id] = new_id, or -1 if dropped class LogitsWrapper(torch.nn.Module): """Expose a clean (input_ids, attention_mask) -> logits signature for tracing.""" def __init__(self, model): super().__init__() self.model = model def forward(self, input_ids, attention_mask): out = self.model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True) return out.logits def remap_ids(input_ids: torch.Tensor, remap: np.ndarray, unk_new_id: int): arr = input_ids.cpu().numpy() out = remap[arr] out[out == -1] = unk_new_id return torch.tensor(out, dtype=torch.long) def dir_size_mb(path: Path) -> float: total = sum(f.stat().st_size for f in path.rglob("*") if f.is_file()) return total / 1e6 def main(): ap = argparse.ArgumentParser() ap.add_argument("--out", default="ios_models") ap.add_argument("--models", nargs="+", default=["sat-1l-sm", "sat-3l-sm"]) ap.add_argument("--vocabs", nargs="+", default=["full", "en_zh"]) ap.add_argument("--quants", nargs="+", default=["fp16", "int8", "palette6", "palette4"]) ap.add_argument("--seq-len", type=int, default=256) args = ap.parse_args() import coremltools as ct from coremltools.optimize.coreml import ( OpPalettizerConfig, OpLinearQuantizerConfig, OptimizationConfig, palettize_weights, linear_quantize_weights, ) out_root = Path(args.out) out_root.mkdir(parents=True, exist_ok=True) results = [] for short in args.models: repo = f"segment-any-text/{short}" tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base") for vocab in args.vocabs: print(f"\n=== {short} / vocab={vocab} ===", flush=True) model = AutoModelForTokenClassification.from_pretrained(repo).eval() variant_dir = out_root / f"{short}-{vocab}" variant_dir.mkdir(parents=True, exist_ok=True) if vocab == "en_zh": keep_ids = compute_keep_ids(tokenizer) remap = prune_embedding(model, keep_ids) unk_new_id = int(remap[tokenizer.unk_token_id]) np.save(variant_dir / "old_to_new_ids.npy", remap) json.dump( {"kept_vocab_size": len(keep_ids), "orig_vocab_size": int(len(remap)), "unk_new_id": unk_new_id, "pad_new_id": int(remap[tokenizer.pad_token_id])}, open(variant_dir / "prune_info.json", "w"), indent=2) print(f" kept {len(keep_ids)} / {len(remap)} tokens", flush=True) else: remap, unk_new_id = None, None vocab_size = model.config.vocab_size n_params = sum(p.numel() for p in model.parameters()) print(f" vocab_size={vocab_size} params={n_params/1e6:.1f}M", flush=True) # --- trace --- wrapper = LogitsWrapper(model).eval() ids = torch.randint(0, vocab_size, (1, args.seq_len), dtype=torch.long) ids[:, 0] = model.config.bos_token_id if vocab == "full" else 0 mask = torch.ones((1, args.seq_len), dtype=torch.long) with torch.no_grad(): traced = torch.jit.trace(wrapper, (ids, mask)) # --- convert to fp16 mlmodel (base for quantization) --- mlmodel_fp16 = ct.convert( traced, inputs=[ ct.TensorType(name="input_ids", shape=(1, args.seq_len), dtype=np.int32), ct.TensorType(name="attention_mask", shape=(1, args.seq_len), dtype=np.int32), ], outputs=[ct.TensorType(name="logits")], minimum_deployment_target=ct.target.iOS16, compute_precision=ct.precision.FLOAT16, convert_to="mlprogram", ) for quant in args.quants: tag = f"{short}-{vocab}-{quant}" try: if quant == "fp16": mlm = mlmodel_fp16 elif quant == "int8": cfg = OptimizationConfig( global_config=OpLinearQuantizerConfig( mode="linear_symmetric", dtype="int8", weight_threshold=512)) mlm = linear_quantize_weights(mlmodel_fp16, cfg) elif quant.startswith("palette"): nbits = int(quant.replace("palette", "")) cfg = OptimizationConfig( global_config=OpPalettizerConfig( nbits=nbits, mode="kmeans", weight_threshold=512)) mlm = palettize_weights(mlmodel_fp16, cfg) else: raise ValueError(quant) pkg = variant_dir / f"{tag}.mlpackage" if pkg.exists(): shutil.rmtree(pkg) mlm.save(str(pkg)) size = dir_size_mb(pkg) results.append((tag, vocab_size, round(size, 1))) print(f" [{quant:9s}] {size:7.1f} MB -> {pkg.name}", flush=True) except Exception as e: results.append((tag, vocab_size, f"FAIL: {e}")) print(f" [{quant:9s}] FAILED: {e}", flush=True) print("\n================ SUMMARY ================") print(f"{'variant':32s}{'vocab':>8s}{'size MB':>10s}") for tag, v, s in results: print(f"{tag:32s}{v:8d}{str(s):>10s}") json.dump([{"variant": t, "vocab": v, "size_mb": s} for t, v, s in results], open(out_root / "sizes.json", "w"), indent=2) print(f"\nWrote {out_root/'sizes.json'}") if __name__ == "__main__": main()