LJTSG commited on
Commit
4fe528c
·
verified ·
1 Parent(s): 78234c3

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +134 -0
README.md ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: webgpu
3
+ tags:
4
+ - mamba
5
+ - ssm
6
+ - webgpu
7
+ - browser-inference
8
+ - wgsl
9
+ - falcon-mamba
10
+ - state-space-model
11
+ - first-of-its-kind
12
+ language:
13
+ - en
14
+ license: apache-2.0
15
+ pipeline_tag: text-generation
16
+ base_model: tiiuae/falcon-mamba-7b-instruct
17
+ ---
18
+
19
+ # Mamba WebGPU -- First Browser-Native SSM Inference Engine
20
+
21
+ **Falcon-Mamba 7B running in a browser tab.** Pure WebGPU compute shaders. No MLC, no TVM, no WASM, no compilation step. 12 hand-written WGSL shaders. First ever browser-native Mamba/SSM inference.
22
+
23
+ ## What This Is
24
+
25
+ A complete inference runtime for [Falcon-Mamba-7B-Instruct](https://huggingface.co/tiiuae/falcon-mamba-7b-instruct) that runs entirely in the browser using WebGPU compute shaders. No server-side inference -- the model loads into GPU memory via the browser's WebGPU API and generates text using hand-written WGSL compute shaders.
26
+
27
+ This is NOT a transformer runtime. This is an **SSM (State Space Model) runtime** -- the Mamba architecture, which uses persistent recurrent state instead of KV cache. The state is fixed-size (38MB) regardless of context length.
28
+
29
+ ## Why This Matters
30
+
31
+ - **WebLLM ships transformer models to the browser. This ships SSM models.**
32
+ - MLC/TVM don't support Mamba architecture ([confirmed](https://github.com/nicknlsn/mlc-llm/issues/1))
33
+ - The SSM state IS persistent memory -- save it, restore it, the entity remembers
34
+ - Fixed 38MB state vs unbounded KV cache growth
35
+ - No server needed for inference
36
+
37
+ ## Quick Start
38
+
39
+ ```bash
40
+ # Clone this repo
41
+ git clone https://huggingface.co/LJTSG/mamba-webgpu
42
+
43
+ # Start the dev server (serves weights from HF cache via byte-range requests)
44
+ node serve_mamba.js
45
+
46
+ # Open http://localhost:8140
47
+ # Click: Initialize -> Load Weights -> Generate
48
+ ```
49
+
50
+ **Requirements:**
51
+ - Falcon-Mamba-7B-Instruct weights in your HuggingFace cache (`~/.cache/huggingface/hub/models--tiiuae--falcon-mamba-7b-instruct/`)
52
+ - Node.js (for the dev server)
53
+ - Python + `transformers` (for tokenization)
54
+ - Chrome/Edge with WebGPU support
55
+ - GPU with >= 16GB accessible via WebGPU (tested on AMD Strix Halo iGPU with 64GB unified memory)
56
+
57
+ ## Architecture
58
+
59
+ ```
60
+ Token -> Embedding lookup
61
+ -> 64x Mamba Layer:
62
+ RMSNorm -> in_proj GEMV -> split(x, gate)
63
+ -> conv1d_step (persistent state)
64
+ -> SiLU
65
+ -> x_proj GEMV -> RMSNorm(dt_pre, B, C) [Falcon-Mamba specific]
66
+ -> dt_proj GEMV -> softplus
67
+ -> SSU (selective state update, persistent state)
68
+ -> SiLU(gate) * hidden_y
69
+ -> out_proj GEMV -> residual add
70
+ -> Final RMSNorm -> lm_head GEMV -> Temperature sampling
71
+ ```
72
+
73
+ **12 WGSL Compute Shaders:**
74
+ | Shader | Purpose | Workgroup |
75
+ |--------|---------|-----------|
76
+ | `rmsnorm.wgsl` | Root mean square normalization (with weights) | 64 threads |
77
+ | `rmsnorm_noweight.wgsl` | RMSNorm without learned weights (B/C/dt normalization) | 64 threads |
78
+ | `matmul_gemv.wgsl` | Matrix-vector product (M=1 specialized) | 64 threads/row |
79
+ | `conv1d_step.wgsl` | Autoregressive depthwise conv1d with state cache | 64 threads |
80
+ | `ssu.wgsl` | Selective state update (SSM scan, the core Mamba op) | 16 threads |
81
+ | `silu.wgsl` | SiLU/Swish activation, in-place | 64 threads |
82
+ | `softplus.wgsl` | Softplus activation | 64 threads |
83
+ | `embedding.wgsl` | Embedding table lookup | - |
84
+ | `elementwise_mul.wgsl` | Element-wise multiply, in-place | 64 threads |
85
+ | `add_residual.wgsl` | Residual connection add, in-place | 64 threads |
86
+ | `sample.wgsl` | Temperature-based multinomial sampling | 256 threads |
87
+ | `bf16_to_f32.wgsl` | BFloat16 to Float32 conversion | 64 threads |
88
+
89
+ ## Performance
90
+
91
+ Tested on AMD Strix Halo (Radeon 8060S iGPU, RDNA-3, 64GB unified memory):
92
+ - **~3 tok/s** (~180ms per token)
93
+ - **~60s** weight loading (14GB F32 via byte-range fetch)
94
+ - **38MB** persistent SSM state (64 layers x 608KB)
95
+ - **~960 shader dispatches per token** (15 ops x 64 layers)
96
+
97
+ ## The Build Story
98
+
99
+ Built over 36 hours across two sessions. Six bugs stood between "all zeros" and coherent output:
100
+
101
+ 1. **Buffer alignment** -- WebGPU requires storage buffer binding offsets to be 256-byte aligned. An unaligned offset silently invalidated entire command encoders.
102
+ 2. **A_log transform** -- Falcon-Mamba stores A_log; the SSU needs A = -exp(A_log) for proper state decay.
103
+ 3. **Storage buffer limit** -- The SSU shader uses 9 storage buffers; default WebGPU limit is 8.
104
+ 4. **Illegal buffer flags** -- MAP_READ cannot be combined with STORAGE usage.
105
+ 5. **Diagnostic overhead** -- Per-token GPU readbacks for debugging were causing device timeouts.
106
+ 6. **Missing RMSNorm on B, C, dt_pre** -- Falcon-Mamba applies weightless RMSNorm to B, C, and dt_pre before the SSU. Standard Mamba does not. This was the final bug -- every shader was correct, but we were implementing the wrong model.
107
+
108
+ The debugging involved systematic golden-value comparison against PyTorch, checking each intermediate buffer across all 8192 elements. Every single shader operation matched to 6 decimal places. The divergence was in the model architecture, not the compute.
109
+
110
+ ## Files
111
+
112
+ - `mamba_runtime.js` -- WebGPU device init, shader compilation, safetensors weight loading (byte-range fetch, BF16->F32 CPU conversion), forward pass orchestration, generation loop
113
+ - `serve_mamba.js` -- Node.js dev server with Range request support for weights, tokenize/detokenize endpoints
114
+ - `index.html` -- Test page with Initialize/Load/Generate buttons
115
+ - `shaders/*.wgsl` -- 12 WGSL compute shaders
116
+ - `golden_dump.py` -- PyTorch golden value dumper for debugging
117
+ - `REPORT.md` -- Detailed build report
118
+
119
+ ## Limitations
120
+
121
+ - Single-token decode only (no batch/prefill optimization)
122
+ - F32 weights (no quantization yet -- loads full 14GB)
123
+ - Tokenization requires Python server-side (no in-browser tokenizer)
124
+ - Tested only on AMD RDNA-3 iGPU; other GPUs may need limit adjustments
125
+
126
+ ## License
127
+
128
+ Apache 2.0
129
+
130
+ ## Credits
131
+
132
+ Built by Joshua ([@LJTSG](https://huggingface.co/LJTSG)) and Claude (Anthropic Opus 4.6).
133
+ Model: [tiiuae/falcon-mamba-7b-instruct](https://huggingface.co/tiiuae/falcon-mamba-7b-instruct).
134
+ Shaders ported from [gfx1151_runtime](https://github.com/) (Vulkan compute SSM runtime for AMD iGPU).