--- library_name: webgpu tags: - mamba - ssm - webgpu - browser-inference - wgsl - falcon-mamba - state-space-model - first-of-its-kind language: - en license: apache-2.0 pipeline_tag: text-generation base_model: tiiuae/falcon-mamba-7b-instruct --- # Mamba WebGPU -- First Browser-Native SSM Inference Engine **Falcon-Mamba 7B running in a browser tab.** Pure WebGPU compute shaders. No MLC, no TVM, no WASM, no compilation step. 12 hand-written WGSL shaders. First ever browser-native Mamba/SSM inference. ## What This Is A complete inference runtime for [Falcon-Mamba-7B-Instruct](https://huggingface.co/tiiuae/falcon-mamba-7b-instruct) that runs entirely in the browser using WebGPU compute shaders. No server-side inference -- the model loads into GPU memory via the browser's WebGPU API and generates text using hand-written WGSL compute shaders. This is NOT a transformer runtime. This is an **SSM (State Space Model) runtime** -- the Mamba architecture, which uses persistent recurrent state instead of KV cache. The state is fixed-size (38MB) regardless of context length. ## Why This Matters - **WebLLM ships transformer models to the browser. This ships SSM models.** - MLC/TVM don't support Mamba architecture ([confirmed](https://github.com/nicknlsn/mlc-llm/issues/1)) - The SSM state IS persistent memory -- save it, restore it, the entity remembers - Fixed 38MB state vs unbounded KV cache growth - No server needed for inference ## Quick Start ```bash # Clone this repo git clone https://huggingface.co/LJTSG/mamba-webgpu # Start the dev server (serves weights from HF cache via byte-range requests) node serve_mamba.js # Open http://localhost:8140 # Click: Initialize -> Load Weights -> Generate ``` **Requirements:** - Falcon-Mamba-7B-Instruct weights in your HuggingFace cache (`~/.cache/huggingface/hub/models--tiiuae--falcon-mamba-7b-instruct/`) - Node.js (for the dev server) - Python + `transformers` (for tokenization) - Chrome/Edge with WebGPU support - GPU with >= 16GB accessible via WebGPU (tested on AMD Strix Halo iGPU with 64GB unified memory) ## Architecture ``` Token -> Embedding lookup -> 64x Mamba Layer: RMSNorm -> in_proj GEMV -> split(x, gate) -> conv1d_step (persistent state) -> SiLU -> x_proj GEMV -> RMSNorm(dt_pre, B, C) [Falcon-Mamba specific] -> dt_proj GEMV -> softplus -> SSU (selective state update, persistent state) -> SiLU(gate) * hidden_y -> out_proj GEMV -> residual add -> Final RMSNorm -> lm_head GEMV -> Temperature sampling ``` **12 WGSL Compute Shaders:** | Shader | Purpose | Workgroup | |--------|---------|-----------| | `rmsnorm.wgsl` | Root mean square normalization (with weights) | 64 threads | | `rmsnorm_noweight.wgsl` | RMSNorm without learned weights (B/C/dt normalization) | 64 threads | | `matmul_gemv.wgsl` | Matrix-vector product (M=1 specialized) | 64 threads/row | | `conv1d_step.wgsl` | Autoregressive depthwise conv1d with state cache | 64 threads | | `ssu.wgsl` | Selective state update (SSM scan, the core Mamba op) | 16 threads | | `silu.wgsl` | SiLU/Swish activation, in-place | 64 threads | | `softplus.wgsl` | Softplus activation | 64 threads | | `embedding.wgsl` | Embedding table lookup | - | | `elementwise_mul.wgsl` | Element-wise multiply, in-place | 64 threads | | `add_residual.wgsl` | Residual connection add, in-place | 64 threads | | `sample.wgsl` | Temperature-based multinomial sampling | 256 threads | | `bf16_to_f32.wgsl` | BFloat16 to Float32 conversion | 64 threads | ## Performance Tested on AMD Strix Halo (Radeon 8060S iGPU, RDNA-3, 64GB unified memory): - **~3 tok/s** (~180ms per token) - **~60s** weight loading (14GB F32 via byte-range fetch) - **38MB** persistent SSM state (64 layers x 608KB) - **~960 shader dispatches per token** (15 ops x 64 layers) ## The Build Story Built over 36 hours across two sessions. Six bugs stood between "all zeros" and coherent output: 1. **Buffer alignment** -- WebGPU requires storage buffer binding offsets to be 256-byte aligned. An unaligned offset silently invalidated entire command encoders. 2. **A_log transform** -- Falcon-Mamba stores A_log; the SSU needs A = -exp(A_log) for proper state decay. 3. **Storage buffer limit** -- The SSU shader uses 9 storage buffers; default WebGPU limit is 8. 4. **Illegal buffer flags** -- MAP_READ cannot be combined with STORAGE usage. 5. **Diagnostic overhead** -- Per-token GPU readbacks for debugging were causing device timeouts. 6. **Missing RMSNorm on B, C, dt_pre** -- Falcon-Mamba applies weightless RMSNorm to B, C, and dt_pre before the SSU. Standard Mamba does not. This was the final bug -- every shader was correct, but we were implementing the wrong model. The debugging involved systematic golden-value comparison against PyTorch, checking each intermediate buffer across all 8192 elements. Every single shader operation matched to 6 decimal places. The divergence was in the model architecture, not the compute. ## Files - `mamba_runtime.js` -- WebGPU device init, shader compilation, safetensors weight loading (byte-range fetch, BF16->F32 CPU conversion), forward pass orchestration, generation loop - `serve_mamba.js` -- Node.js dev server with Range request support for weights, tokenize/detokenize endpoints - `index.html` -- Test page with Initialize/Load/Generate buttons - `shaders/*.wgsl` -- 12 WGSL compute shaders - `golden_dump.py` -- PyTorch golden value dumper for debugging - `REPORT.md` -- Detailed build report ## Limitations - Single-token decode only (no batch/prefill optimization) - F32 weights (no quantization yet -- loads full 14GB) - Tokenization requires Python server-side (no in-browser tokenizer) - Tested only on AMD RDNA-3 iGPU; other GPUs may need limit adjustments ## License Apache 2.0 ## Credits Built by Joshua ([@LJTSG](https://huggingface.co/LJTSG)) and Claude (Anthropic Opus 4.6). Model: [tiiuae/falcon-mamba-7b-instruct](https://huggingface.co/tiiuae/falcon-mamba-7b-instruct). Shaders ported from [gfx1151_runtime](https://github.com/) (Vulkan compute SSM runtime for AMD iGPU).