let wasmModulePromise = null; function normalizeLayout(layoutHeader) { if (!layoutHeader || layoutHeader.length < 13) { throw new Error("BitNet layout_header must contain at least 13 entries"); } const header = Int32Array.from(Array.from(layoutHeader, Number)); if (header[0] !== 1 || header[1] !== 16 || header[2] !== 32 || header[9] !== 1) { throw new Error("Unsupported BitNet WASM layout; expected v1 16x32 interleave mode 1"); } return { header, logicalOut: header[3], logicalIn: header[4], paddedOut: header[5], paddedIn: header[6], scaleGranularity: header[7], scaleGroupSize: header[8], segmentCount: header[11], }; } function resolveUrl(path, baseUrl) { return new URL(path, baseUrl).toString(); } function sleep(ms) { return new Promise((resolve) => setTimeout(resolve, ms)); } async function fetchWithRetry(url, options = {}) { const attempts = Math.max(1, Number(options.attempts || 5)); let lastError = null; for (let attempt = 0; attempt < attempts; attempt += 1) { try { const response = await fetch(url); if (response.ok) return response; if (response.status < 500 && response.status !== 408 && response.status !== 429) { throw new Error(`failed to fetch ${url}: ${response.status}`); } lastError = new Error(`failed to fetch ${url}: ${response.status}`); } catch (error) { lastError = error; } if (attempt < attempts - 1) { await sleep(Math.min(2000, 150 * 2 ** attempt)); } } throw lastError || new Error(`failed to fetch ${url}`); } async function fetchTensor(entry, baseUrl, TypedArray) { const url = resolveUrl(entry.path, baseUrl); const response = await fetchWithRetry(url); if (!response.ok) { throw new Error(`failed to fetch ${entry.path}: ${response.status}`); } return new TypedArray(await response.arrayBuffer()); } function tensorType(entry) { if (entry.dtype === "uint8") return Uint8Array; if (entry.dtype === "int32") return Int32Array; if (entry.dtype === "float32") return Float32Array; throw new Error(`unsupported tensor dtype: ${entry.dtype}`); } async function ensureBitNetWasm() { if (!wasmModulePromise) { wasmModulePromise = (async () => { let module; try { module = await import(new URL("model_stack_bitnet_wasm.js", import.meta.url).href); } catch (error) { module = await import(new URL("pkg/model_stack_bitnet_wasm.js", import.meta.url).href); } await module.default(); return module; })(); } return wasmModulePromise; } export class BitNetLinearWASM { constructor(bundle) { this.layout = normalizeLayout(bundle.layoutHeader); this.packedWeight = bundle.packedWeight instanceof Uint8Array ? bundle.packedWeight : new Uint8Array(bundle.packedWeight); this.scaleValues = bundle.scaleValues instanceof Float32Array ? bundle.scaleValues : new Float32Array(bundle.scaleValues); this.segmentOffsets = bundle.segmentOffsets instanceof Int32Array ? bundle.segmentOffsets : Int32Array.from(bundle.segmentOffsets || []); this.bias = bundle.bias ? (bundle.bias instanceof Float32Array ? bundle.bias : new Float32Array(bundle.bias)) : new Float32Array(0); this.inputScales = bundle.inputScales ? (bundle.inputScales instanceof Float32Array ? bundle.inputScales : new Float32Array(bundle.inputScales)) : new Float32Array([1]); this.inputQuantMode = bundle.inputQuantMode ?? 0; this.inputQuantBits = bundle.inputQuantBits ?? 8; this.inputScaleRows = bundle.inputScaleRows ?? 1; } static async fromManifestLayer(manifest, layer, manifestUrl, options = {}) { const progress = typeof options.progress === "function" ? options.progress : () => {}; const index = Number(options.index || 0); const total = Number(options.total || 0); const name = String(options.name || layer.name || "layer"); const label = total ? `${index}/${total}: ${name}` : name; const baseUrl = new URL(".", manifestUrl).toString(); const tensors = layer.tensors; const layersBaseUrl = resolveUrl("layers/", baseUrl); progress({ phase: "layer_tensors", index, total, name, message: `Loading BitNet WASM tensors ${label}` }); const [packedWeight, scaleValues, segmentOffsets, bias, inputScales] = await Promise.all([ fetchTensor(tensors.packed_weight, layersBaseUrl, Uint8Array), fetchTensor(tensors.scale_values, layersBaseUrl, Float32Array), fetchTensor(tensors.segment_offsets, layersBaseUrl, Int32Array), tensors.bias ? fetchTensor(tensors.bias, layersBaseUrl, Float32Array) : Promise.resolve(null), fetchTensor(tensors.act_scale, layersBaseUrl, tensorType(tensors.act_scale)), ]); progress({ phase: "layer_ready", index, total, name, message: `BitNet WASM layer ${label} ready` }); return new BitNetLinearWASM({ layoutHeader: layer.layout_header, packedWeight, scaleValues, segmentOffsets, bias, inputScales, inputQuantMode: layer.act_quant_mode === "none" ? 0 : 1, inputQuantBits: layer.act_quant_bits, inputScaleRows: layer.act_quant_mode === "static_int8" ? 1 : 1, }); } async run(input, rows = 1) { const x = input instanceof Float32Array ? input : new Float32Array(input); if (x.length !== rows * this.layout.logicalIn) { throw new Error(`BitNet input length mismatch: got ${x.length}, expected ${rows * this.layout.logicalIn}`); } const wasm = await ensureBitNetWasm(); return wasm.bitnet_linear_f32( x, this.packedWeight, this.scaleValues, this.segmentOffsets, this.bias, this.layout.header, this.inputScales, rows, this.inputQuantMode, this.inputQuantBits, this.inputScaleRows, ); } }