File size: 8,344 Bytes
357ae2c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 | #!/usr/bin/env python3
"""
Build iOS Core ML variants of SaT models.
Produces, for each base model and vocab variant, four quantization levels:
fp16, int8 (linear), 6-bit palettized, 4-bit palettized.
Vocab pruning (EN+ZH): keeps only XLM-R tokens whose surface form is pure-ASCII
(covers English) or contains CJK characters (covers Chinese), plus all special
tokens. The word-embedding matrix is sliced to the kept rows and an
old->new id remap table is emitted so the app can feed remapped ids.
On Linux you can CONVERT and measure .mlpackage sizes, but not run them
(no Core ML runtime). Accuracy of the pruned torch model is validated here;
quantization fidelity must be checked on a Mac.
Usage:
python scripts/build_ios_coreml.py --out ios_models \
--models sat-1l-sm sat-3l-sm --seq-len 256
"""
import argparse
import json
import shutil
import unicodedata
from pathlib import Path
import numpy as np
import torch
import wtpsplit.models # noqa: F401 (registers SubwordXLM* with Auto classes)
from transformers import AutoModelForTokenClassification, AutoTokenizer
CJK_RANGES = [
(0x4E00, 0x9FFF), (0x3400, 0x4DBF), (0xF900, 0xFAFF),
(0x3000, 0x303F), (0xFF00, 0xFFEF),
]
def is_cjk(s: str) -> bool:
return any(any(a <= ord(c) <= b for a, b in CJK_RANGES) for c in s)
def compute_keep_ids(tokenizer) -> list[int]:
"""Token ids to keep for English + Chinese (ASCII or CJK), plus specials."""
vocab = tokenizer.get_vocab() # token -> id
keep = set(tokenizer.all_special_ids)
for tok, idx in vocab.items():
s = tok.replace("▁", " ") # SP underscore -> space
if all(ord(c) < 128 for c in s) or is_cjk(s):
keep.add(idx)
return sorted(keep)
def prune_embedding(model, keep_ids: list[int]):
"""Slice word_embeddings + classifier-independent; return old->new remap array."""
old_emb = model.roberta.embeddings.word_embeddings
keep = torch.tensor(keep_ids, dtype=torch.long)
new_weight = old_emb.weight.data[keep].clone()
new_emb = torch.nn.Embedding(len(keep_ids), new_weight.shape[1],
padding_idx=None)
new_emb.weight.data = new_weight
model.roberta.embeddings.word_embeddings = new_emb
model.config.vocab_size = len(keep_ids)
remap = np.full(old_emb.weight.shape[0], -1, dtype=np.int64)
for new_id, old_id in enumerate(keep_ids):
remap[old_id] = new_id
return remap # remap[old_id] = new_id, or -1 if dropped
class LogitsWrapper(torch.nn.Module):
"""Expose a clean (input_ids, attention_mask) -> logits signature for tracing."""
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, input_ids, attention_mask):
out = self.model(input_ids=input_ids, attention_mask=attention_mask,
return_dict=True)
return out.logits
def remap_ids(input_ids: torch.Tensor, remap: np.ndarray, unk_new_id: int):
arr = input_ids.cpu().numpy()
out = remap[arr]
out[out == -1] = unk_new_id
return torch.tensor(out, dtype=torch.long)
def dir_size_mb(path: Path) -> float:
total = sum(f.stat().st_size for f in path.rglob("*") if f.is_file())
return total / 1e6
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--out", default="ios_models")
ap.add_argument("--models", nargs="+", default=["sat-1l-sm", "sat-3l-sm"])
ap.add_argument("--vocabs", nargs="+", default=["full", "en_zh"])
ap.add_argument("--quants", nargs="+",
default=["fp16", "int8", "palette6", "palette4"])
ap.add_argument("--seq-len", type=int, default=256)
args = ap.parse_args()
import coremltools as ct
from coremltools.optimize.coreml import (
OpPalettizerConfig, OpLinearQuantizerConfig, OptimizationConfig,
palettize_weights, linear_quantize_weights,
)
out_root = Path(args.out)
out_root.mkdir(parents=True, exist_ok=True)
results = []
for short in args.models:
repo = f"segment-any-text/{short}"
tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base")
for vocab in args.vocabs:
print(f"\n=== {short} / vocab={vocab} ===", flush=True)
model = AutoModelForTokenClassification.from_pretrained(repo).eval()
variant_dir = out_root / f"{short}-{vocab}"
variant_dir.mkdir(parents=True, exist_ok=True)
if vocab == "en_zh":
keep_ids = compute_keep_ids(tokenizer)
remap = prune_embedding(model, keep_ids)
unk_new_id = int(remap[tokenizer.unk_token_id])
np.save(variant_dir / "old_to_new_ids.npy", remap)
json.dump(
{"kept_vocab_size": len(keep_ids),
"orig_vocab_size": int(len(remap)),
"unk_new_id": unk_new_id,
"pad_new_id": int(remap[tokenizer.pad_token_id])},
open(variant_dir / "prune_info.json", "w"), indent=2)
print(f" kept {len(keep_ids)} / {len(remap)} tokens", flush=True)
else:
remap, unk_new_id = None, None
vocab_size = model.config.vocab_size
n_params = sum(p.numel() for p in model.parameters())
print(f" vocab_size={vocab_size} params={n_params/1e6:.1f}M", flush=True)
# --- trace ---
wrapper = LogitsWrapper(model).eval()
ids = torch.randint(0, vocab_size, (1, args.seq_len), dtype=torch.long)
ids[:, 0] = model.config.bos_token_id if vocab == "full" else 0
mask = torch.ones((1, args.seq_len), dtype=torch.long)
with torch.no_grad():
traced = torch.jit.trace(wrapper, (ids, mask))
# --- convert to fp16 mlmodel (base for quantization) ---
mlmodel_fp16 = ct.convert(
traced,
inputs=[
ct.TensorType(name="input_ids", shape=(1, args.seq_len), dtype=np.int32),
ct.TensorType(name="attention_mask", shape=(1, args.seq_len), dtype=np.int32),
],
outputs=[ct.TensorType(name="logits")],
minimum_deployment_target=ct.target.iOS16,
compute_precision=ct.precision.FLOAT16,
convert_to="mlprogram",
)
for quant in args.quants:
tag = f"{short}-{vocab}-{quant}"
try:
if quant == "fp16":
mlm = mlmodel_fp16
elif quant == "int8":
cfg = OptimizationConfig(
global_config=OpLinearQuantizerConfig(
mode="linear_symmetric", dtype="int8", weight_threshold=512))
mlm = linear_quantize_weights(mlmodel_fp16, cfg)
elif quant.startswith("palette"):
nbits = int(quant.replace("palette", ""))
cfg = OptimizationConfig(
global_config=OpPalettizerConfig(
nbits=nbits, mode="kmeans", weight_threshold=512))
mlm = palettize_weights(mlmodel_fp16, cfg)
else:
raise ValueError(quant)
pkg = variant_dir / f"{tag}.mlpackage"
if pkg.exists():
shutil.rmtree(pkg)
mlm.save(str(pkg))
size = dir_size_mb(pkg)
results.append((tag, vocab_size, round(size, 1)))
print(f" [{quant:9s}] {size:7.1f} MB -> {pkg.name}", flush=True)
except Exception as e:
results.append((tag, vocab_size, f"FAIL: {e}"))
print(f" [{quant:9s}] FAILED: {e}", flush=True)
print("\n================ SUMMARY ================")
print(f"{'variant':32s}{'vocab':>8s}{'size MB':>10s}")
for tag, v, s in results:
print(f"{tag:32s}{v:8d}{str(s):>10s}")
json.dump([{"variant": t, "vocab": v, "size_mb": s} for t, v, s in results],
open(out_root / "sizes.json", "w"), indent=2)
print(f"\nWrote {out_root/'sizes.json'}")
if __name__ == "__main__":
main()
|