| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| const CONFIG = { |
| hidden_size: 4096, |
| intermediate_size: 8192, |
| num_layers: 64, |
| vocab_size: 65024, |
| state_size: 16, |
| conv_kernel: 4, |
| dt_rank: 256, |
| rms_eps: 1e-5, |
| }; |
|
|
| class MambaRuntime { |
| constructor() { |
| this.device = null; |
| this.pipelines = {}; |
| this.bindLayouts = {}; |
| this.weights = {}; |
| this.state = {}; |
| this.ready = false; |
| } |
|
|
| |
| async init() { |
| if (!navigator.gpu) throw new Error('WebGPU not supported in this browser'); |
| const adapter = await navigator.gpu.requestAdapter(); |
| if (!adapter) throw new Error('No WebGPU adapter found'); |
|
|
| |
| const limits = adapter.limits; |
| console.log('[mamba] maxBufferSize:', limits.maxBufferSize, |
| '=', (limits.maxBufferSize / 1024 / 1024 / 1024).toFixed(2), 'GB'); |
|
|
| this.device = await adapter.requestDevice({ |
| requiredLimits: { |
| maxBufferSize: limits.maxBufferSize, |
| maxStorageBufferBindingSize: limits.maxStorageBufferBindingSize, |
| maxComputeWorkgroupStorageSize: limits.maxComputeWorkgroupStorageSize, |
| maxStorageBuffersPerShaderStage: Math.min(limits.maxStorageBuffersPerShaderStage, 16), |
| } |
| }); |
|
|
| this.device.lost.then((info) => { |
| console.error('[mamba] DEVICE LOST:', info.reason, info.message); |
| }); |
| this.device.addEventListener('uncapturederror', (e) => { |
| console.error('[mamba] GPU ERROR:', e.error.message); |
| }); |
|
|
| console.log('[mamba] device ready, compiling shaders...'); |
| await this._compileShaders(); |
| console.log('[mamba] shaders compiled'); |
| return this; |
| } |
|
|
| |
| async _compileShaders() { |
| const shaderNames = [ |
| 'conv1d_step', 'ssu', 'matmul_gemv', 'rmsnorm', 'rmsnorm_noweight', |
| 'silu', 'softplus', 'embedding', 'elementwise_mul', 'sample', |
| 'bf16_to_f32', 'add_residual' |
| ]; |
|
|
| for (const name of shaderNames) { |
| const resp = await fetch(`./shaders/${name}.wgsl`); |
| if (!resp.ok) throw new Error(`Failed to load shader: ${name}.wgsl`); |
| const code = await resp.text(); |
|
|
| const shaderModule = this.device.createShaderModule({ code, label: name }); |
|
|
| |
| |
| const pipeline = this.device.createComputePipeline({ |
| layout: 'auto', |
| compute: { module: shaderModule, entryPoint: 'main' }, |
| label: name, |
| }); |
|
|
| this.pipelines[name] = pipeline; |
| } |
| } |
|
|
| |
| _createBuffer(size, usage, label) { |
| return this.device.createBuffer({ |
| size: Math.max(size, 4), |
| usage, |
| label, |
| mappedAtCreation: false, |
| }); |
| } |
|
|
| |
| _upload(buffer, data) { |
| this.device.queue.writeBuffer(buffer, 0, data); |
| } |
|
|
| |
| async _readback(buffer, size) { |
| const staging = this.device.createBuffer({ |
| size, |
| usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST, |
| }); |
| const encoder = this.device.createCommandEncoder(); |
| encoder.copyBufferToBuffer(buffer, 0, staging, 0, size); |
| this.device.queue.submit([encoder.finish()]); |
| await staging.mapAsync(GPUMapMode.READ); |
| const result = new Float32Array(staging.getMappedRange().slice(0)); |
| staging.unmap(); |
| staging.destroy(); |
| return result; |
| } |
|
|
| |
| _dispatch(shaderName, bindGroup, uniformBindGroup, workgroupsX, workgroupsY = 1, workgroupsZ = 1) { |
| const encoder = this.device.createCommandEncoder(); |
| const pass = encoder.beginComputePass(); |
| pass.setPipeline(this.pipelines[shaderName]); |
| pass.setBindGroup(0, bindGroup); |
| if (uniformBindGroup) pass.setBindGroup(1, uniformBindGroup); |
| pass.dispatchWorkgroups(workgroupsX, workgroupsY, workgroupsZ); |
| pass.end(); |
| this.device.queue.submit([encoder.finish()]); |
| } |
|
|
| |
| async loadWeights(basePath) { |
| console.log('[mamba] loading weights from', basePath); |
|
|
| |
| const indexResp = await fetch(`${basePath}/model.safetensors.index.json`); |
| let fileMap; |
| let files; |
| if (indexResp.ok) { |
| const index = await indexResp.json(); |
| fileMap = index.weight_map; |
| files = [...new Set(Object.values(fileMap))]; |
| console.log(`[mamba] multi-shard: ${files.length} files, ${Object.keys(fileMap).length} tensors`); |
| } else { |
| files = ['model.safetensors']; |
| fileMap = null; |
| } |
|
|
| |
| for (const file of files) { |
| console.log(`[mamba] parsing ${file} header...`); |
|
|
| |
| const headResp = await fetch(`${basePath}/${file}`, { |
| headers: { 'Range': 'bytes=0-7' } |
| }); |
| let headerLen; |
| if (headResp.status === 206) { |
| |
| const headBuf = await headResp.arrayBuffer(); |
| headerLen = new DataView(headBuf).getUint32(0, true); |
| } else { |
| |
| const fullBuf = await headResp.arrayBuffer(); |
| headerLen = new DataView(fullBuf).getUint32(0, true); |
| } |
| console.log(`[mamba] header: ${headerLen} bytes`); |
|
|
| |
| const hdrResp = await fetch(`${basePath}/${file}`, { |
| headers: { 'Range': `bytes=8-${8 + headerLen - 1}` } |
| }); |
| let headerStr; |
| if (hdrResp.status === 206) { |
| headerStr = await hdrResp.text(); |
| } else { |
| const fullBuf = await hdrResp.arrayBuffer(); |
| headerStr = new TextDecoder().decode(new Uint8Array(fullBuf, 8, headerLen)); |
| } |
| const header = JSON.parse(headerStr); |
| const dataOffset = 8 + headerLen; |
|
|
| |
| const tensorNames = Object.keys(header).filter(n => n !== '__metadata__'); |
| console.log(`[mamba] ${tensorNames.length} tensors in this shard`); |
|
|
| let loaded = 0; |
| for (const name of tensorNames) { |
| const meta = header[name]; |
| const dtype = meta.dtype; |
| const shape = meta.shape; |
| const [start, end] = meta.data_offsets; |
| const byteLen = end - start; |
|
|
| if (byteLen > 2_000_000_000) { |
| console.log(`[mamba] SKIP ${name} (${(byteLen/1e9).toFixed(2)} GB β exceeds buffer limit)`); |
| continue; |
| } |
|
|
| |
| const absStart = dataOffset + start; |
| const absEnd = dataOffset + end - 1; |
| const tResp = await fetch(`${basePath}/${file}`, { |
| headers: { 'Range': `bytes=${absStart}-${absEnd}` } |
| }); |
|
|
| let tensorBuf; |
| if (tResp.status === 206) { |
| tensorBuf = await tResp.arrayBuffer(); |
| } else { |
| |
| console.log(`[mamba] WARN: no range support, loading full file for ${name}`); |
| const fullBuf = await tResp.arrayBuffer(); |
| tensorBuf = fullBuf.slice(absStart, absStart + byteLen); |
| } |
|
|
| |
| let gpuBuf; |
| let finalDtype = dtype; |
| let finalByteLen = byteLen; |
|
|
| if (dtype === 'BF16') { |
| |
| const bf16 = new Uint16Array(tensorBuf); |
| const f32 = new Float32Array(bf16.length); |
| const tmpU32 = new Uint32Array(f32.buffer); |
| for (let j = 0; j < bf16.length; j++) { |
| tmpU32[j] = bf16[j] << 16; |
| } |
| finalByteLen = f32.byteLength; |
| gpuBuf = this._createBuffer( |
| finalByteLen, |
| GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, |
| name |
| ); |
| this._upload(gpuBuf, f32); |
| finalDtype = 'F32'; |
| } else { |
| gpuBuf = this._createBuffer( |
| byteLen, |
| GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, |
| name |
| ); |
| this._upload(gpuBuf, new Uint8Array(tensorBuf)); |
| } |
| this.weights[name] = { buffer: gpuBuf, shape, dtype: finalDtype, byteLen: finalByteLen }; |
|
|
| loaded++; |
| if (loaded % 20 === 0) { |
| console.log(`[mamba] loaded ${loaded}/${tensorNames.length} tensors`); |
| } |
| } |
| console.log(`[mamba] shard done: ${loaded} tensors loaded`); |
| } |
|
|
| console.log(`[mamba] TOTAL: ${Object.keys(this.weights).length} tensors loaded`); |
|
|
| |
| this._allocateState(); |
| this.ready = true; |
| } |
|
|
| |
| _allocateState() { |
| const H = CONFIG.intermediate_size; |
| const S = CONFIG.state_size; |
| const K = CONFIG.conv_kernel; |
|
|
| for (let l = 0; l < CONFIG.num_layers; l++) { |
| |
| this.state[`layer.${l}.ssm`] = this._createBuffer( |
| H * S * 4, |
| GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, |
| `ssm_state_${l}` |
| ); |
|
|
| |
| this.state[`layer.${l}.conv`] = this._createBuffer( |
| H * (K - 1) * 4, |
| GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, |
| `conv_state_${l}` |
| ); |
| } |
|
|
| |
| console.log(`[mamba] allocated ${CONFIG.num_layers} layers of SSM + conv1d state (~38 MB)`); |
| } |
|
|
| |
| async saveState() { |
| const state = {}; |
| for (const [key, buf] of Object.entries(this.state)) { |
| state[key] = await this._readback(buf, buf.size); |
| } |
| return state; |
| } |
|
|
| async restoreState(state) { |
| for (const [key, data] of Object.entries(state)) { |
| if (this.state[key]) { |
| this._upload(this.state[key], data); |
| } |
| } |
| } |
|
|
| |
| _allocateScratch() { |
| if (this.scratch) return; |
| const H = CONFIG.hidden_size; |
| const I = CONFIG.intermediate_size; |
| const DR = CONFIG.dt_rank; |
| const S = CONFIG.state_size; |
| const F = 4; |
|
|
| this.scratch = { |
| norm_out: this._createBuffer(H * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'norm_out'), |
| projected: this._createBuffer(2 * I * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'projected'), |
| hidden: this._createBuffer(I * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'hidden'), |
| gate: this._createBuffer(I * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'gate'), |
| hidden_c: this._createBuffer(I * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'hidden_c'), |
| sxBC: this._createBuffer((DR + 2*S)*F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'sxBC'), |
| B_proj: this._createBuffer(S * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'B_proj'), |
| C_proj: this._createBuffer(S * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'C_proj'), |
| dt_pre: this._createBuffer(DR * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'dt_pre'), |
| dt: this._createBuffer(I * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'dt'), |
| hidden_y: this._createBuffer(I * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'hidden_y'), |
| gate_silu: this._createBuffer(I * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'gate_silu'), |
| out_proj_o: this._createBuffer(H * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'out_proj_o'), |
| logits: this._createBuffer(CONFIG.vocab_size * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'logits'), |
| token_out: this._createBuffer(4, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'token_out'), |
| hidden_state: this._createBuffer(H * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'hidden_state'), |
| token_id: this._createBuffer(4, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'token_id'), |
| }; |
| this._tokenCount = 0; |
| console.log('[mamba] scratch buffers allocated'); |
| } |
|
|
| |
| async _forwardOneToken(tokenId) { |
| const H = CONFIG.hidden_size; |
| const I = CONFIG.intermediate_size; |
| const DR = CONFIG.dt_rank; |
| const S = CONFIG.state_size; |
| const V = CONFIG.vocab_size; |
|
|
| |
| this._upload(this.scratch.token_id, new Uint32Array([tokenId])); |
| const embBuf = await this._getF32Weight('backbone.embeddings.weight'); |
| const encoder1 = this.device.createCommandEncoder(); |
| encoder1.copyBufferToBuffer(embBuf, tokenId * H * 4, this.scratch.hidden_state, 0, H * 4); |
| this.device.queue.submit([encoder1.finish()]); |
|
|
| |
| for (let l = 0; l < CONFIG.num_layers; l++) { |
| const prefix = `backbone.layers.${l}`; |
|
|
| |
| const normW = await this._getF32Weight(`${prefix}.norm.weight`); |
| let encoder = this.device.createCommandEncoder(); |
| let pass = encoder.beginComputePass(); |
| pass.setPipeline(this.pipelines['rmsnorm']); |
| pass.setBindGroup(0, this.device.createBindGroup({ |
| layout: this.pipelines['rmsnorm'].getBindGroupLayout(0), |
| entries: [ |
| { binding: 0, resource: { buffer: this.scratch.hidden_state } }, |
| { binding: 1, resource: { buffer: normW } }, |
| { binding: 2, resource: { buffer: this.scratch.norm_out } }, |
| ], |
| })); |
| const rmsnormParams = new ArrayBuffer(12); |
| new DataView(rmsnormParams).setUint32(0, 1, true); |
| new DataView(rmsnormParams).setUint32(4, H, true); |
| new DataView(rmsnormParams).setFloat32(8, CONFIG.rms_eps, true); |
| pass.setBindGroup(1, this.device.createBindGroup({ |
| layout: this.pipelines['rmsnorm'].getBindGroupLayout(1), |
| entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(rmsnormParams)) }}], |
| })); |
| pass.dispatchWorkgroups(1); |
| pass.end(); |
|
|
|
|
|
|
| |
| const inProjW = await this._getF32Weight(`${prefix}.mixer.in_proj.weight`); |
| pass = encoder.beginComputePass(); |
| pass.setPipeline(this.pipelines['matmul_gemv']); |
| pass.setBindGroup(0, this.device.createBindGroup({ |
| layout: this.pipelines['matmul_gemv'].getBindGroupLayout(0), |
| entries: [ |
| { binding: 0, resource: { buffer: this.scratch.norm_out } }, |
| { binding: 1, resource: { buffer: inProjW } }, |
| { binding: 2, resource: { buffer: this.scratch.projected } }, |
| ], |
| })); |
| const gemvParams1 = new ArrayBuffer(8); |
| new DataView(gemvParams1).setUint32(0, I * 2, true); |
| new DataView(gemvParams1).setUint32(4, H, true); |
| pass.setBindGroup(1, this.device.createBindGroup({ |
| layout: this.pipelines['matmul_gemv'].getBindGroupLayout(1), |
| entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(gemvParams1)) }}], |
| })); |
| pass.dispatchWorkgroups(I * 2); |
| pass.end(); |
|
|
| |
| encoder.copyBufferToBuffer(this.scratch.projected, 0, this.scratch.hidden, 0, I * 4); |
| encoder.copyBufferToBuffer(this.scratch.projected, I * 4, this.scratch.gate, 0, I * 4); |
|
|
|
|
|
|
| |
| const conv1dW = await this._getF32Weight(`${prefix}.mixer.conv1d.weight`); |
| const conv1dB = await this._getF32Weight(`${prefix}.mixer.conv1d.bias`); |
| pass = encoder.beginComputePass(); |
| pass.setPipeline(this.pipelines['conv1d_step']); |
| pass.setBindGroup(0, this.device.createBindGroup({ |
| layout: this.pipelines['conv1d_step'].getBindGroupLayout(0), |
| entries: [ |
| { binding: 0, resource: { buffer: this.state[`layer.${l}.conv`] } }, |
| { binding: 1, resource: { buffer: this.scratch.hidden } }, |
| { binding: 2, resource: { buffer: conv1dW } }, |
| { binding: 3, resource: { buffer: conv1dB } }, |
| { binding: 4, resource: { buffer: this.scratch.hidden_c } }, |
| ], |
| })); |
| const conv1dParams = new ArrayBuffer(4); |
| new DataView(conv1dParams).setUint32(0, I, true); |
| pass.setBindGroup(1, this.device.createBindGroup({ |
| layout: this.pipelines['conv1d_step'].getBindGroupLayout(1), |
| entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(conv1dParams)) }}], |
| })); |
| pass.dispatchWorkgroups(Math.ceil(I / 64)); |
| pass.end(); |
|
|
| |
| pass = encoder.beginComputePass(); |
| pass.setPipeline(this.pipelines['silu']); |
| pass.setBindGroup(0, this.device.createBindGroup({ |
| layout: this.pipelines['silu'].getBindGroupLayout(0), |
| entries: [{ binding: 0, resource: { buffer: this.scratch.hidden_c } }], |
| })); |
| const siluParams = new ArrayBuffer(4); |
| new DataView(siluParams).setUint32(0, I, true); |
| pass.setBindGroup(1, this.device.createBindGroup({ |
| layout: this.pipelines['silu'].getBindGroupLayout(1), |
| entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(siluParams)) }}], |
| })); |
| pass.dispatchWorkgroups(Math.ceil(I / 64)); |
| pass.end(); |
| |
|
|
|
|
|
|
| |
| const xProjW = await this._getF32Weight(`${prefix}.mixer.x_proj.weight`); |
| pass = encoder.beginComputePass(); |
| pass.setPipeline(this.pipelines['matmul_gemv']); |
| pass.setBindGroup(0, this.device.createBindGroup({ |
| layout: this.pipelines['matmul_gemv'].getBindGroupLayout(0), |
| entries: [ |
| { binding: 0, resource: { buffer: this.scratch.hidden_c } }, |
| { binding: 1, resource: { buffer: xProjW } }, |
| { binding: 2, resource: { buffer: this.scratch.sxBC } }, |
| ], |
| })); |
| const gemvParams2 = new ArrayBuffer(8); |
| new DataView(gemvParams2).setUint32(0, DR + 2 * S, true); |
| new DataView(gemvParams2).setUint32(4, I, true); |
| pass.setBindGroup(1, this.device.createBindGroup({ |
| layout: this.pipelines['matmul_gemv'].getBindGroupLayout(1), |
| entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(gemvParams2)) }}], |
| })); |
| pass.dispatchWorkgroups(DR + 2 * S); |
| pass.end(); |
|
|
| |
| encoder.copyBufferToBuffer(this.scratch.sxBC, 0, this.scratch.dt_pre, 0, DR * 4); |
| encoder.copyBufferToBuffer(this.scratch.sxBC, DR * 4, this.scratch.B_proj, 0, S * 4); |
| encoder.copyBufferToBuffer(this.scratch.sxBC, (DR + S) * 4, this.scratch.C_proj, 0, S * 4); |
|
|
| |
| const rmsNwParams_dt = new ArrayBuffer(8); |
| new DataView(rmsNwParams_dt).setUint32(0, DR, true); |
| new DataView(rmsNwParams_dt).setFloat32(4, CONFIG.rms_eps, true); |
| pass = encoder.beginComputePass(); |
| pass.setPipeline(this.pipelines['rmsnorm_noweight']); |
| pass.setBindGroup(0, this.device.createBindGroup({ |
| layout: this.pipelines['rmsnorm_noweight'].getBindGroupLayout(0), |
| entries: [{ binding: 0, resource: { buffer: this.scratch.dt_pre } }], |
| })); |
| pass.setBindGroup(1, this.device.createBindGroup({ |
| layout: this.pipelines['rmsnorm_noweight'].getBindGroupLayout(1), |
| entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(rmsNwParams_dt)) }}], |
| })); |
| pass.dispatchWorkgroups(1); |
| pass.end(); |
|
|
| const rmsNwParams_s = new ArrayBuffer(8); |
| new DataView(rmsNwParams_s).setUint32(0, S, true); |
| new DataView(rmsNwParams_s).setFloat32(4, CONFIG.rms_eps, true); |
| pass = encoder.beginComputePass(); |
| pass.setPipeline(this.pipelines['rmsnorm_noweight']); |
| pass.setBindGroup(0, this.device.createBindGroup({ |
| layout: this.pipelines['rmsnorm_noweight'].getBindGroupLayout(0), |
| entries: [{ binding: 0, resource: { buffer: this.scratch.B_proj } }], |
| })); |
| pass.setBindGroup(1, this.device.createBindGroup({ |
| layout: this.pipelines['rmsnorm_noweight'].getBindGroupLayout(1), |
| entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(rmsNwParams_s)) }}], |
| })); |
| pass.dispatchWorkgroups(1); |
| pass.end(); |
|
|
| pass = encoder.beginComputePass(); |
| pass.setPipeline(this.pipelines['rmsnorm_noweight']); |
| pass.setBindGroup(0, this.device.createBindGroup({ |
| layout: this.pipelines['rmsnorm_noweight'].getBindGroupLayout(0), |
| entries: [{ binding: 0, resource: { buffer: this.scratch.C_proj } }], |
| })); |
| pass.setBindGroup(1, this.device.createBindGroup({ |
| layout: this.pipelines['rmsnorm_noweight'].getBindGroupLayout(1), |
| entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(rmsNwParams_s)) }}], |
| })); |
| pass.dispatchWorkgroups(1); |
| pass.end(); |
|
|
| |
| const dtProjW = await this._getF32Weight(`${prefix}.mixer.dt_proj.weight`); |
| pass = encoder.beginComputePass(); |
| pass.setPipeline(this.pipelines['matmul_gemv']); |
| pass.setBindGroup(0, this.device.createBindGroup({ |
| layout: this.pipelines['matmul_gemv'].getBindGroupLayout(0), |
| entries: [ |
| { binding: 0, resource: { buffer: this.scratch.dt_pre } }, |
| { binding: 1, resource: { buffer: dtProjW } }, |
| { binding: 2, resource: { buffer: this.scratch.dt } }, |
| ], |
| })); |
| const gemvParams3 = new ArrayBuffer(8); |
| new DataView(gemvParams3).setUint32(0, I, true); |
| new DataView(gemvParams3).setUint32(4, DR, true); |
| pass.setBindGroup(1, this.device.createBindGroup({ |
| layout: this.pipelines['matmul_gemv'].getBindGroupLayout(1), |
| entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(gemvParams3)) }}], |
| })); |
| pass.dispatchWorkgroups(I); |
| pass.end(); |
|
|
| |
| |
| const aLog = await this._getF32Weight(`${prefix}.mixer.A_log`); |
| const dWeight = await this._getF32Weight(`${prefix}.mixer.D`); |
| const dtBias = await this._getF32Weight(`${prefix}.mixer.dt_proj.bias`); |
| pass = encoder.beginComputePass(); |
| pass.setPipeline(this.pipelines['ssu']); |
| pass.setBindGroup(0, this.device.createBindGroup({ |
| layout: this.pipelines['ssu'].getBindGroupLayout(0), |
| entries: [ |
| { binding: 0, resource: { buffer: this.state[`layer.${l}.ssm`] } }, |
| { binding: 1, resource: { buffer: this.scratch.hidden_c } }, |
| { binding: 2, resource: { buffer: this.scratch.dt } }, |
| { binding: 3, resource: { buffer: aLog } }, |
| { binding: 4, resource: { buffer: this.scratch.B_proj } }, |
| { binding: 5, resource: { buffer: this.scratch.C_proj } }, |
| { binding: 6, resource: { buffer: dWeight } }, |
| { binding: 7, resource: { buffer: dtBias } }, |
| { binding: 8, resource: { buffer: this.scratch.hidden_y } }, |
| ], |
| })); |
| const ssuParams = new ArrayBuffer(8); |
| new DataView(ssuParams).setUint32(0, I, true); |
| new DataView(ssuParams).setUint32(4, S, true); |
| pass.setBindGroup(1, this.device.createBindGroup({ |
| layout: this.pipelines['ssu'].getBindGroupLayout(1), |
| entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(ssuParams)) }}], |
| })); |
| pass.dispatchWorkgroups(I); |
| pass.end(); |
|
|
|
|
|
|
| |
| pass = encoder.beginComputePass(); |
| pass.setPipeline(this.pipelines['silu']); |
| pass.setBindGroup(0, this.device.createBindGroup({ |
| layout: this.pipelines['silu'].getBindGroupLayout(0), |
| entries: [{ binding: 0, resource: { buffer: this.scratch.gate } }], |
| })); |
| pass.setBindGroup(1, this.device.createBindGroup({ |
| layout: this.pipelines['silu'].getBindGroupLayout(1), |
| entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(siluParams)) }}], |
| })); |
| pass.dispatchWorkgroups(Math.ceil(I / 64)); |
| pass.end(); |
|
|
| |
| pass = encoder.beginComputePass(); |
| pass.setPipeline(this.pipelines['elementwise_mul']); |
| pass.setBindGroup(0, this.device.createBindGroup({ |
| layout: this.pipelines['elementwise_mul'].getBindGroupLayout(0), |
| entries: [ |
| { binding: 0, resource: { buffer: this.scratch.hidden_y } }, |
| { binding: 1, resource: { buffer: this.scratch.gate } }, |
| ], |
| })); |
| pass.setBindGroup(1, this.device.createBindGroup({ |
| layout: this.pipelines['elementwise_mul'].getBindGroupLayout(1), |
| entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(siluParams)) }}], |
| })); |
| pass.dispatchWorkgroups(Math.ceil(I / 64)); |
| pass.end(); |
|
|
| |
| const outProjW = await this._getF32Weight(`${prefix}.mixer.out_proj.weight`); |
| pass = encoder.beginComputePass(); |
| pass.setPipeline(this.pipelines['matmul_gemv']); |
| pass.setBindGroup(0, this.device.createBindGroup({ |
| layout: this.pipelines['matmul_gemv'].getBindGroupLayout(0), |
| entries: [ |
| { binding: 0, resource: { buffer: this.scratch.hidden_y } }, |
| { binding: 1, resource: { buffer: outProjW } }, |
| { binding: 2, resource: { buffer: this.scratch.out_proj_o } }, |
| ], |
| })); |
| const gemvParams4 = new ArrayBuffer(8); |
| new DataView(gemvParams4).setUint32(0, H, true); |
| new DataView(gemvParams4).setUint32(4, I, true); |
| pass.setBindGroup(1, this.device.createBindGroup({ |
| layout: this.pipelines['matmul_gemv'].getBindGroupLayout(1), |
| entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(gemvParams4)) }}], |
| })); |
| pass.dispatchWorkgroups(H); |
| pass.end(); |
|
|
| |
| this.device.queue.submit([encoder.finish()]); |
|
|
| |
| |
| { |
| const enc2 = this.device.createCommandEncoder(); |
| const addPass = enc2.beginComputePass(); |
| addPass.setPipeline(this.pipelines['add_residual']); |
| addPass.setBindGroup(0, this.device.createBindGroup({ |
| layout: this.pipelines['add_residual'].getBindGroupLayout(0), |
| entries: [ |
| { binding: 0, resource: { buffer: this.scratch.hidden_state } }, |
| { binding: 1, resource: { buffer: this.scratch.out_proj_o } }, |
| ], |
| })); |
| const addParams = new ArrayBuffer(4); |
| new DataView(addParams).setUint32(0, H, true); |
| addPass.setBindGroup(1, this.device.createBindGroup({ |
| layout: this.pipelines['add_residual'].getBindGroupLayout(1), |
| entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(addParams)) }}], |
| })); |
| addPass.dispatchWorkgroups(Math.ceil(H / 64)); |
| addPass.end(); |
| this.device.queue.submit([enc2.finish()]); |
| } |
|
|
|
|
| } |
|
|
| |
| await this.device.queue.onSubmittedWorkDone(); |
|
|
| |
| const normFW = await this._getF32Weight('backbone.norm_f.weight'); |
| let encoder = this.device.createCommandEncoder(); |
| let pass = encoder.beginComputePass(); |
| pass.setPipeline(this.pipelines['rmsnorm']); |
| pass.setBindGroup(0, this.device.createBindGroup({ |
| layout: this.pipelines['rmsnorm'].getBindGroupLayout(0), |
| entries: [ |
| { binding: 0, resource: { buffer: this.scratch.hidden_state } }, |
| { binding: 1, resource: { buffer: normFW } }, |
| { binding: 2, resource: { buffer: this.scratch.norm_out } }, |
| ], |
| })); |
| const finalNormParams = new ArrayBuffer(12); |
| new DataView(finalNormParams).setUint32(0, 1, true); |
| new DataView(finalNormParams).setUint32(4, H, true); |
| new DataView(finalNormParams).setFloat32(8, CONFIG.rms_eps, true); |
| pass.setBindGroup(1, this.device.createBindGroup({ |
| layout: this.pipelines['rmsnorm'].getBindGroupLayout(1), |
| entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(finalNormParams)) }}], |
| })); |
| pass.dispatchWorkgroups(1); |
| pass.end(); |
|
|
| |
| const lmHeadW = await this._getF32Weight('lm_head.weight'); |
| pass = encoder.beginComputePass(); |
| pass.setPipeline(this.pipelines['matmul_gemv']); |
| pass.setBindGroup(0, this.device.createBindGroup({ |
| layout: this.pipelines['matmul_gemv'].getBindGroupLayout(0), |
| entries: [ |
| { binding: 0, resource: { buffer: this.scratch.norm_out } }, |
| { binding: 1, resource: { buffer: lmHeadW } }, |
| { binding: 2, resource: { buffer: this.scratch.logits } }, |
| ], |
| })); |
| const gemvFinal = new ArrayBuffer(8); |
| new DataView(gemvFinal).setUint32(0, V, true); |
| new DataView(gemvFinal).setUint32(4, H, true); |
| pass.setBindGroup(1, this.device.createBindGroup({ |
| layout: this.pipelines['matmul_gemv'].getBindGroupLayout(1), |
| entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(gemvFinal)) }}], |
| })); |
| pass.dispatchWorkgroups(V); |
| pass.end(); |
|
|
| |
| pass = encoder.beginComputePass(); |
| pass.setPipeline(this.pipelines['sample']); |
| pass.setBindGroup(0, this.device.createBindGroup({ |
| layout: this.pipelines['sample'].getBindGroupLayout(0), |
| entries: [ |
| { binding: 0, resource: { buffer: this.scratch.logits } }, |
| { binding: 1, resource: { buffer: this.scratch.token_out } }, |
| ], |
| })); |
| const sampleParams = new ArrayBuffer(12); |
| new DataView(sampleParams).setUint32(0, V, true); |
| new DataView(sampleParams).setFloat32(4, 1.0 / 0.75, true); |
| new DataView(sampleParams).setUint32(8, Math.floor(Math.random() * 0xFFFFFFFF), true); |
| pass.setBindGroup(1, this.device.createBindGroup({ |
| layout: this.pipelines['sample'].getBindGroupLayout(1), |
| entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(sampleParams)) }}], |
| })); |
| pass.dispatchWorkgroups(1); |
| pass.end(); |
|
|
| this.device.queue.submit([encoder.finish()]); |
| await this.device.queue.onSubmittedWorkDone(); |
|
|
| |
| const tokenResult = await this._readback(this.scratch.token_out, 4); |
| this._tokenCount++; |
| return new Uint32Array(tokenResult.buffer)[0]; |
| } |
|
|
| |
| async tokenize(text, system = '') { |
| const resp = await fetch('/tokenize', { |
| method: 'POST', |
| headers: { 'Content-Type': 'application/json' }, |
| body: JSON.stringify({ text, system }), |
| }); |
| const data = await resp.json(); |
| return data.result; |
| } |
|
|
| async detokenize(tokens) { |
| const resp = await fetch('/detokenize', { |
| method: 'POST', |
| headers: { 'Content-Type': 'application/json' }, |
| body: JSON.stringify({ tokens }), |
| }); |
| const data = await resp.json(); |
| return data.result; |
| } |
|
|
| |
| async generate(prompt, maxTokens = 100, temperature = 0.75, onToken = null, system = '') { |
| if (!this.ready) throw new Error('Call loadWeights() first'); |
| this._allocateScratch(); |
|
|
| console.log('[mamba] generate:', prompt, 'max_tokens:', maxTokens); |
|
|
| |
| const promptTokens = await this.tokenize(prompt, system); |
| console.log(`[mamba] prompt tokens (${promptTokens.length}):`, promptTokens); |
|
|
| |
| console.log('[mamba] encoding prompt...'); |
| for (let i = 0; i < promptTokens.length; i++) { |
| const t0 = performance.now(); |
| await this._forwardOneToken(promptTokens[i]); |
| const elapsed = performance.now() - t0; |
| if (i === 0 || i === promptTokens.length - 1) { |
| console.log(`[mamba] prompt token ${i}/${promptTokens.length}: ${promptTokens[i]} (${elapsed.toFixed(0)}ms)`); |
| } |
| } |
| console.log('[mamba] prompt encoded, generating...'); |
|
|
| |
| const generated = []; |
| |
| |
| const firstResult = await this._readback(this.scratch.token_out, 4); |
| let inputToken = new Uint32Array(firstResult.buffer)[0]; |
| generated.push(inputToken); |
| console.log(`[mamba] first generated token: ${inputToken}`); |
| if (onToken) onToken(inputToken, 0); |
|
|
| for (let step = 1; step < maxTokens; step++) { |
| const t0 = performance.now(); |
| try { |
| const nextToken = await this._forwardOneToken(inputToken); |
| const elapsed = performance.now() - t0; |
| if (step < 5 || step % 20 === 0) { |
| console.log(`[mamba] step ${step}: token=${nextToken} (${elapsed.toFixed(0)}ms)`); |
| } |
| generated.push(nextToken); |
| inputToken = nextToken; |
| if (onToken) onToken(nextToken, step); |
| if (nextToken === 11 || nextToken === 10 || nextToken === 0) break; |
| } catch (e) { |
| console.error(`[mamba] step ${step} failed:`, e.message); |
| break; |
| } |
| } |
|
|
| |
| const text = await this.detokenize(generated); |
| console.log(`[mamba] generated ${generated.length} tokens`); |
| return text; |
| } |
|
|
| |
| _getWeight(name) { |
| return this.weights[name] || null; |
| } |
|
|
| |
| _createUniform(data) { |
| const buf = this.device.createBuffer({ |
| size: data.byteLength, |
| usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST, |
| mappedAtCreation: true, |
| }); |
| new Uint8Array(buf.getMappedRange()).set(new Uint8Array(data.buffer)); |
| buf.unmap(); |
| return buf; |
| } |
|
|
| |
| async _convertBF16toF32(weightInfo) { |
| if (weightInfo.dtype !== 'BF16' || weightInfo.f32buffer) return weightInfo; |
|
|
| const numBF16 = weightInfo.byteLen / 2; |
| const numPairs = weightInfo.byteLen / 4; |
| const f32Bytes = numBF16 * 4; |
|
|
| |
| const f32Buf = this._createBuffer( |
| f32Bytes, |
| GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, |
| weightInfo.buffer.label + '_f32' |
| ); |
|
|
| |
| const paramBuf = this._createUniform(new Uint32Array([numPairs])); |
|
|
| |
| const pipeline = this.pipelines['bf16_to_f32']; |
| const bg0 = this.device.createBindGroup({ |
| layout: pipeline.getBindGroupLayout(0), |
| entries: [ |
| { binding: 0, resource: { buffer: weightInfo.buffer } }, |
| { binding: 1, resource: { buffer: f32Buf } }, |
| ], |
| }); |
| const bg1 = this.device.createBindGroup({ |
| layout: pipeline.getBindGroupLayout(1), |
| entries: [ |
| { binding: 0, resource: { buffer: paramBuf } }, |
| ], |
| }); |
|
|
| |
| const encoder = this.device.createCommandEncoder(); |
| const pass = encoder.beginComputePass(); |
| pass.setPipeline(pipeline); |
| pass.setBindGroup(0, bg0); |
| pass.setBindGroup(1, bg1); |
| pass.dispatchWorkgroups(Math.ceil(numPairs / 64)); |
| pass.end(); |
| this.device.queue.submit([encoder.finish()]); |
| await this.device.queue.onSubmittedWorkDone(); |
|
|
| |
| weightInfo.f32buffer = f32Buf; |
| weightInfo.f32size = f32Bytes; |
| paramBuf.destroy(); |
| return weightInfo; |
| } |
|
|
| |
| async _getF32Weight(name) { |
| const w = this.weights[name]; |
| if (!w) throw new Error(`Missing weight: ${name}`); |
| return w.buffer; |
| } |
|
|
| |
| _dispatchShader(encoder, shaderName, storageBuffers, uniformData) { |
| const pipeline = this.pipelines[shaderName]; |
| const pass = encoder.beginComputePass(); |
| pass.setPipeline(pipeline); |
|
|
| |
| const entries0 = storageBuffers.map((buf, i) => ({ |
| binding: i, resource: { buffer: buf } |
| })); |
| const bg0 = this.device.createBindGroup({ |
| layout: pipeline.getBindGroupLayout(0), |
| entries: entries0, |
| }); |
| pass.setBindGroup(0, bg0); |
|
|
| |
| if (uniformData) { |
| const ubuf = this._createUniform(uniformData); |
| const bg1 = this.device.createBindGroup({ |
| layout: pipeline.getBindGroupLayout(1), |
| entries: [{ binding: 0, resource: { buffer: ubuf } }], |
| }); |
| pass.setBindGroup(1, bg1); |
| |
| } |
|
|
| return pass; |
| } |
|
|
| |
| destroy() { |
| for (const w of Object.values(this.weights)) w.buffer.destroy(); |
| for (const s of Object.values(this.state)) s.destroy(); |
| this.weights = {}; |
| this.state = {}; |
| this.ready = false; |
| } |
| } |
|
|
| |
| export { MambaRuntime, CONFIG }; |
|
|