mlboydaisuke's picture
NAFNet-GoPro-width32 LiteRT fp16 (fully-GPU deblur, Pixel 8a corr 1.0)
76dbca3 verified
|
Raw
History Blame Contribute Delete
2.93 kB
---
license: mit
library_name: LiteRT
pipeline_tag: image-to-image
tags:
- litert
- tflite
- on-device
- android
- gpu
- image-restoration
- deblurring
- nafnet
base_model: megvii-research/NAFNet
---
# NAFNet-GoPro-width32 β€” LiteRT (on-device image deblur, fully-GPU)
[NAFNet](https://github.com/megvii-research/NAFNet) (Nonlinear Activation Free Network, ECCV 2022) image
restoration, converted to **LiteRT** and running **fully on the `CompiledModel` GPU** (ML Drift) on Android.
NAFNet is a U-Net of **NAFBlocks** with **no activation functions at all** (SimpleGate = channel-split
multiply), so the whole network is a clean CNN on the GPU delegate. This is the **GoPro-width32** variant β€”
motion deblur.
![NAFNet β€” blurry input | restored (on-device LiteRT GPU)](samples/sample.png)
## On-device (Pixel 8a, Tensor G3 β€” verified)
| | |
|---|---|
| nodes on GPU | **2179 / 2179** LITERT_CL (full residency) |
| inference | **~42 ms** (256Γ—256) |
| size | 38 MB (fp16) |
| accuracy | device output **== PyTorch (corr 1.000000)** β€” re-authoring is numerically exact |
```
image[1,3,256,256] (RGB [0,1]) β†’[GPU: NAFNet U-Net]β†’ restored[1,3,256,256]
```
## Usage (Android, LiteRT CompiledModel)
```kotlin
val model = CompiledModel.create(modelPath, CompiledModel.Options(Accelerator.GPU), null)
val input = model.createInputBuffers(); val output = model.createOutputBuffers()
input[0].writeFloat(chw) // [1,3,256,256] RGB in [0,1], NCHW
model.run(input, output)
val restored = output[0].readFloat() // [1,3,256,256] in [0,1]
```
A complete Android sample (image picker + before/after) is in the official
[google-ai-edge/litert-samples](https://github.com/google-ai-edge/litert-samples) repo under
`compiled_model_api/image_restoration`.
## How it converts (litert-torch)
NAFNet is fully convolutional (any size that is a multiple of 16; exported here at 256Γ—256). Three
numerically-exact GPU re-authorings:
1. **`LayerNorm2d` β†’ fp16-safe channel LayerNorm.** NAFNet's residual stream grows large (|x|β‰ˆ175 at the
bottleneck), so the LayerNorm channel reductions `Ξ£_c x` and `Ξ£_c (xβˆ’ΞΌ)Β²` (~15M) **overflow fp16 (max
65504)** on the Mali delegate (which computes in fp16 regardless of the model dtype) β†’ a grid artifact.
Doing the reductions in a down-scaled `x/S` domain (S=128) and rescaling is numerically exact and fp16-safe.
2. **Simplified Channel Attention `AdaptiveAvgPool2d(1)` β†’ `mean(3).mean(2)`** (two single-axis means).
3. **Upsample `Conv2d(1Γ—1)+PixelShuffle(2)` β†’ Conv2d + depth-to-space `ZeroStuffConvT2d`**.
Result: banned ops NONE, all tensors ≀4D, tflite-vs-torch corr **1.0**, device-vs-torch corr **1.0**.
## License
[MIT](https://github.com/megvii-research/NAFNet/blob/main/LICENSE). Upstream:
[megvii-research/NAFNet](https://github.com/megvii-research/NAFNet). Original weights:
NAFNet-GoPro-width32 from the official release.