Update AgentKernel Lite WASM decoder kernel
Browse files- runtime/encdec_runtime.js +193 -27
- runtime/model_stack_bitnet_wasm_bg.wasm +2 -2
runtime/encdec_runtime.js
CHANGED
|
@@ -77,6 +77,11 @@ function zeros(length) {
|
|
| 77 |
return new Float32Array(length);
|
| 78 |
}
|
| 79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
function addInPlace(dst, src) {
|
| 81 |
for (let i = 0; i < dst.length; i += 1) {
|
| 82 |
dst[i] += src[i];
|
|
@@ -412,6 +417,7 @@ export class BitNetEncoderDecoderWebGPU {
|
|
| 412 |
this.linears = linears;
|
| 413 |
this.denseLinears = {};
|
| 414 |
this.graph = manifest.graph;
|
|
|
|
| 415 |
this.decoderRotary = decoderUsesRotary(manifest, this.graph);
|
| 416 |
this.decoderRotaryBase = rotaryBase(manifest, this.graph);
|
| 417 |
}
|
|
@@ -495,6 +501,70 @@ export class BitNetEncoderDecoderWebGPU {
|
|
| 495 |
return denseLayer;
|
| 496 |
}
|
| 497 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 498 |
tensor(name) {
|
| 499 |
const tensor = this.dense[name];
|
| 500 |
if (!tensor) throw new Error(`missing dense tensor: ${name}`);
|
|
@@ -504,6 +574,9 @@ export class BitNetEncoderDecoderWebGPU {
|
|
| 504 |
norm(prefix, x, rows) {
|
| 505 |
const weight = this.tensor(`${prefix}.weight`);
|
| 506 |
const bias = this.dense[`${prefix}.bias`]?.data || null;
|
|
|
|
|
|
|
|
|
|
| 507 |
if (bias) {
|
| 508 |
return layerNorm(x, rows, this.graph.d_model, weight, bias);
|
| 509 |
}
|
|
@@ -516,53 +589,84 @@ export class BitNetEncoderDecoderWebGPU {
|
|
| 516 |
);
|
| 517 |
}
|
| 518 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 519 |
async attentionBlock(prefix, x, seqLen, kv, kvLen, causal) {
|
| 520 |
const dModel = this.graph.d_model;
|
| 521 |
const nHeads = this.graph.n_heads;
|
| 522 |
const headDim = this.graph.head_dim;
|
| 523 |
-
|
|
|
|
|
|
|
| 524 |
const kInput = kv || x;
|
| 525 |
const kRows = kvLen || seqLen;
|
| 526 |
-
|
| 527 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 528 |
if (causal && this.decoderRotary) {
|
| 529 |
applyRotaryMergedInPlace(q, k, seqLen, nHeads, headDim, this.decoderRotaryBase, 0);
|
| 530 |
}
|
| 531 |
-
const merged = attention(q, k, v, seqLen, kRows,
|
| 532 |
return this.linear(`${prefix}.w_o`).run(merged, seqLen);
|
| 533 |
}
|
| 534 |
|
| 535 |
async selfAttentionIncremental(prefix, x, layerCache) {
|
| 536 |
const nHeads = this.graph.n_heads;
|
| 537 |
const headDim = this.graph.head_dim;
|
| 538 |
-
const q
|
| 539 |
-
const kNew = await this.linear(`${prefix}.w_k`).run(x, 1);
|
| 540 |
-
const vNew = await this.linear(`${prefix}.w_v`).run(x, 1);
|
| 541 |
const position = Number(layerCache.selfLen || 0);
|
| 542 |
if (this.decoderRotary) {
|
| 543 |
applyRotaryMergedInPlace(q, kNew, 1, nHeads, headDim, this.decoderRotaryBase, position);
|
| 544 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 545 |
layerCache.selfK = appendCachedRows(layerCache, "selfK", kNew);
|
| 546 |
layerCache.selfV = appendCachedRows(layerCache, "selfV", vNew);
|
| 547 |
layerCache.selfLen = Number(layerCache.selfLen || 0) + 1;
|
| 548 |
-
const merged = attention(q, layerCache.selfK, layerCache.selfV, 1, layerCache.selfLen,
|
| 549 |
return this.linear(`${prefix}.w_o`).run(merged, 1);
|
| 550 |
}
|
| 551 |
|
| 552 |
async selfAttentionIncrementalSpan(prefix, x, seqLen, layerCache) {
|
| 553 |
const nHeads = this.graph.n_heads;
|
| 554 |
const headDim = this.graph.head_dim;
|
| 555 |
-
const q
|
| 556 |
-
const kNew = await this.linear(`${prefix}.w_k`).run(x, seqLen);
|
| 557 |
-
const vNew = await this.linear(`${prefix}.w_v`).run(x, seqLen);
|
| 558 |
const position = Number(layerCache.selfLen || 0);
|
| 559 |
if (this.decoderRotary) {
|
| 560 |
applyRotaryMergedInPlace(q, kNew, seqLen, nHeads, headDim, this.decoderRotaryBase, position);
|
| 561 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 562 |
layerCache.selfK = appendCachedRows(layerCache, "selfK", kNew);
|
| 563 |
layerCache.selfV = appendCachedRows(layerCache, "selfV", vNew);
|
| 564 |
layerCache.selfLen = Number(layerCache.selfLen || 0) + seqLen;
|
| 565 |
-
const merged = attention(q, layerCache.selfK, layerCache.selfV, seqLen, layerCache.selfLen,
|
| 566 |
return this.linear(`${prefix}.w_o`).run(merged, seqLen);
|
| 567 |
}
|
| 568 |
|
|
@@ -570,11 +674,20 @@ export class BitNetEncoderDecoderWebGPU {
|
|
| 570 |
const nHeads = this.graph.n_heads;
|
| 571 |
const headDim = this.graph.head_dim;
|
| 572 |
const q = await this.linear(`${prefix}.w_q`).run(x, 1);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 573 |
if (!layerCache.crossK || !layerCache.crossV) {
|
| 574 |
-
layerCache.crossK =
|
| 575 |
-
layerCache.crossV = await this.linear(`${prefix}.w_v`).run(memory, memoryLen);
|
| 576 |
}
|
| 577 |
-
const merged = attention(q, layerCache.crossK, layerCache.crossV, 1, memoryLen,
|
| 578 |
return this.linear(`${prefix}.w_o`).run(merged, 1);
|
| 579 |
}
|
| 580 |
|
|
@@ -582,25 +695,39 @@ export class BitNetEncoderDecoderWebGPU {
|
|
| 582 |
const nHeads = this.graph.n_heads;
|
| 583 |
const headDim = this.graph.head_dim;
|
| 584 |
const q = await this.linear(`${prefix}.w_q`).run(x, seqLen);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 585 |
if (!layerCache.crossK || !layerCache.crossV) {
|
| 586 |
-
layerCache.crossK =
|
| 587 |
-
layerCache.crossV = await this.linear(`${prefix}.w_v`).run(memory, memoryLen);
|
| 588 |
}
|
| 589 |
-
const merged = attention(q, layerCache.crossK, layerCache.crossV, seqLen, memoryLen,
|
| 590 |
return this.linear(`${prefix}.w_o`).run(merged, seqLen);
|
| 591 |
}
|
| 592 |
|
| 593 |
async mlp(prefix, x, seqLen) {
|
| 594 |
const wIn = this.linear(`${prefix}.w_in`);
|
| 595 |
const wOut = this.linear(`${prefix}.w_out`);
|
|
|
|
|
|
|
|
|
|
| 596 |
const hidden = await wIn.run(x, seqLen);
|
| 597 |
const activation = String(this.graph.activation || "silu").toLowerCase();
|
| 598 |
const isGated =
|
| 599 |
wIn.layout.logicalOut === wOut.layout.logicalIn * 2 ||
|
| 600 |
hidden.length === seqLen * wOut.layout.logicalIn * 2;
|
| 601 |
const activated = isGated || ["swiglu", "gated-silu", "geglu", "reglu"].includes(activation)
|
| 602 |
-
?
|
| 603 |
-
|
|
|
|
|
|
|
| 604 |
return wOut.run(activated, seqLen);
|
| 605 |
}
|
| 606 |
|
|
@@ -624,6 +751,14 @@ export class BitNetEncoderDecoderWebGPU {
|
|
| 624 |
}
|
| 625 |
|
| 626 |
async decoderLayerIncremental(index, x, memory, memoryLen, layerCache) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 627 |
let n = this.norm(`decoder.${index}.self_attn_block.n1`, x, 1);
|
| 628 |
x = addInPlace(
|
| 629 |
x.slice(),
|
|
@@ -791,24 +926,54 @@ export class BitNetEncoderDecoderGenerationSession {
|
|
| 791 |
}
|
| 792 |
|
| 793 |
async next(tokenId) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 794 |
await this.prepare();
|
| 795 |
let x = embed([Number(tokenId)], this.runtime.tensor("dec_embed.weight"), this.runtime.graph.d_model);
|
| 796 |
for (let i = 0; i < this.runtime.graph.n_layers; i += 1) {
|
| 797 |
x = await this.runtime.decoderLayerIncremental(i, x, this.memory, this.memoryLen, this.layerCaches[i]);
|
| 798 |
}
|
| 799 |
-
|
| 800 |
-
|
| 801 |
-
|
| 802 |
-
|
| 803 |
-
|
| 804 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 805 |
);
|
| 806 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 807 |
}
|
| 808 |
|
| 809 |
cloneState() {
|
| 810 |
return this.layerCaches.map((cache) => {
|
| 811 |
const cloned = { ...cache };
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 812 |
if (cache.selfK) {
|
| 813 |
cloned.selfK = cache.selfK.slice();
|
| 814 |
cloned.selfKLength = cloned.selfK.length;
|
|
@@ -849,6 +1014,7 @@ export class BitNetEncoderDecoderGenerationSession {
|
|
| 849 |
export class BitNetEncoderDecoderWASM extends BitNetEncoderDecoderWebGPU {
|
| 850 |
constructor(manifest, manifestUrl, denseTensors, linears) {
|
| 851 |
super(null, manifest, manifestUrl, denseTensors, linears);
|
|
|
|
| 852 |
}
|
| 853 |
|
| 854 |
static async fromManifestUrl(manifestUrl, options = {}) {
|
|
|
|
| 77 |
return new Float32Array(length);
|
| 78 |
}
|
| 79 |
|
| 80 |
+
function toUint32IdArray(ids) {
|
| 81 |
+
if (ids instanceof Uint32Array) return ids;
|
| 82 |
+
return Uint32Array.from(Array.from(ids || [], Number).filter((id) => Number.isFinite(id)));
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
function addInPlace(dst, src) {
|
| 86 |
for (let i = 0; i < dst.length; i += 1) {
|
| 87 |
dst[i] += src[i];
|
|
|
|
| 417 |
this.linears = linears;
|
| 418 |
this.denseLinears = {};
|
| 419 |
this.graph = manifest.graph;
|
| 420 |
+
this.wasmOps = null;
|
| 421 |
this.decoderRotary = decoderUsesRotary(manifest, this.graph);
|
| 422 |
this.decoderRotaryBase = rotaryBase(manifest, this.graph);
|
| 423 |
}
|
|
|
|
| 501 |
return denseLayer;
|
| 502 |
}
|
| 503 |
|
| 504 |
+
linear3(firstName, secondName, thirdName, input, rows) {
|
| 505 |
+
const first = this.linear(firstName);
|
| 506 |
+
const second = this.linear(secondName);
|
| 507 |
+
const third = this.linear(thirdName);
|
| 508 |
+
if (this.wasmOps?.bitnet_linear3_f32 && first.handle && second.handle && third.handle) {
|
| 509 |
+
const merged = this.wasmOps.bitnet_linear3_f32(first.handle, second.handle, third.handle, input, rows);
|
| 510 |
+
const firstLen = rows * first.layout.logicalOut;
|
| 511 |
+
const secondLen = rows * second.layout.logicalOut;
|
| 512 |
+
return [
|
| 513 |
+
merged.slice(0, firstLen),
|
| 514 |
+
merged.slice(firstLen, firstLen + secondLen),
|
| 515 |
+
merged.slice(firstLen + secondLen),
|
| 516 |
+
];
|
| 517 |
+
}
|
| 518 |
+
return [first.run(input, rows), second.run(input, rows), third.run(input, rows)];
|
| 519 |
+
}
|
| 520 |
+
|
| 521 |
+
linear2(firstName, secondName, input, rows) {
|
| 522 |
+
const first = this.linear(firstName);
|
| 523 |
+
const second = this.linear(secondName);
|
| 524 |
+
if (this.wasmOps?.bitnet_linear2_f32 && first.handle && second.handle) {
|
| 525 |
+
const merged = this.wasmOps.bitnet_linear2_f32(first.handle, second.handle, input, rows);
|
| 526 |
+
const firstLen = rows * first.layout.logicalOut;
|
| 527 |
+
return [merged.slice(0, firstLen), merged.slice(firstLen)];
|
| 528 |
+
}
|
| 529 |
+
return [first.run(input, rows), second.run(input, rows)];
|
| 530 |
+
}
|
| 531 |
+
|
| 532 |
+
decoderLayerHandle(index) {
|
| 533 |
+
if (!this.wasmOps?.DecoderLayerHandle) return null;
|
| 534 |
+
const names = [
|
| 535 |
+
`decoder.${index}.self_attn_block.attn.w_q`,
|
| 536 |
+
`decoder.${index}.self_attn_block.attn.w_k`,
|
| 537 |
+
`decoder.${index}.self_attn_block.attn.w_v`,
|
| 538 |
+
`decoder.${index}.self_attn_block.attn.w_o`,
|
| 539 |
+
`decoder.${index}.self_attn_block.mlp.w_in`,
|
| 540 |
+
`decoder.${index}.self_attn_block.mlp.w_out`,
|
| 541 |
+
`decoder.${index}.cross_block.cross.w_q`,
|
| 542 |
+
`decoder.${index}.cross_block.cross.w_k`,
|
| 543 |
+
`decoder.${index}.cross_block.cross.w_v`,
|
| 544 |
+
`decoder.${index}.cross_block.cross.w_o`,
|
| 545 |
+
`decoder.${index}.cross_block.mlp.w_in`,
|
| 546 |
+
`decoder.${index}.cross_block.mlp.w_out`,
|
| 547 |
+
];
|
| 548 |
+
const layers = names.map((name) => this.linear(name));
|
| 549 |
+
if (!layers.every((layer) => layer?.handle)) return null;
|
| 550 |
+
return new this.wasmOps.DecoderLayerHandle(
|
| 551 |
+
...layers.map((layer) => layer.handle),
|
| 552 |
+
this.tensor(`decoder.${index}.self_attn_block.n1.weight`),
|
| 553 |
+
this.dense[`decoder.${index}.self_attn_block.n1.bias`]?.data || new Float32Array(0),
|
| 554 |
+
this.tensor(`decoder.${index}.self_attn_block.n2.weight`),
|
| 555 |
+
this.dense[`decoder.${index}.self_attn_block.n2.bias`]?.data || new Float32Array(0),
|
| 556 |
+
this.tensor(`decoder.${index}.cross_block.n1.weight`),
|
| 557 |
+
this.dense[`decoder.${index}.cross_block.n1.bias`]?.data || new Float32Array(0),
|
| 558 |
+
this.tensor(`decoder.${index}.cross_block.n2.weight`),
|
| 559 |
+
this.dense[`decoder.${index}.cross_block.n2.bias`]?.data || new Float32Array(0),
|
| 560 |
+
String(this.graph.activation || "silu"),
|
| 561 |
+
this.graph.d_model,
|
| 562 |
+
this.graph.n_heads,
|
| 563 |
+
this.graph.head_dim,
|
| 564 |
+
this.decoderRotary ? this.decoderRotaryBase : 0,
|
| 565 |
+
);
|
| 566 |
+
}
|
| 567 |
+
|
| 568 |
tensor(name) {
|
| 569 |
const tensor = this.dense[name];
|
| 570 |
if (!tensor) throw new Error(`missing dense tensor: ${name}`);
|
|
|
|
| 574 |
norm(prefix, x, rows) {
|
| 575 |
const weight = this.tensor(`${prefix}.weight`);
|
| 576 |
const bias = this.dense[`${prefix}.bias`]?.data || null;
|
| 577 |
+
if (this.wasmOps?.layer_norm_f32 && bias) {
|
| 578 |
+
return this.wasmOps.layer_norm_f32(x, weight, bias, rows, this.graph.d_model, 1e-5);
|
| 579 |
+
}
|
| 580 |
if (bias) {
|
| 581 |
return layerNorm(x, rows, this.graph.d_model, weight, bias);
|
| 582 |
}
|
|
|
|
| 589 |
);
|
| 590 |
}
|
| 591 |
|
| 592 |
+
attention(q, k, v, qLen, kvLen, causal, pastLen = 0) {
|
| 593 |
+
if (this.wasmOps?.attention_f32) {
|
| 594 |
+
return this.wasmOps.attention_f32(
|
| 595 |
+
q,
|
| 596 |
+
k,
|
| 597 |
+
v,
|
| 598 |
+
qLen,
|
| 599 |
+
kvLen,
|
| 600 |
+
this.graph.n_heads,
|
| 601 |
+
this.graph.head_dim,
|
| 602 |
+
Boolean(causal),
|
| 603 |
+
Number(pastLen || 0),
|
| 604 |
+
);
|
| 605 |
+
}
|
| 606 |
+
return attention(q, k, v, qLen, kvLen, this.graph.n_heads, this.graph.head_dim, causal, pastLen);
|
| 607 |
+
}
|
| 608 |
+
|
| 609 |
async attentionBlock(prefix, x, seqLen, kv, kvLen, causal) {
|
| 610 |
const dModel = this.graph.d_model;
|
| 611 |
const nHeads = this.graph.n_heads;
|
| 612 |
const headDim = this.graph.head_dim;
|
| 613 |
+
let q;
|
| 614 |
+
let k;
|
| 615 |
+
let v;
|
| 616 |
const kInput = kv || x;
|
| 617 |
const kRows = kvLen || seqLen;
|
| 618 |
+
if (!kv) {
|
| 619 |
+
[q, k, v] = this.linear3(`${prefix}.w_q`, `${prefix}.w_k`, `${prefix}.w_v`, x, seqLen);
|
| 620 |
+
} else {
|
| 621 |
+
q = await this.linear(`${prefix}.w_q`).run(x, seqLen);
|
| 622 |
+
[k, v] = this.linear2(`${prefix}.w_k`, `${prefix}.w_v`, kInput, kRows);
|
| 623 |
+
}
|
| 624 |
if (causal && this.decoderRotary) {
|
| 625 |
applyRotaryMergedInPlace(q, k, seqLen, nHeads, headDim, this.decoderRotaryBase, 0);
|
| 626 |
}
|
| 627 |
+
const merged = this.attention(q, k, v, seqLen, kRows, causal);
|
| 628 |
return this.linear(`${prefix}.w_o`).run(merged, seqLen);
|
| 629 |
}
|
| 630 |
|
| 631 |
async selfAttentionIncremental(prefix, x, layerCache) {
|
| 632 |
const nHeads = this.graph.n_heads;
|
| 633 |
const headDim = this.graph.head_dim;
|
| 634 |
+
const [q, kNew, vNew] = this.linear3(`${prefix}.w_q`, `${prefix}.w_k`, `${prefix}.w_v`, x, 1);
|
|
|
|
|
|
|
| 635 |
const position = Number(layerCache.selfLen || 0);
|
| 636 |
if (this.decoderRotary) {
|
| 637 |
applyRotaryMergedInPlace(q, kNew, 1, nHeads, headDim, this.decoderRotaryBase, position);
|
| 638 |
}
|
| 639 |
+
if (this.wasmOps?.AttentionKvCache) {
|
| 640 |
+
layerCache.selfAttention ??= new this.wasmOps.AttentionKvCache(nHeads, headDim);
|
| 641 |
+
const merged = layerCache.selfAttention.append_self_attention(q, kNew, vNew, 1, false);
|
| 642 |
+
layerCache.selfLen = layerCache.selfAttention.len();
|
| 643 |
+
return this.linear(`${prefix}.w_o`).run(merged, 1);
|
| 644 |
+
}
|
| 645 |
layerCache.selfK = appendCachedRows(layerCache, "selfK", kNew);
|
| 646 |
layerCache.selfV = appendCachedRows(layerCache, "selfV", vNew);
|
| 647 |
layerCache.selfLen = Number(layerCache.selfLen || 0) + 1;
|
| 648 |
+
const merged = this.attention(q, layerCache.selfK, layerCache.selfV, 1, layerCache.selfLen, false);
|
| 649 |
return this.linear(`${prefix}.w_o`).run(merged, 1);
|
| 650 |
}
|
| 651 |
|
| 652 |
async selfAttentionIncrementalSpan(prefix, x, seqLen, layerCache) {
|
| 653 |
const nHeads = this.graph.n_heads;
|
| 654 |
const headDim = this.graph.head_dim;
|
| 655 |
+
const [q, kNew, vNew] = this.linear3(`${prefix}.w_q`, `${prefix}.w_k`, `${prefix}.w_v`, x, seqLen);
|
|
|
|
|
|
|
| 656 |
const position = Number(layerCache.selfLen || 0);
|
| 657 |
if (this.decoderRotary) {
|
| 658 |
applyRotaryMergedInPlace(q, kNew, seqLen, nHeads, headDim, this.decoderRotaryBase, position);
|
| 659 |
}
|
| 660 |
+
if (this.wasmOps?.AttentionKvCache) {
|
| 661 |
+
layerCache.selfAttention ??= new this.wasmOps.AttentionKvCache(nHeads, headDim);
|
| 662 |
+
const merged = layerCache.selfAttention.append_self_attention(q, kNew, vNew, seqLen, true);
|
| 663 |
+
layerCache.selfLen = layerCache.selfAttention.len();
|
| 664 |
+
return this.linear(`${prefix}.w_o`).run(merged, seqLen);
|
| 665 |
+
}
|
| 666 |
layerCache.selfK = appendCachedRows(layerCache, "selfK", kNew);
|
| 667 |
layerCache.selfV = appendCachedRows(layerCache, "selfV", vNew);
|
| 668 |
layerCache.selfLen = Number(layerCache.selfLen || 0) + seqLen;
|
| 669 |
+
const merged = this.attention(q, layerCache.selfK, layerCache.selfV, seqLen, layerCache.selfLen, true, position);
|
| 670 |
return this.linear(`${prefix}.w_o`).run(merged, seqLen);
|
| 671 |
}
|
| 672 |
|
|
|
|
| 674 |
const nHeads = this.graph.n_heads;
|
| 675 |
const headDim = this.graph.head_dim;
|
| 676 |
const q = await this.linear(`${prefix}.w_q`).run(x, 1);
|
| 677 |
+
if (this.wasmOps?.AttentionKvCache) {
|
| 678 |
+
layerCache.crossAttention ??= new this.wasmOps.AttentionKvCache(nHeads, headDim);
|
| 679 |
+
if (!layerCache.crossReady) {
|
| 680 |
+
const [crossK, crossV] = this.linear2(`${prefix}.w_k`, `${prefix}.w_v`, memory, memoryLen);
|
| 681 |
+
layerCache.crossAttention.set_cross(crossK, crossV, memoryLen);
|
| 682 |
+
layerCache.crossReady = true;
|
| 683 |
+
}
|
| 684 |
+
const merged = layerCache.crossAttention.attention(q, 1, false, 0);
|
| 685 |
+
return this.linear(`${prefix}.w_o`).run(merged, 1);
|
| 686 |
+
}
|
| 687 |
if (!layerCache.crossK || !layerCache.crossV) {
|
| 688 |
+
[layerCache.crossK, layerCache.crossV] = this.linear2(`${prefix}.w_k`, `${prefix}.w_v`, memory, memoryLen);
|
|
|
|
| 689 |
}
|
| 690 |
+
const merged = this.attention(q, layerCache.crossK, layerCache.crossV, 1, memoryLen, false);
|
| 691 |
return this.linear(`${prefix}.w_o`).run(merged, 1);
|
| 692 |
}
|
| 693 |
|
|
|
|
| 695 |
const nHeads = this.graph.n_heads;
|
| 696 |
const headDim = this.graph.head_dim;
|
| 697 |
const q = await this.linear(`${prefix}.w_q`).run(x, seqLen);
|
| 698 |
+
if (this.wasmOps?.AttentionKvCache) {
|
| 699 |
+
layerCache.crossAttention ??= new this.wasmOps.AttentionKvCache(nHeads, headDim);
|
| 700 |
+
if (!layerCache.crossReady) {
|
| 701 |
+
const [crossK, crossV] = this.linear2(`${prefix}.w_k`, `${prefix}.w_v`, memory, memoryLen);
|
| 702 |
+
layerCache.crossAttention.set_cross(crossK, crossV, memoryLen);
|
| 703 |
+
layerCache.crossReady = true;
|
| 704 |
+
}
|
| 705 |
+
const merged = layerCache.crossAttention.attention(q, seqLen, false, 0);
|
| 706 |
+
return this.linear(`${prefix}.w_o`).run(merged, seqLen);
|
| 707 |
+
}
|
| 708 |
if (!layerCache.crossK || !layerCache.crossV) {
|
| 709 |
+
[layerCache.crossK, layerCache.crossV] = this.linear2(`${prefix}.w_k`, `${prefix}.w_v`, memory, memoryLen);
|
|
|
|
| 710 |
}
|
| 711 |
+
const merged = this.attention(q, layerCache.crossK, layerCache.crossV, seqLen, memoryLen, false);
|
| 712 |
return this.linear(`${prefix}.w_o`).run(merged, seqLen);
|
| 713 |
}
|
| 714 |
|
| 715 |
async mlp(prefix, x, seqLen) {
|
| 716 |
const wIn = this.linear(`${prefix}.w_in`);
|
| 717 |
const wOut = this.linear(`${prefix}.w_out`);
|
| 718 |
+
if (this.wasmOps?.bitnet_mlp_f32 && wIn.handle && wOut.handle) {
|
| 719 |
+
return this.wasmOps.bitnet_mlp_f32(wIn.handle, wOut.handle, x, seqLen, String(this.graph.activation || "silu"));
|
| 720 |
+
}
|
| 721 |
const hidden = await wIn.run(x, seqLen);
|
| 722 |
const activation = String(this.graph.activation || "silu").toLowerCase();
|
| 723 |
const isGated =
|
| 724 |
wIn.layout.logicalOut === wOut.layout.logicalIn * 2 ||
|
| 725 |
hidden.length === seqLen * wOut.layout.logicalIn * 2;
|
| 726 |
const activated = isGated || ["swiglu", "gated-silu", "geglu", "reglu"].includes(activation)
|
| 727 |
+
? (this.wasmOps?.gated_activation_f32
|
| 728 |
+
? this.wasmOps.gated_activation_f32(hidden, seqLen, wOut.layout.logicalIn, activation)
|
| 729 |
+
: gatedActivation(hidden, seqLen, wOut.layout.logicalIn, activation))
|
| 730 |
+
: (this.wasmOps?.activate_f32 ? this.wasmOps.activate_f32(hidden, activation) : activate(hidden, activation));
|
| 731 |
return wOut.run(activated, seqLen);
|
| 732 |
}
|
| 733 |
|
|
|
|
| 751 |
}
|
| 752 |
|
| 753 |
async decoderLayerIncremental(index, x, memory, memoryLen, layerCache) {
|
| 754 |
+
if (this.wasmOps?.DecoderLayerHandle) {
|
| 755 |
+
layerCache.decoderLayer ??= this.decoderLayerHandle(index);
|
| 756 |
+
if (layerCache.decoderLayer?.next) {
|
| 757 |
+
const out = layerCache.decoderLayer.next(x, memory, memoryLen);
|
| 758 |
+
layerCache.selfLen = layerCache.decoderLayer.self_len();
|
| 759 |
+
return out;
|
| 760 |
+
}
|
| 761 |
+
}
|
| 762 |
let n = this.norm(`decoder.${index}.self_attn_block.n1`, x, 1);
|
| 763 |
x = addInPlace(
|
| 764 |
x.slice(),
|
|
|
|
| 926 |
}
|
| 927 |
|
| 928 |
async next(tokenId) {
|
| 929 |
+
const hidden = await this.nextHidden(tokenId);
|
| 930 |
+
return this.runtime.linear("lm_head").run(hidden, 1);
|
| 931 |
+
}
|
| 932 |
+
|
| 933 |
+
async nextHidden(tokenId) {
|
| 934 |
await this.prepare();
|
| 935 |
let x = embed([Number(tokenId)], this.runtime.tensor("dec_embed.weight"), this.runtime.graph.d_model);
|
| 936 |
for (let i = 0; i < this.runtime.graph.n_layers; i += 1) {
|
| 937 |
x = await this.runtime.decoderLayerIncremental(i, x, this.memory, this.memoryLen, this.layerCaches[i]);
|
| 938 |
}
|
| 939 |
+
return this.runtime.norm("dec_norm", x, 1);
|
| 940 |
+
}
|
| 941 |
+
|
| 942 |
+
async sampleNext(tokenId, generatedIds, options = {}) {
|
| 943 |
+
if (!this.runtime.wasmOps?.bitnet_sample_token_f32) return null;
|
| 944 |
+
const lmHead = this.runtime.linear("lm_head");
|
| 945 |
+
if (!lmHead.handle) return null;
|
| 946 |
+
const hidden = await this.nextHidden(tokenId);
|
| 947 |
+
const sample = this.runtime.wasmOps.bitnet_sample_token_f32(
|
| 948 |
+
lmHead.handle,
|
| 949 |
+
hidden,
|
| 950 |
+
toUint32IdArray(generatedIds),
|
| 951 |
+
toUint32IdArray(options.blockedIds),
|
| 952 |
+
Number(options.temperature ?? 0.35),
|
| 953 |
+
Number(options.topP ?? 0.9),
|
| 954 |
+
Number(options.repetitionPenalty ?? 1.16),
|
| 955 |
+
Number(options.randomValue ?? Math.random()),
|
| 956 |
);
|
| 957 |
+
return {
|
| 958 |
+
tokenId: Number(sample.token_id),
|
| 959 |
+
probability: Number(sample.probability),
|
| 960 |
+
topProbability: Number(sample.top_probability),
|
| 961 |
+
rank: Number(sample.rank),
|
| 962 |
+
};
|
| 963 |
}
|
| 964 |
|
| 965 |
cloneState() {
|
| 966 |
return this.layerCaches.map((cache) => {
|
| 967 |
const cloned = { ...cache };
|
| 968 |
+
if (cache.selfAttention?.clone_cache) {
|
| 969 |
+
cloned.selfAttention = cache.selfAttention.clone_cache();
|
| 970 |
+
}
|
| 971 |
+
if (cache.crossAttention?.clone_cache) {
|
| 972 |
+
cloned.crossAttention = cache.crossAttention.clone_cache();
|
| 973 |
+
}
|
| 974 |
+
if (cache.decoderLayer?.clone_cache) {
|
| 975 |
+
cloned.decoderLayer = cache.decoderLayer.clone_cache();
|
| 976 |
+
}
|
| 977 |
if (cache.selfK) {
|
| 978 |
cloned.selfK = cache.selfK.slice();
|
| 979 |
cloned.selfKLength = cloned.selfK.length;
|
|
|
|
| 1014 |
export class BitNetEncoderDecoderWASM extends BitNetEncoderDecoderWebGPU {
|
| 1015 |
constructor(manifest, manifestUrl, denseTensors, linears) {
|
| 1016 |
super(null, manifest, manifestUrl, denseTensors, linears);
|
| 1017 |
+
this.wasmOps = Object.values(linears || {}).find((layer) => layer?.wasm)?.wasm || null;
|
| 1018 |
}
|
| 1019 |
|
| 1020 |
static async fromManifestUrl(manifestUrl, options = {}) {
|
runtime/model_stack_bitnet_wasm_bg.wasm
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3016e4b7ef05f1b94bb19b7017280407281d219341a0d593460d8ba57268cee4
|
| 3 |
+
size 75132
|