import * as ort from 'onnxruntime-web/wasm'; import { Tokenizer } from './tokenizer'; import { runEncoder, stepDecoder, initialPastKv } from './runtime'; import type { NeedleSessions } from './runtime'; export interface GenerateOpts { maxNewTokens?: number; // default 256 eosTokenId: number; // = 1 per tokenizer-specials.json bosOrPrefixTokenId: number; // Cactus uses EOS (id=1) as the decoder seed; pass eos here. toolsTokenId: number; // = 5 per tokenizer-specials.json } /** * Format (query, tools) into the encoder input token list, matching Cactus's * `_build_encoder_input`: * * [query_tokens..., (id=5), tools_tokens...] * * The Python side truncates to max_enc_len=1024; we do the same here. Tools are * stringified to JSON before encoding. */ export function buildEncoderInput( tokenizer: Tokenizer, query: string, tools: unknown[], toolsTokenId: number, maxEncLen = 1024, ): number[] { const qTokens = tokenizer.encode(query); const tTokens = tokenizer.encode(JSON.stringify(tools)); const maxQuery = maxEncLen - 2; const q = qTokens.length > maxQuery ? qTokens.slice(0, maxQuery) : qTokens; const remaining = maxEncLen - q.length - 1; const t = tTokens.slice(0, remaining); return [...q, toolsTokenId, ...t]; } export async function generate( sessions: NeedleSessions, tokenizer: Tokenizer, query: string, tools: unknown[], opts: GenerateOpts, onToken?: (id: number, decodedSoFar: string) => void, ): Promise<{ ids: number[]; text: string }> { const encoderInputIds = buildEncoderInput(tokenizer, query, tools, opts.toolsTokenId); const encoderOut = await runEncoder(sessions.encoder, encoderInputIds); let pastKv = initialPastKv(); let nextId = opts.bosOrPrefixTokenId; // Cactus convention: decoder seeded with EOS (id=1) const generated: number[] = []; const maxNew = opts.maxNewTokens ?? 256; for (let i = 0; i < maxNew; i++) { const { logits, presentSelfKv } = await stepDecoder( sessions.decoder, nextId, encoderOut, pastKv, ); pastKv = presentSelfKv; nextId = sampleNextToken(logits); if (shouldStop(nextId, generated, opts.eosTokenId, tokenizer)) break; generated.push(nextId); onToken?.(nextId, tokenizer.decode(generated)); } let text = tokenizer.decode(generated); // Strip the leading marker that Cactus's generate() also strips. if (text.startsWith('')) text = text.slice(''.length); return { ids: generated, text }; } // ──────────────────────────────────────────────────────────────────────────── // USER CONTRIBUTION POINT #1 — sampleNextToken // // Pick the next token from a (1, 1, vocab_size) float32 logits tensor. // // Choices: // (a) argmax — deterministic, repeatable. Function calling has a narrow // correct answer; argmax is what Cactus's native generate() uses. // (b) temperature sampling — softmax(logits / T), then sample. T<1 = sharper, // T>1 = more varied. Non-deterministic without a seed. // (c) top-p (nucleus) — softmax, sort, keep tokens until cumulative ≥ p, // sample. Most "natural" sampling but adds two hyperparams. // // Default (if no preference given): (a) argmax. // ──────────────────────────────────────────────────────────────────────────── function sampleNextToken(logits: ort.Tensor): number { const data = logits.data as Float32Array; // shape (1, 1, vocab_size) → flat array of vocab_size // USER: replace this body with your choice from (a)/(b)/(c). The default is (a) argmax. let bestIdx = 0; let bestVal = -Infinity; for (let i = 0; i < data.length; i++) { if (data[i] > bestVal) { bestVal = data[i]; bestIdx = i; } } return bestIdx; } // ──────────────────────────────────────────────────────────────────────────── // USER CONTRIBUTION POINT #2 — shouldStop // // Decide whether to halt generation after emitting `nextId` on top of `soFar`. // The decode loop's `maxNewTokens` cap also bounds the loop independently. // // Choices: // (a) EOS-only — matches Cactus's native generate() exactly. Simplest, safest. // (b) EOS OR balanced-JSON — if the decoded text since is a // valid parseable JSON array (e.g. ']' at top level with brace balance // at zero), stop. Crisper exit when the model trails into padding. // (c) EOS OR balanced-JSON OR token == ']' — same as (b) but cheaper to // check, since tokenizer's ']' token ID is fixed. // // Default: (a) EOS-only. // ──────────────────────────────────────────────────────────────────────────── function shouldStop( nextId: number, soFar: number[], eosId: number, tokenizer: Tokenizer, ): boolean { // USER: replace this body with your choice from (a)/(b)/(c). Default is (a). void soFar; void tokenizer; return nextId === eosId; }