/** * Decoder Worker - Runs adapter + decoder in a separate thread */ importScripts('https://cdn.jsdelivr.net/npm/onnxruntime-web@1.17.0/dist/ort.min.js'); // Configure ONNX Runtime to find WASM files from CDN ort.env.wasm.wasmPaths = 'https://cdn.jsdelivr.net/npm/onnxruntime-web@1.17.0/dist/'; // Helper to fetch model with progress reporting async function fetchModelWithProgress(url, modelName) { const response = await fetch(url); if (!response.ok) { throw new Error(`Failed to fetch ${modelName}: ${response.status}`); } const contentLength = response.headers.get('Content-Length'); const total = contentLength ? parseInt(contentLength, 10) : 0; if (!response.body || !total) { // No streaming support or unknown size - just download const buffer = await response.arrayBuffer(); self.postMessage({ type: 'progress', model: modelName, loaded: buffer.byteLength, total: buffer.byteLength, done: true }); return buffer; } const reader = response.body.getReader(); const chunks = []; let loaded = 0; while (true) { const { done, value } = await reader.read(); if (done) break; chunks.push(value); loaded += value.length; self.postMessage({ type: 'progress', model: modelName, loaded, total, done: false }); } self.postMessage({ type: 'progress', model: modelName, loaded: total, total, done: true }); // Combine chunks into single ArrayBuffer const result = new Uint8Array(loaded); let offset = 0; for (const chunk of chunks) { result.set(chunk, offset); offset += chunk.length; } return result.buffer; } // Model config let cfg = null; let tailLatency = 0; // Decoding config const TOKENS_PER_SECOND = 6.5; // Max tokens per second of audio const FRAME_DURATION_MS = 20; // Each encoder frame is 20ms // Check for repetitive token patterns that indicate decoding should stop function hasRepetition(tokens) { const len = tokens.length; if (len < 5) return false; // Check if last 5 tokens are the same const last5 = tokens.slice(-5); if (last5.every(t => t === last5[0])) { return true; } // Check for 3 repeated same pairs (e.g., [A,B,A,B,A,B]) if (len >= 6) { const pair1 = [tokens[len - 6], tokens[len - 5]]; const pair2 = [tokens[len - 4], tokens[len - 3]]; const pair3 = [tokens[len - 2], tokens[len - 1]]; if (pair1[0] === pair2[0] && pair2[0] === pair3[0] && pair1[1] === pair2[1] && pair2[1] === pair3[1]) { return true; } } // Check for 2 repeated same triples (e.g., [A,B,C,A,B,C]) if (len >= 6) { const triple1 = [tokens[len - 6], tokens[len - 5], tokens[len - 4]]; const triple2 = [tokens[len - 3], tokens[len - 2], tokens[len - 1]]; if (triple1[0] === triple2[0] && triple1[1] === triple2[1] && triple1[2] === triple2[2]) { return true; } } return false; } // Sessions let adapterSession = null; let decoderInitSession = null; let decoderStepSession = null; // Decoder state let crossCache = null; let selfCache = null; // Tokenizer let tokenizer = null; // Accumulated features let accumulatedFeatures = null; let currentSegmentId = null; class MoonshineTokenizer { constructor() { this.decoder = null; this.vocab = null; } load(tokenizerJson) { this.vocab = tokenizerJson.model.vocab; this.decoder = Object.fromEntries( Object.entries(this.vocab).map(([k, v]) => [v, k]) ); } decode(tokenIds, skipSpecial = true) { const specialTokens = new Set([0, 1, 2]); let text = ''; for (const id of tokenIds) { if (skipSpecial && specialTokens.has(id)) continue; const token = this.decoder[id] || ''; text += token; } // Handle various space placeholder representations text = text.replace(/\u0120/g, ' '); // Ġ (GPT-2 style) text = text.replace(/Ġ/g, ' '); // Literal Ġ character text = text.replace(/▁/g, ' '); // SentencePiece style (U+2581) text = text.replace(/\u010a/g, '\n'); // Newline marker return text.trim(); } } async function runAdapter(features, dims) { const feeds = { 'encoder_output': new ort.Tensor('float32', features, dims) }; const results = await adapterSession.run(feeds); return results.context; } async function initDecoderCache(context) { const feeds = { 'context': context }; const results = await decoderInitSession.run(feeds); // Store cross-attention cache (even-indexed layers) crossCache = []; for (let i = 0; i < cfg.depth * 2; i++) { if ((i + 1) % 2 === 0) { crossCache.push({ k: results[`cache_${i}_k`], v: results[`cache_${i}_v`] }); } } // Initialize empty self-attention cache selfCache = []; for (let i = 0; i < cfg.depth; i++) { selfCache.push({ k: new ort.Tensor('float32', new Float32Array(0), [1, cfg.nheads, 0, cfg.head_dim]), v: new ort.Tensor('float32', new Float32Array(0), [1, cfg.nheads, 0, cfg.head_dim]) }); } } async function decodeStep(tokenId, position) { const feeds = { 'token_id': new ort.Tensor('int64', BigInt64Array.from([BigInt(tokenId)]), [1, 1]), 'position': new ort.Tensor('int64', BigInt64Array.from([BigInt(position)]), [1]) }; // Add cache inputs let selfIdx = 0; let crossIdx = 0; for (let i = 0; i < cfg.depth * 2; i++) { if ((i + 1) % 2 !== 0) { feeds[`in_cache_${i}_k`] = selfCache[selfIdx].k; feeds[`in_cache_${i}_v`] = selfCache[selfIdx].v; selfIdx++; } else { feeds[`in_cache_${i}_k`] = crossCache[crossIdx].k; feeds[`in_cache_${i}_v`] = crossCache[crossIdx].v; crossIdx++; } } const results = await decoderStepSession.run(feeds); // Update self-attention cache selfIdx = 0; for (let i = 0; i < cfg.depth * 2; i++) { if ((i + 1) % 2 !== 0) { selfCache[selfIdx] = { k: results[`out_cache_${i}_k`], v: results[`out_cache_${i}_v`] }; selfIdx++; } } return results.logits; } async function decodeAccumulated() { if (!accumulatedFeatures || accumulatedFeatures.dims[1] === 0) { return ''; } try { const context = await runAdapter(accumulatedFeatures.data, accumulatedFeatures.dims); await initDecoderCache(context); const numFrames = accumulatedFeatures.dims[1]; // Calculate duration in seconds and max tokens based on that const durationSeconds = (numFrames * FRAME_DURATION_MS) / 1000; const maxTokens = Math.max(10, Math.floor(durationSeconds * TOKENS_PER_SECOND)); const tokens = [1]; // BOS for (let step = 0; step < maxTokens; step++) { const logits = await decodeStep(tokens[tokens.length - 1], step); let maxIdx = 0; let maxVal = logits.data[0]; for (let i = 1; i < cfg.vocab_size; i++) { if (logits.data[i] > maxVal) { maxVal = logits.data[i]; maxIdx = i; } } tokens.push(maxIdx); // Stop on EOS if (maxIdx === 2) break; // Stop on repetitive patterns if (hasRepetition(tokens)) { console.log('Stopping decode due to repetition detected'); break; } } return tokenizer.decode(tokens, true); } catch (e) { console.error('Decode error:', e); return ''; } } self.onmessage = async function(e) { const { type, data } = e.data; switch (type) { case 'init': { try { cfg = data.cfg; const onnxUrl = data.onnxUrl; const modelName = data.modelName; const dtype = 'fp32'; tailLatency = cfg.n_future * cfg.encoder_depth; // Load tokenizer self.postMessage({ type: 'status', message: 'Loading tokenizer...' }); self.postMessage({ type: 'model_start', model: 'Tokenizer' }); const tokenizerResponse = await fetch(`${onnxUrl}/tokenizer.json`); const tokenizerJson = await tokenizerResponse.json(); tokenizer = new MoonshineTokenizer(); tokenizer.load(tokenizerJson); self.postMessage({ type: 'model_done', model: 'Tokenizer' }); // Initialize adapter const adapterUrl = `${onnxUrl}/adapter_${modelName}_${dtype}.onnx`; self.postMessage({ type: 'status', message: 'Loading adapter...' }); self.postMessage({ type: 'model_start', model: 'Adapter' }); const adapterBuffer = await fetchModelWithProgress(adapterUrl, 'Adapter'); adapterSession = await ort.InferenceSession.create(adapterBuffer); self.postMessage({ type: 'model_done', model: 'Adapter' }); // Initialize decoder init const decInitUrl = `${onnxUrl}/decoder_init_${modelName}_${dtype}.onnx`; self.postMessage({ type: 'status', message: 'Loading decoder (init)...' }); self.postMessage({ type: 'model_start', model: 'Decoder Init' }); const decInitBuffer = await fetchModelWithProgress(decInitUrl, 'Decoder Init'); decoderInitSession = await ort.InferenceSession.create(decInitBuffer); self.postMessage({ type: 'model_done', model: 'Decoder Init' }); // Initialize decoder step const decStepUrl = `${onnxUrl}/decoder_step_${modelName}_${dtype}.onnx`; self.postMessage({ type: 'status', message: 'Loading decoder (step)...' }); self.postMessage({ type: 'model_start', model: 'Decoder Step' }); const decStepBuffer = await fetchModelWithProgress(decStepUrl, 'Decoder Step'); decoderStepSession = await ort.InferenceSession.create(decStepBuffer); self.postMessage({ type: 'model_done', model: 'Decoder Step' }); self.postMessage({ type: 'ready' }); } catch (err) { self.postMessage({ type: 'error', message: err.message }); } break; } case 'segment_start': { accumulatedFeatures = null; currentSegmentId = data.segmentId; self.postMessage({ type: 'live_caption', text: '' }); break; } case 'segment_end': { if (data.segmentId !== currentSegmentId) break; const text = await decodeAccumulated(); self.postMessage({ type: 'transcript', segmentId: data.segmentId, text: text }); accumulatedFeatures = null; currentSegmentId = null; self.postMessage({ type: 'live_caption', text: '' }); break; } case 'features': { if (data.segmentId !== currentSegmentId) break; const newFeatures = { data: new Float32Array(data.features), dims: data.dims }; console.log(`Decoder received ${data.dims[1]} frames, accumulated: ${accumulatedFeatures ? accumulatedFeatures.dims[1] : 0}`); if (accumulatedFeatures === null) { accumulatedFeatures = newFeatures; } else { // Trim last tailLatency frames from accumulated const numFrames = accumulatedFeatures.dims[1]; const keepFrames = Math.max(0, numFrames - tailLatency); if (keepFrames > 0) { const totalFrames = keepFrames + newFeatures.dims[1]; const combined = new Float32Array(totalFrames * cfg.dim); // Copy kept frames for (let f = 0; f < keepFrames; f++) { for (let d = 0; d < cfg.dim; d++) { combined[f * cfg.dim + d] = accumulatedFeatures.data[f * cfg.dim + d]; } } // Copy new frames combined.set(newFeatures.data, keepFrames * cfg.dim); accumulatedFeatures = { data: combined, dims: [1, totalFrames, cfg.dim] }; } else { accumulatedFeatures = newFeatures; } } // Live caption const partialText = await decodeAccumulated(); self.postMessage({ type: 'live_caption', text: partialText }); break; } } };