| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """Convert SigLIP2's text encoder to Core ML — companion to the image |
| converter (convert_to_coreml_mobileclip.py). |
| |
| Why this script is separate from the image converter: |
| - Image branch converts cleanly because Vision Transformer attention is |
| written using primitive matmul ops in open_clip's TimmModel wrapper. |
| - Text branch uses PyTorch's nn.MultiheadAttention which dispatches to |
| F._native_multi_head_attention — a fused C++ kernel that coremltools |
| has no converter for. This script applies the standard workaround: |
| replace nn.MultiheadAttention with manual matmul attention before |
| tracing. Mathematically identical, traces cleanly, converts at 0.999996 |
| cosine vs PyTorch. |
| |
| Caveat (read before publishing): SigLIP2 uses Gemma2's 256K-token vocabulary. |
| The token embedding alone is 393 MB at fp16, pushing the full text encoder |
| to ~565 MB. Even at 8-bit palettization it's ~280 MB. For most use cases, |
| running the text encoder via open_clip on PyTorch (one-shot, ~50ms per query) |
| is more practical than shipping this artifact. |
| |
| Usage: |
| # fp16 conversion (565 MB) |
| ./convert_to_coreml_text_siglip2.py |
| |
| # 8-bit palettized (~280 MB) — slow, ~30 minutes for k-means |
| ./convert_to_coreml_text_siglip2.py --palettize 8 |
| |
| # Use a different SigLIP variant (untested but should work for any open_clip |
| # model that exposes encode_text + has nn.MultiheadAttention modules) |
| ./convert_to_coreml_text_siglip2.py --model ViT-B-16-SigLIP2-256 |
| """ |
|
|
| import argparse |
| import sys |
| import time |
| from pathlib import Path |
|
|
| import coremltools as ct |
| import numpy as np |
| import open_clip |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class SeparateProjMHA(nn.Module): |
| """Drop-in replacement for nn.MultiheadAttention using three separate |
| Q/K/V Linear layers + manual matmul attention. Traces cleanly through |
| coremltools while remaining mathematically identical to the source. |
| |
| Constructed from an existing nn.MultiheadAttention; weights are sliced |
| from its fused in_proj_weight into three separate Linears. |
| """ |
|
|
| def __init__(self, src: nn.MultiheadAttention): |
| super().__init__() |
| self.embed_dim = src.embed_dim |
| self.num_heads = src.num_heads |
| self.head_dim = src.embed_dim // src.num_heads |
| self.scale = self.head_dim ** -0.5 |
| self.batch_first = src.batch_first |
|
|
| E = src.embed_dim |
| w = src.in_proj_weight.detach() |
| b = src.in_proj_bias.detach() if src.in_proj_bias is not None else None |
| self.q_proj = nn.Linear(E, E, bias=b is not None) |
| self.k_proj = nn.Linear(E, E, bias=b is not None) |
| self.v_proj = nn.Linear(E, E, bias=b is not None) |
| with torch.no_grad(): |
| self.q_proj.weight.copy_(w[:E]) |
| self.k_proj.weight.copy_(w[E:2 * E]) |
| self.v_proj.weight.copy_(w[2 * E:]) |
| if b is not None: |
| self.q_proj.bias.copy_(b[:E]) |
| self.k_proj.bias.copy_(b[E:2 * E]) |
| self.v_proj.bias.copy_(b[2 * E:]) |
|
|
| self.out_proj = nn.Linear(E, E, bias=src.out_proj.bias is not None) |
| self.out_proj.weight = nn.Parameter(src.out_proj.weight.detach().clone()) |
| if src.out_proj.bias is not None: |
| self.out_proj.bias = nn.Parameter(src.out_proj.bias.detach().clone()) |
|
|
| def forward(self, query, key, value, **kwargs): |
| |
| |
| if self.batch_first: |
| B, N, C = query.shape |
| else: |
| N, B, C = query.shape |
| query = query.transpose(0, 1) |
| q = self.q_proj(query).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2) |
| k = self.k_proj(query).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2) |
| v = self.v_proj(query).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2) |
| attn = (q @ k.transpose(-2, -1)) * self.scale |
| attn = attn.softmax(dim=-1) |
| out = (attn @ v).transpose(1, 2).reshape(B, N, C) |
| out = self.out_proj(out) |
| if not self.batch_first: |
| out = out.transpose(0, 1) |
| |
| |
| |
| return out, None |
|
|
|
|
| def patch_mha_modules(text_branch: nn.Module) -> int: |
| """In-place replacement of every nn.MultiheadAttention with SeparateProjMHA. |
| Returns the number of replacements made.""" |
| targets = [(name, m) for name, m in text_branch.named_modules() |
| if isinstance(m, nn.MultiheadAttention)] |
| for name, mod in targets: |
| parent = text_branch |
| parts = name.split(".") |
| for part in parts[:-1]: |
| parent = getattr(parent, part) |
| setattr(parent, parts[-1], SeparateProjMHA(mod)) |
| return len(targets) |
|
|
|
|
| class L2NormTextEncoder(nn.Module): |
| """Wraps an open_clip model so the converted Core ML output is already |
| L2-normalized (matches the convention of our image-encoder converter).""" |
| def __init__(self, m): |
| super().__init__() |
| self.m = m |
|
|
| def forward(self, tokens): |
| f = self.m.encode_text(tokens) |
| return f / f.norm(dim=-1, keepdim=True) |
|
|
|
|
| def convert(model_name: str, pretrained: str, output_dir: Path, |
| context_length: int = 64) -> tuple[Path, int]: |
| """End-to-end: load → patch MHA → trace → convert → save.""" |
| print(f"[1/4] loading {model_name} ({pretrained}) …", flush=True) |
| t0 = time.perf_counter() |
| model, _, _ = open_clip.create_model_and_transforms(model_name, pretrained=pretrained) |
| tokenizer = open_clip.get_tokenizer(model_name) |
| model.eval() |
| print(f" {time.perf_counter()-t0:.1f}s", flush=True) |
|
|
| print(f"[2/4] patching nn.MultiheadAttention modules → SeparateProjMHA …", |
| flush=True) |
| n_patched = patch_mha_modules(model.text) |
| print(f" patched {n_patched} attention modules", flush=True) |
| if n_patched == 0: |
| sys.exit("no nn.MultiheadAttention modules found — model may already be " |
| "convertible, or use a different attention implementation") |
|
|
| txt_enc = L2NormTextEncoder(model).eval() |
| sample_tokens = tokenizer(["calibration"]) |
| if sample_tokens.shape[1] != context_length: |
| |
| |
| if sample_tokens.shape[1] > context_length: |
| sample_tokens = sample_tokens[:, :context_length] |
| else: |
| pad = torch.zeros(1, context_length - sample_tokens.shape[1], |
| dtype=sample_tokens.dtype) |
| sample_tokens = torch.cat([sample_tokens, pad], dim=1) |
| print(f"[3/4] tracing at context_length={context_length} …", flush=True) |
| t0 = time.perf_counter() |
| with torch.no_grad(): |
| traced = torch.jit.trace(txt_enc, sample_tokens) |
| print(f" {time.perf_counter()-t0:.1f}s", flush=True) |
|
|
| print(f"[4/4] converting to Core ML (fp16) …", flush=True) |
| t0 = time.perf_counter() |
| ml = ct.convert( |
| traced, |
| inputs=[ct.TensorType(name="text", shape=(1, context_length), dtype=np.int32)], |
| outputs=[ct.TensorType(name="embedding")], |
| compute_units=ct.ComputeUnit.CPU_AND_NE, |
| minimum_deployment_target=ct.target.macOS14, |
| ) |
| print(f" {time.perf_counter()-t0:.1f}s", flush=True) |
|
|
| output_dir.mkdir(parents=True, exist_ok=True) |
| out_path = output_dir / f"{model_name}_text.mlpackage" |
| ml.save(str(out_path)) |
| out_size = sum(f.stat().st_size for f in out_path.rglob("*") if f.is_file()) |
| print(f" saved → {out_path} ({out_size/1e6:.1f} MB)", flush=True) |
| return out_path, out_size |
|
|
|
|
| def palettize(src_path: Path, nbits: int) -> tuple[Path, int]: |
| import coremltools.optimize.coreml as cto |
| print(f"[palettize] {nbits}-bit k-means on {src_path.name} (slow — many large " |
| f"weight matrices in the 256K-vocab embedding table) …", flush=True) |
| t0 = time.perf_counter() |
| src = ct.models.MLModel(str(src_path), compute_units=ct.ComputeUnit.CPU_ONLY) |
| config = cto.OptimizationConfig( |
| global_config=cto.OpPalettizerConfig(nbits=nbits, mode="kmeans"), |
| ) |
| out = cto.palettize_weights(src, config) |
| out_path = src_path.parent / f"{src_path.stem}_{nbits}bit.mlpackage" |
| out.save(str(out_path)) |
| sz = sum(f.stat().st_size for f in out_path.rglob("*") if f.is_file()) |
| print(f" {time.perf_counter()-t0:.1f}s → {sz/1e6:.1f} MB", flush=True) |
| return out_path, sz |
|
|
|
|
| def verify(coreml_path: Path, model_name: str, pretrained: str, |
| context_length: int) -> float: |
| """Encode a sample query both ways, return cosine similarity.""" |
| pt_model, _, _ = open_clip.create_model_and_transforms( |
| model_name, pretrained=pretrained) |
| tokenizer = open_clip.get_tokenizer(model_name) |
| pt_model.eval() |
| tokens = tokenizer(["a photo of a cat"]) |
| if tokens.shape[1] > context_length: |
| tokens = tokens[:, :context_length] |
| elif tokens.shape[1] < context_length: |
| pad = torch.zeros(1, context_length - tokens.shape[1], dtype=tokens.dtype) |
| tokens = torch.cat([tokens, pad], dim=1) |
|
|
| with torch.no_grad(): |
| pt_emb = pt_model.encode_text(tokens) |
| pt_emb = (pt_emb / pt_emb.norm(dim=-1, keepdim=True))[0].numpy().astype(np.float32) |
|
|
| cm = ct.models.MLModel(str(coreml_path), compute_units=ct.ComputeUnit.CPU_AND_NE) |
| cm_out = next(iter(cm.predict({"text": tokens.numpy().astype(np.int32)}).values())) |
| cm_emb = cm_out[0].astype(np.float32) |
| cm_emb /= np.linalg.norm(cm_emb) |
| return float(np.dot(pt_emb, cm_emb)) |
|
|
|
|
| def main(): |
| p = argparse.ArgumentParser( |
| formatter_class=argparse.RawDescriptionHelpFormatter, |
| description=__doc__.split("\n\n")[0], |
| ) |
| p.add_argument("--model", default="ViT-B-16-SigLIP2", |
| help="open_clip model name (default: ViT-B-16-SigLIP2)") |
| p.add_argument("--pretrained", default="webli", |
| help="open_clip pretrained tag (default: webli)") |
| p.add_argument("--context-length", type=int, default=64, |
| help="Token context length to compile in (default: 64)") |
| p.add_argument("-o", "--output-dir", type=Path, |
| default=Path.home() / ".cache" / "mobileclip-coreml", |
| help="Where to save the .mlpackage") |
| p.add_argument("--palettize", type=int, choices=[2, 4, 6, 8], default=None, |
| help="Also produce a palettized variant (slow on the 256K " |
| "embedding table — ~30 min). 8-bit recommended.") |
| p.add_argument("--no-verify", action="store_true", |
| help="Skip cosine verification vs PyTorch (saves ~30s).") |
| args = p.parse_args() |
|
|
| fp16_path, fp16_size = convert(args.model, args.pretrained, |
| args.output_dir, args.context_length) |
|
|
| pal_path, pal_size = (None, 0) |
| if args.palettize: |
| pal_path, pal_size = palettize(fp16_path, args.palettize) |
|
|
| if not args.no_verify: |
| print(f"\n[verify] cosine vs PyTorch:", flush=True) |
| cos_fp16 = verify(fp16_path, args.model, args.pretrained, args.context_length) |
| print(f" fp16: {cos_fp16:.6f}", flush=True) |
| if pal_path is not None: |
| cos_pal = verify(pal_path, args.model, args.pretrained, args.context_length) |
| print(f" {args.palettize}-bit palettized: {cos_pal:.6f}", flush=True) |
|
|
| print(f"\n[done] artifacts in {args.output_dir}", flush=True) |
| print(f" fp16: {fp16_path.name} ({fp16_size/1e6:.0f} MB)", flush=True) |
| if pal_path is not None: |
| print(f" {args.palettize}-bit palett.: {pal_path.name} ({pal_size/1e6:.0f} MB, " |
| f"{fp16_size/pal_size:.1f}x smaller)", flush=True) |
| print(f"\nNote: SigLIP2's 256K-token Gemma vocab makes this artifact ~565 MB " |
| f"at fp16. Consider whether you really need on-device text encoding " |
| f"vs running open_clip in PyTorch (~50ms per query, no shipping cost).", |
| flush=True) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|