File size: 11,742 Bytes
424bd46 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 | #!/usr/bin/env -S uv run --script
# /// script
# requires-python = ">=3.10,<3.14"
# dependencies = [
# "coremltools",
# "open_clip_torch",
# "transformers",
# "torch",
# "torchvision",
# "pillow",
# "numpy",
# ]
# ///
"""Convert any open_clip image encoder to Core ML for ANE acceleration.
Used to produce the .mlpackage files that ./embed_media_mobileclip.py loads
when --force-coreml is passed. The indexer also has an inline copy of the
fp16 conversion logic for lazy auto-build on first use; this standalone
script adds palettization + correctness verification + benchmarking, suitable
for producing artifacts to publish on HuggingFace.
Tested model coverage:
- MobileCLIP2-B / dfndr2b → fp16: 0.999 cosine vs PyTorch (drop-in)
- ViT-B-16-SigLIP2 / webli → fp16: 0.976, 8-bit palettized: 0.966
Untested (expect to work but verify cosine):
- other MobileCLIP2-* (S0/S2/S3/S4, L-14)
- other SigLIP/SigLIP2 sizes (S0/S2/S3/S4/L-14)
- EVA02-*, ViTamin-*, PE-Core-*
Usage:
# fp16 conversion (default)
./convert_to_coreml_mobileclip.py ViT-B-16-SigLIP2
# 8-bit palettized — half the disk size, near-identical fidelity
./convert_to_coreml_mobileclip.py ViT-B-16-SigLIP2 --palettize 8
# Custom pretrained tag + output dir
./convert_to_coreml_mobileclip.py MobileCLIP2-B --pretrained dfndr2b -o ./out
# Skip the cosine verification + benchmark to convert faster
./convert_to_coreml_mobileclip.py ViT-B-16-SigLIP2 --no-verify
"""
import argparse
import sys
import time
from pathlib import Path
import numpy as np
import torch
from PIL import Image, ImageDraw
import open_clip
import coremltools as ct
def default_pretrained_for(model_name: str) -> str:
if model_name.startswith("MobileCLIP2-"):
return "dfndr2b"
if "SigLIP2" in model_name:
return "webli"
if "SigLIP" in model_name:
return "webli"
if "EVA02" in model_name:
return "merged2b_s8b_b131k"
return "datacompdr"
def preprocess_image_size(preprocess) -> int:
for tf in preprocess.transforms:
if hasattr(tf, "size"):
s = tf.size
return s if isinstance(s, int) else int(s[0])
sys.exit("could not determine input size from preprocess transform")
def preprocess_normalization(preprocess) -> tuple[float, list[float]]:
"""Derive ct.ImageType scale/bias from the model's Normalize transform.
For Normalize(mean, std), the math is:
normalized = (pixel/255 - mean) / std
= pixel * (1/(255*std)) + (-mean/std)
So Core ML's ImageType params are:
scale = 1 / (255 * std)
bias = -mean / std
Examples:
SigLIP2 (mean=0.5, std=0.5): scale=2/255, bias=[-1,-1,-1] → [-1, 1]
MobileCLIP2 (mean=0, std=1): scale=1/255, bias=[0, 0, 0] → [0, 1]
OpenAI CLIP (mean≈0.48, std≈0.27): scale ≈ 0.0146, bias varies → standard CLIP norm
Getting this wrong silently degrades the embedding. Our SigLIP2 was at
0.976 cosine vs PyTorch for weeks because we hardcoded the [0,1] mapping
that worked for MobileCLIP2 but not SigLIP2.
"""
for tf in preprocess.transforms:
if type(tf).__name__ == "Normalize":
mean = list(tf.mean)
std = list(tf.std)
# Channel-wise scale/bias. Core ML accepts a single scale + per-channel bias
# only when std is uniform across channels. For SigLIP2 (std=0.5,0.5,0.5)
# this works; for OpenAI CLIP (std varies) we'd need a different approach.
if not all(s == std[0] for s in std):
sys.exit(f"non-uniform std {std} not supported by ct.ImageType "
"(would need per-channel scale)")
scale = 1.0 / (255.0 * std[0])
bias = [-m / std[0] for m in mean]
return scale, bias
# No Normalize transform → assume [0, 1] direct
return 1.0 / 255.0, [0.0, 0.0, 0.0]
class L2NormImageEncoder(torch.nn.Module):
"""Wraps an open_clip model so the Core ML output is already L2-normalized.
Saves a normalization step at search time and matches the convention used
by Apple's pre-shipped Core ML packages.
"""
def __init__(self, m):
super().__init__()
self.m = m
def forward(self, x):
f = self.m.encode_image(x)
return f / f.norm(dim=-1, keepdim=True)
def convert(model_name: str, pretrained: str, output_dir: Path) -> tuple[Path, int]:
"""Trace open_clip image branch + convert to fp16 Core ML. Returns (path, size)."""
print(f"[1/3] loading {model_name} ({pretrained}) …", flush=True)
t0 = time.perf_counter()
model, _, preprocess = open_clip.create_model_and_transforms(model_name, pretrained=pretrained)
model.eval()
print(f" {time.perf_counter()-t0:.1f}s", flush=True)
size = preprocess_image_size(preprocess)
scale, bias = preprocess_normalization(preprocess)
print(f"[2/3] tracing at {size}x{size} (input scale={scale:.5f}, bias={bias}) …", flush=True)
t0 = time.perf_counter()
with torch.no_grad():
traced = torch.jit.trace(L2NormImageEncoder(model).eval(),
torch.zeros(1, 3, size, size))
print(f" {time.perf_counter()-t0:.1f}s", flush=True)
print(f"[3/3] converting to Core ML (fp16) …", flush=True)
t0 = time.perf_counter()
ml = ct.convert(
traced,
inputs=[ct.ImageType(name="image", shape=(1, 3, size, size),
scale=scale, bias=bias)],
outputs=[ct.TensorType(name="embedding")],
compute_units=ct.ComputeUnit.CPU_AND_NE,
minimum_deployment_target=ct.target.macOS14,
)
print(f" {time.perf_counter()-t0:.1f}s", flush=True)
output_dir.mkdir(parents=True, exist_ok=True)
out_path = output_dir / f"{model_name}_image.mlpackage"
ml.save(str(out_path))
out_size = sum(f.stat().st_size for f in out_path.rglob("*") if f.is_file())
print(f" saved → {out_path} ({out_size/1e6:.1f} MB)", flush=True)
return out_path, out_size
def palettize(src_path: Path, nbits: int) -> tuple[Path, int]:
"""Apply k-means palettization. Returns (path, size)."""
import coremltools.optimize.coreml as cto
print(f"[palettize] loading {src_path.name} …", flush=True)
src = ct.models.MLModel(str(src_path), compute_units=ct.ComputeUnit.CPU_ONLY)
print(f"[palettize] {nbits}-bit k-means clustering (this scales with model depth) …", flush=True)
t0 = time.perf_counter()
config = cto.OptimizationConfig(
global_config=cto.OpPalettizerConfig(nbits=nbits, mode="kmeans"),
)
compressed = cto.palettize_weights(src, config)
print(f" {time.perf_counter()-t0:.1f}s", flush=True)
out_path = src_path.parent / f"{src_path.stem}_{nbits}bit.mlpackage"
compressed.save(str(out_path))
out_size = sum(f.stat().st_size for f in out_path.rglob("*") if f.is_file())
print(f" saved → {out_path} ({out_size/1e6:.1f} MB)", flush=True)
return out_path, out_size
def verify(coreml_path: Path, model_name: str, pretrained: str,
pytorch_model, preprocess) -> float:
"""Encode a synthetic test image both ways, return cosine similarity."""
img = Image.new("RGB", (224, 224), (40, 40, 40))
ImageDraw.Draw(img).ellipse([40, 40, 184, 184], fill=(0, 255, 0))
with torch.no_grad():
pt = pytorch_model.encode_image(preprocess(img).unsqueeze(0))
pt = (pt / pt.norm(dim=-1, keepdim=True))[0].numpy().astype(np.float32)
cm = ct.models.MLModel(str(coreml_path), compute_units=ct.ComputeUnit.CPU_AND_NE)
cm_out = next(iter(cm.predict({"image": img}).values())).squeeze().astype(np.float32)
cm_out /= np.linalg.norm(cm_out)
return float(np.dot(pt, cm_out))
def benchmark(coreml_path: Path, n: int = 200) -> float:
"""Return throughput in images/sec for the converted model on ANE."""
cm = ct.models.MLModel(str(coreml_path), compute_units=ct.ComputeUnit.CPU_AND_NE)
spec = cm.get_spec()
size = spec.description.input[0].type.imageType.width
imgs = [Image.new("RGB", (size, size), (i % 255, (i*3) % 255, (i*7) % 255)) for i in range(n)]
for _ in range(3):
cm.predict({"image": imgs[0]})
t0 = time.perf_counter()
cm.predict([{"image": img} for img in imgs])
return n / (time.perf_counter() - t0)
def main():
p = argparse.ArgumentParser(
formatter_class=argparse.RawDescriptionHelpFormatter,
description=__doc__.split("\n\n")[0],
)
p.add_argument("model", help="open_clip model name (e.g. ViT-B-16-SigLIP2, MobileCLIP2-B)")
p.add_argument("--pretrained", default=None,
help="open_clip pretrained tag (auto-detected from model name if omitted)")
p.add_argument("-o", "--output-dir", type=Path,
default=Path.home() / ".cache" / "mobileclip-coreml",
help="Where to save the .mlpackage (default: ~/.cache/mobileclip-coreml)")
p.add_argument("--palettize", type=int, choices=[2, 4, 6, 8], default=None,
help="After fp16 conversion, also produce a palettized version "
"with this bit-depth. 8-bit ≈ 2x smaller, near-zero quality "
"loss (recommended). 6-bit ≈ 2.7x but degrades ViT models. "
"4/2-bit only for non-critical layers.")
p.add_argument("--no-verify", action="store_true",
help="Skip cosine-similarity verification vs PyTorch (saves ~30s).")
p.add_argument("--no-benchmark", action="store_true",
help="Skip throughput benchmark (saves ~5s).")
args = p.parse_args()
if args.pretrained is None:
args.pretrained = default_pretrained_for(args.model)
print(f"[setup] auto-detected pretrained tag: {args.pretrained}", file=sys.stderr)
fp16_path, fp16_size = convert(args.model, args.pretrained, args.output_dir)
pal_path, pal_size = (None, 0)
if args.palettize:
pal_path, pal_size = palettize(fp16_path, args.palettize)
if not args.no_verify:
print(f"\n[verify] cosine similarity vs PyTorch:", flush=True)
# Reload PyTorch model once for verification.
pt_model, _, preprocess = open_clip.create_model_and_transforms(
args.model, pretrained=args.pretrained)
pt_model.eval()
cos_fp16 = verify(fp16_path, args.model, args.pretrained, pt_model, preprocess)
print(f" fp16: {cos_fp16:.4f}", flush=True)
if pal_path is not None:
cos_pal = verify(pal_path, args.model, args.pretrained, pt_model, preprocess)
print(f" {args.palettize}-bit palettized: {cos_pal:.4f} (compounded vs PyTorch)", flush=True)
if not args.no_benchmark:
print(f"\n[benchmark] throughput on ANE (200 in-memory images):", flush=True)
fps_fp16 = benchmark(fp16_path)
print(f" fp16: {fps_fp16:6.1f} img/s", flush=True)
if pal_path is not None:
fps_pal = benchmark(pal_path)
print(f" {args.palettize}-bit palettized: {fps_pal:6.1f} img/s", flush=True)
print(f"\n[done] artifacts in {args.output_dir}", flush=True)
print(f" fp16: {fp16_path.name} ({fp16_size/1e6:.0f} MB)", flush=True)
if pal_path is not None:
print(f" {args.palettize}-bit palett.: {pal_path.name} ({pal_size/1e6:.0f} MB, "
f"{fp16_size/pal_size:.1f}x smaller)", flush=True)
print(f"\nNext: rename one to '{args.model}_image.mlpackage' inside the output dir to "
f"make embed_media_mobileclip.py use it.", flush=True)
if __name__ == "__main__":
main()
|