mlboydaisuke's picture
NAFNet-GoPro-width32 LiteRT fp16 (fully-GPU deblur, Pixel 8a corr 1.0)
76dbca3 verified
|
Raw
History Blame Contribute Delete
2.93 kB
metadata
license: mit
library_name: LiteRT
pipeline_tag: image-to-image
tags:
  - litert
  - tflite
  - on-device
  - android
  - gpu
  - image-restoration
  - deblurring
  - nafnet
base_model: megvii-research/NAFNet

NAFNet-GoPro-width32 — LiteRT (on-device image deblur, fully-GPU)

NAFNet (Nonlinear Activation Free Network, ECCV 2022) image restoration, converted to LiteRT and running fully on the CompiledModel GPU (ML Drift) on Android. NAFNet is a U-Net of NAFBlocks with no activation functions at all (SimpleGate = channel-split multiply), so the whole network is a clean CNN on the GPU delegate. This is the GoPro-width32 variant — motion deblur.

NAFNet — blurry input | restored (on-device LiteRT GPU)

On-device (Pixel 8a, Tensor G3 — verified)

nodes on GPU 2179 / 2179 LITERT_CL (full residency)
inference ~42 ms (256×256)
size 38 MB (fp16)
accuracy device output == PyTorch (corr 1.000000) — re-authoring is numerically exact
image[1,3,256,256] (RGB [0,1]) →[GPU: NAFNet U-Net]→ restored[1,3,256,256]

Usage (Android, LiteRT CompiledModel)

val model = CompiledModel.create(modelPath, CompiledModel.Options(Accelerator.GPU), null)
val input = model.createInputBuffers(); val output = model.createOutputBuffers()
input[0].writeFloat(chw)             // [1,3,256,256] RGB in [0,1], NCHW
model.run(input, output)
val restored = output[0].readFloat()  // [1,3,256,256] in [0,1]

A complete Android sample (image picker + before/after) is in the official google-ai-edge/litert-samples repo under compiled_model_api/image_restoration.

How it converts (litert-torch)

NAFNet is fully convolutional (any size that is a multiple of 16; exported here at 256×256). Three numerically-exact GPU re-authorings:

  1. LayerNorm2d → fp16-safe channel LayerNorm. NAFNet's residual stream grows large (|x|≈175 at the bottleneck), so the LayerNorm channel reductions Σ_c x and Σ_c (x−μ)² (~15M) overflow fp16 (max 65504) on the Mali delegate (which computes in fp16 regardless of the model dtype) → a grid artifact. Doing the reductions in a down-scaled x/S domain (S=128) and rescaling is numerically exact and fp16-safe.
  2. Simplified Channel Attention AdaptiveAvgPool2d(1) → mean(3).mean(2) (two single-axis means).
  3. Upsample Conv2d(1×1)+PixelShuffle(2) → Conv2d + depth-to-space ZeroStuffConvT2d.

Result: banned ops NONE, all tensors ≤4D, tflite-vs-torch corr 1.0, device-vs-torch corr 1.0.

License

MIT. Upstream: megvii-research/NAFNet. Original weights: NAFNet-GoPro-width32 from the official release.