mlboydaisuke's picture
Upload README.md with huggingface_hub
fa53cc0 verified
|
Raw
History Blame Contribute Delete
4.13 kB
---
license: apache-2.0
library_name: litert
pipeline_tag: depth-estimation
tags:
- litert
- tflite
- depth-estimation
- monocular-depth
- on-device
- gpu
- depth-anything
base_model: depth-anything/DA3-SMALL
---
# Depth Anything 3 (Small) — LiteRT GPU, monocular depth
On-device **LiteRT / TFLite** conversion of [**Depth Anything 3 — Small**](https://huggingface.co/depth-anything/DA3-SMALL)
(ByteDance-Seed, Apache-2.0) for **monocular depth**, running fully on the mobile **GPU** via the LiteRT
`CompiledModel` API (ML Drift delegate). No CPU fallback ops — the whole graph is GPU-compatible.
| | |
|---|---|
| Task | Monocular depth (single RGB → depth) |
| Backbone | DINOv2 ViT-S + RoPE, DPT/DualDPT depth head |
| Input | `[1, 3, 896, 504]` NCHW float32, ImageNet-normalized, **native portrait aspect** |
| Output | `[1, 1, 896, 504]` depth |
| Precision / size | FP16, **55 MB** |
| Device | Pixel 8a, LiteRT GPU (`Accelerator.GPU`), **~1.8 s / image** |
| Fidelity | **Pearson corr 0.99948** vs the official PyTorch DA3-Small pipeline |
## Why a fixed 896×504 (native aspect, not square)
DA3 processes images at their **native aspect ratio** (`upper_bound_resize`, longer side → 896, multiple of 14).
Forcing a square `896×896` and letterbox-padding drops the match to corr **0.977** (the black padding leaks into
the content through global attention). Converting at the native rectangle restores **corr 0.9994** and is also
faster (fewer tokens). This checkpoint is built for **portrait ~9:16**. For another aspect, re-convert at that
shape (or your camera's fixed aspect) with the script below.
## Preprocessing (must match)
```
resize to 504×896 (W×H) → x/255 → (x - mean) / std
mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225] # ImageNet, RGB, NCHW
```
## GPU-clean conversion (what was patched)
Converted with `litert-torch`. DA3 is not GPU-clean out of the box; the following exact, GPU-clean rewrites
were applied (all numerically faithful unless noted):
1. checkpoint `model.` key-prefix strip (load fix)
2. RoPE `max_position = int(positions.max())+1` → constant (torch.export data-dependent)
3. fused-QKV attention → 3 separate Linears + 4D attention (avoids 5D RESHAPE; exact, 1e-6)
4. **LayerScale** `gamma` folded into `attn.proj` / `mlp.fc2` (the LayerScale MUL otherwise mis-lays-out the
token dim on the GPU delegate: `fully_connected {1,1,N,C} vs {N,1,1,C}`)
5. `pos_embed` bicubic interpolation **baked** to a constant (the interpolate of a constant emits `GATHER_ND`
on desktop and `RESIZE_BILINEAR` with 0 runtime inputs on device)
6. **ConvTranspose2d(k=s,stride=s)** → zero-stuff (nearest-upsample × top-left mask) + `Conv2d` (flipped
weight) — exact equivalent (~1e-7), because the Pixel-8a GPU rejects `TRANSPOSE_CONV` and the conv+
depth-to-space alternative needs >4D
7. DPT-head `custom_interpolate` `align_corners=True → False` (GPU bans `align_corners=True` resize) — **the
only non-exact rewrite**; source of the residual ~0.05 % vs the official model
8. head UV pos-embed-again disabled (its `make_sincos` broadcast emits `BROADCAST_TO`; ratio-0.1 refinement)
9. camera-token insertion `x[:, :, 0] = cam_token` → `torch.cat` (in-place index-assign → `SELECT_V2`)
Net result: `GATHER_ND = 0`, no `>4D` tensors, no `TRANSPOSE_CONV` / `BROADCAST_TO` / banned ops.
## Fidelity note (honest)
corr **0.99948** vs the official FP32 PyTorch pipeline. FP16 is **not** a factor (FP32≡FP16, corr 1.0). The
residual ~0.05 % is the `align_corners=True→False` change in (7), which the mobile GPU forces — an irreducible
hardware constraint, not a conversion error. Structure and edge sharpness are visually identical.
## Usage (Android / LiteRT CompiledModel)
```kotlin
val model = CompiledModel.create(context.assets, "da3_small_gpu_fp16.tflite",
CompiledModel.Options(Accelerator.GPU), null)
// input: [1,3,896,504] NCHW, ImageNet-normalized; output: [1,1,896,504] depth
```
## License
Apache-2.0, inherited from the upstream [Depth Anything 3](https://github.com/ByteDance-Seed/depth-anything-3).