moonshine-streaming-demo / encoder_worker.js
keveman's picture
Upload 7 files
bc36801 verified
raw
history blame
6.71 kB
/**
* Encoder Worker - Runs preprocessor + encoder in a separate thread
*/
importScripts('https://cdn.jsdelivr.net/npm/onnxruntime-web@1.17.0/dist/ort.min.js');
// Configure ONNX Runtime to find WASM files from CDN
ort.env.wasm.wasmPaths = 'https://cdn.jsdelivr.net/npm/onnxruntime-web@1.17.0/dist/';
// Model config
let cfg = null;
let preprocessor = null;
let encoder = null;
let tailLatency = 0;
// Preprocessor state
let prepSession = null;
let prepDim = 0;
let prepC1 = 0;
let prepStateC1 = null;
let prepStateC2 = null;
// Encoder state
let encSession = null;
let encDim = 0;
let encNPast = 0;
let encNFuture = 0;
let encEncoderDepth = 0;
let encContextSize = 0;
let encInputBuffer = [];
let encTotalInputFrames = 0;
let encTotalOutputFrames = 0;
function resetPreprocessor() {
if (prepStateC1) prepStateC1.fill(0);
if (prepStateC2) prepStateC2.fill(0);
}
function resetEncoder() {
encInputBuffer = [];
encTotalInputFrames = 0;
encTotalOutputFrames = 0;
}
async function processPreprocessor(audioChunk) {
const feeds = {
'audio_chunk': new ort.Tensor('float32', audioChunk, [1, audioChunk.length]),
'state_c1': new ort.Tensor('float32', prepStateC1, [1, 4, prepDim]),
'state_c2': new ort.Tensor('float32', prepStateC2, [1, 4, prepC1])
};
const results = await prepSession.run(feeds);
// Update states
prepStateC1.set(results.new_state_c1.data);
prepStateC2.set(results.new_state_c2.data);
return {
data: results.features.data,
dims: results.features.dims
};
}
async function processEncoder(melData, melDims, flush = true) {
const newFrames = melDims[1];
// Append new frames to buffer
for (let f = 0; f < newFrames; f++) {
const frame = new Float32Array(encDim);
for (let d = 0; d < encDim; d++) {
frame[d] = melData[f * encDim + d];
}
encInputBuffer.push(frame);
}
encTotalInputFrames += newFrames;
// Calculate output range
const canOutput = flush
? encTotalInputFrames
: Math.max(0, encTotalInputFrames - tailLatency);
const outputFrom = flush
? Math.max(0, encTotalOutputFrames - tailLatency)
: encTotalOutputFrames;
const newOutputCount = canOutput - outputFrom;
if (newOutputCount <= 0) {
return { data: new Float32Array(0), dims: [1, 0, encDim] };
}
// Prepare input buffer tensor
const bufferFrames = encInputBuffer.length;
const bufferData = new Float32Array(bufferFrames * encDim);
for (let f = 0; f < bufferFrames; f++) {
bufferData.set(encInputBuffer[f], f * encDim);
}
const feeds = {
'input': new ort.Tensor('float32', bufferData, [1, bufferFrames, encDim])
};
const results = await encSession.run(feeds);
const fullOutput = results.output;
// Calculate which frames to return
const bufStartFrame = encTotalInputFrames - bufferFrames;
const outputStart = outputFrom - bufStartFrame;
// Extract the subset of output
const resultData = new Float32Array(newOutputCount * encDim);
for (let f = 0; f < newOutputCount; f++) {
for (let d = 0; d < encDim; d++) {
resultData[f * encDim + d] = fullOutput.data[(outputStart + f) * encDim + d];
}
}
// Trim input buffer to context size
if (encInputBuffer.length > encContextSize) {
encInputBuffer = encInputBuffer.slice(-encContextSize);
}
encTotalOutputFrames = canOutput;
return { data: resultData, dims: [1, newOutputCount, encDim] };
}
self.onmessage = async function(e) {
const { type, data } = e.data;
switch (type) {
case 'init': {
try {
cfg = data.cfg;
const onnxUrl = data.onnxUrl;
const modelName = data.modelName;
const dtype = 'fp32';
tailLatency = cfg.n_future * cfg.encoder_depth;
// Initialize preprocessor
self.postMessage({ type: 'status', message: 'Loading preprocessor...' });
prepSession = await ort.InferenceSession.create(
`${onnxUrl}/preprocessor_streaming_${modelName}_${dtype}.onnx`
);
prepDim = cfg.dim;
prepC1 = 2 * cfg.dim;
prepStateC1 = new Float32Array(4 * cfg.dim);
prepStateC2 = new Float32Array(4 * prepC1);
// Initialize encoder
self.postMessage({ type: 'status', message: 'Loading encoder...' });
encSession = await ort.InferenceSession.create(
`${onnxUrl}/encoder_${modelName}_${dtype}.onnx`
);
encDim = cfg.dim;
encNPast = cfg.n_past;
encNFuture = cfg.n_future;
encEncoderDepth = cfg.encoder_depth;
encContextSize = cfg.encoder_depth * (cfg.n_past + cfg.n_future);
self.postMessage({ type: 'ready' });
} catch (err) {
self.postMessage({ type: 'error', message: err.message });
}
break;
}
case 'segment_start': {
resetPreprocessor();
resetEncoder();
self.postMessage({
type: 'segment_start',
segmentId: data.segmentId
});
break;
}
case 'segment_end': {
self.postMessage({
type: 'segment_end',
segmentId: data.segmentId
});
break;
}
case 'audio': {
try {
// Process through preprocessor
const mel = await processPreprocessor(new Float32Array(data.audio));
const audioMs = (data.audio.length / 16000 * 1000).toFixed(0);
console.log(`Audio ${data.audio.length} samples (${audioMs}ms) β†’ Mel ${mel.dims[1]} frames`);
// Process through encoder with flush=true
const enc = await processEncoder(mel.data, mel.dims, true);
console.log(`Mel ${mel.dims[1]} frames β†’ Encoder ${enc.dims[1]} frames (accumulated: ${encTotalOutputFrames})`);
if (enc.dims[1] > 0) {
self.postMessage({
type: 'features',
segmentId: data.segmentId,
features: enc.data,
dims: enc.dims
}, [enc.data.buffer]); // Transfer ownership
}
} catch (err) {
console.error('Encoder error:', err);
}
break;
}
}
};