Update packed BitNet WASM fallback runtime
Browse files- runtime/encdec_runtime.js +187 -0
runtime/encdec_runtime.js
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import { BitNetLinearWebGPU } from "./bitnet_webgpu.js";
|
|
|
|
| 2 |
|
| 3 |
function resolveUrl(path, baseUrl) {
|
| 4 |
return new URL(path, baseUrl).toString();
|
|
@@ -31,6 +32,14 @@ function addInPlace(dst, src) {
|
|
| 31 |
return dst;
|
| 32 |
}
|
| 33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
function layerNorm(x, rows, cols, weight, bias, eps = 1e-5) {
|
| 35 |
const out = new Float32Array(x.length);
|
| 36 |
for (let r = 0; r < rows; r += 1) {
|
|
@@ -262,6 +271,31 @@ export class BitNetEncoderDecoderWebGPU {
|
|
| 262 |
return this.linear(`${prefix}.w_o`).run(merged, seqLen);
|
| 263 |
}
|
| 264 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
async mlp(prefix, x, seqLen) {
|
| 266 |
const wIn = this.linear(`${prefix}.w_in`);
|
| 267 |
const wOut = this.linear(`${prefix}.w_out`);
|
|
@@ -333,6 +367,48 @@ export class BitNetEncoderDecoderWebGPU {
|
|
| 333 |
return addInPlace(x, await this.mlp(`decoder.${index}.cross_block.mlp`, n, seqLen));
|
| 334 |
}
|
| 335 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 336 |
async encode(encInputIds) {
|
| 337 |
let x = embed(encInputIds, this.tensor("enc_embed.weight"), this.graph.d_model);
|
| 338 |
for (let i = 0; i < this.graph.n_layers; i += 1) {
|
|
@@ -366,4 +442,115 @@ export class BitNetEncoderDecoderWebGPU {
|
|
| 366 |
const hidden = await this.decode(decInputIds, memory, encInputIds.length);
|
| 367 |
return this.linear("lm_head").run(hidden, decInputIds.length);
|
| 368 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 369 |
}
|
|
|
|
| 1 |
import { BitNetLinearWebGPU } from "./bitnet_webgpu.js";
|
| 2 |
+
import { BitNetLinearWASM } from "./bitnet_wasm_runtime.js";
|
| 3 |
|
| 4 |
function resolveUrl(path, baseUrl) {
|
| 5 |
return new URL(path, baseUrl).toString();
|
|
|
|
| 32 |
return dst;
|
| 33 |
}
|
| 34 |
|
| 35 |
+
function appendRows(existing, next) {
|
| 36 |
+
if (!existing || existing.length === 0) return next.slice();
|
| 37 |
+
const out = new Float32Array(existing.length + next.length);
|
| 38 |
+
out.set(existing, 0);
|
| 39 |
+
out.set(next, existing.length);
|
| 40 |
+
return out;
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
function layerNorm(x, rows, cols, weight, bias, eps = 1e-5) {
|
| 44 |
const out = new Float32Array(x.length);
|
| 45 |
for (let r = 0; r < rows; r += 1) {
|
|
|
|
| 271 |
return this.linear(`${prefix}.w_o`).run(merged, seqLen);
|
| 272 |
}
|
| 273 |
|
| 274 |
+
async selfAttentionIncremental(prefix, x, layerCache) {
|
| 275 |
+
const nHeads = this.graph.n_heads;
|
| 276 |
+
const headDim = this.graph.head_dim;
|
| 277 |
+
const q = await this.linear(`${prefix}.w_q`).run(x, 1);
|
| 278 |
+
const kNew = await this.linear(`${prefix}.w_k`).run(x, 1);
|
| 279 |
+
const vNew = await this.linear(`${prefix}.w_v`).run(x, 1);
|
| 280 |
+
layerCache.selfK = appendRows(layerCache.selfK, kNew);
|
| 281 |
+
layerCache.selfV = appendRows(layerCache.selfV, vNew);
|
| 282 |
+
layerCache.selfLen = Number(layerCache.selfLen || 0) + 1;
|
| 283 |
+
const merged = attention(q, layerCache.selfK, layerCache.selfV, 1, layerCache.selfLen, nHeads, headDim, false);
|
| 284 |
+
return this.linear(`${prefix}.w_o`).run(merged, 1);
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
async crossAttentionCached(prefix, x, memory, memoryLen, layerCache) {
|
| 288 |
+
const nHeads = this.graph.n_heads;
|
| 289 |
+
const headDim = this.graph.head_dim;
|
| 290 |
+
const q = await this.linear(`${prefix}.w_q`).run(x, 1);
|
| 291 |
+
if (!layerCache.crossK || !layerCache.crossV) {
|
| 292 |
+
layerCache.crossK = await this.linear(`${prefix}.w_k`).run(memory, memoryLen);
|
| 293 |
+
layerCache.crossV = await this.linear(`${prefix}.w_v`).run(memory, memoryLen);
|
| 294 |
+
}
|
| 295 |
+
const merged = attention(q, layerCache.crossK, layerCache.crossV, 1, memoryLen, nHeads, headDim, false);
|
| 296 |
+
return this.linear(`${prefix}.w_o`).run(merged, 1);
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
async mlp(prefix, x, seqLen) {
|
| 300 |
const wIn = this.linear(`${prefix}.w_in`);
|
| 301 |
const wOut = this.linear(`${prefix}.w_out`);
|
|
|
|
| 367 |
return addInPlace(x, await this.mlp(`decoder.${index}.cross_block.mlp`, n, seqLen));
|
| 368 |
}
|
| 369 |
|
| 370 |
+
async decoderLayerIncremental(index, x, memory, memoryLen, layerCache) {
|
| 371 |
+
const dModel = this.graph.d_model;
|
| 372 |
+
let n = layerNorm(
|
| 373 |
+
x,
|
| 374 |
+
1,
|
| 375 |
+
dModel,
|
| 376 |
+
this.tensor(`decoder.${index}.self_attn_block.n1.weight`),
|
| 377 |
+
this.dense[`decoder.${index}.self_attn_block.n1.bias`]?.data,
|
| 378 |
+
);
|
| 379 |
+
x = addInPlace(
|
| 380 |
+
x.slice(),
|
| 381 |
+
await this.selfAttentionIncremental(`decoder.${index}.self_attn_block.attn`, n, layerCache),
|
| 382 |
+
);
|
| 383 |
+
n = layerNorm(
|
| 384 |
+
x,
|
| 385 |
+
1,
|
| 386 |
+
dModel,
|
| 387 |
+
this.tensor(`decoder.${index}.self_attn_block.n2.weight`),
|
| 388 |
+
this.dense[`decoder.${index}.self_attn_block.n2.bias`]?.data,
|
| 389 |
+
);
|
| 390 |
+
x = addInPlace(x, await this.mlp(`decoder.${index}.self_attn_block.mlp`, n, 1));
|
| 391 |
+
n = layerNorm(
|
| 392 |
+
x,
|
| 393 |
+
1,
|
| 394 |
+
dModel,
|
| 395 |
+
this.tensor(`decoder.${index}.cross_block.n1.weight`),
|
| 396 |
+
this.dense[`decoder.${index}.cross_block.n1.bias`]?.data,
|
| 397 |
+
);
|
| 398 |
+
x = addInPlace(
|
| 399 |
+
x.slice(),
|
| 400 |
+
await this.crossAttentionCached(`decoder.${index}.cross_block.cross`, n, memory, memoryLen, layerCache),
|
| 401 |
+
);
|
| 402 |
+
n = layerNorm(
|
| 403 |
+
x,
|
| 404 |
+
1,
|
| 405 |
+
dModel,
|
| 406 |
+
this.tensor(`decoder.${index}.cross_block.n2.weight`),
|
| 407 |
+
this.dense[`decoder.${index}.cross_block.n2.bias`]?.data,
|
| 408 |
+
);
|
| 409 |
+
return addInPlace(x, await this.mlp(`decoder.${index}.cross_block.mlp`, n, 1));
|
| 410 |
+
}
|
| 411 |
+
|
| 412 |
async encode(encInputIds) {
|
| 413 |
let x = embed(encInputIds, this.tensor("enc_embed.weight"), this.graph.d_model);
|
| 414 |
for (let i = 0; i < this.graph.n_layers; i += 1) {
|
|
|
|
| 442 |
const hidden = await this.decode(decInputIds, memory, encInputIds.length);
|
| 443 |
return this.linear("lm_head").run(hidden, decInputIds.length);
|
| 444 |
}
|
| 445 |
+
|
| 446 |
+
createGenerationSession(encInputIds) {
|
| 447 |
+
return new BitNetEncoderDecoderGenerationSession(this, encInputIds);
|
| 448 |
+
}
|
| 449 |
+
}
|
| 450 |
+
|
| 451 |
+
export class BitNetEncoderDecoderGenerationSession {
|
| 452 |
+
constructor(runtime, encInputIds) {
|
| 453 |
+
this.runtime = runtime;
|
| 454 |
+
this.encInputIds = Array.from(encInputIds || [], Number);
|
| 455 |
+
this.memory = null;
|
| 456 |
+
this.memoryLen = this.encInputIds.length;
|
| 457 |
+
this.layerCaches = Array.from({ length: runtime.graph.n_layers }, () => ({}));
|
| 458 |
+
}
|
| 459 |
+
|
| 460 |
+
async prepare() {
|
| 461 |
+
if (!this.memory) {
|
| 462 |
+
this.memory = await this.runtime.encode(this.encInputIds);
|
| 463 |
+
}
|
| 464 |
+
return this;
|
| 465 |
+
}
|
| 466 |
+
|
| 467 |
+
async next(tokenId) {
|
| 468 |
+
await this.prepare();
|
| 469 |
+
let x = embed([Number(tokenId)], this.runtime.tensor("dec_embed.weight"), this.runtime.graph.d_model);
|
| 470 |
+
for (let i = 0; i < this.runtime.graph.n_layers; i += 1) {
|
| 471 |
+
x = await this.runtime.decoderLayerIncremental(i, x, this.memory, this.memoryLen, this.layerCaches[i]);
|
| 472 |
+
}
|
| 473 |
+
const hidden = layerNorm(
|
| 474 |
+
x,
|
| 475 |
+
1,
|
| 476 |
+
this.runtime.graph.d_model,
|
| 477 |
+
this.runtime.tensor("dec_norm.weight"),
|
| 478 |
+
this.runtime.dense["dec_norm.bias"]?.data,
|
| 479 |
+
);
|
| 480 |
+
return this.runtime.linear("lm_head").run(hidden, 1);
|
| 481 |
+
}
|
| 482 |
+
}
|
| 483 |
+
|
| 484 |
+
export class BitNetEncoderDecoderWASM extends BitNetEncoderDecoderWebGPU {
|
| 485 |
+
constructor(manifest, manifestUrl, denseTensors, linears) {
|
| 486 |
+
super(null, manifest, manifestUrl, denseTensors, linears);
|
| 487 |
+
}
|
| 488 |
+
|
| 489 |
+
static async fromManifestUrl(manifestUrl, options = {}) {
|
| 490 |
+
const progress = typeof options.progress === "function" ? options.progress : () => {};
|
| 491 |
+
progress({ phase: "manifest", message: "Loading model manifest" });
|
| 492 |
+
const manifest = options.manifest || await fetchJson(manifestUrl);
|
| 493 |
+
const baseUrl = new URL(".", manifestUrl).toString();
|
| 494 |
+
const dense = {};
|
| 495 |
+
const denseEntries = Object.entries(manifest.dense_tensors || {});
|
| 496 |
+
for (const [index, [name, entry]] of denseEntries.entries()) {
|
| 497 |
+
progress({
|
| 498 |
+
phase: "dense",
|
| 499 |
+
index: index + 1,
|
| 500 |
+
total: denseEntries.length,
|
| 501 |
+
name,
|
| 502 |
+
message: `Loading dense tensor ${index + 1}/${denseEntries.length}: ${name}`,
|
| 503 |
+
});
|
| 504 |
+
dense[name] = await fetchFloatTensor(entry, baseUrl);
|
| 505 |
+
}
|
| 506 |
+
progress({
|
| 507 |
+
phase: "dense_ready",
|
| 508 |
+
index: denseEntries.length,
|
| 509 |
+
total: denseEntries.length,
|
| 510 |
+
message: "Dense tensors ready",
|
| 511 |
+
});
|
| 512 |
+
|
| 513 |
+
const linears = {};
|
| 514 |
+
const layers = manifest.layers || [];
|
| 515 |
+
const layerConcurrency = Math.max(1, Math.min(Number(options.layerConcurrency || 4), layers.length || 1));
|
| 516 |
+
progress({
|
| 517 |
+
phase: "prepare_layers",
|
| 518 |
+
index: 0,
|
| 519 |
+
total: layers.length,
|
| 520 |
+
message: `Preparing ${layers.length} BitNet WASM layers (${layerConcurrency} parallel)`,
|
| 521 |
+
});
|
| 522 |
+
let nextLayer = 0;
|
| 523 |
+
let completedLayers = 0;
|
| 524 |
+
async function loadLayerWorker() {
|
| 525 |
+
while (nextLayer < layers.length) {
|
| 526 |
+
const index = nextLayer;
|
| 527 |
+
nextLayer += 1;
|
| 528 |
+
const layer = layers[index];
|
| 529 |
+
progress({
|
| 530 |
+
phase: "layer",
|
| 531 |
+
index: index + 1,
|
| 532 |
+
total: layers.length,
|
| 533 |
+
name: layer.name,
|
| 534 |
+
message: `Loading BitNet WASM layer ${index + 1}/${layers.length}: ${layer.name}`,
|
| 535 |
+
});
|
| 536 |
+
linears[layer.name] = await BitNetLinearWASM.fromManifestLayer(manifest, layer, manifestUrl, {
|
| 537 |
+
progress,
|
| 538 |
+
index: index + 1,
|
| 539 |
+
total: layers.length,
|
| 540 |
+
name: layer.name,
|
| 541 |
+
});
|
| 542 |
+
completedLayers += 1;
|
| 543 |
+
progress({
|
| 544 |
+
phase: "layer_ready",
|
| 545 |
+
index: completedLayers,
|
| 546 |
+
total: layers.length,
|
| 547 |
+
name: layer.name,
|
| 548 |
+
message: `BitNet WASM layer ${completedLayers}/${layers.length} ready: ${layer.name}`,
|
| 549 |
+
});
|
| 550 |
+
}
|
| 551 |
+
}
|
| 552 |
+
await Promise.all(Array.from({ length: Math.min(layerConcurrency, layers.length) }, () => loadLayerWorker()));
|
| 553 |
+
progress({ phase: "wasm_ready", message: "BitNet WASM runtime ready" });
|
| 554 |
+
return new BitNetEncoderDecoderWASM(manifest, manifestUrl, dense, linears);
|
| 555 |
+
}
|
| 556 |
}
|