#!/usr/bin/env -S uv run --script # /// script # requires-python = ">=3.10,<3.14" # dependencies = [ # "coremltools", # "open_clip_torch", # "transformers", # "torch", # "torchvision", # "numpy", # ] # /// """Convert SigLIP2's text encoder to Core ML — companion to the image converter (convert_to_coreml_mobileclip.py). Why this script is separate from the image converter: - Image branch converts cleanly because Vision Transformer attention is written using primitive matmul ops in open_clip's TimmModel wrapper. - Text branch uses PyTorch's nn.MultiheadAttention which dispatches to F._native_multi_head_attention — a fused C++ kernel that coremltools has no converter for. This script applies the standard workaround: replace nn.MultiheadAttention with manual matmul attention before tracing. Mathematically identical, traces cleanly, converts at 0.999996 cosine vs PyTorch. Caveat (read before publishing): SigLIP2 uses Gemma2's 256K-token vocabulary. The token embedding alone is 393 MB at fp16, pushing the full text encoder to ~565 MB. Even at 8-bit palettization it's ~280 MB. For most use cases, running the text encoder via open_clip on PyTorch (one-shot, ~50ms per query) is more practical than shipping this artifact. Usage: # fp16 conversion (565 MB) ./convert_to_coreml_text_siglip2.py # 8-bit palettized (~280 MB) — slow, ~30 minutes for k-means ./convert_to_coreml_text_siglip2.py --palettize 8 # Use a different SigLIP variant (untested but should work for any open_clip # model that exposes encode_text + has nn.MultiheadAttention modules) ./convert_to_coreml_text_siglip2.py --model ViT-B-16-SigLIP2-256 """ import argparse import sys import time from pathlib import Path import coremltools as ct import numpy as np import open_clip import torch import torch.nn as nn import torch.nn.functional as F class SeparateProjMHA(nn.Module): """Drop-in replacement for nn.MultiheadAttention using three separate Q/K/V Linear layers + manual matmul attention. Traces cleanly through coremltools while remaining mathematically identical to the source. Constructed from an existing nn.MultiheadAttention; weights are sliced from its fused in_proj_weight into three separate Linears. """ def __init__(self, src: nn.MultiheadAttention): super().__init__() self.embed_dim = src.embed_dim self.num_heads = src.num_heads self.head_dim = src.embed_dim // src.num_heads self.scale = self.head_dim ** -0.5 self.batch_first = src.batch_first E = src.embed_dim w = src.in_proj_weight.detach() b = src.in_proj_bias.detach() if src.in_proj_bias is not None else None self.q_proj = nn.Linear(E, E, bias=b is not None) self.k_proj = nn.Linear(E, E, bias=b is not None) self.v_proj = nn.Linear(E, E, bias=b is not None) with torch.no_grad(): self.q_proj.weight.copy_(w[:E]) self.k_proj.weight.copy_(w[E:2 * E]) self.v_proj.weight.copy_(w[2 * E:]) if b is not None: self.q_proj.bias.copy_(b[:E]) self.k_proj.bias.copy_(b[E:2 * E]) self.v_proj.bias.copy_(b[2 * E:]) self.out_proj = nn.Linear(E, E, bias=src.out_proj.bias is not None) self.out_proj.weight = nn.Parameter(src.out_proj.weight.detach().clone()) if src.out_proj.bias is not None: self.out_proj.bias = nn.Parameter(src.out_proj.bias.detach().clone()) def forward(self, query, key, value, **kwargs): # Self-attention only — sufficient for SigLIP/CLIP text encoders. # The (key, value) args are ignored; attention is always over `query`. if self.batch_first: B, N, C = query.shape else: N, B, C = query.shape query = query.transpose(0, 1) q = self.q_proj(query).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2) k = self.k_proj(query).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2) v = self.v_proj(query).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2) attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) out = (attn @ v).transpose(1, 2).reshape(B, N, C) out = self.out_proj(out) if not self.batch_first: out = out.transpose(0, 1) # nn.MultiheadAttention returns (output, attn_weights); we never # produce the latter — passing None is fine because the caller # discards it in standard CLIP-family text encoders. return out, None def patch_mha_modules(text_branch: nn.Module) -> int: """In-place replacement of every nn.MultiheadAttention with SeparateProjMHA. Returns the number of replacements made.""" targets = [(name, m) for name, m in text_branch.named_modules() if isinstance(m, nn.MultiheadAttention)] for name, mod in targets: parent = text_branch parts = name.split(".") for part in parts[:-1]: parent = getattr(parent, part) setattr(parent, parts[-1], SeparateProjMHA(mod)) return len(targets) class L2NormTextEncoder(nn.Module): """Wraps an open_clip model so the converted Core ML output is already L2-normalized (matches the convention of our image-encoder converter).""" def __init__(self, m): super().__init__() self.m = m def forward(self, tokens): f = self.m.encode_text(tokens) return f / f.norm(dim=-1, keepdim=True) def convert(model_name: str, pretrained: str, output_dir: Path, context_length: int = 64) -> tuple[Path, int]: """End-to-end: load → patch MHA → trace → convert → save.""" print(f"[1/4] loading {model_name} ({pretrained}) …", flush=True) t0 = time.perf_counter() model, _, _ = open_clip.create_model_and_transforms(model_name, pretrained=pretrained) tokenizer = open_clip.get_tokenizer(model_name) model.eval() print(f" {time.perf_counter()-t0:.1f}s", flush=True) print(f"[2/4] patching nn.MultiheadAttention modules → SeparateProjMHA …", flush=True) n_patched = patch_mha_modules(model.text) print(f" patched {n_patched} attention modules", flush=True) if n_patched == 0: sys.exit("no nn.MultiheadAttention modules found — model may already be " "convertible, or use a different attention implementation") txt_enc = L2NormTextEncoder(model).eval() sample_tokens = tokenizer(["calibration"]) if sample_tokens.shape[1] != context_length: # Pad/truncate to user-specified context_length so the .mlpackage has # a fixed input shape matching what the searcher will pass. if sample_tokens.shape[1] > context_length: sample_tokens = sample_tokens[:, :context_length] else: pad = torch.zeros(1, context_length - sample_tokens.shape[1], dtype=sample_tokens.dtype) sample_tokens = torch.cat([sample_tokens, pad], dim=1) print(f"[3/4] tracing at context_length={context_length} …", flush=True) t0 = time.perf_counter() with torch.no_grad(): traced = torch.jit.trace(txt_enc, sample_tokens) print(f" {time.perf_counter()-t0:.1f}s", flush=True) print(f"[4/4] converting to Core ML (fp16) …", flush=True) t0 = time.perf_counter() ml = ct.convert( traced, inputs=[ct.TensorType(name="text", shape=(1, context_length), dtype=np.int32)], outputs=[ct.TensorType(name="embedding")], compute_units=ct.ComputeUnit.CPU_AND_NE, minimum_deployment_target=ct.target.macOS14, ) print(f" {time.perf_counter()-t0:.1f}s", flush=True) output_dir.mkdir(parents=True, exist_ok=True) out_path = output_dir / f"{model_name}_text.mlpackage" ml.save(str(out_path)) out_size = sum(f.stat().st_size for f in out_path.rglob("*") if f.is_file()) print(f" saved → {out_path} ({out_size/1e6:.1f} MB)", flush=True) return out_path, out_size def palettize(src_path: Path, nbits: int) -> tuple[Path, int]: import coremltools.optimize.coreml as cto print(f"[palettize] {nbits}-bit k-means on {src_path.name} (slow — many large " f"weight matrices in the 256K-vocab embedding table) …", flush=True) t0 = time.perf_counter() src = ct.models.MLModel(str(src_path), compute_units=ct.ComputeUnit.CPU_ONLY) config = cto.OptimizationConfig( global_config=cto.OpPalettizerConfig(nbits=nbits, mode="kmeans"), ) out = cto.palettize_weights(src, config) out_path = src_path.parent / f"{src_path.stem}_{nbits}bit.mlpackage" out.save(str(out_path)) sz = sum(f.stat().st_size for f in out_path.rglob("*") if f.is_file()) print(f" {time.perf_counter()-t0:.1f}s → {sz/1e6:.1f} MB", flush=True) return out_path, sz def verify(coreml_path: Path, model_name: str, pretrained: str, context_length: int) -> float: """Encode a sample query both ways, return cosine similarity.""" pt_model, _, _ = open_clip.create_model_and_transforms( model_name, pretrained=pretrained) tokenizer = open_clip.get_tokenizer(model_name) pt_model.eval() tokens = tokenizer(["a photo of a cat"]) if tokens.shape[1] > context_length: tokens = tokens[:, :context_length] elif tokens.shape[1] < context_length: pad = torch.zeros(1, context_length - tokens.shape[1], dtype=tokens.dtype) tokens = torch.cat([tokens, pad], dim=1) with torch.no_grad(): pt_emb = pt_model.encode_text(tokens) pt_emb = (pt_emb / pt_emb.norm(dim=-1, keepdim=True))[0].numpy().astype(np.float32) cm = ct.models.MLModel(str(coreml_path), compute_units=ct.ComputeUnit.CPU_AND_NE) cm_out = next(iter(cm.predict({"text": tokens.numpy().astype(np.int32)}).values())) cm_emb = cm_out[0].astype(np.float32) cm_emb /= np.linalg.norm(cm_emb) return float(np.dot(pt_emb, cm_emb)) def main(): p = argparse.ArgumentParser( formatter_class=argparse.RawDescriptionHelpFormatter, description=__doc__.split("\n\n")[0], ) p.add_argument("--model", default="ViT-B-16-SigLIP2", help="open_clip model name (default: ViT-B-16-SigLIP2)") p.add_argument("--pretrained", default="webli", help="open_clip pretrained tag (default: webli)") p.add_argument("--context-length", type=int, default=64, help="Token context length to compile in (default: 64)") p.add_argument("-o", "--output-dir", type=Path, default=Path.home() / ".cache" / "mobileclip-coreml", help="Where to save the .mlpackage") p.add_argument("--palettize", type=int, choices=[2, 4, 6, 8], default=None, help="Also produce a palettized variant (slow on the 256K " "embedding table — ~30 min). 8-bit recommended.") p.add_argument("--no-verify", action="store_true", help="Skip cosine verification vs PyTorch (saves ~30s).") args = p.parse_args() fp16_path, fp16_size = convert(args.model, args.pretrained, args.output_dir, args.context_length) pal_path, pal_size = (None, 0) if args.palettize: pal_path, pal_size = palettize(fp16_path, args.palettize) if not args.no_verify: print(f"\n[verify] cosine vs PyTorch:", flush=True) cos_fp16 = verify(fp16_path, args.model, args.pretrained, args.context_length) print(f" fp16: {cos_fp16:.6f}", flush=True) if pal_path is not None: cos_pal = verify(pal_path, args.model, args.pretrained, args.context_length) print(f" {args.palettize}-bit palettized: {cos_pal:.6f}", flush=True) print(f"\n[done] artifacts in {args.output_dir}", flush=True) print(f" fp16: {fp16_path.name} ({fp16_size/1e6:.0f} MB)", flush=True) if pal_path is not None: print(f" {args.palettize}-bit palett.: {pal_path.name} ({pal_size/1e6:.0f} MB, " f"{fp16_size/pal_size:.1f}x smaller)", flush=True) print(f"\nNote: SigLIP2's 256K-token Gemma vocab makes this artifact ~565 MB " f"at fp16. Consider whether you really need on-device text encoding " f"vs running open_clip in PyTorch (~50ms per query, no shipping cost).", flush=True) if __name__ == "__main__": main()