File size: 6,125 Bytes
4fe528c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 | ---
library_name: webgpu
tags:
- mamba
- ssm
- webgpu
- browser-inference
- wgsl
- falcon-mamba
- state-space-model
- first-of-its-kind
language:
- en
license: apache-2.0
pipeline_tag: text-generation
base_model: tiiuae/falcon-mamba-7b-instruct
---
# Mamba WebGPU -- First Browser-Native SSM Inference Engine
**Falcon-Mamba 7B running in a browser tab.** Pure WebGPU compute shaders. No MLC, no TVM, no WASM, no compilation step. 12 hand-written WGSL shaders. First ever browser-native Mamba/SSM inference.
## What This Is
A complete inference runtime for [Falcon-Mamba-7B-Instruct](https://huggingface.co/tiiuae/falcon-mamba-7b-instruct) that runs entirely in the browser using WebGPU compute shaders. No server-side inference -- the model loads into GPU memory via the browser's WebGPU API and generates text using hand-written WGSL compute shaders.
This is NOT a transformer runtime. This is an **SSM (State Space Model) runtime** -- the Mamba architecture, which uses persistent recurrent state instead of KV cache. The state is fixed-size (38MB) regardless of context length.
## Why This Matters
- **WebLLM ships transformer models to the browser. This ships SSM models.**
- MLC/TVM don't support Mamba architecture ([confirmed](https://github.com/nicknlsn/mlc-llm/issues/1))
- The SSM state IS persistent memory -- save it, restore it, the entity remembers
- Fixed 38MB state vs unbounded KV cache growth
- No server needed for inference
## Quick Start
```bash
# Clone this repo
git clone https://huggingface.co/LJTSG/mamba-webgpu
# Start the dev server (serves weights from HF cache via byte-range requests)
node serve_mamba.js
# Open http://localhost:8140
# Click: Initialize -> Load Weights -> Generate
```
**Requirements:**
- Falcon-Mamba-7B-Instruct weights in your HuggingFace cache (`~/.cache/huggingface/hub/models--tiiuae--falcon-mamba-7b-instruct/`)
- Node.js (for the dev server)
- Python + `transformers` (for tokenization)
- Chrome/Edge with WebGPU support
- GPU with >= 16GB accessible via WebGPU (tested on AMD Strix Halo iGPU with 64GB unified memory)
## Architecture
```
Token -> Embedding lookup
-> 64x Mamba Layer:
RMSNorm -> in_proj GEMV -> split(x, gate)
-> conv1d_step (persistent state)
-> SiLU
-> x_proj GEMV -> RMSNorm(dt_pre, B, C) [Falcon-Mamba specific]
-> dt_proj GEMV -> softplus
-> SSU (selective state update, persistent state)
-> SiLU(gate) * hidden_y
-> out_proj GEMV -> residual add
-> Final RMSNorm -> lm_head GEMV -> Temperature sampling
```
**12 WGSL Compute Shaders:**
| Shader | Purpose | Workgroup |
|--------|---------|-----------|
| `rmsnorm.wgsl` | Root mean square normalization (with weights) | 64 threads |
| `rmsnorm_noweight.wgsl` | RMSNorm without learned weights (B/C/dt normalization) | 64 threads |
| `matmul_gemv.wgsl` | Matrix-vector product (M=1 specialized) | 64 threads/row |
| `conv1d_step.wgsl` | Autoregressive depthwise conv1d with state cache | 64 threads |
| `ssu.wgsl` | Selective state update (SSM scan, the core Mamba op) | 16 threads |
| `silu.wgsl` | SiLU/Swish activation, in-place | 64 threads |
| `softplus.wgsl` | Softplus activation | 64 threads |
| `embedding.wgsl` | Embedding table lookup | - |
| `elementwise_mul.wgsl` | Element-wise multiply, in-place | 64 threads |
| `add_residual.wgsl` | Residual connection add, in-place | 64 threads |
| `sample.wgsl` | Temperature-based multinomial sampling | 256 threads |
| `bf16_to_f32.wgsl` | BFloat16 to Float32 conversion | 64 threads |
## Performance
Tested on AMD Strix Halo (Radeon 8060S iGPU, RDNA-3, 64GB unified memory):
- **~3 tok/s** (~180ms per token)
- **~60s** weight loading (14GB F32 via byte-range fetch)
- **38MB** persistent SSM state (64 layers x 608KB)
- **~960 shader dispatches per token** (15 ops x 64 layers)
## The Build Story
Built over 36 hours across two sessions. Six bugs stood between "all zeros" and coherent output:
1. **Buffer alignment** -- WebGPU requires storage buffer binding offsets to be 256-byte aligned. An unaligned offset silently invalidated entire command encoders.
2. **A_log transform** -- Falcon-Mamba stores A_log; the SSU needs A = -exp(A_log) for proper state decay.
3. **Storage buffer limit** -- The SSU shader uses 9 storage buffers; default WebGPU limit is 8.
4. **Illegal buffer flags** -- MAP_READ cannot be combined with STORAGE usage.
5. **Diagnostic overhead** -- Per-token GPU readbacks for debugging were causing device timeouts.
6. **Missing RMSNorm on B, C, dt_pre** -- Falcon-Mamba applies weightless RMSNorm to B, C, and dt_pre before the SSU. Standard Mamba does not. This was the final bug -- every shader was correct, but we were implementing the wrong model.
The debugging involved systematic golden-value comparison against PyTorch, checking each intermediate buffer across all 8192 elements. Every single shader operation matched to 6 decimal places. The divergence was in the model architecture, not the compute.
## Files
- `mamba_runtime.js` -- WebGPU device init, shader compilation, safetensors weight loading (byte-range fetch, BF16->F32 CPU conversion), forward pass orchestration, generation loop
- `serve_mamba.js` -- Node.js dev server with Range request support for weights, tokenize/detokenize endpoints
- `index.html` -- Test page with Initialize/Load/Generate buttons
- `shaders/*.wgsl` -- 12 WGSL compute shaders
- `golden_dump.py` -- PyTorch golden value dumper for debugging
- `REPORT.md` -- Detailed build report
## Limitations
- Single-token decode only (no batch/prefill optimization)
- F32 weights (no quantization yet -- loads full 14GB)
- Tokenization requires Python server-side (no in-browser tokenizer)
- Tested only on AMD RDNA-3 iGPU; other GPUs may need limit adjustments
## License
Apache 2.0
## Credits
Built by Joshua ([@LJTSG](https://huggingface.co/LJTSG)) and Claude (Anthropic Opus 4.6).
Model: [tiiuae/falcon-mamba-7b-instruct](https://huggingface.co/tiiuae/falcon-mamba-7b-instruct).
Shaders ported from [gfx1151_runtime](https://github.com/) (Vulkan compute SSM runtime for AMD iGPU).
|