agentkernel-lite-100m-bitnet / runtime /bitnet_wasm_runtime.js
PeytonT's picture
Update AgentKernel Lite BitNet WASM bundle
6eea347 verified
let wasmModulePromise = null;
function normalizeLayout(layoutHeader) {
if (!layoutHeader || layoutHeader.length < 13) {
throw new Error("BitNet layout_header must contain at least 13 entries");
}
const header = Int32Array.from(Array.from(layoutHeader, Number));
if (header[0] !== 1 || header[1] !== 16 || header[2] !== 32 || header[9] !== 1) {
throw new Error("Unsupported BitNet WASM layout; expected v1 16x32 interleave mode 1");
}
return {
header,
logicalOut: header[3],
logicalIn: header[4],
paddedOut: header[5],
paddedIn: header[6],
scaleGranularity: header[7],
scaleGroupSize: header[8],
segmentCount: header[11],
};
}
function resolveUrl(path, baseUrl) {
return new URL(path, baseUrl).toString();
}
function sleep(ms) {
return new Promise((resolve) => setTimeout(resolve, ms));
}
async function fetchWithRetry(url, options = {}) {
const attempts = Math.max(1, Number(options.attempts || 5));
let lastError = null;
for (let attempt = 0; attempt < attempts; attempt += 1) {
try {
const response = await fetch(url);
if (response.ok) return response;
if (response.status < 500 && response.status !== 408 && response.status !== 429) {
throw new Error(`failed to fetch ${url}: ${response.status}`);
}
lastError = new Error(`failed to fetch ${url}: ${response.status}`);
} catch (error) {
lastError = error;
}
if (attempt < attempts - 1) {
await sleep(Math.min(2000, 150 * 2 ** attempt));
}
}
throw lastError || new Error(`failed to fetch ${url}`);
}
async function fetchTensor(entry, baseUrl, TypedArray) {
const url = resolveUrl(entry.path, baseUrl);
const response = await fetchWithRetry(url);
if (!response.ok) {
throw new Error(`failed to fetch ${entry.path}: ${response.status}`);
}
return new TypedArray(await response.arrayBuffer());
}
function tensorType(entry) {
if (entry.dtype === "uint8") return Uint8Array;
if (entry.dtype === "int32") return Int32Array;
if (entry.dtype === "float32") return Float32Array;
throw new Error(`unsupported tensor dtype: ${entry.dtype}`);
}
async function ensureBitNetWasm() {
if (!wasmModulePromise) {
wasmModulePromise = (async () => {
let module;
try {
module = await import(new URL("model_stack_bitnet_wasm.js", import.meta.url).href);
} catch (error) {
module = await import(new URL("pkg/model_stack_bitnet_wasm.js", import.meta.url).href);
}
await module.default();
return module;
})();
}
return wasmModulePromise;
}
export class BitNetLinearWASM {
constructor(bundle) {
this.layout = normalizeLayout(bundle.layoutHeader);
this.packedWeight = bundle.packedWeight instanceof Uint8Array
? bundle.packedWeight
: new Uint8Array(bundle.packedWeight);
this.scaleValues = bundle.scaleValues instanceof Float32Array
? bundle.scaleValues
: new Float32Array(bundle.scaleValues);
this.segmentOffsets = bundle.segmentOffsets instanceof Int32Array
? bundle.segmentOffsets
: Int32Array.from(bundle.segmentOffsets || []);
this.bias = bundle.bias
? (bundle.bias instanceof Float32Array ? bundle.bias : new Float32Array(bundle.bias))
: new Float32Array(0);
this.inputScales = bundle.inputScales
? (bundle.inputScales instanceof Float32Array ? bundle.inputScales : new Float32Array(bundle.inputScales))
: new Float32Array([1]);
this.inputQuantMode = bundle.inputQuantMode ?? 0;
this.inputQuantBits = bundle.inputQuantBits ?? 8;
this.inputScaleRows = bundle.inputScaleRows ?? 1;
}
static async fromManifestLayer(manifest, layer, manifestUrl, options = {}) {
const progress = typeof options.progress === "function" ? options.progress : () => {};
const index = Number(options.index || 0);
const total = Number(options.total || 0);
const name = String(options.name || layer.name || "layer");
const label = total ? `${index}/${total}: ${name}` : name;
const baseUrl = new URL(".", manifestUrl).toString();
const tensors = layer.tensors;
const layersBaseUrl = resolveUrl("layers/", baseUrl);
progress({ phase: "layer_tensors", index, total, name, message: `Loading BitNet WASM tensors ${label}` });
const [packedWeight, scaleValues, segmentOffsets, bias, inputScales] = await Promise.all([
fetchTensor(tensors.packed_weight, layersBaseUrl, Uint8Array),
fetchTensor(tensors.scale_values, layersBaseUrl, Float32Array),
fetchTensor(tensors.segment_offsets, layersBaseUrl, Int32Array),
tensors.bias ? fetchTensor(tensors.bias, layersBaseUrl, Float32Array) : Promise.resolve(null),
fetchTensor(tensors.act_scale, layersBaseUrl, tensorType(tensors.act_scale)),
]);
progress({ phase: "layer_ready", index, total, name, message: `BitNet WASM layer ${label} ready` });
return new BitNetLinearWASM({
layoutHeader: layer.layout_header,
packedWeight,
scaleValues,
segmentOffsets,
bias,
inputScales,
inputQuantMode: layer.act_quant_mode === "none" ? 0 : 1,
inputQuantBits: layer.act_quant_bits,
inputScaleRows: layer.act_quant_mode === "static_int8" ? 1 : 1,
});
}
async run(input, rows = 1) {
const x = input instanceof Float32Array ? input : new Float32Array(input);
if (x.length !== rows * this.layout.logicalIn) {
throw new Error(`BitNet input length mismatch: got ${x.length}, expected ${rows * this.layout.logicalIn}`);
}
const wasm = await ensureBitNetWasm();
return wasm.bitnet_linear_f32(
x,
this.packedWeight,
this.scaleValues,
this.segmentOffsets,
this.bias,
this.layout.header,
this.inputScales,
rows,
this.inputQuantMode,
this.inputQuantBits,
this.inputScaleRows,
);
}
}