import * as ort from 'onnxruntime-web/wasm'; import { ENCODER_URL, DECODER_URL } from './config'; // Point onnxruntime-web at the WASM files Vite copies into /ort/ via vite.config.ts. ort.env.wasm.wasmPaths = '/ort/'; // Disable multi-threading so ORT doesn't try to dynamically import the .mjs // worker shim (which Vite's dev server blocks for files served from /public). ort.env.wasm.numThreads = 1; export interface NeedleSessions { encoder: ort.InferenceSession; decoder: ort.InferenceSession; } export async function loadSessions(onProgress?: (m: string) => void): Promise { onProgress?.('downloading encoder…'); const encoder = await ort.InferenceSession.create(ENCODER_URL, { executionProviders: ['wasm'] }); onProgress?.('downloading decoder…'); const decoder = await ort.InferenceSession.create(DECODER_URL, { executionProviders: ['wasm'] }); return { encoder, decoder }; } export async function runEncoder(sess: ort.InferenceSession, inputIds: number[]): Promise { const ids = new ort.Tensor( 'int64', BigInt64Array.from(inputIds.map(BigInt)), [1, inputIds.length], ); const out = await sess.run({ input_ids: ids }); return out.encoder_out; } export interface DecoderStepOut { logits: ort.Tensor; presentSelfKv: ort.Tensor; } export async function stepDecoder( sess: ort.InferenceSession, decoderInputId: number, encoderOut: ort.Tensor, pastSelfKv: ort.Tensor, ): Promise { const dec = new ort.Tensor('int64', BigInt64Array.from([BigInt(decoderInputId)]), [1, 1]); const out = await sess.run({ decoder_input_ids: dec, encoder_out: encoderOut, past_self_kv: pastSelfKv, }); return { logits: out.logits, presentSelfKv: out.present_self_kv }; } /** * Initial empty past_self_kv tensor for step 0. Shape (8, 2, 1, 4, 0, 64) per the * ONNX export's dynamic axes (layers, k|v, batch, kv_heads, past_seq=0, head_dim). */ export function initialPastKv(): ort.Tensor { return new ort.Tensor('float32', new Float32Array(0), [8, 2, 1, 4, 0, 64]); }