soprano-web-onnx / inference-worker.js
KevinAHM's picture
Soprano 1.1
9b19787
// ONNX Runtime Web Worker (Classic Script)
console.log('Worker Script Starting (Classic)...');
self.postMessage({ type: 'status', status: 'Worker Thread Started', state: 'idle' });
try {
const ortVersion = '1.20.0';
importScripts(`https://cdn.jsdelivr.net/npm/onnxruntime-web@${ortVersion}/dist/ort.min.js`);
} catch (e) {
console.error('Failed to load ORT in worker:', e);
}
// Configuration
const MODELS = {
backbone: './onnx/soprano_backbone_kv_fp32.onnx',
decoder: './onnx/soprano_decoder_int8.onnx',
tokenizer: './' // Tokenizer loading still needs context, we'll see if it works in worker or needs to be passed
};
// We need to import the Hugging Face tokenizer library appropriately for a worker.
// The main file import was dynamic. We'll try to do the same here.
// Note: Transformers.js usually works in workers.
const RECEPTIVE_FIELD = 4;
const TOKEN_SIZE = 2048;
const SAMPLE_RATE = 32000;
// State
let backboneSession = null;
let decoderSession = null;
let tokenizer = null;
let isGenerating = false;
let isReady = false;
// FP16 Lookup Table
let fp16Lookup = new Float32Array(65536);
let isFp16Backbone = false;
// Helpers
function initFp16Lookup() {
for (let i = 0; i < 65536; i++) {
const s = (i & 0x8000) >> 15;
const e = (i & 0x7C00) >> 10;
const f = i & 0x03FF;
if (e === 0) {
fp16Lookup[i] = (s ? -1 : 1) * Math.pow(2, -14) * (f / 1024);
} else if (e === 31) {
fp16Lookup[i] = f ? NaN : ((s ? -1 : 1) * Infinity);
} else {
fp16Lookup[i] = (s ? -1 : 1) * Math.pow(2, e - 15) * (1 + f / 1024);
}
}
}
// ----------------------------------------------------------------------------
// Text Preprocessing (Ported from onnx-streaming.js)
// ----------------------------------------------------------------------------
// ... (Including the full text preprocessing logic here to keep worker self-contained) ...
// For brevity in this tool call, I will include the necessary functions.
// Ideally, these would be in a shared utils file, but I'll paste them to ensure it works.
const ONES = ['', 'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'ten', 'eleven', 'twelve', 'thirteen', 'fourteen', 'fifteen', 'sixteen', 'seventeen', 'eighteen', 'nineteen'];
const TENS = ['', '', 'twenty', 'thirty', 'forty', 'fifty', 'sixty', 'seventy', 'eighty', 'ninety'];
const ORDINAL_ONES = ['', 'first', 'second', 'third', 'fourth', 'fifth', 'sixth', 'seventh', 'eighth', 'ninth', 'tenth', 'eleventh', 'twelfth', 'thirteenth', 'fourteenth', 'fifteenth', 'sixteenth', 'seventeenth', 'eighteenth', 'nineteenth'];
const ORDINAL_TENS = ['', '', 'twentieth', 'thirtieth', 'fortieth', 'fiftieth', 'sixtieth', 'seventieth', 'eightieth', 'ninetieth'];
function numberToWords(num, options = {}) {
const { andword = '', zero = 'zero', group = 0 } = options;
if (num === 0) return zero;
const convert = (n) => {
if (n < 20) return ONES[n];
if (n < 100) return TENS[Math.floor(n / 10)] + (n % 10 ? ' ' + ONES[n % 10] : '');
if (n < 1000) {
const remainder = n % 100;
return ONES[Math.floor(n / 100)] + ' hundred' + (remainder ? (andword ? ' ' + andword + ' ' : ' ') + convert(remainder) : '');
}
if (n < 1000000) {
const thousands = Math.floor(n / 1000);
const remainder = n % 1000;
return convert(thousands) + ' thousand' + (remainder ? ' ' + convert(remainder) : '');
}
if (n < 1000000000) {
const millions = Math.floor(n / 1000000);
const remainder = n % 1000000;
return convert(millions) + ' million' + (remainder ? ' ' + convert(remainder) : '');
}
const billions = Math.floor(n / 1000000000);
const remainder = n % 1000000000;
return convert(billions) + ' billion' + (remainder ? ' ' + convert(remainder) : '');
};
if (group === 2 && num > 1000 && num < 10000) {
const high = Math.floor(num / 100);
const low = num % 100;
if (low === 0) return convert(high) + ' hundred';
else if (low < 10) return convert(high) + ' ' + (zero === 'oh' ? 'oh' : zero) + ' ' + ONES[low];
else return convert(high) + ' ' + convert(low);
}
return convert(num);
}
function ordinalToWords(num) {
if (num < 20) return ORDINAL_ONES[num] || numberToWords(num) + 'th';
if (num < 100) {
const tens = Math.floor(num / 10);
const ones = num % 10;
if (ones === 0) return ORDINAL_TENS[tens];
return TENS[tens] + ' ' + ORDINAL_ONES[ones];
}
const cardinal = numberToWords(num);
if (cardinal.endsWith('y')) return cardinal.slice(0, -1) + 'ieth';
if (cardinal.endsWith('one')) return cardinal.slice(0, -3) + 'first';
if (cardinal.endsWith('two')) return cardinal.slice(0, -3) + 'second';
if (cardinal.endsWith('three')) return cardinal.slice(0, -5) + 'third';
if (cardinal.endsWith('ve')) return cardinal.slice(0, -2) + 'fth';
if (cardinal.endsWith('e')) return cardinal.slice(0, -1) + 'th';
if (cardinal.endsWith('t')) return cardinal + 'h';
return cardinal + 'th';
}
const UNICODE_MAP = {
'à': 'a', 'á': 'a', 'â': 'a', 'ã': 'a', 'ä': 'a', 'å': 'a', 'æ': 'ae', 'ç': 'c', 'è': 'e', 'é': 'e', 'ê': 'e', 'ë': 'e', 'ì': 'i', 'í': 'i', 'î': 'i', 'ï': 'i', 'ñ': 'n', 'ò': 'o', 'ó': 'o', 'ô': 'o', 'õ': 'o', 'ö': 'o', 'ø': 'o', 'ù': 'u', 'ú': 'u', 'û': 'u', 'ü': 'u', 'ý': 'y', 'ÿ': 'y', 'ß': 'ss', 'œ': 'oe', 'ð': 'd', 'þ': 'th', 'À': 'A', 'Á': 'A', 'Â': 'A', 'Ã': 'A', 'Ä': 'A', 'Å': 'A', 'Æ': 'AE', 'Ç': 'C', 'È': 'E', 'É': 'E', 'Ê': 'E', 'Ë': 'E', 'Ì': 'I', 'Í': 'I', 'Î': 'I', 'Ï': 'I', 'Ñ': 'N', 'Ò': 'O', 'Ó': 'O', 'Ô': 'O', 'Õ': 'O', 'Ö': 'O', 'Ø': 'O', 'Ù': 'U', 'Ú': 'U', 'Û': 'U', 'Ü': 'U', 'Ý': 'Y', '\u201C': '"', '\u201D': '"', '\u2018': "'", '\u2019': "'", '\u2026': '...', '\u2013': '-', '\u2014': '-'
};
function convertToAscii(text) {
return text.split('').map(c => UNICODE_MAP[c] || c).join('').normalize('NFD').replace(/[\u0300-\u036f]/g, '');
}
const ABBREVIATIONS = [
[/\bmrs\./gi, 'misuss'], [/\bms\./gi, 'miss'], [/\bmr\./gi, 'mister'], [/\bdr\./gi, 'doctor'], [/\bst\./gi, 'saint'], [/\bco\./gi, 'company'], [/\bjr\./gi, 'junior'], [/\bmaj\./gi, 'major'], [/\bgen\./gi, 'general'], [/\bdrs\./gi, 'doctors'], [/\brev\./gi, 'reverend'], [/\blt\./gi, 'lieutenant'], [/\bhon\./gi, 'honorable'], [/\bsgt\./gi, 'sergeant'], [/\bcapt\./gi, 'captain'], [/\besq\./gi, 'esquire'], [/\bltd\./gi, 'limited'], [/\bcol\./gi, 'colonel'], [/\bft\./gi, 'fort']
];
const CASED_ABBREVIATIONS = [
[/\bTTS\b/g, 'text to speech'], [/\bHz\b/g, 'hertz'], [/\bkHz\b/g, 'kilohertz'], [/\bKBs\b/g, 'kilobytes'], [/\bKB\b/g, 'kilobyte'], [/\bMBs\b/g, 'megabytes'], [/\bMB\b/g, 'megabyte'], [/\bGBs\b/g, 'gigabytes'], [/\bGB\b/g, 'gigabyte'], [/\bTBs\b/g, 'terabytes'], [/\bTB\b/g, 'terabyte'], [/\bAPIs\b/g, "a p i's"], [/\bAPI\b/g, 'a p i'], [/\bCLIs\b/g, "c l i's"], [/\bCLI\b/g, 'c l i'], [/\bCPUs\b/g, "c p u's"], [/\bCPU\b/g, 'c p u'], [/\bGPUs\b/g, "g p u's"], [/\bGPU\b/g, 'g p u'], [/\bAve\b/g, 'avenue'], [/\betc\b/g, 'etcetera']
];
function expandAbbreviations(text) {
for (const [regex, replacement] of [...ABBREVIATIONS, ...CASED_ABBREVIATIONS]) text = text.replace(regex, replacement);
return text;
}
const NUM_PREFIX_RE = /#(\d)/g;
const NUM_SUFFIX_RE = /(\d)([KMBT])/gi;
const NUM_LETTER_SPLIT_RE = /(\d)([a-z])|([a-z])(\d)/gi;
const COMMA_NUMBER_RE = /(\d[\d,]+\d)/g;
const DATE_RE = /(^|[^/])(\d\d?[/-]\d\d?[/-]\d\d(?:\d\d)?)($|[^/])/g;
const PHONE_NUMBER_RE = /\(?\d{3}\)?[-.\s]\d{3}[-.\s]?\d{4}/g;
const TIME_RE = /(\d\d?):(\d\d)(?::(\d\d))?/g;
const POUNDS_RE = /£([\d,]*\d+)/g;
const DOLLARS_RE = /\$([\d.,]*\d+)/g;
const DECIMAL_NUMBER_RE = /(\d+(?:\.\d+)+)/g;
const MULTIPLY_RE = /(\d)\s?\*\s?(\d)/g;
const DIVIDE_RE = /(\d)\s?\/\s?(\d)/g;
const ADD_RE = /(\d)\s?\+\s?(\d)/g;
const SUBTRACT_RE = /(\d)?\s?-\s?(\d)/g;
const FRACTION_RE = /(\d+)\/(\d+)/g;
const ORDINAL_RE = /(\d+)(st|nd|rd|th)/gi;
const NUMBER_RE = /\d+/g;
function normalizeNumbers(text) {
text = text.replace(NUM_PREFIX_RE, (_, d) => `number ${d}`);
text = text.replace(NUM_SUFFIX_RE, (_, num, suffix) => {
const map = { k: 'thousand', m: 'million', b: 'billion', t: 'trillion' };
return `${num} ${map[suffix.toLowerCase()]}`;
});
for (let i = 0; i < 2; i++) {
text = text.replace(NUM_LETTER_SPLIT_RE, (m, d1, l1, l2, d2) => {
if (d1 && l1) return `${d1} ${l1}`;
if (l2 && d2) return `${l2} ${d2}`;
return m;
});
}
text = text.replace(COMMA_NUMBER_RE, m => m.replace(/,/g, ''));
text = text.replace(DATE_RE, (_, pre, date, post) => pre + date.split(/[./-]/).join(' dash ') + post);
text = text.replace(PHONE_NUMBER_RE, m => {
const digits = m.replace(/\D/g, '');
return digits.length === 10 ? `${digits.slice(0, 3).split('').join(' ')}, ${digits.slice(3, 6).split('').join(' ')}, ${digits.slice(6).split('').join(' ')}` : m;
});
text = text.replace(TIME_RE, (_, hours, minutes, seconds) => {
const h = parseInt(hours), m = parseInt(minutes), s = seconds ? parseInt(seconds) : 0;
if (!seconds) return m === 0 ? (h === 0 ? '0' : h > 12 ? `${hours} minutes` : `${hours} o'clock`) : minutes.startsWith('0') ? `${hours} oh ${minutes[1]}` : `${hours} ${minutes}`;
let res = '';
if (h !== 0) res = hours + ' ' + (m === 0 ? 'oh oh' : minutes.startsWith('0') ? `oh ${minutes[1]}` : minutes);
else if (m !== 0) res = minutes + ' ' + (s === 0 ? 'oh oh' : seconds.startsWith('0') ? `oh ${seconds[1]}` : seconds);
else res = seconds;
return res + ' ' + (s === 0 ? '' : seconds.startsWith('0') ? `oh ${seconds[1]}` : seconds);
});
text = text.replace(POUNDS_RE, (_, amount) => `${amount.replace(/,/g, '')} pounds`);
text = text.replace(DOLLARS_RE, (_, amount) => {
const parts = amount.replace(/,/g, '').split('.');
const dollars = parseInt(parts[0]) || 0;
const cents = parts[1] ? parseInt(parts[1]) : 0;
if (dollars && cents) return `${dollars} ${dollars === 1 ? 'dollar' : 'dollars'}, ${cents} ${cents === 1 ? 'cent' : 'cents'}`;
if (dollars) return `${dollars} ${dollars === 1 ? 'dollar' : 'dollars'}`;
if (cents) return `${cents} ${cents === 1 ? 'cent' : 'cents'}`;
return 'zero dollars';
});
text = text.replace(DECIMAL_NUMBER_RE, m => m.split('.').join(' point ').split('').join(' ')); // Simplified
text = text.replace(MULTIPLY_RE, '$1 times $2');
text = text.replace(DIVIDE_RE, '$1 over $2');
text = text.replace(ADD_RE, '$1 plus $2');
text = text.replace(SUBTRACT_RE, (_, a, b) => (a ? a : '') + ' minus ' + b);
text = text.replace(FRACTION_RE, '$1 over $2');
text = text.replace(ORDINAL_RE, (_, num) => ordinalToWords(parseInt(num)));
text = text.replace(NUMBER_RE, m => {
const num = parseInt(m);
if (num > 1000 && num < 3000) {
if (num === 2000) return 'two thousand';
if (num > 2000 && num < 2010) return 'two thousand ' + numberToWords(num % 100);
if (num % 100 === 0) return numberToWords(Math.floor(num / 100)) + ' hundred';
return numberToWords(num, { zero: 'oh', group: 2 });
}
return numberToWords(num);
});
return text;
}
const SPECIAL_CHARACTERS = [
[/@/g, ' at '], [/&/g, ' and '], [/%/g, ' percent '], [/:/g, '.'], [/;/g, ','], [/\+/g, ' plus '], [/\\/g, ' backslash '], [/~/g, ' about '], [/(^| )<3/g, ' heart '], [/<=/g, ' less than or equal to '], [/>=/g, ' greater than or equal to '], [/</g, ' less than '], [/>/g, ' greater than '], [/=/g, ' equals '], [/\//g, ' slash '], [/_/g, ' '],
];
const LINK_HEADER_RE = /https?:\/\//gi;
const DASH_RE = /(.) - (.)/g;
const DOT_RE = /([A-Z])\.([A-Z])/gi;
const PARENTHESES_RE = /[\(\[\{][^\)\]\}]*[\)\]\}](.)?/g;
function normalizeSpecial(text) {
text = text.replace(LINK_HEADER_RE, 'h t t p s colon slash slash ');
text = text.replace(DASH_RE, '$1, $2');
text = text.replace(DOT_RE, '$1 dot $2');
text = text.replace(PARENTHESES_RE, (m, after) => {
let result = m.replace(/[\(\[\{]/g, ', ').replace(/[\)\]\}]/g, ', ');
if (after && /[$.!?,]/.test(after)) result = result.slice(0, -2) + after;
return result;
});
return text;
}
function expandSpecialCharacters(text) {
for (const [regex, replacement] of SPECIAL_CHARACTERS) text = text.replace(regex, replacement);
return text;
}
function normalizeNewlines(text) {
return text.split('\n').map(line => {
line = line.trim();
if (!line) return '';
if (!/[.!?]$/.test(line)) line += '.';
return line;
}).join(' ');
}
function removeUnknownCharacters(text) {
text = text.replace(/[^A-Za-z !\$%&'\*\+,\-./0123456789<>\?_]/g, '');
return text.replace(/[<>\/_+]/g, '');
}
function collapseWhitespace(text) {
return text.replace(/\s+/g, ' ').replace(/ ([.\?!,])/g, '$1');
}
function dedupPunctuation(text) {
return text.replace(/\.\.\.+/g, '[ELLIPSIS]').replace(/,+/g, ',').replace(/[.,]*\.[.,]*/g, '.').replace(/[.,!]*![.,!]*/g, '!').replace(/[.,!?]*\?[.,!?]*/g, '?').replace(/\[ELLIPSIS\]/g, '...');
}
function cleanText(text) {
text = convertToAscii(text);
text = normalizeNewlines(text);
text = normalizeNumbers(text);
text = normalizeSpecial(text);
text = expandAbbreviations(text);
text = expandSpecialCharacters(text);
text = text.toLowerCase();
text = removeUnknownCharacters(text);
text = collapseWhitespace(text);
text = dedupPunctuation(text);
return text.trim();
}
function preprocessText(text, batchSize = 3, minLength = 30) {
text = text.trim();
const cleanedText = cleanText(text);
let sentences = cleanedText.split(/(?<=[.!?])\s+/).filter(s => s.trim());
if (sentences.length === 0) return cleanedText ? [`[STOP][TEXT]${cleanedText}[START]`] : [];
if (minLength > 0 && sentences.length > 1) {
const merged = [];
for (let i = 0; i < sentences.length; i++) {
const cur = sentences[i];
if (cur.length < minLength) {
if (merged.length > 0) merged[merged.length - 1] = (merged[merged.length - 1] + ' ' + cur).trim();
else if (i + 1 < sentences.length) sentences[i + 1] = (cur + ' ' + sentences[i + 1]).trim();
else merged.push(cur);
} else merged.push(cur);
}
sentences = merged;
}
const prompts = [];
for (let i = 0; i < sentences.length; i += batchSize) {
const batch = sentences.slice(i, i + batchSize).join(' ');
prompts.push(`[STOP][TEXT]${batch}[START]`);
}
return prompts;
}
// ----------------------------------------------------------------------------
// Worker Logic
// ----------------------------------------------------------------------------
self.onmessage = async (e) => {
const { type, data } = e.data;
console.log('Worker received message:', type);
if (type === 'load') {
try {
await loadModels();
postMessage({ type: 'loaded' });
} catch (err) {
postMessage({ type: 'error', error: err.toString() });
}
} else if (type === 'generate') {
if (!isReady) {
postMessage({ type: 'error', error: 'Models are not loaded yet.' });
return;
}
if (isGenerating) return;
try {
await startGeneration(data.text);
} catch (err) {
console.error('Generation Error:', err);
postMessage({ type: 'error', error: err.toString() });
}
}
else if (type === 'stop') {
isGenerating = false;
postMessage({ type: 'status', status: 'Stopped', state: 'idle' });
}
};
async function loadModels() {
if (backboneSession) return;
postMessage({ type: 'status', status: 'Loading models...', state: 'loading' });
// Configure WASM Paths to use EXACT same version as loader
const version = '1.20.0';
const cdnBase = `https://cdn.jsdelivr.net/npm/onnxruntime-web@${version}/dist/`;
ort.env.wasm.wasmPaths = cdnBase;
// Disable multi-threading if not in cross-origin isolated environment to avoid ERR_WASM_FILE_NOT_FOUND
if (!self.crossOriginIsolated) {
console.warn('Environment is not cross-origin isolated. Disabling WASM multi-threading.');
ort.env.wasm.numThreads = 1;
} else if (typeof navigator !== 'undefined' && navigator.hardwareConcurrency) {
ort.env.wasm.numThreads = Math.min(navigator.hardwareConcurrency, 8);
}
try {
const backboneOptions = {
executionProviders: ['wasm'],
freeDimensionOverrides: { 'batch': 1 },
graphOptimizationLevel: 'all'
};
// Initialize FP16 Lookup
isFp16Backbone = MODELS.backbone.includes('fp16');
if (isFp16Backbone) initFp16Lookup();
console.log('Loading Backbone...');
backboneSession = await ort.InferenceSession.create(MODELS.backbone, backboneOptions);
console.log('Loading Decoder...');
const decoderBuf = await fetch(MODELS.decoder).then(r => {
if (!r.ok) throw new Error(`Failed to load decoder: ${r.statusText}`);
return r.arrayBuffer();
});
// External data check
let dataBuf = null;
try {
const dataUrl = MODELS.decoder + '.data';
const dataRes = await fetch(dataUrl);
if (dataRes.ok) {
dataBuf = await dataRes.arrayBuffer();
}
} catch (e) { }
const decoderOptions = {
executionProviders: ['wasm'],
freeDimensionOverrides: { 'batch': 1 }
};
if (dataBuf) {
decoderOptions.externalData = [{
data: new Uint8Array(dataBuf),
path: MODELS.decoder.split('/').pop() + '.data'
}];
}
decoderSession = await ort.InferenceSession.create(new Uint8Array(decoderBuf), decoderOptions);
console.log('Loading Tokenizer...');
const transformers = await import('https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.0.0');
const { AutoTokenizer, env } = transformers;
env.allowLocalModels = true;
env.allowRemoteModels = false;
env.localModelPath = new URL('.', self.location.href).pathname;
tokenizer = await AutoTokenizer.from_pretrained(MODELS.tokenizer, {
local_files_only: true
});
isReady = true;
postMessage({ type: 'status', status: 'Ready', state: 'idle' });
postMessage({ type: 'model_status', status: 'ready', text: 'Ready' });
postMessage({ type: 'loaded' });
} catch (err) {
console.error('Model load failed in worker:', err);
throw err;
}
}
async function startGeneration(text) {
isGenerating = true;
postMessage({ type: 'status', status: 'Generating...', state: 'running' });
const prompts = preprocessText(text);
const overallStartTime = performance.now();
let isFirstBatch = true;
let cumulativeSamples = 0;
// We need to pass the generation start time for RTF calculation logic
// But since logic is here, we can just handle it.
for (const prompt of prompts) {
if (!isGenerating) break;
const { input_ids } = await tokenizer(prompt);
// Note: tokenizer runs in worker, so input_ids.data is available
const batchSamples = await generationLoop(input_ids.data, overallStartTime, isFirstBatch, cumulativeSamples);
cumulativeSamples += batchSamples;
isFirstBatch = false;
}
if (isGenerating) {
postMessage({ type: 'stream_ended' });
postMessage({ type: 'status', status: 'Finished', state: 'idle' });
}
isGenerating = false;
}
// Sampling Cache
let _topKIndices = null;
let _topKScores = null;
let _topKOrder = null;
let _topKExp = null;
const samplingParams = { temperature: 0.3, topK: 50, topP: 0.95, repetitionPenalty: 1.2 };
async function generationLoop(promptTokens, startTime, isFirstBatch = true, cumulativeSamples = 0) {
const batch = 1;
const numLayers = 17;
const hiddenDim = 128; // This seems small, confirming config... yes from previous code.
const promptLen = promptTokens.length;
const vocabSize = 8192;
const maxNewTokens = 512;
const seenTokenMask = new Uint8Array(vocabSize);
for (let i = 0; i < promptTokens.length; i++) {
const tid = Number(promptTokens[i]);
if (tid >= 0 && tid < vocabSize) seenTokenMask[tid] = 1;
}
const kvType = isFp16Backbone ? 'float16' : 'float32';
const kvData = isFp16Backbone ? new Uint16Array(0) : new Float32Array(0);
let pastKeyValues = {};
for (let i = 0; i < numLayers; i++) {
pastKeyValues[`past_key_values.${i}.key`] = new ort.Tensor(kvType, kvData, [batch, 1, 0, hiddenDim]);
pastKeyValues[`past_key_values.${i}.value`] = new ort.Tensor(kvType, kvData, [batch, 1, 0, hiddenDim]);
}
const maxSeqLen = promptLen + maxNewTokens;
const attentionMaskData = new BigInt64Array(maxSeqLen);
attentionMaskData.fill(1n);
let currentSeqLen = promptLen;
const nextInputIdData = new BigInt64Array(1);
const nextPositionIdData = new BigInt64Array(1);
const nextInputIdsTensor = new ort.Tensor('int64', nextInputIdData, [batch, 1]);
const nextPositionIdsTensor = new ort.Tensor('int64', nextPositionIdData, [batch, 1]);
let currentInputIds = new ort.Tensor('int64', BigInt64Array.from(promptTokens), [batch, promptLen]);
let currentAttentionMask = new ort.Tensor('int64', attentionMaskData.subarray(0, currentSeqLen), [batch, currentSeqLen]);
let currentPositionIds = new ort.Tensor('int64', BigInt64Array.from({ length: promptLen }, (_, i) => BigInt(i)), [batch, promptLen]);
const hiddenStatesBuffer = [];
let totalSamples = 0;
const targetChunkSize = 8;
let chunkCounter = targetChunkSize;
let firstChunk = true;
// Pipelining
let lastDecoderPromise = Promise.resolve();
let chunkBackboneTime = 0;
// We'll track start time for generation inside this batch
if (isFirstBatch) {
postMessage({ type: 'generation_started', data: { time: performance.now() } });
}
for (let i = 0; i < maxNewTokens; i++) {
if (!isGenerating) break;
// Yield Check (Optimization)
if (i % 4 === 0) {
await new Promise(resolve => setTimeout(resolve, 0));
}
const inputs = {
input_ids: currentInputIds,
attention_mask: currentAttentionMask,
position_ids: currentPositionIds,
...pastKeyValues
};
const bbStart = performance.now();
const outputs = await backboneSession.run(inputs);
chunkBackboneTime += (performance.now() - bbStart);
const backboneNames = backboneSession.outputNames;
const logits = outputs[backboneNames[0]];
const lastHiddenState = outputs[backboneNames[backboneNames.length - 1]];
for (let j = 0; j < numLayers; j++) {
pastKeyValues[`past_key_values.${j}.key`] = outputs[backboneNames[1 + j * 2]];
pastKeyValues[`past_key_values.${j}.value`] = outputs[backboneNames[2 + j * 2]];
}
const nextTokenId = sample(logits, seenTokenMask);
const finished = (nextTokenId === 3n);
const nextTokenIdNum = Number(nextTokenId);
if (nextTokenIdNum >= 0 && nextTokenIdNum < vocabSize) seenTokenMask[nextTokenIdNum] = 1;
const seqLen = lastHiddenState.dims[1];
const hiddenDimSize = lastHiddenState.dims[2];
const lastTokenStateRaw = lastHiddenState.data.subarray((seqLen - 1) * hiddenDimSize, seqLen * hiddenDimSize);
let lastTokenState;
if (lastTokenStateRaw instanceof Uint16Array) {
lastTokenState = new Float32Array(hiddenDimSize);
for (let j = 0; j < hiddenDimSize; j++) {
lastTokenState[j] = fp16Lookup[lastTokenStateRaw[j]];
}
} else {
lastTokenState = new Float32Array(lastTokenStateRaw);
}
if (i > 0 && !finished) {
hiddenStatesBuffer.push(new Float32Array(lastTokenState));
if (hiddenStatesBuffer.length > 2 * RECEPTIVE_FIELD + targetChunkSize) {
hiddenStatesBuffer.splice(0, hiddenStatesBuffer.length - (2 * RECEPTIVE_FIELD + targetChunkSize));
}
}
// Decode Logic
if (finished || hiddenStatesBuffer.length >= RECEPTIVE_FIELD + targetChunkSize) {
if (finished || chunkCounter === targetChunkSize) {
const window = hiddenStatesBuffer.slice(-hiddenStatesBuffer.length);
const currentWindowSize = window.length;
const decoderInput = new Float32Array(512 * currentWindowSize);
for (let w = 0; w < currentWindowSize; w++) {
for (let d = 0; d < 512; d++) {
decoderInput[d * currentWindowSize + w] = window[w][d];
}
}
const isLast = finished;
const captureChunkCounter = chunkCounter;
const captureFirstChunk = firstChunk;
const captureBBTime = chunkBackboneTime;
chunkBackboneTime = 0;
// Send to decoder (in promise chain)
lastDecoderPromise = lastDecoderPromise.then(async () => {
const decStart = performance.now();
const decoderOutputs = await decoderSession.run({
[decoderSession.inputNames[0]]: new ort.Tensor('float32', decoderInput, [1, 512, currentWindowSize])
});
const decDuration = performance.now() - decStart;
const audio = decoderOutputs[decoderSession.outputNames[0]].data;
let audioChunk;
if (isLast) {
const startIdx = audio.length - (RECEPTIVE_FIELD + captureChunkCounter - 1) * TOKEN_SIZE + TOKEN_SIZE;
audioChunk = audio.subarray(startIdx);
} else {
const startIdx = audio.length - (RECEPTIVE_FIELD + targetChunkSize) * TOKEN_SIZE + TOKEN_SIZE;
const endIdx = audio.length - RECEPTIVE_FIELD * TOKEN_SIZE + TOKEN_SIZE;
audioChunk = audio.subarray(startIdx, endIdx);
}
// Post audio to main thread
postMessage({
type: 'audio_chunk',
data: audioChunk,
metrics: {
bbTime: captureBBTime,
decTime: decDuration,
chunkDuration: audioChunk.length / SAMPLE_RATE,
isFirst: captureFirstChunk && isFirstBatch
}
}, [audioChunk.buffer]); // Transferable
});
firstChunk = false;
chunkCounter = 0;
}
chunkCounter++;
}
if (finished) break;
nextInputIdData[0] = nextTokenId;
currentInputIds = nextInputIdsTensor;
currentSeqLen += 1;
currentAttentionMask = new ort.Tensor('int64', attentionMaskData.subarray(0, currentSeqLen), [1, currentSeqLen]);
nextPositionIdData[0] = BigInt(currentSeqLen - 1);
currentPositionIds = nextPositionIdsTensor;
}
await lastDecoderPromise;
return totalSamples;
}
function sample(logitsTensor, seenTokenMask) {
let rawData = logitsTensor.data;
const vocabSize = logitsTensor.dims[2];
const lastStepOffset = (logitsTensor.dims[1] - 1) * vocabSize;
let data;
if (rawData instanceof Uint16Array) {
data = new Float32Array(vocabSize);
for (let j = 0; j < vocabSize; j++) {
data[j] = fp16Lookup[rawData[lastStepOffset + j]];
}
} else {
data = rawData.subarray ? rawData.subarray(lastStepOffset) : rawData.slice(lastStepOffset);
}
const { temperature, topK, topP, repetitionPenalty } = samplingParams;
const useRepetitionPenalty = repetitionPenalty !== 1.0;
const invTemperature = 1.0 / temperature;
// Fast path: Top-K
const k = Math.min(topK, vocabSize);
if (k > 0 && k < vocabSize) {
if (!_topKIndices || _topKIndices.length !== k) {
_topKIndices = new Int32Array(k);
_topKScores = new Float32Array(k);
_topKExp = new Float64Array(k);
_topKOrder = Array.from({ length: k }, (_, i) => i);
}
const heapIndices = _topKIndices;
const heapScores = _topKScores;
let heapSize = 0;
for (let tokenId = 0; tokenId < vocabSize; tokenId++) {
let s = data[tokenId] * invTemperature;
if (useRepetitionPenalty && seenTokenMask[tokenId]) s = s < 0 ? (s * repetitionPenalty) : (s / repetitionPenalty);
if (heapSize < k) {
let pos = heapSize++;
while (pos > 0) {
const parent = (pos - 1) >> 1;
if (heapScores[parent] <= s) break;
heapScores[pos] = heapScores[parent];
heapIndices[pos] = heapIndices[parent];
pos = parent;
}
heapScores[pos] = s;
heapIndices[pos] = tokenId;
} else if (s > heapScores[0]) {
let pos = 0;
while (pos < (k >> 1)) {
let left = (pos << 1) + 1;
let right = left + 1;
let smallest = left;
if (right < k && heapScores[right] < heapScores[left]) smallest = right;
if (heapScores[smallest] >= s) break;
heapScores[pos] = heapScores[smallest];
heapIndices[pos] = heapIndices[smallest];
pos = smallest;
}
heapScores[pos] = s;
heapIndices[pos] = tokenId;
}
}
const expBuf = _topKExp;
const order = _topKOrder;
// Sort
for (let i = 0; i < k; i++) order[i] = i;
order.sort((a, b) => heapScores[b] - heapScores[a]);
const maxScore = heapScores[order[0]];
let sumExp = 0;
for (let i = 0; i < k; i++) {
const w = Math.exp(heapScores[order[i]] - maxScore);
expBuf[i] = w;
sumExp += w;
}
let keep = k;
if (topP < 1.0) {
const threshold = topP * sumExp;
let cumulative = 0;
for (let i = 0; i < k; i++) {
cumulative += expBuf[i];
if (cumulative >= threshold) { keep = i + 1; break; }
}
}
let r = Math.random() * sumExp; // Technically simple top-k sample logic for now
// Correct Top-P sampling needs re-normalization of sumExp over 'keep'
// For simplicity/speed in this giant block, let's just sample from weighted top-k
r = Math.random() * ((topP < 1.0) ? _topKExp.slice(0, keep).reduce((a, b) => a + b, 0) : sumExp);
for (let i = 0; i < keep; i++) {
r -= expBuf[i];
if (r <= 0) return BigInt(heapIndices[order[i]]);
}
return BigInt(heapIndices[order[0]]);
}
return 0n; // Fallback
}