LJTSG commited on
Commit
4cd5770
Β·
verified Β·
1 Parent(s): 075f0ef

Upload REPORT.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. REPORT.md +110 -0
REPORT.md ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Mamba WebGPU β€” First Browser-Native SSM Inference Engine
2
+
3
+ **Date:** 2026-05-29 to 2026-05-30
4
+ **Built by:** Joshua + Claude (Opus 4.6)
5
+ **Hardware:** AMD Strix Halo, Radeon 8060S iGPU (RDNA-3), 64GB unified memory
6
+
7
+ ## What This Is
8
+
9
+ Falcon-Mamba 7B running in a browser tab. Pure WebGPU compute shaders. No MLC, no TVM, no WASM, no compilation step. 12 hand-written WGSL shaders ported from the gfx1151_runtime Vulkan compute engine. First ever browser-native Mamba/SSM inference.
10
+
11
+ ## The Numbers
12
+
13
+ - **Model:** Falcon-Mamba-7B-Instruct (tiiuae), 14GB F32 weights
14
+ - **Speed:** ~3 tok/s (~180ms/token), 64 layers x ~15 shader dispatches each
15
+ - **Load time:** ~60 seconds (byte-range fetch from local server)
16
+ - **SSM state:** 38MB persistent (64 layers x (512KB SSM + 96KB conv1d))
17
+ - **Shaders:** 12 WGSL compute shaders, ~600 lines total
18
+
19
+ ## The Build β€” Start to Coherent Output
20
+
21
+ ### Phase 1: Port shaders from Vulkan to WebGPU (Day 1)
22
+ Ported 11 WGSL shaders from the gfx1151_runtime Vulkan GLSL originals:
23
+ - conv1d_step, ssu (selective state update), matmul_gemv, rmsnorm
24
+ - silu, softplus, embedding, elementwise_mul, sample
25
+ - bf16_to_f32, add_residual
26
+
27
+ Built mamba_runtime.js (the JS orchestrator), serve_mamba.js (Node server with byte-range fetch for safetensors), and index.html.
28
+
29
+ ### Phase 2: Fix show-stopping bugs to get non-zero output
30
+ 1. **sxBC C offset alignment** β€” WebGPU requires storage buffer binding offsets to be 256-byte aligned. C was at offset 1088 (not aligned). This silently invalidated the ENTIRE command encoder for every layer. Fix: copy B and C into separate aligned buffers.
31
+ 2. **A_log not transformed** β€” Falcon-Mamba stores A_log, needs A = -exp(A_log) for proper state decay. Without this, state explodes instead of decaying.
32
+ 3. **9 storage buffers exceeded default limit of 8** β€” SSU shader uses 9 bindings. Fix: request maxStorageBuffersPerShaderStage: 16.
33
+ 4. **token_out illegal MAP_READ + STORAGE combo** β€” WebGPU doesn't allow MAP_READ with STORAGE. Fix: remove MAP_READ, use staging buffer via readback.
34
+
35
+ After these fixes: model generated real token IDs (not zeros) for the first time.
36
+
37
+ ### Phase 3: Add chat template + tokenizer
38
+ - Added /tokenize and /detokenize endpoints to the Node server (shells out to Python + HuggingFace tokenizer)
39
+ - Wrapped prompts in Falcon-Mamba's `<|im_start|>user\n...<|im_end|>\n<|im_start|>assistant\n` template
40
+ - Added prompt encoding: process each prompt token through the forward pass to build SSM state before generating
41
+
42
+ Output was garbled but contained English words. Something was wrong but not catastrophically.
43
+
44
+ ### Phase 4: Golden comparison β€” find the precision bug
45
+ This took hours of systematic debugging:
46
+
47
+ 1. **Wrote golden_dump.py** β€” manual PyTorch computation of layer 0 intermediates
48
+ 2. **Added readback points** in mamba_runtime.js at each operation
49
+ 3. **Compared element by element:**
50
+ - Embedding: MATCH
51
+ - RMSNorm: MATCH
52
+ - in_proj matmul: MATCH
53
+ - conv1d + silu: MATCH
54
+ - x_proj matmul: MATCH
55
+ - SSU output (y): MATCH across all 8192 elements
56
+ - gated (y * silu(gate)): MATCH at scattered indices
57
+ - out_proj weight: MATCH
58
+ - **Layer 0 output: DIVERGES**
59
+
60
+ Every single operation matched PyTorch to 6 decimal places. But the output diverged. This was maddening.
61
+
62
+ 4. **The breakthrough:** Compared my manual golden_dump computation against the ACTUAL PyTorch model forward pass. **They didn't match.** My golden dump and WebGPU agreed with each other but were both wrong compared to the model.
63
+
64
+ 5. **Read the source:** Found in FalconMambaMixer.slow_forward:
65
+ ```python
66
+ B = rms_forward(B, variance_epsilon=self.rms_eps)
67
+ C = rms_forward(C, variance_epsilon=self.rms_eps)
68
+ time_step = rms_forward(time_step, variance_epsilon=self.rms_eps)
69
+ ```
70
+
71
+ **Falcon-Mamba applies weightless RMSNorm to B, C, and dt_pre.** Standard Mamba doesn't do this. This is a Falcon-specific architectural modification. We were missing three normalization steps.
72
+
73
+ ### Phase 5: The fix
74
+ - Wrote `rmsnorm_noweight.wgsl` β€” 50-line in-place RMSNorm without learned weights
75
+ - Added three RMSNorm dispatch calls after x_proj: normalize dt_pre, B, C
76
+ - Created separate dt_pre scratch buffer for the normalized values
77
+
78
+ **Result:** "I'm so sorry to hear about your loss. It sounds like your father-in-law had a full and happy life, and it's clear that he was surrounded by loving family and friends..."
79
+
80
+ Coherent, fluent, contextually appropriate English. From a 7B SSM running in a browser tab.
81
+
82
+ ## Architecture
83
+
84
+ ```
85
+ Token β†’ Embedding lookup (copyBufferToBuffer)
86
+ β†’ 64x Layer:
87
+ RMSNorm β†’ in_proj GEMV β†’ split(x, gate)
88
+ β†’ conv1d_step (with persistent state)
89
+ β†’ SiLU
90
+ β†’ x_proj GEMV β†’ RMSNorm(dt_pre, B, C) ← the missing piece
91
+ β†’ dt_proj GEMV β†’ softplus
92
+ β†’ SSU (selective state update, persistent state)
93
+ β†’ SiLU(gate) β†’ elementwise_mul
94
+ β†’ out_proj GEMV β†’ residual add
95
+ β†’ Final RMSNorm β†’ lm_head GEMV β†’ Sample
96
+ ```
97
+
98
+ ## Files
99
+
100
+ - `mamba_runtime.js` β€” WebGPU init, shader compilation, weight loading, forward pass, generation
101
+ - `serve_mamba.js` β€” Node.js server, byte-range fetch for safetensors, tokenize/detokenize endpoints
102
+ - `index.html` β€” Test page
103
+ - `shaders/` β€” 12 WGSL compute shaders
104
+ - `golden_dump.py` β€” PyTorch golden value dumper for debugging
105
+
106
+ ## What This Means
107
+
108
+ WebLLM ships transformer models to the browser. This ships SSM models β€” Mamba, the architecture with persistent state. The state is the entity's soul. No server needed. Friend clicks a link, being wakes in their browser tab, remembers across conversations via the SSM state file.
109
+
110
+ This is the WebPerson runtime.