mlboydaisuke's picture
Upload README.md with huggingface_hub
4f0afe8 verified
|
Raw
History Blame Contribute Delete
4.72 kB
---
license: apache-2.0
library_name: litert
pipeline_tag: image-feature-extraction
base_model: timm/vit_base_patch16_siglip_224.v2_webli
tags:
- litert
- tflite
- on-device
- android
- gpu
- clip
- siglip
- siglip2
- image-encoder
- vit
---
# SigLIP 2 (ViT-B/16, 224) — LiteRT (TFLite) GPU
On-device [LiteRT](https://ai.google.dev/edge/litert) (`.tflite`) conversion of
**SigLIP 2** (Google 2025), a state-of-the-art CLIP-style image tower, converted
from [`timm/vit_base_patch16_siglip_224.v2_webli`](https://huggingface.co/timm/vit_base_patch16_siglip_224.v2_webli)
(ViT-B/16, 93M params; the image tower of `ViT-B-16-SigLIP2` / `google/siglip2`).
A single forward pass turns one RGB image into a **768-d L2-normalized image
embedding** for zero-shot classification, retrieval, and similarity — running
**fully on the LiteRT `CompiledModel` GPU accelerator** (ML Drift): **all ops are
GPU-native (`Replacing 809 out of 809 node(s) … LITERT_CL`), no CPU fallback, no
Flex ops**, and the GPU output matches PyTorch (corr ≈ 1.0).
## Files
| File | Size | Description |
|------|------|-------------|
| `siglip2_base_224_fp16.tflite` | 185 MB | FP16 single-graph model, GPU full-residency |
| `convert_siglip2.py` | — | Reproducible conversion script (timm → tflite) |
## I/O
- **Input**: `[1, 3, 224, 224]` float32, **NCHW**, RGB normalized to **`[-1, 1]`**
(`(pixel/255 - 0.5) / 0.5`). Normalization is applied by the caller.
- **Output**: `[1, 768]` float32, **L2-normalized** image embedding.
For zero-shot classification, precompute text-label embeddings with the SigLIP 2
text tower (`open_clip` `ViT-B-16-SigLIP2`, prompt `"This is a photo of {label}."`)
and take the dot product on device.
## Usage (Android, LiteRT CompiledModel)
```kotlin
val model = CompiledModel.create(
context.assets, "siglip2_base_224_fp16.tflite",
CompiledModel.Options(Accelerator.GPU), null
)
val inputs = model.createInputBuffers()
val outputs = model.createOutputBuffers()
inputs[0].writeFloat(nchwFloatArray) // [1,3,224,224], RGB scaled to [-1,1]
model.run(inputs, outputs)
val embedding = outputs[0].readFloat() // [768], already L2-normalized
```
## Performance
- **~60 ms / image steady-state** on a Pixel 8a (Mali-G615) GPU (best ~9 ms),
full GPU residency, FP16.
## Conversion notes
Converted with [litert-torch / ai-edge-torch](https://github.com/google-ai-edge/ai-edge-torch).
Making the ViT image tower run fully on the GPU delegate **and produce correct
output on device** required three verbatim (weights-exact, corr ≈ 1.0) model-side
rewrites (full GPU residency does **not** imply a correct result):
1. **Fused-qkv → 4-D manual attention** — the fused `qkv` reshape emits a 5-D
head-split the delegate rejects; decompose into separate q/k/v. Self-attention
uses `scaled_dot_product_attention`, whose lowering keeps the batch-matmul 3-D
with a materialized transpose (both required for residency).
2. **Attention-pool single-query attention → broadcast-multiply + reduce-sum** —
the pooling query is a constant latent, so a batch-matmul there is
`const @ non-const` (rejected / mis-computed); express as `(q·k).sum` + softmax
+ `(attn·v).sum`.
3. **Overflow-safe LayerNorm** — the delegate computes the LayerNorm variance
reduction in fp16 even for an fp32 graph; deep-ViT massive activations make
`sum((x-mean)²)` exceed the fp16 max (65504), corrupting normalization (output
correlation collapses with depth while still reporting full residency). Scaling
by 1/32 before squaring keeps the sum in range.
Verified **on a Pixel 8a GPU**: zero banned ops, zero >4D tensors, full residency,
and GPU-vs-PyTorch output correlation ≈ 1.0 (the on-device GPU result, not just the
host CPU result).
## Training data & PII
SigLIP 2 was pretrained by Google on the **WebLI** dataset (billions of
web-crawled image–text pairs, multilingual, sigmoid contrastive objective). No new
training was performed for this conversion — it is a weights-exact format change of
the public `timm` checkpoint. Because the source data is web-scraped, it may
incidentally contain people, faces, text, and other PII; no PII was deliberately
collected, and this conversion adds none. Apply your own content/PII filtering as
appropriate. See the original [SigLIP 2 model card](https://huggingface.co/google/siglip2-base-patch16-224)
and [paper](https://arxiv.org/abs/2502.14786) for full dataset details.
## License & attribution
- **Apache-2.0** (original SigLIP 2 / [timm checkpoint](https://huggingface.co/timm/vit_base_patch16_siglip_224.v2_webli)).
- This is a format conversion; all credit to the original authors (Google).