PeytonT commited on
Commit
74e04d2
·
verified ·
1 Parent(s): 26989ee

Update packed BitNet WASM fallback runtime

Browse files
Files changed (1) hide show
  1. runtime/encdec_runtime.js +187 -0
runtime/encdec_runtime.js CHANGED
@@ -1,4 +1,5 @@
1
  import { BitNetLinearWebGPU } from "./bitnet_webgpu.js";
 
2
 
3
  function resolveUrl(path, baseUrl) {
4
  return new URL(path, baseUrl).toString();
@@ -31,6 +32,14 @@ function addInPlace(dst, src) {
31
  return dst;
32
  }
33
 
 
 
 
 
 
 
 
 
34
  function layerNorm(x, rows, cols, weight, bias, eps = 1e-5) {
35
  const out = new Float32Array(x.length);
36
  for (let r = 0; r < rows; r += 1) {
@@ -262,6 +271,31 @@ export class BitNetEncoderDecoderWebGPU {
262
  return this.linear(`${prefix}.w_o`).run(merged, seqLen);
263
  }
264
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
  async mlp(prefix, x, seqLen) {
266
  const wIn = this.linear(`${prefix}.w_in`);
267
  const wOut = this.linear(`${prefix}.w_out`);
@@ -333,6 +367,48 @@ export class BitNetEncoderDecoderWebGPU {
333
  return addInPlace(x, await this.mlp(`decoder.${index}.cross_block.mlp`, n, seqLen));
334
  }
335
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
  async encode(encInputIds) {
337
  let x = embed(encInputIds, this.tensor("enc_embed.weight"), this.graph.d_model);
338
  for (let i = 0; i < this.graph.n_layers; i += 1) {
@@ -366,4 +442,115 @@ export class BitNetEncoderDecoderWebGPU {
366
  const hidden = await this.decode(decInputIds, memory, encInputIds.length);
367
  return this.linear("lm_head").run(hidden, decInputIds.length);
368
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
369
  }
 
1
  import { BitNetLinearWebGPU } from "./bitnet_webgpu.js";
2
+ import { BitNetLinearWASM } from "./bitnet_wasm_runtime.js";
3
 
4
  function resolveUrl(path, baseUrl) {
5
  return new URL(path, baseUrl).toString();
 
32
  return dst;
33
  }
34
 
35
+ function appendRows(existing, next) {
36
+ if (!existing || existing.length === 0) return next.slice();
37
+ const out = new Float32Array(existing.length + next.length);
38
+ out.set(existing, 0);
39
+ out.set(next, existing.length);
40
+ return out;
41
+ }
42
+
43
  function layerNorm(x, rows, cols, weight, bias, eps = 1e-5) {
44
  const out = new Float32Array(x.length);
45
  for (let r = 0; r < rows; r += 1) {
 
271
  return this.linear(`${prefix}.w_o`).run(merged, seqLen);
272
  }
273
 
274
+ async selfAttentionIncremental(prefix, x, layerCache) {
275
+ const nHeads = this.graph.n_heads;
276
+ const headDim = this.graph.head_dim;
277
+ const q = await this.linear(`${prefix}.w_q`).run(x, 1);
278
+ const kNew = await this.linear(`${prefix}.w_k`).run(x, 1);
279
+ const vNew = await this.linear(`${prefix}.w_v`).run(x, 1);
280
+ layerCache.selfK = appendRows(layerCache.selfK, kNew);
281
+ layerCache.selfV = appendRows(layerCache.selfV, vNew);
282
+ layerCache.selfLen = Number(layerCache.selfLen || 0) + 1;
283
+ const merged = attention(q, layerCache.selfK, layerCache.selfV, 1, layerCache.selfLen, nHeads, headDim, false);
284
+ return this.linear(`${prefix}.w_o`).run(merged, 1);
285
+ }
286
+
287
+ async crossAttentionCached(prefix, x, memory, memoryLen, layerCache) {
288
+ const nHeads = this.graph.n_heads;
289
+ const headDim = this.graph.head_dim;
290
+ const q = await this.linear(`${prefix}.w_q`).run(x, 1);
291
+ if (!layerCache.crossK || !layerCache.crossV) {
292
+ layerCache.crossK = await this.linear(`${prefix}.w_k`).run(memory, memoryLen);
293
+ layerCache.crossV = await this.linear(`${prefix}.w_v`).run(memory, memoryLen);
294
+ }
295
+ const merged = attention(q, layerCache.crossK, layerCache.crossV, 1, memoryLen, nHeads, headDim, false);
296
+ return this.linear(`${prefix}.w_o`).run(merged, 1);
297
+ }
298
+
299
  async mlp(prefix, x, seqLen) {
300
  const wIn = this.linear(`${prefix}.w_in`);
301
  const wOut = this.linear(`${prefix}.w_out`);
 
367
  return addInPlace(x, await this.mlp(`decoder.${index}.cross_block.mlp`, n, seqLen));
368
  }
369
 
370
+ async decoderLayerIncremental(index, x, memory, memoryLen, layerCache) {
371
+ const dModel = this.graph.d_model;
372
+ let n = layerNorm(
373
+ x,
374
+ 1,
375
+ dModel,
376
+ this.tensor(`decoder.${index}.self_attn_block.n1.weight`),
377
+ this.dense[`decoder.${index}.self_attn_block.n1.bias`]?.data,
378
+ );
379
+ x = addInPlace(
380
+ x.slice(),
381
+ await this.selfAttentionIncremental(`decoder.${index}.self_attn_block.attn`, n, layerCache),
382
+ );
383
+ n = layerNorm(
384
+ x,
385
+ 1,
386
+ dModel,
387
+ this.tensor(`decoder.${index}.self_attn_block.n2.weight`),
388
+ this.dense[`decoder.${index}.self_attn_block.n2.bias`]?.data,
389
+ );
390
+ x = addInPlace(x, await this.mlp(`decoder.${index}.self_attn_block.mlp`, n, 1));
391
+ n = layerNorm(
392
+ x,
393
+ 1,
394
+ dModel,
395
+ this.tensor(`decoder.${index}.cross_block.n1.weight`),
396
+ this.dense[`decoder.${index}.cross_block.n1.bias`]?.data,
397
+ );
398
+ x = addInPlace(
399
+ x.slice(),
400
+ await this.crossAttentionCached(`decoder.${index}.cross_block.cross`, n, memory, memoryLen, layerCache),
401
+ );
402
+ n = layerNorm(
403
+ x,
404
+ 1,
405
+ dModel,
406
+ this.tensor(`decoder.${index}.cross_block.n2.weight`),
407
+ this.dense[`decoder.${index}.cross_block.n2.bias`]?.data,
408
+ );
409
+ return addInPlace(x, await this.mlp(`decoder.${index}.cross_block.mlp`, n, 1));
410
+ }
411
+
412
  async encode(encInputIds) {
413
  let x = embed(encInputIds, this.tensor("enc_embed.weight"), this.graph.d_model);
414
  for (let i = 0; i < this.graph.n_layers; i += 1) {
 
442
  const hidden = await this.decode(decInputIds, memory, encInputIds.length);
443
  return this.linear("lm_head").run(hidden, decInputIds.length);
444
  }
445
+
446
+ createGenerationSession(encInputIds) {
447
+ return new BitNetEncoderDecoderGenerationSession(this, encInputIds);
448
+ }
449
+ }
450
+
451
+ export class BitNetEncoderDecoderGenerationSession {
452
+ constructor(runtime, encInputIds) {
453
+ this.runtime = runtime;
454
+ this.encInputIds = Array.from(encInputIds || [], Number);
455
+ this.memory = null;
456
+ this.memoryLen = this.encInputIds.length;
457
+ this.layerCaches = Array.from({ length: runtime.graph.n_layers }, () => ({}));
458
+ }
459
+
460
+ async prepare() {
461
+ if (!this.memory) {
462
+ this.memory = await this.runtime.encode(this.encInputIds);
463
+ }
464
+ return this;
465
+ }
466
+
467
+ async next(tokenId) {
468
+ await this.prepare();
469
+ let x = embed([Number(tokenId)], this.runtime.tensor("dec_embed.weight"), this.runtime.graph.d_model);
470
+ for (let i = 0; i < this.runtime.graph.n_layers; i += 1) {
471
+ x = await this.runtime.decoderLayerIncremental(i, x, this.memory, this.memoryLen, this.layerCaches[i]);
472
+ }
473
+ const hidden = layerNorm(
474
+ x,
475
+ 1,
476
+ this.runtime.graph.d_model,
477
+ this.runtime.tensor("dec_norm.weight"),
478
+ this.runtime.dense["dec_norm.bias"]?.data,
479
+ );
480
+ return this.runtime.linear("lm_head").run(hidden, 1);
481
+ }
482
+ }
483
+
484
+ export class BitNetEncoderDecoderWASM extends BitNetEncoderDecoderWebGPU {
485
+ constructor(manifest, manifestUrl, denseTensors, linears) {
486
+ super(null, manifest, manifestUrl, denseTensors, linears);
487
+ }
488
+
489
+ static async fromManifestUrl(manifestUrl, options = {}) {
490
+ const progress = typeof options.progress === "function" ? options.progress : () => {};
491
+ progress({ phase: "manifest", message: "Loading model manifest" });
492
+ const manifest = options.manifest || await fetchJson(manifestUrl);
493
+ const baseUrl = new URL(".", manifestUrl).toString();
494
+ const dense = {};
495
+ const denseEntries = Object.entries(manifest.dense_tensors || {});
496
+ for (const [index, [name, entry]] of denseEntries.entries()) {
497
+ progress({
498
+ phase: "dense",
499
+ index: index + 1,
500
+ total: denseEntries.length,
501
+ name,
502
+ message: `Loading dense tensor ${index + 1}/${denseEntries.length}: ${name}`,
503
+ });
504
+ dense[name] = await fetchFloatTensor(entry, baseUrl);
505
+ }
506
+ progress({
507
+ phase: "dense_ready",
508
+ index: denseEntries.length,
509
+ total: denseEntries.length,
510
+ message: "Dense tensors ready",
511
+ });
512
+
513
+ const linears = {};
514
+ const layers = manifest.layers || [];
515
+ const layerConcurrency = Math.max(1, Math.min(Number(options.layerConcurrency || 4), layers.length || 1));
516
+ progress({
517
+ phase: "prepare_layers",
518
+ index: 0,
519
+ total: layers.length,
520
+ message: `Preparing ${layers.length} BitNet WASM layers (${layerConcurrency} parallel)`,
521
+ });
522
+ let nextLayer = 0;
523
+ let completedLayers = 0;
524
+ async function loadLayerWorker() {
525
+ while (nextLayer < layers.length) {
526
+ const index = nextLayer;
527
+ nextLayer += 1;
528
+ const layer = layers[index];
529
+ progress({
530
+ phase: "layer",
531
+ index: index + 1,
532
+ total: layers.length,
533
+ name: layer.name,
534
+ message: `Loading BitNet WASM layer ${index + 1}/${layers.length}: ${layer.name}`,
535
+ });
536
+ linears[layer.name] = await BitNetLinearWASM.fromManifestLayer(manifest, layer, manifestUrl, {
537
+ progress,
538
+ index: index + 1,
539
+ total: layers.length,
540
+ name: layer.name,
541
+ });
542
+ completedLayers += 1;
543
+ progress({
544
+ phase: "layer_ready",
545
+ index: completedLayers,
546
+ total: layers.length,
547
+ name: layer.name,
548
+ message: `BitNet WASM layer ${completedLayers}/${layers.length} ready: ${layer.name}`,
549
+ });
550
+ }
551
+ }
552
+ await Promise.all(Array.from({ length: Math.min(layerConcurrency, layers.length) }, () => loadLayerWorker()));
553
+ progress({ phase: "wasm_ready", message: "BitNet WASM runtime ready" });
554
+ return new BitNetEncoderDecoderWASM(manifest, manifestUrl, dense, linears);
555
+ }
556
  }