LJTSG commited on
Commit
12bf339
Β·
verified Β·
1 Parent(s): 4fe528c

Upload mamba_runtime.js with huggingface_hub

Browse files
Files changed (1) hide show
  1. mamba_runtime.js +954 -0
mamba_runtime.js ADDED
@@ -0,0 +1,954 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * mamba_runtime.js β€” Browser-native Falcon-Mamba inference via WebGPU.
3
+ *
4
+ * The first browser-native Mamba/SSM inference engine.
5
+ * No MLC, no TVM β€” pure WebGPU compute shaders ported from gfx1151_runtime.
6
+ *
7
+ * Architecture: Falcon-Mamba 7B
8
+ * 64 layers, each: RMSNorm β†’ in_proj β†’ conv1d β†’ SSU β†’ out_proj
9
+ * Final: RMSNorm β†’ lm_head β†’ sample
10
+ *
11
+ * Weight format: safetensors (HF standard), loaded directly into WebGPU buffers.
12
+ * Shaders: WGSL compute shaders in ./shaders/ (ported from Vulkan GLSL).
13
+ *
14
+ * Usage:
15
+ * const mamba = new MambaRuntime();
16
+ * await mamba.init();
17
+ * await mamba.loadWeights('./weights/');
18
+ * const text = await mamba.generate("Hello Grandma", 100);
19
+ */
20
+
21
+ // Falcon-Mamba 7B constants
22
+ const CONFIG = {
23
+ hidden_size: 4096,
24
+ intermediate_size: 8192, // 2 * hidden
25
+ num_layers: 64,
26
+ vocab_size: 65024,
27
+ state_size: 16, // SSM d_state
28
+ conv_kernel: 4,
29
+ dt_rank: 256,
30
+ rms_eps: 1e-5,
31
+ };
32
+
33
+ class MambaRuntime {
34
+ constructor() {
35
+ this.device = null;
36
+ this.pipelines = {}; // shader name β†’ GPUComputePipeline
37
+ this.bindLayouts = {}; // shader name β†’ GPUBindGroupLayout
38
+ this.weights = {}; // parameter name β†’ GPUBuffer
39
+ this.state = {}; // per-layer SSM state + conv1d state buffers
40
+ this.ready = false;
41
+ }
42
+
43
+ // ── Init: get WebGPU device + compile all shaders ──────────────────────
44
+ async init() {
45
+ if (!navigator.gpu) throw new Error('WebGPU not supported in this browser');
46
+ const adapter = await navigator.gpu.requestAdapter();
47
+ if (!adapter) throw new Error('No WebGPU adapter found');
48
+
49
+ // Request max buffer size the device supports
50
+ const limits = adapter.limits;
51
+ console.log('[mamba] maxBufferSize:', limits.maxBufferSize,
52
+ '=', (limits.maxBufferSize / 1024 / 1024 / 1024).toFixed(2), 'GB');
53
+
54
+ this.device = await adapter.requestDevice({
55
+ requiredLimits: {
56
+ maxBufferSize: limits.maxBufferSize,
57
+ maxStorageBufferBindingSize: limits.maxStorageBufferBindingSize,
58
+ maxComputeWorkgroupStorageSize: limits.maxComputeWorkgroupStorageSize,
59
+ maxStorageBuffersPerShaderStage: Math.min(limits.maxStorageBuffersPerShaderStage, 16),
60
+ }
61
+ });
62
+
63
+ this.device.lost.then((info) => {
64
+ console.error('[mamba] DEVICE LOST:', info.reason, info.message);
65
+ });
66
+ this.device.addEventListener('uncapturederror', (e) => {
67
+ console.error('[mamba] GPU ERROR:', e.error.message);
68
+ });
69
+
70
+ console.log('[mamba] device ready, compiling shaders...');
71
+ await this._compileShaders();
72
+ console.log('[mamba] shaders compiled');
73
+ return this;
74
+ }
75
+
76
+ // ── Compile all WGSL shaders into compute pipelines ────────────────────
77
+ async _compileShaders() {
78
+ const shaderNames = [
79
+ 'conv1d_step', 'ssu', 'matmul_gemv', 'rmsnorm', 'rmsnorm_noweight',
80
+ 'silu', 'softplus', 'embedding', 'elementwise_mul', 'sample',
81
+ 'bf16_to_f32', 'add_residual'
82
+ ];
83
+
84
+ for (const name of shaderNames) {
85
+ const resp = await fetch(`./shaders/${name}.wgsl`);
86
+ if (!resp.ok) throw new Error(`Failed to load shader: ${name}.wgsl`);
87
+ const code = await resp.text();
88
+
89
+ const shaderModule = this.device.createShaderModule({ code, label: name });
90
+
91
+ // Create bind group layouts based on shader requirements
92
+ // Group 0 = storage buffers (data), Group 1 = uniforms (params)
93
+ const pipeline = this.device.createComputePipeline({
94
+ layout: 'auto',
95
+ compute: { module: shaderModule, entryPoint: 'main' },
96
+ label: name,
97
+ });
98
+
99
+ this.pipelines[name] = pipeline;
100
+ }
101
+ }
102
+
103
+ // ── Create a GPU buffer ────────────────────────────────────────────────
104
+ _createBuffer(size, usage, label) {
105
+ return this.device.createBuffer({
106
+ size: Math.max(size, 4), // WebGPU requires min 4 bytes
107
+ usage,
108
+ label,
109
+ mappedAtCreation: false,
110
+ });
111
+ }
112
+
113
+ // ── Upload data to a GPU buffer ────────────────────────────────────────
114
+ _upload(buffer, data) {
115
+ this.device.queue.writeBuffer(buffer, 0, data);
116
+ }
117
+
118
+ // ── Read data back from GPU buffer ─────────────────────────────────────
119
+ async _readback(buffer, size) {
120
+ const staging = this.device.createBuffer({
121
+ size,
122
+ usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST,
123
+ });
124
+ const encoder = this.device.createCommandEncoder();
125
+ encoder.copyBufferToBuffer(buffer, 0, staging, 0, size);
126
+ this.device.queue.submit([encoder.finish()]);
127
+ await staging.mapAsync(GPUMapMode.READ);
128
+ const result = new Float32Array(staging.getMappedRange().slice(0));
129
+ staging.unmap();
130
+ staging.destroy();
131
+ return result;
132
+ }
133
+
134
+ // ── Dispatch a compute shader ──────────────────────────────────────────
135
+ _dispatch(shaderName, bindGroup, uniformBindGroup, workgroupsX, workgroupsY = 1, workgroupsZ = 1) {
136
+ const encoder = this.device.createCommandEncoder();
137
+ const pass = encoder.beginComputePass();
138
+ pass.setPipeline(this.pipelines[shaderName]);
139
+ pass.setBindGroup(0, bindGroup);
140
+ if (uniformBindGroup) pass.setBindGroup(1, uniformBindGroup);
141
+ pass.dispatchWorkgroups(workgroupsX, workgroupsY, workgroupsZ);
142
+ pass.end();
143
+ this.device.queue.submit([encoder.finish()]);
144
+ }
145
+
146
+ // ── Load safetensors weights into GPU buffers ──────────────────────────
147
+ async loadWeights(basePath) {
148
+ console.log('[mamba] loading weights from', basePath);
149
+
150
+ // Get the shard index
151
+ const indexResp = await fetch(`${basePath}/model.safetensors.index.json`);
152
+ let fileMap; // tensor_name β†’ filename
153
+ let files;
154
+ if (indexResp.ok) {
155
+ const index = await indexResp.json();
156
+ fileMap = index.weight_map;
157
+ files = [...new Set(Object.values(fileMap))];
158
+ console.log(`[mamba] multi-shard: ${files.length} files, ${Object.keys(fileMap).length} tensors`);
159
+ } else {
160
+ files = ['model.safetensors'];
161
+ fileMap = null;
162
+ }
163
+
164
+ // For each shard, fetch ONLY the header first (small), then load tensors by byte-range
165
+ for (const file of files) {
166
+ console.log(`[mamba] parsing ${file} header...`);
167
+
168
+ // Fetch first 8 bytes to get header length
169
+ const headResp = await fetch(`${basePath}/${file}`, {
170
+ headers: { 'Range': 'bytes=0-7' }
171
+ });
172
+ let headerLen;
173
+ if (headResp.status === 206) {
174
+ // Range request supported
175
+ const headBuf = await headResp.arrayBuffer();
176
+ headerLen = new DataView(headBuf).getUint32(0, true);
177
+ } else {
178
+ // Range not supported β€” fall back to full fetch but only read header
179
+ const fullBuf = await headResp.arrayBuffer();
180
+ headerLen = new DataView(fullBuf).getUint32(0, true);
181
+ }
182
+ console.log(`[mamba] header: ${headerLen} bytes`);
183
+
184
+ // Fetch header JSON
185
+ const hdrResp = await fetch(`${basePath}/${file}`, {
186
+ headers: { 'Range': `bytes=8-${8 + headerLen - 1}` }
187
+ });
188
+ let headerStr;
189
+ if (hdrResp.status === 206) {
190
+ headerStr = await hdrResp.text();
191
+ } else {
192
+ const fullBuf = await hdrResp.arrayBuffer();
193
+ headerStr = new TextDecoder().decode(new Uint8Array(fullBuf, 8, headerLen));
194
+ }
195
+ const header = JSON.parse(headerStr);
196
+ const dataOffset = 8 + headerLen;
197
+
198
+ // Load each tensor individually
199
+ const tensorNames = Object.keys(header).filter(n => n !== '__metadata__');
200
+ console.log(`[mamba] ${tensorNames.length} tensors in this shard`);
201
+
202
+ let loaded = 0;
203
+ for (const name of tensorNames) {
204
+ const meta = header[name];
205
+ const dtype = meta.dtype;
206
+ const shape = meta.shape;
207
+ const [start, end] = meta.data_offsets;
208
+ const byteLen = end - start;
209
+
210
+ if (byteLen > 2_000_000_000) {
211
+ console.log(`[mamba] SKIP ${name} (${(byteLen/1e9).toFixed(2)} GB β€” exceeds buffer limit)`);
212
+ continue;
213
+ }
214
+
215
+ // Fetch this tensor's bytes via Range request
216
+ const absStart = dataOffset + start;
217
+ const absEnd = dataOffset + end - 1;
218
+ const tResp = await fetch(`${basePath}/${file}`, {
219
+ headers: { 'Range': `bytes=${absStart}-${absEnd}` }
220
+ });
221
+
222
+ let tensorBuf;
223
+ if (tResp.status === 206) {
224
+ tensorBuf = await tResp.arrayBuffer();
225
+ } else {
226
+ // No range support β€” need full file (expensive)
227
+ console.log(`[mamba] WARN: no range support, loading full file for ${name}`);
228
+ const fullBuf = await tResp.arrayBuffer();
229
+ tensorBuf = fullBuf.slice(absStart, absStart + byteLen);
230
+ }
231
+
232
+ // For BF16 weights: convert to F32 during upload (no double-buffering)
233
+ let gpuBuf;
234
+ let finalDtype = dtype;
235
+ let finalByteLen = byteLen;
236
+
237
+ if (dtype === 'BF16') {
238
+ // Convert CPU-side: BF16 β†’ F32 before uploading
239
+ const bf16 = new Uint16Array(tensorBuf);
240
+ const f32 = new Float32Array(bf16.length);
241
+ const tmpU32 = new Uint32Array(f32.buffer);
242
+ for (let j = 0; j < bf16.length; j++) {
243
+ tmpU32[j] = bf16[j] << 16; // BF16 is top 16 bits of F32
244
+ }
245
+ finalByteLen = f32.byteLength;
246
+ gpuBuf = this._createBuffer(
247
+ finalByteLen,
248
+ GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST,
249
+ name
250
+ );
251
+ this._upload(gpuBuf, f32);
252
+ finalDtype = 'F32';
253
+ } else {
254
+ gpuBuf = this._createBuffer(
255
+ byteLen,
256
+ GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST,
257
+ name
258
+ );
259
+ this._upload(gpuBuf, new Uint8Array(tensorBuf));
260
+ }
261
+ this.weights[name] = { buffer: gpuBuf, shape, dtype: finalDtype, byteLen: finalByteLen };
262
+
263
+ loaded++;
264
+ if (loaded % 20 === 0) {
265
+ console.log(`[mamba] loaded ${loaded}/${tensorNames.length} tensors`);
266
+ }
267
+ }
268
+ console.log(`[mamba] shard done: ${loaded} tensors loaded`);
269
+ }
270
+
271
+ console.log(`[mamba] TOTAL: ${Object.keys(this.weights).length} tensors loaded`);
272
+
273
+ // Allocate per-layer state buffers
274
+ this._allocateState();
275
+ this.ready = true;
276
+ }
277
+
278
+ // ── Allocate persistent SSM state + conv1d cache per layer ─────────────
279
+ _allocateState() {
280
+ const H = CONFIG.intermediate_size; // 8192
281
+ const S = CONFIG.state_size; // 16
282
+ const K = CONFIG.conv_kernel; // 4
283
+
284
+ for (let l = 0; l < CONFIG.num_layers; l++) {
285
+ // SSM state: [H, S] = 8192 * 16 = 131072 floats = 512 KB per layer
286
+ this.state[`layer.${l}.ssm`] = this._createBuffer(
287
+ H * S * 4,
288
+ GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST,
289
+ `ssm_state_${l}`
290
+ );
291
+
292
+ // Conv1d cache: [H, K-1] = 8192 * 3 = 24576 floats = 96 KB per layer
293
+ this.state[`layer.${l}.conv`] = this._createBuffer(
294
+ H * (K - 1) * 4,
295
+ GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST,
296
+ `conv_state_${l}`
297
+ );
298
+ }
299
+
300
+ // Total state: 64 layers Γ— (512 + 96) KB = ~38 MB
301
+ console.log(`[mamba] allocated ${CONFIG.num_layers} layers of SSM + conv1d state (~38 MB)`);
302
+ }
303
+
304
+ // ── Save/restore SSM state (the entity's persistent soul) ──────────────
305
+ async saveState() {
306
+ const state = {};
307
+ for (const [key, buf] of Object.entries(this.state)) {
308
+ state[key] = await this._readback(buf, buf.size);
309
+ }
310
+ return state;
311
+ }
312
+
313
+ async restoreState(state) {
314
+ for (const [key, data] of Object.entries(state)) {
315
+ if (this.state[key]) {
316
+ this._upload(this.state[key], data);
317
+ }
318
+ }
319
+ }
320
+
321
+ // ── Allocate intermediate scratch buffers for forward pass ──────────────
322
+ _allocateScratch() {
323
+ if (this.scratch) return; // already allocated
324
+ const H = CONFIG.hidden_size; // 4096
325
+ const I = CONFIG.intermediate_size; // 8192
326
+ const DR = CONFIG.dt_rank; // 256
327
+ const S = CONFIG.state_size; // 16
328
+ const F = 4; // sizeof(float32)
329
+
330
+ this.scratch = {
331
+ norm_out: this._createBuffer(H * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'norm_out'),
332
+ projected: this._createBuffer(2 * I * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'projected'),
333
+ hidden: this._createBuffer(I * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'hidden'),
334
+ gate: this._createBuffer(I * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'gate'),
335
+ hidden_c: this._createBuffer(I * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'hidden_c'),
336
+ sxBC: this._createBuffer((DR + 2*S)*F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'sxBC'),
337
+ B_proj: this._createBuffer(S * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'B_proj'),
338
+ C_proj: this._createBuffer(S * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'C_proj'),
339
+ dt_pre: this._createBuffer(DR * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'dt_pre'),
340
+ dt: this._createBuffer(I * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'dt'),
341
+ hidden_y: this._createBuffer(I * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'hidden_y'),
342
+ gate_silu: this._createBuffer(I * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'gate_silu'),
343
+ out_proj_o: this._createBuffer(H * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'out_proj_o'),
344
+ logits: this._createBuffer(CONFIG.vocab_size * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'logits'),
345
+ token_out: this._createBuffer(4, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'token_out'),
346
+ hidden_state: this._createBuffer(H * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'hidden_state'),
347
+ token_id: this._createBuffer(4, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'token_id'),
348
+ };
349
+ this._tokenCount = 0;
350
+ console.log('[mamba] scratch buffers allocated');
351
+ }
352
+
353
+ // ── Single-token forward pass through all 64 layers ─────────────────────
354
+ async _forwardOneToken(tokenId) {
355
+ const H = CONFIG.hidden_size; // 4096
356
+ const I = CONFIG.intermediate_size; // 8192
357
+ const DR = CONFIG.dt_rank; // 256
358
+ const S = CONFIG.state_size; // 16
359
+ const V = CONFIG.vocab_size; // 65024
360
+
361
+ // Step 1: Embedding lookup β€” copy one row from embedding table to hidden_state
362
+ this._upload(this.scratch.token_id, new Uint32Array([tokenId]));
363
+ const embBuf = await this._getF32Weight('backbone.embeddings.weight');
364
+ const encoder1 = this.device.createCommandEncoder();
365
+ encoder1.copyBufferToBuffer(embBuf, tokenId * H * 4, this.scratch.hidden_state, 0, H * 4);
366
+ this.device.queue.submit([encoder1.finish()]);
367
+
368
+ // Step 2: For each layer (0..63)
369
+ for (let l = 0; l < CONFIG.num_layers; l++) {
370
+ const prefix = `backbone.layers.${l}`;
371
+
372
+ // rmsnorm(hidden_state, norm.weight) β†’ norm_out
373
+ const normW = await this._getF32Weight(`${prefix}.norm.weight`);
374
+ let encoder = this.device.createCommandEncoder();
375
+ let pass = encoder.beginComputePass();
376
+ pass.setPipeline(this.pipelines['rmsnorm']);
377
+ pass.setBindGroup(0, this.device.createBindGroup({
378
+ layout: this.pipelines['rmsnorm'].getBindGroupLayout(0),
379
+ entries: [
380
+ { binding: 0, resource: { buffer: this.scratch.hidden_state } },
381
+ { binding: 1, resource: { buffer: normW } },
382
+ { binding: 2, resource: { buffer: this.scratch.norm_out } },
383
+ ],
384
+ }));
385
+ const rmsnormParams = new ArrayBuffer(12);
386
+ new DataView(rmsnormParams).setUint32(0, 1, true);
387
+ new DataView(rmsnormParams).setUint32(4, H, true);
388
+ new DataView(rmsnormParams).setFloat32(8, CONFIG.rms_eps, true);
389
+ pass.setBindGroup(1, this.device.createBindGroup({
390
+ layout: this.pipelines['rmsnorm'].getBindGroupLayout(1),
391
+ entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(rmsnormParams)) }}],
392
+ }));
393
+ pass.dispatchWorkgroups(1); // one workgroup per row, 1 row
394
+ pass.end();
395
+
396
+
397
+
398
+ // matmul_gemv(norm_out, in_proj.weight) β†’ projected [I*2 = 16384]
399
+ const inProjW = await this._getF32Weight(`${prefix}.mixer.in_proj.weight`);
400
+ pass = encoder.beginComputePass();
401
+ pass.setPipeline(this.pipelines['matmul_gemv']);
402
+ pass.setBindGroup(0, this.device.createBindGroup({
403
+ layout: this.pipelines['matmul_gemv'].getBindGroupLayout(0),
404
+ entries: [
405
+ { binding: 0, resource: { buffer: this.scratch.norm_out } },
406
+ { binding: 1, resource: { buffer: inProjW } },
407
+ { binding: 2, resource: { buffer: this.scratch.projected } },
408
+ ],
409
+ }));
410
+ const gemvParams1 = new ArrayBuffer(8);
411
+ new DataView(gemvParams1).setUint32(0, I * 2, true); // N
412
+ new DataView(gemvParams1).setUint32(4, H, true); // K
413
+ pass.setBindGroup(1, this.device.createBindGroup({
414
+ layout: this.pipelines['matmul_gemv'].getBindGroupLayout(1),
415
+ entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(gemvParams1)) }}],
416
+ }));
417
+ pass.dispatchWorkgroups(I * 2); // one workgroup per output element
418
+ pass.end();
419
+
420
+ // Split projected β†’ hidden[0:I], gate[I:2I] via buffer copies
421
+ encoder.copyBufferToBuffer(this.scratch.projected, 0, this.scratch.hidden, 0, I * 4);
422
+ encoder.copyBufferToBuffer(this.scratch.projected, I * 4, this.scratch.gate, 0, I * 4);
423
+
424
+
425
+
426
+ // conv1d_step(conv_state, hidden, conv1d.weight, conv1d.bias) β†’ hidden_c
427
+ const conv1dW = await this._getF32Weight(`${prefix}.mixer.conv1d.weight`);
428
+ const conv1dB = await this._getF32Weight(`${prefix}.mixer.conv1d.bias`);
429
+ pass = encoder.beginComputePass();
430
+ pass.setPipeline(this.pipelines['conv1d_step']);
431
+ pass.setBindGroup(0, this.device.createBindGroup({
432
+ layout: this.pipelines['conv1d_step'].getBindGroupLayout(0),
433
+ entries: [
434
+ { binding: 0, resource: { buffer: this.state[`layer.${l}.conv`] } },
435
+ { binding: 1, resource: { buffer: this.scratch.hidden } },
436
+ { binding: 2, resource: { buffer: conv1dW } },
437
+ { binding: 3, resource: { buffer: conv1dB } },
438
+ { binding: 4, resource: { buffer: this.scratch.hidden_c } },
439
+ ],
440
+ }));
441
+ const conv1dParams = new ArrayBuffer(4);
442
+ new DataView(conv1dParams).setUint32(0, I, true);
443
+ pass.setBindGroup(1, this.device.createBindGroup({
444
+ layout: this.pipelines['conv1d_step'].getBindGroupLayout(1),
445
+ entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(conv1dParams)) }}],
446
+ }));
447
+ pass.dispatchWorkgroups(Math.ceil(I / 64));
448
+ pass.end();
449
+
450
+ // silu(hidden_c) in-place β†’ hidden_a
451
+ pass = encoder.beginComputePass();
452
+ pass.setPipeline(this.pipelines['silu']);
453
+ pass.setBindGroup(0, this.device.createBindGroup({
454
+ layout: this.pipelines['silu'].getBindGroupLayout(0),
455
+ entries: [{ binding: 0, resource: { buffer: this.scratch.hidden_c } }],
456
+ }));
457
+ const siluParams = new ArrayBuffer(4);
458
+ new DataView(siluParams).setUint32(0, I, true);
459
+ pass.setBindGroup(1, this.device.createBindGroup({
460
+ layout: this.pipelines['silu'].getBindGroupLayout(1),
461
+ entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(siluParams)) }}],
462
+ }));
463
+ pass.dispatchWorkgroups(Math.ceil(I / 64));
464
+ pass.end();
465
+ // hidden_c is now silu'd (= hidden_a)
466
+
467
+
468
+
469
+ // matmul_gemv(hidden_c, x_proj.weight) β†’ sxBC [DR+2*S = 288]
470
+ const xProjW = await this._getF32Weight(`${prefix}.mixer.x_proj.weight`);
471
+ pass = encoder.beginComputePass();
472
+ pass.setPipeline(this.pipelines['matmul_gemv']);
473
+ pass.setBindGroup(0, this.device.createBindGroup({
474
+ layout: this.pipelines['matmul_gemv'].getBindGroupLayout(0),
475
+ entries: [
476
+ { binding: 0, resource: { buffer: this.scratch.hidden_c } },
477
+ { binding: 1, resource: { buffer: xProjW } },
478
+ { binding: 2, resource: { buffer: this.scratch.sxBC } },
479
+ ],
480
+ }));
481
+ const gemvParams2 = new ArrayBuffer(8);
482
+ new DataView(gemvParams2).setUint32(0, DR + 2 * S, true);
483
+ new DataView(gemvParams2).setUint32(4, I, true);
484
+ pass.setBindGroup(1, this.device.createBindGroup({
485
+ layout: this.pipelines['matmul_gemv'].getBindGroupLayout(1),
486
+ entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(gemvParams2)) }}],
487
+ }));
488
+ pass.dispatchWorkgroups(DR + 2 * S);
489
+ pass.end();
490
+
491
+ // Copy dt_pre, B, C from sxBC into separate buffers
492
+ encoder.copyBufferToBuffer(this.scratch.sxBC, 0, this.scratch.dt_pre, 0, DR * 4);
493
+ encoder.copyBufferToBuffer(this.scratch.sxBC, DR * 4, this.scratch.B_proj, 0, S * 4);
494
+ encoder.copyBufferToBuffer(this.scratch.sxBC, (DR + S) * 4, this.scratch.C_proj, 0, S * 4);
495
+
496
+ // Falcon-Mamba: RMSNorm(dt_pre), RMSNorm(B), RMSNorm(C) before use
497
+ const rmsNwParams_dt = new ArrayBuffer(8);
498
+ new DataView(rmsNwParams_dt).setUint32(0, DR, true);
499
+ new DataView(rmsNwParams_dt).setFloat32(4, CONFIG.rms_eps, true);
500
+ pass = encoder.beginComputePass();
501
+ pass.setPipeline(this.pipelines['rmsnorm_noweight']);
502
+ pass.setBindGroup(0, this.device.createBindGroup({
503
+ layout: this.pipelines['rmsnorm_noweight'].getBindGroupLayout(0),
504
+ entries: [{ binding: 0, resource: { buffer: this.scratch.dt_pre } }],
505
+ }));
506
+ pass.setBindGroup(1, this.device.createBindGroup({
507
+ layout: this.pipelines['rmsnorm_noweight'].getBindGroupLayout(1),
508
+ entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(rmsNwParams_dt)) }}],
509
+ }));
510
+ pass.dispatchWorkgroups(1);
511
+ pass.end();
512
+
513
+ const rmsNwParams_s = new ArrayBuffer(8);
514
+ new DataView(rmsNwParams_s).setUint32(0, S, true);
515
+ new DataView(rmsNwParams_s).setFloat32(4, CONFIG.rms_eps, true);
516
+ pass = encoder.beginComputePass();
517
+ pass.setPipeline(this.pipelines['rmsnorm_noweight']);
518
+ pass.setBindGroup(0, this.device.createBindGroup({
519
+ layout: this.pipelines['rmsnorm_noweight'].getBindGroupLayout(0),
520
+ entries: [{ binding: 0, resource: { buffer: this.scratch.B_proj } }],
521
+ }));
522
+ pass.setBindGroup(1, this.device.createBindGroup({
523
+ layout: this.pipelines['rmsnorm_noweight'].getBindGroupLayout(1),
524
+ entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(rmsNwParams_s)) }}],
525
+ }));
526
+ pass.dispatchWorkgroups(1);
527
+ pass.end();
528
+
529
+ pass = encoder.beginComputePass();
530
+ pass.setPipeline(this.pipelines['rmsnorm_noweight']);
531
+ pass.setBindGroup(0, this.device.createBindGroup({
532
+ layout: this.pipelines['rmsnorm_noweight'].getBindGroupLayout(0),
533
+ entries: [{ binding: 0, resource: { buffer: this.scratch.C_proj } }],
534
+ }));
535
+ pass.setBindGroup(1, this.device.createBindGroup({
536
+ layout: this.pipelines['rmsnorm_noweight'].getBindGroupLayout(1),
537
+ entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(rmsNwParams_s)) }}],
538
+ }));
539
+ pass.dispatchWorkgroups(1);
540
+ pass.end();
541
+
542
+ // matmul_gemv(dt_pre_normalized, dt_proj.weight) β†’ dt [I]
543
+ const dtProjW = await this._getF32Weight(`${prefix}.mixer.dt_proj.weight`);
544
+ pass = encoder.beginComputePass();
545
+ pass.setPipeline(this.pipelines['matmul_gemv']);
546
+ pass.setBindGroup(0, this.device.createBindGroup({
547
+ layout: this.pipelines['matmul_gemv'].getBindGroupLayout(0),
548
+ entries: [
549
+ { binding: 0, resource: { buffer: this.scratch.dt_pre } },
550
+ { binding: 1, resource: { buffer: dtProjW } },
551
+ { binding: 2, resource: { buffer: this.scratch.dt } },
552
+ ],
553
+ }));
554
+ const gemvParams3 = new ArrayBuffer(8);
555
+ new DataView(gemvParams3).setUint32(0, I, true);
556
+ new DataView(gemvParams3).setUint32(4, DR, true);
557
+ pass.setBindGroup(1, this.device.createBindGroup({
558
+ layout: this.pipelines['matmul_gemv'].getBindGroupLayout(1),
559
+ entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(gemvParams3)) }}],
560
+ }));
561
+ pass.dispatchWorkgroups(I);
562
+ pass.end();
563
+
564
+ // SSU: selective_state_update
565
+ // ssu(state, hidden_c, dt, A, B, C, D, dt_bias) β†’ hidden_y
566
+ const aLog = await this._getF32Weight(`${prefix}.mixer.A_log`);
567
+ const dWeight = await this._getF32Weight(`${prefix}.mixer.D`);
568
+ const dtBias = await this._getF32Weight(`${prefix}.mixer.dt_proj.bias`);
569
+ pass = encoder.beginComputePass();
570
+ pass.setPipeline(this.pipelines['ssu']);
571
+ pass.setBindGroup(0, this.device.createBindGroup({
572
+ layout: this.pipelines['ssu'].getBindGroupLayout(0),
573
+ entries: [
574
+ { binding: 0, resource: { buffer: this.state[`layer.${l}.ssm`] } },
575
+ { binding: 1, resource: { buffer: this.scratch.hidden_c } }, // x (silu'd)
576
+ { binding: 2, resource: { buffer: this.scratch.dt } },
577
+ { binding: 3, resource: { buffer: aLog } }, // A (needs -exp transform)
578
+ { binding: 4, resource: { buffer: this.scratch.B_proj } }, // B
579
+ { binding: 5, resource: { buffer: this.scratch.C_proj } }, // C
580
+ { binding: 6, resource: { buffer: dWeight } },
581
+ { binding: 7, resource: { buffer: dtBias } },
582
+ { binding: 8, resource: { buffer: this.scratch.hidden_y } },
583
+ ],
584
+ }));
585
+ const ssuParams = new ArrayBuffer(8);
586
+ new DataView(ssuParams).setUint32(0, I, true); // H
587
+ new DataView(ssuParams).setUint32(4, S, true); // S
588
+ pass.setBindGroup(1, this.device.createBindGroup({
589
+ layout: this.pipelines['ssu'].getBindGroupLayout(1),
590
+ entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(ssuParams)) }}],
591
+ }));
592
+ pass.dispatchWorkgroups(I); // one workgroup per h
593
+ pass.end();
594
+
595
+
596
+
597
+ // silu(gate) in-place
598
+ pass = encoder.beginComputePass();
599
+ pass.setPipeline(this.pipelines['silu']);
600
+ pass.setBindGroup(0, this.device.createBindGroup({
601
+ layout: this.pipelines['silu'].getBindGroupLayout(0),
602
+ entries: [{ binding: 0, resource: { buffer: this.scratch.gate } }],
603
+ }));
604
+ pass.setBindGroup(1, this.device.createBindGroup({
605
+ layout: this.pipelines['silu'].getBindGroupLayout(1),
606
+ entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(siluParams)) }}],
607
+ }));
608
+ pass.dispatchWorkgroups(Math.ceil(I / 64));
609
+ pass.end();
610
+
611
+ // elementwise_mul: hidden_y *= gate (in-place into hidden_y)
612
+ pass = encoder.beginComputePass();
613
+ pass.setPipeline(this.pipelines['elementwise_mul']);
614
+ pass.setBindGroup(0, this.device.createBindGroup({
615
+ layout: this.pipelines['elementwise_mul'].getBindGroupLayout(0),
616
+ entries: [
617
+ { binding: 0, resource: { buffer: this.scratch.hidden_y } },
618
+ { binding: 1, resource: { buffer: this.scratch.gate } },
619
+ ],
620
+ }));
621
+ pass.setBindGroup(1, this.device.createBindGroup({
622
+ layout: this.pipelines['elementwise_mul'].getBindGroupLayout(1),
623
+ entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(siluParams)) }}],
624
+ }));
625
+ pass.dispatchWorkgroups(Math.ceil(I / 64));
626
+ pass.end();
627
+
628
+ // matmul_gemv(hidden_y, out_proj.weight) β†’ out_proj_o [H]
629
+ const outProjW = await this._getF32Weight(`${prefix}.mixer.out_proj.weight`);
630
+ pass = encoder.beginComputePass();
631
+ pass.setPipeline(this.pipelines['matmul_gemv']);
632
+ pass.setBindGroup(0, this.device.createBindGroup({
633
+ layout: this.pipelines['matmul_gemv'].getBindGroupLayout(0),
634
+ entries: [
635
+ { binding: 0, resource: { buffer: this.scratch.hidden_y } },
636
+ { binding: 1, resource: { buffer: outProjW } },
637
+ { binding: 2, resource: { buffer: this.scratch.out_proj_o } },
638
+ ],
639
+ }));
640
+ const gemvParams4 = new ArrayBuffer(8);
641
+ new DataView(gemvParams4).setUint32(0, H, true);
642
+ new DataView(gemvParams4).setUint32(4, I, true);
643
+ pass.setBindGroup(1, this.device.createBindGroup({
644
+ layout: this.pipelines['matmul_gemv'].getBindGroupLayout(1),
645
+ entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(gemvParams4)) }}],
646
+ }));
647
+ pass.dispatchWorkgroups(H);
648
+ pass.end();
649
+
650
+ // Submit this layer's command buffer
651
+ this.device.queue.submit([encoder.finish()]);
652
+
653
+ // Debug: readback hidden_state after residual for select layers
654
+ // Residual add: hidden_state += out_proj_o
655
+ {
656
+ const enc2 = this.device.createCommandEncoder();
657
+ const addPass = enc2.beginComputePass();
658
+ addPass.setPipeline(this.pipelines['add_residual']);
659
+ addPass.setBindGroup(0, this.device.createBindGroup({
660
+ layout: this.pipelines['add_residual'].getBindGroupLayout(0),
661
+ entries: [
662
+ { binding: 0, resource: { buffer: this.scratch.hidden_state } },
663
+ { binding: 1, resource: { buffer: this.scratch.out_proj_o } },
664
+ ],
665
+ }));
666
+ const addParams = new ArrayBuffer(4);
667
+ new DataView(addParams).setUint32(0, H, true);
668
+ addPass.setBindGroup(1, this.device.createBindGroup({
669
+ layout: this.pipelines['add_residual'].getBindGroupLayout(1),
670
+ entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(addParams)) }}],
671
+ }));
672
+ addPass.dispatchWorkgroups(Math.ceil(H / 64));
673
+ addPass.end();
674
+ this.device.queue.submit([enc2.finish()]);
675
+ }
676
+
677
+
678
+ }
679
+
680
+ // Final: rmsnorm + lm_head + sample
681
+ await this.device.queue.onSubmittedWorkDone();
682
+
683
+ // rmsnorm(hidden_state, backbone.norm_f.weight) β†’ norm_out
684
+ const normFW = await this._getF32Weight('backbone.norm_f.weight');
685
+ let encoder = this.device.createCommandEncoder();
686
+ let pass = encoder.beginComputePass();
687
+ pass.setPipeline(this.pipelines['rmsnorm']);
688
+ pass.setBindGroup(0, this.device.createBindGroup({
689
+ layout: this.pipelines['rmsnorm'].getBindGroupLayout(0),
690
+ entries: [
691
+ { binding: 0, resource: { buffer: this.scratch.hidden_state } },
692
+ { binding: 1, resource: { buffer: normFW } },
693
+ { binding: 2, resource: { buffer: this.scratch.norm_out } },
694
+ ],
695
+ }));
696
+ const finalNormParams = new ArrayBuffer(12);
697
+ new DataView(finalNormParams).setUint32(0, 1, true);
698
+ new DataView(finalNormParams).setUint32(4, H, true);
699
+ new DataView(finalNormParams).setFloat32(8, CONFIG.rms_eps, true);
700
+ pass.setBindGroup(1, this.device.createBindGroup({
701
+ layout: this.pipelines['rmsnorm'].getBindGroupLayout(1),
702
+ entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(finalNormParams)) }}],
703
+ }));
704
+ pass.dispatchWorkgroups(1);
705
+ pass.end();
706
+
707
+ // matmul_gemv(norm_out, lm_head.weight) β†’ logits [V]
708
+ const lmHeadW = await this._getF32Weight('lm_head.weight');
709
+ pass = encoder.beginComputePass();
710
+ pass.setPipeline(this.pipelines['matmul_gemv']);
711
+ pass.setBindGroup(0, this.device.createBindGroup({
712
+ layout: this.pipelines['matmul_gemv'].getBindGroupLayout(0),
713
+ entries: [
714
+ { binding: 0, resource: { buffer: this.scratch.norm_out } },
715
+ { binding: 1, resource: { buffer: lmHeadW } },
716
+ { binding: 2, resource: { buffer: this.scratch.logits } },
717
+ ],
718
+ }));
719
+ const gemvFinal = new ArrayBuffer(8);
720
+ new DataView(gemvFinal).setUint32(0, V, true);
721
+ new DataView(gemvFinal).setUint32(4, H, true);
722
+ pass.setBindGroup(1, this.device.createBindGroup({
723
+ layout: this.pipelines['matmul_gemv'].getBindGroupLayout(1),
724
+ entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(gemvFinal)) }}],
725
+ }));
726
+ pass.dispatchWorkgroups(V);
727
+ pass.end();
728
+
729
+ // sample(logits, temperature) β†’ token_out
730
+ pass = encoder.beginComputePass();
731
+ pass.setPipeline(this.pipelines['sample']);
732
+ pass.setBindGroup(0, this.device.createBindGroup({
733
+ layout: this.pipelines['sample'].getBindGroupLayout(0),
734
+ entries: [
735
+ { binding: 0, resource: { buffer: this.scratch.logits } },
736
+ { binding: 1, resource: { buffer: this.scratch.token_out } },
737
+ ],
738
+ }));
739
+ const sampleParams = new ArrayBuffer(12);
740
+ new DataView(sampleParams).setUint32(0, V, true);
741
+ new DataView(sampleParams).setFloat32(4, 1.0 / 0.75, true); // inv_temperature
742
+ new DataView(sampleParams).setUint32(8, Math.floor(Math.random() * 0xFFFFFFFF), true); // rng_seed
743
+ pass.setBindGroup(1, this.device.createBindGroup({
744
+ layout: this.pipelines['sample'].getBindGroupLayout(1),
745
+ entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(sampleParams)) }}],
746
+ }));
747
+ pass.dispatchWorkgroups(1);
748
+ pass.end();
749
+
750
+ this.device.queue.submit([encoder.finish()]);
751
+ await this.device.queue.onSubmittedWorkDone();
752
+
753
+ // Read back the sampled token
754
+ const tokenResult = await this._readback(this.scratch.token_out, 4);
755
+ this._tokenCount++;
756
+ return new Uint32Array(tokenResult.buffer)[0];
757
+ }
758
+
759
+ // ── Tokenize/detokenize via server ──────────────────────────────────────
760
+ async tokenize(text) {
761
+ const resp = await fetch('/tokenize', {
762
+ method: 'POST',
763
+ headers: { 'Content-Type': 'application/json' },
764
+ body: JSON.stringify({ text }),
765
+ });
766
+ const data = await resp.json();
767
+ return data.result;
768
+ }
769
+
770
+ async detokenize(tokens) {
771
+ const resp = await fetch('/detokenize', {
772
+ method: 'POST',
773
+ headers: { 'Content-Type': 'application/json' },
774
+ body: JSON.stringify({ tokens }),
775
+ });
776
+ const data = await resp.json();
777
+ return data.result;
778
+ }
779
+
780
+ // ── Generate text ──────────────────────────────────────────────────────
781
+ async generate(prompt, maxTokens = 100, temperature = 0.75, onToken = null) {
782
+ if (!this.ready) throw new Error('Call loadWeights() first');
783
+ this._allocateScratch();
784
+
785
+ console.log('[mamba] generate:', prompt, 'max_tokens:', maxTokens);
786
+
787
+ // Tokenize the prompt
788
+ const promptTokens = await this.tokenize(prompt);
789
+ console.log(`[mamba] prompt tokens (${promptTokens.length}):`, promptTokens);
790
+
791
+ // Process prompt tokens through forward pass to build SSM state
792
+ console.log('[mamba] encoding prompt...');
793
+ for (let i = 0; i < promptTokens.length; i++) {
794
+ const t0 = performance.now();
795
+ await this._forwardOneToken(promptTokens[i]);
796
+ const elapsed = performance.now() - t0;
797
+ if (i === 0 || i === promptTokens.length - 1) {
798
+ console.log(`[mamba] prompt token ${i}/${promptTokens.length}: ${promptTokens[i]} (${elapsed.toFixed(0)}ms)`);
799
+ }
800
+ }
801
+ console.log('[mamba] prompt encoded, generating...');
802
+
803
+ // Get the last prompt token's output as first generation input
804
+ const generated = [];
805
+ // The last _forwardOneToken already produced the next-token prediction
806
+ // We need to read it back
807
+ const firstResult = await this._readback(this.scratch.token_out, 4);
808
+ let inputToken = new Uint32Array(firstResult.buffer)[0];
809
+ generated.push(inputToken);
810
+ console.log(`[mamba] first generated token: ${inputToken}`);
811
+ if (onToken) onToken(inputToken, 0);
812
+
813
+ for (let step = 1; step < maxTokens; step++) {
814
+ const t0 = performance.now();
815
+ try {
816
+ const nextToken = await this._forwardOneToken(inputToken);
817
+ const elapsed = performance.now() - t0;
818
+ if (step < 5 || step % 20 === 0) {
819
+ console.log(`[mamba] step ${step}: token=${nextToken} (${elapsed.toFixed(0)}ms)`);
820
+ }
821
+ generated.push(nextToken);
822
+ inputToken = nextToken;
823
+ if (onToken) onToken(nextToken, step);
824
+ if (nextToken === 11 || nextToken === 10 || nextToken === 0) break; // EOS=11, im_end=10, PAD=0
825
+ } catch (e) {
826
+ console.error(`[mamba] step ${step} failed:`, e.message);
827
+ break;
828
+ }
829
+ }
830
+
831
+ // Decode the generated tokens
832
+ const text = await this.detokenize(generated);
833
+ console.log(`[mamba] generated ${generated.length} tokens`);
834
+ return text;
835
+ }
836
+
837
+ // ── Helper: get weight buffer by name ──────────────────────────────────
838
+ _getWeight(name) {
839
+ return this.weights[name] || null;
840
+ }
841
+
842
+ // ── Create a uniform buffer with typed data ────────────────────────────
843
+ _createUniform(data) {
844
+ const buf = this.device.createBuffer({
845
+ size: data.byteLength,
846
+ usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
847
+ mappedAtCreation: true,
848
+ });
849
+ new Uint8Array(buf.getMappedRange()).set(new Uint8Array(data.buffer));
850
+ buf.unmap();
851
+ return buf;
852
+ }
853
+
854
+ // ── BF16 β†’ F32 conversion for a weight tensor ─────────────────────────
855
+ async _convertBF16toF32(weightInfo) {
856
+ if (weightInfo.dtype !== 'BF16' || weightInfo.f32buffer) return weightInfo;
857
+
858
+ const numBF16 = weightInfo.byteLen / 2; // each bf16 is 2 bytes
859
+ const numPairs = weightInfo.byteLen / 4; // each u32 holds 2 bf16
860
+ const f32Bytes = numBF16 * 4;
861
+
862
+ // Create output F32 buffer
863
+ const f32Buf = this._createBuffer(
864
+ f32Bytes,
865
+ GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST,
866
+ weightInfo.buffer.label + '_f32'
867
+ );
868
+
869
+ // Create uniform for params
870
+ const paramBuf = this._createUniform(new Uint32Array([numPairs]));
871
+
872
+ // Create bind groups
873
+ const pipeline = this.pipelines['bf16_to_f32'];
874
+ const bg0 = this.device.createBindGroup({
875
+ layout: pipeline.getBindGroupLayout(0),
876
+ entries: [
877
+ { binding: 0, resource: { buffer: weightInfo.buffer } },
878
+ { binding: 1, resource: { buffer: f32Buf } },
879
+ ],
880
+ });
881
+ const bg1 = this.device.createBindGroup({
882
+ layout: pipeline.getBindGroupLayout(1),
883
+ entries: [
884
+ { binding: 0, resource: { buffer: paramBuf } },
885
+ ],
886
+ });
887
+
888
+ // Dispatch
889
+ const encoder = this.device.createCommandEncoder();
890
+ const pass = encoder.beginComputePass();
891
+ pass.setPipeline(pipeline);
892
+ pass.setBindGroup(0, bg0);
893
+ pass.setBindGroup(1, bg1);
894
+ pass.dispatchWorkgroups(Math.ceil(numPairs / 64));
895
+ pass.end();
896
+ this.device.queue.submit([encoder.finish()]);
897
+ await this.device.queue.onSubmittedWorkDone();
898
+
899
+ // Cache the F32 buffer
900
+ weightInfo.f32buffer = f32Buf;
901
+ weightInfo.f32size = f32Bytes;
902
+ paramBuf.destroy();
903
+ return weightInfo;
904
+ }
905
+
906
+ // ── Get F32 weight buffer (already converted during load) ───────────────
907
+ async _getF32Weight(name) {
908
+ const w = this.weights[name];
909
+ if (!w) throw new Error(`Missing weight: ${name}`);
910
+ return w.buffer;
911
+ }
912
+
913
+ // ── Dispatch a shader with auto bind group creation ─────────────────────
914
+ _dispatchShader(encoder, shaderName, storageBuffers, uniformData) {
915
+ const pipeline = this.pipelines[shaderName];
916
+ const pass = encoder.beginComputePass();
917
+ pass.setPipeline(pipeline);
918
+
919
+ // Bind group 0: storage buffers
920
+ const entries0 = storageBuffers.map((buf, i) => ({
921
+ binding: i, resource: { buffer: buf }
922
+ }));
923
+ const bg0 = this.device.createBindGroup({
924
+ layout: pipeline.getBindGroupLayout(0),
925
+ entries: entries0,
926
+ });
927
+ pass.setBindGroup(0, bg0);
928
+
929
+ // Bind group 1: uniforms (if provided)
930
+ if (uniformData) {
931
+ const ubuf = this._createUniform(uniformData);
932
+ const bg1 = this.device.createBindGroup({
933
+ layout: pipeline.getBindGroupLayout(1),
934
+ entries: [{ binding: 0, resource: { buffer: ubuf } }],
935
+ });
936
+ pass.setBindGroup(1, bg1);
937
+ // Note: ubuf leaks β€” for production, cache these. Fine for proof-of-concept.
938
+ }
939
+
940
+ return pass; // caller sets dispatch count and calls pass.end()
941
+ }
942
+
943
+ // ── Cleanup ────────────────────────────────────────────────────────────
944
+ destroy() {
945
+ for (const w of Object.values(this.weights)) w.buffer.destroy();
946
+ for (const s of Object.values(this.state)) s.destroy();
947
+ this.weights = {};
948
+ this.state = {};
949
+ this.ready = false;
950
+ }
951
+ }
952
+
953
+ // ES module export for browser <script type="module">
954
+ export { MambaRuntime, CONFIG };