Notes / lab log
Running log of what Claude Opus 4.8 in Claude Code figured out. Newest at the bottom of each section.
Environment
- macOS (darwin arm64), Apple Silicon. No CUDA → torch CPU/MPS build.
- System python 3.9.6; using
uvto manage an isolated env. - git-lfs 3.7.1 present. Weights cloned to
/tmp/Moebius/Moebius-weights(pretrained, ft_celebahq, ft_ffhq, ft_places2 — each ~450MB fp32.bin). - Code repo at
/tmp/Moebius/Moebius.
Code map (what matters for the port)
- Entry:
infer/infer_moebius.py→infer/utils.py:build_pipeline. - Pipeline:
removal/v1_2/pipeline.py(RemovalSDXLPipeline_BatchMode). - Model wrapper:
removal/v1_2/removal_model.py(RemovalModel= embedding + diff UNet). - Core helpers:
utils_infer.py(encode_clean_latents,predict_noise). - UNet impl:
model_lib/nets/unet_lambda_prune_lite.py(+ lambda layers undermodel_lib/nets/layers/λ/vanillaλ.py). - Config:
config/model_cfg/moebius.yaml.
Key findings
- The CUDA/Triton
fladependency is ONLY imported inmodel_lib/nets/layers/gla/gla.py(the GLA teacher variant). Moebius's student UNet (lambda-DWConv) does not need it — must avoid importingunet_glato keep the graph clean for export. - "Prompt" conditioning is a plain
nn.Embedding(20, 3072). CFG uses fixed ids: cond=[0..9], uncond=[10..19]. So encoder_hidden_states is a constant per branch → can be precomputed and baked into the ONNX UNet as a constant, OR passed as input. - 9-channel UNet input = cat([noisy_latents(4), resized_mask(1), masked_latents(4)], dim=1).
- CFG batches uncond+cond into one forward (batch dim ×2), then splits.
- einsum is used in the λ layers (linear attention). Supported in ONNX; need to check ORT-Web WebGPU coverage.
Phase 1 results (reference inference — DONE)
- Got the real pipeline running end-to-end on CPU (macOS, torch 2.7.1).
- Patches needed to load student on CPU/mac:
model_lib/__init__.py: wrapped teacherunet_glaimport in try/except (needsfla).- Don't import
utils_train(drags in orjson/library);build_vaeis justAutoencoderKL.from_pretrained(vae_dir).
- VAE scaling_factor = 0.13025 (NOT the usual SD 0.18215!). Custom VAE. block_out_channels = [128,256,512,512] → vae_scale_factor 8. This MUST be hardcoded correctly in the JS port or colors/contrast will be wrong.
- removal_model params = 226.04M confirmed. load_state_dict: all keys matched.
- Perf: ~8.9 s/step on CPU (×19 steps + CFG ×2 = 38 UNet passes ≈ 2:48 total). WebGPU expected far faster. Confirms CPU/WASM is unusable; WebGPU is the whole game.
- Output saved to reference_out/reference_result.png — plausible inpaint. Mask convention: white(255) → 1 → region to inpaint (zeroed in masked_image); black → keep.
- num_inference_steps=20 with strength=0.99 → DDIM uses 19 steps (drops first).
Parity strategy
- Won't try to reproduce torch RNG in JS. For PyTorch↔ONNX parity: dump identical input tensors and compare outputs. For the web app: generate noise with a seedable JS RNG; diffusion is robust to the particular noise draw, so visual results will be valid even if not bit-identical to the torch reference.
- DDIM
scale_model_inputis identity → skip in TS. Need to reproduce DDIM alphas/betas (scaled_linear, beta 0.00085→0.012, 1000 steps) and the DDIM step update in JS.
Architecture: spatial size is FIXED (important!)
- Self-attn (attn1): MQSλ with
r=15→ local-context path (Conv3d pos_conv). Spatially dynamic, fine at any size. - Cross-attn (attn2): MQCλ with NO
r→ global path →rel_pos_embis annn.Parameter(n*n, m, dim_k, dim_u)where n = per-block sample_size, m = 10. This is TIED to the trained spatial resolution. Different spatial size → wrong/oob indexing. - ⇒ Export at STATIC 512×512 image (64×64 latent). Web app resizes user input to 512×512, inpaints, resizes result back + pastes. Square only. This is the benchmark resolution.
ONNX export plan
- Three graphs, spatial static, batch dynamic where cheap:
- vae_encoder: (B,3,512,512) → moments (B,8,64,64); JS uses mean=moments[:,:4]*sf.
- unet (RemovalModel): (B,9,64,64), timesteps(B,), input_ids(B,10) → noise(B,4,64,64). Embedding (nn.Embedding 20×3072) stays IN the graph (cheap gather). CFG batches B=2.
- vae_decoder: (B,4,64,64) → (B,3,512,512).
- scaling_factor = 0.13025 applied in JS (encode: latent*sf; decode: latent/sf).
Phase 2 results (ONNX export — DONE)
- torch.onnx.export (legacy tracer, opset 18) traced all 3 graphs cleanly. No op-coverage failures. The einsum/lambda/Conv3d ops all exported.
- Parity vs PyTorch (CPU EP): decoder 5.7e-5, unet 3.6e-6, encoder mean ch ~2e-2.
- FULL pipeline parity test (python/onnx_pipeline.py): reimplemented DDIM+CFG+9ch+scaling in numpy on the ONNX sessions, vs torch models with identical noise: final latents max|Δ| 0.149, decoded image mean|Δ| 0.0022, max 0.090 → visually identical. This validates the ENTIRE orchestration I'll port to TS.
- numpy DDIM vs diffusers DDIMScheduler: step max|Δ| 5e-7, timesteps identical. ✓
DDIM constants for JS (validated)
- betas = linspace(sqrt(0.00085), sqrt(0.012), 1000)^2 ; alphas_cumprod = cumprod(1-betas)
- timesteps(20 steps) = [950,900,...,50,0]; strength 0.99 ⇒ drop first ⇒ [900,...,0] (19).
- ddim_step (eta=0, clip_sample=False): pred_x0 = (sample - sqrt(1-ac_t)*eps) / sqrt(ac_t) prev = sqrt(ac_prev)*pred_x0 + sqrt(1-ac_prev)*eps ac_prev = alphas_cumprod[prev_t], or final_alpha_cumprod=1.0 when prev_t<0 (last step).
- noise_offset 0.0357: noise += 0.0357 * randn(B,4,1,1). (optional; small)
Web pipeline recipe (numpy → TS)
- resize image+mask to 512×512 (mask NEAREST, binarize ≥128).
- img→[-1,1] CHW; masked = img*(1-mask).
- encode img & masked → moments; take mean[:4] * 0.13025.
- mask→64×64 NEAREST, 1ch.
- latents = randn(1,4,64,64) [+ noise_offset].
- loop t in timesteps: nine=cat([latents,mask64,maskedLat]); batch×2; unet; cfg = u + g*(c-u); latents = ddim_step.
- decode(latents/0.13025); (x+1)/2; clip; → image.
- paste: out*blur(mask) + (1-blur(mask))*orig.
Phase 3 (web app) — in progress
- Vite + TS + onnxruntime-web (1.27.0). Default
onnxruntime-web/webgpuimport resolves to the self-containedort.webgpu.bundle.min.mjs. - Models served LOCALLY: web/public/models -> ../models symlink, at /models/.onnx. Total ~1.24GB fetched over localhost (no internet). ORT runtime served from web/ort-dist at /ort/ via a custom static middleware (see vite.config.ts).
- BUG FIXED: ORT glue .mjs must NOT be in /public (Vite tries to module-transform it). Fix = serve /ort/* as raw static files via configureServer middleware.
- COOP/COEP headers set (needed for threaded WASM / SharedArrayBuffer).
WebGPU op coverage (confirmed from ORT source at /tmp/Moebius/onnxruntime)
- js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts registers: Einsum ✓, Conv ✓ (conv.ts has computeConv3DInfo / createConv3DNaiveProgramInfo → Conv3d pos_conv works, naive kernel so possibly slow), InstanceNormalization ✓, MatMul/Gemm ✓, Softmax ✓, Reduce* ✓, Transpose/Concat/Gather/Pad/Resize/Where ✓.
- GroupNorm: not registered by that name, BUT torch.onnx exports nn.GroupNorm as a Reshape→InstanceNormalization→Reshape→Mul→Add decomposition → covered. (VAE decoder CPU-EP parity was 5.7e-5, so the graph is decomposed, not a single GroupNorm op.)
- ⇒ No expected silent CPU fallback for the heavy ops. Confirm empirically in console.
Verification without a GPU browser (sandbox can't drive Chrome — user's live Chrome
holds the playwright profile)
- web/test/fixture/*.bin: dumped inputs + reference final latents from the validated numpy/ONNX pipeline, to check the TS port (ddim.ts + 9ch assembly) in Node.
TODO / unknowns
- fp16 export to ~450MB UNet for real deployment (VAE fp16 unstable: decoder Cast issue; keep VAE fp32). Quality risk in λ layers — validate before shipping fp16.
- Confirm in-browser: WebGPU selected, no CPU fallback, end-to-end correctness + timing.