mlboydaisuke's picture
Upload README.md with huggingface_hub
22fd6fc verified
|
Raw
History Blame Contribute Delete
5.48 kB
---
license: apache-2.0
library_name: litert
pipeline_tag: image-feature-extraction
base_model: timm/vit_pe_core_base_patch16_224.fb
tags:
- litert
- tflite
- on-device
- android
- gpu
- clip
- perception-encoder
- image-encoder
- vit
- rope
---
# Perception Encoder (PE-Core-B16-224) — LiteRT (TFLite) GPU
On-device [LiteRT](https://ai.google.dev/edge/litert) (`.tflite`) conversion of
**Perception Encoder Core** (PE-Core, Meta 2025), the SOTA CLIP-style image tower,
converted from [`timm/vit_pe_core_base_patch16_224.fb`](https://huggingface.co/timm/vit_pe_core_base_patch16_224.fb)
(ViT-B/16, 94M params; original [facebook/PE-Core-B16-224](https://huggingface.co/facebook/PE-Core-B16-224)).
A single forward pass turns one RGB image into a **1024-d L2-normalized image
embedding** for zero-shot classification, retrieval, and similarity — running
**fully on the LiteRT `CompiledModel` GPU accelerator** (ML Drift): **all 1028
ops are GPU-native (`Replacing 1028 out of 1028 node(s) ... LITERT_CL`), no CPU
fallback, no Flex ops.**
## Files
| File | Size | Description |
|------|------|-------------|
| `pe_core_base_224_fp16.tflite` | 187 MB | FP16 single-graph model, GPU full-residency |
| `convert_pecore.py` | — | Reproducible conversion script (timm → tflite) |
## I/O
- **Input**: `[1, 3, 224, 224]` float32, **NCHW**, RGB normalized to **`[-1, 1]`**
i.e. `(pixel/255 - 0.5) / 0.5` (timm mean/std = `(0.5, 0.5, 0.5)`). Normalization
is applied by the caller (not baked into the graph).
- **Output**: `[1, 1024]` float32, **L2-normalized** image embedding.
## Usage (Android, LiteRT CompiledModel)
```kotlin
val model = CompiledModel.create(
context.assets, "pe_core_base_224_fp16.tflite",
CompiledModel.Options(Accelerator.GPU), null
)
val inputs = model.createInputBuffers()
val outputs = model.createOutputBuffers()
inputs[0].writeFloat(nchwFloatArray) // [1,3,224,224], RGB scaled to [-1,1]
model.run(inputs, outputs)
val embedding = outputs[0].readFloat() // [1024], already L2-normalized
```
For zero-shot classification, precompute text-label embeddings with the PE-Core
text tower offline and take the dot product on device.
## Performance
- **~66 ms / image steady-state** on a Pixel 8a (Mali-G615) GPU (best 12.5 ms),
full GPU residency, FP16.
## Conversion notes
Converted with [litert-torch / ai-edge-torch](https://github.com/google-ai-edge/ai-edge-torch).
Making a RoPE ViT image tower **fully GPU-resident *and* numerically correct** on
the ML Drift GPU delegate required four verbatim (weights-exact, output
corr ≈ 1.0) model-side rewrites — the first three for residency, the last for
on-device numerical correctness:
1. **Fused-qkv → 4D manual attention** — the fused `qkv` reshape emits a 5D
head-split the GPU delegate rejects; decompose into separate q/k/v projections.
Self-attention uses `scaled_dot_product_attention`, whose lowering keeps the
batch-matmul 3D with a materialized transpose (both required for residency).
2. **Interleaved 2D-RoPE → rotate-half** — PE-Core's interleaved rotary uses a
strided `x[..., ::2]` that lowers to `GATHER_ND` (GPU-banned). Bake an
even→odd channel permutation into the q/k weights (preserves q·k exactly) and
apply the rotate-half form with constant cos/sin → clean
`MUL`/`ADD`/`SLICE`/`CONCAT`.
3. **Attention-pool single-query attention → broadcast-multiply + reduce-sum**
the pooling query is a constant latent, so a batch-matmul there is
`const @ non-const` (rejected at compile, and the reordered `const-RHS` form is
mis-computed on device); expressing it as `(q·k).sum` + softmax + `(attn·v).sum`
is exact and GPU-correct.
4. **Overflow-safe LayerNorm** — the delegate computes the LayerNorm variance
reduction in **fp16 even for an fp32 graph**; deep-ViT "massive activations"
(|x|~50+) make `sum((x-mean)²)` exceed fp16 max (65504), so the normalization
is wrong and the error compounds with depth (output correlation collapses to
~0.28 over 12 blocks while *still reporting full GPU residency*). Scaling by
1/32 before squaring (undone after) keeps the running sum in range —
mathematically identical to `nn.LayerNorm`.
Verified **on a Pixel 8a GPU**: zero banned ops, zero >4D tensors, full residency,
and TFLite(GPU)-vs-PyTorch output correlation = 1.0 (the on-device GPU result —
not just the host CPU result — matches the reference).
## Training data & PII
PE-Core was pretrained by Meta on a large-scale **web-crawled image–text dataset**
(billions of image–caption pairs, CLIP-style contrastive objective). No new
training was performed for this conversion — it is a weights-exact format change
of the public `timm`/`facebook` checkpoint. Because the source data is
web-scraped, it may incidentally contain people, faces, text, and other PII;
no PII was deliberately collected, and this conversion adds none. Users deploying
the encoder should apply their own content/PII filtering as appropriate. See the
original [PE model card](https://huggingface.co/facebook/PE-Core-B16-224) and
[paper](https://arxiv.org/abs/2504.13181) for full dataset details.
## License & attribution
- **Apache-2.0** (original [PE-Core](https://huggingface.co/facebook/PE-Core-B16-224) /
[timm checkpoint](https://huggingface.co/timm/vit_pe_core_base_patch16_224.fb)).
- This is a format conversion; all credit to the original authors (Meta / FAIR).