--- license: apache-2.0 library_name: litert pipeline_tag: image-feature-extraction tags: - litert - tflite - sam2 - segment-anything - image-encoder - on-device - gpu base_model: facebook/sam2.1-hiera-tiny --- # SAM 2.1 (Hiera-Tiny) image encoder — LiteRT GPU On-device **LiteRT / TFLite** conversion of the **image encoder** of [**SAM 2.1 Hiera-Tiny**](https://huggingface.co/facebook/sam2.1-hiera-tiny) (Meta, Apache-2.0), running **fully on the mobile GPU** via the LiteRT `CompiledModel` API (ML Drift / `LITERT_CL` delegate). The whole graph is GPU-resident — no CPU/XNNPACK fallback ops. This is the heavy backbone of the Segment Anything 2 image path: it turns an RGB image into the multi-scale feature pyramid that a (small) prompt-encoder + mask-decoder then query per click/box. | | | |---|---| | Task | Image encoder for promptable segmentation (SAM 2 image path) | | Backbone | Hiera-Tiny (hierarchical ViT, window + global attention) + FPN neck | | Input | `[1, 3, 1024, 1024]` NCHW float32, ImageNet-normalized | | Outputs | 3 FPN feature maps: `[1,256,256,256]`, `[1,256,128,128]`, `[1,256,64,64]` | | Precision / size | FP16, **80 MB** | | Device | Pixel 8a, LiteRT GPU (`Accelerator.GPU`), **~7 ms / image** | | Residency | **`Replacing 862 out of 862 node(s) with delegate (LITERT_CL)`** (full, single partition) | ## Preprocessing (must match) ``` resize to 1024x1024 (bilinear) -> x/255 -> (x - mean) / std mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225] # ImageNet, RGB, NCHW ``` ## GPU-clean conversion (what was re-authored) Converted with `litert-torch`. SAM 2's Hiera encoder is not GPU-clean out of the box; these exact, weights-faithful rewrites were applied (model-side only — **no converter patch**): 1. **`window_partition` / `window_unpartition`**: the 6-D `view`+`permute` window reshape rejected by the GPU delegate (>4-D) is re-expressed as a sequence of **≤4-D** `reshape`/`transpose` ops (numerically exact, verified vs the original). 2. **`Sam2MultiScaleAttention`**: the 5-D fused-QKV reshape is decomposed into separate q/k/v, and attention runs as a **3-D batched SDPA** (`[B*heads, N, d]`). A 4-D SDPA makes the delegate emit a `[C,C]->[nW,ws,C,C]` `BROADCAST_TO` on every windowed block; the 3-D form removes all 9. 3. **Windowed positional embedding**: the bicubic-interpolate + tile of the constant `pos_embed` is **baked to a buffer** (add only) — removes a runtime interpolate of a constant. 4. **Neck**: the (constant, shape-only) sine FPN position encodings are dropped from the graph (compute them host-side) — removes the remaining `BROADCAST_TO` ops. 5. **Overflow-safe LayerNorm** (scale-before-square) as an fp16 safety margin for the deep stages. Net: `banned ops = NONE`, `>4-D tensors = 0`, full GPU residency. ## Fidelity (honest) Eager re-authoring is **numerically exact** (`cos = 1.000`, `mae = 0`). On-device GPU output vs the CPU reference, per FPN level: | Output | cosine | |---|---| | FPN-0 `256x256` (high-res, drives mask detail) | **0.99998** | | FPN-1 `128x128` | **0.99994** | | FPN-2 `64x64` (coarse image embedding) | **0.99253** | The deepest 64×64 feature drifts slightly on the GPU. This is **not** LayerNorm overflow (scale-before-square LayerNorm doesn't change it, and the CPU fp16 model matches PyTorch fp32 at corr 0.999999) — it is the mobile GPU computing the deep-stage global attention (64×64 = 4096 tokens) in true fp16, where the CPU path upcasts to fp32. The high-resolution features that carry mask boundaries are near-exact, so mask quality is preserved in practice. ## Usage (Android / LiteRT CompiledModel) ```kotlin val model = CompiledModel.create(context.assets, "sam2_tiny_image_encoder_fp16.tflite", CompiledModel.Options(Accelerator.GPU), null) // input: [1,3,1024,1024] NCHW, ImageNet-normalized // outputs: 3 FPN feature maps -> feed to the SAM 2 prompt encoder + mask decoder ``` ## Training data & PII SAM 2 was trained by Meta on **SA-1B** (licensed photos) and **SA-V** (licensed videos) with model-in-the-loop mask annotation. No new training was performed for this conversion — it is a weights-faithful format change of the public `facebook/sam2.1-hiera-tiny` checkpoint. Because the source data is real-world imagery, it may incidentally contain people, faces, vehicles, signage and other PII; no PII was deliberately collected and this conversion adds none. Apply your own content/PII filtering as appropriate. See the [SAM 2 release](https://github.com/facebookresearch/sam2) and [paper](https://arxiv.org/abs/2408.00714) for full dataset details. ## License Apache-2.0, inherited from the upstream [SAM 2.1](https://huggingface.co/facebook/sam2.1-hiera-tiny). This is a format conversion; all credit to the original authors (Meta AI).