File size: 8,344 Bytes
357ae2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
#!/usr/bin/env python3
"""
Build iOS Core ML variants of SaT models.

Produces, for each base model and vocab variant, four quantization levels:
  fp16, int8 (linear), 6-bit palettized, 4-bit palettized.

Vocab pruning (EN+ZH): keeps only XLM-R tokens whose surface form is pure-ASCII
(covers English) or contains CJK characters (covers Chinese), plus all special
tokens. The word-embedding matrix is sliced to the kept rows and an
old->new id remap table is emitted so the app can feed remapped ids.

On Linux you can CONVERT and measure .mlpackage sizes, but not run them
(no Core ML runtime). Accuracy of the pruned torch model is validated here;
quantization fidelity must be checked on a Mac.

Usage:
  python scripts/build_ios_coreml.py --out ios_models \
      --models sat-1l-sm sat-3l-sm --seq-len 256
"""
import argparse
import json
import shutil
import unicodedata
from pathlib import Path

import numpy as np
import torch

import wtpsplit.models  # noqa: F401  (registers SubwordXLM* with Auto classes)
from transformers import AutoModelForTokenClassification, AutoTokenizer

CJK_RANGES = [
    (0x4E00, 0x9FFF), (0x3400, 0x4DBF), (0xF900, 0xFAFF),
    (0x3000, 0x303F), (0xFF00, 0xFFEF),
]


def is_cjk(s: str) -> bool:
    return any(any(a <= ord(c) <= b for a, b in CJK_RANGES) for c in s)


def compute_keep_ids(tokenizer) -> list[int]:
    """Token ids to keep for English + Chinese (ASCII or CJK), plus specials."""
    vocab = tokenizer.get_vocab()  # token -> id
    keep = set(tokenizer.all_special_ids)
    for tok, idx in vocab.items():
        s = tok.replace("▁", " ")  # SP underscore -> space
        if all(ord(c) < 128 for c in s) or is_cjk(s):
            keep.add(idx)
    return sorted(keep)


def prune_embedding(model, keep_ids: list[int]):
    """Slice word_embeddings + classifier-independent; return old->new remap array."""
    old_emb = model.roberta.embeddings.word_embeddings
    keep = torch.tensor(keep_ids, dtype=torch.long)
    new_weight = old_emb.weight.data[keep].clone()
    new_emb = torch.nn.Embedding(len(keep_ids), new_weight.shape[1],
                                 padding_idx=None)
    new_emb.weight.data = new_weight
    model.roberta.embeddings.word_embeddings = new_emb
    model.config.vocab_size = len(keep_ids)

    remap = np.full(old_emb.weight.shape[0], -1, dtype=np.int64)
    for new_id, old_id in enumerate(keep_ids):
        remap[old_id] = new_id
    return remap  # remap[old_id] = new_id, or -1 if dropped


class LogitsWrapper(torch.nn.Module):
    """Expose a clean (input_ids, attention_mask) -> logits signature for tracing."""

    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, input_ids, attention_mask):
        out = self.model(input_ids=input_ids, attention_mask=attention_mask,
                         return_dict=True)
        return out.logits


def remap_ids(input_ids: torch.Tensor, remap: np.ndarray, unk_new_id: int):
    arr = input_ids.cpu().numpy()
    out = remap[arr]
    out[out == -1] = unk_new_id
    return torch.tensor(out, dtype=torch.long)


def dir_size_mb(path: Path) -> float:
    total = sum(f.stat().st_size for f in path.rglob("*") if f.is_file())
    return total / 1e6


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--out", default="ios_models")
    ap.add_argument("--models", nargs="+", default=["sat-1l-sm", "sat-3l-sm"])
    ap.add_argument("--vocabs", nargs="+", default=["full", "en_zh"])
    ap.add_argument("--quants", nargs="+",
                    default=["fp16", "int8", "palette6", "palette4"])
    ap.add_argument("--seq-len", type=int, default=256)
    args = ap.parse_args()

    import coremltools as ct
    from coremltools.optimize.coreml import (
        OpPalettizerConfig, OpLinearQuantizerConfig, OptimizationConfig,
        palettize_weights, linear_quantize_weights,
    )

    out_root = Path(args.out)
    out_root.mkdir(parents=True, exist_ok=True)
    results = []

    for short in args.models:
        repo = f"segment-any-text/{short}"
        tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base")

        for vocab in args.vocabs:
            print(f"\n=== {short} / vocab={vocab} ===", flush=True)
            model = AutoModelForTokenClassification.from_pretrained(repo).eval()

            variant_dir = out_root / f"{short}-{vocab}"
            variant_dir.mkdir(parents=True, exist_ok=True)

            if vocab == "en_zh":
                keep_ids = compute_keep_ids(tokenizer)
                remap = prune_embedding(model, keep_ids)
                unk_new_id = int(remap[tokenizer.unk_token_id])
                np.save(variant_dir / "old_to_new_ids.npy", remap)
                json.dump(
                    {"kept_vocab_size": len(keep_ids),
                     "orig_vocab_size": int(len(remap)),
                     "unk_new_id": unk_new_id,
                     "pad_new_id": int(remap[tokenizer.pad_token_id])},
                    open(variant_dir / "prune_info.json", "w"), indent=2)
                print(f"  kept {len(keep_ids)} / {len(remap)} tokens", flush=True)
            else:
                remap, unk_new_id = None, None

            vocab_size = model.config.vocab_size
            n_params = sum(p.numel() for p in model.parameters())
            print(f"  vocab_size={vocab_size}  params={n_params/1e6:.1f}M", flush=True)

            # --- trace ---
            wrapper = LogitsWrapper(model).eval()
            ids = torch.randint(0, vocab_size, (1, args.seq_len), dtype=torch.long)
            ids[:, 0] = model.config.bos_token_id if vocab == "full" else 0
            mask = torch.ones((1, args.seq_len), dtype=torch.long)
            with torch.no_grad():
                traced = torch.jit.trace(wrapper, (ids, mask))

            # --- convert to fp16 mlmodel (base for quantization) ---
            mlmodel_fp16 = ct.convert(
                traced,
                inputs=[
                    ct.TensorType(name="input_ids", shape=(1, args.seq_len), dtype=np.int32),
                    ct.TensorType(name="attention_mask", shape=(1, args.seq_len), dtype=np.int32),
                ],
                outputs=[ct.TensorType(name="logits")],
                minimum_deployment_target=ct.target.iOS16,
                compute_precision=ct.precision.FLOAT16,
                convert_to="mlprogram",
            )

            for quant in args.quants:
                tag = f"{short}-{vocab}-{quant}"
                try:
                    if quant == "fp16":
                        mlm = mlmodel_fp16
                    elif quant == "int8":
                        cfg = OptimizationConfig(
                            global_config=OpLinearQuantizerConfig(
                                mode="linear_symmetric", dtype="int8", weight_threshold=512))
                        mlm = linear_quantize_weights(mlmodel_fp16, cfg)
                    elif quant.startswith("palette"):
                        nbits = int(quant.replace("palette", ""))
                        cfg = OptimizationConfig(
                            global_config=OpPalettizerConfig(
                                nbits=nbits, mode="kmeans", weight_threshold=512))
                        mlm = palettize_weights(mlmodel_fp16, cfg)
                    else:
                        raise ValueError(quant)

                    pkg = variant_dir / f"{tag}.mlpackage"
                    if pkg.exists():
                        shutil.rmtree(pkg)
                    mlm.save(str(pkg))
                    size = dir_size_mb(pkg)
                    results.append((tag, vocab_size, round(size, 1)))
                    print(f"  [{quant:9s}] {size:7.1f} MB  -> {pkg.name}", flush=True)
                except Exception as e:
                    results.append((tag, vocab_size, f"FAIL: {e}"))
                    print(f"  [{quant:9s}] FAILED: {e}", flush=True)

    print("\n================ SUMMARY ================")
    print(f"{'variant':32s}{'vocab':>8s}{'size MB':>10s}")
    for tag, v, s in results:
        print(f"{tag:32s}{v:8d}{str(s):>10s}")
    json.dump([{"variant": t, "vocab": v, "size_mb": s} for t, v, s in results],
              open(out_root / "sizes.json", "w"), indent=2)
    print(f"\nWrote {out_root/'sizes.json'}")


if __name__ == "__main__":
    main()