trigo / trigo-web /app /src /services /onnxInferencer.ts
k-l-lambda's picture
updated
502af73
/**
* ONNX Inferencer for Browser (Frontend Wrapper)
*
* This file imports onnxruntime-web and creates sessions,
* then injects them into the platform-agnostic inferencer.
*/
// IMPORTANT: Import onnxruntime-web first to register backends
import * as ort from "onnxruntime-web";
// Configure ONNX Runtime Web environment BEFORE creating any sessions
ort.env.wasm.numThreads = 1;
ort.env.wasm.simd = true;
// Import the common inferencer class AFTER onnxruntime-web is loaded
import { ModelInferencer, type InferencerConfig, type InferenceResult } from "@inc/modelInferencer";
/**
* Web-specific configuration for the inferencer
*/
export interface WebInferencerConfig extends Partial<InferencerConfig> {
modelPath: string;
executionProviders?: ("wasm" | "webgl" | "webgpu")[];
sessionOptions?: ort.InferenceSession.SessionOptions;
}
/**
* ONNX Inferencer for Browser with Web-specific defaults
*/
export class OnnxInferencer extends ModelInferencer {
private modelPath: string;
private sessionOptions?: ort.InferenceSession.SessionOptions;
constructor(config: WebInferencerConfig) {
const { modelPath, executionProviders = ["wasm"], sessionOptions, ...baseConfig } = config;
// Pass ort.Tensor constructor to base class
super(ort.Tensor as any, baseConfig);
this.modelPath = modelPath;
this.sessionOptions = {
executionProviders: executionProviders,
graphOptimizationLevel: "all",
...sessionOptions
};
}
/**
* Initialize the inference session
*/
async initialize(): Promise<void> {
console.log("[OnnxInferencer] Initializing...");
console.log("[OnnxInferencer] Model path:", this.modelPath);
try {
const session = await ort.InferenceSession.create(this.modelPath, this.sessionOptions);
// Inject session into base class
this.setSession(session as any);
} catch (error) {
console.error("[OnnxInferencer] Failed to create session:", error);
throw error;
}
}
}
// Re-export types
export type { InferencerConfig, InferenceResult };