trigo / trigo-web /backend /dist /inc /modelInferencer.d.ts
k-l-lambda's picture
Deploy: fix build by keeping pre-built dist folder
63a8db2
/**
* ONNX Model Inferencer (Frontend/Backend Common)
*
* Platform-agnostic inference logic that accepts ONNX session from platform-specific code.
* No direct dependency on onnxruntime packages - uses dependency injection pattern.
*
* Adapted from Node.js test_inference.js for cross-platform use
* Provides causal language model inference using GPT-2 ONNX model
*
* Vocabulary Design (128 tokens):
* 0-3: Special tokens (PAD=0, START=1, END=2, VALUE=3)
* 4-7: Reserved for future use
* 10: LF (newline) for multi-line game records
* 32-127: ASCII printable characters (direct identity mapping)
*
* This design uses direct identity mapping: token_id = ascii_value
* No complex formulas needed - simple and efficient.
*/
/**
* Minimal ONNX Tensor interface (platform-agnostic)
*/
export interface OnnxTensor {
readonly data: number[] | Float32Array | Int32Array | BigInt64Array | Uint8Array;
readonly dims: readonly number[];
readonly type: string;
}
/**
* Minimal ONNX Session interface (platform-agnostic)
*/
export interface OnnxSession {
readonly inputNames: readonly string[];
readonly outputNames: readonly string[];
run(feeds: Record<string, OnnxTensor>): Promise<Record<string, OnnxTensor>>;
}
/**
* Tensor constructor interface (platform-specific)
*/
export interface TensorConstructor {
new (type: string, data: BigInt64Array | Float32Array | Int32Array | Uint8Array, dims: number[]): OnnxTensor;
}
/**
* Configuration for the inferencer
*/
export interface InferencerConfig {
vocabSize: number;
seqLen: number;
modelPath?: string;
}
/**
* Inference result containing generated tokens and metadata
*/
export interface InferenceResult {
tokens: number[];
text: string;
logits: Float32Array;
inferenceTime: number;
}
/**
* Evaluation mode inputs for tree attention
*/
export interface EvaluationInputs {
prefixIds: number[];
evaluatedIds: number[];
evaluatedMask: number[];
}
/**
* Evaluation mode output
*/
export interface EvaluationOutput {
logits: Float32Array;
numEvaluated: number;
}
/**
* Model Inferencer for Causal Language Model
* Compatible with both frontend (onnxruntime-web) and backend (onnxruntime-node)
*/
export declare class ModelInferencer {
private session;
private config;
private TensorClass;
private readonly PAD_TOKEN;
private readonly START_TOKEN;
private readonly END_TOKEN;
private readonly VALUE_TOKEN;
constructor(TensorClass: TensorConstructor, config?: Partial<InferencerConfig>);
/**
* Set the inference session (created by platform-specific code)
*/
setSession(session: OnnxSession): void;
/**
* Run basic inference test
*/
testBasicInference(): Promise<InferenceResult>;
/**
* Generate tokens autoregressively from a prompt
*/
generateText(prompt: string, numTokens?: number): Promise<InferenceResult>;
/**
* Get model information
*/
getModelInfo(): {
inputs: string[];
outputs: string[];
} | null;
/**
* Get configuration
*/
getConfig(): InferencerConfig;
/**
* Run inference with token array input
* Returns raw logits as Float32Array
*/
runInference(tokens: number[]): Promise<Float32Array>;
/**
* Run tree attention inference (evaluation mode)
* For models exported with --evaluation flag
* @param inputs - Prefix, evaluated tokens, and attention mask
* @returns Logits for each evaluated position
*/
runEvaluationInference(inputs: EvaluationInputs): Promise<EvaluationOutput>;
/**
* Run value prediction inference (for evaluation mode models)
* For models exported with --evaluation-mode flag
* @param tokens - Token IDs (already includes START/END tokens and padding)
* @returns Predicted game outcome value in range [-1, 1]
*/
runValuePrediction(tokens: number[]): Promise<number>;
/**
* Compute softmax for a single position's logits
* @param logits - Full logits array
* @param position - Which evaluated position (0 = last prefix, 1-m = evaluated tokens)
* @returns Probability distribution over vocabulary
*/
softmax(logits: Float32Array, position: number): Float32Array;
/**
* Check if inferencer is ready
*/
isReady(): boolean;
/**
* Destroy the session and free resources
*/
destroy(): void;
private printModelInfo;
private createRandomInput;
private padSequence;
private validateOutput;
private getPredictions;
}
//# sourceMappingURL=modelInferencer.d.ts.map