File size: 12,612 Bytes
424bd46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
#!/usr/bin/env -S uv run --script

# /// script
# requires-python = ">=3.10,<3.14"
# dependencies = [
#   "coremltools",
#   "open_clip_torch",
#   "transformers",
#   "torch",
#   "torchvision",
#   "numpy",
# ]
# ///

"""Convert SigLIP2's text encoder to Core ML β€” companion to the image
converter (convert_to_coreml_mobileclip.py).

Why this script is separate from the image converter:
  - Image branch converts cleanly because Vision Transformer attention is
    written using primitive matmul ops in open_clip's TimmModel wrapper.
  - Text branch uses PyTorch's nn.MultiheadAttention which dispatches to
    F._native_multi_head_attention β€” a fused C++ kernel that coremltools
    has no converter for. This script applies the standard workaround:
    replace nn.MultiheadAttention with manual matmul attention before
    tracing. Mathematically identical, traces cleanly, converts at 0.999996
    cosine vs PyTorch.

Caveat (read before publishing): SigLIP2 uses Gemma2's 256K-token vocabulary.
The token embedding alone is 393 MB at fp16, pushing the full text encoder
to ~565 MB. Even at 8-bit palettization it's ~280 MB. For most use cases,
running the text encoder via open_clip on PyTorch (one-shot, ~50ms per query)
is more practical than shipping this artifact.

Usage:
    # fp16 conversion (565 MB)
    ./convert_to_coreml_text_siglip2.py

    # 8-bit palettized (~280 MB) β€” slow, ~30 minutes for k-means
    ./convert_to_coreml_text_siglip2.py --palettize 8

    # Use a different SigLIP variant (untested but should work for any open_clip
    # model that exposes encode_text + has nn.MultiheadAttention modules)
    ./convert_to_coreml_text_siglip2.py --model ViT-B-16-SigLIP2-256
"""

import argparse
import sys
import time
from pathlib import Path

import coremltools as ct
import numpy as np
import open_clip
import torch
import torch.nn as nn
import torch.nn.functional as F


class SeparateProjMHA(nn.Module):
    """Drop-in replacement for nn.MultiheadAttention using three separate
    Q/K/V Linear layers + manual matmul attention. Traces cleanly through
    coremltools while remaining mathematically identical to the source.

    Constructed from an existing nn.MultiheadAttention; weights are sliced
    from its fused in_proj_weight into three separate Linears.
    """

    def __init__(self, src: nn.MultiheadAttention):
        super().__init__()
        self.embed_dim = src.embed_dim
        self.num_heads = src.num_heads
        self.head_dim = src.embed_dim // src.num_heads
        self.scale = self.head_dim ** -0.5
        self.batch_first = src.batch_first

        E = src.embed_dim
        w = src.in_proj_weight.detach()
        b = src.in_proj_bias.detach() if src.in_proj_bias is not None else None
        self.q_proj = nn.Linear(E, E, bias=b is not None)
        self.k_proj = nn.Linear(E, E, bias=b is not None)
        self.v_proj = nn.Linear(E, E, bias=b is not None)
        with torch.no_grad():
            self.q_proj.weight.copy_(w[:E])
            self.k_proj.weight.copy_(w[E:2 * E])
            self.v_proj.weight.copy_(w[2 * E:])
            if b is not None:
                self.q_proj.bias.copy_(b[:E])
                self.k_proj.bias.copy_(b[E:2 * E])
                self.v_proj.bias.copy_(b[2 * E:])

        self.out_proj = nn.Linear(E, E, bias=src.out_proj.bias is not None)
        self.out_proj.weight = nn.Parameter(src.out_proj.weight.detach().clone())
        if src.out_proj.bias is not None:
            self.out_proj.bias = nn.Parameter(src.out_proj.bias.detach().clone())

    def forward(self, query, key, value, **kwargs):
        # Self-attention only β€” sufficient for SigLIP/CLIP text encoders.
        # The (key, value) args are ignored; attention is always over `query`.
        if self.batch_first:
            B, N, C = query.shape
        else:
            N, B, C = query.shape
            query = query.transpose(0, 1)
        q = self.q_proj(query).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(query).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(query).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        out = (attn @ v).transpose(1, 2).reshape(B, N, C)
        out = self.out_proj(out)
        if not self.batch_first:
            out = out.transpose(0, 1)
        # nn.MultiheadAttention returns (output, attn_weights); we never
        # produce the latter β€” passing None is fine because the caller
        # discards it in standard CLIP-family text encoders.
        return out, None


def patch_mha_modules(text_branch: nn.Module) -> int:
    """In-place replacement of every nn.MultiheadAttention with SeparateProjMHA.
    Returns the number of replacements made."""
    targets = [(name, m) for name, m in text_branch.named_modules()
               if isinstance(m, nn.MultiheadAttention)]
    for name, mod in targets:
        parent = text_branch
        parts = name.split(".")
        for part in parts[:-1]:
            parent = getattr(parent, part)
        setattr(parent, parts[-1], SeparateProjMHA(mod))
    return len(targets)


class L2NormTextEncoder(nn.Module):
    """Wraps an open_clip model so the converted Core ML output is already
    L2-normalized (matches the convention of our image-encoder converter)."""
    def __init__(self, m):
        super().__init__()
        self.m = m

    def forward(self, tokens):
        f = self.m.encode_text(tokens)
        return f / f.norm(dim=-1, keepdim=True)


def convert(model_name: str, pretrained: str, output_dir: Path,
            context_length: int = 64) -> tuple[Path, int]:
    """End-to-end: load β†’ patch MHA β†’ trace β†’ convert β†’ save."""
    print(f"[1/4] loading {model_name} ({pretrained}) …", flush=True)
    t0 = time.perf_counter()
    model, _, _ = open_clip.create_model_and_transforms(model_name, pretrained=pretrained)
    tokenizer = open_clip.get_tokenizer(model_name)
    model.eval()
    print(f"      {time.perf_counter()-t0:.1f}s", flush=True)

    print(f"[2/4] patching nn.MultiheadAttention modules β†’ SeparateProjMHA …",
          flush=True)
    n_patched = patch_mha_modules(model.text)
    print(f"      patched {n_patched} attention modules", flush=True)
    if n_patched == 0:
        sys.exit("no nn.MultiheadAttention modules found β€” model may already be "
                 "convertible, or use a different attention implementation")

    txt_enc = L2NormTextEncoder(model).eval()
    sample_tokens = tokenizer(["calibration"])
    if sample_tokens.shape[1] != context_length:
        # Pad/truncate to user-specified context_length so the .mlpackage has
        # a fixed input shape matching what the searcher will pass.
        if sample_tokens.shape[1] > context_length:
            sample_tokens = sample_tokens[:, :context_length]
        else:
            pad = torch.zeros(1, context_length - sample_tokens.shape[1],
                              dtype=sample_tokens.dtype)
            sample_tokens = torch.cat([sample_tokens, pad], dim=1)
    print(f"[3/4] tracing at context_length={context_length} …", flush=True)
    t0 = time.perf_counter()
    with torch.no_grad():
        traced = torch.jit.trace(txt_enc, sample_tokens)
    print(f"      {time.perf_counter()-t0:.1f}s", flush=True)

    print(f"[4/4] converting to Core ML (fp16) …", flush=True)
    t0 = time.perf_counter()
    ml = ct.convert(
        traced,
        inputs=[ct.TensorType(name="text", shape=(1, context_length), dtype=np.int32)],
        outputs=[ct.TensorType(name="embedding")],
        compute_units=ct.ComputeUnit.CPU_AND_NE,
        minimum_deployment_target=ct.target.macOS14,
    )
    print(f"      {time.perf_counter()-t0:.1f}s", flush=True)

    output_dir.mkdir(parents=True, exist_ok=True)
    out_path = output_dir / f"{model_name}_text.mlpackage"
    ml.save(str(out_path))
    out_size = sum(f.stat().st_size for f in out_path.rglob("*") if f.is_file())
    print(f"      saved β†’ {out_path} ({out_size/1e6:.1f} MB)", flush=True)
    return out_path, out_size


def palettize(src_path: Path, nbits: int) -> tuple[Path, int]:
    import coremltools.optimize.coreml as cto
    print(f"[palettize] {nbits}-bit k-means on {src_path.name} (slow β€” many large "
          f"weight matrices in the 256K-vocab embedding table) …", flush=True)
    t0 = time.perf_counter()
    src = ct.models.MLModel(str(src_path), compute_units=ct.ComputeUnit.CPU_ONLY)
    config = cto.OptimizationConfig(
        global_config=cto.OpPalettizerConfig(nbits=nbits, mode="kmeans"),
    )
    out = cto.palettize_weights(src, config)
    out_path = src_path.parent / f"{src_path.stem}_{nbits}bit.mlpackage"
    out.save(str(out_path))
    sz = sum(f.stat().st_size for f in out_path.rglob("*") if f.is_file())
    print(f"            {time.perf_counter()-t0:.1f}s β†’ {sz/1e6:.1f} MB", flush=True)
    return out_path, sz


def verify(coreml_path: Path, model_name: str, pretrained: str,
           context_length: int) -> float:
    """Encode a sample query both ways, return cosine similarity."""
    pt_model, _, _ = open_clip.create_model_and_transforms(
        model_name, pretrained=pretrained)
    tokenizer = open_clip.get_tokenizer(model_name)
    pt_model.eval()
    tokens = tokenizer(["a photo of a cat"])
    if tokens.shape[1] > context_length:
        tokens = tokens[:, :context_length]
    elif tokens.shape[1] < context_length:
        pad = torch.zeros(1, context_length - tokens.shape[1], dtype=tokens.dtype)
        tokens = torch.cat([tokens, pad], dim=1)

    with torch.no_grad():
        pt_emb = pt_model.encode_text(tokens)
        pt_emb = (pt_emb / pt_emb.norm(dim=-1, keepdim=True))[0].numpy().astype(np.float32)

    cm = ct.models.MLModel(str(coreml_path), compute_units=ct.ComputeUnit.CPU_AND_NE)
    cm_out = next(iter(cm.predict({"text": tokens.numpy().astype(np.int32)}).values()))
    cm_emb = cm_out[0].astype(np.float32)
    cm_emb /= np.linalg.norm(cm_emb)
    return float(np.dot(pt_emb, cm_emb))


def main():
    p = argparse.ArgumentParser(
        formatter_class=argparse.RawDescriptionHelpFormatter,
        description=__doc__.split("\n\n")[0],
    )
    p.add_argument("--model", default="ViT-B-16-SigLIP2",
                   help="open_clip model name (default: ViT-B-16-SigLIP2)")
    p.add_argument("--pretrained", default="webli",
                   help="open_clip pretrained tag (default: webli)")
    p.add_argument("--context-length", type=int, default=64,
                   help="Token context length to compile in (default: 64)")
    p.add_argument("-o", "--output-dir", type=Path,
                   default=Path.home() / ".cache" / "mobileclip-coreml",
                   help="Where to save the .mlpackage")
    p.add_argument("--palettize", type=int, choices=[2, 4, 6, 8], default=None,
                   help="Also produce a palettized variant (slow on the 256K "
                        "embedding table β€” ~30 min). 8-bit recommended.")
    p.add_argument("--no-verify", action="store_true",
                   help="Skip cosine verification vs PyTorch (saves ~30s).")
    args = p.parse_args()

    fp16_path, fp16_size = convert(args.model, args.pretrained,
                                   args.output_dir, args.context_length)

    pal_path, pal_size = (None, 0)
    if args.palettize:
        pal_path, pal_size = palettize(fp16_path, args.palettize)

    if not args.no_verify:
        print(f"\n[verify] cosine vs PyTorch:", flush=True)
        cos_fp16 = verify(fp16_path, args.model, args.pretrained, args.context_length)
        print(f"  fp16:           {cos_fp16:.6f}", flush=True)
        if pal_path is not None:
            cos_pal = verify(pal_path, args.model, args.pretrained, args.context_length)
            print(f"  {args.palettize}-bit palettized: {cos_pal:.6f}", flush=True)

    print(f"\n[done] artifacts in {args.output_dir}", flush=True)
    print(f"  fp16:        {fp16_path.name}  ({fp16_size/1e6:.0f} MB)", flush=True)
    if pal_path is not None:
        print(f"  {args.palettize}-bit palett.: {pal_path.name}  ({pal_size/1e6:.0f} MB, "
              f"{fp16_size/pal_size:.1f}x smaller)", flush=True)
    print(f"\nNote: SigLIP2's 256K-token Gemma vocab makes this artifact ~565 MB "
          f"at fp16. Consider whether you really need on-device text encoding "
          f"vs running open_clip in PyTorch (~50ms per query, no shipping cost).",
          flush=True)


if __name__ == "__main__":
    main()