needle-playground / src /generate.ts
shreyask's picture
Upload folder using huggingface_hub
814c07e verified
import * as ort from 'onnxruntime-web/wasm';
import { Tokenizer } from './tokenizer';
import { runEncoder, stepDecoder, initialPastKv } from './runtime';
import type { NeedleSessions } from './runtime';
export interface GenerateOpts {
maxNewTokens?: number; // default 256
eosTokenId: number; // = 1 per tokenizer-specials.json
bosOrPrefixTokenId: number; // Cactus uses EOS (id=1) as the decoder seed; pass eos here.
toolsTokenId: number; // = 5 per tokenizer-specials.json
}
/**
* Format (query, tools) into the encoder input token list, matching Cactus's
* `_build_encoder_input`:
*
* [query_tokens..., <tools>(id=5), tools_tokens...]
*
* The Python side truncates to max_enc_len=1024; we do the same here. Tools are
* stringified to JSON before encoding.
*/
export function buildEncoderInput(
tokenizer: Tokenizer,
query: string,
tools: unknown[],
toolsTokenId: number,
maxEncLen = 1024,
): number[] {
const qTokens = tokenizer.encode(query);
const tTokens = tokenizer.encode(JSON.stringify(tools));
const maxQuery = maxEncLen - 2;
const q = qTokens.length > maxQuery ? qTokens.slice(0, maxQuery) : qTokens;
const remaining = maxEncLen - q.length - 1;
const t = tTokens.slice(0, remaining);
return [...q, toolsTokenId, ...t];
}
export async function generate(
sessions: NeedleSessions,
tokenizer: Tokenizer,
query: string,
tools: unknown[],
opts: GenerateOpts,
onToken?: (id: number, decodedSoFar: string) => void,
): Promise<{ ids: number[]; text: string }> {
const encoderInputIds = buildEncoderInput(tokenizer, query, tools, opts.toolsTokenId);
const encoderOut = await runEncoder(sessions.encoder, encoderInputIds);
let pastKv = initialPastKv();
let nextId = opts.bosOrPrefixTokenId; // Cactus convention: decoder seeded with EOS (id=1)
const generated: number[] = [];
const maxNew = opts.maxNewTokens ?? 256;
for (let i = 0; i < maxNew; i++) {
const { logits, presentSelfKv } = await stepDecoder(
sessions.decoder, nextId, encoderOut, pastKv,
);
pastKv = presentSelfKv;
nextId = sampleNextToken(logits);
if (shouldStop(nextId, generated, opts.eosTokenId, tokenizer)) break;
generated.push(nextId);
onToken?.(nextId, tokenizer.decode(generated));
}
let text = tokenizer.decode(generated);
// Strip the leading <tool_call> marker that Cactus's generate() also strips.
if (text.startsWith('<tool_call>')) text = text.slice('<tool_call>'.length);
return { ids: generated, text };
}
// ────────────────────────────────────────────────────────────────────────────
// USER CONTRIBUTION POINT #1 β€” sampleNextToken
//
// Pick the next token from a (1, 1, vocab_size) float32 logits tensor.
//
// Choices:
// (a) argmax β€” deterministic, repeatable. Function calling has a narrow
// correct answer; argmax is what Cactus's native generate() uses.
// (b) temperature sampling β€” softmax(logits / T), then sample. T<1 = sharper,
// T>1 = more varied. Non-deterministic without a seed.
// (c) top-p (nucleus) β€” softmax, sort, keep tokens until cumulative β‰₯ p,
// sample. Most "natural" sampling but adds two hyperparams.
//
// Default (if no preference given): (a) argmax.
// ────────────────────────────────────────────────────────────────────────────
function sampleNextToken(logits: ort.Tensor): number {
const data = logits.data as Float32Array; // shape (1, 1, vocab_size) β†’ flat array of vocab_size
// USER: replace this body with your choice from (a)/(b)/(c). The default is (a) argmax.
let bestIdx = 0;
let bestVal = -Infinity;
for (let i = 0; i < data.length; i++) {
if (data[i] > bestVal) { bestVal = data[i]; bestIdx = i; }
}
return bestIdx;
}
// ────────────────────────────────────────────────────────────────────────────
// USER CONTRIBUTION POINT #2 β€” shouldStop
//
// Decide whether to halt generation after emitting `nextId` on top of `soFar`.
// The decode loop's `maxNewTokens` cap also bounds the loop independently.
//
// Choices:
// (a) EOS-only β€” matches Cactus's native generate() exactly. Simplest, safest.
// (b) EOS OR balanced-JSON β€” if the decoded text since <tool_call> is a
// valid parseable JSON array (e.g. ']' at top level with brace balance
// at zero), stop. Crisper exit when the model trails into padding.
// (c) EOS OR balanced-JSON OR token == ']' β€” same as (b) but cheaper to
// check, since tokenizer's ']' token ID is fixed.
//
// Default: (a) EOS-only.
// ────────────────────────────────────────────────────────────────────────────
function shouldStop(
nextId: number,
soFar: number[],
eosId: number,
tokenizer: Tokenizer,
): boolean {
// USER: replace this body with your choice from (a)/(b)/(c). Default is (a).
void soFar; void tokenizer;
return nextId === eosId;
}