File size: 2,075 Bytes
814c07e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import * as ort from 'onnxruntime-web/wasm';
import { ENCODER_URL, DECODER_URL } from './config';

// Point onnxruntime-web at the WASM files Vite copies into /ort/ via vite.config.ts.
ort.env.wasm.wasmPaths = '/ort/';
// Disable multi-threading so ORT doesn't try to dynamically import the .mjs
// worker shim (which Vite's dev server blocks for files served from /public).
ort.env.wasm.numThreads = 1;

export interface NeedleSessions {
  encoder: ort.InferenceSession;
  decoder: ort.InferenceSession;
}

export async function loadSessions(onProgress?: (m: string) => void): Promise<NeedleSessions> {
  onProgress?.('downloading encoder…');
  const encoder = await ort.InferenceSession.create(ENCODER_URL, { executionProviders: ['wasm'] });
  onProgress?.('downloading decoder…');
  const decoder = await ort.InferenceSession.create(DECODER_URL, { executionProviders: ['wasm'] });
  return { encoder, decoder };
}

export async function runEncoder(sess: ort.InferenceSession, inputIds: number[]): Promise<ort.Tensor> {
  const ids = new ort.Tensor(
    'int64',
    BigInt64Array.from(inputIds.map(BigInt)),
    [1, inputIds.length],
  );
  const out = await sess.run({ input_ids: ids });
  return out.encoder_out;
}

export interface DecoderStepOut {
  logits: ort.Tensor;
  presentSelfKv: ort.Tensor;
}

export async function stepDecoder(
  sess: ort.InferenceSession,
  decoderInputId: number,
  encoderOut: ort.Tensor,
  pastSelfKv: ort.Tensor,
): Promise<DecoderStepOut> {
  const dec = new ort.Tensor('int64', BigInt64Array.from([BigInt(decoderInputId)]), [1, 1]);
  const out = await sess.run({
    decoder_input_ids: dec,
    encoder_out: encoderOut,
    past_self_kv: pastSelfKv,
  });
  return { logits: out.logits, presentSelfKv: out.present_self_kv };
}

/**
 * Initial empty past_self_kv tensor for step 0. Shape (8, 2, 1, 4, 0, 64) per the
 * ONNX export's dynamic axes (layers, k|v, batch, kv_heads, past_seq=0, head_dim).
 */
export function initialPastKv(): ort.Tensor {
  return new ort.Tensor('float32', new Float32Array(0), [8, 2, 1, 4, 0, 64]);
}