File size: 8,193 Bytes
9f6cfb3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 | # Notes / lab log
Running log of what Claude Opus 4.8 in Claude Code figured out. Newest at the bottom of each section.
## Environment
- macOS (darwin arm64), Apple Silicon. No CUDA → torch CPU/MPS build.
- System python 3.9.6; using `uv` to manage an isolated env.
- git-lfs 3.7.1 present. Weights cloned to `/tmp/Moebius/Moebius-weights`
(pretrained, ft_celebahq, ft_ffhq, ft_places2 — each ~450MB fp32 `.bin`).
- Code repo at `/tmp/Moebius/Moebius`.
## Code map (what matters for the port)
- Entry: `infer/infer_moebius.py` → `infer/utils.py:build_pipeline`.
- Pipeline: `removal/v1_2/pipeline.py` (`RemovalSDXLPipeline_BatchMode`).
- Model wrapper: `removal/v1_2/removal_model.py` (`RemovalModel` = embedding + diff UNet).
- Core helpers: `utils_infer.py` (`encode_clean_latents`, `predict_noise`).
- UNet impl: `model_lib/nets/unet_lambda_prune_lite.py` (+ lambda layers under
`model_lib/nets/layers/λ/vanillaλ.py`).
- Config: `config/model_cfg/moebius.yaml`.
## Key findings
- The CUDA/Triton `fla` dependency is ONLY imported in `model_lib/nets/layers/gla/gla.py`
(the GLA teacher variant). Moebius's student UNet (lambda-DWConv) does not need it —
must avoid importing `unet_gla` to keep the graph clean for export.
- "Prompt" conditioning is a plain `nn.Embedding(20, 3072)`. CFG uses fixed ids:
cond=[0..9], uncond=[10..19]. So encoder_hidden_states is a constant per branch →
can be precomputed and baked into the ONNX UNet as a constant, OR passed as input.
- 9-channel UNet input = cat([noisy_latents(4), resized_mask(1), masked_latents(4)], dim=1).
- CFG batches uncond+cond into one forward (batch dim ×2), then splits.
- einsum is used in the λ layers (linear attention). Supported in ONNX; need to check
ORT-Web WebGPU coverage.
## Phase 1 results (reference inference — DONE)
- Got the real pipeline running end-to-end on CPU (macOS, torch 2.7.1).
- Patches needed to load student on CPU/mac:
- `model_lib/__init__.py`: wrapped teacher `unet_gla` import in try/except (needs `fla`).
- Don't import `utils_train` (drags in orjson/library); `build_vae` is just
`AutoencoderKL.from_pretrained(vae_dir)`.
- **VAE scaling_factor = 0.13025** (NOT the usual SD 0.18215!). Custom VAE.
block_out_channels = [128,256,512,512] → vae_scale_factor 8. This MUST be hardcoded
correctly in the JS port or colors/contrast will be wrong.
- removal_model params = 226.04M confirmed. load_state_dict: all keys matched.
- Perf: ~8.9 s/step on CPU (×19 steps + CFG ×2 = 38 UNet passes ≈ 2:48 total). WebGPU
expected far faster. Confirms CPU/WASM is unusable; WebGPU is the whole game.
- Output saved to reference_out/reference_result.png — plausible inpaint. Mask convention:
white(255) → 1 → region to inpaint (zeroed in masked_image); black → keep.
- num_inference_steps=20 with strength=0.99 → DDIM uses 19 steps (drops first).
## Parity strategy
- Won't try to reproduce torch RNG in JS. For PyTorch↔ONNX parity: dump identical input
tensors and compare outputs. For the web app: generate noise with a seedable JS RNG;
diffusion is robust to the particular noise draw, so visual results will be valid even
if not bit-identical to the torch reference.
- DDIM `scale_model_input` is identity → skip in TS. Need to reproduce DDIM alphas/betas
(scaled_linear, beta 0.00085→0.012, 1000 steps) and the DDIM step update in JS.
## Architecture: spatial size is FIXED (important!)
- Self-attn (attn1): MQSλ with `r=15` → local-context path (Conv3d pos_conv). Spatially
dynamic, fine at any size.
- Cross-attn (attn2): MQCλ with NO `r` → global path → `rel_pos_emb` is an
`nn.Parameter(n*n, m, dim_k, dim_u)` where n = per-block sample_size, m = 10. This is
TIED to the trained spatial resolution. Different spatial size → wrong/oob indexing.
- ⇒ Export at STATIC 512×512 image (64×64 latent). Web app resizes user input to 512×512,
inpaints, resizes result back + pastes. Square only. This is the benchmark resolution.
## ONNX export plan
- Three graphs, spatial static, batch dynamic where cheap:
- vae_encoder: (B,3,512,512) → moments (B,8,64,64); JS uses mean=moments[:,:4]*sf.
- unet (RemovalModel): (B,9,64,64), timesteps(B,), input_ids(B,10) → noise(B,4,64,64).
Embedding (nn.Embedding 20×3072) stays IN the graph (cheap gather). CFG batches B=2.
- vae_decoder: (B,4,64,64) → (B,3,512,512).
- scaling_factor = 0.13025 applied in JS (encode: latent*sf; decode: latent/sf).
## Phase 2 results (ONNX export — DONE)
- torch.onnx.export (legacy tracer, opset 18) traced all 3 graphs cleanly. No op-coverage
failures. The einsum/lambda/Conv3d ops all exported.
- Parity vs PyTorch (CPU EP): decoder 5.7e-5, unet 3.6e-6, encoder mean ch ~2e-2.
- FULL pipeline parity test (python/onnx_pipeline.py): reimplemented DDIM+CFG+9ch+scaling
in numpy on the ONNX sessions, vs torch models with identical noise:
final latents max|Δ| 0.149, decoded image mean|Δ| 0.0022, max 0.090 → visually identical.
This validates the ENTIRE orchestration I'll port to TS.
- numpy DDIM vs diffusers DDIMScheduler: step max|Δ| 5e-7, timesteps identical. ✓
## DDIM constants for JS (validated)
- betas = linspace(sqrt(0.00085), sqrt(0.012), 1000)^2 ; alphas_cumprod = cumprod(1-betas)
- timesteps(20 steps) = [950,900,...,50,0]; strength 0.99 ⇒ drop first ⇒ [900,...,0] (19).
- ddim_step (eta=0, clip_sample=False):
pred_x0 = (sample - sqrt(1-ac_t)*eps) / sqrt(ac_t)
prev = sqrt(ac_prev)*pred_x0 + sqrt(1-ac_prev)*eps
ac_prev = alphas_cumprod[prev_t], or final_alpha_cumprod=1.0 when prev_t<0 (last step).
- noise_offset 0.0357: noise += 0.0357 * randn(B,4,1,1). (optional; small)
## Web pipeline recipe (numpy → TS)
1. resize image+mask to 512×512 (mask NEAREST, binarize ≥128).
2. img→[-1,1] CHW; masked = img*(1-mask).
3. encode img & masked → moments; take mean[:4] * 0.13025.
4. mask→64×64 NEAREST, 1ch.
5. latents = randn(1,4,64,64) [+ noise_offset].
6. loop t in timesteps: nine=cat([latents,mask64,maskedLat]); batch×2; unet;
cfg = u + g*(c-u); latents = ddim_step.
7. decode(latents/0.13025); (x+1)/2; clip; → image.
8. paste: out*blur(mask) + (1-blur(mask))*orig.
## Phase 3 (web app) — in progress
- Vite + TS + onnxruntime-web (1.27.0). Default `onnxruntime-web/webgpu` import resolves
to the self-contained `ort.webgpu.bundle.min.mjs`.
- Models served LOCALLY: web/public/models -> ../models symlink, at /models/*.onnx.
Total ~1.24GB fetched over localhost (no internet). ORT runtime served from
web/ort-dist at /ort/* via a custom static middleware (see vite.config.ts).
- BUG FIXED: ORT glue .mjs must NOT be in /public (Vite tries to module-transform it).
Fix = serve /ort/* as raw static files via configureServer middleware.
- COOP/COEP headers set (needed for threaded WASM / SharedArrayBuffer).
## WebGPU op coverage (confirmed from ORT source at /tmp/Moebius/onnxruntime)
- js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts registers: Einsum ✓, Conv ✓
(conv.ts has computeConv3DInfo / createConv3DNaiveProgramInfo → Conv3d pos_conv works,
naive kernel so possibly slow), InstanceNormalization ✓, MatMul/Gemm ✓, Softmax ✓,
Reduce* ✓, Transpose/Concat/Gather/Pad/Resize/Where ✓.
- GroupNorm: not registered by that name, BUT torch.onnx exports nn.GroupNorm as a
Reshape→InstanceNormalization→Reshape→Mul→Add decomposition → covered. (VAE decoder
CPU-EP parity was 5.7e-5, so the graph is decomposed, not a single GroupNorm op.)
- ⇒ No expected silent CPU fallback for the heavy ops. Confirm empirically in console.
## Verification without a GPU browser (sandbox can't drive Chrome — user's live Chrome
## holds the playwright profile)
- web/test/fixture/*.bin: dumped inputs + reference final latents from the validated
numpy/ONNX pipeline, to check the TS port (ddim.ts + 9ch assembly) in Node.
## TODO / unknowns
- fp16 export to ~450MB UNet for real deployment (VAE fp16 unstable: decoder Cast issue;
keep VAE fp32). Quality risk in λ layers — validate before shipping fp16.
- Confirm in-browser: WebGPU selected, no CPU fallback, end-to-end correctness + timing.
|