Moebius-ONNX / notes.md
simonw's picture
Add Moebius ONNX exports (unet + VAE enc/dec) + model card + lab notes
9f6cfb3 verified
|
Raw
History Blame Contribute Delete
8.19 kB

Notes / lab log

Running log of what Claude Opus 4.8 in Claude Code figured out. Newest at the bottom of each section.

Environment

  • macOS (darwin arm64), Apple Silicon. No CUDA → torch CPU/MPS build.
  • System python 3.9.6; using uv to manage an isolated env.
  • git-lfs 3.7.1 present. Weights cloned to /tmp/Moebius/Moebius-weights (pretrained, ft_celebahq, ft_ffhq, ft_places2 — each ~450MB fp32 .bin).
  • Code repo at /tmp/Moebius/Moebius.

Code map (what matters for the port)

  • Entry: infer/infer_moebius.pyinfer/utils.py:build_pipeline.
  • Pipeline: removal/v1_2/pipeline.py (RemovalSDXLPipeline_BatchMode).
  • Model wrapper: removal/v1_2/removal_model.py (RemovalModel = embedding + diff UNet).
  • Core helpers: utils_infer.py (encode_clean_latents, predict_noise).
  • UNet impl: model_lib/nets/unet_lambda_prune_lite.py (+ lambda layers under model_lib/nets/layers/λ/vanillaλ.py).
  • Config: config/model_cfg/moebius.yaml.

Key findings

  • The CUDA/Triton fla dependency is ONLY imported in model_lib/nets/layers/gla/gla.py (the GLA teacher variant). Moebius's student UNet (lambda-DWConv) does not need it — must avoid importing unet_gla to keep the graph clean for export.
  • "Prompt" conditioning is a plain nn.Embedding(20, 3072). CFG uses fixed ids: cond=[0..9], uncond=[10..19]. So encoder_hidden_states is a constant per branch → can be precomputed and baked into the ONNX UNet as a constant, OR passed as input.
  • 9-channel UNet input = cat([noisy_latents(4), resized_mask(1), masked_latents(4)], dim=1).
  • CFG batches uncond+cond into one forward (batch dim ×2), then splits.
  • einsum is used in the λ layers (linear attention). Supported in ONNX; need to check ORT-Web WebGPU coverage.

Phase 1 results (reference inference — DONE)

  • Got the real pipeline running end-to-end on CPU (macOS, torch 2.7.1).
  • Patches needed to load student on CPU/mac:
    • model_lib/__init__.py: wrapped teacher unet_gla import in try/except (needs fla).
    • Don't import utils_train (drags in orjson/library); build_vae is just AutoencoderKL.from_pretrained(vae_dir).
  • VAE scaling_factor = 0.13025 (NOT the usual SD 0.18215!). Custom VAE. block_out_channels = [128,256,512,512] → vae_scale_factor 8. This MUST be hardcoded correctly in the JS port or colors/contrast will be wrong.
  • removal_model params = 226.04M confirmed. load_state_dict: all keys matched.
  • Perf: ~8.9 s/step on CPU (×19 steps + CFG ×2 = 38 UNet passes ≈ 2:48 total). WebGPU expected far faster. Confirms CPU/WASM is unusable; WebGPU is the whole game.
  • Output saved to reference_out/reference_result.png — plausible inpaint. Mask convention: white(255) → 1 → region to inpaint (zeroed in masked_image); black → keep.
  • num_inference_steps=20 with strength=0.99 → DDIM uses 19 steps (drops first).

Parity strategy

  • Won't try to reproduce torch RNG in JS. For PyTorch↔ONNX parity: dump identical input tensors and compare outputs. For the web app: generate noise with a seedable JS RNG; diffusion is robust to the particular noise draw, so visual results will be valid even if not bit-identical to the torch reference.
  • DDIM scale_model_input is identity → skip in TS. Need to reproduce DDIM alphas/betas (scaled_linear, beta 0.00085→0.012, 1000 steps) and the DDIM step update in JS.

Architecture: spatial size is FIXED (important!)

  • Self-attn (attn1): MQSλ with r=15 → local-context path (Conv3d pos_conv). Spatially dynamic, fine at any size.
  • Cross-attn (attn2): MQCλ with NO r → global path → rel_pos_emb is an nn.Parameter(n*n, m, dim_k, dim_u) where n = per-block sample_size, m = 10. This is TIED to the trained spatial resolution. Different spatial size → wrong/oob indexing.
  • ⇒ Export at STATIC 512×512 image (64×64 latent). Web app resizes user input to 512×512, inpaints, resizes result back + pastes. Square only. This is the benchmark resolution.

ONNX export plan

  • Three graphs, spatial static, batch dynamic where cheap:
    • vae_encoder: (B,3,512,512) → moments (B,8,64,64); JS uses mean=moments[:,:4]*sf.
    • unet (RemovalModel): (B,9,64,64), timesteps(B,), input_ids(B,10) → noise(B,4,64,64). Embedding (nn.Embedding 20×3072) stays IN the graph (cheap gather). CFG batches B=2.
    • vae_decoder: (B,4,64,64) → (B,3,512,512).
  • scaling_factor = 0.13025 applied in JS (encode: latent*sf; decode: latent/sf).

Phase 2 results (ONNX export — DONE)

  • torch.onnx.export (legacy tracer, opset 18) traced all 3 graphs cleanly. No op-coverage failures. The einsum/lambda/Conv3d ops all exported.
  • Parity vs PyTorch (CPU EP): decoder 5.7e-5, unet 3.6e-6, encoder mean ch ~2e-2.
  • FULL pipeline parity test (python/onnx_pipeline.py): reimplemented DDIM+CFG+9ch+scaling in numpy on the ONNX sessions, vs torch models with identical noise: final latents max|Δ| 0.149, decoded image mean|Δ| 0.0022, max 0.090 → visually identical. This validates the ENTIRE orchestration I'll port to TS.
  • numpy DDIM vs diffusers DDIMScheduler: step max|Δ| 5e-7, timesteps identical. ✓

DDIM constants for JS (validated)

  • betas = linspace(sqrt(0.00085), sqrt(0.012), 1000)^2 ; alphas_cumprod = cumprod(1-betas)
  • timesteps(20 steps) = [950,900,...,50,0]; strength 0.99 ⇒ drop first ⇒ [900,...,0] (19).
  • ddim_step (eta=0, clip_sample=False): pred_x0 = (sample - sqrt(1-ac_t)*eps) / sqrt(ac_t) prev = sqrt(ac_prev)*pred_x0 + sqrt(1-ac_prev)*eps ac_prev = alphas_cumprod[prev_t], or final_alpha_cumprod=1.0 when prev_t<0 (last step).
  • noise_offset 0.0357: noise += 0.0357 * randn(B,4,1,1). (optional; small)

Web pipeline recipe (numpy → TS)

  1. resize image+mask to 512×512 (mask NEAREST, binarize ≥128).
  2. img→[-1,1] CHW; masked = img*(1-mask).
  3. encode img & masked → moments; take mean[:4] * 0.13025.
  4. mask→64×64 NEAREST, 1ch.
  5. latents = randn(1,4,64,64) [+ noise_offset].
  6. loop t in timesteps: nine=cat([latents,mask64,maskedLat]); batch×2; unet; cfg = u + g*(c-u); latents = ddim_step.
  7. decode(latents/0.13025); (x+1)/2; clip; → image.
  8. paste: out*blur(mask) + (1-blur(mask))*orig.

Phase 3 (web app) — in progress

  • Vite + TS + onnxruntime-web (1.27.0). Default onnxruntime-web/webgpu import resolves to the self-contained ort.webgpu.bundle.min.mjs.
  • Models served LOCALLY: web/public/models -> ../models symlink, at /models/.onnx. Total ~1.24GB fetched over localhost (no internet). ORT runtime served from web/ort-dist at /ort/ via a custom static middleware (see vite.config.ts).
  • BUG FIXED: ORT glue .mjs must NOT be in /public (Vite tries to module-transform it). Fix = serve /ort/* as raw static files via configureServer middleware.
  • COOP/COEP headers set (needed for threaded WASM / SharedArrayBuffer).

WebGPU op coverage (confirmed from ORT source at /tmp/Moebius/onnxruntime)

  • js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts registers: Einsum ✓, Conv ✓ (conv.ts has computeConv3DInfo / createConv3DNaiveProgramInfo → Conv3d pos_conv works, naive kernel so possibly slow), InstanceNormalization ✓, MatMul/Gemm ✓, Softmax ✓, Reduce* ✓, Transpose/Concat/Gather/Pad/Resize/Where ✓.
  • GroupNorm: not registered by that name, BUT torch.onnx exports nn.GroupNorm as a Reshape→InstanceNormalization→Reshape→Mul→Add decomposition → covered. (VAE decoder CPU-EP parity was 5.7e-5, so the graph is decomposed, not a single GroupNorm op.)
  • ⇒ No expected silent CPU fallback for the heavy ops. Confirm empirically in console.

Verification without a GPU browser (sandbox can't drive Chrome — user's live Chrome

holds the playwright profile)

  • web/test/fixture/*.bin: dumped inputs + reference final latents from the validated numpy/ONNX pipeline, to check the TS port (ddim.ts + 9ch assembly) in Node.

TODO / unknowns

  • fp16 export to ~450MB UNet for real deployment (VAE fp16 unstable: decoder Cast issue; keep VAE fp32). Quality risk in λ layers — validate before shipping fp16.
  • Confirm in-browser: WebGPU selected, no CPU fallback, end-to-end correctness + timing.