/** * Encoder Worker - Runs preprocessor + encoder in a separate thread */ importScripts('https://cdn.jsdelivr.net/npm/onnxruntime-web@1.21.0/dist/ort.all.min.js'); // Configure ONNX Runtime paths ort.env.wasm.wasmPaths = 'https://cdn.jsdelivr.net/npm/onnxruntime-web@1.21.0/dist/'; const MODEL_CACHE_NAME = 'moonshine-models-v1'; // Helper to fetch model with progress reporting and caching async function fetchModelWithProgress(url, modelName) { // Try to get from cache first try { const cache = await caches.open(MODEL_CACHE_NAME); const cachedResponse = await cache.match(url); if (cachedResponse) { const buffer = await cachedResponse.arrayBuffer(); self.postMessage({ type: 'progress', model: modelName, loaded: buffer.byteLength, total: buffer.byteLength, done: true, cached: true }); console.log(`${modelName} loaded from cache`); return buffer; } } catch (e) { console.warn('Cache API not available:', e.message); } // Fetch from network const response = await fetch(url); if (!response.ok) { throw new Error(`Failed to fetch ${modelName}: ${response.status}`); } const contentLength = response.headers.get('Content-Length'); const total = contentLength ? parseInt(contentLength, 10) : 0; if (!response.body || !total) { // No streaming support or unknown size - just download const buffer = await response.arrayBuffer(); self.postMessage({ type: 'progress', model: modelName, loaded: buffer.byteLength, total: buffer.byteLength, done: true }); // Cache the response try { const cache = await caches.open(MODEL_CACHE_NAME); await cache.put(url, new Response(buffer.slice(0))); } catch (e) { console.warn('Failed to cache model:', e.message); } return buffer; } const reader = response.body.getReader(); const chunks = []; let loaded = 0; while (true) { const { done, value } = await reader.read(); if (done) break; chunks.push(value); loaded += value.length; self.postMessage({ type: 'progress', model: modelName, loaded, total, done: false }); } self.postMessage({ type: 'progress', model: modelName, loaded: total, total, done: true }); // Combine chunks into single ArrayBuffer const result = new Uint8Array(loaded); let offset = 0; for (const chunk of chunks) { result.set(chunk, offset); offset += chunk.length; } // Cache the result try { const cache = await caches.open(MODEL_CACHE_NAME); await cache.put(url, new Response(result.slice(0))); console.log(`${modelName} cached`); } catch (e) { console.warn('Failed to cache model:', e.message); } return result.buffer; } // Model config let cfg = null; let preprocessor = null; let encoder = null; let tailLatency = 0; // Preprocessor state let prepSession = null; let prepDim = 0; let prepC1 = 0; let prepStateC1 = null; let prepStateC2 = null; // Encoder state let encSession = null; let encDim = 0; let encNPast = 0; let encNFuture = 0; let encEncoderDepth = 0; let encContextSize = 0; let encInputBuffer = []; let encTotalInputFrames = 0; let encTotalOutputFrames = 0; function resetPreprocessor() { if (prepStateC1) prepStateC1.fill(0); if (prepStateC2) prepStateC2.fill(0); } function resetEncoder() { encInputBuffer = []; encTotalInputFrames = 0; encTotalOutputFrames = 0; } async function processPreprocessor(audioChunk) { const feeds = { 'audio_chunk': new ort.Tensor('float32', audioChunk, [1, audioChunk.length]), 'state_c1': new ort.Tensor('float32', prepStateC1, [1, 4, prepDim]), 'state_c2': new ort.Tensor('float32', prepStateC2, [1, 4, prepC1]) }; const results = await prepSession.run(feeds); // Update states prepStateC1.set(results.new_state_c1.data); prepStateC2.set(results.new_state_c2.data); return { data: results.features.data, dims: results.features.dims }; } async function processEncoder(melData, melDims, flush = true) { const newFrames = melDims[1]; // Append new frames to buffer for (let f = 0; f < newFrames; f++) { const frame = new Float32Array(encDim); for (let d = 0; d < encDim; d++) { frame[d] = melData[f * encDim + d]; } encInputBuffer.push(frame); } encTotalInputFrames += newFrames; // Calculate output range const canOutput = flush ? encTotalInputFrames : Math.max(0, encTotalInputFrames - tailLatency); const outputFrom = flush ? Math.max(0, encTotalOutputFrames - tailLatency) : encTotalOutputFrames; const newOutputCount = canOutput - outputFrom; if (newOutputCount <= 0) { return { data: new Float32Array(0), dims: [1, 0, encDim] }; } // Prepare input buffer tensor const bufferFrames = encInputBuffer.length; const bufferData = new Float32Array(bufferFrames * encDim); for (let f = 0; f < bufferFrames; f++) { bufferData.set(encInputBuffer[f], f * encDim); } const feeds = { 'input': new ort.Tensor('float32', bufferData, [1, bufferFrames, encDim]) }; const results = await encSession.run(feeds); const fullOutput = results.output; // Calculate which frames to return const bufStartFrame = encTotalInputFrames - bufferFrames; const outputStart = outputFrom - bufStartFrame; // Extract the subset of output const resultData = new Float32Array(newOutputCount * encDim); for (let f = 0; f < newOutputCount; f++) { for (let d = 0; d < encDim; d++) { resultData[f * encDim + d] = fullOutput.data[(outputStart + f) * encDim + d]; } } // Trim input buffer to context size if (encInputBuffer.length > encContextSize) { encInputBuffer = encInputBuffer.slice(-encContextSize); } encTotalOutputFrames = canOutput; return { data: resultData, dims: [1, newOutputCount, encDim] }; } // Message queue for sequential processing const messageQueue = []; let isProcessing = false; async function processMessage(e) { const { type, data } = e.data; switch (type) { case 'init': { try { cfg = data.cfg; const onnxUrl = data.onnxUrl; const modelName = data.modelName; const backend = data.backend || 'wasm'; const dtype = 'fp32'; // Check WebGPU availability if (backend === 'webgpu') { if (typeof navigator !== 'undefined' && navigator.gpu) { console.log('WebGPU navigator.gpu is available'); const adapter = await navigator.gpu.requestAdapter(); if (adapter) { console.log('WebGPU adapter obtained:', adapter); } else { throw new Error('WebGPU adapter not available'); } } else { throw new Error('WebGPU not supported (navigator.gpu is undefined)'); } } const sessionOptions = { executionProviders: [backend] }; console.log(`Creating sessions with backend: ${backend}`); tailLatency = cfg.n_future * cfg.encoder_depth; // Initialize preprocessor const prepUrl = `${onnxUrl}/preprocessor_streaming_${modelName}_${dtype}.onnx`; self.postMessage({ type: 'status', message: 'Loading preprocessor...' }); self.postMessage({ type: 'model_start', model: 'Preprocessor' }); const prepBuffer = await fetchModelWithProgress(prepUrl, 'Preprocessor'); prepSession = await ort.InferenceSession.create(prepBuffer, sessionOptions); self.postMessage({ type: 'model_done', model: 'Preprocessor' }); prepDim = cfg.dim; prepC1 = 2 * cfg.dim; prepStateC1 = new Float32Array(4 * cfg.dim); prepStateC2 = new Float32Array(4 * prepC1); // Initialize encoder const encUrl = `${onnxUrl}/encoder_${modelName}_${dtype}.onnx`; self.postMessage({ type: 'status', message: 'Loading encoder...' }); self.postMessage({ type: 'model_start', model: 'Encoder' }); const encBuffer = await fetchModelWithProgress(encUrl, 'Encoder'); encSession = await ort.InferenceSession.create(encBuffer, sessionOptions); self.postMessage({ type: 'model_done', model: 'Encoder' }); encDim = cfg.dim; encNPast = cfg.n_past; encNFuture = cfg.n_future; encEncoderDepth = cfg.encoder_depth; encContextSize = cfg.encoder_depth * (cfg.n_past + cfg.n_future); self.postMessage({ type: 'ready', backend: backend }); } catch (err) { self.postMessage({ type: 'error', message: err.message }); } break; } case 'segment_start': { resetPreprocessor(); resetEncoder(); self.postMessage({ type: 'segment_start', segmentId: data.segmentId }); break; } case 'segment_end': { self.postMessage({ type: 'segment_end', segmentId: data.segmentId }); break; } case 'audio': { try { // Process through preprocessor const mel = await processPreprocessor(new Float32Array(data.audio)); const audioMs = (data.audio.length / 16000 * 1000).toFixed(0); console.log(`Audio ${data.audio.length} samples (${audioMs}ms) → Mel ${mel.dims[1]} frames`); // Process through encoder with flush=true const enc = await processEncoder(mel.data, mel.dims, true); console.log(`Mel ${mel.dims[1]} frames → Encoder ${enc.dims[1]} frames (accumulated: ${encTotalOutputFrames})`); if (enc.dims[1] > 0) { self.postMessage({ type: 'features', segmentId: data.segmentId, features: enc.data, dims: enc.dims }, [enc.data.buffer]); // Transfer ownership } } catch (err) { console.error('Encoder error:', err); } break; } } } async function processQueue() { if (isProcessing) return; isProcessing = true; while (messageQueue.length > 0) { const msg = messageQueue.shift(); await processMessage(msg); } isProcessing = false; } self.onmessage = function(e) { messageQueue.push(e); processQueue(); };