mlboydaisuke's picture
Upload README.md with huggingface_hub
82972af verified
|
Raw
History Blame Contribute Delete
4.84 kB
---
license: apache-2.0
library_name: litert
pipeline_tag: image-feature-extraction
tags:
- litert
- tflite
- sam2
- segment-anything
- image-encoder
- on-device
- gpu
base_model: facebook/sam2.1-hiera-tiny
---
# SAM 2.1 (Hiera-Tiny) image encoder β€” LiteRT GPU
On-device **LiteRT / TFLite** conversion of the **image encoder** of
[**SAM 2.1 Hiera-Tiny**](https://huggingface.co/facebook/sam2.1-hiera-tiny) (Meta, Apache-2.0),
running **fully on the mobile GPU** via the LiteRT `CompiledModel` API (ML Drift / `LITERT_CL` delegate).
The whole graph is GPU-resident β€” no CPU/XNNPACK fallback ops.
This is the heavy backbone of the Segment Anything 2 image path: it turns an RGB image into the
multi-scale feature pyramid that a (small) prompt-encoder + mask-decoder then query per click/box.
| | |
|---|---|
| Task | Image encoder for promptable segmentation (SAM 2 image path) |
| Backbone | Hiera-Tiny (hierarchical ViT, window + global attention) + FPN neck |
| Input | `[1, 3, 1024, 1024]` NCHW float32, ImageNet-normalized |
| Outputs | 3 FPN feature maps: `[1,256,256,256]`, `[1,256,128,128]`, `[1,256,64,64]` |
| Precision / size | FP16, **80 MB** |
| Device | Pixel 8a, LiteRT GPU (`Accelerator.GPU`), **~7 ms / image** |
| Residency | **`Replacing 862 out of 862 node(s) with delegate (LITERT_CL)`** (full, single partition) |
## Preprocessing (must match)
```
resize to 1024x1024 (bilinear) -> x/255 -> (x - mean) / std
mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225] # ImageNet, RGB, NCHW
```
## GPU-clean conversion (what was re-authored)
Converted with `litert-torch`. SAM 2's Hiera encoder is not GPU-clean out of the box; these exact,
weights-faithful rewrites were applied (model-side only β€” **no converter patch**):
1. **`window_partition` / `window_unpartition`**: the 6-D `view`+`permute` window reshape rejected by the
GPU delegate (>4-D) is re-expressed as a sequence of **≀4-D** `reshape`/`transpose` ops (numerically
exact, verified vs the original).
2. **`Sam2MultiScaleAttention`**: the 5-D fused-QKV reshape is decomposed into separate q/k/v, and
attention runs as a **3-D batched SDPA** (`[B*heads, N, d]`). A 4-D SDPA makes the delegate emit a
`[C,C]->[nW,ws,C,C]` `BROADCAST_TO` on every windowed block; the 3-D form removes all 9.
3. **Windowed positional embedding**: the bicubic-interpolate + tile of the constant `pos_embed` is
**baked to a buffer** (add only) β€” removes a runtime interpolate of a constant.
4. **Neck**: the (constant, shape-only) sine FPN position encodings are dropped from the graph (compute
them host-side) β€” removes the remaining `BROADCAST_TO` ops.
5. **Overflow-safe LayerNorm** (scale-before-square) as an fp16 safety margin for the deep stages.
Net: `banned ops = NONE`, `>4-D tensors = 0`, full GPU residency.
## Fidelity (honest)
Eager re-authoring is **numerically exact** (`cos = 1.000`, `mae = 0`). On-device GPU output vs the
CPU reference, per FPN level:
| Output | cosine |
|---|---|
| FPN-0 `256x256` (high-res, drives mask detail) | **0.99998** |
| FPN-1 `128x128` | **0.99994** |
| FPN-2 `64x64` (coarse image embedding) | **0.99253** |
The deepest 64Γ—64 feature drifts slightly on the GPU. This is **not** LayerNorm overflow
(scale-before-square LayerNorm doesn't change it, and the CPU fp16 model matches PyTorch fp32 at
corr 0.999999) β€” it is the mobile GPU computing the deep-stage global attention (64Γ—64 = 4096 tokens)
in true fp16, where the CPU path upcasts to fp32. The high-resolution features that carry mask
boundaries are near-exact, so mask quality is preserved in practice.
## Usage (Android / LiteRT CompiledModel)
```kotlin
val model = CompiledModel.create(context.assets, "sam2_tiny_image_encoder_fp16.tflite",
CompiledModel.Options(Accelerator.GPU), null)
// input: [1,3,1024,1024] NCHW, ImageNet-normalized
// outputs: 3 FPN feature maps -> feed to the SAM 2 prompt encoder + mask decoder
```
## Training data & PII
SAM 2 was trained by Meta on **SA-1B** (licensed photos) and **SA-V** (licensed videos) with
model-in-the-loop mask annotation. No new training was performed for this conversion β€” it is a
weights-faithful format change of the public `facebook/sam2.1-hiera-tiny` checkpoint. Because the
source data is real-world imagery, it may incidentally contain people, faces, vehicles, signage and
other PII; no PII was deliberately collected and this conversion adds none. Apply your own content/PII
filtering as appropriate. See the [SAM 2 release](https://github.com/facebookresearch/sam2) and
[paper](https://arxiv.org/abs/2408.00714) for full dataset details.
## License
Apache-2.0, inherited from the upstream [SAM 2.1](https://huggingface.co/facebook/sam2.1-hiera-tiny).
This is a format conversion; all credit to the original authors (Meta AI).