wtpsplit-kit / scripts /build_ios_coreml.py
krmanik's picture
Upload folder using huggingface_hub
357ae2c verified
#!/usr/bin/env python3
"""
Build iOS Core ML variants of SaT models.
Produces, for each base model and vocab variant, four quantization levels:
fp16, int8 (linear), 6-bit palettized, 4-bit palettized.
Vocab pruning (EN+ZH): keeps only XLM-R tokens whose surface form is pure-ASCII
(covers English) or contains CJK characters (covers Chinese), plus all special
tokens. The word-embedding matrix is sliced to the kept rows and an
old->new id remap table is emitted so the app can feed remapped ids.
On Linux you can CONVERT and measure .mlpackage sizes, but not run them
(no Core ML runtime). Accuracy of the pruned torch model is validated here;
quantization fidelity must be checked on a Mac.
Usage:
python scripts/build_ios_coreml.py --out ios_models \
--models sat-1l-sm sat-3l-sm --seq-len 256
"""
import argparse
import json
import shutil
import unicodedata
from pathlib import Path
import numpy as np
import torch
import wtpsplit.models # noqa: F401 (registers SubwordXLM* with Auto classes)
from transformers import AutoModelForTokenClassification, AutoTokenizer
CJK_RANGES = [
(0x4E00, 0x9FFF), (0x3400, 0x4DBF), (0xF900, 0xFAFF),
(0x3000, 0x303F), (0xFF00, 0xFFEF),
]
def is_cjk(s: str) -> bool:
return any(any(a <= ord(c) <= b for a, b in CJK_RANGES) for c in s)
def compute_keep_ids(tokenizer) -> list[int]:
"""Token ids to keep for English + Chinese (ASCII or CJK), plus specials."""
vocab = tokenizer.get_vocab() # token -> id
keep = set(tokenizer.all_special_ids)
for tok, idx in vocab.items():
s = tok.replace("▁", " ") # SP underscore -> space
if all(ord(c) < 128 for c in s) or is_cjk(s):
keep.add(idx)
return sorted(keep)
def prune_embedding(model, keep_ids: list[int]):
"""Slice word_embeddings + classifier-independent; return old->new remap array."""
old_emb = model.roberta.embeddings.word_embeddings
keep = torch.tensor(keep_ids, dtype=torch.long)
new_weight = old_emb.weight.data[keep].clone()
new_emb = torch.nn.Embedding(len(keep_ids), new_weight.shape[1],
padding_idx=None)
new_emb.weight.data = new_weight
model.roberta.embeddings.word_embeddings = new_emb
model.config.vocab_size = len(keep_ids)
remap = np.full(old_emb.weight.shape[0], -1, dtype=np.int64)
for new_id, old_id in enumerate(keep_ids):
remap[old_id] = new_id
return remap # remap[old_id] = new_id, or -1 if dropped
class LogitsWrapper(torch.nn.Module):
"""Expose a clean (input_ids, attention_mask) -> logits signature for tracing."""
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, input_ids, attention_mask):
out = self.model(input_ids=input_ids, attention_mask=attention_mask,
return_dict=True)
return out.logits
def remap_ids(input_ids: torch.Tensor, remap: np.ndarray, unk_new_id: int):
arr = input_ids.cpu().numpy()
out = remap[arr]
out[out == -1] = unk_new_id
return torch.tensor(out, dtype=torch.long)
def dir_size_mb(path: Path) -> float:
total = sum(f.stat().st_size for f in path.rglob("*") if f.is_file())
return total / 1e6
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--out", default="ios_models")
ap.add_argument("--models", nargs="+", default=["sat-1l-sm", "sat-3l-sm"])
ap.add_argument("--vocabs", nargs="+", default=["full", "en_zh"])
ap.add_argument("--quants", nargs="+",
default=["fp16", "int8", "palette6", "palette4"])
ap.add_argument("--seq-len", type=int, default=256)
args = ap.parse_args()
import coremltools as ct
from coremltools.optimize.coreml import (
OpPalettizerConfig, OpLinearQuantizerConfig, OptimizationConfig,
palettize_weights, linear_quantize_weights,
)
out_root = Path(args.out)
out_root.mkdir(parents=True, exist_ok=True)
results = []
for short in args.models:
repo = f"segment-any-text/{short}"
tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base")
for vocab in args.vocabs:
print(f"\n=== {short} / vocab={vocab} ===", flush=True)
model = AutoModelForTokenClassification.from_pretrained(repo).eval()
variant_dir = out_root / f"{short}-{vocab}"
variant_dir.mkdir(parents=True, exist_ok=True)
if vocab == "en_zh":
keep_ids = compute_keep_ids(tokenizer)
remap = prune_embedding(model, keep_ids)
unk_new_id = int(remap[tokenizer.unk_token_id])
np.save(variant_dir / "old_to_new_ids.npy", remap)
json.dump(
{"kept_vocab_size": len(keep_ids),
"orig_vocab_size": int(len(remap)),
"unk_new_id": unk_new_id,
"pad_new_id": int(remap[tokenizer.pad_token_id])},
open(variant_dir / "prune_info.json", "w"), indent=2)
print(f" kept {len(keep_ids)} / {len(remap)} tokens", flush=True)
else:
remap, unk_new_id = None, None
vocab_size = model.config.vocab_size
n_params = sum(p.numel() for p in model.parameters())
print(f" vocab_size={vocab_size} params={n_params/1e6:.1f}M", flush=True)
# --- trace ---
wrapper = LogitsWrapper(model).eval()
ids = torch.randint(0, vocab_size, (1, args.seq_len), dtype=torch.long)
ids[:, 0] = model.config.bos_token_id if vocab == "full" else 0
mask = torch.ones((1, args.seq_len), dtype=torch.long)
with torch.no_grad():
traced = torch.jit.trace(wrapper, (ids, mask))
# --- convert to fp16 mlmodel (base for quantization) ---
mlmodel_fp16 = ct.convert(
traced,
inputs=[
ct.TensorType(name="input_ids", shape=(1, args.seq_len), dtype=np.int32),
ct.TensorType(name="attention_mask", shape=(1, args.seq_len), dtype=np.int32),
],
outputs=[ct.TensorType(name="logits")],
minimum_deployment_target=ct.target.iOS16,
compute_precision=ct.precision.FLOAT16,
convert_to="mlprogram",
)
for quant in args.quants:
tag = f"{short}-{vocab}-{quant}"
try:
if quant == "fp16":
mlm = mlmodel_fp16
elif quant == "int8":
cfg = OptimizationConfig(
global_config=OpLinearQuantizerConfig(
mode="linear_symmetric", dtype="int8", weight_threshold=512))
mlm = linear_quantize_weights(mlmodel_fp16, cfg)
elif quant.startswith("palette"):
nbits = int(quant.replace("palette", ""))
cfg = OptimizationConfig(
global_config=OpPalettizerConfig(
nbits=nbits, mode="kmeans", weight_threshold=512))
mlm = palettize_weights(mlmodel_fp16, cfg)
else:
raise ValueError(quant)
pkg = variant_dir / f"{tag}.mlpackage"
if pkg.exists():
shutil.rmtree(pkg)
mlm.save(str(pkg))
size = dir_size_mb(pkg)
results.append((tag, vocab_size, round(size, 1)))
print(f" [{quant:9s}] {size:7.1f} MB -> {pkg.name}", flush=True)
except Exception as e:
results.append((tag, vocab_size, f"FAIL: {e}"))
print(f" [{quant:9s}] FAILED: {e}", flush=True)
print("\n================ SUMMARY ================")
print(f"{'variant':32s}{'vocab':>8s}{'size MB':>10s}")
for tag, v, s in results:
print(f"{tag:32s}{v:8d}{str(s):>10s}")
json.dump([{"variant": t, "vocab": v, "size_mb": s} for t, v, s in results],
open(out_root / "sizes.json", "w"), indent=2)
print(f"\nWrote {out_root/'sizes.json'}")
if __name__ == "__main__":
main()