File size: 12,612 Bytes
424bd46 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 | #!/usr/bin/env -S uv run --script
# /// script
# requires-python = ">=3.10,<3.14"
# dependencies = [
# "coremltools",
# "open_clip_torch",
# "transformers",
# "torch",
# "torchvision",
# "numpy",
# ]
# ///
"""Convert SigLIP2's text encoder to Core ML β companion to the image
converter (convert_to_coreml_mobileclip.py).
Why this script is separate from the image converter:
- Image branch converts cleanly because Vision Transformer attention is
written using primitive matmul ops in open_clip's TimmModel wrapper.
- Text branch uses PyTorch's nn.MultiheadAttention which dispatches to
F._native_multi_head_attention β a fused C++ kernel that coremltools
has no converter for. This script applies the standard workaround:
replace nn.MultiheadAttention with manual matmul attention before
tracing. Mathematically identical, traces cleanly, converts at 0.999996
cosine vs PyTorch.
Caveat (read before publishing): SigLIP2 uses Gemma2's 256K-token vocabulary.
The token embedding alone is 393 MB at fp16, pushing the full text encoder
to ~565 MB. Even at 8-bit palettization it's ~280 MB. For most use cases,
running the text encoder via open_clip on PyTorch (one-shot, ~50ms per query)
is more practical than shipping this artifact.
Usage:
# fp16 conversion (565 MB)
./convert_to_coreml_text_siglip2.py
# 8-bit palettized (~280 MB) β slow, ~30 minutes for k-means
./convert_to_coreml_text_siglip2.py --palettize 8
# Use a different SigLIP variant (untested but should work for any open_clip
# model that exposes encode_text + has nn.MultiheadAttention modules)
./convert_to_coreml_text_siglip2.py --model ViT-B-16-SigLIP2-256
"""
import argparse
import sys
import time
from pathlib import Path
import coremltools as ct
import numpy as np
import open_clip
import torch
import torch.nn as nn
import torch.nn.functional as F
class SeparateProjMHA(nn.Module):
"""Drop-in replacement for nn.MultiheadAttention using three separate
Q/K/V Linear layers + manual matmul attention. Traces cleanly through
coremltools while remaining mathematically identical to the source.
Constructed from an existing nn.MultiheadAttention; weights are sliced
from its fused in_proj_weight into three separate Linears.
"""
def __init__(self, src: nn.MultiheadAttention):
super().__init__()
self.embed_dim = src.embed_dim
self.num_heads = src.num_heads
self.head_dim = src.embed_dim // src.num_heads
self.scale = self.head_dim ** -0.5
self.batch_first = src.batch_first
E = src.embed_dim
w = src.in_proj_weight.detach()
b = src.in_proj_bias.detach() if src.in_proj_bias is not None else None
self.q_proj = nn.Linear(E, E, bias=b is not None)
self.k_proj = nn.Linear(E, E, bias=b is not None)
self.v_proj = nn.Linear(E, E, bias=b is not None)
with torch.no_grad():
self.q_proj.weight.copy_(w[:E])
self.k_proj.weight.copy_(w[E:2 * E])
self.v_proj.weight.copy_(w[2 * E:])
if b is not None:
self.q_proj.bias.copy_(b[:E])
self.k_proj.bias.copy_(b[E:2 * E])
self.v_proj.bias.copy_(b[2 * E:])
self.out_proj = nn.Linear(E, E, bias=src.out_proj.bias is not None)
self.out_proj.weight = nn.Parameter(src.out_proj.weight.detach().clone())
if src.out_proj.bias is not None:
self.out_proj.bias = nn.Parameter(src.out_proj.bias.detach().clone())
def forward(self, query, key, value, **kwargs):
# Self-attention only β sufficient for SigLIP/CLIP text encoders.
# The (key, value) args are ignored; attention is always over `query`.
if self.batch_first:
B, N, C = query.shape
else:
N, B, C = query.shape
query = query.transpose(0, 1)
q = self.q_proj(query).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(query).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(query).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
out = (attn @ v).transpose(1, 2).reshape(B, N, C)
out = self.out_proj(out)
if not self.batch_first:
out = out.transpose(0, 1)
# nn.MultiheadAttention returns (output, attn_weights); we never
# produce the latter β passing None is fine because the caller
# discards it in standard CLIP-family text encoders.
return out, None
def patch_mha_modules(text_branch: nn.Module) -> int:
"""In-place replacement of every nn.MultiheadAttention with SeparateProjMHA.
Returns the number of replacements made."""
targets = [(name, m) for name, m in text_branch.named_modules()
if isinstance(m, nn.MultiheadAttention)]
for name, mod in targets:
parent = text_branch
parts = name.split(".")
for part in parts[:-1]:
parent = getattr(parent, part)
setattr(parent, parts[-1], SeparateProjMHA(mod))
return len(targets)
class L2NormTextEncoder(nn.Module):
"""Wraps an open_clip model so the converted Core ML output is already
L2-normalized (matches the convention of our image-encoder converter)."""
def __init__(self, m):
super().__init__()
self.m = m
def forward(self, tokens):
f = self.m.encode_text(tokens)
return f / f.norm(dim=-1, keepdim=True)
def convert(model_name: str, pretrained: str, output_dir: Path,
context_length: int = 64) -> tuple[Path, int]:
"""End-to-end: load β patch MHA β trace β convert β save."""
print(f"[1/4] loading {model_name} ({pretrained}) β¦", flush=True)
t0 = time.perf_counter()
model, _, _ = open_clip.create_model_and_transforms(model_name, pretrained=pretrained)
tokenizer = open_clip.get_tokenizer(model_name)
model.eval()
print(f" {time.perf_counter()-t0:.1f}s", flush=True)
print(f"[2/4] patching nn.MultiheadAttention modules β SeparateProjMHA β¦",
flush=True)
n_patched = patch_mha_modules(model.text)
print(f" patched {n_patched} attention modules", flush=True)
if n_patched == 0:
sys.exit("no nn.MultiheadAttention modules found β model may already be "
"convertible, or use a different attention implementation")
txt_enc = L2NormTextEncoder(model).eval()
sample_tokens = tokenizer(["calibration"])
if sample_tokens.shape[1] != context_length:
# Pad/truncate to user-specified context_length so the .mlpackage has
# a fixed input shape matching what the searcher will pass.
if sample_tokens.shape[1] > context_length:
sample_tokens = sample_tokens[:, :context_length]
else:
pad = torch.zeros(1, context_length - sample_tokens.shape[1],
dtype=sample_tokens.dtype)
sample_tokens = torch.cat([sample_tokens, pad], dim=1)
print(f"[3/4] tracing at context_length={context_length} β¦", flush=True)
t0 = time.perf_counter()
with torch.no_grad():
traced = torch.jit.trace(txt_enc, sample_tokens)
print(f" {time.perf_counter()-t0:.1f}s", flush=True)
print(f"[4/4] converting to Core ML (fp16) β¦", flush=True)
t0 = time.perf_counter()
ml = ct.convert(
traced,
inputs=[ct.TensorType(name="text", shape=(1, context_length), dtype=np.int32)],
outputs=[ct.TensorType(name="embedding")],
compute_units=ct.ComputeUnit.CPU_AND_NE,
minimum_deployment_target=ct.target.macOS14,
)
print(f" {time.perf_counter()-t0:.1f}s", flush=True)
output_dir.mkdir(parents=True, exist_ok=True)
out_path = output_dir / f"{model_name}_text.mlpackage"
ml.save(str(out_path))
out_size = sum(f.stat().st_size for f in out_path.rglob("*") if f.is_file())
print(f" saved β {out_path} ({out_size/1e6:.1f} MB)", flush=True)
return out_path, out_size
def palettize(src_path: Path, nbits: int) -> tuple[Path, int]:
import coremltools.optimize.coreml as cto
print(f"[palettize] {nbits}-bit k-means on {src_path.name} (slow β many large "
f"weight matrices in the 256K-vocab embedding table) β¦", flush=True)
t0 = time.perf_counter()
src = ct.models.MLModel(str(src_path), compute_units=ct.ComputeUnit.CPU_ONLY)
config = cto.OptimizationConfig(
global_config=cto.OpPalettizerConfig(nbits=nbits, mode="kmeans"),
)
out = cto.palettize_weights(src, config)
out_path = src_path.parent / f"{src_path.stem}_{nbits}bit.mlpackage"
out.save(str(out_path))
sz = sum(f.stat().st_size for f in out_path.rglob("*") if f.is_file())
print(f" {time.perf_counter()-t0:.1f}s β {sz/1e6:.1f} MB", flush=True)
return out_path, sz
def verify(coreml_path: Path, model_name: str, pretrained: str,
context_length: int) -> float:
"""Encode a sample query both ways, return cosine similarity."""
pt_model, _, _ = open_clip.create_model_and_transforms(
model_name, pretrained=pretrained)
tokenizer = open_clip.get_tokenizer(model_name)
pt_model.eval()
tokens = tokenizer(["a photo of a cat"])
if tokens.shape[1] > context_length:
tokens = tokens[:, :context_length]
elif tokens.shape[1] < context_length:
pad = torch.zeros(1, context_length - tokens.shape[1], dtype=tokens.dtype)
tokens = torch.cat([tokens, pad], dim=1)
with torch.no_grad():
pt_emb = pt_model.encode_text(tokens)
pt_emb = (pt_emb / pt_emb.norm(dim=-1, keepdim=True))[0].numpy().astype(np.float32)
cm = ct.models.MLModel(str(coreml_path), compute_units=ct.ComputeUnit.CPU_AND_NE)
cm_out = next(iter(cm.predict({"text": tokens.numpy().astype(np.int32)}).values()))
cm_emb = cm_out[0].astype(np.float32)
cm_emb /= np.linalg.norm(cm_emb)
return float(np.dot(pt_emb, cm_emb))
def main():
p = argparse.ArgumentParser(
formatter_class=argparse.RawDescriptionHelpFormatter,
description=__doc__.split("\n\n")[0],
)
p.add_argument("--model", default="ViT-B-16-SigLIP2",
help="open_clip model name (default: ViT-B-16-SigLIP2)")
p.add_argument("--pretrained", default="webli",
help="open_clip pretrained tag (default: webli)")
p.add_argument("--context-length", type=int, default=64,
help="Token context length to compile in (default: 64)")
p.add_argument("-o", "--output-dir", type=Path,
default=Path.home() / ".cache" / "mobileclip-coreml",
help="Where to save the .mlpackage")
p.add_argument("--palettize", type=int, choices=[2, 4, 6, 8], default=None,
help="Also produce a palettized variant (slow on the 256K "
"embedding table β ~30 min). 8-bit recommended.")
p.add_argument("--no-verify", action="store_true",
help="Skip cosine verification vs PyTorch (saves ~30s).")
args = p.parse_args()
fp16_path, fp16_size = convert(args.model, args.pretrained,
args.output_dir, args.context_length)
pal_path, pal_size = (None, 0)
if args.palettize:
pal_path, pal_size = palettize(fp16_path, args.palettize)
if not args.no_verify:
print(f"\n[verify] cosine vs PyTorch:", flush=True)
cos_fp16 = verify(fp16_path, args.model, args.pretrained, args.context_length)
print(f" fp16: {cos_fp16:.6f}", flush=True)
if pal_path is not None:
cos_pal = verify(pal_path, args.model, args.pretrained, args.context_length)
print(f" {args.palettize}-bit palettized: {cos_pal:.6f}", flush=True)
print(f"\n[done] artifacts in {args.output_dir}", flush=True)
print(f" fp16: {fp16_path.name} ({fp16_size/1e6:.0f} MB)", flush=True)
if pal_path is not None:
print(f" {args.palettize}-bit palett.: {pal_path.name} ({pal_size/1e6:.0f} MB, "
f"{fp16_size/pal_size:.1f}x smaller)", flush=True)
print(f"\nNote: SigLIP2's 256K-token Gemma vocab makes this artifact ~565 MB "
f"at fp16. Consider whether you really need on-device text encoding "
f"vs running open_clip in PyTorch (~50ms per query, no shipping cost).",
flush=True)
if __name__ == "__main__":
main()
|