|
|
|
|
|
|
|
|
|
|
|
|
|
|
importScripts('https://cdn.jsdelivr.net/npm/onnxruntime-web@1.21.0/dist/ort.all.min.js'); |
|
|
|
|
|
|
|
|
ort.env.wasm.wasmPaths = 'https://cdn.jsdelivr.net/npm/onnxruntime-web@1.21.0/dist/'; |
|
|
|
|
|
const MODEL_CACHE_NAME = 'moonshine-models-v1'; |
|
|
|
|
|
|
|
|
async function fetchModelWithProgress(url, modelName) { |
|
|
|
|
|
try { |
|
|
const cache = await caches.open(MODEL_CACHE_NAME); |
|
|
const cachedResponse = await cache.match(url); |
|
|
|
|
|
if (cachedResponse) { |
|
|
const buffer = await cachedResponse.arrayBuffer(); |
|
|
self.postMessage({ |
|
|
type: 'progress', |
|
|
model: modelName, |
|
|
loaded: buffer.byteLength, |
|
|
total: buffer.byteLength, |
|
|
done: true, |
|
|
cached: true |
|
|
}); |
|
|
console.log(`${modelName} loaded from cache`); |
|
|
return buffer; |
|
|
} |
|
|
} catch (e) { |
|
|
console.warn('Cache API not available:', e.message); |
|
|
} |
|
|
|
|
|
|
|
|
const response = await fetch(url); |
|
|
if (!response.ok) { |
|
|
throw new Error(`Failed to fetch ${modelName}: ${response.status}`); |
|
|
} |
|
|
|
|
|
const contentLength = response.headers.get('Content-Length'); |
|
|
const total = contentLength ? parseInt(contentLength, 10) : 0; |
|
|
|
|
|
if (!response.body || !total) { |
|
|
|
|
|
const buffer = await response.arrayBuffer(); |
|
|
self.postMessage({ |
|
|
type: 'progress', |
|
|
model: modelName, |
|
|
loaded: buffer.byteLength, |
|
|
total: buffer.byteLength, |
|
|
done: true |
|
|
}); |
|
|
|
|
|
try { |
|
|
const cache = await caches.open(MODEL_CACHE_NAME); |
|
|
await cache.put(url, new Response(buffer.slice(0))); |
|
|
} catch (e) { |
|
|
console.warn('Failed to cache model:', e.message); |
|
|
} |
|
|
return buffer; |
|
|
} |
|
|
|
|
|
const reader = response.body.getReader(); |
|
|
const chunks = []; |
|
|
let loaded = 0; |
|
|
|
|
|
while (true) { |
|
|
const { done, value } = await reader.read(); |
|
|
if (done) break; |
|
|
|
|
|
chunks.push(value); |
|
|
loaded += value.length; |
|
|
|
|
|
self.postMessage({ |
|
|
type: 'progress', |
|
|
model: modelName, |
|
|
loaded, |
|
|
total, |
|
|
done: false |
|
|
}); |
|
|
} |
|
|
|
|
|
self.postMessage({ |
|
|
type: 'progress', |
|
|
model: modelName, |
|
|
loaded: total, |
|
|
total, |
|
|
done: true |
|
|
}); |
|
|
|
|
|
|
|
|
const result = new Uint8Array(loaded); |
|
|
let offset = 0; |
|
|
for (const chunk of chunks) { |
|
|
result.set(chunk, offset); |
|
|
offset += chunk.length; |
|
|
} |
|
|
|
|
|
|
|
|
try { |
|
|
const cache = await caches.open(MODEL_CACHE_NAME); |
|
|
await cache.put(url, new Response(result.slice(0))); |
|
|
console.log(`${modelName} cached`); |
|
|
} catch (e) { |
|
|
console.warn('Failed to cache model:', e.message); |
|
|
} |
|
|
|
|
|
return result.buffer; |
|
|
} |
|
|
|
|
|
|
|
|
let cfg = null; |
|
|
let tailLatency = 0; |
|
|
|
|
|
|
|
|
const TOKENS_PER_SECOND = 6.5; |
|
|
const FRAME_DURATION_MS = 20; |
|
|
|
|
|
|
|
|
function hasRepetition(tokens) { |
|
|
const len = tokens.length; |
|
|
if (len < 5) return false; |
|
|
|
|
|
|
|
|
const last5 = tokens.slice(-5); |
|
|
if (last5.every(t => t === last5[0])) { |
|
|
return true; |
|
|
} |
|
|
|
|
|
|
|
|
if (len >= 6) { |
|
|
const pair1 = [tokens[len - 6], tokens[len - 5]]; |
|
|
const pair2 = [tokens[len - 4], tokens[len - 3]]; |
|
|
const pair3 = [tokens[len - 2], tokens[len - 1]]; |
|
|
if (pair1[0] === pair2[0] && pair2[0] === pair3[0] && |
|
|
pair1[1] === pair2[1] && pair2[1] === pair3[1]) { |
|
|
return true; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if (len >= 6) { |
|
|
const triple1 = [tokens[len - 6], tokens[len - 5], tokens[len - 4]]; |
|
|
const triple2 = [tokens[len - 3], tokens[len - 2], tokens[len - 1]]; |
|
|
if (triple1[0] === triple2[0] && |
|
|
triple1[1] === triple2[1] && |
|
|
triple1[2] === triple2[2]) { |
|
|
return true; |
|
|
} |
|
|
} |
|
|
|
|
|
return false; |
|
|
} |
|
|
|
|
|
|
|
|
let adapterSession = null; |
|
|
let decoderInitSession = null; |
|
|
let decoderStepSession = null; |
|
|
|
|
|
|
|
|
let crossCache = null; |
|
|
let selfCache = null; |
|
|
|
|
|
|
|
|
let tokenizer = null; |
|
|
|
|
|
|
|
|
let accumulatedFeatures = null; |
|
|
let currentSegmentId = null; |
|
|
|
|
|
|
|
|
let isDecoding = false; |
|
|
let lastDecodeTime = 0; |
|
|
let pendingDecode = false; |
|
|
const MIN_DECODE_INTERVAL_MS = 500; |
|
|
|
|
|
class MoonshineTokenizer { |
|
|
constructor() { |
|
|
this.decoder = null; |
|
|
this.vocab = null; |
|
|
} |
|
|
|
|
|
load(tokenizerJson) { |
|
|
this.vocab = tokenizerJson.model.vocab; |
|
|
this.decoder = Object.fromEntries( |
|
|
Object.entries(this.vocab).map(([k, v]) => [v, k]) |
|
|
); |
|
|
} |
|
|
|
|
|
decode(tokenIds, skipSpecial = true) { |
|
|
const specialTokens = new Set([0, 1, 2]); |
|
|
let text = ''; |
|
|
|
|
|
for (const id of tokenIds) { |
|
|
if (skipSpecial && specialTokens.has(id)) continue; |
|
|
const token = this.decoder[id] || ''; |
|
|
text += token; |
|
|
} |
|
|
|
|
|
|
|
|
text = text.replace(/\u0120/g, ' '); |
|
|
text = text.replace(/Ġ/g, ' '); |
|
|
text = text.replace(/▁/g, ' '); |
|
|
text = text.replace(/\u010a/g, '\n'); |
|
|
|
|
|
return text.trim(); |
|
|
} |
|
|
} |
|
|
|
|
|
async function runAdapter(features, dims) { |
|
|
const feeds = { |
|
|
'encoder_output': new ort.Tensor('float32', features, dims) |
|
|
}; |
|
|
const results = await adapterSession.run(feeds); |
|
|
return results.context; |
|
|
} |
|
|
|
|
|
async function initDecoderCache(context) { |
|
|
const feeds = { 'context': context }; |
|
|
const results = await decoderInitSession.run(feeds); |
|
|
|
|
|
|
|
|
crossCache = []; |
|
|
for (let i = 0; i < cfg.depth * 2; i++) { |
|
|
if ((i + 1) % 2 === 0) { |
|
|
crossCache.push({ |
|
|
k: results[`cache_${i}_k`], |
|
|
v: results[`cache_${i}_v`] |
|
|
}); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
selfCache = []; |
|
|
for (let i = 0; i < cfg.depth; i++) { |
|
|
selfCache.push({ |
|
|
k: new ort.Tensor('float32', new Float32Array(0), [1, cfg.nheads, 0, cfg.head_dim]), |
|
|
v: new ort.Tensor('float32', new Float32Array(0), [1, cfg.nheads, 0, cfg.head_dim]) |
|
|
}); |
|
|
} |
|
|
} |
|
|
|
|
|
async function decodeStep(tokenId, position) { |
|
|
const feeds = { |
|
|
'token_id': new ort.Tensor('int64', BigInt64Array.from([BigInt(tokenId)]), [1, 1]), |
|
|
'position': new ort.Tensor('int64', BigInt64Array.from([BigInt(position)]), [1]) |
|
|
}; |
|
|
|
|
|
|
|
|
let selfIdx = 0; |
|
|
let crossIdx = 0; |
|
|
for (let i = 0; i < cfg.depth * 2; i++) { |
|
|
if ((i + 1) % 2 !== 0) { |
|
|
feeds[`in_cache_${i}_k`] = selfCache[selfIdx].k; |
|
|
feeds[`in_cache_${i}_v`] = selfCache[selfIdx].v; |
|
|
selfIdx++; |
|
|
} else { |
|
|
feeds[`in_cache_${i}_k`] = crossCache[crossIdx].k; |
|
|
feeds[`in_cache_${i}_v`] = crossCache[crossIdx].v; |
|
|
crossIdx++; |
|
|
} |
|
|
} |
|
|
|
|
|
const results = await decoderStepSession.run(feeds); |
|
|
|
|
|
|
|
|
selfIdx = 0; |
|
|
for (let i = 0; i < cfg.depth * 2; i++) { |
|
|
if ((i + 1) % 2 !== 0) { |
|
|
selfCache[selfIdx] = { |
|
|
k: results[`out_cache_${i}_k`], |
|
|
v: results[`out_cache_${i}_v`] |
|
|
}; |
|
|
selfIdx++; |
|
|
} |
|
|
} |
|
|
|
|
|
return results.logits; |
|
|
} |
|
|
|
|
|
async function decodeAccumulated() { |
|
|
if (!accumulatedFeatures || accumulatedFeatures.dims[1] === 0) { |
|
|
return ''; |
|
|
} |
|
|
|
|
|
try { |
|
|
const context = await runAdapter(accumulatedFeatures.data, accumulatedFeatures.dims); |
|
|
await initDecoderCache(context); |
|
|
|
|
|
const numFrames = accumulatedFeatures.dims[1]; |
|
|
|
|
|
const durationSeconds = (numFrames * FRAME_DURATION_MS) / 1000; |
|
|
const maxTokens = Math.max(10, Math.floor(durationSeconds * TOKENS_PER_SECOND)); |
|
|
|
|
|
const tokens = [1]; |
|
|
for (let step = 0; step < maxTokens; step++) { |
|
|
const logits = await decodeStep(tokens[tokens.length - 1], step); |
|
|
|
|
|
let maxIdx = 0; |
|
|
let maxVal = logits.data[0]; |
|
|
for (let i = 1; i < cfg.vocab_size; i++) { |
|
|
if (logits.data[i] > maxVal) { |
|
|
maxVal = logits.data[i]; |
|
|
maxIdx = i; |
|
|
} |
|
|
} |
|
|
|
|
|
tokens.push(maxIdx); |
|
|
|
|
|
|
|
|
if (maxIdx === 2) break; |
|
|
|
|
|
|
|
|
if (hasRepetition(tokens)) { |
|
|
console.log('Stopping decode due to repetition detected'); |
|
|
break; |
|
|
} |
|
|
} |
|
|
|
|
|
return tokenizer.decode(tokens, true); |
|
|
} catch (e) { |
|
|
console.error('Decode error:', e); |
|
|
return ''; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
function accumulateFeaturesData(data) { |
|
|
const newFeatures = { |
|
|
data: new Float32Array(data.features), |
|
|
dims: data.dims |
|
|
}; |
|
|
|
|
|
if (accumulatedFeatures === null) { |
|
|
accumulatedFeatures = newFeatures; |
|
|
} else { |
|
|
|
|
|
const numFrames = accumulatedFeatures.dims[1]; |
|
|
const keepFrames = Math.max(0, numFrames - tailLatency); |
|
|
|
|
|
if (keepFrames > 0) { |
|
|
const totalFrames = keepFrames + newFeatures.dims[1]; |
|
|
const combined = new Float32Array(totalFrames * cfg.dim); |
|
|
|
|
|
|
|
|
for (let f = 0; f < keepFrames; f++) { |
|
|
for (let d = 0; d < cfg.dim; d++) { |
|
|
combined[f * cfg.dim + d] = accumulatedFeatures.data[f * cfg.dim + d]; |
|
|
} |
|
|
} |
|
|
|
|
|
combined.set(newFeatures.data, keepFrames * cfg.dim); |
|
|
|
|
|
accumulatedFeatures = { |
|
|
data: combined, |
|
|
dims: [1, totalFrames, cfg.dim] |
|
|
}; |
|
|
} else { |
|
|
accumulatedFeatures = newFeatures; |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
const messageQueue = []; |
|
|
let isProcessingQueue = false; |
|
|
|
|
|
async function processMessage(e) { |
|
|
const { type, data } = e.data; |
|
|
|
|
|
switch (type) { |
|
|
case 'init': { |
|
|
try { |
|
|
cfg = data.cfg; |
|
|
const onnxUrl = data.onnxUrl; |
|
|
const modelName = data.modelName; |
|
|
const backend = data.backend || 'wasm'; |
|
|
const dtype = 'fp32'; |
|
|
|
|
|
const sessionOptions = { executionProviders: [backend] }; |
|
|
|
|
|
tailLatency = cfg.n_future * cfg.encoder_depth; |
|
|
|
|
|
|
|
|
self.postMessage({ type: 'status', message: 'Loading tokenizer...' }); |
|
|
self.postMessage({ type: 'model_start', model: 'Tokenizer' }); |
|
|
const tokenizerResponse = await fetch(`${onnxUrl}/tokenizer.json`); |
|
|
const tokenizerJson = await tokenizerResponse.json(); |
|
|
tokenizer = new MoonshineTokenizer(); |
|
|
tokenizer.load(tokenizerJson); |
|
|
self.postMessage({ type: 'model_done', model: 'Tokenizer' }); |
|
|
|
|
|
|
|
|
const adapterUrl = `${onnxUrl}/adapter_${modelName}_${dtype}.onnx`; |
|
|
self.postMessage({ type: 'status', message: 'Loading adapter...' }); |
|
|
self.postMessage({ type: 'model_start', model: 'Adapter' }); |
|
|
const adapterBuffer = await fetchModelWithProgress(adapterUrl, 'Adapter'); |
|
|
adapterSession = await ort.InferenceSession.create(adapterBuffer, sessionOptions); |
|
|
self.postMessage({ type: 'model_done', model: 'Adapter' }); |
|
|
|
|
|
|
|
|
const decInitUrl = `${onnxUrl}/decoder_init_${modelName}_${dtype}.onnx`; |
|
|
self.postMessage({ type: 'status', message: 'Loading decoder (init)...' }); |
|
|
self.postMessage({ type: 'model_start', model: 'Decoder Init' }); |
|
|
const decInitBuffer = await fetchModelWithProgress(decInitUrl, 'Decoder Init'); |
|
|
decoderInitSession = await ort.InferenceSession.create(decInitBuffer, sessionOptions); |
|
|
self.postMessage({ type: 'model_done', model: 'Decoder Init' }); |
|
|
|
|
|
|
|
|
const decStepUrl = `${onnxUrl}/decoder_step_${modelName}_${dtype}.onnx`; |
|
|
self.postMessage({ type: 'status', message: 'Loading decoder (step)...' }); |
|
|
self.postMessage({ type: 'model_start', model: 'Decoder Step' }); |
|
|
const decStepBuffer = await fetchModelWithProgress(decStepUrl, 'Decoder Step'); |
|
|
decoderStepSession = await ort.InferenceSession.create(decStepBuffer, sessionOptions); |
|
|
self.postMessage({ type: 'model_done', model: 'Decoder Step' }); |
|
|
|
|
|
self.postMessage({ type: 'ready', backend: backend }); |
|
|
} catch (err) { |
|
|
self.postMessage({ type: 'error', message: err.message }); |
|
|
} |
|
|
break; |
|
|
} |
|
|
|
|
|
case 'segment_start': { |
|
|
accumulatedFeatures = null; |
|
|
currentSegmentId = data.segmentId; |
|
|
isDecoding = false; |
|
|
lastDecodeTime = 0; |
|
|
pendingDecode = false; |
|
|
self.postMessage({ type: 'live_caption', text: '' }); |
|
|
break; |
|
|
} |
|
|
|
|
|
case 'segment_end': { |
|
|
if (data.segmentId !== currentSegmentId) break; |
|
|
|
|
|
|
|
|
while (isDecoding) { |
|
|
await new Promise(resolve => setTimeout(resolve, 50)); |
|
|
} |
|
|
|
|
|
isDecoding = true; |
|
|
const text = await decodeAccumulated(); |
|
|
isDecoding = false; |
|
|
|
|
|
self.postMessage({ |
|
|
type: 'transcript', |
|
|
segmentId: data.segmentId, |
|
|
text: text |
|
|
}); |
|
|
|
|
|
accumulatedFeatures = null; |
|
|
currentSegmentId = null; |
|
|
self.postMessage({ type: 'live_caption', text: '' }); |
|
|
break; |
|
|
} |
|
|
|
|
|
case 'features': { |
|
|
if (data.segmentId !== currentSegmentId) break; |
|
|
|
|
|
|
|
|
accumulateFeaturesData(data); |
|
|
|
|
|
|
|
|
while (messageQueue.length > 0 && messageQueue[0].data.type === 'features') { |
|
|
const nextMsg = messageQueue.shift(); |
|
|
const nextData = nextMsg.data.data; |
|
|
if (nextData.segmentId === currentSegmentId) { |
|
|
accumulateFeaturesData(nextData); |
|
|
} |
|
|
} |
|
|
|
|
|
console.log(`Decoder accumulated features, total: ${accumulatedFeatures ? accumulatedFeatures.dims[1] : 0} frames`); |
|
|
|
|
|
|
|
|
const now = Date.now(); |
|
|
const timeSinceLastDecode = now - lastDecodeTime; |
|
|
|
|
|
if (isDecoding) { |
|
|
|
|
|
pendingDecode = true; |
|
|
} else if (timeSinceLastDecode >= MIN_DECODE_INTERVAL_MS) { |
|
|
|
|
|
isDecoding = true; |
|
|
lastDecodeTime = now; |
|
|
|
|
|
try { |
|
|
const partialText = await decodeAccumulated(); |
|
|
self.postMessage({ type: 'live_caption', text: partialText }); |
|
|
} finally { |
|
|
isDecoding = false; |
|
|
|
|
|
|
|
|
if (pendingDecode) { |
|
|
pendingDecode = false; |
|
|
setTimeout(async () => { |
|
|
if (!isDecoding && currentSegmentId !== null) { |
|
|
isDecoding = true; |
|
|
lastDecodeTime = Date.now(); |
|
|
try { |
|
|
const text = await decodeAccumulated(); |
|
|
self.postMessage({ type: 'live_caption', text: text }); |
|
|
} finally { |
|
|
isDecoding = false; |
|
|
} |
|
|
} |
|
|
}, MIN_DECODE_INTERVAL_MS); |
|
|
} |
|
|
} |
|
|
} else { |
|
|
|
|
|
pendingDecode = true; |
|
|
} |
|
|
break; |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
async function processQueue() { |
|
|
if (isProcessingQueue) return; |
|
|
isProcessingQueue = true; |
|
|
|
|
|
while (messageQueue.length > 0) { |
|
|
const msg = messageQueue.shift(); |
|
|
await processMessage(msg); |
|
|
} |
|
|
|
|
|
isProcessingQueue = false; |
|
|
} |
|
|
|
|
|
self.onmessage = function(e) { |
|
|
messageQueue.push(e); |
|
|
processQueue(); |
|
|
}; |
|
|
|