| --- |
| library_name: webgpu |
| tags: |
| - mamba |
| - ssm |
| - webgpu |
| - browser-inference |
| - wgsl |
| - falcon-mamba |
| - state-space-model |
| - first-of-its-kind |
| language: |
| - en |
| license: apache-2.0 |
| pipeline_tag: text-generation |
| base_model: tiiuae/falcon-mamba-7b-instruct |
| --- |
| |
| # Mamba WebGPU -- First Browser-Native SSM Inference Engine |
|
|
| **Falcon-Mamba 7B running in a browser tab.** Pure WebGPU compute shaders. No MLC, no TVM, no WASM, no compilation step. 12 hand-written WGSL shaders. First ever browser-native Mamba/SSM inference. |
|
|
| ## What This Is |
|
|
| A complete inference runtime for [Falcon-Mamba-7B-Instruct](https://huggingface.co/tiiuae/falcon-mamba-7b-instruct) that runs entirely in the browser using WebGPU compute shaders. No server-side inference -- the model loads into GPU memory via the browser's WebGPU API and generates text using hand-written WGSL compute shaders. |
|
|
| This is NOT a transformer runtime. This is an **SSM (State Space Model) runtime** -- the Mamba architecture, which uses persistent recurrent state instead of KV cache. The state is fixed-size (38MB) regardless of context length. |
|
|
| ## Why This Matters |
|
|
| - **WebLLM ships transformer models to the browser. This ships SSM models.** |
| - MLC/TVM don't support Mamba architecture ([confirmed](https://github.com/nicknlsn/mlc-llm/issues/1)) |
| - The SSM state IS persistent memory -- save it, restore it, the entity remembers |
| - Fixed 38MB state vs unbounded KV cache growth |
| - No server needed for inference |
|
|
| ## Quick Start |
|
|
| ```bash |
| # Clone this repo |
| git clone https://huggingface.co/LJTSG/mamba-webgpu |
| |
| # Start the dev server (serves weights from HF cache via byte-range requests) |
| node serve_mamba.js |
| |
| # Open http://localhost:8140 |
| # Click: Initialize -> Load Weights -> Generate |
| ``` |
|
|
| **Requirements:** |
| - Falcon-Mamba-7B-Instruct weights in your HuggingFace cache (`~/.cache/huggingface/hub/models--tiiuae--falcon-mamba-7b-instruct/`) |
| - Node.js (for the dev server) |
| - Python + `transformers` (for tokenization) |
| - Chrome/Edge with WebGPU support |
| - GPU with >= 16GB accessible via WebGPU (tested on AMD Strix Halo iGPU with 64GB unified memory) |
|
|
| ## Architecture |
|
|
| ``` |
| Token -> Embedding lookup |
| -> 64x Mamba Layer: |
| RMSNorm -> in_proj GEMV -> split(x, gate) |
| -> conv1d_step (persistent state) |
| -> SiLU |
| -> x_proj GEMV -> RMSNorm(dt_pre, B, C) [Falcon-Mamba specific] |
| -> dt_proj GEMV -> softplus |
| -> SSU (selective state update, persistent state) |
| -> SiLU(gate) * hidden_y |
| -> out_proj GEMV -> residual add |
| -> Final RMSNorm -> lm_head GEMV -> Temperature sampling |
| ``` |
|
|
| **12 WGSL Compute Shaders:** |
| | Shader | Purpose | Workgroup | |
| |--------|---------|-----------| |
| | `rmsnorm.wgsl` | Root mean square normalization (with weights) | 64 threads | |
| | `rmsnorm_noweight.wgsl` | RMSNorm without learned weights (B/C/dt normalization) | 64 threads | |
| | `matmul_gemv.wgsl` | Matrix-vector product (M=1 specialized) | 64 threads/row | |
| | `conv1d_step.wgsl` | Autoregressive depthwise conv1d with state cache | 64 threads | |
| | `ssu.wgsl` | Selective state update (SSM scan, the core Mamba op) | 16 threads | |
| | `silu.wgsl` | SiLU/Swish activation, in-place | 64 threads | |
| | `softplus.wgsl` | Softplus activation | 64 threads | |
| | `embedding.wgsl` | Embedding table lookup | - | |
| | `elementwise_mul.wgsl` | Element-wise multiply, in-place | 64 threads | |
| | `add_residual.wgsl` | Residual connection add, in-place | 64 threads | |
| | `sample.wgsl` | Temperature-based multinomial sampling | 256 threads | |
| | `bf16_to_f32.wgsl` | BFloat16 to Float32 conversion | 64 threads | |
|
|
| ## Performance |
|
|
| Tested on AMD Strix Halo (Radeon 8060S iGPU, RDNA-3, 64GB unified memory): |
| - **~3 tok/s** (~180ms per token) |
| - **~60s** weight loading (14GB F32 via byte-range fetch) |
| - **38MB** persistent SSM state (64 layers x 608KB) |
| - **~960 shader dispatches per token** (15 ops x 64 layers) |
|
|
| ## The Build Story |
|
|
| Built over 36 hours across two sessions. Six bugs stood between "all zeros" and coherent output: |
|
|
| 1. **Buffer alignment** -- WebGPU requires storage buffer binding offsets to be 256-byte aligned. An unaligned offset silently invalidated entire command encoders. |
| 2. **A_log transform** -- Falcon-Mamba stores A_log; the SSU needs A = -exp(A_log) for proper state decay. |
| 3. **Storage buffer limit** -- The SSU shader uses 9 storage buffers; default WebGPU limit is 8. |
| 4. **Illegal buffer flags** -- MAP_READ cannot be combined with STORAGE usage. |
| 5. **Diagnostic overhead** -- Per-token GPU readbacks for debugging were causing device timeouts. |
| 6. **Missing RMSNorm on B, C, dt_pre** -- Falcon-Mamba applies weightless RMSNorm to B, C, and dt_pre before the SSU. Standard Mamba does not. This was the final bug -- every shader was correct, but we were implementing the wrong model. |
| |
| The debugging involved systematic golden-value comparison against PyTorch, checking each intermediate buffer across all 8192 elements. Every single shader operation matched to 6 decimal places. The divergence was in the model architecture, not the compute. |
| |
| ## Files |
| |
| - `mamba_runtime.js` -- WebGPU device init, shader compilation, safetensors weight loading (byte-range fetch, BF16->F32 CPU conversion), forward pass orchestration, generation loop |
| - `serve_mamba.js` -- Node.js dev server with Range request support for weights, tokenize/detokenize endpoints |
| - `index.html` -- Test page with Initialize/Load/Generate buttons |
| - `shaders/*.wgsl` -- 12 WGSL compute shaders |
| - `golden_dump.py` -- PyTorch golden value dumper for debugging |
| - `REPORT.md` -- Detailed build report |
| |
| ## Limitations |
| |
| - Single-token decode only (no batch/prefill optimization) |
| - F32 weights (no quantization yet -- loads full 14GB) |
| - Tokenization requires Python server-side (no in-browser tokenizer) |
| - Tested only on AMD RDNA-3 iGPU; other GPUs may need limit adjustments |
| |
| ## License |
| |
| Apache 2.0 |
| |
| ## Credits |
| |
| Built by Joshua ([@LJTSG](https://huggingface.co/LJTSG)) and Claude (Anthropic Opus 4.6). |
| Model: [tiiuae/falcon-mamba-7b-instruct](https://huggingface.co/tiiuae/falcon-mamba-7b-instruct). |
| Shaders ported from [gfx1151_runtime](https://github.com/) (Vulkan compute SSM runtime for AMD iGPU). |
| |