import { BitNetLinearWebGPU } from "./bitnet_webgpu.js"; import { BitNetLinearWASM } from "./bitnet_wasm_runtime.js"; function resolveUrl(path, baseUrl) { return new URL(path, baseUrl).toString(); } function sleep(ms) { return new Promise((resolve) => setTimeout(resolve, ms)); } async function fetchWithRetry(url, options = {}) { const attempts = Math.max(1, Number(options.attempts || 5)); let lastError = null; for (let attempt = 0; attempt < attempts; attempt += 1) { try { const response = await fetch(url); if (response.ok) return response; if (response.status < 500 && response.status !== 408 && response.status !== 429) { throw new Error(`failed to fetch ${url}: ${response.status}`); } lastError = new Error(`failed to fetch ${url}: ${response.status}`); } catch (error) { lastError = error; } if (attempt < attempts - 1) { await sleep(Math.min(2000, 150 * 2 ** attempt)); } } throw lastError || new Error(`failed to fetch ${url}`); } async function fetchJson(url) { const response = await fetchWithRetry(url); if (!response.ok) { throw new Error(`failed to fetch ${url}: ${response.status}`); } return response.json(); } async function fetchFloatTensor(entry, baseUrl) { const url = resolveUrl(entry.path, baseUrl); const response = await fetchWithRetry(url); if (!response.ok) { throw new Error(`failed to fetch ${entry.path}: ${response.status}`); } const buffer = await response.arrayBuffer(); const dtype = String(entry.dtype || "float32").toLowerCase(); if (dtype === "float16" || dtype === "fp16" || dtype === "f16") { return { data: float16ArrayToFloat32(new Uint16Array(buffer)), shape: entry.shape }; } return { data: new Float32Array(buffer), shape: entry.shape }; } function float16ArrayToFloat32(values) { const out = new Float32Array(values.length); for (let i = 0; i < values.length; i += 1) { out[i] = float16ToFloat32(values[i]); } return out; } function float16ToFloat32(value) { const sign = (value & 0x8000) ? -1 : 1; const exponent = (value >> 10) & 0x1f; const fraction = value & 0x03ff; if (exponent === 0) { return fraction === 0 ? sign * 0 : sign * 2 ** -14 * (fraction / 1024); } if (exponent === 0x1f) { return fraction === 0 ? sign * Infinity : NaN; } return sign * 2 ** (exponent - 15) * (1 + fraction / 1024); } function zeros(length) { return new Float32Array(length); } function toUint32IdArray(ids) { if (ids instanceof Uint32Array) return ids; return Uint32Array.from(Array.from(ids || [], Number).filter((id) => Number.isFinite(id))); } function addInPlace(dst, src) { for (let i = 0; i < dst.length; i += 1) { dst[i] += src[i]; } return dst; } function l2Normalize(values) { let norm = 0; for (let i = 0; i < values.length; i += 1) norm += values[i] * values[i]; norm = Math.sqrt(Math.max(norm, 1e-12)); const out = new Float32Array(values.length); for (let i = 0; i < values.length; i += 1) out[i] = values[i] / norm; return out; } function meanPoolRows(x, rows, cols, attentionMask = null) { const out = new Float32Array(cols); let denom = 0; for (let r = 0; r < rows; r += 1) { const weight = attentionMask ? Number(attentionMask[r] || 0) : 1; if (weight <= 0) continue; denom += weight; const rowOffset = r * cols; for (let c = 0; c < cols; c += 1) out[c] += x[rowOffset + c] * weight; } denom = Math.max(denom, 1); for (let c = 0; c < cols; c += 1) out[c] /= denom; return out; } function appendRows(existing, next) { if (!existing || existing.length === 0) return next.slice(); const out = new Float32Array(existing.length + next.length); out.set(existing, 0); out.set(next, existing.length); return out; } function appendCachedRows(cache, field, next) { const source = next instanceof Float32Array ? next : new Float32Array(next); const lengthField = `${field}Length`; const capacityField = `${field}Capacity`; const currentLength = Number(cache[lengthField] || 0); const required = currentLength + source.length; let storage = cache[field]; if (!storage || storage.length < required) { let capacity = Math.max(required, Number(cache[capacityField] || 0), source.length * 8); while (capacity < required) capacity *= 2; const grown = new Float32Array(capacity); if (storage && currentLength > 0) grown.set(storage.subarray(0, currentLength), 0); storage = grown; cache[field] = storage; cache[capacityField] = capacity; } storage.set(source, currentLength); cache[lengthField] = required; return storage.subarray(0, required); } function layerNorm(x, rows, cols, weight, bias, eps = 1e-5) { const out = new Float32Array(x.length); for (let r = 0; r < rows; r += 1) { let mean = 0; for (let c = 0; c < cols; c += 1) mean += x[r * cols + c]; mean /= cols; let variance = 0; for (let c = 0; c < cols; c += 1) { const d = x[r * cols + c] - mean; variance += d * d; } const inv = 1 / Math.sqrt(variance / cols + eps); for (let c = 0; c < cols; c += 1) { out[r * cols + c] = (x[r * cols + c] - mean) * inv * weight[c] + (bias ? bias[c] : 0); } } return out; } function rmsNorm(x, rows, cols, weight, eps = 1e-6) { const out = new Float32Array(x.length); for (let r = 0; r < rows; r += 1) { let meanSq = 0; const rowOffset = r * cols; for (let c = 0; c < cols; c += 1) { const value = x[rowOffset + c]; meanSq += value * value; } const inv = 1 / Math.sqrt(meanSq / cols + eps); for (let c = 0; c < cols; c += 1) { out[rowOffset + c] = x[rowOffset + c] * inv * weight[c]; } } return out; } function silu(x) { const out = new Float32Array(x.length); for (let i = 0; i < x.length; i += 1) { out[i] = x[i] / (1 + Math.exp(-x[i])); } return out; } function gelu(x) { const out = new Float32Array(x.length); for (let i = 0; i < x.length; i += 1) { const v = x[i]; out[i] = 0.5 * v * (1 + Math.tanh(Math.sqrt(2 / Math.PI) * (v + 0.044715 * v * v * v))); } return out; } function activate(x, name) { const normalized = String(name || "silu").toLowerCase(); if (normalized === "gelu") return gelu(x); return silu(x); } function gatedActivation(x, rows, cols, name) { const out = new Float32Array(rows * cols); const gateName = String(name || "swiglu").toLowerCase(); for (let row = 0; row < rows; row += 1) { const inputOffset = row * cols * 2; const outputOffset = row * cols; for (let i = 0; i < cols; i += 1) { const a = x[inputOffset + i]; const b = x[inputOffset + cols + i]; const activated = gateName === "geglu" ? 0.5 * a * (1 + Math.tanh(Math.sqrt(2 / Math.PI) * (a + 0.044715 * a * a * a))) : a / (1 + Math.exp(-a)); out[outputOffset + i] = activated * b; } } return out; } function embed(tokens, embedding, dModel) { const out = new Float32Array(tokens.length * dModel); for (let t = 0; t < tokens.length; t += 1) { const token = tokens[t]; out.set(embedding.subarray(token * dModel, token * dModel + dModel), t * dModel); } return out; } function addPositionEmbeddingInPlace(x, positionEmbedding, dModel) { if (!positionEmbedding) return x; for (let t = 0; t < x.length / dModel; t += 1) { const src = t * dModel; for (let c = 0; c < dModel; c += 1) { x[src + c] += positionEmbedding[src + c]; } } return x; } function traceTensor(name, tensor, shape) { let maxAbs = 0; let sum = 0; let sumSq = 0; for (let i = 0; i < tensor.length; i += 1) { const value = Number(tensor[i]); const abs = Math.abs(value); if (abs > maxAbs) maxAbs = abs; sum += value; sumSq += value * value; } return { name, shape, len: tensor.length, maxAbs, mean: tensor.length ? sum / tensor.length : 0, rms: tensor.length ? Math.sqrt(sumSq / tensor.length) : 0, values: Array.from(tensor), }; } class DenseLinear { constructor(name, weightTensor, biasTensor = null) { if (!weightTensor?.data || !Array.isArray(weightTensor.shape) || weightTensor.shape.length !== 2) { throw new Error(`dense linear ${name} is missing a rank-2 weight tensor`); } this.name = name; this.weight = weightTensor.data; this.bias = biasTensor?.data || null; this.outFeatures = Number(weightTensor.shape[0]); this.inFeatures = Number(weightTensor.shape[1]); this.layout = { logicalIn: this.inFeatures, logicalOut: this.outFeatures, }; } async run(input, rows) { const rowCount = Number(rows || 0); if (rowCount <= 0) return new Float32Array(0); if (input.length < rowCount * this.inFeatures) { throw new Error( `dense linear ${this.name} input too small: got ${input.length}, expected ${rowCount * this.inFeatures}`, ); } const out = new Float32Array(rowCount * this.outFeatures); for (let r = 0; r < rowCount; r += 1) { const inputOffset = r * this.inFeatures; const outputOffset = r * this.outFeatures; for (let o = 0; o < this.outFeatures; o += 1) { let sum = this.bias ? this.bias[o] : 0; const weightOffset = o * this.inFeatures; for (let i = 0; i < this.inFeatures; i += 1) { sum += input[inputOffset + i] * this.weight[weightOffset + i]; } out[outputOffset + o] = sum; } } return out; } } function splitHeads(x, seqLen, nHeads, headDim) { const out = []; for (let h = 0; h < nHeads; h += 1) { const head = new Float32Array(seqLen * headDim); for (let t = 0; t < seqLen; t += 1) { const src = t * nHeads * headDim + h * headDim; head.set(x.subarray(src, src + headDim), t * headDim); } out.push(head); } return out; } function mergeHeads(heads, seqLen, nHeads, headDim) { const out = new Float32Array(seqLen * nHeads * headDim); for (let h = 0; h < nHeads; h += 1) { for (let t = 0; t < seqLen; t += 1) { out.set(heads[h].subarray(t * headDim, t * headDim + headDim), t * nHeads * headDim + h * headDim); } } return out; } function attention(q, k, v, qLen, kvLen, nHeads, headDim, causal, pastLen = 0) { const qh = splitHeads(q, qLen, nHeads, headDim); const kh = splitHeads(k, kvLen, nHeads, headDim); const vh = splitHeads(v, kvLen, nHeads, headDim); const outHeads = []; const scale = 1 / Math.sqrt(headDim); for (let h = 0; h < nHeads; h += 1) { const out = new Float32Array(qLen * headDim); for (let i = 0; i < qLen; i += 1) { const scores = new Float32Array(kvLen); let maxScore = -Infinity; for (let j = 0; j < kvLen; j += 1) { let score = causal && j > pastLen + i ? -1e30 : 0; if (score > -1e20) { for (let d = 0; d < headDim; d += 1) { score += qh[h][i * headDim + d] * kh[h][j * headDim + d] * scale; } } scores[j] = score; maxScore = Math.max(maxScore, score); } let denom = 0; for (let j = 0; j < kvLen; j += 1) { scores[j] = Math.exp(scores[j] - maxScore); denom += scores[j]; } for (let d = 0; d < headDim; d += 1) { let sum = 0; for (let j = 0; j < kvLen; j += 1) { sum += (scores[j] / denom) * vh[h][j * headDim + d]; } out[i * headDim + d] = sum; } } outHeads.push(out); } return mergeHeads(outHeads, qLen, nHeads, headDim); } function decoderUsesRotary(manifest, graph) { const positional = String(manifest?.model?.positional || graph?.positional || "").toLowerCase(); return positional === "apply_rotary" || positional === "rotary" || positional === "rope"; } function rotaryBase(manifest, graph) { const value = Number(manifest?.model?.rope_theta || graph?.rope_theta || 1000000); return Number.isFinite(value) && value > 0 ? value : 1000000; } function applyRotaryMergedInPlace(q, k, seqLen, nHeads, headDim, baseTheta, startPosition = 0) { if (headDim % 2 !== 0) throw new Error("RoPE head_dim must be even"); const half = headDim / 2; const invFreq = new Float32Array(half); for (let i = 0; i < half; i += 1) { invFreq[i] = 1 / (baseTheta ** ((2 * i) / headDim)); } for (let t = 0; t < seqLen; t += 1) { const position = startPosition + t; for (let h = 0; h < nHeads; h += 1) { const baseOffset = t * nHeads * headDim + h * headDim; for (let i = 0; i < half; i += 1) { const angle = position * invFreq[i]; const cos = Math.cos(angle); const sin = Math.sin(angle); const left = baseOffset + i; const right = baseOffset + i + half; const q1 = q[left]; const q2 = q[right]; const k1 = k[left]; const k2 = k[right]; q[left] = q1 * cos - q2 * sin; q[right] = q2 * cos + q1 * sin; k[left] = k1 * cos - k2 * sin; k[right] = k2 * cos + k1 * sin; } } } return [q, k]; } export class BitNetEncoderDecoderWebGPU { constructor(device, manifest, manifestUrl, denseTensors, linears) { if (manifest.graph?.architecture !== "encoder_decoder") { throw new Error("manifest is not an encoder_decoder browser BitNet bundle"); } this.device = device; this.manifest = manifest; this.manifestUrl = manifestUrl; this.dense = denseTensors; this.linears = linears; this.denseLinears = {}; this.graph = manifest.graph; this.wasmOps = null; this.decoderRotary = decoderUsesRotary(manifest, this.graph); this.decoderRotaryBase = rotaryBase(manifest, this.graph); } static async fromManifestUrl(device, manifestUrl, options = {}) { const progress = typeof options.progress === "function" ? options.progress : () => {}; progress({ phase: "manifest", message: "Loading model manifest" }); const manifest = options.manifest || await fetchJson(manifestUrl); const baseUrl = new URL(".", manifestUrl).toString(); const dense = {}; const denseEntries = Object.entries(manifest.dense_tensors || {}); for (const [index, [name, entry]] of denseEntries.entries()) { progress({ phase: "dense", index: index + 1, total: denseEntries.length, name, message: `Loading dense tensor ${index + 1}/${denseEntries.length}: ${name}`, }); dense[name] = await fetchFloatTensor(entry, baseUrl); } progress({ phase: "dense_ready", index: denseEntries.length, total: denseEntries.length, message: "Dense tensors ready", }); const linears = {}; const layers = manifest.layers || []; const layerConcurrency = Math.max(1, Math.min(Number(options.layerConcurrency || 4), layers.length || 1)); progress({ phase: "prepare_layers", index: 0, total: layers.length, message: `Preparing ${layers.length} BitNet layers (${layerConcurrency} parallel)`, }); let nextLayer = 0; let completedLayers = 0; async function loadLayerWorker() { while (nextLayer < layers.length) { const index = nextLayer; nextLayer += 1; const layer = layers[index]; progress({ phase: "layer", index: index + 1, total: layers.length, name: layer.name, message: `Loading BitNet layer ${index + 1}/${layers.length}: ${layer.name}`, }); linears[layer.name] = await BitNetLinearWebGPU.fromManifestLayer(device, manifest, layer, manifestUrl, { progress, index: index + 1, total: layers.length, name: layer.name, }); completedLayers += 1; progress({ phase: "layer_ready", index: completedLayers, total: layers.length, name: layer.name, message: `BitNet layer ${completedLayers}/${layers.length} ready: ${layer.name}`, }); } } await Promise.all(Array.from({ length: Math.min(layerConcurrency, layers.length) }, () => loadLayerWorker())); progress({ phase: "ready", message: "Model runtime ready" }); return new BitNetEncoderDecoderWebGPU(device, manifest, manifestUrl, dense, linears); } linear(name) { const layer = this.linears[name]; if (layer) return layer; if (this.denseLinears[name]) return this.denseLinears[name]; const weight = this.dense[`${name}.weight`]; if (!weight) throw new Error(`missing linear layer: ${name}`); const bias = this.dense[`${name}.bias`] || null; const denseLayer = new DenseLinear(name, weight, bias); this.denseLinears[name] = denseLayer; return denseLayer; } linear3(firstName, secondName, thirdName, input, rows) { const first = this.linear(firstName); const second = this.linear(secondName); const third = this.linear(thirdName); if (this.wasmOps?.bitnet_linear3_f32 && first.handle && second.handle && third.handle) { const merged = this.wasmOps.bitnet_linear3_f32(first.handle, second.handle, third.handle, input, rows); const firstLen = rows * first.layout.logicalOut; const secondLen = rows * second.layout.logicalOut; return [ merged.slice(0, firstLen), merged.slice(firstLen, firstLen + secondLen), merged.slice(firstLen + secondLen), ]; } return [first.run(input, rows), second.run(input, rows), third.run(input, rows)]; } linear2(firstName, secondName, input, rows) { const first = this.linear(firstName); const second = this.linear(secondName); if (this.wasmOps?.bitnet_linear2_f32 && first.handle && second.handle) { const merged = this.wasmOps.bitnet_linear2_f32(first.handle, second.handle, input, rows); const firstLen = rows * first.layout.logicalOut; return [merged.slice(0, firstLen), merged.slice(firstLen)]; } return [first.run(input, rows), second.run(input, rows)]; } decoderLayerHandle(index) { if (!this.wasmOps?.DecoderLayerHandle) return null; const names = [ `decoder.${index}.self_attn_block.attn.w_q`, `decoder.${index}.self_attn_block.attn.w_k`, `decoder.${index}.self_attn_block.attn.w_v`, `decoder.${index}.self_attn_block.attn.w_o`, `decoder.${index}.self_attn_block.mlp.w_in`, `decoder.${index}.self_attn_block.mlp.w_out`, `decoder.${index}.cross_block.cross.w_q`, `decoder.${index}.cross_block.cross.w_k`, `decoder.${index}.cross_block.cross.w_v`, `decoder.${index}.cross_block.cross.w_o`, `decoder.${index}.cross_block.mlp.w_in`, `decoder.${index}.cross_block.mlp.w_out`, ]; const layers = names.map((name) => this.linear(name)); if (!layers.every((layer) => layer?.handle)) return null; return new this.wasmOps.DecoderLayerHandle( ...layers.map((layer) => layer.handle), this.tensor(`decoder.${index}.self_attn_block.n1.weight`), this.dense[`decoder.${index}.self_attn_block.n1.bias`]?.data || new Float32Array(0), this.tensor(`decoder.${index}.self_attn_block.n2.weight`), this.dense[`decoder.${index}.self_attn_block.n2.bias`]?.data || new Float32Array(0), this.tensor(`decoder.${index}.cross_block.n1.weight`), this.dense[`decoder.${index}.cross_block.n1.bias`]?.data || new Float32Array(0), this.tensor(`decoder.${index}.cross_block.n2.weight`), this.dense[`decoder.${index}.cross_block.n2.bias`]?.data || new Float32Array(0), String(this.graph.activation || "silu"), this.graph.d_model, this.graph.n_heads, this.graph.head_dim, this.decoderRotary ? this.decoderRotaryBase : 0, ); } tensor(name) { const tensor = this.dense[name]; if (!tensor) throw new Error(`missing dense tensor: ${name}`); return tensor.data; } norm(prefix, x, rows) { const weight = this.tensor(`${prefix}.weight`); const bias = this.dense[`${prefix}.bias`]?.data || null; if (this.wasmOps?.layer_norm_f32 && bias) { return this.wasmOps.layer_norm_f32(x, weight, bias, rows, this.graph.d_model, 1e-5); } if (bias) { return layerNorm(x, rows, this.graph.d_model, weight, bias); } return rmsNorm( x, rows, this.graph.d_model, weight, Number(this.manifest?.model?.rms_norm_eps || this.graph?.rms_norm_eps || 1e-6), ); } attention(q, k, v, qLen, kvLen, causal, pastLen = 0) { if (this.wasmOps?.attention_f32) { return this.wasmOps.attention_f32( q, k, v, qLen, kvLen, this.graph.n_heads, this.graph.head_dim, Boolean(causal), Number(pastLen || 0), ); } return attention(q, k, v, qLen, kvLen, this.graph.n_heads, this.graph.head_dim, causal, pastLen); } async attentionBlock(prefix, x, seqLen, kv, kvLen, causal) { const dModel = this.graph.d_model; const nHeads = this.graph.n_heads; const headDim = this.graph.head_dim; let q; let k; let v; const kInput = kv || x; const kRows = kvLen || seqLen; if (!kv) { [q, k, v] = this.linear3(`${prefix}.w_q`, `${prefix}.w_k`, `${prefix}.w_v`, x, seqLen); } else { q = await this.linear(`${prefix}.w_q`).run(x, seqLen); [k, v] = this.linear2(`${prefix}.w_k`, `${prefix}.w_v`, kInput, kRows); } if (causal && this.decoderRotary) { applyRotaryMergedInPlace(q, k, seqLen, nHeads, headDim, this.decoderRotaryBase, 0); } const merged = this.attention(q, k, v, seqLen, kRows, causal); return this.linear(`${prefix}.w_o`).run(merged, seqLen); } async selfAttentionIncremental(prefix, x, layerCache) { const nHeads = this.graph.n_heads; const headDim = this.graph.head_dim; const [q, kNew, vNew] = this.linear3(`${prefix}.w_q`, `${prefix}.w_k`, `${prefix}.w_v`, x, 1); const position = Number(layerCache.selfLen || 0); if (this.decoderRotary) { applyRotaryMergedInPlace(q, kNew, 1, nHeads, headDim, this.decoderRotaryBase, position); } if (this.wasmOps?.AttentionKvCache) { layerCache.selfAttention ??= new this.wasmOps.AttentionKvCache(nHeads, headDim); const merged = layerCache.selfAttention.append_self_attention(q, kNew, vNew, 1, false); layerCache.selfLen = layerCache.selfAttention.len(); return this.linear(`${prefix}.w_o`).run(merged, 1); } layerCache.selfK = appendCachedRows(layerCache, "selfK", kNew); layerCache.selfV = appendCachedRows(layerCache, "selfV", vNew); layerCache.selfLen = Number(layerCache.selfLen || 0) + 1; const merged = this.attention(q, layerCache.selfK, layerCache.selfV, 1, layerCache.selfLen, false); return this.linear(`${prefix}.w_o`).run(merged, 1); } async selfAttentionIncrementalSpan(prefix, x, seqLen, layerCache) { const nHeads = this.graph.n_heads; const headDim = this.graph.head_dim; const [q, kNew, vNew] = this.linear3(`${prefix}.w_q`, `${prefix}.w_k`, `${prefix}.w_v`, x, seqLen); const position = Number(layerCache.selfLen || 0); if (this.decoderRotary) { applyRotaryMergedInPlace(q, kNew, seqLen, nHeads, headDim, this.decoderRotaryBase, position); } if (this.wasmOps?.AttentionKvCache) { layerCache.selfAttention ??= new this.wasmOps.AttentionKvCache(nHeads, headDim); const merged = layerCache.selfAttention.append_self_attention(q, kNew, vNew, seqLen, true); layerCache.selfLen = layerCache.selfAttention.len(); return this.linear(`${prefix}.w_o`).run(merged, seqLen); } layerCache.selfK = appendCachedRows(layerCache, "selfK", kNew); layerCache.selfV = appendCachedRows(layerCache, "selfV", vNew); layerCache.selfLen = Number(layerCache.selfLen || 0) + seqLen; const merged = this.attention(q, layerCache.selfK, layerCache.selfV, seqLen, layerCache.selfLen, true, position); return this.linear(`${prefix}.w_o`).run(merged, seqLen); } async crossAttentionCached(prefix, x, memory, memoryLen, layerCache) { const nHeads = this.graph.n_heads; const headDim = this.graph.head_dim; const q = await this.linear(`${prefix}.w_q`).run(x, 1); if (this.wasmOps?.AttentionKvCache) { layerCache.crossAttention ??= new this.wasmOps.AttentionKvCache(nHeads, headDim); if (!layerCache.crossReady) { const [crossK, crossV] = this.linear2(`${prefix}.w_k`, `${prefix}.w_v`, memory, memoryLen); layerCache.crossAttention.set_cross(crossK, crossV, memoryLen); layerCache.crossReady = true; } const merged = layerCache.crossAttention.attention(q, 1, false, 0); return this.linear(`${prefix}.w_o`).run(merged, 1); } if (!layerCache.crossK || !layerCache.crossV) { [layerCache.crossK, layerCache.crossV] = this.linear2(`${prefix}.w_k`, `${prefix}.w_v`, memory, memoryLen); } const merged = this.attention(q, layerCache.crossK, layerCache.crossV, 1, memoryLen, false); return this.linear(`${prefix}.w_o`).run(merged, 1); } async crossAttentionCachedSpan(prefix, x, seqLen, memory, memoryLen, layerCache) { const nHeads = this.graph.n_heads; const headDim = this.graph.head_dim; const q = await this.linear(`${prefix}.w_q`).run(x, seqLen); if (this.wasmOps?.AttentionKvCache) { layerCache.crossAttention ??= new this.wasmOps.AttentionKvCache(nHeads, headDim); if (!layerCache.crossReady) { const [crossK, crossV] = this.linear2(`${prefix}.w_k`, `${prefix}.w_v`, memory, memoryLen); layerCache.crossAttention.set_cross(crossK, crossV, memoryLen); layerCache.crossReady = true; } const merged = layerCache.crossAttention.attention(q, seqLen, false, 0); return this.linear(`${prefix}.w_o`).run(merged, seqLen); } if (!layerCache.crossK || !layerCache.crossV) { [layerCache.crossK, layerCache.crossV] = this.linear2(`${prefix}.w_k`, `${prefix}.w_v`, memory, memoryLen); } const merged = this.attention(q, layerCache.crossK, layerCache.crossV, seqLen, memoryLen, false); return this.linear(`${prefix}.w_o`).run(merged, seqLen); } async mlp(prefix, x, seqLen) { const wIn = this.linear(`${prefix}.w_in`); const wOut = this.linear(`${prefix}.w_out`); if (this.wasmOps?.bitnet_mlp_f32 && wIn.handle && wOut.handle) { return this.wasmOps.bitnet_mlp_f32(wIn.handle, wOut.handle, x, seqLen, String(this.graph.activation || "silu")); } const hidden = await wIn.run(x, seqLen); const activation = String(this.graph.activation || "silu").toLowerCase(); const isGated = wIn.layout.logicalOut === wOut.layout.logicalIn * 2 || hidden.length === seqLen * wOut.layout.logicalIn * 2; const activated = isGated || ["swiglu", "gated-silu", "geglu", "reglu"].includes(activation) ? (this.wasmOps?.gated_activation_f32 ? this.wasmOps.gated_activation_f32(hidden, seqLen, wOut.layout.logicalIn, activation) : gatedActivation(hidden, seqLen, wOut.layout.logicalIn, activation)) : (this.wasmOps?.activate_f32 ? this.wasmOps.activate_f32(hidden, activation) : activate(hidden, activation)); return wOut.run(activated, seqLen); } async encoderLayer(index, x, seqLen) { const n1 = this.norm(`encoder.${index}.n1`, x, seqLen); const attnOut = await this.attentionBlock(`encoder.${index}.attn`, n1, seqLen, null, null, false); x = addInPlace(x.slice(), attnOut); const n2 = this.norm(`encoder.${index}.n2`, x, seqLen); return addInPlace(x, await this.mlp(`encoder.${index}.mlp`, n2, seqLen)); } async decoderLayer(index, x, seqLen, memory, memoryLen) { let n = this.norm(`decoder.${index}.self_attn_block.n1`, x, seqLen); x = addInPlace(x.slice(), await this.attentionBlock(`decoder.${index}.self_attn_block.attn`, n, seqLen, null, null, true)); n = this.norm(`decoder.${index}.self_attn_block.n2`, x, seqLen); x = addInPlace(x, await this.mlp(`decoder.${index}.self_attn_block.mlp`, n, seqLen)); n = this.norm(`decoder.${index}.cross_block.n1`, x, seqLen); x = addInPlace(x.slice(), await this.attentionBlock(`decoder.${index}.cross_block.cross`, n, seqLen, memory, memoryLen, false)); n = this.norm(`decoder.${index}.cross_block.n2`, x, seqLen); return addInPlace(x, await this.mlp(`decoder.${index}.cross_block.mlp`, n, seqLen)); } async decoderLayerIncremental(index, x, memory, memoryLen, layerCache) { if (this.wasmOps?.DecoderLayerHandle) { layerCache.decoderLayer ??= this.decoderLayerHandle(index); if (layerCache.decoderLayer?.next) { const out = layerCache.decoderLayer.next(x, memory, memoryLen); layerCache.selfLen = layerCache.decoderLayer.self_len(); return out; } } let n = this.norm(`decoder.${index}.self_attn_block.n1`, x, 1); x = addInPlace( x.slice(), await this.selfAttentionIncremental(`decoder.${index}.self_attn_block.attn`, n, layerCache), ); n = this.norm(`decoder.${index}.self_attn_block.n2`, x, 1); x = addInPlace(x, await this.mlp(`decoder.${index}.self_attn_block.mlp`, n, 1)); n = this.norm(`decoder.${index}.cross_block.n1`, x, 1); x = addInPlace( x.slice(), await this.crossAttentionCached(`decoder.${index}.cross_block.cross`, n, memory, memoryLen, layerCache), ); n = this.norm(`decoder.${index}.cross_block.n2`, x, 1); return addInPlace(x, await this.mlp(`decoder.${index}.cross_block.mlp`, n, 1)); } async decoderLayerIncrementalSpan(index, x, seqLen, memory, memoryLen, layerCache) { let n = this.norm(`decoder.${index}.self_attn_block.n1`, x, seqLen); x = addInPlace( x.slice(), await this.selfAttentionIncrementalSpan(`decoder.${index}.self_attn_block.attn`, n, seqLen, layerCache), ); n = this.norm(`decoder.${index}.self_attn_block.n2`, x, seqLen); x = addInPlace(x, await this.mlp(`decoder.${index}.self_attn_block.mlp`, n, seqLen)); n = this.norm(`decoder.${index}.cross_block.n1`, x, seqLen); x = addInPlace( x.slice(), await this.crossAttentionCachedSpan(`decoder.${index}.cross_block.cross`, n, seqLen, memory, memoryLen, layerCache), ); n = this.norm(`decoder.${index}.cross_block.n2`, x, seqLen); return addInPlace(x, await this.mlp(`decoder.${index}.cross_block.mlp`, n, seqLen)); } async encode(encInputIds) { let x = embed(encInputIds, this.tensor("enc_embed.weight"), this.graph.d_model); if (this.graph.encoder_position_embeddings) { x = addPositionEmbeddingInPlace(x, this.tensor("enc_pos_embed.weight"), this.graph.d_model); } for (let i = 0; i < this.graph.n_layers; i += 1) { x = await this.encoderLayer(i, x, encInputIds.length); } return layerNorm( x, encInputIds.length, this.graph.d_model, this.tensor("enc_norm.weight"), this.dense["enc_norm.bias"]?.data, ); } async retrievalEmbedding(encInputIds, options = {}) { const retrieval = this.graph.retrieval || {}; const headName = options.kind === "doc" ? retrieval.doc_head : retrieval.query_head; if (!headName) { throw new Error("model manifest does not expose retrieval heads"); } const inputIds = Array.from(encInputIds || [], Number); const memory = await this.encode(inputIds); const pooled = meanPoolRows( memory, inputIds.length, this.graph.d_model, options.attentionMask || inputIds.map((id) => (id === 0 ? 0 : 1)), ); const projected = await this.linear(headName).run(pooled, 1); return l2Normalize(projected); } async retrievalQueryEmbedding(encInputIds, options = {}) { return this.retrievalEmbedding(encInputIds, { ...options, kind: "query" }); } async retrievalDocEmbedding(encInputIds, options = {}) { return this.retrievalEmbedding(encInputIds, { ...options, kind: "doc" }); } async decode(decInputIds, memory, memoryLen) { let x = embed(decInputIds, this.tensor("dec_embed.weight"), this.graph.d_model); for (let i = 0; i < this.graph.n_layers; i += 1) { x = await this.decoderLayer(i, x, decInputIds.length, memory, memoryLen); } return layerNorm( x, decInputIds.length, this.graph.d_model, this.tensor("dec_norm.weight"), this.dense["dec_norm.bias"]?.data, ); } async forward(encInputIds, decInputIds) { const memory = await this.encode(encInputIds); const hidden = await this.decode(decInputIds, memory, encInputIds.length); return this.linear("lm_head").run(hidden, decInputIds.length); } async debugTrace(encInputIds, decInputIds) { const traces = []; let x = embed(encInputIds, this.tensor("enc_embed.weight"), this.graph.d_model); if (this.graph.encoder_position_embeddings) { x = addPositionEmbeddingInPlace(x, this.tensor("enc_pos_embed.weight"), this.graph.d_model); } traces.push(traceTensor("enc_embed", x, [encInputIds.length, this.graph.d_model])); for (let i = 0; i < this.graph.n_layers; i += 1) { const n1 = this.norm(`encoder.${i}.n1`, x, encInputIds.length); traces.push(traceTensor(`encoder.${i}.n1`, n1, [encInputIds.length, this.graph.d_model])); const attnOut = await this.attentionBlock(`encoder.${i}.attn`, n1, encInputIds.length, null, null, false); traces.push(traceTensor(`encoder.${i}.attn`, attnOut, [encInputIds.length, this.graph.d_model])); x = addInPlace(x.slice(), attnOut); traces.push(traceTensor(`encoder.${i}.attn_resid`, x, [encInputIds.length, this.graph.d_model])); const n2 = this.norm(`encoder.${i}.n2`, x, encInputIds.length); traces.push(traceTensor(`encoder.${i}.n2`, n2, [encInputIds.length, this.graph.d_model])); const mlpOut = await this.mlp(`encoder.${i}.mlp`, n2, encInputIds.length); traces.push(traceTensor(`encoder.${i}.mlp`, mlpOut, [encInputIds.length, this.graph.d_model])); x = addInPlace(x, mlpOut); traces.push(traceTensor(`encoder.${i}`, x, [encInputIds.length, this.graph.d_model])); } const memory = layerNorm( x, encInputIds.length, this.graph.d_model, this.tensor("enc_norm.weight"), this.dense["enc_norm.bias"]?.data, ); traces.push(traceTensor("enc_norm", memory, [encInputIds.length, this.graph.d_model])); let hidden = embed(decInputIds, this.tensor("dec_embed.weight"), this.graph.d_model); traces.push(traceTensor("dec_embed", hidden, [decInputIds.length, this.graph.d_model])); for (let i = 0; i < this.graph.n_layers; i += 1) { hidden = await this.decoderLayer(i, hidden, decInputIds.length, memory, encInputIds.length); traces.push(traceTensor(`decoder.${i}`, hidden, [decInputIds.length, this.graph.d_model])); } hidden = layerNorm( hidden, decInputIds.length, this.graph.d_model, this.tensor("dec_norm.weight"), this.dense["dec_norm.bias"]?.data, ); traces.push(traceTensor("dec_norm", hidden, [decInputIds.length, this.graph.d_model])); const logits = await this.linear("lm_head").run(hidden, decInputIds.length); traces.push(traceTensor("logits", logits, [decInputIds.length, this.graph.vocab_size])); return { traces }; } createGenerationSession(encInputIds) { return new BitNetEncoderDecoderGenerationSession(this, encInputIds); } } export class BitNetEncoderDecoderGenerationSession { constructor(runtime, encInputIds) { this.runtime = runtime; this.encInputIds = Array.from(encInputIds || [], Number); this.memory = null; this.memoryLen = this.encInputIds.length; this.layerCaches = Array.from({ length: runtime.graph.n_layers }, () => ({})); } async prepare() { if (!this.memory) { this.memory = await this.runtime.encode(this.encInputIds); } return this; } async next(tokenId) { const hidden = await this.nextHidden(tokenId); return this.runtime.linear("lm_head").run(hidden, 1); } async nextHidden(tokenId) { await this.prepare(); let x = embed([Number(tokenId)], this.runtime.tensor("dec_embed.weight"), this.runtime.graph.d_model); for (let i = 0; i < this.runtime.graph.n_layers; i += 1) { x = await this.runtime.decoderLayerIncremental(i, x, this.memory, this.memoryLen, this.layerCaches[i]); } return this.runtime.norm("dec_norm", x, 1); } async sampleNext(tokenId, generatedIds, options = {}) { if (!this.runtime.wasmOps?.bitnet_sample_token_f32) return null; const lmHead = this.runtime.linear("lm_head"); if (!lmHead.handle) return null; const hidden = await this.nextHidden(tokenId); const sample = this.runtime.wasmOps.bitnet_sample_token_f32( lmHead.handle, hidden, toUint32IdArray(generatedIds), toUint32IdArray(options.blockedIds), Number(options.temperature ?? 0.35), Number(options.topP ?? 0.9), Number(options.repetitionPenalty ?? 1.16), Number(options.randomValue ?? Math.random()), ); return { tokenId: Number(sample.token_id), probability: Number(sample.probability), topProbability: Number(sample.top_probability), rank: Number(sample.rank), }; } cloneState() { return this.layerCaches.map((cache) => { const cloned = { ...cache }; if (cache.selfAttention?.clone_cache) { cloned.selfAttention = cache.selfAttention.clone_cache(); } if (cache.crossAttention?.clone_cache) { cloned.crossAttention = cache.crossAttention.clone_cache(); } if (cache.decoderLayer?.clone_cache) { cloned.decoderLayer = cache.decoderLayer.clone_cache(); } if (cache.selfK) { cloned.selfK = cache.selfK.slice(); cloned.selfKLength = cloned.selfK.length; cloned.selfKCapacity = cloned.selfK.length; } if (cache.selfV) { cloned.selfV = cache.selfV.slice(); cloned.selfVLength = cloned.selfV.length; cloned.selfVCapacity = cloned.selfV.length; } return cloned; }); } restoreState(layerCaches) { this.layerCaches = layerCaches; } async nextMany(tokenIds) { const ids = Array.from(tokenIds || [], Number).filter((id) => Number.isFinite(id)); if (!ids.length) return new Float32Array(0); await this.prepare(); let x = embed(ids, this.runtime.tensor("dec_embed.weight"), this.runtime.graph.d_model); for (let i = 0; i < this.runtime.graph.n_layers; i += 1) { x = await this.runtime.decoderLayerIncrementalSpan(i, x, ids.length, this.memory, this.memoryLen, this.layerCaches[i]); } const hidden = layerNorm( x, ids.length, this.runtime.graph.d_model, this.runtime.tensor("dec_norm.weight"), this.runtime.dense["dec_norm.bias"]?.data, ); return this.runtime.linear("lm_head").run(hidden, ids.length); } } export class BitNetEncoderDecoderWASM extends BitNetEncoderDecoderWebGPU { constructor(manifest, manifestUrl, denseTensors, linears) { super(null, manifest, manifestUrl, denseTensors, linears); this.wasmOps = Object.values(linears || {}).find((layer) => layer?.wasm)?.wasm || null; } static async fromManifestUrl(manifestUrl, options = {}) { const progress = typeof options.progress === "function" ? options.progress : () => {}; progress({ phase: "manifest", message: "Loading model manifest" }); const manifest = options.manifest || await fetchJson(manifestUrl); const baseUrl = new URL(".", manifestUrl).toString(); const dense = {}; const denseEntries = Object.entries(manifest.dense_tensors || {}); for (const [index, [name, entry]] of denseEntries.entries()) { progress({ phase: "dense", index: index + 1, total: denseEntries.length, name, message: `Loading dense tensor ${index + 1}/${denseEntries.length}: ${name}`, }); dense[name] = await fetchFloatTensor(entry, baseUrl); } progress({ phase: "dense_ready", index: denseEntries.length, total: denseEntries.length, message: "Dense tensors ready", }); const linears = {}; const layers = manifest.layers || []; const layerConcurrency = Math.max(1, Math.min(Number(options.layerConcurrency || 4), layers.length || 1)); progress({ phase: "prepare_layers", index: 0, total: layers.length, message: `Preparing ${layers.length} BitNet WASM layers (${layerConcurrency} parallel)`, }); let nextLayer = 0; let completedLayers = 0; async function loadLayerWorker() { while (nextLayer < layers.length) { const index = nextLayer; nextLayer += 1; const layer = layers[index]; progress({ phase: "layer", index: index + 1, total: layers.length, name: layer.name, message: `Loading BitNet WASM layer ${index + 1}/${layers.length}: ${layer.name}`, }); linears[layer.name] = await BitNetLinearWASM.fromManifestLayer(manifest, layer, manifestUrl, { progress, index: index + 1, total: layers.length, name: layer.name, }); completedLayers += 1; progress({ phase: "layer_ready", index: completedLayers, total: layers.length, name: layer.name, message: `BitNet WASM layer ${completedLayers}/${layers.length} ready: ${layer.name}`, }); } } await Promise.all(Array.from({ length: Math.min(layerConcurrency, layers.length) }, () => loadLayerWorker())); progress({ phase: "wasm_ready", message: "BitNet WASM runtime ready" }); return new BitNetEncoderDecoderWASM(manifest, manifestUrl, dense, linears); } }