/** * Keet v2.0 - Model Manager * * Handles loading, caching, and managing parakeet.js model lifecycle. * Supports WebGPU with WASM fallback. * Stories 2.2 & 2.3: Progress UI + Cache API */ import type { ModelState, BackendType, ModelConfig, ModelProgress, ModelManagerCallbacks } from './types'; import { ParakeetModel, getParakeetModel, getModelConfig } from 'parakeet.js'; // Default model configuration (Parakeet TDT 0.6B) const DEFAULT_MODEL_ID = 'parakeet-tdt-0.6b-v2'; const CACHE_NAME = 'keet-model-cache-v1'; const PARAKEET_DB_NAME = 'parakeet-cache-db'; type ModelConfigResolver = ((modelKeyOrRepoId: string) => { repoId?: string } | null) | undefined; type ResolvedModelAssets = { urls: { encoderUrl: string; decoderUrl: string; tokenizerUrl: string; preprocessorUrl?: string; encoderDataUrl?: string | null; decoderDataUrl?: string | null; }; filenames?: { encoder: string; decoder: string; }; preprocessorBackend?: string; }; export class ModelManager { private _state: ModelState = 'unloaded'; private _progress: number = 0; private _backend: BackendType = 'webgpu'; private _model: any = null; // ParakeetModel instance private _callbacks: ModelManagerCallbacks = {}; private _isOfflineReady: boolean = false; private _isCached: boolean = false; constructor(callbacks: ModelManagerCallbacks = {}) { this._callbacks = callbacks; } // Getters getState(): ModelState { return this._state; } getProgress(): number { return this._progress; } getBackend(): BackendType { return this._backend; } getModel(): any { return this._model; } isOfflineReady(): boolean { return this._isOfflineReady; } isCached(): boolean { return this._isCached; } /** * Check if model is already cached (partial check) */ async checkCache(): Promise { // In v2.0 we rely on parakeet.js/IndexedDB cache, but we can do a quick check return this._isCached; } /** * Load the model with WebGPU/WASM fallback */ async loadModel(config: { modelId?: string } = {}): Promise { const modelId = config.modelId || DEFAULT_MODEL_ID; this._setState('loading'); this._setProgress({ stage: 'init', progress: 0, message: 'Initializing...' }); try { // 1. Detect WebGPU support const hasWebGPU = await this._detectWebGPU(); this._backend = hasWebGPU ? 'webgpu' : 'wasm'; this._setProgress({ stage: 'backend', progress: 10, message: `Using ${this._backend.toUpperCase()} backend` }); // 2. Import parakeet.js symbols this._setProgress({ stage: 'import', progress: 15, message: 'Loading parakeet.js...' }); const createModelFromAssets = async (assets: ResolvedModelAssets): Promise => { const preprocessorBackend = assets.preprocessorBackend || 'js'; console.log(`[ModelManager] Loading model with backend=${this._backend}, preprocessorBackend=${preprocessorBackend}`); return ParakeetModel.fromUrls({ ...assets.urls, filenames: assets.filenames, preprocessorBackend, backend: this._backend === 'webgpu' ? 'webgpu-hybrid' : 'wasm', verbose: false, }); }; // 3. Resolve model URLs via parakeet.js Hub (handles .data files correctly) this._setProgress({ stage: 'resolve', progress: 20, message: 'Resolving model assets...' }); const modelAssets = await getParakeetModel(modelId, { backend: this._backend, preprocessorBackend: 'js', // Use pure JS mel — faster, no ONNX download needed progress: (p: any) => { // Map parakeet.js progress to our UI const pct = Math.round(20 + (p.loaded / p.total) * 70); this._setProgress({ stage: 'download', progress: pct, message: 'Downloading assets...', file: `${p.file} (${Math.round(p.loaded / 1024 / 1024)}MB)` }); } }); // 4. Load the model into ONNX Runtime this._setProgress({ stage: 'compile', progress: 90, message: 'Compiling model (this may take a moment)...' }); try { this._model = await createModelFromAssets(modelAssets as ResolvedModelAssets); } catch (loadError) { if (!this._isRecoverableFetchError(loadError)) { throw loadError; } console.warn('[ModelManager] Hub blob URL load failed, clearing stale cache and retrying with direct URLs'); this._setProgress({ stage: 'recover', progress: 35, message: 'Recovering stale model cache...' }); await this._clearParakeetIndexedDbCache(); const directAssets = this._buildDirectModelAssets( modelId, this._backend, getModelConfig as ModelConfigResolver ); this._setProgress({ stage: 'recover', progress: 55, message: 'Retrying model load...' }); this._model = await createModelFromAssets(directAssets); } // Log which preprocessor the model is actually using const ppBackend = this._model.getPreprocessorBackend?.() || 'unknown'; console.log(`[ModelManager] Model ready. Preprocessor: ${ppBackend === 'js' ? 'JS (mel.js) — no ONNX preprocessor loaded' : 'ONNX (nemo128.onnx)'}`); this._setProgress({ stage: 'complete', progress: 100, message: 'Model ready' }); this._setState('ready'); // Mark as offline ready this._isOfflineReady = true; this._isCached = true; } catch (error) { console.error('Model loading failed:', error); this._setState('error'); this._setProgress({ stage: 'error', progress: 0, message: error instanceof Error ? error.message : 'Failed to load model' }); this._callbacks.onError?.(error as Error); throw error; } } /** * Side-load model from local files */ async loadLocalModel(files: FileList): Promise { this._setState('loading'); this._setProgress({ stage: 'init', progress: 0, message: 'Processing local files...' }); try { const fileArray = Array.from(files); const findFile = (pattern: RegExp) => fileArray.find(f => pattern.test(f.name)); // Map files to assets const assets = { encoder: findFile(/encoder.*\.onnx$/i), decoder: findFile(/decoder.*\.onnx$/i), tokenizer: findFile(/vocab.*\.txt$/i), preprocessor: findFile(/nemo.*\.onnx$/i), encoderData: findFile(/encoder.*\.onnx\.data$/i), decoderData: findFile(/decoder.*\.onnx\.data$/i), }; // Validation — preprocessor ONNX is optional (JS backend is default) if (!assets.encoder || !assets.decoder || !assets.tokenizer) { const missing = []; if (!assets.encoder) missing.push('encoder-model.onnx'); if (!assets.decoder) missing.push('decoder_joint-model.onnx'); if (!assets.tokenizer) missing.push('vocab.txt'); throw new Error(`Missing required files: ${missing.join(', ')}`); } const hasWebGPU = await this._detectWebGPU(); this._backend = hasWebGPU ? 'webgpu' : 'wasm'; this._setProgress({ stage: 'import', progress: 20, message: 'Initialising parakeet.js...' }); this._setProgress({ stage: 'compile', progress: 40, message: 'Compiling local model...' }); // Use JS preprocessor by default; fall back to ONNX if preprocessor file is provided const useOnnxPreprocessor = !!assets.preprocessor; const urls: Record = { encoderUrl: URL.createObjectURL(assets.encoder), decoderUrl: URL.createObjectURL(assets.decoder), tokenizerUrl: URL.createObjectURL(assets.tokenizer), encoderDataUrl: assets.encoderData ? URL.createObjectURL(assets.encoderData) : undefined, decoderDataUrl: assets.decoderData ? URL.createObjectURL(assets.decoderData) : undefined, }; if (useOnnxPreprocessor) { urls.preprocessorUrl = URL.createObjectURL(assets.preprocessor!); } this._model = await ParakeetModel.fromUrls({ ...urls, filenames: { encoder: assets.encoder.name, decoder: assets.decoder.name }, preprocessorBackend: useOnnxPreprocessor ? 'onnx' : 'js', backend: this._backend === 'webgpu' ? 'webgpu-hybrid' : 'wasm', verbose: false, }); this._setProgress({ stage: 'complete', progress: 100, message: 'Local model ready' }); this._setState('ready'); this._isOfflineReady = true; } catch (error) { console.error('Local model loading failed:', error); this._setState('error'); this._setProgress({ stage: 'error', progress: 0, message: error instanceof Error ? error.message : 'Failed to load local model' }); this._callbacks.onError?.(error as Error); throw error; } } /** * Detect WebGPU availability */ private async _detectWebGPU(): Promise { // Cast navigator to any to access WebGPU API (not in all TypeScript defs) const nav = navigator as any; if (!nav.gpu) { console.log('WebGPU not supported in this browser'); return false; } try { const adapter = await nav.gpu.requestAdapter(); if (!adapter) { console.log('No WebGPU adapter found'); return false; } const device = await adapter.requestDevice(); device.destroy(); console.log('WebGPU is available'); return true; } catch (e) { console.log('WebGPU check failed:', e); return false; } } /** * Update state and notify callbacks */ private _setState(state: ModelState): void { this._state = state; this._callbacks.onStateChange?.(state); } /** * Update progress and notify callbacks */ private _setProgress(progress: ModelProgress): void { this._progress = progress.progress; this._callbacks.onProgress?.(progress); } /** * Clear cached model data */ async clearCache(): Promise { try { await caches.delete(CACHE_NAME); await this._clearParakeetIndexedDbCache(); this._isCached = false; console.log('Model cache cleared'); } catch (e) { console.error('Failed to clear cache:', e); } } /** * Dispose model and free resources */ dispose(): void { this._model = null; this._state = 'unloaded'; this._progress = 0; } private _isRecoverableFetchError(error: unknown): boolean { if (!(error instanceof Error)) { return false; } const message = error.message.toLowerCase(); return message.includes('failed to fetch') || message.includes('fetch failed'); } private _buildDirectModelAssets( modelId: string, backend: BackendType, getModelConfig: ModelConfigResolver ): ResolvedModelAssets { const repoId = getModelConfig?.(modelId)?.repoId || modelId; const revision = 'main'; const encoderName = backend === 'webgpu' ? 'encoder-model.onnx' : 'encoder-model.int8.onnx'; const decoderName = 'decoder_joint-model.int8.onnx'; const baseUrl = `https://huggingface.co/${repoId}/resolve/${revision}`; return { urls: { encoderUrl: `${baseUrl}/${encoderName}`, decoderUrl: `${baseUrl}/${decoderName}`, tokenizerUrl: `${baseUrl}/vocab.txt`, }, preprocessorBackend: 'js', }; } private async _clearParakeetIndexedDbCache(): Promise { if (typeof indexedDB === 'undefined') return; await new Promise((resolve) => { const request = indexedDB.deleteDatabase(PARAKEET_DB_NAME); request.onsuccess = () => resolve(); request.onerror = () => { console.warn('[ModelManager] Failed to clear parakeet IndexedDB cache'); resolve(); }; request.onblocked = () => { console.warn('[ModelManager] Parakeet IndexedDB cache clear blocked'); resolve(); }; }); } }