Upload mamba_runtime.js with huggingface_hub
Browse files- mamba_runtime.js +954 -0
mamba_runtime.js
ADDED
|
@@ -0,0 +1,954 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/**
|
| 2 |
+
* mamba_runtime.js β Browser-native Falcon-Mamba inference via WebGPU.
|
| 3 |
+
*
|
| 4 |
+
* The first browser-native Mamba/SSM inference engine.
|
| 5 |
+
* No MLC, no TVM β pure WebGPU compute shaders ported from gfx1151_runtime.
|
| 6 |
+
*
|
| 7 |
+
* Architecture: Falcon-Mamba 7B
|
| 8 |
+
* 64 layers, each: RMSNorm β in_proj β conv1d β SSU β out_proj
|
| 9 |
+
* Final: RMSNorm β lm_head β sample
|
| 10 |
+
*
|
| 11 |
+
* Weight format: safetensors (HF standard), loaded directly into WebGPU buffers.
|
| 12 |
+
* Shaders: WGSL compute shaders in ./shaders/ (ported from Vulkan GLSL).
|
| 13 |
+
*
|
| 14 |
+
* Usage:
|
| 15 |
+
* const mamba = new MambaRuntime();
|
| 16 |
+
* await mamba.init();
|
| 17 |
+
* await mamba.loadWeights('./weights/');
|
| 18 |
+
* const text = await mamba.generate("Hello Grandma", 100);
|
| 19 |
+
*/
|
| 20 |
+
|
| 21 |
+
// Falcon-Mamba 7B constants
|
| 22 |
+
const CONFIG = {
|
| 23 |
+
hidden_size: 4096,
|
| 24 |
+
intermediate_size: 8192, // 2 * hidden
|
| 25 |
+
num_layers: 64,
|
| 26 |
+
vocab_size: 65024,
|
| 27 |
+
state_size: 16, // SSM d_state
|
| 28 |
+
conv_kernel: 4,
|
| 29 |
+
dt_rank: 256,
|
| 30 |
+
rms_eps: 1e-5,
|
| 31 |
+
};
|
| 32 |
+
|
| 33 |
+
class MambaRuntime {
|
| 34 |
+
constructor() {
|
| 35 |
+
this.device = null;
|
| 36 |
+
this.pipelines = {}; // shader name β GPUComputePipeline
|
| 37 |
+
this.bindLayouts = {}; // shader name β GPUBindGroupLayout
|
| 38 |
+
this.weights = {}; // parameter name β GPUBuffer
|
| 39 |
+
this.state = {}; // per-layer SSM state + conv1d state buffers
|
| 40 |
+
this.ready = false;
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
// ββ Init: get WebGPU device + compile all shaders ββββββββββββββββββββββ
|
| 44 |
+
async init() {
|
| 45 |
+
if (!navigator.gpu) throw new Error('WebGPU not supported in this browser');
|
| 46 |
+
const adapter = await navigator.gpu.requestAdapter();
|
| 47 |
+
if (!adapter) throw new Error('No WebGPU adapter found');
|
| 48 |
+
|
| 49 |
+
// Request max buffer size the device supports
|
| 50 |
+
const limits = adapter.limits;
|
| 51 |
+
console.log('[mamba] maxBufferSize:', limits.maxBufferSize,
|
| 52 |
+
'=', (limits.maxBufferSize / 1024 / 1024 / 1024).toFixed(2), 'GB');
|
| 53 |
+
|
| 54 |
+
this.device = await adapter.requestDevice({
|
| 55 |
+
requiredLimits: {
|
| 56 |
+
maxBufferSize: limits.maxBufferSize,
|
| 57 |
+
maxStorageBufferBindingSize: limits.maxStorageBufferBindingSize,
|
| 58 |
+
maxComputeWorkgroupStorageSize: limits.maxComputeWorkgroupStorageSize,
|
| 59 |
+
maxStorageBuffersPerShaderStage: Math.min(limits.maxStorageBuffersPerShaderStage, 16),
|
| 60 |
+
}
|
| 61 |
+
});
|
| 62 |
+
|
| 63 |
+
this.device.lost.then((info) => {
|
| 64 |
+
console.error('[mamba] DEVICE LOST:', info.reason, info.message);
|
| 65 |
+
});
|
| 66 |
+
this.device.addEventListener('uncapturederror', (e) => {
|
| 67 |
+
console.error('[mamba] GPU ERROR:', e.error.message);
|
| 68 |
+
});
|
| 69 |
+
|
| 70 |
+
console.log('[mamba] device ready, compiling shaders...');
|
| 71 |
+
await this._compileShaders();
|
| 72 |
+
console.log('[mamba] shaders compiled');
|
| 73 |
+
return this;
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
// ββ Compile all WGSL shaders into compute pipelines ββββββββββββββββββββ
|
| 77 |
+
async _compileShaders() {
|
| 78 |
+
const shaderNames = [
|
| 79 |
+
'conv1d_step', 'ssu', 'matmul_gemv', 'rmsnorm', 'rmsnorm_noweight',
|
| 80 |
+
'silu', 'softplus', 'embedding', 'elementwise_mul', 'sample',
|
| 81 |
+
'bf16_to_f32', 'add_residual'
|
| 82 |
+
];
|
| 83 |
+
|
| 84 |
+
for (const name of shaderNames) {
|
| 85 |
+
const resp = await fetch(`./shaders/${name}.wgsl`);
|
| 86 |
+
if (!resp.ok) throw new Error(`Failed to load shader: ${name}.wgsl`);
|
| 87 |
+
const code = await resp.text();
|
| 88 |
+
|
| 89 |
+
const shaderModule = this.device.createShaderModule({ code, label: name });
|
| 90 |
+
|
| 91 |
+
// Create bind group layouts based on shader requirements
|
| 92 |
+
// Group 0 = storage buffers (data), Group 1 = uniforms (params)
|
| 93 |
+
const pipeline = this.device.createComputePipeline({
|
| 94 |
+
layout: 'auto',
|
| 95 |
+
compute: { module: shaderModule, entryPoint: 'main' },
|
| 96 |
+
label: name,
|
| 97 |
+
});
|
| 98 |
+
|
| 99 |
+
this.pipelines[name] = pipeline;
|
| 100 |
+
}
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
// ββ Create a GPU buffer ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 104 |
+
_createBuffer(size, usage, label) {
|
| 105 |
+
return this.device.createBuffer({
|
| 106 |
+
size: Math.max(size, 4), // WebGPU requires min 4 bytes
|
| 107 |
+
usage,
|
| 108 |
+
label,
|
| 109 |
+
mappedAtCreation: false,
|
| 110 |
+
});
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
// ββ Upload data to a GPU buffer ββββββββββββββββββββββββββββββββββββββββ
|
| 114 |
+
_upload(buffer, data) {
|
| 115 |
+
this.device.queue.writeBuffer(buffer, 0, data);
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
// ββ Read data back from GPU buffer βββββββββββββββββββββββββββββββββββββ
|
| 119 |
+
async _readback(buffer, size) {
|
| 120 |
+
const staging = this.device.createBuffer({
|
| 121 |
+
size,
|
| 122 |
+
usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST,
|
| 123 |
+
});
|
| 124 |
+
const encoder = this.device.createCommandEncoder();
|
| 125 |
+
encoder.copyBufferToBuffer(buffer, 0, staging, 0, size);
|
| 126 |
+
this.device.queue.submit([encoder.finish()]);
|
| 127 |
+
await staging.mapAsync(GPUMapMode.READ);
|
| 128 |
+
const result = new Float32Array(staging.getMappedRange().slice(0));
|
| 129 |
+
staging.unmap();
|
| 130 |
+
staging.destroy();
|
| 131 |
+
return result;
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
// ββ Dispatch a compute shader ββββββββββββββββββββββββββββββββββββββββββ
|
| 135 |
+
_dispatch(shaderName, bindGroup, uniformBindGroup, workgroupsX, workgroupsY = 1, workgroupsZ = 1) {
|
| 136 |
+
const encoder = this.device.createCommandEncoder();
|
| 137 |
+
const pass = encoder.beginComputePass();
|
| 138 |
+
pass.setPipeline(this.pipelines[shaderName]);
|
| 139 |
+
pass.setBindGroup(0, bindGroup);
|
| 140 |
+
if (uniformBindGroup) pass.setBindGroup(1, uniformBindGroup);
|
| 141 |
+
pass.dispatchWorkgroups(workgroupsX, workgroupsY, workgroupsZ);
|
| 142 |
+
pass.end();
|
| 143 |
+
this.device.queue.submit([encoder.finish()]);
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
// ββ Load safetensors weights into GPU buffers ββββββββββββββββββββββββββ
|
| 147 |
+
async loadWeights(basePath) {
|
| 148 |
+
console.log('[mamba] loading weights from', basePath);
|
| 149 |
+
|
| 150 |
+
// Get the shard index
|
| 151 |
+
const indexResp = await fetch(`${basePath}/model.safetensors.index.json`);
|
| 152 |
+
let fileMap; // tensor_name β filename
|
| 153 |
+
let files;
|
| 154 |
+
if (indexResp.ok) {
|
| 155 |
+
const index = await indexResp.json();
|
| 156 |
+
fileMap = index.weight_map;
|
| 157 |
+
files = [...new Set(Object.values(fileMap))];
|
| 158 |
+
console.log(`[mamba] multi-shard: ${files.length} files, ${Object.keys(fileMap).length} tensors`);
|
| 159 |
+
} else {
|
| 160 |
+
files = ['model.safetensors'];
|
| 161 |
+
fileMap = null;
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
// For each shard, fetch ONLY the header first (small), then load tensors by byte-range
|
| 165 |
+
for (const file of files) {
|
| 166 |
+
console.log(`[mamba] parsing ${file} header...`);
|
| 167 |
+
|
| 168 |
+
// Fetch first 8 bytes to get header length
|
| 169 |
+
const headResp = await fetch(`${basePath}/${file}`, {
|
| 170 |
+
headers: { 'Range': 'bytes=0-7' }
|
| 171 |
+
});
|
| 172 |
+
let headerLen;
|
| 173 |
+
if (headResp.status === 206) {
|
| 174 |
+
// Range request supported
|
| 175 |
+
const headBuf = await headResp.arrayBuffer();
|
| 176 |
+
headerLen = new DataView(headBuf).getUint32(0, true);
|
| 177 |
+
} else {
|
| 178 |
+
// Range not supported β fall back to full fetch but only read header
|
| 179 |
+
const fullBuf = await headResp.arrayBuffer();
|
| 180 |
+
headerLen = new DataView(fullBuf).getUint32(0, true);
|
| 181 |
+
}
|
| 182 |
+
console.log(`[mamba] header: ${headerLen} bytes`);
|
| 183 |
+
|
| 184 |
+
// Fetch header JSON
|
| 185 |
+
const hdrResp = await fetch(`${basePath}/${file}`, {
|
| 186 |
+
headers: { 'Range': `bytes=8-${8 + headerLen - 1}` }
|
| 187 |
+
});
|
| 188 |
+
let headerStr;
|
| 189 |
+
if (hdrResp.status === 206) {
|
| 190 |
+
headerStr = await hdrResp.text();
|
| 191 |
+
} else {
|
| 192 |
+
const fullBuf = await hdrResp.arrayBuffer();
|
| 193 |
+
headerStr = new TextDecoder().decode(new Uint8Array(fullBuf, 8, headerLen));
|
| 194 |
+
}
|
| 195 |
+
const header = JSON.parse(headerStr);
|
| 196 |
+
const dataOffset = 8 + headerLen;
|
| 197 |
+
|
| 198 |
+
// Load each tensor individually
|
| 199 |
+
const tensorNames = Object.keys(header).filter(n => n !== '__metadata__');
|
| 200 |
+
console.log(`[mamba] ${tensorNames.length} tensors in this shard`);
|
| 201 |
+
|
| 202 |
+
let loaded = 0;
|
| 203 |
+
for (const name of tensorNames) {
|
| 204 |
+
const meta = header[name];
|
| 205 |
+
const dtype = meta.dtype;
|
| 206 |
+
const shape = meta.shape;
|
| 207 |
+
const [start, end] = meta.data_offsets;
|
| 208 |
+
const byteLen = end - start;
|
| 209 |
+
|
| 210 |
+
if (byteLen > 2_000_000_000) {
|
| 211 |
+
console.log(`[mamba] SKIP ${name} (${(byteLen/1e9).toFixed(2)} GB β exceeds buffer limit)`);
|
| 212 |
+
continue;
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
// Fetch this tensor's bytes via Range request
|
| 216 |
+
const absStart = dataOffset + start;
|
| 217 |
+
const absEnd = dataOffset + end - 1;
|
| 218 |
+
const tResp = await fetch(`${basePath}/${file}`, {
|
| 219 |
+
headers: { 'Range': `bytes=${absStart}-${absEnd}` }
|
| 220 |
+
});
|
| 221 |
+
|
| 222 |
+
let tensorBuf;
|
| 223 |
+
if (tResp.status === 206) {
|
| 224 |
+
tensorBuf = await tResp.arrayBuffer();
|
| 225 |
+
} else {
|
| 226 |
+
// No range support β need full file (expensive)
|
| 227 |
+
console.log(`[mamba] WARN: no range support, loading full file for ${name}`);
|
| 228 |
+
const fullBuf = await tResp.arrayBuffer();
|
| 229 |
+
tensorBuf = fullBuf.slice(absStart, absStart + byteLen);
|
| 230 |
+
}
|
| 231 |
+
|
| 232 |
+
// For BF16 weights: convert to F32 during upload (no double-buffering)
|
| 233 |
+
let gpuBuf;
|
| 234 |
+
let finalDtype = dtype;
|
| 235 |
+
let finalByteLen = byteLen;
|
| 236 |
+
|
| 237 |
+
if (dtype === 'BF16') {
|
| 238 |
+
// Convert CPU-side: BF16 β F32 before uploading
|
| 239 |
+
const bf16 = new Uint16Array(tensorBuf);
|
| 240 |
+
const f32 = new Float32Array(bf16.length);
|
| 241 |
+
const tmpU32 = new Uint32Array(f32.buffer);
|
| 242 |
+
for (let j = 0; j < bf16.length; j++) {
|
| 243 |
+
tmpU32[j] = bf16[j] << 16; // BF16 is top 16 bits of F32
|
| 244 |
+
}
|
| 245 |
+
finalByteLen = f32.byteLength;
|
| 246 |
+
gpuBuf = this._createBuffer(
|
| 247 |
+
finalByteLen,
|
| 248 |
+
GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST,
|
| 249 |
+
name
|
| 250 |
+
);
|
| 251 |
+
this._upload(gpuBuf, f32);
|
| 252 |
+
finalDtype = 'F32';
|
| 253 |
+
} else {
|
| 254 |
+
gpuBuf = this._createBuffer(
|
| 255 |
+
byteLen,
|
| 256 |
+
GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST,
|
| 257 |
+
name
|
| 258 |
+
);
|
| 259 |
+
this._upload(gpuBuf, new Uint8Array(tensorBuf));
|
| 260 |
+
}
|
| 261 |
+
this.weights[name] = { buffer: gpuBuf, shape, dtype: finalDtype, byteLen: finalByteLen };
|
| 262 |
+
|
| 263 |
+
loaded++;
|
| 264 |
+
if (loaded % 20 === 0) {
|
| 265 |
+
console.log(`[mamba] loaded ${loaded}/${tensorNames.length} tensors`);
|
| 266 |
+
}
|
| 267 |
+
}
|
| 268 |
+
console.log(`[mamba] shard done: ${loaded} tensors loaded`);
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
console.log(`[mamba] TOTAL: ${Object.keys(this.weights).length} tensors loaded`);
|
| 272 |
+
|
| 273 |
+
// Allocate per-layer state buffers
|
| 274 |
+
this._allocateState();
|
| 275 |
+
this.ready = true;
|
| 276 |
+
}
|
| 277 |
+
|
| 278 |
+
// ββ Allocate persistent SSM state + conv1d cache per layer βββββββββββββ
|
| 279 |
+
_allocateState() {
|
| 280 |
+
const H = CONFIG.intermediate_size; // 8192
|
| 281 |
+
const S = CONFIG.state_size; // 16
|
| 282 |
+
const K = CONFIG.conv_kernel; // 4
|
| 283 |
+
|
| 284 |
+
for (let l = 0; l < CONFIG.num_layers; l++) {
|
| 285 |
+
// SSM state: [H, S] = 8192 * 16 = 131072 floats = 512 KB per layer
|
| 286 |
+
this.state[`layer.${l}.ssm`] = this._createBuffer(
|
| 287 |
+
H * S * 4,
|
| 288 |
+
GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST,
|
| 289 |
+
`ssm_state_${l}`
|
| 290 |
+
);
|
| 291 |
+
|
| 292 |
+
// Conv1d cache: [H, K-1] = 8192 * 3 = 24576 floats = 96 KB per layer
|
| 293 |
+
this.state[`layer.${l}.conv`] = this._createBuffer(
|
| 294 |
+
H * (K - 1) * 4,
|
| 295 |
+
GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST,
|
| 296 |
+
`conv_state_${l}`
|
| 297 |
+
);
|
| 298 |
+
}
|
| 299 |
+
|
| 300 |
+
// Total state: 64 layers Γ (512 + 96) KB = ~38 MB
|
| 301 |
+
console.log(`[mamba] allocated ${CONFIG.num_layers} layers of SSM + conv1d state (~38 MB)`);
|
| 302 |
+
}
|
| 303 |
+
|
| 304 |
+
// ββ Save/restore SSM state (the entity's persistent soul) ββββββββββββββ
|
| 305 |
+
async saveState() {
|
| 306 |
+
const state = {};
|
| 307 |
+
for (const [key, buf] of Object.entries(this.state)) {
|
| 308 |
+
state[key] = await this._readback(buf, buf.size);
|
| 309 |
+
}
|
| 310 |
+
return state;
|
| 311 |
+
}
|
| 312 |
+
|
| 313 |
+
async restoreState(state) {
|
| 314 |
+
for (const [key, data] of Object.entries(state)) {
|
| 315 |
+
if (this.state[key]) {
|
| 316 |
+
this._upload(this.state[key], data);
|
| 317 |
+
}
|
| 318 |
+
}
|
| 319 |
+
}
|
| 320 |
+
|
| 321 |
+
// ββ Allocate intermediate scratch buffers for forward pass ββββββββββββββ
|
| 322 |
+
_allocateScratch() {
|
| 323 |
+
if (this.scratch) return; // already allocated
|
| 324 |
+
const H = CONFIG.hidden_size; // 4096
|
| 325 |
+
const I = CONFIG.intermediate_size; // 8192
|
| 326 |
+
const DR = CONFIG.dt_rank; // 256
|
| 327 |
+
const S = CONFIG.state_size; // 16
|
| 328 |
+
const F = 4; // sizeof(float32)
|
| 329 |
+
|
| 330 |
+
this.scratch = {
|
| 331 |
+
norm_out: this._createBuffer(H * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'norm_out'),
|
| 332 |
+
projected: this._createBuffer(2 * I * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'projected'),
|
| 333 |
+
hidden: this._createBuffer(I * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'hidden'),
|
| 334 |
+
gate: this._createBuffer(I * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'gate'),
|
| 335 |
+
hidden_c: this._createBuffer(I * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'hidden_c'),
|
| 336 |
+
sxBC: this._createBuffer((DR + 2*S)*F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'sxBC'),
|
| 337 |
+
B_proj: this._createBuffer(S * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'B_proj'),
|
| 338 |
+
C_proj: this._createBuffer(S * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'C_proj'),
|
| 339 |
+
dt_pre: this._createBuffer(DR * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'dt_pre'),
|
| 340 |
+
dt: this._createBuffer(I * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'dt'),
|
| 341 |
+
hidden_y: this._createBuffer(I * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'hidden_y'),
|
| 342 |
+
gate_silu: this._createBuffer(I * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'gate_silu'),
|
| 343 |
+
out_proj_o: this._createBuffer(H * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'out_proj_o'),
|
| 344 |
+
logits: this._createBuffer(CONFIG.vocab_size * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'logits'),
|
| 345 |
+
token_out: this._createBuffer(4, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'token_out'),
|
| 346 |
+
hidden_state: this._createBuffer(H * F, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'hidden_state'),
|
| 347 |
+
token_id: this._createBuffer(4, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, 'token_id'),
|
| 348 |
+
};
|
| 349 |
+
this._tokenCount = 0;
|
| 350 |
+
console.log('[mamba] scratch buffers allocated');
|
| 351 |
+
}
|
| 352 |
+
|
| 353 |
+
// ββ Single-token forward pass through all 64 layers βββββββββββββββββββββ
|
| 354 |
+
async _forwardOneToken(tokenId) {
|
| 355 |
+
const H = CONFIG.hidden_size; // 4096
|
| 356 |
+
const I = CONFIG.intermediate_size; // 8192
|
| 357 |
+
const DR = CONFIG.dt_rank; // 256
|
| 358 |
+
const S = CONFIG.state_size; // 16
|
| 359 |
+
const V = CONFIG.vocab_size; // 65024
|
| 360 |
+
|
| 361 |
+
// Step 1: Embedding lookup β copy one row from embedding table to hidden_state
|
| 362 |
+
this._upload(this.scratch.token_id, new Uint32Array([tokenId]));
|
| 363 |
+
const embBuf = await this._getF32Weight('backbone.embeddings.weight');
|
| 364 |
+
const encoder1 = this.device.createCommandEncoder();
|
| 365 |
+
encoder1.copyBufferToBuffer(embBuf, tokenId * H * 4, this.scratch.hidden_state, 0, H * 4);
|
| 366 |
+
this.device.queue.submit([encoder1.finish()]);
|
| 367 |
+
|
| 368 |
+
// Step 2: For each layer (0..63)
|
| 369 |
+
for (let l = 0; l < CONFIG.num_layers; l++) {
|
| 370 |
+
const prefix = `backbone.layers.${l}`;
|
| 371 |
+
|
| 372 |
+
// rmsnorm(hidden_state, norm.weight) β norm_out
|
| 373 |
+
const normW = await this._getF32Weight(`${prefix}.norm.weight`);
|
| 374 |
+
let encoder = this.device.createCommandEncoder();
|
| 375 |
+
let pass = encoder.beginComputePass();
|
| 376 |
+
pass.setPipeline(this.pipelines['rmsnorm']);
|
| 377 |
+
pass.setBindGroup(0, this.device.createBindGroup({
|
| 378 |
+
layout: this.pipelines['rmsnorm'].getBindGroupLayout(0),
|
| 379 |
+
entries: [
|
| 380 |
+
{ binding: 0, resource: { buffer: this.scratch.hidden_state } },
|
| 381 |
+
{ binding: 1, resource: { buffer: normW } },
|
| 382 |
+
{ binding: 2, resource: { buffer: this.scratch.norm_out } },
|
| 383 |
+
],
|
| 384 |
+
}));
|
| 385 |
+
const rmsnormParams = new ArrayBuffer(12);
|
| 386 |
+
new DataView(rmsnormParams).setUint32(0, 1, true);
|
| 387 |
+
new DataView(rmsnormParams).setUint32(4, H, true);
|
| 388 |
+
new DataView(rmsnormParams).setFloat32(8, CONFIG.rms_eps, true);
|
| 389 |
+
pass.setBindGroup(1, this.device.createBindGroup({
|
| 390 |
+
layout: this.pipelines['rmsnorm'].getBindGroupLayout(1),
|
| 391 |
+
entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(rmsnormParams)) }}],
|
| 392 |
+
}));
|
| 393 |
+
pass.dispatchWorkgroups(1); // one workgroup per row, 1 row
|
| 394 |
+
pass.end();
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
// matmul_gemv(norm_out, in_proj.weight) β projected [I*2 = 16384]
|
| 399 |
+
const inProjW = await this._getF32Weight(`${prefix}.mixer.in_proj.weight`);
|
| 400 |
+
pass = encoder.beginComputePass();
|
| 401 |
+
pass.setPipeline(this.pipelines['matmul_gemv']);
|
| 402 |
+
pass.setBindGroup(0, this.device.createBindGroup({
|
| 403 |
+
layout: this.pipelines['matmul_gemv'].getBindGroupLayout(0),
|
| 404 |
+
entries: [
|
| 405 |
+
{ binding: 0, resource: { buffer: this.scratch.norm_out } },
|
| 406 |
+
{ binding: 1, resource: { buffer: inProjW } },
|
| 407 |
+
{ binding: 2, resource: { buffer: this.scratch.projected } },
|
| 408 |
+
],
|
| 409 |
+
}));
|
| 410 |
+
const gemvParams1 = new ArrayBuffer(8);
|
| 411 |
+
new DataView(gemvParams1).setUint32(0, I * 2, true); // N
|
| 412 |
+
new DataView(gemvParams1).setUint32(4, H, true); // K
|
| 413 |
+
pass.setBindGroup(1, this.device.createBindGroup({
|
| 414 |
+
layout: this.pipelines['matmul_gemv'].getBindGroupLayout(1),
|
| 415 |
+
entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(gemvParams1)) }}],
|
| 416 |
+
}));
|
| 417 |
+
pass.dispatchWorkgroups(I * 2); // one workgroup per output element
|
| 418 |
+
pass.end();
|
| 419 |
+
|
| 420 |
+
// Split projected β hidden[0:I], gate[I:2I] via buffer copies
|
| 421 |
+
encoder.copyBufferToBuffer(this.scratch.projected, 0, this.scratch.hidden, 0, I * 4);
|
| 422 |
+
encoder.copyBufferToBuffer(this.scratch.projected, I * 4, this.scratch.gate, 0, I * 4);
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
// conv1d_step(conv_state, hidden, conv1d.weight, conv1d.bias) β hidden_c
|
| 427 |
+
const conv1dW = await this._getF32Weight(`${prefix}.mixer.conv1d.weight`);
|
| 428 |
+
const conv1dB = await this._getF32Weight(`${prefix}.mixer.conv1d.bias`);
|
| 429 |
+
pass = encoder.beginComputePass();
|
| 430 |
+
pass.setPipeline(this.pipelines['conv1d_step']);
|
| 431 |
+
pass.setBindGroup(0, this.device.createBindGroup({
|
| 432 |
+
layout: this.pipelines['conv1d_step'].getBindGroupLayout(0),
|
| 433 |
+
entries: [
|
| 434 |
+
{ binding: 0, resource: { buffer: this.state[`layer.${l}.conv`] } },
|
| 435 |
+
{ binding: 1, resource: { buffer: this.scratch.hidden } },
|
| 436 |
+
{ binding: 2, resource: { buffer: conv1dW } },
|
| 437 |
+
{ binding: 3, resource: { buffer: conv1dB } },
|
| 438 |
+
{ binding: 4, resource: { buffer: this.scratch.hidden_c } },
|
| 439 |
+
],
|
| 440 |
+
}));
|
| 441 |
+
const conv1dParams = new ArrayBuffer(4);
|
| 442 |
+
new DataView(conv1dParams).setUint32(0, I, true);
|
| 443 |
+
pass.setBindGroup(1, this.device.createBindGroup({
|
| 444 |
+
layout: this.pipelines['conv1d_step'].getBindGroupLayout(1),
|
| 445 |
+
entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(conv1dParams)) }}],
|
| 446 |
+
}));
|
| 447 |
+
pass.dispatchWorkgroups(Math.ceil(I / 64));
|
| 448 |
+
pass.end();
|
| 449 |
+
|
| 450 |
+
// silu(hidden_c) in-place β hidden_a
|
| 451 |
+
pass = encoder.beginComputePass();
|
| 452 |
+
pass.setPipeline(this.pipelines['silu']);
|
| 453 |
+
pass.setBindGroup(0, this.device.createBindGroup({
|
| 454 |
+
layout: this.pipelines['silu'].getBindGroupLayout(0),
|
| 455 |
+
entries: [{ binding: 0, resource: { buffer: this.scratch.hidden_c } }],
|
| 456 |
+
}));
|
| 457 |
+
const siluParams = new ArrayBuffer(4);
|
| 458 |
+
new DataView(siluParams).setUint32(0, I, true);
|
| 459 |
+
pass.setBindGroup(1, this.device.createBindGroup({
|
| 460 |
+
layout: this.pipelines['silu'].getBindGroupLayout(1),
|
| 461 |
+
entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(siluParams)) }}],
|
| 462 |
+
}));
|
| 463 |
+
pass.dispatchWorkgroups(Math.ceil(I / 64));
|
| 464 |
+
pass.end();
|
| 465 |
+
// hidden_c is now silu'd (= hidden_a)
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
// matmul_gemv(hidden_c, x_proj.weight) β sxBC [DR+2*S = 288]
|
| 470 |
+
const xProjW = await this._getF32Weight(`${prefix}.mixer.x_proj.weight`);
|
| 471 |
+
pass = encoder.beginComputePass();
|
| 472 |
+
pass.setPipeline(this.pipelines['matmul_gemv']);
|
| 473 |
+
pass.setBindGroup(0, this.device.createBindGroup({
|
| 474 |
+
layout: this.pipelines['matmul_gemv'].getBindGroupLayout(0),
|
| 475 |
+
entries: [
|
| 476 |
+
{ binding: 0, resource: { buffer: this.scratch.hidden_c } },
|
| 477 |
+
{ binding: 1, resource: { buffer: xProjW } },
|
| 478 |
+
{ binding: 2, resource: { buffer: this.scratch.sxBC } },
|
| 479 |
+
],
|
| 480 |
+
}));
|
| 481 |
+
const gemvParams2 = new ArrayBuffer(8);
|
| 482 |
+
new DataView(gemvParams2).setUint32(0, DR + 2 * S, true);
|
| 483 |
+
new DataView(gemvParams2).setUint32(4, I, true);
|
| 484 |
+
pass.setBindGroup(1, this.device.createBindGroup({
|
| 485 |
+
layout: this.pipelines['matmul_gemv'].getBindGroupLayout(1),
|
| 486 |
+
entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(gemvParams2)) }}],
|
| 487 |
+
}));
|
| 488 |
+
pass.dispatchWorkgroups(DR + 2 * S);
|
| 489 |
+
pass.end();
|
| 490 |
+
|
| 491 |
+
// Copy dt_pre, B, C from sxBC into separate buffers
|
| 492 |
+
encoder.copyBufferToBuffer(this.scratch.sxBC, 0, this.scratch.dt_pre, 0, DR * 4);
|
| 493 |
+
encoder.copyBufferToBuffer(this.scratch.sxBC, DR * 4, this.scratch.B_proj, 0, S * 4);
|
| 494 |
+
encoder.copyBufferToBuffer(this.scratch.sxBC, (DR + S) * 4, this.scratch.C_proj, 0, S * 4);
|
| 495 |
+
|
| 496 |
+
// Falcon-Mamba: RMSNorm(dt_pre), RMSNorm(B), RMSNorm(C) before use
|
| 497 |
+
const rmsNwParams_dt = new ArrayBuffer(8);
|
| 498 |
+
new DataView(rmsNwParams_dt).setUint32(0, DR, true);
|
| 499 |
+
new DataView(rmsNwParams_dt).setFloat32(4, CONFIG.rms_eps, true);
|
| 500 |
+
pass = encoder.beginComputePass();
|
| 501 |
+
pass.setPipeline(this.pipelines['rmsnorm_noweight']);
|
| 502 |
+
pass.setBindGroup(0, this.device.createBindGroup({
|
| 503 |
+
layout: this.pipelines['rmsnorm_noweight'].getBindGroupLayout(0),
|
| 504 |
+
entries: [{ binding: 0, resource: { buffer: this.scratch.dt_pre } }],
|
| 505 |
+
}));
|
| 506 |
+
pass.setBindGroup(1, this.device.createBindGroup({
|
| 507 |
+
layout: this.pipelines['rmsnorm_noweight'].getBindGroupLayout(1),
|
| 508 |
+
entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(rmsNwParams_dt)) }}],
|
| 509 |
+
}));
|
| 510 |
+
pass.dispatchWorkgroups(1);
|
| 511 |
+
pass.end();
|
| 512 |
+
|
| 513 |
+
const rmsNwParams_s = new ArrayBuffer(8);
|
| 514 |
+
new DataView(rmsNwParams_s).setUint32(0, S, true);
|
| 515 |
+
new DataView(rmsNwParams_s).setFloat32(4, CONFIG.rms_eps, true);
|
| 516 |
+
pass = encoder.beginComputePass();
|
| 517 |
+
pass.setPipeline(this.pipelines['rmsnorm_noweight']);
|
| 518 |
+
pass.setBindGroup(0, this.device.createBindGroup({
|
| 519 |
+
layout: this.pipelines['rmsnorm_noweight'].getBindGroupLayout(0),
|
| 520 |
+
entries: [{ binding: 0, resource: { buffer: this.scratch.B_proj } }],
|
| 521 |
+
}));
|
| 522 |
+
pass.setBindGroup(1, this.device.createBindGroup({
|
| 523 |
+
layout: this.pipelines['rmsnorm_noweight'].getBindGroupLayout(1),
|
| 524 |
+
entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(rmsNwParams_s)) }}],
|
| 525 |
+
}));
|
| 526 |
+
pass.dispatchWorkgroups(1);
|
| 527 |
+
pass.end();
|
| 528 |
+
|
| 529 |
+
pass = encoder.beginComputePass();
|
| 530 |
+
pass.setPipeline(this.pipelines['rmsnorm_noweight']);
|
| 531 |
+
pass.setBindGroup(0, this.device.createBindGroup({
|
| 532 |
+
layout: this.pipelines['rmsnorm_noweight'].getBindGroupLayout(0),
|
| 533 |
+
entries: [{ binding: 0, resource: { buffer: this.scratch.C_proj } }],
|
| 534 |
+
}));
|
| 535 |
+
pass.setBindGroup(1, this.device.createBindGroup({
|
| 536 |
+
layout: this.pipelines['rmsnorm_noweight'].getBindGroupLayout(1),
|
| 537 |
+
entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(rmsNwParams_s)) }}],
|
| 538 |
+
}));
|
| 539 |
+
pass.dispatchWorkgroups(1);
|
| 540 |
+
pass.end();
|
| 541 |
+
|
| 542 |
+
// matmul_gemv(dt_pre_normalized, dt_proj.weight) β dt [I]
|
| 543 |
+
const dtProjW = await this._getF32Weight(`${prefix}.mixer.dt_proj.weight`);
|
| 544 |
+
pass = encoder.beginComputePass();
|
| 545 |
+
pass.setPipeline(this.pipelines['matmul_gemv']);
|
| 546 |
+
pass.setBindGroup(0, this.device.createBindGroup({
|
| 547 |
+
layout: this.pipelines['matmul_gemv'].getBindGroupLayout(0),
|
| 548 |
+
entries: [
|
| 549 |
+
{ binding: 0, resource: { buffer: this.scratch.dt_pre } },
|
| 550 |
+
{ binding: 1, resource: { buffer: dtProjW } },
|
| 551 |
+
{ binding: 2, resource: { buffer: this.scratch.dt } },
|
| 552 |
+
],
|
| 553 |
+
}));
|
| 554 |
+
const gemvParams3 = new ArrayBuffer(8);
|
| 555 |
+
new DataView(gemvParams3).setUint32(0, I, true);
|
| 556 |
+
new DataView(gemvParams3).setUint32(4, DR, true);
|
| 557 |
+
pass.setBindGroup(1, this.device.createBindGroup({
|
| 558 |
+
layout: this.pipelines['matmul_gemv'].getBindGroupLayout(1),
|
| 559 |
+
entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(gemvParams3)) }}],
|
| 560 |
+
}));
|
| 561 |
+
pass.dispatchWorkgroups(I);
|
| 562 |
+
pass.end();
|
| 563 |
+
|
| 564 |
+
// SSU: selective_state_update
|
| 565 |
+
// ssu(state, hidden_c, dt, A, B, C, D, dt_bias) β hidden_y
|
| 566 |
+
const aLog = await this._getF32Weight(`${prefix}.mixer.A_log`);
|
| 567 |
+
const dWeight = await this._getF32Weight(`${prefix}.mixer.D`);
|
| 568 |
+
const dtBias = await this._getF32Weight(`${prefix}.mixer.dt_proj.bias`);
|
| 569 |
+
pass = encoder.beginComputePass();
|
| 570 |
+
pass.setPipeline(this.pipelines['ssu']);
|
| 571 |
+
pass.setBindGroup(0, this.device.createBindGroup({
|
| 572 |
+
layout: this.pipelines['ssu'].getBindGroupLayout(0),
|
| 573 |
+
entries: [
|
| 574 |
+
{ binding: 0, resource: { buffer: this.state[`layer.${l}.ssm`] } },
|
| 575 |
+
{ binding: 1, resource: { buffer: this.scratch.hidden_c } }, // x (silu'd)
|
| 576 |
+
{ binding: 2, resource: { buffer: this.scratch.dt } },
|
| 577 |
+
{ binding: 3, resource: { buffer: aLog } }, // A (needs -exp transform)
|
| 578 |
+
{ binding: 4, resource: { buffer: this.scratch.B_proj } }, // B
|
| 579 |
+
{ binding: 5, resource: { buffer: this.scratch.C_proj } }, // C
|
| 580 |
+
{ binding: 6, resource: { buffer: dWeight } },
|
| 581 |
+
{ binding: 7, resource: { buffer: dtBias } },
|
| 582 |
+
{ binding: 8, resource: { buffer: this.scratch.hidden_y } },
|
| 583 |
+
],
|
| 584 |
+
}));
|
| 585 |
+
const ssuParams = new ArrayBuffer(8);
|
| 586 |
+
new DataView(ssuParams).setUint32(0, I, true); // H
|
| 587 |
+
new DataView(ssuParams).setUint32(4, S, true); // S
|
| 588 |
+
pass.setBindGroup(1, this.device.createBindGroup({
|
| 589 |
+
layout: this.pipelines['ssu'].getBindGroupLayout(1),
|
| 590 |
+
entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(ssuParams)) }}],
|
| 591 |
+
}));
|
| 592 |
+
pass.dispatchWorkgroups(I); // one workgroup per h
|
| 593 |
+
pass.end();
|
| 594 |
+
|
| 595 |
+
|
| 596 |
+
|
| 597 |
+
// silu(gate) in-place
|
| 598 |
+
pass = encoder.beginComputePass();
|
| 599 |
+
pass.setPipeline(this.pipelines['silu']);
|
| 600 |
+
pass.setBindGroup(0, this.device.createBindGroup({
|
| 601 |
+
layout: this.pipelines['silu'].getBindGroupLayout(0),
|
| 602 |
+
entries: [{ binding: 0, resource: { buffer: this.scratch.gate } }],
|
| 603 |
+
}));
|
| 604 |
+
pass.setBindGroup(1, this.device.createBindGroup({
|
| 605 |
+
layout: this.pipelines['silu'].getBindGroupLayout(1),
|
| 606 |
+
entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(siluParams)) }}],
|
| 607 |
+
}));
|
| 608 |
+
pass.dispatchWorkgroups(Math.ceil(I / 64));
|
| 609 |
+
pass.end();
|
| 610 |
+
|
| 611 |
+
// elementwise_mul: hidden_y *= gate (in-place into hidden_y)
|
| 612 |
+
pass = encoder.beginComputePass();
|
| 613 |
+
pass.setPipeline(this.pipelines['elementwise_mul']);
|
| 614 |
+
pass.setBindGroup(0, this.device.createBindGroup({
|
| 615 |
+
layout: this.pipelines['elementwise_mul'].getBindGroupLayout(0),
|
| 616 |
+
entries: [
|
| 617 |
+
{ binding: 0, resource: { buffer: this.scratch.hidden_y } },
|
| 618 |
+
{ binding: 1, resource: { buffer: this.scratch.gate } },
|
| 619 |
+
],
|
| 620 |
+
}));
|
| 621 |
+
pass.setBindGroup(1, this.device.createBindGroup({
|
| 622 |
+
layout: this.pipelines['elementwise_mul'].getBindGroupLayout(1),
|
| 623 |
+
entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(siluParams)) }}],
|
| 624 |
+
}));
|
| 625 |
+
pass.dispatchWorkgroups(Math.ceil(I / 64));
|
| 626 |
+
pass.end();
|
| 627 |
+
|
| 628 |
+
// matmul_gemv(hidden_y, out_proj.weight) β out_proj_o [H]
|
| 629 |
+
const outProjW = await this._getF32Weight(`${prefix}.mixer.out_proj.weight`);
|
| 630 |
+
pass = encoder.beginComputePass();
|
| 631 |
+
pass.setPipeline(this.pipelines['matmul_gemv']);
|
| 632 |
+
pass.setBindGroup(0, this.device.createBindGroup({
|
| 633 |
+
layout: this.pipelines['matmul_gemv'].getBindGroupLayout(0),
|
| 634 |
+
entries: [
|
| 635 |
+
{ binding: 0, resource: { buffer: this.scratch.hidden_y } },
|
| 636 |
+
{ binding: 1, resource: { buffer: outProjW } },
|
| 637 |
+
{ binding: 2, resource: { buffer: this.scratch.out_proj_o } },
|
| 638 |
+
],
|
| 639 |
+
}));
|
| 640 |
+
const gemvParams4 = new ArrayBuffer(8);
|
| 641 |
+
new DataView(gemvParams4).setUint32(0, H, true);
|
| 642 |
+
new DataView(gemvParams4).setUint32(4, I, true);
|
| 643 |
+
pass.setBindGroup(1, this.device.createBindGroup({
|
| 644 |
+
layout: this.pipelines['matmul_gemv'].getBindGroupLayout(1),
|
| 645 |
+
entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(gemvParams4)) }}],
|
| 646 |
+
}));
|
| 647 |
+
pass.dispatchWorkgroups(H);
|
| 648 |
+
pass.end();
|
| 649 |
+
|
| 650 |
+
// Submit this layer's command buffer
|
| 651 |
+
this.device.queue.submit([encoder.finish()]);
|
| 652 |
+
|
| 653 |
+
// Debug: readback hidden_state after residual for select layers
|
| 654 |
+
// Residual add: hidden_state += out_proj_o
|
| 655 |
+
{
|
| 656 |
+
const enc2 = this.device.createCommandEncoder();
|
| 657 |
+
const addPass = enc2.beginComputePass();
|
| 658 |
+
addPass.setPipeline(this.pipelines['add_residual']);
|
| 659 |
+
addPass.setBindGroup(0, this.device.createBindGroup({
|
| 660 |
+
layout: this.pipelines['add_residual'].getBindGroupLayout(0),
|
| 661 |
+
entries: [
|
| 662 |
+
{ binding: 0, resource: { buffer: this.scratch.hidden_state } },
|
| 663 |
+
{ binding: 1, resource: { buffer: this.scratch.out_proj_o } },
|
| 664 |
+
],
|
| 665 |
+
}));
|
| 666 |
+
const addParams = new ArrayBuffer(4);
|
| 667 |
+
new DataView(addParams).setUint32(0, H, true);
|
| 668 |
+
addPass.setBindGroup(1, this.device.createBindGroup({
|
| 669 |
+
layout: this.pipelines['add_residual'].getBindGroupLayout(1),
|
| 670 |
+
entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(addParams)) }}],
|
| 671 |
+
}));
|
| 672 |
+
addPass.dispatchWorkgroups(Math.ceil(H / 64));
|
| 673 |
+
addPass.end();
|
| 674 |
+
this.device.queue.submit([enc2.finish()]);
|
| 675 |
+
}
|
| 676 |
+
|
| 677 |
+
|
| 678 |
+
}
|
| 679 |
+
|
| 680 |
+
// Final: rmsnorm + lm_head + sample
|
| 681 |
+
await this.device.queue.onSubmittedWorkDone();
|
| 682 |
+
|
| 683 |
+
// rmsnorm(hidden_state, backbone.norm_f.weight) β norm_out
|
| 684 |
+
const normFW = await this._getF32Weight('backbone.norm_f.weight');
|
| 685 |
+
let encoder = this.device.createCommandEncoder();
|
| 686 |
+
let pass = encoder.beginComputePass();
|
| 687 |
+
pass.setPipeline(this.pipelines['rmsnorm']);
|
| 688 |
+
pass.setBindGroup(0, this.device.createBindGroup({
|
| 689 |
+
layout: this.pipelines['rmsnorm'].getBindGroupLayout(0),
|
| 690 |
+
entries: [
|
| 691 |
+
{ binding: 0, resource: { buffer: this.scratch.hidden_state } },
|
| 692 |
+
{ binding: 1, resource: { buffer: normFW } },
|
| 693 |
+
{ binding: 2, resource: { buffer: this.scratch.norm_out } },
|
| 694 |
+
],
|
| 695 |
+
}));
|
| 696 |
+
const finalNormParams = new ArrayBuffer(12);
|
| 697 |
+
new DataView(finalNormParams).setUint32(0, 1, true);
|
| 698 |
+
new DataView(finalNormParams).setUint32(4, H, true);
|
| 699 |
+
new DataView(finalNormParams).setFloat32(8, CONFIG.rms_eps, true);
|
| 700 |
+
pass.setBindGroup(1, this.device.createBindGroup({
|
| 701 |
+
layout: this.pipelines['rmsnorm'].getBindGroupLayout(1),
|
| 702 |
+
entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(finalNormParams)) }}],
|
| 703 |
+
}));
|
| 704 |
+
pass.dispatchWorkgroups(1);
|
| 705 |
+
pass.end();
|
| 706 |
+
|
| 707 |
+
// matmul_gemv(norm_out, lm_head.weight) β logits [V]
|
| 708 |
+
const lmHeadW = await this._getF32Weight('lm_head.weight');
|
| 709 |
+
pass = encoder.beginComputePass();
|
| 710 |
+
pass.setPipeline(this.pipelines['matmul_gemv']);
|
| 711 |
+
pass.setBindGroup(0, this.device.createBindGroup({
|
| 712 |
+
layout: this.pipelines['matmul_gemv'].getBindGroupLayout(0),
|
| 713 |
+
entries: [
|
| 714 |
+
{ binding: 0, resource: { buffer: this.scratch.norm_out } },
|
| 715 |
+
{ binding: 1, resource: { buffer: lmHeadW } },
|
| 716 |
+
{ binding: 2, resource: { buffer: this.scratch.logits } },
|
| 717 |
+
],
|
| 718 |
+
}));
|
| 719 |
+
const gemvFinal = new ArrayBuffer(8);
|
| 720 |
+
new DataView(gemvFinal).setUint32(0, V, true);
|
| 721 |
+
new DataView(gemvFinal).setUint32(4, H, true);
|
| 722 |
+
pass.setBindGroup(1, this.device.createBindGroup({
|
| 723 |
+
layout: this.pipelines['matmul_gemv'].getBindGroupLayout(1),
|
| 724 |
+
entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(gemvFinal)) }}],
|
| 725 |
+
}));
|
| 726 |
+
pass.dispatchWorkgroups(V);
|
| 727 |
+
pass.end();
|
| 728 |
+
|
| 729 |
+
// sample(logits, temperature) β token_out
|
| 730 |
+
pass = encoder.beginComputePass();
|
| 731 |
+
pass.setPipeline(this.pipelines['sample']);
|
| 732 |
+
pass.setBindGroup(0, this.device.createBindGroup({
|
| 733 |
+
layout: this.pipelines['sample'].getBindGroupLayout(0),
|
| 734 |
+
entries: [
|
| 735 |
+
{ binding: 0, resource: { buffer: this.scratch.logits } },
|
| 736 |
+
{ binding: 1, resource: { buffer: this.scratch.token_out } },
|
| 737 |
+
],
|
| 738 |
+
}));
|
| 739 |
+
const sampleParams = new ArrayBuffer(12);
|
| 740 |
+
new DataView(sampleParams).setUint32(0, V, true);
|
| 741 |
+
new DataView(sampleParams).setFloat32(4, 1.0 / 0.75, true); // inv_temperature
|
| 742 |
+
new DataView(sampleParams).setUint32(8, Math.floor(Math.random() * 0xFFFFFFFF), true); // rng_seed
|
| 743 |
+
pass.setBindGroup(1, this.device.createBindGroup({
|
| 744 |
+
layout: this.pipelines['sample'].getBindGroupLayout(1),
|
| 745 |
+
entries: [{ binding: 0, resource: { buffer: this._createUniform(new Uint8Array(sampleParams)) }}],
|
| 746 |
+
}));
|
| 747 |
+
pass.dispatchWorkgroups(1);
|
| 748 |
+
pass.end();
|
| 749 |
+
|
| 750 |
+
this.device.queue.submit([encoder.finish()]);
|
| 751 |
+
await this.device.queue.onSubmittedWorkDone();
|
| 752 |
+
|
| 753 |
+
// Read back the sampled token
|
| 754 |
+
const tokenResult = await this._readback(this.scratch.token_out, 4);
|
| 755 |
+
this._tokenCount++;
|
| 756 |
+
return new Uint32Array(tokenResult.buffer)[0];
|
| 757 |
+
}
|
| 758 |
+
|
| 759 |
+
// ββ Tokenize/detokenize via server ββββββββββββββββββββββββββββββββββββββ
|
| 760 |
+
async tokenize(text) {
|
| 761 |
+
const resp = await fetch('/tokenize', {
|
| 762 |
+
method: 'POST',
|
| 763 |
+
headers: { 'Content-Type': 'application/json' },
|
| 764 |
+
body: JSON.stringify({ text }),
|
| 765 |
+
});
|
| 766 |
+
const data = await resp.json();
|
| 767 |
+
return data.result;
|
| 768 |
+
}
|
| 769 |
+
|
| 770 |
+
async detokenize(tokens) {
|
| 771 |
+
const resp = await fetch('/detokenize', {
|
| 772 |
+
method: 'POST',
|
| 773 |
+
headers: { 'Content-Type': 'application/json' },
|
| 774 |
+
body: JSON.stringify({ tokens }),
|
| 775 |
+
});
|
| 776 |
+
const data = await resp.json();
|
| 777 |
+
return data.result;
|
| 778 |
+
}
|
| 779 |
+
|
| 780 |
+
// ββ Generate text ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 781 |
+
async generate(prompt, maxTokens = 100, temperature = 0.75, onToken = null) {
|
| 782 |
+
if (!this.ready) throw new Error('Call loadWeights() first');
|
| 783 |
+
this._allocateScratch();
|
| 784 |
+
|
| 785 |
+
console.log('[mamba] generate:', prompt, 'max_tokens:', maxTokens);
|
| 786 |
+
|
| 787 |
+
// Tokenize the prompt
|
| 788 |
+
const promptTokens = await this.tokenize(prompt);
|
| 789 |
+
console.log(`[mamba] prompt tokens (${promptTokens.length}):`, promptTokens);
|
| 790 |
+
|
| 791 |
+
// Process prompt tokens through forward pass to build SSM state
|
| 792 |
+
console.log('[mamba] encoding prompt...');
|
| 793 |
+
for (let i = 0; i < promptTokens.length; i++) {
|
| 794 |
+
const t0 = performance.now();
|
| 795 |
+
await this._forwardOneToken(promptTokens[i]);
|
| 796 |
+
const elapsed = performance.now() - t0;
|
| 797 |
+
if (i === 0 || i === promptTokens.length - 1) {
|
| 798 |
+
console.log(`[mamba] prompt token ${i}/${promptTokens.length}: ${promptTokens[i]} (${elapsed.toFixed(0)}ms)`);
|
| 799 |
+
}
|
| 800 |
+
}
|
| 801 |
+
console.log('[mamba] prompt encoded, generating...');
|
| 802 |
+
|
| 803 |
+
// Get the last prompt token's output as first generation input
|
| 804 |
+
const generated = [];
|
| 805 |
+
// The last _forwardOneToken already produced the next-token prediction
|
| 806 |
+
// We need to read it back
|
| 807 |
+
const firstResult = await this._readback(this.scratch.token_out, 4);
|
| 808 |
+
let inputToken = new Uint32Array(firstResult.buffer)[0];
|
| 809 |
+
generated.push(inputToken);
|
| 810 |
+
console.log(`[mamba] first generated token: ${inputToken}`);
|
| 811 |
+
if (onToken) onToken(inputToken, 0);
|
| 812 |
+
|
| 813 |
+
for (let step = 1; step < maxTokens; step++) {
|
| 814 |
+
const t0 = performance.now();
|
| 815 |
+
try {
|
| 816 |
+
const nextToken = await this._forwardOneToken(inputToken);
|
| 817 |
+
const elapsed = performance.now() - t0;
|
| 818 |
+
if (step < 5 || step % 20 === 0) {
|
| 819 |
+
console.log(`[mamba] step ${step}: token=${nextToken} (${elapsed.toFixed(0)}ms)`);
|
| 820 |
+
}
|
| 821 |
+
generated.push(nextToken);
|
| 822 |
+
inputToken = nextToken;
|
| 823 |
+
if (onToken) onToken(nextToken, step);
|
| 824 |
+
if (nextToken === 11 || nextToken === 10 || nextToken === 0) break; // EOS=11, im_end=10, PAD=0
|
| 825 |
+
} catch (e) {
|
| 826 |
+
console.error(`[mamba] step ${step} failed:`, e.message);
|
| 827 |
+
break;
|
| 828 |
+
}
|
| 829 |
+
}
|
| 830 |
+
|
| 831 |
+
// Decode the generated tokens
|
| 832 |
+
const text = await this.detokenize(generated);
|
| 833 |
+
console.log(`[mamba] generated ${generated.length} tokens`);
|
| 834 |
+
return text;
|
| 835 |
+
}
|
| 836 |
+
|
| 837 |
+
// ββ Helper: get weight buffer by name ββββββββββββββββββββββββββββββββββ
|
| 838 |
+
_getWeight(name) {
|
| 839 |
+
return this.weights[name] || null;
|
| 840 |
+
}
|
| 841 |
+
|
| 842 |
+
// ββ Create a uniform buffer with typed data ββββββββββββββββββββββββββββ
|
| 843 |
+
_createUniform(data) {
|
| 844 |
+
const buf = this.device.createBuffer({
|
| 845 |
+
size: data.byteLength,
|
| 846 |
+
usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
|
| 847 |
+
mappedAtCreation: true,
|
| 848 |
+
});
|
| 849 |
+
new Uint8Array(buf.getMappedRange()).set(new Uint8Array(data.buffer));
|
| 850 |
+
buf.unmap();
|
| 851 |
+
return buf;
|
| 852 |
+
}
|
| 853 |
+
|
| 854 |
+
// ββ BF16 β F32 conversion for a weight tensor βββββββββββββββββββββββββ
|
| 855 |
+
async _convertBF16toF32(weightInfo) {
|
| 856 |
+
if (weightInfo.dtype !== 'BF16' || weightInfo.f32buffer) return weightInfo;
|
| 857 |
+
|
| 858 |
+
const numBF16 = weightInfo.byteLen / 2; // each bf16 is 2 bytes
|
| 859 |
+
const numPairs = weightInfo.byteLen / 4; // each u32 holds 2 bf16
|
| 860 |
+
const f32Bytes = numBF16 * 4;
|
| 861 |
+
|
| 862 |
+
// Create output F32 buffer
|
| 863 |
+
const f32Buf = this._createBuffer(
|
| 864 |
+
f32Bytes,
|
| 865 |
+
GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST,
|
| 866 |
+
weightInfo.buffer.label + '_f32'
|
| 867 |
+
);
|
| 868 |
+
|
| 869 |
+
// Create uniform for params
|
| 870 |
+
const paramBuf = this._createUniform(new Uint32Array([numPairs]));
|
| 871 |
+
|
| 872 |
+
// Create bind groups
|
| 873 |
+
const pipeline = this.pipelines['bf16_to_f32'];
|
| 874 |
+
const bg0 = this.device.createBindGroup({
|
| 875 |
+
layout: pipeline.getBindGroupLayout(0),
|
| 876 |
+
entries: [
|
| 877 |
+
{ binding: 0, resource: { buffer: weightInfo.buffer } },
|
| 878 |
+
{ binding: 1, resource: { buffer: f32Buf } },
|
| 879 |
+
],
|
| 880 |
+
});
|
| 881 |
+
const bg1 = this.device.createBindGroup({
|
| 882 |
+
layout: pipeline.getBindGroupLayout(1),
|
| 883 |
+
entries: [
|
| 884 |
+
{ binding: 0, resource: { buffer: paramBuf } },
|
| 885 |
+
],
|
| 886 |
+
});
|
| 887 |
+
|
| 888 |
+
// Dispatch
|
| 889 |
+
const encoder = this.device.createCommandEncoder();
|
| 890 |
+
const pass = encoder.beginComputePass();
|
| 891 |
+
pass.setPipeline(pipeline);
|
| 892 |
+
pass.setBindGroup(0, bg0);
|
| 893 |
+
pass.setBindGroup(1, bg1);
|
| 894 |
+
pass.dispatchWorkgroups(Math.ceil(numPairs / 64));
|
| 895 |
+
pass.end();
|
| 896 |
+
this.device.queue.submit([encoder.finish()]);
|
| 897 |
+
await this.device.queue.onSubmittedWorkDone();
|
| 898 |
+
|
| 899 |
+
// Cache the F32 buffer
|
| 900 |
+
weightInfo.f32buffer = f32Buf;
|
| 901 |
+
weightInfo.f32size = f32Bytes;
|
| 902 |
+
paramBuf.destroy();
|
| 903 |
+
return weightInfo;
|
| 904 |
+
}
|
| 905 |
+
|
| 906 |
+
// ββ Get F32 weight buffer (already converted during load) βββββββββββββββ
|
| 907 |
+
async _getF32Weight(name) {
|
| 908 |
+
const w = this.weights[name];
|
| 909 |
+
if (!w) throw new Error(`Missing weight: ${name}`);
|
| 910 |
+
return w.buffer;
|
| 911 |
+
}
|
| 912 |
+
|
| 913 |
+
// ββ Dispatch a shader with auto bind group creation βββββββββββββββββββββ
|
| 914 |
+
_dispatchShader(encoder, shaderName, storageBuffers, uniformData) {
|
| 915 |
+
const pipeline = this.pipelines[shaderName];
|
| 916 |
+
const pass = encoder.beginComputePass();
|
| 917 |
+
pass.setPipeline(pipeline);
|
| 918 |
+
|
| 919 |
+
// Bind group 0: storage buffers
|
| 920 |
+
const entries0 = storageBuffers.map((buf, i) => ({
|
| 921 |
+
binding: i, resource: { buffer: buf }
|
| 922 |
+
}));
|
| 923 |
+
const bg0 = this.device.createBindGroup({
|
| 924 |
+
layout: pipeline.getBindGroupLayout(0),
|
| 925 |
+
entries: entries0,
|
| 926 |
+
});
|
| 927 |
+
pass.setBindGroup(0, bg0);
|
| 928 |
+
|
| 929 |
+
// Bind group 1: uniforms (if provided)
|
| 930 |
+
if (uniformData) {
|
| 931 |
+
const ubuf = this._createUniform(uniformData);
|
| 932 |
+
const bg1 = this.device.createBindGroup({
|
| 933 |
+
layout: pipeline.getBindGroupLayout(1),
|
| 934 |
+
entries: [{ binding: 0, resource: { buffer: ubuf } }],
|
| 935 |
+
});
|
| 936 |
+
pass.setBindGroup(1, bg1);
|
| 937 |
+
// Note: ubuf leaks β for production, cache these. Fine for proof-of-concept.
|
| 938 |
+
}
|
| 939 |
+
|
| 940 |
+
return pass; // caller sets dispatch count and calls pass.end()
|
| 941 |
+
}
|
| 942 |
+
|
| 943 |
+
// ββ Cleanup ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 944 |
+
destroy() {
|
| 945 |
+
for (const w of Object.values(this.weights)) w.buffer.destroy();
|
| 946 |
+
for (const s of Object.values(this.state)) s.destroy();
|
| 947 |
+
this.weights = {};
|
| 948 |
+
this.state = {};
|
| 949 |
+
this.ready = false;
|
| 950 |
+
}
|
| 951 |
+
}
|
| 952 |
+
|
| 953 |
+
// ES module export for browser <script type="module">
|
| 954 |
+
export { MambaRuntime, CONFIG };
|