moonshine-streaming-demo / decoder_worker.js
Manjunath Kudlur
Ensure WebGPU enabled onnx runtime is loaded
2be96ef
/**
* Decoder Worker - Runs adapter + decoder in a separate thread
*/
importScripts('https://cdn.jsdelivr.net/npm/onnxruntime-web@1.21.0/dist/ort.all.min.js');
// Configure ONNX Runtime paths
ort.env.wasm.wasmPaths = 'https://cdn.jsdelivr.net/npm/onnxruntime-web@1.21.0/dist/';
const MODEL_CACHE_NAME = 'moonshine-models-v1';
// Helper to fetch model with progress reporting and caching
async function fetchModelWithProgress(url, modelName) {
// Try to get from cache first
try {
const cache = await caches.open(MODEL_CACHE_NAME);
const cachedResponse = await cache.match(url);
if (cachedResponse) {
const buffer = await cachedResponse.arrayBuffer();
self.postMessage({
type: 'progress',
model: modelName,
loaded: buffer.byteLength,
total: buffer.byteLength,
done: true,
cached: true
});
console.log(`${modelName} loaded from cache`);
return buffer;
}
} catch (e) {
console.warn('Cache API not available:', e.message);
}
// Fetch from network
const response = await fetch(url);
if (!response.ok) {
throw new Error(`Failed to fetch ${modelName}: ${response.status}`);
}
const contentLength = response.headers.get('Content-Length');
const total = contentLength ? parseInt(contentLength, 10) : 0;
if (!response.body || !total) {
// No streaming support or unknown size - just download
const buffer = await response.arrayBuffer();
self.postMessage({
type: 'progress',
model: modelName,
loaded: buffer.byteLength,
total: buffer.byteLength,
done: true
});
// Cache the response
try {
const cache = await caches.open(MODEL_CACHE_NAME);
await cache.put(url, new Response(buffer.slice(0)));
} catch (e) {
console.warn('Failed to cache model:', e.message);
}
return buffer;
}
const reader = response.body.getReader();
const chunks = [];
let loaded = 0;
while (true) {
const { done, value } = await reader.read();
if (done) break;
chunks.push(value);
loaded += value.length;
self.postMessage({
type: 'progress',
model: modelName,
loaded,
total,
done: false
});
}
self.postMessage({
type: 'progress',
model: modelName,
loaded: total,
total,
done: true
});
// Combine chunks into single ArrayBuffer
const result = new Uint8Array(loaded);
let offset = 0;
for (const chunk of chunks) {
result.set(chunk, offset);
offset += chunk.length;
}
// Cache the result
try {
const cache = await caches.open(MODEL_CACHE_NAME);
await cache.put(url, new Response(result.slice(0)));
console.log(`${modelName} cached`);
} catch (e) {
console.warn('Failed to cache model:', e.message);
}
return result.buffer;
}
// Model config
let cfg = null;
let tailLatency = 0;
// Decoding config
const TOKENS_PER_SECOND = 6.5; // Max tokens per second of audio
const FRAME_DURATION_MS = 20; // Each encoder frame is 20ms
// Check for repetitive token patterns that indicate decoding should stop
function hasRepetition(tokens) {
const len = tokens.length;
if (len < 5) return false;
// Check if last 5 tokens are the same
const last5 = tokens.slice(-5);
if (last5.every(t => t === last5[0])) {
return true;
}
// Check for 3 repeated same pairs (e.g., [A,B,A,B,A,B])
if (len >= 6) {
const pair1 = [tokens[len - 6], tokens[len - 5]];
const pair2 = [tokens[len - 4], tokens[len - 3]];
const pair3 = [tokens[len - 2], tokens[len - 1]];
if (pair1[0] === pair2[0] && pair2[0] === pair3[0] &&
pair1[1] === pair2[1] && pair2[1] === pair3[1]) {
return true;
}
}
// Check for 2 repeated same triples (e.g., [A,B,C,A,B,C])
if (len >= 6) {
const triple1 = [tokens[len - 6], tokens[len - 5], tokens[len - 4]];
const triple2 = [tokens[len - 3], tokens[len - 2], tokens[len - 1]];
if (triple1[0] === triple2[0] &&
triple1[1] === triple2[1] &&
triple1[2] === triple2[2]) {
return true;
}
}
return false;
}
// Sessions
let adapterSession = null;
let decoderInitSession = null;
let decoderStepSession = null;
// Decoder state
let crossCache = null;
let selfCache = null;
// Tokenizer
let tokenizer = null;
// Accumulated features
let accumulatedFeatures = null;
let currentSegmentId = null;
// Live caption throttling to prevent pipeline backup
let isDecoding = false;
let lastDecodeTime = 0;
let pendingDecode = false;
const MIN_DECODE_INTERVAL_MS = 500; // Don't decode more often than every 500ms for live captions
class MoonshineTokenizer {
constructor() {
this.decoder = null;
this.vocab = null;
}
load(tokenizerJson) {
this.vocab = tokenizerJson.model.vocab;
this.decoder = Object.fromEntries(
Object.entries(this.vocab).map(([k, v]) => [v, k])
);
}
decode(tokenIds, skipSpecial = true) {
const specialTokens = new Set([0, 1, 2]);
let text = '';
for (const id of tokenIds) {
if (skipSpecial && specialTokens.has(id)) continue;
const token = this.decoder[id] || '';
text += token;
}
// Handle various space placeholder representations
text = text.replace(/\u0120/g, ' '); // Ġ (GPT-2 style)
text = text.replace(/Ġ/g, ' '); // Literal Ġ character
text = text.replace(/▁/g, ' '); // SentencePiece style (U+2581)
text = text.replace(/\u010a/g, '\n'); // Newline marker
return text.trim();
}
}
async function runAdapter(features, dims) {
const feeds = {
'encoder_output': new ort.Tensor('float32', features, dims)
};
const results = await adapterSession.run(feeds);
return results.context;
}
async function initDecoderCache(context) {
const feeds = { 'context': context };
const results = await decoderInitSession.run(feeds);
// Store cross-attention cache (even-indexed layers)
crossCache = [];
for (let i = 0; i < cfg.depth * 2; i++) {
if ((i + 1) % 2 === 0) {
crossCache.push({
k: results[`cache_${i}_k`],
v: results[`cache_${i}_v`]
});
}
}
// Initialize empty self-attention cache
selfCache = [];
for (let i = 0; i < cfg.depth; i++) {
selfCache.push({
k: new ort.Tensor('float32', new Float32Array(0), [1, cfg.nheads, 0, cfg.head_dim]),
v: new ort.Tensor('float32', new Float32Array(0), [1, cfg.nheads, 0, cfg.head_dim])
});
}
}
async function decodeStep(tokenId, position) {
const feeds = {
'token_id': new ort.Tensor('int64', BigInt64Array.from([BigInt(tokenId)]), [1, 1]),
'position': new ort.Tensor('int64', BigInt64Array.from([BigInt(position)]), [1])
};
// Add cache inputs
let selfIdx = 0;
let crossIdx = 0;
for (let i = 0; i < cfg.depth * 2; i++) {
if ((i + 1) % 2 !== 0) {
feeds[`in_cache_${i}_k`] = selfCache[selfIdx].k;
feeds[`in_cache_${i}_v`] = selfCache[selfIdx].v;
selfIdx++;
} else {
feeds[`in_cache_${i}_k`] = crossCache[crossIdx].k;
feeds[`in_cache_${i}_v`] = crossCache[crossIdx].v;
crossIdx++;
}
}
const results = await decoderStepSession.run(feeds);
// Update self-attention cache
selfIdx = 0;
for (let i = 0; i < cfg.depth * 2; i++) {
if ((i + 1) % 2 !== 0) {
selfCache[selfIdx] = {
k: results[`out_cache_${i}_k`],
v: results[`out_cache_${i}_v`]
};
selfIdx++;
}
}
return results.logits;
}
async function decodeAccumulated() {
if (!accumulatedFeatures || accumulatedFeatures.dims[1] === 0) {
return '';
}
try {
const context = await runAdapter(accumulatedFeatures.data, accumulatedFeatures.dims);
await initDecoderCache(context);
const numFrames = accumulatedFeatures.dims[1];
// Calculate duration in seconds and max tokens based on that
const durationSeconds = (numFrames * FRAME_DURATION_MS) / 1000;
const maxTokens = Math.max(10, Math.floor(durationSeconds * TOKENS_PER_SECOND));
const tokens = [1]; // BOS
for (let step = 0; step < maxTokens; step++) {
const logits = await decodeStep(tokens[tokens.length - 1], step);
let maxIdx = 0;
let maxVal = logits.data[0];
for (let i = 1; i < cfg.vocab_size; i++) {
if (logits.data[i] > maxVal) {
maxVal = logits.data[i];
maxIdx = i;
}
}
tokens.push(maxIdx);
// Stop on EOS
if (maxIdx === 2) break;
// Stop on repetitive patterns
if (hasRepetition(tokens)) {
console.log('Stopping decode due to repetition detected');
break;
}
}
return tokenizer.decode(tokens, true);
} catch (e) {
console.error('Decode error:', e);
return '';
}
}
// Helper to accumulate features data
function accumulateFeaturesData(data) {
const newFeatures = {
data: new Float32Array(data.features),
dims: data.dims
};
if (accumulatedFeatures === null) {
accumulatedFeatures = newFeatures;
} else {
// Trim last tailLatency frames from accumulated
const numFrames = accumulatedFeatures.dims[1];
const keepFrames = Math.max(0, numFrames - tailLatency);
if (keepFrames > 0) {
const totalFrames = keepFrames + newFeatures.dims[1];
const combined = new Float32Array(totalFrames * cfg.dim);
// Copy kept frames
for (let f = 0; f < keepFrames; f++) {
for (let d = 0; d < cfg.dim; d++) {
combined[f * cfg.dim + d] = accumulatedFeatures.data[f * cfg.dim + d];
}
}
// Copy new frames
combined.set(newFeatures.data, keepFrames * cfg.dim);
accumulatedFeatures = {
data: combined,
dims: [1, totalFrames, cfg.dim]
};
} else {
accumulatedFeatures = newFeatures;
}
}
}
// Message queue for sequential processing
const messageQueue = [];
let isProcessingQueue = false;
async function processMessage(e) {
const { type, data } = e.data;
switch (type) {
case 'init': {
try {
cfg = data.cfg;
const onnxUrl = data.onnxUrl;
const modelName = data.modelName;
const backend = data.backend || 'wasm';
const dtype = 'fp32';
const sessionOptions = { executionProviders: [backend] };
tailLatency = cfg.n_future * cfg.encoder_depth;
// Load tokenizer
self.postMessage({ type: 'status', message: 'Loading tokenizer...' });
self.postMessage({ type: 'model_start', model: 'Tokenizer' });
const tokenizerResponse = await fetch(`${onnxUrl}/tokenizer.json`);
const tokenizerJson = await tokenizerResponse.json();
tokenizer = new MoonshineTokenizer();
tokenizer.load(tokenizerJson);
self.postMessage({ type: 'model_done', model: 'Tokenizer' });
// Initialize adapter
const adapterUrl = `${onnxUrl}/adapter_${modelName}_${dtype}.onnx`;
self.postMessage({ type: 'status', message: 'Loading adapter...' });
self.postMessage({ type: 'model_start', model: 'Adapter' });
const adapterBuffer = await fetchModelWithProgress(adapterUrl, 'Adapter');
adapterSession = await ort.InferenceSession.create(adapterBuffer, sessionOptions);
self.postMessage({ type: 'model_done', model: 'Adapter' });
// Initialize decoder init
const decInitUrl = `${onnxUrl}/decoder_init_${modelName}_${dtype}.onnx`;
self.postMessage({ type: 'status', message: 'Loading decoder (init)...' });
self.postMessage({ type: 'model_start', model: 'Decoder Init' });
const decInitBuffer = await fetchModelWithProgress(decInitUrl, 'Decoder Init');
decoderInitSession = await ort.InferenceSession.create(decInitBuffer, sessionOptions);
self.postMessage({ type: 'model_done', model: 'Decoder Init' });
// Initialize decoder step
const decStepUrl = `${onnxUrl}/decoder_step_${modelName}_${dtype}.onnx`;
self.postMessage({ type: 'status', message: 'Loading decoder (step)...' });
self.postMessage({ type: 'model_start', model: 'Decoder Step' });
const decStepBuffer = await fetchModelWithProgress(decStepUrl, 'Decoder Step');
decoderStepSession = await ort.InferenceSession.create(decStepBuffer, sessionOptions);
self.postMessage({ type: 'model_done', model: 'Decoder Step' });
self.postMessage({ type: 'ready', backend: backend });
} catch (err) {
self.postMessage({ type: 'error', message: err.message });
}
break;
}
case 'segment_start': {
accumulatedFeatures = null;
currentSegmentId = data.segmentId;
isDecoding = false;
lastDecodeTime = 0;
pendingDecode = false;
self.postMessage({ type: 'live_caption', text: '' });
break;
}
case 'segment_end': {
if (data.segmentId !== currentSegmentId) break;
// Wait for any in-progress decode to finish before final decode
while (isDecoding) {
await new Promise(resolve => setTimeout(resolve, 50));
}
isDecoding = true;
const text = await decodeAccumulated();
isDecoding = false;
self.postMessage({
type: 'transcript',
segmentId: data.segmentId,
text: text
});
accumulatedFeatures = null;
currentSegmentId = null;
self.postMessage({ type: 'live_caption', text: '' });
break;
}
case 'features': {
if (data.segmentId !== currentSegmentId) break;
// Accumulate this message's features
accumulateFeaturesData(data);
// Drain all pending features messages from the queue and accumulate them too
while (messageQueue.length > 0 && messageQueue[0].data.type === 'features') {
const nextMsg = messageQueue.shift();
const nextData = nextMsg.data.data;
if (nextData.segmentId === currentSegmentId) {
accumulateFeaturesData(nextData);
}
}
console.log(`Decoder accumulated features, total: ${accumulatedFeatures ? accumulatedFeatures.dims[1] : 0} frames`);
// Live caption with throttling
const now = Date.now();
const timeSinceLastDecode = now - lastDecodeTime;
if (isDecoding) {
// Already decoding, mark that we need another decode when done
pendingDecode = true;
} else if (timeSinceLastDecode >= MIN_DECODE_INTERVAL_MS) {
// Enough time has passed, decode now
isDecoding = true;
lastDecodeTime = now;
try {
const partialText = await decodeAccumulated();
self.postMessage({ type: 'live_caption', text: partialText });
} finally {
isDecoding = false;
// If there was a pending decode request, schedule it
if (pendingDecode) {
pendingDecode = false;
setTimeout(async () => {
if (!isDecoding && currentSegmentId !== null) {
isDecoding = true;
lastDecodeTime = Date.now();
try {
const text = await decodeAccumulated();
self.postMessage({ type: 'live_caption', text: text });
} finally {
isDecoding = false;
}
}
}, MIN_DECODE_INTERVAL_MS);
}
}
} else {
// Too soon since last decode, mark pending
pendingDecode = true;
}
break;
}
}
}
async function processQueue() {
if (isProcessingQueue) return;
isProcessingQueue = true;
while (messageQueue.length > 0) {
const msg = messageQueue.shift();
await processMessage(msg);
}
isProcessingQueue = false;
}
self.onmessage = function(e) {
messageQueue.push(e);
processQueue();
};