Spaces:
Running
Running
| import * as ort from 'onnxruntime-web/wasm'; | |
| import { Tokenizer } from './tokenizer'; | |
| import { runEncoder, stepDecoder, initialPastKv } from './runtime'; | |
| import type { NeedleSessions } from './runtime'; | |
| export interface GenerateOpts { | |
| maxNewTokens?: number; // default 256 | |
| eosTokenId: number; // = 1 per tokenizer-specials.json | |
| bosOrPrefixTokenId: number; // Cactus uses EOS (id=1) as the decoder seed; pass eos here. | |
| toolsTokenId: number; // = 5 per tokenizer-specials.json | |
| } | |
| /** | |
| * Format (query, tools) into the encoder input token list, matching Cactus's | |
| * `_build_encoder_input`: | |
| * | |
| * [query_tokens..., <tools>(id=5), tools_tokens...] | |
| * | |
| * The Python side truncates to max_enc_len=1024; we do the same here. Tools are | |
| * stringified to JSON before encoding. | |
| */ | |
| export function buildEncoderInput( | |
| tokenizer: Tokenizer, | |
| query: string, | |
| tools: unknown[], | |
| toolsTokenId: number, | |
| maxEncLen = 1024, | |
| ): number[] { | |
| const qTokens = tokenizer.encode(query); | |
| const tTokens = tokenizer.encode(JSON.stringify(tools)); | |
| const maxQuery = maxEncLen - 2; | |
| const q = qTokens.length > maxQuery ? qTokens.slice(0, maxQuery) : qTokens; | |
| const remaining = maxEncLen - q.length - 1; | |
| const t = tTokens.slice(0, remaining); | |
| return [...q, toolsTokenId, ...t]; | |
| } | |
| export async function generate( | |
| sessions: NeedleSessions, | |
| tokenizer: Tokenizer, | |
| query: string, | |
| tools: unknown[], | |
| opts: GenerateOpts, | |
| onToken?: (id: number, decodedSoFar: string) => void, | |
| ): Promise<{ ids: number[]; text: string }> { | |
| const encoderInputIds = buildEncoderInput(tokenizer, query, tools, opts.toolsTokenId); | |
| const encoderOut = await runEncoder(sessions.encoder, encoderInputIds); | |
| let pastKv = initialPastKv(); | |
| let nextId = opts.bosOrPrefixTokenId; // Cactus convention: decoder seeded with EOS (id=1) | |
| const generated: number[] = []; | |
| const maxNew = opts.maxNewTokens ?? 256; | |
| for (let i = 0; i < maxNew; i++) { | |
| const { logits, presentSelfKv } = await stepDecoder( | |
| sessions.decoder, nextId, encoderOut, pastKv, | |
| ); | |
| pastKv = presentSelfKv; | |
| nextId = sampleNextToken(logits); | |
| if (shouldStop(nextId, generated, opts.eosTokenId, tokenizer)) break; | |
| generated.push(nextId); | |
| onToken?.(nextId, tokenizer.decode(generated)); | |
| } | |
| let text = tokenizer.decode(generated); | |
| // Strip the leading <tool_call> marker that Cactus's generate() also strips. | |
| if (text.startsWith('<tool_call>')) text = text.slice('<tool_call>'.length); | |
| return { ids: generated, text }; | |
| } | |
| // ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| // USER CONTRIBUTION POINT #1 β sampleNextToken | |
| // | |
| // Pick the next token from a (1, 1, vocab_size) float32 logits tensor. | |
| // | |
| // Choices: | |
| // (a) argmax β deterministic, repeatable. Function calling has a narrow | |
| // correct answer; argmax is what Cactus's native generate() uses. | |
| // (b) temperature sampling β softmax(logits / T), then sample. T<1 = sharper, | |
| // T>1 = more varied. Non-deterministic without a seed. | |
| // (c) top-p (nucleus) β softmax, sort, keep tokens until cumulative β₯ p, | |
| // sample. Most "natural" sampling but adds two hyperparams. | |
| // | |
| // Default (if no preference given): (a) argmax. | |
| // ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| function sampleNextToken(logits: ort.Tensor): number { | |
| const data = logits.data as Float32Array; // shape (1, 1, vocab_size) β flat array of vocab_size | |
| // USER: replace this body with your choice from (a)/(b)/(c). The default is (a) argmax. | |
| let bestIdx = 0; | |
| let bestVal = -Infinity; | |
| for (let i = 0; i < data.length; i++) { | |
| if (data[i] > bestVal) { bestVal = data[i]; bestIdx = i; } | |
| } | |
| return bestIdx; | |
| } | |
| // ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| // USER CONTRIBUTION POINT #2 β shouldStop | |
| // | |
| // Decide whether to halt generation after emitting `nextId` on top of `soFar`. | |
| // The decode loop's `maxNewTokens` cap also bounds the loop independently. | |
| // | |
| // Choices: | |
| // (a) EOS-only β matches Cactus's native generate() exactly. Simplest, safest. | |
| // (b) EOS OR balanced-JSON β if the decoded text since <tool_call> is a | |
| // valid parseable JSON array (e.g. ']' at top level with brace balance | |
| // at zero), stop. Crisper exit when the model trails into padding. | |
| // (c) EOS OR balanced-JSON OR token == ']' β same as (b) but cheaper to | |
| // check, since tokenizer's ']' token ID is fixed. | |
| // | |
| // Default: (a) EOS-only. | |
| // ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| function shouldStop( | |
| nextId: number, | |
| soFar: number[], | |
| eosId: number, | |
| tokenizer: Tokenizer, | |
| ): boolean { | |
| // USER: replace this body with your choice from (a)/(b)/(c). Default is (a). | |
| void soFar; void tokenizer; | |
| return nextId === eosId; | |
| } | |