PeytonT commited on
Commit
8a84b2f
·
verified ·
1 Parent(s): d10b171

Update AgentKernel Lite WASM decoder kernel

Browse files
runtime/encdec_runtime.js CHANGED
@@ -77,6 +77,11 @@ function zeros(length) {
77
  return new Float32Array(length);
78
  }
79
 
 
 
 
 
 
80
  function addInPlace(dst, src) {
81
  for (let i = 0; i < dst.length; i += 1) {
82
  dst[i] += src[i];
@@ -412,6 +417,7 @@ export class BitNetEncoderDecoderWebGPU {
412
  this.linears = linears;
413
  this.denseLinears = {};
414
  this.graph = manifest.graph;
 
415
  this.decoderRotary = decoderUsesRotary(manifest, this.graph);
416
  this.decoderRotaryBase = rotaryBase(manifest, this.graph);
417
  }
@@ -495,6 +501,70 @@ export class BitNetEncoderDecoderWebGPU {
495
  return denseLayer;
496
  }
497
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
498
  tensor(name) {
499
  const tensor = this.dense[name];
500
  if (!tensor) throw new Error(`missing dense tensor: ${name}`);
@@ -504,6 +574,9 @@ export class BitNetEncoderDecoderWebGPU {
504
  norm(prefix, x, rows) {
505
  const weight = this.tensor(`${prefix}.weight`);
506
  const bias = this.dense[`${prefix}.bias`]?.data || null;
 
 
 
507
  if (bias) {
508
  return layerNorm(x, rows, this.graph.d_model, weight, bias);
509
  }
@@ -516,53 +589,84 @@ export class BitNetEncoderDecoderWebGPU {
516
  );
517
  }
518
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
519
  async attentionBlock(prefix, x, seqLen, kv, kvLen, causal) {
520
  const dModel = this.graph.d_model;
521
  const nHeads = this.graph.n_heads;
522
  const headDim = this.graph.head_dim;
523
- const q = await this.linear(`${prefix}.w_q`).run(x, seqLen);
 
 
524
  const kInput = kv || x;
525
  const kRows = kvLen || seqLen;
526
- const k = await this.linear(`${prefix}.w_k`).run(kInput, kRows);
527
- const v = await this.linear(`${prefix}.w_v`).run(kInput, kRows);
 
 
 
 
528
  if (causal && this.decoderRotary) {
529
  applyRotaryMergedInPlace(q, k, seqLen, nHeads, headDim, this.decoderRotaryBase, 0);
530
  }
531
- const merged = attention(q, k, v, seqLen, kRows, nHeads, headDim, causal);
532
  return this.linear(`${prefix}.w_o`).run(merged, seqLen);
533
  }
534
 
535
  async selfAttentionIncremental(prefix, x, layerCache) {
536
  const nHeads = this.graph.n_heads;
537
  const headDim = this.graph.head_dim;
538
- const q = await this.linear(`${prefix}.w_q`).run(x, 1);
539
- const kNew = await this.linear(`${prefix}.w_k`).run(x, 1);
540
- const vNew = await this.linear(`${prefix}.w_v`).run(x, 1);
541
  const position = Number(layerCache.selfLen || 0);
542
  if (this.decoderRotary) {
543
  applyRotaryMergedInPlace(q, kNew, 1, nHeads, headDim, this.decoderRotaryBase, position);
544
  }
 
 
 
 
 
 
545
  layerCache.selfK = appendCachedRows(layerCache, "selfK", kNew);
546
  layerCache.selfV = appendCachedRows(layerCache, "selfV", vNew);
547
  layerCache.selfLen = Number(layerCache.selfLen || 0) + 1;
548
- const merged = attention(q, layerCache.selfK, layerCache.selfV, 1, layerCache.selfLen, nHeads, headDim, false);
549
  return this.linear(`${prefix}.w_o`).run(merged, 1);
550
  }
551
 
552
  async selfAttentionIncrementalSpan(prefix, x, seqLen, layerCache) {
553
  const nHeads = this.graph.n_heads;
554
  const headDim = this.graph.head_dim;
555
- const q = await this.linear(`${prefix}.w_q`).run(x, seqLen);
556
- const kNew = await this.linear(`${prefix}.w_k`).run(x, seqLen);
557
- const vNew = await this.linear(`${prefix}.w_v`).run(x, seqLen);
558
  const position = Number(layerCache.selfLen || 0);
559
  if (this.decoderRotary) {
560
  applyRotaryMergedInPlace(q, kNew, seqLen, nHeads, headDim, this.decoderRotaryBase, position);
561
  }
 
 
 
 
 
 
562
  layerCache.selfK = appendCachedRows(layerCache, "selfK", kNew);
563
  layerCache.selfV = appendCachedRows(layerCache, "selfV", vNew);
564
  layerCache.selfLen = Number(layerCache.selfLen || 0) + seqLen;
565
- const merged = attention(q, layerCache.selfK, layerCache.selfV, seqLen, layerCache.selfLen, nHeads, headDim, true, position);
566
  return this.linear(`${prefix}.w_o`).run(merged, seqLen);
567
  }
568
 
@@ -570,11 +674,20 @@ export class BitNetEncoderDecoderWebGPU {
570
  const nHeads = this.graph.n_heads;
571
  const headDim = this.graph.head_dim;
572
  const q = await this.linear(`${prefix}.w_q`).run(x, 1);
 
 
 
 
 
 
 
 
 
 
573
  if (!layerCache.crossK || !layerCache.crossV) {
574
- layerCache.crossK = await this.linear(`${prefix}.w_k`).run(memory, memoryLen);
575
- layerCache.crossV = await this.linear(`${prefix}.w_v`).run(memory, memoryLen);
576
  }
577
- const merged = attention(q, layerCache.crossK, layerCache.crossV, 1, memoryLen, nHeads, headDim, false);
578
  return this.linear(`${prefix}.w_o`).run(merged, 1);
579
  }
580
 
@@ -582,25 +695,39 @@ export class BitNetEncoderDecoderWebGPU {
582
  const nHeads = this.graph.n_heads;
583
  const headDim = this.graph.head_dim;
584
  const q = await this.linear(`${prefix}.w_q`).run(x, seqLen);
 
 
 
 
 
 
 
 
 
 
585
  if (!layerCache.crossK || !layerCache.crossV) {
586
- layerCache.crossK = await this.linear(`${prefix}.w_k`).run(memory, memoryLen);
587
- layerCache.crossV = await this.linear(`${prefix}.w_v`).run(memory, memoryLen);
588
  }
589
- const merged = attention(q, layerCache.crossK, layerCache.crossV, seqLen, memoryLen, nHeads, headDim, false);
590
  return this.linear(`${prefix}.w_o`).run(merged, seqLen);
591
  }
592
 
593
  async mlp(prefix, x, seqLen) {
594
  const wIn = this.linear(`${prefix}.w_in`);
595
  const wOut = this.linear(`${prefix}.w_out`);
 
 
 
596
  const hidden = await wIn.run(x, seqLen);
597
  const activation = String(this.graph.activation || "silu").toLowerCase();
598
  const isGated =
599
  wIn.layout.logicalOut === wOut.layout.logicalIn * 2 ||
600
  hidden.length === seqLen * wOut.layout.logicalIn * 2;
601
  const activated = isGated || ["swiglu", "gated-silu", "geglu", "reglu"].includes(activation)
602
- ? gatedActivation(hidden, seqLen, wOut.layout.logicalIn, activation)
603
- : activate(hidden, activation);
 
 
604
  return wOut.run(activated, seqLen);
605
  }
606
 
@@ -624,6 +751,14 @@ export class BitNetEncoderDecoderWebGPU {
624
  }
625
 
626
  async decoderLayerIncremental(index, x, memory, memoryLen, layerCache) {
 
 
 
 
 
 
 
 
627
  let n = this.norm(`decoder.${index}.self_attn_block.n1`, x, 1);
628
  x = addInPlace(
629
  x.slice(),
@@ -791,24 +926,54 @@ export class BitNetEncoderDecoderGenerationSession {
791
  }
792
 
793
  async next(tokenId) {
 
 
 
 
 
794
  await this.prepare();
795
  let x = embed([Number(tokenId)], this.runtime.tensor("dec_embed.weight"), this.runtime.graph.d_model);
796
  for (let i = 0; i < this.runtime.graph.n_layers; i += 1) {
797
  x = await this.runtime.decoderLayerIncremental(i, x, this.memory, this.memoryLen, this.layerCaches[i]);
798
  }
799
- const hidden = layerNorm(
800
- x,
801
- 1,
802
- this.runtime.graph.d_model,
803
- this.runtime.tensor("dec_norm.weight"),
804
- this.runtime.dense["dec_norm.bias"]?.data,
 
 
 
 
 
 
 
 
 
 
 
805
  );
806
- return this.runtime.linear("lm_head").run(hidden, 1);
 
 
 
 
 
807
  }
808
 
809
  cloneState() {
810
  return this.layerCaches.map((cache) => {
811
  const cloned = { ...cache };
 
 
 
 
 
 
 
 
 
812
  if (cache.selfK) {
813
  cloned.selfK = cache.selfK.slice();
814
  cloned.selfKLength = cloned.selfK.length;
@@ -849,6 +1014,7 @@ export class BitNetEncoderDecoderGenerationSession {
849
  export class BitNetEncoderDecoderWASM extends BitNetEncoderDecoderWebGPU {
850
  constructor(manifest, manifestUrl, denseTensors, linears) {
851
  super(null, manifest, manifestUrl, denseTensors, linears);
 
852
  }
853
 
854
  static async fromManifestUrl(manifestUrl, options = {}) {
 
77
  return new Float32Array(length);
78
  }
79
 
80
+ function toUint32IdArray(ids) {
81
+ if (ids instanceof Uint32Array) return ids;
82
+ return Uint32Array.from(Array.from(ids || [], Number).filter((id) => Number.isFinite(id)));
83
+ }
84
+
85
  function addInPlace(dst, src) {
86
  for (let i = 0; i < dst.length; i += 1) {
87
  dst[i] += src[i];
 
417
  this.linears = linears;
418
  this.denseLinears = {};
419
  this.graph = manifest.graph;
420
+ this.wasmOps = null;
421
  this.decoderRotary = decoderUsesRotary(manifest, this.graph);
422
  this.decoderRotaryBase = rotaryBase(manifest, this.graph);
423
  }
 
501
  return denseLayer;
502
  }
503
 
504
+ linear3(firstName, secondName, thirdName, input, rows) {
505
+ const first = this.linear(firstName);
506
+ const second = this.linear(secondName);
507
+ const third = this.linear(thirdName);
508
+ if (this.wasmOps?.bitnet_linear3_f32 && first.handle && second.handle && third.handle) {
509
+ const merged = this.wasmOps.bitnet_linear3_f32(first.handle, second.handle, third.handle, input, rows);
510
+ const firstLen = rows * first.layout.logicalOut;
511
+ const secondLen = rows * second.layout.logicalOut;
512
+ return [
513
+ merged.slice(0, firstLen),
514
+ merged.slice(firstLen, firstLen + secondLen),
515
+ merged.slice(firstLen + secondLen),
516
+ ];
517
+ }
518
+ return [first.run(input, rows), second.run(input, rows), third.run(input, rows)];
519
+ }
520
+
521
+ linear2(firstName, secondName, input, rows) {
522
+ const first = this.linear(firstName);
523
+ const second = this.linear(secondName);
524
+ if (this.wasmOps?.bitnet_linear2_f32 && first.handle && second.handle) {
525
+ const merged = this.wasmOps.bitnet_linear2_f32(first.handle, second.handle, input, rows);
526
+ const firstLen = rows * first.layout.logicalOut;
527
+ return [merged.slice(0, firstLen), merged.slice(firstLen)];
528
+ }
529
+ return [first.run(input, rows), second.run(input, rows)];
530
+ }
531
+
532
+ decoderLayerHandle(index) {
533
+ if (!this.wasmOps?.DecoderLayerHandle) return null;
534
+ const names = [
535
+ `decoder.${index}.self_attn_block.attn.w_q`,
536
+ `decoder.${index}.self_attn_block.attn.w_k`,
537
+ `decoder.${index}.self_attn_block.attn.w_v`,
538
+ `decoder.${index}.self_attn_block.attn.w_o`,
539
+ `decoder.${index}.self_attn_block.mlp.w_in`,
540
+ `decoder.${index}.self_attn_block.mlp.w_out`,
541
+ `decoder.${index}.cross_block.cross.w_q`,
542
+ `decoder.${index}.cross_block.cross.w_k`,
543
+ `decoder.${index}.cross_block.cross.w_v`,
544
+ `decoder.${index}.cross_block.cross.w_o`,
545
+ `decoder.${index}.cross_block.mlp.w_in`,
546
+ `decoder.${index}.cross_block.mlp.w_out`,
547
+ ];
548
+ const layers = names.map((name) => this.linear(name));
549
+ if (!layers.every((layer) => layer?.handle)) return null;
550
+ return new this.wasmOps.DecoderLayerHandle(
551
+ ...layers.map((layer) => layer.handle),
552
+ this.tensor(`decoder.${index}.self_attn_block.n1.weight`),
553
+ this.dense[`decoder.${index}.self_attn_block.n1.bias`]?.data || new Float32Array(0),
554
+ this.tensor(`decoder.${index}.self_attn_block.n2.weight`),
555
+ this.dense[`decoder.${index}.self_attn_block.n2.bias`]?.data || new Float32Array(0),
556
+ this.tensor(`decoder.${index}.cross_block.n1.weight`),
557
+ this.dense[`decoder.${index}.cross_block.n1.bias`]?.data || new Float32Array(0),
558
+ this.tensor(`decoder.${index}.cross_block.n2.weight`),
559
+ this.dense[`decoder.${index}.cross_block.n2.bias`]?.data || new Float32Array(0),
560
+ String(this.graph.activation || "silu"),
561
+ this.graph.d_model,
562
+ this.graph.n_heads,
563
+ this.graph.head_dim,
564
+ this.decoderRotary ? this.decoderRotaryBase : 0,
565
+ );
566
+ }
567
+
568
  tensor(name) {
569
  const tensor = this.dense[name];
570
  if (!tensor) throw new Error(`missing dense tensor: ${name}`);
 
574
  norm(prefix, x, rows) {
575
  const weight = this.tensor(`${prefix}.weight`);
576
  const bias = this.dense[`${prefix}.bias`]?.data || null;
577
+ if (this.wasmOps?.layer_norm_f32 && bias) {
578
+ return this.wasmOps.layer_norm_f32(x, weight, bias, rows, this.graph.d_model, 1e-5);
579
+ }
580
  if (bias) {
581
  return layerNorm(x, rows, this.graph.d_model, weight, bias);
582
  }
 
589
  );
590
  }
591
 
592
+ attention(q, k, v, qLen, kvLen, causal, pastLen = 0) {
593
+ if (this.wasmOps?.attention_f32) {
594
+ return this.wasmOps.attention_f32(
595
+ q,
596
+ k,
597
+ v,
598
+ qLen,
599
+ kvLen,
600
+ this.graph.n_heads,
601
+ this.graph.head_dim,
602
+ Boolean(causal),
603
+ Number(pastLen || 0),
604
+ );
605
+ }
606
+ return attention(q, k, v, qLen, kvLen, this.graph.n_heads, this.graph.head_dim, causal, pastLen);
607
+ }
608
+
609
  async attentionBlock(prefix, x, seqLen, kv, kvLen, causal) {
610
  const dModel = this.graph.d_model;
611
  const nHeads = this.graph.n_heads;
612
  const headDim = this.graph.head_dim;
613
+ let q;
614
+ let k;
615
+ let v;
616
  const kInput = kv || x;
617
  const kRows = kvLen || seqLen;
618
+ if (!kv) {
619
+ [q, k, v] = this.linear3(`${prefix}.w_q`, `${prefix}.w_k`, `${prefix}.w_v`, x, seqLen);
620
+ } else {
621
+ q = await this.linear(`${prefix}.w_q`).run(x, seqLen);
622
+ [k, v] = this.linear2(`${prefix}.w_k`, `${prefix}.w_v`, kInput, kRows);
623
+ }
624
  if (causal && this.decoderRotary) {
625
  applyRotaryMergedInPlace(q, k, seqLen, nHeads, headDim, this.decoderRotaryBase, 0);
626
  }
627
+ const merged = this.attention(q, k, v, seqLen, kRows, causal);
628
  return this.linear(`${prefix}.w_o`).run(merged, seqLen);
629
  }
630
 
631
  async selfAttentionIncremental(prefix, x, layerCache) {
632
  const nHeads = this.graph.n_heads;
633
  const headDim = this.graph.head_dim;
634
+ const [q, kNew, vNew] = this.linear3(`${prefix}.w_q`, `${prefix}.w_k`, `${prefix}.w_v`, x, 1);
 
 
635
  const position = Number(layerCache.selfLen || 0);
636
  if (this.decoderRotary) {
637
  applyRotaryMergedInPlace(q, kNew, 1, nHeads, headDim, this.decoderRotaryBase, position);
638
  }
639
+ if (this.wasmOps?.AttentionKvCache) {
640
+ layerCache.selfAttention ??= new this.wasmOps.AttentionKvCache(nHeads, headDim);
641
+ const merged = layerCache.selfAttention.append_self_attention(q, kNew, vNew, 1, false);
642
+ layerCache.selfLen = layerCache.selfAttention.len();
643
+ return this.linear(`${prefix}.w_o`).run(merged, 1);
644
+ }
645
  layerCache.selfK = appendCachedRows(layerCache, "selfK", kNew);
646
  layerCache.selfV = appendCachedRows(layerCache, "selfV", vNew);
647
  layerCache.selfLen = Number(layerCache.selfLen || 0) + 1;
648
+ const merged = this.attention(q, layerCache.selfK, layerCache.selfV, 1, layerCache.selfLen, false);
649
  return this.linear(`${prefix}.w_o`).run(merged, 1);
650
  }
651
 
652
  async selfAttentionIncrementalSpan(prefix, x, seqLen, layerCache) {
653
  const nHeads = this.graph.n_heads;
654
  const headDim = this.graph.head_dim;
655
+ const [q, kNew, vNew] = this.linear3(`${prefix}.w_q`, `${prefix}.w_k`, `${prefix}.w_v`, x, seqLen);
 
 
656
  const position = Number(layerCache.selfLen || 0);
657
  if (this.decoderRotary) {
658
  applyRotaryMergedInPlace(q, kNew, seqLen, nHeads, headDim, this.decoderRotaryBase, position);
659
  }
660
+ if (this.wasmOps?.AttentionKvCache) {
661
+ layerCache.selfAttention ??= new this.wasmOps.AttentionKvCache(nHeads, headDim);
662
+ const merged = layerCache.selfAttention.append_self_attention(q, kNew, vNew, seqLen, true);
663
+ layerCache.selfLen = layerCache.selfAttention.len();
664
+ return this.linear(`${prefix}.w_o`).run(merged, seqLen);
665
+ }
666
  layerCache.selfK = appendCachedRows(layerCache, "selfK", kNew);
667
  layerCache.selfV = appendCachedRows(layerCache, "selfV", vNew);
668
  layerCache.selfLen = Number(layerCache.selfLen || 0) + seqLen;
669
+ const merged = this.attention(q, layerCache.selfK, layerCache.selfV, seqLen, layerCache.selfLen, true, position);
670
  return this.linear(`${prefix}.w_o`).run(merged, seqLen);
671
  }
672
 
 
674
  const nHeads = this.graph.n_heads;
675
  const headDim = this.graph.head_dim;
676
  const q = await this.linear(`${prefix}.w_q`).run(x, 1);
677
+ if (this.wasmOps?.AttentionKvCache) {
678
+ layerCache.crossAttention ??= new this.wasmOps.AttentionKvCache(nHeads, headDim);
679
+ if (!layerCache.crossReady) {
680
+ const [crossK, crossV] = this.linear2(`${prefix}.w_k`, `${prefix}.w_v`, memory, memoryLen);
681
+ layerCache.crossAttention.set_cross(crossK, crossV, memoryLen);
682
+ layerCache.crossReady = true;
683
+ }
684
+ const merged = layerCache.crossAttention.attention(q, 1, false, 0);
685
+ return this.linear(`${prefix}.w_o`).run(merged, 1);
686
+ }
687
  if (!layerCache.crossK || !layerCache.crossV) {
688
+ [layerCache.crossK, layerCache.crossV] = this.linear2(`${prefix}.w_k`, `${prefix}.w_v`, memory, memoryLen);
 
689
  }
690
+ const merged = this.attention(q, layerCache.crossK, layerCache.crossV, 1, memoryLen, false);
691
  return this.linear(`${prefix}.w_o`).run(merged, 1);
692
  }
693
 
 
695
  const nHeads = this.graph.n_heads;
696
  const headDim = this.graph.head_dim;
697
  const q = await this.linear(`${prefix}.w_q`).run(x, seqLen);
698
+ if (this.wasmOps?.AttentionKvCache) {
699
+ layerCache.crossAttention ??= new this.wasmOps.AttentionKvCache(nHeads, headDim);
700
+ if (!layerCache.crossReady) {
701
+ const [crossK, crossV] = this.linear2(`${prefix}.w_k`, `${prefix}.w_v`, memory, memoryLen);
702
+ layerCache.crossAttention.set_cross(crossK, crossV, memoryLen);
703
+ layerCache.crossReady = true;
704
+ }
705
+ const merged = layerCache.crossAttention.attention(q, seqLen, false, 0);
706
+ return this.linear(`${prefix}.w_o`).run(merged, seqLen);
707
+ }
708
  if (!layerCache.crossK || !layerCache.crossV) {
709
+ [layerCache.crossK, layerCache.crossV] = this.linear2(`${prefix}.w_k`, `${prefix}.w_v`, memory, memoryLen);
 
710
  }
711
+ const merged = this.attention(q, layerCache.crossK, layerCache.crossV, seqLen, memoryLen, false);
712
  return this.linear(`${prefix}.w_o`).run(merged, seqLen);
713
  }
714
 
715
  async mlp(prefix, x, seqLen) {
716
  const wIn = this.linear(`${prefix}.w_in`);
717
  const wOut = this.linear(`${prefix}.w_out`);
718
+ if (this.wasmOps?.bitnet_mlp_f32 && wIn.handle && wOut.handle) {
719
+ return this.wasmOps.bitnet_mlp_f32(wIn.handle, wOut.handle, x, seqLen, String(this.graph.activation || "silu"));
720
+ }
721
  const hidden = await wIn.run(x, seqLen);
722
  const activation = String(this.graph.activation || "silu").toLowerCase();
723
  const isGated =
724
  wIn.layout.logicalOut === wOut.layout.logicalIn * 2 ||
725
  hidden.length === seqLen * wOut.layout.logicalIn * 2;
726
  const activated = isGated || ["swiglu", "gated-silu", "geglu", "reglu"].includes(activation)
727
+ ? (this.wasmOps?.gated_activation_f32
728
+ ? this.wasmOps.gated_activation_f32(hidden, seqLen, wOut.layout.logicalIn, activation)
729
+ : gatedActivation(hidden, seqLen, wOut.layout.logicalIn, activation))
730
+ : (this.wasmOps?.activate_f32 ? this.wasmOps.activate_f32(hidden, activation) : activate(hidden, activation));
731
  return wOut.run(activated, seqLen);
732
  }
733
 
 
751
  }
752
 
753
  async decoderLayerIncremental(index, x, memory, memoryLen, layerCache) {
754
+ if (this.wasmOps?.DecoderLayerHandle) {
755
+ layerCache.decoderLayer ??= this.decoderLayerHandle(index);
756
+ if (layerCache.decoderLayer?.next) {
757
+ const out = layerCache.decoderLayer.next(x, memory, memoryLen);
758
+ layerCache.selfLen = layerCache.decoderLayer.self_len();
759
+ return out;
760
+ }
761
+ }
762
  let n = this.norm(`decoder.${index}.self_attn_block.n1`, x, 1);
763
  x = addInPlace(
764
  x.slice(),
 
926
  }
927
 
928
  async next(tokenId) {
929
+ const hidden = await this.nextHidden(tokenId);
930
+ return this.runtime.linear("lm_head").run(hidden, 1);
931
+ }
932
+
933
+ async nextHidden(tokenId) {
934
  await this.prepare();
935
  let x = embed([Number(tokenId)], this.runtime.tensor("dec_embed.weight"), this.runtime.graph.d_model);
936
  for (let i = 0; i < this.runtime.graph.n_layers; i += 1) {
937
  x = await this.runtime.decoderLayerIncremental(i, x, this.memory, this.memoryLen, this.layerCaches[i]);
938
  }
939
+ return this.runtime.norm("dec_norm", x, 1);
940
+ }
941
+
942
+ async sampleNext(tokenId, generatedIds, options = {}) {
943
+ if (!this.runtime.wasmOps?.bitnet_sample_token_f32) return null;
944
+ const lmHead = this.runtime.linear("lm_head");
945
+ if (!lmHead.handle) return null;
946
+ const hidden = await this.nextHidden(tokenId);
947
+ const sample = this.runtime.wasmOps.bitnet_sample_token_f32(
948
+ lmHead.handle,
949
+ hidden,
950
+ toUint32IdArray(generatedIds),
951
+ toUint32IdArray(options.blockedIds),
952
+ Number(options.temperature ?? 0.35),
953
+ Number(options.topP ?? 0.9),
954
+ Number(options.repetitionPenalty ?? 1.16),
955
+ Number(options.randomValue ?? Math.random()),
956
  );
957
+ return {
958
+ tokenId: Number(sample.token_id),
959
+ probability: Number(sample.probability),
960
+ topProbability: Number(sample.top_probability),
961
+ rank: Number(sample.rank),
962
+ };
963
  }
964
 
965
  cloneState() {
966
  return this.layerCaches.map((cache) => {
967
  const cloned = { ...cache };
968
+ if (cache.selfAttention?.clone_cache) {
969
+ cloned.selfAttention = cache.selfAttention.clone_cache();
970
+ }
971
+ if (cache.crossAttention?.clone_cache) {
972
+ cloned.crossAttention = cache.crossAttention.clone_cache();
973
+ }
974
+ if (cache.decoderLayer?.clone_cache) {
975
+ cloned.decoderLayer = cache.decoderLayer.clone_cache();
976
+ }
977
  if (cache.selfK) {
978
  cloned.selfK = cache.selfK.slice();
979
  cloned.selfKLength = cloned.selfK.length;
 
1014
  export class BitNetEncoderDecoderWASM extends BitNetEncoderDecoderWebGPU {
1015
  constructor(manifest, manifestUrl, denseTensors, linears) {
1016
  super(null, manifest, manifestUrl, denseTensors, linears);
1017
+ this.wasmOps = Object.values(linears || {}).find((layer) => layer?.wasm)?.wasm || null;
1018
  }
1019
 
1020
  static async fromManifestUrl(manifestUrl, options = {}) {
runtime/model_stack_bitnet_wasm_bg.wasm CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:6dd3b40517258c5ef044cdde24c0a7b7a812fc3b13b55cad0b89794f9a0559c5
3
- size 69044
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3016e4b7ef05f1b94bb19b7017280407281d219341a0d593460d8ba57268cee4
3
+ size 75132