/** * mamba_runtime.js — Browser-native Falcon-Mamba inference via WebGPU. * * The first browser-native Mamba/SSM inference engine. * No MLC, no TVM — pure WebGPU compute shaders ported from gfx1151_runtime. * * Architecture: Falcon-Mamba 7B * 64 layers, each: RMSNorm → in_proj → conv1d → SSU → out_proj * Final: RMSNorm → lm_head → sample * * Weight format: safetensors (HF standard), loaded directly into WebGPU buffers. * Shaders: WGSL compute shaders in ./shaders/ (ported from Vulkan GLSL). * * Usage: * const mamba = new MambaRuntime(); * await mamba.init(); * await mamba.loadWeights('./weights/'); * const text = await mamba.generate("Hello Grandma", 100); */ // Falcon-Mamba 7B constants const CONFIG = { hidden_size: 4096, intermediate_size: 8192, // 2 * hidden num_layers: 64, vocab_size: 65024, state_size: 16, // SSM d_state conv_kernel: 4, dt_rank: 256, rms_eps: 1e-5, }; class MambaRuntime { constructor() { this.device = null; this.pipelines = {}; // shader name → GPUComputePipeline this.bindLayouts = {}; // shader name → GPUBindGroupLayout this.weights = {}; // parameter name → GPUBuffer this.state = {}; // per-layer SSM state + conv1d state buffers this.ready = false; } // ── Init: get WebGPU device + compile all shaders ────────────────────── async init() { if (!navigator.gpu) throw new Error('WebGPU not supported in this browser'); const adapter = await navigator.gpu.requestAdapter(); if (!adapter) throw new Error('No WebGPU adapter found'); // Request max buffer size the device supports const limits = adapter.limits; console.log('[mamba] maxBufferSize:', limits.maxBufferSize, '=', (limits.maxBufferSize / 1024 / 1024 / 1024).toFixed(2), 'GB'); this.device = await adapter.requestDevice({ requiredLimits: { maxBufferSize: limits.maxBufferSize, maxStorageBufferBindingSize: limits.maxStorageBufferBindingSize, maxComputeWorkgroupStorageSize: limits.maxComputeWorkgroupStorageSize, maxStorageBuffersPerShaderStage: Math.min(limits.maxStorageBuffersPerShaderStage, 16), } }); this.device.lost.then((info) => { console.error('[mamba] DEVICE LOST:', info.reason, info.message); }); this.device.addEventListener('uncapturederror', (e) => { console.error('[mamba] GPU ERROR:', e.error.message); }); console.log('[mamba] device ready, compiling shaders...'); await this._compileShaders(); console.log('[mamba] shaders compiled'); return this; } // ── Compile all WGSL shaders into compute pipelines ──────────────────── async _compileShaders() { const shaderNames = [ 'conv1d_step', 'ssu', 'matmul_gemv', 'rmsnorm', 'rmsnorm_noweight', 'silu', 'softplus', 'embedding', 'elementwise_mul', 'sample', 'bf16_to_f32', 'add_residual' ]; for (const name of shaderNames) { const resp = await fetch(`./shaders/${name}.wgsl`); if (!resp.ok) throw new Error(`Failed to load shader: ${name}.wgsl`); const code = await resp.text(); const shaderModule = this.device.createShaderModule({ code, label: name }); // Create bind group layouts based on shader requirements // Group 0 = storage buffers (data), Group 1 = uniforms (params) const pipeline = this.device.createComputePipeline({ layout: 'auto', compute: { module: shaderModule, entryPoint: 'main' }, label: name, }); this.pipelines[name] = pipeline; } } // ── Create a GPU buffer ──────────────────────────────────────────────── _createBuffer(size, usage, label) { return this.device.createBuffer({ size: Math.max(size, 4), // WebGPU requires min 4 bytes usage, label, mappedAtCreation: false, }); } // ── Upload data to a GPU buffer ──────────────────────────────────────── _upload(buffer, data) { this.device.queue.writeBuffer(buffer, 0, data); } // ── Read data back from GPU buffer ───────────────────────────────────── async _readback(buffer, size) { const staging = this.device.createBuffer({ size, usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST, }); const encoder = this.device.createCommandEncoder(); encoder.copyBufferToBuffer(buffer, 0, staging, 0, size); this.device.queue.submit([encoder.finish()]); await staging.mapAsync(GPUMapMode.READ); const result = new Float32Array(staging.getMappedRange().slice(0)); staging.unmap(); staging.destroy(); return result; } // ── Dispatch a compute shader ────────────────────────────────────────── _dispatch(shaderName, bindGroup, uniformBindGroup, workgroupsX, workgroupsY = 1, workgroupsZ = 1) { const encoder = this.device.createCommandEncoder(); const pass = encoder.beginComputePass(); pass.setPipeline(this.pipelines[shaderName]); pass.setBindGroup(0, bindGroup); if (uniformBindGroup) pass.setBindGroup(1, uniformBindGroup); pass.dispatchWorkgroups(workgroupsX, workgroupsY, workgroupsZ); pass.end(); this.device.queue.submit([encoder.finish()]); } // ── Load safetensors weights into GPU buffers ────────────────────────── async loadWeights(basePath) { console.log('[mamba] loading weights from', basePath); // Get the shard index const indexResp = await fetch(`${basePath}/model.safetensors.index.json`); let fileMap; // tensor_name → filename let files; if (indexResp.ok) { const index = await indexResp.json(); fileMap = index.weight_map; files = [...new Set(Object.values(fileMap))]; console.log(`[mamba] multi-shard: ${files.length} files, ${Object.keys(fileMap).length} tensors`); } else { files = ['model.safetensors']; fileMap = null; } // For each shard, fetch ONLY the header first (small), then load tensors by byte-range for (const file of files) { console.log(`[mamba] parsing ${file} header...`); // Fetch first 8 bytes to get header length const headResp = await fetch(`${basePath}/${file}`, { headers: { 'Range': 'bytes=0-7' } }); let headerLen; if (headResp.status === 206) { // Range request supported const headBuf = await headResp.arrayBuffer(); headerLen = new DataView(headBuf).getUint32(0, true); } else { // Range not supported — fall back to full fetch but only read header const fullBuf = await headResp.arrayBuffer(); headerLen = new DataView(fullBuf).getUint32(0, true); } console.log(`[mamba] header: ${headerLen} bytes`); // Fetch header JSON const hdrResp = await fetch(`${basePath}/${file}`, { headers: { 'Range': `bytes=8-${8 + headerLen - 1}` } }); let headerStr; if (hdrResp.status === 206) { headerStr = await hdrResp.text(); } else { const fullBuf = await hdrResp.arrayBuffer(); headerStr = new TextDecoder().decode(new Uint8Array(fullBuf, 8, headerLen)); } const header = JSON.parse(headerStr); const dataOffset = 8 + headerLen; // Load each tensor individually const tensorNames = Object.keys(header).filter(n => n !== '__metadata__'); console.log(`[mamba] ${tensorNames.length} tensors in this shard`); let loaded = 0; for (const name of tensorNames) { const meta = header[name]; const dtype = meta.dtype; const shape = meta.shape; const [start, end] = meta.data_offsets; const byteLen = end - start; if (byteLen > 2_000_000_000) { console.log(`[mamba] SKIP ${name} (${(byteLen/1e9).toFixed(2)} GB — exceeds buffer limit)`); continue; } // Fetch this tensor's bytes via Range request const absStart = dataOffset + start; const absEnd = dataOffset + end - 1; const tResp = await fetch(`${basePath}/${file}`, { headers: { 'Range': `bytes=${absStart}-${absEnd}` } }); let tensorBuf; if (tResp.status === 206) { tensorBuf = await tResp.arrayBuffer(); } else { // No range support — need full file (expensive) console.log(`[mamba] WARN: no range support, loading full file for ${name}`); const fullBuf = await tResp.arrayBuffer(); tensorBuf = fullBuf.slice(absStart, absStart + byteLen); } // For BF16 weights: convert to F32 during upload (no double-buffering) let gpuBuf; let finalDtype = dtype; let finalByteLen = byteLen; if (dtype === 'BF16') { // Convert CPU-side: BF16 → F32 before uploading const bf16 = new Uint16Array(tensorBuf); const f32 = new Float32Array(bf16.length); const tmpU32 = new Uint32Array(f32.buffer); for (let j = 0; j < bf16.length; j++) { tmpU32[j] = bf16[j] << 16; // BF16 is top 16 bits of F32 } finalByteLen = f32.byteLength; gpuBuf = this._createBuffer( finalByteLen, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, name ); this._upload(gpuBuf, f32); finalDtype = 'F32'; } else { gpuBuf = this._createBuffer( byteLen, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, name ); this._upload(gpuBuf, new Uint8Array(tensorBuf)); } this.weights[name] = { buffer: gpuBuf, shape, dtype: finalDtype, byteLen: finalByteLen }; loaded++; if (loaded % 20 === 0) { console.log(`[mamba] loaded ${loaded}/${tensorNames.length} tensors`); } } console.log(`[mamba] shard done: ${loaded} tensors loaded`); } console.log(`[mamba] TOTAL: ${Object.keys(this.weights).length} tensors loaded`); // Allocate per-layer state buffers this._allocateState(); this.ready = true; } // ── Allocate persistent SSM state + conv1d cache per layer ───────────── _allocateState() { const H = CONFIG.intermediate_size; // 8192 const S = CONFIG.state_size; // 16 const K = CONFIG.conv_kernel; // 4 for (let l = 0; l < CONFIG.num_layers; l++) { // SSM state: [H, S] = 8192 * 16 = 131072 floats = 512 KB per layer this.state[`layer.${l}.ssm`] = this._createBuffer( H * S * 4, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, `ssm_state_${l}` ); // Conv1d cache: [H, K-1] = 8192 * 3 = 24576 floats = 96 KB per layer this.state[`layer.${l}.conv`] = this._createBuffer( H * (K - 1) * 4, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, `conv_state_${l}` ); } // Total state: 64 layers × (512 + 96) KB = ~38 MB console.log(`[mamba] allocated ${CONFIG.num_layers} layers of SSM + conv1d state (~38 MB)`); } // ── Save/restore SSM state (the entity's persistent soul) ────────────── async saveState() { const state = {}; for (const [key, buf] of Object.entries(this.state)) { state[key] = await this._readback(buf, buf.size); } return state; } async restoreState(state) { for (const [key, data] of Object.entries(state)) { if (this.state[key]) { this._upload(this.state[key], data); } } } // ── Allocate intermediate scratch buffers for forward pass ────────────── _allocateScratch() { if (this.scratch) return; // already allocated const H = CONFIG.hidden_size; // 4096 const I = CONFIG.intermediate_size; // 8192 const DR = CONFIG.dt_rank; // 256 const S = CONFIG.state_size; // 16 const F = 4; // sizeof(float32) this.scratch = { norm_out: this._createBuffer(H * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'norm_out'), projected: this._createBuffer(2 * I * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'projected'), hidden: this._createBuffer(I * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'hidden'), gate: this._createBuffer(I * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'gate'), hidden_c: this._createBuffer(I * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'hidden_c'), sxBC: this._createBuffer((DR + 2*S)*F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'sxBC'), B_proj: this._createBuffer(S * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'B_proj'), C_proj: this._createBuffer(S * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'C_proj'), dt_pre: this._createBuffer(DR * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'dt_pre'), dt: this._createBuffer(I * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'dt'), hidden_y: this._createBuffer(I * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'hidden_y'), gate_silu: this._createBuffer(I * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'gate_silu'), out_proj_o: this._createBuffer(H * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'out_proj_o'), logits: this._createBuffer(CONFIG.vocab_size * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'logits'), token_out: this._createBuffer(4, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'token_out'), hidden_state: this._createBuffer(H * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'hidden_state'), token_id: this._createBuffer(4, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'token_id'), }; this._tokenCount = 0; console.log('[mamba] scratch buffers allocated'); } // ── Single-token forward pass through all 64 layers ───────────────────── async _forwardOneToken(tokenId) { const H = CONFIG.hidden_size; // 4096 const I = CONFIG.intermediate_size; // 8192 const DR = CONFIG.dt_rank; // 256 const S = CONFIG.state_size; // 16 const V = CONFIG.vocab_size; // 65024 // Step 1: Embedding lookup — copy one row from embedding table to hidden_state this._upload(this.scratch.token_id, new Uint32Array([tokenId])); const embBuf = await this._getF32Weight('backbone.embeddings.weight'); const encoder1 = this.device.createCommandEncoder(); encoder1.copyBufferToBuffer(embBuf, tokenId * H * 4, this.scratch.hidden_state, 0, H * 4); this.device.queue.submit([encoder1.finish()]); // Step 2: For each layer (0..63) for (let l = 0; l < CONFIG.num_layers; l++) { const prefix = `backbone.layers.${l}`; // rmsnorm(hidden_state, norm.weight) → norm_out const normW = await this._getF32Weight(`${prefix}.norm.weight`); let encoder = this.device.createCommandEncoder(); let pass = encoder.beginComputePass(); pass.setPipeline(this.pipelines['rmsnorm']); pass.setBindGroup(0, this.device.createBindGroup({ layout: this.pipelines['rmsnorm'].getBindGroupLayout(0), entries: [ { binding: 0, resource: { buffer: this.scratch.hidden_state } }, { binding: 1, resource: { buffer: normW } }, { binding: 2, resource: { buffer: this.scratch.norm_out } }, ], })); const rmsnormParams = new ArrayBuffer(12); new DataView(rmsnormParams).setUint32(0, 1, true); new DataView(rmsnormParams).setUint32(4, H, true); new DataView(rmsnormParams).setFloat32(8, CONFIG.rms_eps, true); pass.setBindGroup(1, this.device.createBindGroup({ layout: this.pipelines['rmsnorm'].getBindGroupLayout(1), entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(rmsnormParams)) }}], })); pass.dispatchWorkgroups(1); // one workgroup per row, 1 row pass.end(); // matmul_gemv(norm_out, in_proj.weight) → projected [I*2 = 16384] const inProjW = await this._getF32Weight(`${prefix}.mixer.in_proj.weight`); pass = encoder.beginComputePass(); pass.setPipeline(this.pipelines['matmul_gemv']); pass.setBindGroup(0, this.device.createBindGroup({ layout: this.pipelines['matmul_gemv'].getBindGroupLayout(0), entries: [ { binding: 0, resource: { buffer: this.scratch.norm_out } }, { binding: 1, resource: { buffer: inProjW } }, { binding: 2, resource: { buffer: this.scratch.projected } }, ], })); const gemvParams1 = new ArrayBuffer(8); new DataView(gemvParams1).setUint32(0, I * 2, true); // N new DataView(gemvParams1).setUint32(4, H, true); // K pass.setBindGroup(1, this.device.createBindGroup({ layout: this.pipelines['matmul_gemv'].getBindGroupLayout(1), entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(gemvParams1)) }}], })); pass.dispatchWorkgroups(I * 2); // one workgroup per output element pass.end(); // Split projected → hidden[0:I], gate[I:2I] via buffer copies encoder.copyBufferToBuffer(this.scratch.projected, 0, this.scratch.hidden, 0, I * 4); encoder.copyBufferToBuffer(this.scratch.projected, I * 4, this.scratch.gate, 0, I * 4); // conv1d_step(conv_state, hidden, conv1d.weight, conv1d.bias) → hidden_c const conv1dW = await this._getF32Weight(`${prefix}.mixer.conv1d.weight`); const conv1dB = await this._getF32Weight(`${prefix}.mixer.conv1d.bias`); pass = encoder.beginComputePass(); pass.setPipeline(this.pipelines['conv1d_step']); pass.setBindGroup(0, this.device.createBindGroup({ layout: this.pipelines['conv1d_step'].getBindGroupLayout(0), entries: [ { binding: 0, resource: { buffer: this.state[`layer.${l}.conv`] } }, { binding: 1, resource: { buffer: this.scratch.hidden } }, { binding: 2, resource: { buffer: conv1dW } }, { binding: 3, resource: { buffer: conv1dB } }, { binding: 4, resource: { buffer: this.scratch.hidden_c } }, ], })); const conv1dParams = new ArrayBuffer(4); new DataView(conv1dParams).setUint32(0, I, true); pass.setBindGroup(1, this.device.createBindGroup({ layout: this.pipelines['conv1d_step'].getBindGroupLayout(1), entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(conv1dParams)) }}], })); pass.dispatchWorkgroups(Math.ceil(I / 64)); pass.end(); // silu(hidden_c) in-place → hidden_a pass = encoder.beginComputePass(); pass.setPipeline(this.pipelines['silu']); pass.setBindGroup(0, this.device.createBindGroup({ layout: this.pipelines['silu'].getBindGroupLayout(0), entries: [{ binding: 0, resource: { buffer: this.scratch.hidden_c } }], })); const siluParams = new ArrayBuffer(4); new DataView(siluParams).setUint32(0, I, true); pass.setBindGroup(1, this.device.createBindGroup({ layout: this.pipelines['silu'].getBindGroupLayout(1), entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(siluParams)) }}], })); pass.dispatchWorkgroups(Math.ceil(I / 64)); pass.end(); // hidden_c is now silu'd (= hidden_a) // matmul_gemv(hidden_c, x_proj.weight) → sxBC [DR+2*S = 288] const xProjW = await this._getF32Weight(`${prefix}.mixer.x_proj.weight`); pass = encoder.beginComputePass(); pass.setPipeline(this.pipelines['matmul_gemv']); pass.setBindGroup(0, this.device.createBindGroup({ layout: this.pipelines['matmul_gemv'].getBindGroupLayout(0), entries: [ { binding: 0, resource: { buffer: this.scratch.hidden_c } }, { binding: 1, resource: { buffer: xProjW } }, { binding: 2, resource: { buffer: this.scratch.sxBC } }, ], })); const gemvParams2 = new ArrayBuffer(8); new DataView(gemvParams2).setUint32(0, DR + 2 * S, true); new DataView(gemvParams2).setUint32(4, I, true); pass.setBindGroup(1, this.device.createBindGroup({ layout: this.pipelines['matmul_gemv'].getBindGroupLayout(1), entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(gemvParams2)) }}], })); pass.dispatchWorkgroups(DR + 2 * S); pass.end(); // Copy dt_pre, B, C from sxBC into separate buffers encoder.copyBufferToBuffer(this.scratch.sxBC, 0, this.scratch.dt_pre, 0, DR * 4); encoder.copyBufferToBuffer(this.scratch.sxBC, DR * 4, this.scratch.B_proj, 0, S * 4); encoder.copyBufferToBuffer(this.scratch.sxBC, (DR + S) * 4, this.scratch.C_proj, 0, S * 4); // Falcon-Mamba: RMSNorm(dt_pre), RMSNorm(B), RMSNorm(C) before use const rmsNwParams_dt = new ArrayBuffer(8); new DataView(rmsNwParams_dt).setUint32(0, DR, true); new DataView(rmsNwParams_dt).setFloat32(4, CONFIG.rms_eps, true); pass = encoder.beginComputePass(); pass.setPipeline(this.pipelines['rmsnorm_noweight']); pass.setBindGroup(0, this.device.createBindGroup({ layout: this.pipelines['rmsnorm_noweight'].getBindGroupLayout(0), entries: [{ binding: 0, resource: { buffer: this.scratch.dt_pre } }], })); pass.setBindGroup(1, this.device.createBindGroup({ layout: this.pipelines['rmsnorm_noweight'].getBindGroupLayout(1), entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(rmsNwParams_dt)) }}], })); pass.dispatchWorkgroups(1); pass.end(); const rmsNwParams_s = new ArrayBuffer(8); new DataView(rmsNwParams_s).setUint32(0, S, true); new DataView(rmsNwParams_s).setFloat32(4, CONFIG.rms_eps, true); pass = encoder.beginComputePass(); pass.setPipeline(this.pipelines['rmsnorm_noweight']); pass.setBindGroup(0, this.device.createBindGroup({ layout: this.pipelines['rmsnorm_noweight'].getBindGroupLayout(0), entries: [{ binding: 0, resource: { buffer: this.scratch.B_proj } }], })); pass.setBindGroup(1, this.device.createBindGroup({ layout: this.pipelines['rmsnorm_noweight'].getBindGroupLayout(1), entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(rmsNwParams_s)) }}], })); pass.dispatchWorkgroups(1); pass.end(); pass = encoder.beginComputePass(); pass.setPipeline(this.pipelines['rmsnorm_noweight']); pass.setBindGroup(0, this.device.createBindGroup({ layout: this.pipelines['rmsnorm_noweight'].getBindGroupLayout(0), entries: [{ binding: 0, resource: { buffer: this.scratch.C_proj } }], })); pass.setBindGroup(1, this.device.createBindGroup({ layout: this.pipelines['rmsnorm_noweight'].getBindGroupLayout(1), entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(rmsNwParams_s)) }}], })); pass.dispatchWorkgroups(1); pass.end(); // matmul_gemv(dt_pre_normalized, dt_proj.weight) → dt [I] const dtProjW = await this._getF32Weight(`${prefix}.mixer.dt_proj.weight`); pass = encoder.beginComputePass(); pass.setPipeline(this.pipelines['matmul_gemv']); pass.setBindGroup(0, this.device.createBindGroup({ layout: this.pipelines['matmul_gemv'].getBindGroupLayout(0), entries: [ { binding: 0, resource: { buffer: this.scratch.dt_pre } }, { binding: 1, resource: { buffer: dtProjW } }, { binding: 2, resource: { buffer: this.scratch.dt } }, ], })); const gemvParams3 = new ArrayBuffer(8); new DataView(gemvParams3).setUint32(0, I, true); new DataView(gemvParams3).setUint32(4, DR, true); pass.setBindGroup(1, this.device.createBindGroup({ layout: this.pipelines['matmul_gemv'].getBindGroupLayout(1), entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(gemvParams3)) }}], })); pass.dispatchWorkgroups(I); pass.end(); // SSU: selective_state_update // ssu(state, hidden_c, dt, A, B, C, D, dt_bias) → hidden_y const aLog = await this._getF32Weight(`${prefix}.mixer.A_log`); const dWeight = await this._getF32Weight(`${prefix}.mixer.D`); const dtBias = await this._getF32Weight(`${prefix}.mixer.dt_proj.bias`); pass = encoder.beginComputePass(); pass.setPipeline(this.pipelines['ssu']); pass.setBindGroup(0, this.device.createBindGroup({ layout: this.pipelines['ssu'].getBindGroupLayout(0), entries: [ { binding: 0, resource: { buffer: this.state[`layer.${l}.ssm`] } }, { binding: 1, resource: { buffer: this.scratch.hidden_c } }, // x (silu'd) { binding: 2, resource: { buffer: this.scratch.dt } }, { binding: 3, resource: { buffer: aLog } }, // A (needs -exp transform) { binding: 4, resource: { buffer: this.scratch.B_proj } }, // B { binding: 5, resource: { buffer: this.scratch.C_proj } }, // C { binding: 6, resource: { buffer: dWeight } }, { binding: 7, resource: { buffer: dtBias } }, { binding: 8, resource: { buffer: this.scratch.hidden_y } }, ], })); const ssuParams = new ArrayBuffer(8); new DataView(ssuParams).setUint32(0, I, true); // H new DataView(ssuParams).setUint32(4, S, true); // S pass.setBindGroup(1, this.device.createBindGroup({ layout: this.pipelines['ssu'].getBindGroupLayout(1), entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(ssuParams)) }}], })); pass.dispatchWorkgroups(I); // one workgroup per h pass.end(); // silu(gate) in-place pass = encoder.beginComputePass(); pass.setPipeline(this.pipelines['silu']); pass.setBindGroup(0, this.device.createBindGroup({ layout: this.pipelines['silu'].getBindGroupLayout(0), entries: [{ binding: 0, resource: { buffer: this.scratch.gate } }], })); pass.setBindGroup(1, this.device.createBindGroup({ layout: this.pipelines['silu'].getBindGroupLayout(1), entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(siluParams)) }}], })); pass.dispatchWorkgroups(Math.ceil(I / 64)); pass.end(); // elementwise_mul: hidden_y *= gate (in-place into hidden_y) pass = encoder.beginComputePass(); pass.setPipeline(this.pipelines['elementwise_mul']); pass.setBindGroup(0, this.device.createBindGroup({ layout: this.pipelines['elementwise_mul'].getBindGroupLayout(0), entries: [ { binding: 0, resource: { buffer: this.scratch.hidden_y } }, { binding: 1, resource: { buffer: this.scratch.gate } }, ], })); pass.setBindGroup(1, this.device.createBindGroup({ layout: this.pipelines['elementwise_mul'].getBindGroupLayout(1), entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(siluParams)) }}], })); pass.dispatchWorkgroups(Math.ceil(I / 64)); pass.end(); // matmul_gemv(hidden_y, out_proj.weight) → out_proj_o [H] const outProjW = await this._getF32Weight(`${prefix}.mixer.out_proj.weight`); pass = encoder.beginComputePass(); pass.setPipeline(this.pipelines['matmul_gemv']); pass.setBindGroup(0, this.device.createBindGroup({ layout: this.pipelines['matmul_gemv'].getBindGroupLayout(0), entries: [ { binding: 0, resource: { buffer: this.scratch.hidden_y } }, { binding: 1, resource: { buffer: outProjW } }, { binding: 2, resource: { buffer: this.scratch.out_proj_o } }, ], })); const gemvParams4 = new ArrayBuffer(8); new DataView(gemvParams4).setUint32(0, H, true); new DataView(gemvParams4).setUint32(4, I, true); pass.setBindGroup(1, this.device.createBindGroup({ layout: this.pipelines['matmul_gemv'].getBindGroupLayout(1), entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(gemvParams4)) }}], })); pass.dispatchWorkgroups(H); pass.end(); // Submit this layer's command buffer this.device.queue.submit([encoder.finish()]); // Debug: readback hidden_state after residual for select layers // Residual add: hidden_state += out_proj_o { const enc2 = this.device.createCommandEncoder(); const addPass = enc2.beginComputePass(); addPass.setPipeline(this.pipelines['add_residual']); addPass.setBindGroup(0, this.device.createBindGroup({ layout: this.pipelines['add_residual'].getBindGroupLayout(0), entries: [ { binding: 0, resource: { buffer: this.scratch.hidden_state } }, { binding: 1, resource: { buffer: this.scratch.out_proj_o } }, ], })); const addParams = new ArrayBuffer(4); new DataView(addParams).setUint32(0, H, true); addPass.setBindGroup(1, this.device.createBindGroup({ layout: this.pipelines['add_residual'].getBindGroupLayout(1), entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(addParams)) }}], })); addPass.dispatchWorkgroups(Math.ceil(H / 64)); addPass.end(); this.device.queue.submit([enc2.finish()]); } } // Final: rmsnorm + lm_head + sample await this.device.queue.onSubmittedWorkDone(); // rmsnorm(hidden_state, backbone.norm_f.weight) → norm_out const normFW = await this._getF32Weight('backbone.norm_f.weight'); let encoder = this.device.createCommandEncoder(); let pass = encoder.beginComputePass(); pass.setPipeline(this.pipelines['rmsnorm']); pass.setBindGroup(0, this.device.createBindGroup({ layout: this.pipelines['rmsnorm'].getBindGroupLayout(0), entries: [ { binding: 0, resource: { buffer: this.scratch.hidden_state } }, { binding: 1, resource: { buffer: normFW } }, { binding: 2, resource: { buffer: this.scratch.norm_out } }, ], })); const finalNormParams = new ArrayBuffer(12); new DataView(finalNormParams).setUint32(0, 1, true); new DataView(finalNormParams).setUint32(4, H, true); new DataView(finalNormParams).setFloat32(8, CONFIG.rms_eps, true); pass.setBindGroup(1, this.device.createBindGroup({ layout: this.pipelines['rmsnorm'].getBindGroupLayout(1), entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(finalNormParams)) }}], })); pass.dispatchWorkgroups(1); pass.end(); // matmul_gemv(norm_out, lm_head.weight) → logits [V] const lmHeadW = await this._getF32Weight('lm_head.weight'); pass = encoder.beginComputePass(); pass.setPipeline(this.pipelines['matmul_gemv']); pass.setBindGroup(0, this.device.createBindGroup({ layout: this.pipelines['matmul_gemv'].getBindGroupLayout(0), entries: [ { binding: 0, resource: { buffer: this.scratch.norm_out } }, { binding: 1, resource: { buffer: lmHeadW } }, { binding: 2, resource: { buffer: this.scratch.logits } }, ], })); const gemvFinal = new ArrayBuffer(8); new DataView(gemvFinal).setUint32(0, V, true); new DataView(gemvFinal).setUint32(4, H, true); pass.setBindGroup(1, this.device.createBindGroup({ layout: this.pipelines['matmul_gemv'].getBindGroupLayout(1), entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(gemvFinal)) }}], })); pass.dispatchWorkgroups(V); pass.end(); // sample(logits, temperature) → token_out pass = encoder.beginComputePass(); pass.setPipeline(this.pipelines['sample']); pass.setBindGroup(0, this.device.createBindGroup({ layout: this.pipelines['sample'].getBindGroupLayout(0), entries: [ { binding: 0, resource: { buffer: this.scratch.logits } }, { binding: 1, resource: { buffer: this.scratch.token_out } }, ], })); const sampleParams = new ArrayBuffer(12); new DataView(sampleParams).setUint32(0, V, true); new DataView(sampleParams).setFloat32(4, 1.0 / 0.75, true); // inv_temperature new DataView(sampleParams).setUint32(8, Math.floor(Math.random() * 0xFFFFFFFF), true); // rng_seed pass.setBindGroup(1, this.device.createBindGroup({ layout: this.pipelines['sample'].getBindGroupLayout(1), entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(sampleParams)) }}], })); pass.dispatchWorkgroups(1); pass.end(); this.device.queue.submit([encoder.finish()]); await this.device.queue.onSubmittedWorkDone(); // Read back the sampled token const tokenResult = await this._readback(this.scratch.token_out, 4); this._tokenCount++; return new Uint32Array(tokenResult.buffer)[0]; } // ── Tokenize/detokenize via server ────────────────────────────────────── async tokenize(text, system = '') { const resp = await fetch('/tokenize', { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ text, system }), }); const data = await resp.json(); return data.result; } async detokenize(tokens) { const resp = await fetch('/detokenize', { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ tokens }), }); const data = await resp.json(); return data.result; } // ── Generate text ────────────────────────────────────────────────────── async generate(prompt, maxTokens = 100, temperature = 0.75, onToken = null, system = '') { if (!this.ready) throw new Error('Call loadWeights() first'); this._allocateScratch(); console.log('[mamba] generate:', prompt, 'max_tokens:', maxTokens); // Tokenize the prompt (with optional system prompt) const promptTokens = await this.tokenize(prompt, system); console.log(`[mamba] prompt tokens (${promptTokens.length}):`, promptTokens); // Process prompt tokens through forward pass to build SSM state console.log('[mamba] encoding prompt...'); for (let i = 0; i < promptTokens.length; i++) { const t0 = performance.now(); await this._forwardOneToken(promptTokens[i]); const elapsed = performance.now() - t0; if (i === 0 || i === promptTokens.length - 1) { console.log(`[mamba] prompt token ${i}/${promptTokens.length}: ${promptTokens[i]} (${elapsed.toFixed(0)}ms)`); } } console.log('[mamba] prompt encoded, generating...'); // Get the last prompt token's output as first generation input const generated = []; // The last _forwardOneToken already produced the next-token prediction // We need to read it back const firstResult = await this._readback(this.scratch.token_out, 4); let inputToken = new Uint32Array(firstResult.buffer)[0]; generated.push(inputToken); console.log(`[mamba] first generated token: ${inputToken}`); if (onToken) onToken(inputToken, 0); for (let step = 1; step < maxTokens; step++) { const t0 = performance.now(); try { const nextToken = await this._forwardOneToken(inputToken); const elapsed = performance.now() - t0; if (step < 5 || step % 20 === 0) { console.log(`[mamba] step ${step}: token=${nextToken} (${elapsed.toFixed(0)}ms)`); } generated.push(nextToken); inputToken = nextToken; if (onToken) onToken(nextToken, step); if (nextToken === 11 || nextToken === 10 || nextToken === 0) break; // EOS=11, im_end=10, PAD=0 } catch (e) { console.error(`[mamba] step ${step} failed:`, e.message); break; } } // Decode the generated tokens const text = await this.detokenize(generated); console.log(`[mamba] generated ${generated.length} tokens`); return text; } // ── Helper: get weight buffer by name ────────────────────────────────── _getWeight(name) { return this.weights[name] || null; } // ── Create a uniform buffer with typed data ──────────────────────────── _createUniform(data) { const buf = this.device.createBuffer({ size: data.byteLength, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST, mappedAtCreation: true, }); new Uint8Array(buf.getMappedRange()).set(new Uint8Array(data.buffer)); buf.unmap(); return buf; } // ── BF16 → F32 conversion for a weight tensor ───────────────────────── async _convertBF16toF32(weightInfo) { if (weightInfo.dtype !== 'BF16' || weightInfo.f32buffer) return weightInfo; const numBF16 = weightInfo.byteLen / 2; // each bf16 is 2 bytes const numPairs = weightInfo.byteLen / 4; // each u32 holds 2 bf16 const f32Bytes = numBF16 * 4; // Create output F32 buffer const f32Buf = this._createBuffer( f32Bytes, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, weightInfo.buffer.label + '_f32' ); // Create uniform for params const paramBuf = this._createUniform(new Uint32Array([numPairs])); // Create bind groups const pipeline = this.pipelines['bf16_to_f32']; const bg0 = this.device.createBindGroup({ layout: pipeline.getBindGroupLayout(0), entries: [ { binding: 0, resource: { buffer: weightInfo.buffer } }, { binding: 1, resource: { buffer: f32Buf } }, ], }); const bg1 = this.device.createBindGroup({ layout: pipeline.getBindGroupLayout(1), entries: [ { binding: 0, resource: { buffer: paramBuf } }, ], }); // Dispatch const encoder = this.device.createCommandEncoder(); const pass = encoder.beginComputePass(); pass.setPipeline(pipeline); pass.setBindGroup(0, bg0); pass.setBindGroup(1, bg1); pass.dispatchWorkgroups(Math.ceil(numPairs / 64)); pass.end(); this.device.queue.submit([encoder.finish()]); await this.device.queue.onSubmittedWorkDone(); // Cache the F32 buffer weightInfo.f32buffer = f32Buf; weightInfo.f32size = f32Bytes; paramBuf.destroy(); return weightInfo; } // ── Get F32 weight buffer (already converted during load) ─────────────── async _getF32Weight(name) { const w = this.weights[name]; if (!w) throw new Error(`Missing weight: ${name}`); return w.buffer; } // ── Dispatch a shader with auto bind group creation ───────────────────── _dispatchShader(encoder, shaderName, storageBuffers, uniformData) { const pipeline = this.pipelines[shaderName]; const pass = encoder.beginComputePass(); pass.setPipeline(pipeline); // Bind group 0: storage buffers const entries0 = storageBuffers.map((buf, i) => ({ binding: i, resource: { buffer: buf } })); const bg0 = this.device.createBindGroup({ layout: pipeline.getBindGroupLayout(0), entries: entries0, }); pass.setBindGroup(0, bg0); // Bind group 1: uniforms (if provided) if (uniformData) { const ubuf = this._createUniform(uniformData); const bg1 = this.device.createBindGroup({ layout: pipeline.getBindGroupLayout(1), entries: [{ binding: 0, resource: { buffer: ubuf } }], }); pass.setBindGroup(1, bg1); // Note: ubuf leaks — for production, cache these. Fine for proof-of-concept. } return pass; // caller sets dispatch count and calls pass.end() } // ── Cleanup ──────────────────────────────────────────────────────────── destroy() { for (const w of Object.values(this.weights)) w.buffer.destroy(); for (const s of Object.values(this.state)) s.destroy(); this.weights = {}; this.state = {}; this.ready = false; } } // ES module export for browser