ViT-B-16-SigLIP2-Image-CoreML / convert_text_encoder.py
batmac's picture
Upload folder using huggingface_hub
424bd46 verified
#!/usr/bin/env -S uv run --script
# /// script
# requires-python = ">=3.10,<3.14"
# dependencies = [
# "coremltools",
# "open_clip_torch",
# "transformers",
# "torch",
# "torchvision",
# "numpy",
# ]
# ///
"""Convert SigLIP2's text encoder to Core ML — companion to the image
converter (convert_to_coreml_mobileclip.py).
Why this script is separate from the image converter:
- Image branch converts cleanly because Vision Transformer attention is
written using primitive matmul ops in open_clip's TimmModel wrapper.
- Text branch uses PyTorch's nn.MultiheadAttention which dispatches to
F._native_multi_head_attention — a fused C++ kernel that coremltools
has no converter for. This script applies the standard workaround:
replace nn.MultiheadAttention with manual matmul attention before
tracing. Mathematically identical, traces cleanly, converts at 0.999996
cosine vs PyTorch.
Caveat (read before publishing): SigLIP2 uses Gemma2's 256K-token vocabulary.
The token embedding alone is 393 MB at fp16, pushing the full text encoder
to ~565 MB. Even at 8-bit palettization it's ~280 MB. For most use cases,
running the text encoder via open_clip on PyTorch (one-shot, ~50ms per query)
is more practical than shipping this artifact.
Usage:
# fp16 conversion (565 MB)
./convert_to_coreml_text_siglip2.py
# 8-bit palettized (~280 MB) — slow, ~30 minutes for k-means
./convert_to_coreml_text_siglip2.py --palettize 8
# Use a different SigLIP variant (untested but should work for any open_clip
# model that exposes encode_text + has nn.MultiheadAttention modules)
./convert_to_coreml_text_siglip2.py --model ViT-B-16-SigLIP2-256
"""
import argparse
import sys
import time
from pathlib import Path
import coremltools as ct
import numpy as np
import open_clip
import torch
import torch.nn as nn
import torch.nn.functional as F
class SeparateProjMHA(nn.Module):
"""Drop-in replacement for nn.MultiheadAttention using three separate
Q/K/V Linear layers + manual matmul attention. Traces cleanly through
coremltools while remaining mathematically identical to the source.
Constructed from an existing nn.MultiheadAttention; weights are sliced
from its fused in_proj_weight into three separate Linears.
"""
def __init__(self, src: nn.MultiheadAttention):
super().__init__()
self.embed_dim = src.embed_dim
self.num_heads = src.num_heads
self.head_dim = src.embed_dim // src.num_heads
self.scale = self.head_dim ** -0.5
self.batch_first = src.batch_first
E = src.embed_dim
w = src.in_proj_weight.detach()
b = src.in_proj_bias.detach() if src.in_proj_bias is not None else None
self.q_proj = nn.Linear(E, E, bias=b is not None)
self.k_proj = nn.Linear(E, E, bias=b is not None)
self.v_proj = nn.Linear(E, E, bias=b is not None)
with torch.no_grad():
self.q_proj.weight.copy_(w[:E])
self.k_proj.weight.copy_(w[E:2 * E])
self.v_proj.weight.copy_(w[2 * E:])
if b is not None:
self.q_proj.bias.copy_(b[:E])
self.k_proj.bias.copy_(b[E:2 * E])
self.v_proj.bias.copy_(b[2 * E:])
self.out_proj = nn.Linear(E, E, bias=src.out_proj.bias is not None)
self.out_proj.weight = nn.Parameter(src.out_proj.weight.detach().clone())
if src.out_proj.bias is not None:
self.out_proj.bias = nn.Parameter(src.out_proj.bias.detach().clone())
def forward(self, query, key, value, **kwargs):
# Self-attention only — sufficient for SigLIP/CLIP text encoders.
# The (key, value) args are ignored; attention is always over `query`.
if self.batch_first:
B, N, C = query.shape
else:
N, B, C = query.shape
query = query.transpose(0, 1)
q = self.q_proj(query).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(query).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(query).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
out = (attn @ v).transpose(1, 2).reshape(B, N, C)
out = self.out_proj(out)
if not self.batch_first:
out = out.transpose(0, 1)
# nn.MultiheadAttention returns (output, attn_weights); we never
# produce the latter — passing None is fine because the caller
# discards it in standard CLIP-family text encoders.
return out, None
def patch_mha_modules(text_branch: nn.Module) -> int:
"""In-place replacement of every nn.MultiheadAttention with SeparateProjMHA.
Returns the number of replacements made."""
targets = [(name, m) for name, m in text_branch.named_modules()
if isinstance(m, nn.MultiheadAttention)]
for name, mod in targets:
parent = text_branch
parts = name.split(".")
for part in parts[:-1]:
parent = getattr(parent, part)
setattr(parent, parts[-1], SeparateProjMHA(mod))
return len(targets)
class L2NormTextEncoder(nn.Module):
"""Wraps an open_clip model so the converted Core ML output is already
L2-normalized (matches the convention of our image-encoder converter)."""
def __init__(self, m):
super().__init__()
self.m = m
def forward(self, tokens):
f = self.m.encode_text(tokens)
return f / f.norm(dim=-1, keepdim=True)
def convert(model_name: str, pretrained: str, output_dir: Path,
context_length: int = 64) -> tuple[Path, int]:
"""End-to-end: load → patch MHA → trace → convert → save."""
print(f"[1/4] loading {model_name} ({pretrained}) …", flush=True)
t0 = time.perf_counter()
model, _, _ = open_clip.create_model_and_transforms(model_name, pretrained=pretrained)
tokenizer = open_clip.get_tokenizer(model_name)
model.eval()
print(f" {time.perf_counter()-t0:.1f}s", flush=True)
print(f"[2/4] patching nn.MultiheadAttention modules → SeparateProjMHA …",
flush=True)
n_patched = patch_mha_modules(model.text)
print(f" patched {n_patched} attention modules", flush=True)
if n_patched == 0:
sys.exit("no nn.MultiheadAttention modules found — model may already be "
"convertible, or use a different attention implementation")
txt_enc = L2NormTextEncoder(model).eval()
sample_tokens = tokenizer(["calibration"])
if sample_tokens.shape[1] != context_length:
# Pad/truncate to user-specified context_length so the .mlpackage has
# a fixed input shape matching what the searcher will pass.
if sample_tokens.shape[1] > context_length:
sample_tokens = sample_tokens[:, :context_length]
else:
pad = torch.zeros(1, context_length - sample_tokens.shape[1],
dtype=sample_tokens.dtype)
sample_tokens = torch.cat([sample_tokens, pad], dim=1)
print(f"[3/4] tracing at context_length={context_length} …", flush=True)
t0 = time.perf_counter()
with torch.no_grad():
traced = torch.jit.trace(txt_enc, sample_tokens)
print(f" {time.perf_counter()-t0:.1f}s", flush=True)
print(f"[4/4] converting to Core ML (fp16) …", flush=True)
t0 = time.perf_counter()
ml = ct.convert(
traced,
inputs=[ct.TensorType(name="text", shape=(1, context_length), dtype=np.int32)],
outputs=[ct.TensorType(name="embedding")],
compute_units=ct.ComputeUnit.CPU_AND_NE,
minimum_deployment_target=ct.target.macOS14,
)
print(f" {time.perf_counter()-t0:.1f}s", flush=True)
output_dir.mkdir(parents=True, exist_ok=True)
out_path = output_dir / f"{model_name}_text.mlpackage"
ml.save(str(out_path))
out_size = sum(f.stat().st_size for f in out_path.rglob("*") if f.is_file())
print(f" saved → {out_path} ({out_size/1e6:.1f} MB)", flush=True)
return out_path, out_size
def palettize(src_path: Path, nbits: int) -> tuple[Path, int]:
import coremltools.optimize.coreml as cto
print(f"[palettize] {nbits}-bit k-means on {src_path.name} (slow — many large "
f"weight matrices in the 256K-vocab embedding table) …", flush=True)
t0 = time.perf_counter()
src = ct.models.MLModel(str(src_path), compute_units=ct.ComputeUnit.CPU_ONLY)
config = cto.OptimizationConfig(
global_config=cto.OpPalettizerConfig(nbits=nbits, mode="kmeans"),
)
out = cto.palettize_weights(src, config)
out_path = src_path.parent / f"{src_path.stem}_{nbits}bit.mlpackage"
out.save(str(out_path))
sz = sum(f.stat().st_size for f in out_path.rglob("*") if f.is_file())
print(f" {time.perf_counter()-t0:.1f}s → {sz/1e6:.1f} MB", flush=True)
return out_path, sz
def verify(coreml_path: Path, model_name: str, pretrained: str,
context_length: int) -> float:
"""Encode a sample query both ways, return cosine similarity."""
pt_model, _, _ = open_clip.create_model_and_transforms(
model_name, pretrained=pretrained)
tokenizer = open_clip.get_tokenizer(model_name)
pt_model.eval()
tokens = tokenizer(["a photo of a cat"])
if tokens.shape[1] > context_length:
tokens = tokens[:, :context_length]
elif tokens.shape[1] < context_length:
pad = torch.zeros(1, context_length - tokens.shape[1], dtype=tokens.dtype)
tokens = torch.cat([tokens, pad], dim=1)
with torch.no_grad():
pt_emb = pt_model.encode_text(tokens)
pt_emb = (pt_emb / pt_emb.norm(dim=-1, keepdim=True))[0].numpy().astype(np.float32)
cm = ct.models.MLModel(str(coreml_path), compute_units=ct.ComputeUnit.CPU_AND_NE)
cm_out = next(iter(cm.predict({"text": tokens.numpy().astype(np.int32)}).values()))
cm_emb = cm_out[0].astype(np.float32)
cm_emb /= np.linalg.norm(cm_emb)
return float(np.dot(pt_emb, cm_emb))
def main():
p = argparse.ArgumentParser(
formatter_class=argparse.RawDescriptionHelpFormatter,
description=__doc__.split("\n\n")[0],
)
p.add_argument("--model", default="ViT-B-16-SigLIP2",
help="open_clip model name (default: ViT-B-16-SigLIP2)")
p.add_argument("--pretrained", default="webli",
help="open_clip pretrained tag (default: webli)")
p.add_argument("--context-length", type=int, default=64,
help="Token context length to compile in (default: 64)")
p.add_argument("-o", "--output-dir", type=Path,
default=Path.home() / ".cache" / "mobileclip-coreml",
help="Where to save the .mlpackage")
p.add_argument("--palettize", type=int, choices=[2, 4, 6, 8], default=None,
help="Also produce a palettized variant (slow on the 256K "
"embedding table — ~30 min). 8-bit recommended.")
p.add_argument("--no-verify", action="store_true",
help="Skip cosine verification vs PyTorch (saves ~30s).")
args = p.parse_args()
fp16_path, fp16_size = convert(args.model, args.pretrained,
args.output_dir, args.context_length)
pal_path, pal_size = (None, 0)
if args.palettize:
pal_path, pal_size = palettize(fp16_path, args.palettize)
if not args.no_verify:
print(f"\n[verify] cosine vs PyTorch:", flush=True)
cos_fp16 = verify(fp16_path, args.model, args.pretrained, args.context_length)
print(f" fp16: {cos_fp16:.6f}", flush=True)
if pal_path is not None:
cos_pal = verify(pal_path, args.model, args.pretrained, args.context_length)
print(f" {args.palettize}-bit palettized: {cos_pal:.6f}", flush=True)
print(f"\n[done] artifacts in {args.output_dir}", flush=True)
print(f" fp16: {fp16_path.name} ({fp16_size/1e6:.0f} MB)", flush=True)
if pal_path is not None:
print(f" {args.palettize}-bit palett.: {pal_path.name} ({pal_size/1e6:.0f} MB, "
f"{fp16_size/pal_size:.1f}x smaller)", flush=True)
print(f"\nNote: SigLIP2's 256K-token Gemma vocab makes this artifact ~565 MB "
f"at fp16. Consider whether you really need on-device text encoding "
f"vs running open_clip in PyTorch (~50ms per query, no shipping cost).",
flush=True)
if __name__ == "__main__":
main()