mamba-webgpu / README.md
LJTSG's picture
Upload README.md with huggingface_hub
4fe528c verified
---
library_name: webgpu
tags:
- mamba
- ssm
- webgpu
- browser-inference
- wgsl
- falcon-mamba
- state-space-model
- first-of-its-kind
language:
- en
license: apache-2.0
pipeline_tag: text-generation
base_model: tiiuae/falcon-mamba-7b-instruct
---
# Mamba WebGPU -- First Browser-Native SSM Inference Engine
**Falcon-Mamba 7B running in a browser tab.** Pure WebGPU compute shaders. No MLC, no TVM, no WASM, no compilation step. 12 hand-written WGSL shaders. First ever browser-native Mamba/SSM inference.
## What This Is
A complete inference runtime for [Falcon-Mamba-7B-Instruct](https://huggingface.co/tiiuae/falcon-mamba-7b-instruct) that runs entirely in the browser using WebGPU compute shaders. No server-side inference -- the model loads into GPU memory via the browser's WebGPU API and generates text using hand-written WGSL compute shaders.
This is NOT a transformer runtime. This is an **SSM (State Space Model) runtime** -- the Mamba architecture, which uses persistent recurrent state instead of KV cache. The state is fixed-size (38MB) regardless of context length.
## Why This Matters
- **WebLLM ships transformer models to the browser. This ships SSM models.**
- MLC/TVM don't support Mamba architecture ([confirmed](https://github.com/nicknlsn/mlc-llm/issues/1))
- The SSM state IS persistent memory -- save it, restore it, the entity remembers
- Fixed 38MB state vs unbounded KV cache growth
- No server needed for inference
## Quick Start
```bash
# Clone this repo
git clone https://huggingface.co/LJTSG/mamba-webgpu
# Start the dev server (serves weights from HF cache via byte-range requests)
node serve_mamba.js
# Open http://localhost:8140
# Click: Initialize -> Load Weights -> Generate
```
**Requirements:**
- Falcon-Mamba-7B-Instruct weights in your HuggingFace cache (`~/.cache/huggingface/hub/models--tiiuae--falcon-mamba-7b-instruct/`)
- Node.js (for the dev server)
- Python + `transformers` (for tokenization)
- Chrome/Edge with WebGPU support
- GPU with >= 16GB accessible via WebGPU (tested on AMD Strix Halo iGPU with 64GB unified memory)
## Architecture
```
Token -> Embedding lookup
-> 64x Mamba Layer:
RMSNorm -> in_proj GEMV -> split(x, gate)
-> conv1d_step (persistent state)
-> SiLU
-> x_proj GEMV -> RMSNorm(dt_pre, B, C) [Falcon-Mamba specific]
-> dt_proj GEMV -> softplus
-> SSU (selective state update, persistent state)
-> SiLU(gate) * hidden_y
-> out_proj GEMV -> residual add
-> Final RMSNorm -> lm_head GEMV -> Temperature sampling
```
**12 WGSL Compute Shaders:**
| Shader | Purpose | Workgroup |
|--------|---------|-----------|
| `rmsnorm.wgsl` | Root mean square normalization (with weights) | 64 threads |
| `rmsnorm_noweight.wgsl` | RMSNorm without learned weights (B/C/dt normalization) | 64 threads |
| `matmul_gemv.wgsl` | Matrix-vector product (M=1 specialized) | 64 threads/row |
| `conv1d_step.wgsl` | Autoregressive depthwise conv1d with state cache | 64 threads |
| `ssu.wgsl` | Selective state update (SSM scan, the core Mamba op) | 16 threads |
| `silu.wgsl` | SiLU/Swish activation, in-place | 64 threads |
| `softplus.wgsl` | Softplus activation | 64 threads |
| `embedding.wgsl` | Embedding table lookup | - |
| `elementwise_mul.wgsl` | Element-wise multiply, in-place | 64 threads |
| `add_residual.wgsl` | Residual connection add, in-place | 64 threads |
| `sample.wgsl` | Temperature-based multinomial sampling | 256 threads |
| `bf16_to_f32.wgsl` | BFloat16 to Float32 conversion | 64 threads |
## Performance
Tested on AMD Strix Halo (Radeon 8060S iGPU, RDNA-3, 64GB unified memory):
- **~3 tok/s** (~180ms per token)
- **~60s** weight loading (14GB F32 via byte-range fetch)
- **38MB** persistent SSM state (64 layers x 608KB)
- **~960 shader dispatches per token** (15 ops x 64 layers)
## The Build Story
Built over 36 hours across two sessions. Six bugs stood between "all zeros" and coherent output:
1. **Buffer alignment** -- WebGPU requires storage buffer binding offsets to be 256-byte aligned. An unaligned offset silently invalidated entire command encoders.
2. **A_log transform** -- Falcon-Mamba stores A_log; the SSU needs A = -exp(A_log) for proper state decay.
3. **Storage buffer limit** -- The SSU shader uses 9 storage buffers; default WebGPU limit is 8.
4. **Illegal buffer flags** -- MAP_READ cannot be combined with STORAGE usage.
5. **Diagnostic overhead** -- Per-token GPU readbacks for debugging were causing device timeouts.
6. **Missing RMSNorm on B, C, dt_pre** -- Falcon-Mamba applies weightless RMSNorm to B, C, and dt_pre before the SSU. Standard Mamba does not. This was the final bug -- every shader was correct, but we were implementing the wrong model.
The debugging involved systematic golden-value comparison against PyTorch, checking each intermediate buffer across all 8192 elements. Every single shader operation matched to 6 decimal places. The divergence was in the model architecture, not the compute.
## Files
- `mamba_runtime.js` -- WebGPU device init, shader compilation, safetensors weight loading (byte-range fetch, BF16->F32 CPU conversion), forward pass orchestration, generation loop
- `serve_mamba.js` -- Node.js dev server with Range request support for weights, tokenize/detokenize endpoints
- `index.html` -- Test page with Initialize/Load/Generate buttons
- `shaders/*.wgsl` -- 12 WGSL compute shaders
- `golden_dump.py` -- PyTorch golden value dumper for debugging
- `REPORT.md` -- Detailed build report
## Limitations
- Single-token decode only (no batch/prefill optimization)
- F32 weights (no quantization yet -- loads full 14GB)
- Tokenization requires Python server-side (no in-browser tokenizer)
- Tested only on AMD RDNA-3 iGPU; other GPUs may need limit adjustments
## License
Apache 2.0
## Credits
Built by Joshua ([@LJTSG](https://huggingface.co/LJTSG)) and Claude (Anthropic Opus 4.6).
Model: [tiiuae/falcon-mamba-7b-instruct](https://huggingface.co/tiiuae/falcon-mamba-7b-instruct).
Shaders ported from [gfx1151_runtime](https://github.com/) (Vulkan compute SSM runtime for AMD iGPU).