mamba-webgpu / mamba_runtime.js
LJTSG's picture
Upload mamba_runtime.js with huggingface_hub
1a8f20a verified
/**
* mamba_runtime.js β€” Browser-native Falcon-Mamba inference via WebGPU.
*
* The first browser-native Mamba/SSM inference engine.
* No MLC, no TVM β€” pure WebGPU compute shaders ported from gfx1151_runtime.
*
* Architecture: Falcon-Mamba 7B
* 64 layers, each: RMSNorm β†’ in_proj β†’ conv1d β†’ SSU β†’ out_proj
* Final: RMSNorm β†’ lm_head β†’ sample
*
* Weight format: safetensors (HF standard), loaded directly into WebGPU buffers.
* Shaders: WGSL compute shaders in ./shaders/ (ported from Vulkan GLSL).
*
* Usage:
* const mamba = new MambaRuntime();
* await mamba.init();
* await mamba.loadWeights('./weights/');
* const text = await mamba.generate("Hello Grandma", 100);
*/
// Falcon-Mamba 7B constants
const CONFIG = {
hidden_size: 4096,
intermediate_size: 8192, // 2 * hidden
num_layers: 64,
vocab_size: 65024,
state_size: 16, // SSM d_state
conv_kernel: 4,
dt_rank: 256,
rms_eps: 1e-5,
};
class MambaRuntime {
constructor() {
this.device = null;
this.pipelines = {}; // shader name β†’ GPUComputePipeline
this.bindLayouts = {}; // shader name β†’ GPUBindGroupLayout
this.weights = {}; // parameter name β†’ GPUBuffer
this.state = {}; // per-layer SSM state + conv1d state buffers
this.ready = false;
}
// ── Init: get WebGPU device + compile all shaders ──────────────────────
async init() {
if (!navigator.gpu) throw new Error('WebGPU not supported in this browser');
const adapter = await navigator.gpu.requestAdapter();
if (!adapter) throw new Error('No WebGPU adapter found');
// Request max buffer size the device supports
const limits = adapter.limits;
console.log('[mamba] maxBufferSize:', limits.maxBufferSize,
'=', (limits.maxBufferSize / 1024 / 1024 / 1024).toFixed(2), 'GB');
this.device = await adapter.requestDevice({
requiredLimits: {
maxBufferSize: limits.maxBufferSize,
maxStorageBufferBindingSize: limits.maxStorageBufferBindingSize,
maxComputeWorkgroupStorageSize: limits.maxComputeWorkgroupStorageSize,
maxStorageBuffersPerShaderStage: Math.min(limits.maxStorageBuffersPerShaderStage, 16),
}
});
this.device.lost.then((info) => {
console.error('[mamba] DEVICE LOST:', info.reason, info.message);
});
this.device.addEventListener('uncapturederror', (e) => {
console.error('[mamba] GPU ERROR:', e.error.message);
});
console.log('[mamba] device ready, compiling shaders...');
await this._compileShaders();
console.log('[mamba] shaders compiled');
return this;
}
// ── Compile all WGSL shaders into compute pipelines ────────────────────
async _compileShaders() {
const shaderNames = [
'conv1d_step', 'ssu', 'matmul_gemv', 'rmsnorm', 'rmsnorm_noweight',
'silu', 'softplus', 'embedding', 'elementwise_mul', 'sample',
'bf16_to_f32', 'add_residual'
];
for (const name of shaderNames) {
const resp = await fetch(`./shaders/${name}.wgsl`);
if (!resp.ok) throw new Error(`Failed to load shader: ${name}.wgsl`);
const code = await resp.text();
const shaderModule = this.device.createShaderModule({ code, label: name });
// Create bind group layouts based on shader requirements
// Group 0 = storage buffers (data), Group 1 = uniforms (params)
const pipeline = this.device.createComputePipeline({
layout: 'auto',
compute: { module: shaderModule, entryPoint: 'main' },
label: name,
});
this.pipelines[name] = pipeline;
}
}
// ── Create a GPU buffer ────────────────────────────────────────────────
_createBuffer(size, usage, label) {
return this.device.createBuffer({
size: Math.max(size, 4), // WebGPU requires min 4 bytes
usage,
label,
mappedAtCreation: false,
});
}
// ── Upload data to a GPU buffer ────────────────────────────────────────
_upload(buffer, data) {
this.device.queue.writeBuffer(buffer, 0, data);
}
// ── Read data back from GPU buffer ─────────────────────────────────────
async _readback(buffer, size) {
const staging = this.device.createBuffer({
size,
usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST,
});
const encoder = this.device.createCommandEncoder();
encoder.copyBufferToBuffer(buffer, 0, staging, 0, size);
this.device.queue.submit([encoder.finish()]);
await staging.mapAsync(GPUMapMode.READ);
const result = new Float32Array(staging.getMappedRange().slice(0));
staging.unmap();
staging.destroy();
return result;
}
// ── Dispatch a compute shader ──────────────────────────────────────────
_dispatch(shaderName, bindGroup, uniformBindGroup, workgroupsX, workgroupsY = 1, workgroupsZ = 1) {
const encoder = this.device.createCommandEncoder();
const pass = encoder.beginComputePass();
pass.setPipeline(this.pipelines[shaderName]);
pass.setBindGroup(0, bindGroup);
if (uniformBindGroup) pass.setBindGroup(1, uniformBindGroup);
pass.dispatchWorkgroups(workgroupsX, workgroupsY, workgroupsZ);
pass.end();
this.device.queue.submit([encoder.finish()]);
}
// ── Load safetensors weights into GPU buffers ──────────────────────────
async loadWeights(basePath) {
console.log('[mamba] loading weights from', basePath);
// Get the shard index
const indexResp = await fetch(`${basePath}/model.safetensors.index.json`);
let fileMap; // tensor_name β†’ filename
let files;
if (indexResp.ok) {
const index = await indexResp.json();
fileMap = index.weight_map;
files = [...new Set(Object.values(fileMap))];
console.log(`[mamba] multi-shard: ${files.length} files, ${Object.keys(fileMap).length} tensors`);
} else {
files = ['model.safetensors'];
fileMap = null;
}
// For each shard, fetch ONLY the header first (small), then load tensors by byte-range
for (const file of files) {
console.log(`[mamba] parsing ${file} header...`);
// Fetch first 8 bytes to get header length
const headResp = await fetch(`${basePath}/${file}`, {
headers: { 'Range': 'bytes=0-7' }
});
let headerLen;
if (headResp.status === 206) {
// Range request supported
const headBuf = await headResp.arrayBuffer();
headerLen = new DataView(headBuf).getUint32(0, true);
} else {
// Range not supported β€” fall back to full fetch but only read header
const fullBuf = await headResp.arrayBuffer();
headerLen = new DataView(fullBuf).getUint32(0, true);
}
console.log(`[mamba] header: ${headerLen} bytes`);
// Fetch header JSON
const hdrResp = await fetch(`${basePath}/${file}`, {
headers: { 'Range': `bytes=8-${8 + headerLen - 1}` }
});
let headerStr;
if (hdrResp.status === 206) {
headerStr = await hdrResp.text();
} else {
const fullBuf = await hdrResp.arrayBuffer();
headerStr = new TextDecoder().decode(new Uint8Array(fullBuf, 8, headerLen));
}
const header = JSON.parse(headerStr);
const dataOffset = 8 + headerLen;
// Load each tensor individually
const tensorNames = Object.keys(header).filter(n => n !== '__metadata__');
console.log(`[mamba] ${tensorNames.length} tensors in this shard`);
let loaded = 0;
for (const name of tensorNames) {
const meta = header[name];
const dtype = meta.dtype;
const shape = meta.shape;
const [start, end] = meta.data_offsets;
const byteLen = end - start;
if (byteLen > 2_000_000_000) {
console.log(`[mamba] SKIP ${name} (${(byteLen/1e9).toFixed(2)} GB β€” exceeds buffer limit)`);
continue;
}
// Fetch this tensor's bytes via Range request
const absStart = dataOffset + start;
const absEnd = dataOffset + end - 1;
const tResp = await fetch(`${basePath}/${file}`, {
headers: { 'Range': `bytes=${absStart}-${absEnd}` }
});
let tensorBuf;
if (tResp.status === 206) {
tensorBuf = await tResp.arrayBuffer();
} else {
// No range support β€” need full file (expensive)
console.log(`[mamba] WARN: no range support, loading full file for ${name}`);
const fullBuf = await tResp.arrayBuffer();
tensorBuf = fullBuf.slice(absStart, absStart + byteLen);
}
// For BF16 weights: convert to F32 during upload (no double-buffering)
let gpuBuf;
let finalDtype = dtype;
let finalByteLen = byteLen;
if (dtype === 'BF16') {
// Convert CPU-side: BF16 β†’ F32 before uploading
const bf16 = new Uint16Array(tensorBuf);
const f32 = new Float32Array(bf16.length);
const tmpU32 = new Uint32Array(f32.buffer);
for (let j = 0; j < bf16.length; j++) {
tmpU32[j] = bf16[j] << 16; // BF16 is top 16 bits of F32
}
finalByteLen = f32.byteLength;
gpuBuf = this._createBuffer(
finalByteLen,
GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST,
name
);
this._upload(gpuBuf, f32);
finalDtype = 'F32';
} else {
gpuBuf = this._createBuffer(
byteLen,
GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST,
name
);
this._upload(gpuBuf, new Uint8Array(tensorBuf));
}
this.weights[name] = { buffer: gpuBuf, shape, dtype: finalDtype, byteLen: finalByteLen };
loaded++;
if (loaded % 20 === 0) {
console.log(`[mamba] loaded ${loaded}/${tensorNames.length} tensors`);
}
}
console.log(`[mamba] shard done: ${loaded} tensors loaded`);
}
console.log(`[mamba] TOTAL: ${Object.keys(this.weights).length} tensors loaded`);
// Allocate per-layer state buffers
this._allocateState();
this.ready = true;
}
// ── Allocate persistent SSM state + conv1d cache per layer ─────────────
_allocateState() {
const H = CONFIG.intermediate_size; // 8192
const S = CONFIG.state_size; // 16
const K = CONFIG.conv_kernel; // 4
for (let l = 0; l < CONFIG.num_layers; l++) {
// SSM state: [H, S] = 8192 * 16 = 131072 floats = 512 KB per layer
this.state[`layer.${l}.ssm`] = this._createBuffer(
H * S * 4,
GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST,
`ssm_state_${l}`
);
// Conv1d cache: [H, K-1] = 8192 * 3 = 24576 floats = 96 KB per layer
this.state[`layer.${l}.conv`] = this._createBuffer(
H * (K - 1) * 4,
GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST,
`conv_state_${l}`
);
}
// Total state: 64 layers Γ— (512 + 96) KB = ~38 MB
console.log(`[mamba] allocated ${CONFIG.num_layers} layers of SSM + conv1d state (~38 MB)`);
}
// ── Save/restore SSM state (the entity's persistent soul) ──────────────
async saveState() {
const state = {};
for (const [key, buf] of Object.entries(this.state)) {
state[key] = await this._readback(buf, buf.size);
}
return state;
}
async restoreState(state) {
for (const [key, data] of Object.entries(state)) {
if (this.state[key]) {
this._upload(this.state[key], data);
}
}
}
// ── Allocate intermediate scratch buffers for forward pass ──────────────
_allocateScratch() {
if (this.scratch) return; // already allocated
const H = CONFIG.hidden_size; // 4096
const I = CONFIG.intermediate_size; // 8192
const DR = CONFIG.dt_rank; // 256
const S = CONFIG.state_size; // 16
const F = 4; // sizeof(float32)
this.scratch = {
norm_out: this._createBuffer(H * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'norm_out'),
projected: this._createBuffer(2 * I * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'projected'),
hidden: this._createBuffer(I * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'hidden'),
gate: this._createBuffer(I * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'gate'),
hidden_c: this._createBuffer(I * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'hidden_c'),
sxBC: this._createBuffer((DR + 2*S)*F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'sxBC'),
B_proj: this._createBuffer(S * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'B_proj'),
C_proj: this._createBuffer(S * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'C_proj'),
dt_pre: this._createBuffer(DR * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'dt_pre'),
dt: this._createBuffer(I * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'dt'),
hidden_y: this._createBuffer(I * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'hidden_y'),
gate_silu: this._createBuffer(I * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'gate_silu'),
out_proj_o: this._createBuffer(H * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'out_proj_o'),
logits: this._createBuffer(CONFIG.vocab_size * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'logits'),
token_out: this._createBuffer(4, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'token_out'),
hidden_state: this._createBuffer(H * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'hidden_state'),
token_id: this._createBuffer(4, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'token_id'),
};
this._tokenCount = 0;
console.log('[mamba] scratch buffers allocated');
}
// ── Single-token forward pass through all 64 layers ─────────────────────
async _forwardOneToken(tokenId) {
const H = CONFIG.hidden_size; // 4096
const I = CONFIG.intermediate_size; // 8192
const DR = CONFIG.dt_rank; // 256
const S = CONFIG.state_size; // 16
const V = CONFIG.vocab_size; // 65024
// Step 1: Embedding lookup β€” copy one row from embedding table to hidden_state
this._upload(this.scratch.token_id, new Uint32Array([tokenId]));
const embBuf = await this._getF32Weight('backbone.embeddings.weight');
const encoder1 = this.device.createCommandEncoder();
encoder1.copyBufferToBuffer(embBuf, tokenId * H * 4, this.scratch.hidden_state, 0, H * 4);
this.device.queue.submit([encoder1.finish()]);
// Step 2: For each layer (0..63)
for (let l = 0; l < CONFIG.num_layers; l++) {
const prefix = `backbone.layers.${l}`;
// rmsnorm(hidden_state, norm.weight) β†’ norm_out
const normW = await this._getF32Weight(`${prefix}.norm.weight`);
let encoder = this.device.createCommandEncoder();
let pass = encoder.beginComputePass();
pass.setPipeline(this.pipelines['rmsnorm']);
pass.setBindGroup(0, this.device.createBindGroup({
layout: this.pipelines['rmsnorm'].getBindGroupLayout(0),
entries: [
{ binding: 0, resource: { buffer: this.scratch.hidden_state } },
{ binding: 1, resource: { buffer: normW } },
{ binding: 2, resource: { buffer: this.scratch.norm_out } },
],
}));
const rmsnormParams = new ArrayBuffer(12);
new DataView(rmsnormParams).setUint32(0, 1, true);
new DataView(rmsnormParams).setUint32(4, H, true);
new DataView(rmsnormParams).setFloat32(8, CONFIG.rms_eps, true);
pass.setBindGroup(1, this.device.createBindGroup({
layout: this.pipelines['rmsnorm'].getBindGroupLayout(1),
entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(rmsnormParams)) }}],
}));
pass.dispatchWorkgroups(1); // one workgroup per row, 1 row
pass.end();
// matmul_gemv(norm_out, in_proj.weight) β†’ projected [I*2 = 16384]
const inProjW = await this._getF32Weight(`${prefix}.mixer.in_proj.weight`);
pass = encoder.beginComputePass();
pass.setPipeline(this.pipelines['matmul_gemv']);
pass.setBindGroup(0, this.device.createBindGroup({
layout: this.pipelines['matmul_gemv'].getBindGroupLayout(0),
entries: [
{ binding: 0, resource: { buffer: this.scratch.norm_out } },
{ binding: 1, resource: { buffer: inProjW } },
{ binding: 2, resource: { buffer: this.scratch.projected } },
],
}));
const gemvParams1 = new ArrayBuffer(8);
new DataView(gemvParams1).setUint32(0, I * 2, true); // N
new DataView(gemvParams1).setUint32(4, H, true); // K
pass.setBindGroup(1, this.device.createBindGroup({
layout: this.pipelines['matmul_gemv'].getBindGroupLayout(1),
entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(gemvParams1)) }}],
}));
pass.dispatchWorkgroups(I * 2); // one workgroup per output element
pass.end();
// Split projected β†’ hidden[0:I], gate[I:2I] via buffer copies
encoder.copyBufferToBuffer(this.scratch.projected, 0, this.scratch.hidden, 0, I * 4);
encoder.copyBufferToBuffer(this.scratch.projected, I * 4, this.scratch.gate, 0, I * 4);
// conv1d_step(conv_state, hidden, conv1d.weight, conv1d.bias) β†’ hidden_c
const conv1dW = await this._getF32Weight(`${prefix}.mixer.conv1d.weight`);
const conv1dB = await this._getF32Weight(`${prefix}.mixer.conv1d.bias`);
pass = encoder.beginComputePass();
pass.setPipeline(this.pipelines['conv1d_step']);
pass.setBindGroup(0, this.device.createBindGroup({
layout: this.pipelines['conv1d_step'].getBindGroupLayout(0),
entries: [
{ binding: 0, resource: { buffer: this.state[`layer.${l}.conv`] } },
{ binding: 1, resource: { buffer: this.scratch.hidden } },
{ binding: 2, resource: { buffer: conv1dW } },
{ binding: 3, resource: { buffer: conv1dB } },
{ binding: 4, resource: { buffer: this.scratch.hidden_c } },
],
}));
const conv1dParams = new ArrayBuffer(4);
new DataView(conv1dParams).setUint32(0, I, true);
pass.setBindGroup(1, this.device.createBindGroup({
layout: this.pipelines['conv1d_step'].getBindGroupLayout(1),
entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(conv1dParams)) }}],
}));
pass.dispatchWorkgroups(Math.ceil(I / 64));
pass.end();
// silu(hidden_c) in-place β†’ hidden_a
pass = encoder.beginComputePass();
pass.setPipeline(this.pipelines['silu']);
pass.setBindGroup(0, this.device.createBindGroup({
layout: this.pipelines['silu'].getBindGroupLayout(0),
entries: [{ binding: 0, resource: { buffer: this.scratch.hidden_c } }],
}));
const siluParams = new ArrayBuffer(4);
new DataView(siluParams).setUint32(0, I, true);
pass.setBindGroup(1, this.device.createBindGroup({
layout: this.pipelines['silu'].getBindGroupLayout(1),
entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(siluParams)) }}],
}));
pass.dispatchWorkgroups(Math.ceil(I / 64));
pass.end();
// hidden_c is now silu'd (= hidden_a)
// matmul_gemv(hidden_c, x_proj.weight) β†’ sxBC [DR+2*S = 288]
const xProjW = await this._getF32Weight(`${prefix}.mixer.x_proj.weight`);
pass = encoder.beginComputePass();
pass.setPipeline(this.pipelines['matmul_gemv']);
pass.setBindGroup(0, this.device.createBindGroup({
layout: this.pipelines['matmul_gemv'].getBindGroupLayout(0),
entries: [
{ binding: 0, resource: { buffer: this.scratch.hidden_c } },
{ binding: 1, resource: { buffer: xProjW } },
{ binding: 2, resource: { buffer: this.scratch.sxBC } },
],
}));
const gemvParams2 = new ArrayBuffer(8);
new DataView(gemvParams2).setUint32(0, DR + 2 * S, true);
new DataView(gemvParams2).setUint32(4, I, true);
pass.setBindGroup(1, this.device.createBindGroup({
layout: this.pipelines['matmul_gemv'].getBindGroupLayout(1),
entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(gemvParams2)) }}],
}));
pass.dispatchWorkgroups(DR + 2 * S);
pass.end();
// Copy dt_pre, B, C from sxBC into separate buffers
encoder.copyBufferToBuffer(this.scratch.sxBC, 0, this.scratch.dt_pre, 0, DR * 4);
encoder.copyBufferToBuffer(this.scratch.sxBC, DR * 4, this.scratch.B_proj, 0, S * 4);
encoder.copyBufferToBuffer(this.scratch.sxBC, (DR + S) * 4, this.scratch.C_proj, 0, S * 4);
// Falcon-Mamba: RMSNorm(dt_pre), RMSNorm(B), RMSNorm(C) before use
const rmsNwParams_dt = new ArrayBuffer(8);
new DataView(rmsNwParams_dt).setUint32(0, DR, true);
new DataView(rmsNwParams_dt).setFloat32(4, CONFIG.rms_eps, true);
pass = encoder.beginComputePass();
pass.setPipeline(this.pipelines['rmsnorm_noweight']);
pass.setBindGroup(0, this.device.createBindGroup({
layout: this.pipelines['rmsnorm_noweight'].getBindGroupLayout(0),
entries: [{ binding: 0, resource: { buffer: this.scratch.dt_pre } }],
}));
pass.setBindGroup(1, this.device.createBindGroup({
layout: this.pipelines['rmsnorm_noweight'].getBindGroupLayout(1),
entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(rmsNwParams_dt)) }}],
}));
pass.dispatchWorkgroups(1);
pass.end();
const rmsNwParams_s = new ArrayBuffer(8);
new DataView(rmsNwParams_s).setUint32(0, S, true);
new DataView(rmsNwParams_s).setFloat32(4, CONFIG.rms_eps, true);
pass = encoder.beginComputePass();
pass.setPipeline(this.pipelines['rmsnorm_noweight']);
pass.setBindGroup(0, this.device.createBindGroup({
layout: this.pipelines['rmsnorm_noweight'].getBindGroupLayout(0),
entries: [{ binding: 0, resource: { buffer: this.scratch.B_proj } }],
}));
pass.setBindGroup(1, this.device.createBindGroup({
layout: this.pipelines['rmsnorm_noweight'].getBindGroupLayout(1),
entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(rmsNwParams_s)) }}],
}));
pass.dispatchWorkgroups(1);
pass.end();
pass = encoder.beginComputePass();
pass.setPipeline(this.pipelines['rmsnorm_noweight']);
pass.setBindGroup(0, this.device.createBindGroup({
layout: this.pipelines['rmsnorm_noweight'].getBindGroupLayout(0),
entries: [{ binding: 0, resource: { buffer: this.scratch.C_proj } }],
}));
pass.setBindGroup(1, this.device.createBindGroup({
layout: this.pipelines['rmsnorm_noweight'].getBindGroupLayout(1),
entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(rmsNwParams_s)) }}],
}));
pass.dispatchWorkgroups(1);
pass.end();
// matmul_gemv(dt_pre_normalized, dt_proj.weight) β†’ dt [I]
const dtProjW = await this._getF32Weight(`${prefix}.mixer.dt_proj.weight`);
pass = encoder.beginComputePass();
pass.setPipeline(this.pipelines['matmul_gemv']);
pass.setBindGroup(0, this.device.createBindGroup({
layout: this.pipelines['matmul_gemv'].getBindGroupLayout(0),
entries: [
{ binding: 0, resource: { buffer: this.scratch.dt_pre } },
{ binding: 1, resource: { buffer: dtProjW } },
{ binding: 2, resource: { buffer: this.scratch.dt } },
],
}));
const gemvParams3 = new ArrayBuffer(8);
new DataView(gemvParams3).setUint32(0, I, true);
new DataView(gemvParams3).setUint32(4, DR, true);
pass.setBindGroup(1, this.device.createBindGroup({
layout: this.pipelines['matmul_gemv'].getBindGroupLayout(1),
entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(gemvParams3)) }}],
}));
pass.dispatchWorkgroups(I);
pass.end();
// SSU: selective_state_update
// ssu(state, hidden_c, dt, A, B, C, D, dt_bias) β†’ hidden_y
const aLog = await this._getF32Weight(`${prefix}.mixer.A_log`);
const dWeight = await this._getF32Weight(`${prefix}.mixer.D`);
const dtBias = await this._getF32Weight(`${prefix}.mixer.dt_proj.bias`);
pass = encoder.beginComputePass();
pass.setPipeline(this.pipelines['ssu']);
pass.setBindGroup(0, this.device.createBindGroup({
layout: this.pipelines['ssu'].getBindGroupLayout(0),
entries: [
{ binding: 0, resource: { buffer: this.state[`layer.${l}.ssm`] } },
{ binding: 1, resource: { buffer: this.scratch.hidden_c } }, // x (silu'd)
{ binding: 2, resource: { buffer: this.scratch.dt } },
{ binding: 3, resource: { buffer: aLog } }, // A (needs -exp transform)
{ binding: 4, resource: { buffer: this.scratch.B_proj } }, // B
{ binding: 5, resource: { buffer: this.scratch.C_proj } }, // C
{ binding: 6, resource: { buffer: dWeight } },
{ binding: 7, resource: { buffer: dtBias } },
{ binding: 8, resource: { buffer: this.scratch.hidden_y } },
],
}));
const ssuParams = new ArrayBuffer(8);
new DataView(ssuParams).setUint32(0, I, true); // H
new DataView(ssuParams).setUint32(4, S, true); // S
pass.setBindGroup(1, this.device.createBindGroup({
layout: this.pipelines['ssu'].getBindGroupLayout(1),
entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(ssuParams)) }}],
}));
pass.dispatchWorkgroups(I); // one workgroup per h
pass.end();
// silu(gate) in-place
pass = encoder.beginComputePass();
pass.setPipeline(this.pipelines['silu']);
pass.setBindGroup(0, this.device.createBindGroup({
layout: this.pipelines['silu'].getBindGroupLayout(0),
entries: [{ binding: 0, resource: { buffer: this.scratch.gate } }],
}));
pass.setBindGroup(1, this.device.createBindGroup({
layout: this.pipelines['silu'].getBindGroupLayout(1),
entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(siluParams)) }}],
}));
pass.dispatchWorkgroups(Math.ceil(I / 64));
pass.end();
// elementwise_mul: hidden_y *= gate (in-place into hidden_y)
pass = encoder.beginComputePass();
pass.setPipeline(this.pipelines['elementwise_mul']);
pass.setBindGroup(0, this.device.createBindGroup({
layout: this.pipelines['elementwise_mul'].getBindGroupLayout(0),
entries: [
{ binding: 0, resource: { buffer: this.scratch.hidden_y } },
{ binding: 1, resource: { buffer: this.scratch.gate } },
],
}));
pass.setBindGroup(1, this.device.createBindGroup({
layout: this.pipelines['elementwise_mul'].getBindGroupLayout(1),
entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(siluParams)) }}],
}));
pass.dispatchWorkgroups(Math.ceil(I / 64));
pass.end();
// matmul_gemv(hidden_y, out_proj.weight) β†’ out_proj_o [H]
const outProjW = await this._getF32Weight(`${prefix}.mixer.out_proj.weight`);
pass = encoder.beginComputePass();
pass.setPipeline(this.pipelines['matmul_gemv']);
pass.setBindGroup(0, this.device.createBindGroup({
layout: this.pipelines['matmul_gemv'].getBindGroupLayout(0),
entries: [
{ binding: 0, resource: { buffer: this.scratch.hidden_y } },
{ binding: 1, resource: { buffer: outProjW } },
{ binding: 2, resource: { buffer: this.scratch.out_proj_o } },
],
}));
const gemvParams4 = new ArrayBuffer(8);
new DataView(gemvParams4).setUint32(0, H, true);
new DataView(gemvParams4).setUint32(4, I, true);
pass.setBindGroup(1, this.device.createBindGroup({
layout: this.pipelines['matmul_gemv'].getBindGroupLayout(1),
entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(gemvParams4)) }}],
}));
pass.dispatchWorkgroups(H);
pass.end();
// Submit this layer's command buffer
this.device.queue.submit([encoder.finish()]);
// Debug: readback hidden_state after residual for select layers
// Residual add: hidden_state += out_proj_o
{
const enc2 = this.device.createCommandEncoder();
const addPass = enc2.beginComputePass();
addPass.setPipeline(this.pipelines['add_residual']);
addPass.setBindGroup(0, this.device.createBindGroup({
layout: this.pipelines['add_residual'].getBindGroupLayout(0),
entries: [
{ binding: 0, resource: { buffer: this.scratch.hidden_state } },
{ binding: 1, resource: { buffer: this.scratch.out_proj_o } },
],
}));
const addParams = new ArrayBuffer(4);
new DataView(addParams).setUint32(0, H, true);
addPass.setBindGroup(1, this.device.createBindGroup({
layout: this.pipelines['add_residual'].getBindGroupLayout(1),
entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(addParams)) }}],
}));
addPass.dispatchWorkgroups(Math.ceil(H / 64));
addPass.end();
this.device.queue.submit([enc2.finish()]);
}
}
// Final: rmsnorm + lm_head + sample
await this.device.queue.onSubmittedWorkDone();
// rmsnorm(hidden_state, backbone.norm_f.weight) β†’ norm_out
const normFW = await this._getF32Weight('backbone.norm_f.weight');
let encoder = this.device.createCommandEncoder();
let pass = encoder.beginComputePass();
pass.setPipeline(this.pipelines['rmsnorm']);
pass.setBindGroup(0, this.device.createBindGroup({
layout: this.pipelines['rmsnorm'].getBindGroupLayout(0),
entries: [
{ binding: 0, resource: { buffer: this.scratch.hidden_state } },
{ binding: 1, resource: { buffer: normFW } },
{ binding: 2, resource: { buffer: this.scratch.norm_out } },
],
}));
const finalNormParams = new ArrayBuffer(12);
new DataView(finalNormParams).setUint32(0, 1, true);
new DataView(finalNormParams).setUint32(4, H, true);
new DataView(finalNormParams).setFloat32(8, CONFIG.rms_eps, true);
pass.setBindGroup(1, this.device.createBindGroup({
layout: this.pipelines['rmsnorm'].getBindGroupLayout(1),
entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(finalNormParams)) }}],
}));
pass.dispatchWorkgroups(1);
pass.end();
// matmul_gemv(norm_out, lm_head.weight) β†’ logits [V]
const lmHeadW = await this._getF32Weight('lm_head.weight');
pass = encoder.beginComputePass();
pass.setPipeline(this.pipelines['matmul_gemv']);
pass.setBindGroup(0, this.device.createBindGroup({
layout: this.pipelines['matmul_gemv'].getBindGroupLayout(0),
entries: [
{ binding: 0, resource: { buffer: this.scratch.norm_out } },
{ binding: 1, resource: { buffer: lmHeadW } },
{ binding: 2, resource: { buffer: this.scratch.logits } },
],
}));
const gemvFinal = new ArrayBuffer(8);
new DataView(gemvFinal).setUint32(0, V, true);
new DataView(gemvFinal).setUint32(4, H, true);
pass.setBindGroup(1, this.device.createBindGroup({
layout: this.pipelines['matmul_gemv'].getBindGroupLayout(1),
entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(gemvFinal)) }}],
}));
pass.dispatchWorkgroups(V);
pass.end();
// sample(logits, temperature) β†’ token_out
pass = encoder.beginComputePass();
pass.setPipeline(this.pipelines['sample']);
pass.setBindGroup(0, this.device.createBindGroup({
layout: this.pipelines['sample'].getBindGroupLayout(0),
entries: [
{ binding: 0, resource: { buffer: this.scratch.logits } },
{ binding: 1, resource: { buffer: this.scratch.token_out } },
],
}));
const sampleParams = new ArrayBuffer(12);
new DataView(sampleParams).setUint32(0, V, true);
new DataView(sampleParams).setFloat32(4, 1.0 / 0.75, true); // inv_temperature
new DataView(sampleParams).setUint32(8, Math.floor(Math.random() * 0xFFFFFFFF), true); // rng_seed
pass.setBindGroup(1, this.device.createBindGroup({
layout: this.pipelines['sample'].getBindGroupLayout(1),
entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(sampleParams)) }}],
}));
pass.dispatchWorkgroups(1);
pass.end();
this.device.queue.submit([encoder.finish()]);
await this.device.queue.onSubmittedWorkDone();
// Read back the sampled token
const tokenResult = await this._readback(this.scratch.token_out, 4);
this._tokenCount++;
return new Uint32Array(tokenResult.buffer)[0];
}
// ── Tokenize/detokenize via server ──────────────────────────────────────
async tokenize(text, system = '') {
const resp = await fetch('/tokenize', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ text, system }),
});
const data = await resp.json();
return data.result;
}
async detokenize(tokens) {
const resp = await fetch('/detokenize', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ tokens }),
});
const data = await resp.json();
return data.result;
}
// ── Generate text ──────────────────────────────────────────────────────
async generate(prompt, maxTokens = 100, temperature = 0.75, onToken = null, system = '') {
if (!this.ready) throw new Error('Call loadWeights() first');
this._allocateScratch();
console.log('[mamba] generate:', prompt, 'max_tokens:', maxTokens);
// Tokenize the prompt (with optional system prompt)
const promptTokens = await this.tokenize(prompt, system);
console.log(`[mamba] prompt tokens (${promptTokens.length}):`, promptTokens);
// Process prompt tokens through forward pass to build SSM state
console.log('[mamba] encoding prompt...');
for (let i = 0; i < promptTokens.length; i++) {
const t0 = performance.now();
await this._forwardOneToken(promptTokens[i]);
const elapsed = performance.now() - t0;
if (i === 0 || i === promptTokens.length - 1) {
console.log(`[mamba] prompt token ${i}/${promptTokens.length}: ${promptTokens[i]} (${elapsed.toFixed(0)}ms)`);
}
}
console.log('[mamba] prompt encoded, generating...');
// Get the last prompt token's output as first generation input
const generated = [];
// The last _forwardOneToken already produced the next-token prediction
// We need to read it back
const firstResult = await this._readback(this.scratch.token_out, 4);
let inputToken = new Uint32Array(firstResult.buffer)[0];
generated.push(inputToken);
console.log(`[mamba] first generated token: ${inputToken}`);
if (onToken) onToken(inputToken, 0);
for (let step = 1; step < maxTokens; step++) {
const t0 = performance.now();
try {
const nextToken = await this._forwardOneToken(inputToken);
const elapsed = performance.now() - t0;
if (step < 5 || step % 20 === 0) {
console.log(`[mamba] step ${step}: token=${nextToken} (${elapsed.toFixed(0)}ms)`);
}
generated.push(nextToken);
inputToken = nextToken;
if (onToken) onToken(nextToken, step);
if (nextToken === 11 || nextToken === 10 || nextToken === 0) break; // EOS=11, im_end=10, PAD=0
} catch (e) {
console.error(`[mamba] step ${step} failed:`, e.message);
break;
}
}
// Decode the generated tokens
const text = await this.detokenize(generated);
console.log(`[mamba] generated ${generated.length} tokens`);
return text;
}
// ── Helper: get weight buffer by name ──────────────────────────────────
_getWeight(name) {
return this.weights[name] || null;
}
// ── Create a uniform buffer with typed data ────────────────────────────
_createUniform(data) {
const buf = this.device.createBuffer({
size: data.byteLength,
usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
mappedAtCreation: true,
});
new Uint8Array(buf.getMappedRange()).set(new Uint8Array(data.buffer));
buf.unmap();
return buf;
}
// ── BF16 β†’ F32 conversion for a weight tensor ─────────────────────────
async _convertBF16toF32(weightInfo) {
if (weightInfo.dtype !== 'BF16' || weightInfo.f32buffer) return weightInfo;
const numBF16 = weightInfo.byteLen / 2; // each bf16 is 2 bytes
const numPairs = weightInfo.byteLen / 4; // each u32 holds 2 bf16
const f32Bytes = numBF16 * 4;
// Create output F32 buffer
const f32Buf = this._createBuffer(
f32Bytes,
GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST,
weightInfo.buffer.label + '_f32'
);
// Create uniform for params
const paramBuf = this._createUniform(new Uint32Array([numPairs]));
// Create bind groups
const pipeline = this.pipelines['bf16_to_f32'];
const bg0 = this.device.createBindGroup({
layout: pipeline.getBindGroupLayout(0),
entries: [
{ binding: 0, resource: { buffer: weightInfo.buffer } },
{ binding: 1, resource: { buffer: f32Buf } },
],
});
const bg1 = this.device.createBindGroup({
layout: pipeline.getBindGroupLayout(1),
entries: [
{ binding: 0, resource: { buffer: paramBuf } },
],
});
// Dispatch
const encoder = this.device.createCommandEncoder();
const pass = encoder.beginComputePass();
pass.setPipeline(pipeline);
pass.setBindGroup(0, bg0);
pass.setBindGroup(1, bg1);
pass.dispatchWorkgroups(Math.ceil(numPairs / 64));
pass.end();
this.device.queue.submit([encoder.finish()]);
await this.device.queue.onSubmittedWorkDone();
// Cache the F32 buffer
weightInfo.f32buffer = f32Buf;
weightInfo.f32size = f32Bytes;
paramBuf.destroy();
return weightInfo;
}
// ── Get F32 weight buffer (already converted during load) ───────────────
async _getF32Weight(name) {
const w = this.weights[name];
if (!w) throw new Error(`Missing weight: ${name}`);
return w.buffer;
}
// ── Dispatch a shader with auto bind group creation ─────────────────────
_dispatchShader(encoder, shaderName, storageBuffers, uniformData) {
const pipeline = this.pipelines[shaderName];
const pass = encoder.beginComputePass();
pass.setPipeline(pipeline);
// Bind group 0: storage buffers
const entries0 = storageBuffers.map((buf, i) => ({
binding: i, resource: { buffer: buf }
}));
const bg0 = this.device.createBindGroup({
layout: pipeline.getBindGroupLayout(0),
entries: entries0,
});
pass.setBindGroup(0, bg0);
// Bind group 1: uniforms (if provided)
if (uniformData) {
const ubuf = this._createUniform(uniformData);
const bg1 = this.device.createBindGroup({
layout: pipeline.getBindGroupLayout(1),
entries: [{ binding: 0, resource: { buffer: ubuf } }],
});
pass.setBindGroup(1, bg1);
// Note: ubuf leaks β€” for production, cache these. Fine for proof-of-concept.
}
return pass; // caller sets dispatch count and calls pass.end()
}
// ── Cleanup ────────────────────────────────────────────────────────────
destroy() {
for (const w of Object.values(this.weights)) w.buffer.destroy();
for (const s of Object.values(this.state)) s.destroy();
this.weights = {};
this.state = {};
this.ready = false;
}
}
// ES module export for browser <script type="module">
export { MambaRuntime, CONFIG };