Upload README.md with huggingface_hub
Browse files
README.md
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
library_name: webgpu
|
| 3 |
+
tags:
|
| 4 |
+
- mamba
|
| 5 |
+
- ssm
|
| 6 |
+
- webgpu
|
| 7 |
+
- browser-inference
|
| 8 |
+
- wgsl
|
| 9 |
+
- falcon-mamba
|
| 10 |
+
- state-space-model
|
| 11 |
+
- first-of-its-kind
|
| 12 |
+
language:
|
| 13 |
+
- en
|
| 14 |
+
license: apache-2.0
|
| 15 |
+
pipeline_tag: text-generation
|
| 16 |
+
base_model: tiiuae/falcon-mamba-7b-instruct
|
| 17 |
+
---
|
| 18 |
+
|
| 19 |
+
# Mamba WebGPU -- First Browser-Native SSM Inference Engine
|
| 20 |
+
|
| 21 |
+
**Falcon-Mamba 7B running in a browser tab.** Pure WebGPU compute shaders. No MLC, no TVM, no WASM, no compilation step. 12 hand-written WGSL shaders. First ever browser-native Mamba/SSM inference.
|
| 22 |
+
|
| 23 |
+
## What This Is
|
| 24 |
+
|
| 25 |
+
A complete inference runtime for [Falcon-Mamba-7B-Instruct](https://huggingface.co/tiiuae/falcon-mamba-7b-instruct) that runs entirely in the browser using WebGPU compute shaders. No server-side inference -- the model loads into GPU memory via the browser's WebGPU API and generates text using hand-written WGSL compute shaders.
|
| 26 |
+
|
| 27 |
+
This is NOT a transformer runtime. This is an **SSM (State Space Model) runtime** -- the Mamba architecture, which uses persistent recurrent state instead of KV cache. The state is fixed-size (38MB) regardless of context length.
|
| 28 |
+
|
| 29 |
+
## Why This Matters
|
| 30 |
+
|
| 31 |
+
- **WebLLM ships transformer models to the browser. This ships SSM models.**
|
| 32 |
+
- MLC/TVM don't support Mamba architecture ([confirmed](https://github.com/nicknlsn/mlc-llm/issues/1))
|
| 33 |
+
- The SSM state IS persistent memory -- save it, restore it, the entity remembers
|
| 34 |
+
- Fixed 38MB state vs unbounded KV cache growth
|
| 35 |
+
- No server needed for inference
|
| 36 |
+
|
| 37 |
+
## Quick Start
|
| 38 |
+
|
| 39 |
+
```bash
|
| 40 |
+
# Clone this repo
|
| 41 |
+
git clone https://huggingface.co/LJTSG/mamba-webgpu
|
| 42 |
+
|
| 43 |
+
# Start the dev server (serves weights from HF cache via byte-range requests)
|
| 44 |
+
node serve_mamba.js
|
| 45 |
+
|
| 46 |
+
# Open http://localhost:8140
|
| 47 |
+
# Click: Initialize -> Load Weights -> Generate
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
**Requirements:**
|
| 51 |
+
- Falcon-Mamba-7B-Instruct weights in your HuggingFace cache (`~/.cache/huggingface/hub/models--tiiuae--falcon-mamba-7b-instruct/`)
|
| 52 |
+
- Node.js (for the dev server)
|
| 53 |
+
- Python + `transformers` (for tokenization)
|
| 54 |
+
- Chrome/Edge with WebGPU support
|
| 55 |
+
- GPU with >= 16GB accessible via WebGPU (tested on AMD Strix Halo iGPU with 64GB unified memory)
|
| 56 |
+
|
| 57 |
+
## Architecture
|
| 58 |
+
|
| 59 |
+
```
|
| 60 |
+
Token -> Embedding lookup
|
| 61 |
+
-> 64x Mamba Layer:
|
| 62 |
+
RMSNorm -> in_proj GEMV -> split(x, gate)
|
| 63 |
+
-> conv1d_step (persistent state)
|
| 64 |
+
-> SiLU
|
| 65 |
+
-> x_proj GEMV -> RMSNorm(dt_pre, B, C) [Falcon-Mamba specific]
|
| 66 |
+
-> dt_proj GEMV -> softplus
|
| 67 |
+
-> SSU (selective state update, persistent state)
|
| 68 |
+
-> SiLU(gate) * hidden_y
|
| 69 |
+
-> out_proj GEMV -> residual add
|
| 70 |
+
-> Final RMSNorm -> lm_head GEMV -> Temperature sampling
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
**12 WGSL Compute Shaders:**
|
| 74 |
+
| Shader | Purpose | Workgroup |
|
| 75 |
+
|--------|---------|-----------|
|
| 76 |
+
| `rmsnorm.wgsl` | Root mean square normalization (with weights) | 64 threads |
|
| 77 |
+
| `rmsnorm_noweight.wgsl` | RMSNorm without learned weights (B/C/dt normalization) | 64 threads |
|
| 78 |
+
| `matmul_gemv.wgsl` | Matrix-vector product (M=1 specialized) | 64 threads/row |
|
| 79 |
+
| `conv1d_step.wgsl` | Autoregressive depthwise conv1d with state cache | 64 threads |
|
| 80 |
+
| `ssu.wgsl` | Selective state update (SSM scan, the core Mamba op) | 16 threads |
|
| 81 |
+
| `silu.wgsl` | SiLU/Swish activation, in-place | 64 threads |
|
| 82 |
+
| `softplus.wgsl` | Softplus activation | 64 threads |
|
| 83 |
+
| `embedding.wgsl` | Embedding table lookup | - |
|
| 84 |
+
| `elementwise_mul.wgsl` | Element-wise multiply, in-place | 64 threads |
|
| 85 |
+
| `add_residual.wgsl` | Residual connection add, in-place | 64 threads |
|
| 86 |
+
| `sample.wgsl` | Temperature-based multinomial sampling | 256 threads |
|
| 87 |
+
| `bf16_to_f32.wgsl` | BFloat16 to Float32 conversion | 64 threads |
|
| 88 |
+
|
| 89 |
+
## Performance
|
| 90 |
+
|
| 91 |
+
Tested on AMD Strix Halo (Radeon 8060S iGPU, RDNA-3, 64GB unified memory):
|
| 92 |
+
- **~3 tok/s** (~180ms per token)
|
| 93 |
+
- **~60s** weight loading (14GB F32 via byte-range fetch)
|
| 94 |
+
- **38MB** persistent SSM state (64 layers x 608KB)
|
| 95 |
+
- **~960 shader dispatches per token** (15 ops x 64 layers)
|
| 96 |
+
|
| 97 |
+
## The Build Story
|
| 98 |
+
|
| 99 |
+
Built over 36 hours across two sessions. Six bugs stood between "all zeros" and coherent output:
|
| 100 |
+
|
| 101 |
+
1. **Buffer alignment** -- WebGPU requires storage buffer binding offsets to be 256-byte aligned. An unaligned offset silently invalidated entire command encoders.
|
| 102 |
+
2. **A_log transform** -- Falcon-Mamba stores A_log; the SSU needs A = -exp(A_log) for proper state decay.
|
| 103 |
+
3. **Storage buffer limit** -- The SSU shader uses 9 storage buffers; default WebGPU limit is 8.
|
| 104 |
+
4. **Illegal buffer flags** -- MAP_READ cannot be combined with STORAGE usage.
|
| 105 |
+
5. **Diagnostic overhead** -- Per-token GPU readbacks for debugging were causing device timeouts.
|
| 106 |
+
6. **Missing RMSNorm on B, C, dt_pre** -- Falcon-Mamba applies weightless RMSNorm to B, C, and dt_pre before the SSU. Standard Mamba does not. This was the final bug -- every shader was correct, but we were implementing the wrong model.
|
| 107 |
+
|
| 108 |
+
The debugging involved systematic golden-value comparison against PyTorch, checking each intermediate buffer across all 8192 elements. Every single shader operation matched to 6 decimal places. The divergence was in the model architecture, not the compute.
|
| 109 |
+
|
| 110 |
+
## Files
|
| 111 |
+
|
| 112 |
+
- `mamba_runtime.js` -- WebGPU device init, shader compilation, safetensors weight loading (byte-range fetch, BF16->F32 CPU conversion), forward pass orchestration, generation loop
|
| 113 |
+
- `serve_mamba.js` -- Node.js dev server with Range request support for weights, tokenize/detokenize endpoints
|
| 114 |
+
- `index.html` -- Test page with Initialize/Load/Generate buttons
|
| 115 |
+
- `shaders/*.wgsl` -- 12 WGSL compute shaders
|
| 116 |
+
- `golden_dump.py` -- PyTorch golden value dumper for debugging
|
| 117 |
+
- `REPORT.md` -- Detailed build report
|
| 118 |
+
|
| 119 |
+
## Limitations
|
| 120 |
+
|
| 121 |
+
- Single-token decode only (no batch/prefill optimization)
|
| 122 |
+
- F32 weights (no quantization yet -- loads full 14GB)
|
| 123 |
+
- Tokenization requires Python server-side (no in-browser tokenizer)
|
| 124 |
+
- Tested only on AMD RDNA-3 iGPU; other GPUs may need limit adjustments
|
| 125 |
+
|
| 126 |
+
## License
|
| 127 |
+
|
| 128 |
+
Apache 2.0
|
| 129 |
+
|
| 130 |
+
## Credits
|
| 131 |
+
|
| 132 |
+
Built by Joshua ([@LJTSG](https://huggingface.co/LJTSG)) and Claude (Anthropic Opus 4.6).
|
| 133 |
+
Model: [tiiuae/falcon-mamba-7b-instruct](https://huggingface.co/tiiuae/falcon-mamba-7b-instruct).
|
| 134 |
+
Shaders ported from [gfx1151_runtime](https://github.com/) (Vulkan compute SSM runtime for AMD iGPU).
|