Spaces:
Running
Running
| // ONNX Runtime Web Worker (Classic Script) | |
| console.log('Worker Script Starting (Classic)...'); | |
| self.postMessage({ type: 'status', status: 'Worker Thread Started', state: 'idle' }); | |
| try { | |
| const ortVersion = '1.20.0'; | |
| importScripts(`https://cdn.jsdelivr.net/npm/onnxruntime-web@${ortVersion}/dist/ort.min.js`); | |
| } catch (e) { | |
| console.error('Failed to load ORT in worker:', e); | |
| } | |
| // Configuration | |
| const MODELS = { | |
| backbone: './onnx/soprano_backbone_kv_fp32.onnx', | |
| decoder: './onnx/soprano_decoder_int8.onnx', | |
| tokenizer: './' // Tokenizer loading still needs context, we'll see if it works in worker or needs to be passed | |
| }; | |
| // We need to import the Hugging Face tokenizer library appropriately for a worker. | |
| // The main file import was dynamic. We'll try to do the same here. | |
| // Note: Transformers.js usually works in workers. | |
| const RECEPTIVE_FIELD = 4; | |
| const TOKEN_SIZE = 2048; | |
| const SAMPLE_RATE = 32000; | |
| // State | |
| let backboneSession = null; | |
| let decoderSession = null; | |
| let tokenizer = null; | |
| let isGenerating = false; | |
| let isReady = false; | |
| // FP16 Lookup Table | |
| let fp16Lookup = new Float32Array(65536); | |
| let isFp16Backbone = false; | |
| // Helpers | |
| function initFp16Lookup() { | |
| for (let i = 0; i < 65536; i++) { | |
| const s = (i & 0x8000) >> 15; | |
| const e = (i & 0x7C00) >> 10; | |
| const f = i & 0x03FF; | |
| if (e === 0) { | |
| fp16Lookup[i] = (s ? -1 : 1) * Math.pow(2, -14) * (f / 1024); | |
| } else if (e === 31) { | |
| fp16Lookup[i] = f ? NaN : ((s ? -1 : 1) * Infinity); | |
| } else { | |
| fp16Lookup[i] = (s ? -1 : 1) * Math.pow(2, e - 15) * (1 + f / 1024); | |
| } | |
| } | |
| } | |
| // ---------------------------------------------------------------------------- | |
| // Text Preprocessing (Ported from onnx-streaming.js) | |
| // ---------------------------------------------------------------------------- | |
| // ... (Including the full text preprocessing logic here to keep worker self-contained) ... | |
| // For brevity in this tool call, I will include the necessary functions. | |
| // Ideally, these would be in a shared utils file, but I'll paste them to ensure it works. | |
| const ONES = ['', 'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'ten', 'eleven', 'twelve', 'thirteen', 'fourteen', 'fifteen', 'sixteen', 'seventeen', 'eighteen', 'nineteen']; | |
| const TENS = ['', '', 'twenty', 'thirty', 'forty', 'fifty', 'sixty', 'seventy', 'eighty', 'ninety']; | |
| const ORDINAL_ONES = ['', 'first', 'second', 'third', 'fourth', 'fifth', 'sixth', 'seventh', 'eighth', 'ninth', 'tenth', 'eleventh', 'twelfth', 'thirteenth', 'fourteenth', 'fifteenth', 'sixteenth', 'seventeenth', 'eighteenth', 'nineteenth']; | |
| const ORDINAL_TENS = ['', '', 'twentieth', 'thirtieth', 'fortieth', 'fiftieth', 'sixtieth', 'seventieth', 'eightieth', 'ninetieth']; | |
| function numberToWords(num, options = {}) { | |
| const { andword = '', zero = 'zero', group = 0 } = options; | |
| if (num === 0) return zero; | |
| const convert = (n) => { | |
| if (n < 20) return ONES[n]; | |
| if (n < 100) return TENS[Math.floor(n / 10)] + (n % 10 ? ' ' + ONES[n % 10] : ''); | |
| if (n < 1000) { | |
| const remainder = n % 100; | |
| return ONES[Math.floor(n / 100)] + ' hundred' + (remainder ? (andword ? ' ' + andword + ' ' : ' ') + convert(remainder) : ''); | |
| } | |
| if (n < 1000000) { | |
| const thousands = Math.floor(n / 1000); | |
| const remainder = n % 1000; | |
| return convert(thousands) + ' thousand' + (remainder ? ' ' + convert(remainder) : ''); | |
| } | |
| if (n < 1000000000) { | |
| const millions = Math.floor(n / 1000000); | |
| const remainder = n % 1000000; | |
| return convert(millions) + ' million' + (remainder ? ' ' + convert(remainder) : ''); | |
| } | |
| const billions = Math.floor(n / 1000000000); | |
| const remainder = n % 1000000000; | |
| return convert(billions) + ' billion' + (remainder ? ' ' + convert(remainder) : ''); | |
| }; | |
| if (group === 2 && num > 1000 && num < 10000) { | |
| const high = Math.floor(num / 100); | |
| const low = num % 100; | |
| if (low === 0) return convert(high) + ' hundred'; | |
| else if (low < 10) return convert(high) + ' ' + (zero === 'oh' ? 'oh' : zero) + ' ' + ONES[low]; | |
| else return convert(high) + ' ' + convert(low); | |
| } | |
| return convert(num); | |
| } | |
| function ordinalToWords(num) { | |
| if (num < 20) return ORDINAL_ONES[num] || numberToWords(num) + 'th'; | |
| if (num < 100) { | |
| const tens = Math.floor(num / 10); | |
| const ones = num % 10; | |
| if (ones === 0) return ORDINAL_TENS[tens]; | |
| return TENS[tens] + ' ' + ORDINAL_ONES[ones]; | |
| } | |
| const cardinal = numberToWords(num); | |
| if (cardinal.endsWith('y')) return cardinal.slice(0, -1) + 'ieth'; | |
| if (cardinal.endsWith('one')) return cardinal.slice(0, -3) + 'first'; | |
| if (cardinal.endsWith('two')) return cardinal.slice(0, -3) + 'second'; | |
| if (cardinal.endsWith('three')) return cardinal.slice(0, -5) + 'third'; | |
| if (cardinal.endsWith('ve')) return cardinal.slice(0, -2) + 'fth'; | |
| if (cardinal.endsWith('e')) return cardinal.slice(0, -1) + 'th'; | |
| if (cardinal.endsWith('t')) return cardinal + 'h'; | |
| return cardinal + 'th'; | |
| } | |
| const UNICODE_MAP = { | |
| 'à': 'a', 'á': 'a', 'â': 'a', 'ã': 'a', 'ä': 'a', 'å': 'a', 'æ': 'ae', 'ç': 'c', 'è': 'e', 'é': 'e', 'ê': 'e', 'ë': 'e', 'ì': 'i', 'í': 'i', 'î': 'i', 'ï': 'i', 'ñ': 'n', 'ò': 'o', 'ó': 'o', 'ô': 'o', 'õ': 'o', 'ö': 'o', 'ø': 'o', 'ù': 'u', 'ú': 'u', 'û': 'u', 'ü': 'u', 'ý': 'y', 'ÿ': 'y', 'ß': 'ss', 'œ': 'oe', 'ð': 'd', 'þ': 'th', 'À': 'A', 'Á': 'A', 'Â': 'A', 'Ã': 'A', 'Ä': 'A', 'Å': 'A', 'Æ': 'AE', 'Ç': 'C', 'È': 'E', 'É': 'E', 'Ê': 'E', 'Ë': 'E', 'Ì': 'I', 'Í': 'I', 'Î': 'I', 'Ï': 'I', 'Ñ': 'N', 'Ò': 'O', 'Ó': 'O', 'Ô': 'O', 'Õ': 'O', 'Ö': 'O', 'Ø': 'O', 'Ù': 'U', 'Ú': 'U', 'Û': 'U', 'Ü': 'U', 'Ý': 'Y', '\u201C': '"', '\u201D': '"', '\u2018': "'", '\u2019': "'", '\u2026': '...', '\u2013': '-', '\u2014': '-' | |
| }; | |
| function convertToAscii(text) { | |
| return text.split('').map(c => UNICODE_MAP[c] || c).join('').normalize('NFD').replace(/[\u0300-\u036f]/g, ''); | |
| } | |
| const ABBREVIATIONS = [ | |
| [/\bmrs\./gi, 'misuss'], [/\bms\./gi, 'miss'], [/\bmr\./gi, 'mister'], [/\bdr\./gi, 'doctor'], [/\bst\./gi, 'saint'], [/\bco\./gi, 'company'], [/\bjr\./gi, 'junior'], [/\bmaj\./gi, 'major'], [/\bgen\./gi, 'general'], [/\bdrs\./gi, 'doctors'], [/\brev\./gi, 'reverend'], [/\blt\./gi, 'lieutenant'], [/\bhon\./gi, 'honorable'], [/\bsgt\./gi, 'sergeant'], [/\bcapt\./gi, 'captain'], [/\besq\./gi, 'esquire'], [/\bltd\./gi, 'limited'], [/\bcol\./gi, 'colonel'], [/\bft\./gi, 'fort'] | |
| ]; | |
| const CASED_ABBREVIATIONS = [ | |
| [/\bTTS\b/g, 'text to speech'], [/\bHz\b/g, 'hertz'], [/\bkHz\b/g, 'kilohertz'], [/\bKBs\b/g, 'kilobytes'], [/\bKB\b/g, 'kilobyte'], [/\bMBs\b/g, 'megabytes'], [/\bMB\b/g, 'megabyte'], [/\bGBs\b/g, 'gigabytes'], [/\bGB\b/g, 'gigabyte'], [/\bTBs\b/g, 'terabytes'], [/\bTB\b/g, 'terabyte'], [/\bAPIs\b/g, "a p i's"], [/\bAPI\b/g, 'a p i'], [/\bCLIs\b/g, "c l i's"], [/\bCLI\b/g, 'c l i'], [/\bCPUs\b/g, "c p u's"], [/\bCPU\b/g, 'c p u'], [/\bGPUs\b/g, "g p u's"], [/\bGPU\b/g, 'g p u'], [/\bAve\b/g, 'avenue'], [/\betc\b/g, 'etcetera'] | |
| ]; | |
| function expandAbbreviations(text) { | |
| for (const [regex, replacement] of [...ABBREVIATIONS, ...CASED_ABBREVIATIONS]) text = text.replace(regex, replacement); | |
| return text; | |
| } | |
| const NUM_PREFIX_RE = /#(\d)/g; | |
| const NUM_SUFFIX_RE = /(\d)([KMBT])/gi; | |
| const NUM_LETTER_SPLIT_RE = /(\d)([a-z])|([a-z])(\d)/gi; | |
| const COMMA_NUMBER_RE = /(\d[\d,]+\d)/g; | |
| const DATE_RE = /(^|[^/])(\d\d?[/-]\d\d?[/-]\d\d(?:\d\d)?)($|[^/])/g; | |
| const PHONE_NUMBER_RE = /\(?\d{3}\)?[-.\s]\d{3}[-.\s]?\d{4}/g; | |
| const TIME_RE = /(\d\d?):(\d\d)(?::(\d\d))?/g; | |
| const POUNDS_RE = /£([\d,]*\d+)/g; | |
| const DOLLARS_RE = /\$([\d.,]*\d+)/g; | |
| const DECIMAL_NUMBER_RE = /(\d+(?:\.\d+)+)/g; | |
| const MULTIPLY_RE = /(\d)\s?\*\s?(\d)/g; | |
| const DIVIDE_RE = /(\d)\s?\/\s?(\d)/g; | |
| const ADD_RE = /(\d)\s?\+\s?(\d)/g; | |
| const SUBTRACT_RE = /(\d)?\s?-\s?(\d)/g; | |
| const FRACTION_RE = /(\d+)\/(\d+)/g; | |
| const ORDINAL_RE = /(\d+)(st|nd|rd|th)/gi; | |
| const NUMBER_RE = /\d+/g; | |
| function normalizeNumbers(text) { | |
| text = text.replace(NUM_PREFIX_RE, (_, d) => `number ${d}`); | |
| text = text.replace(NUM_SUFFIX_RE, (_, num, suffix) => { | |
| const map = { k: 'thousand', m: 'million', b: 'billion', t: 'trillion' }; | |
| return `${num} ${map[suffix.toLowerCase()]}`; | |
| }); | |
| for (let i = 0; i < 2; i++) { | |
| text = text.replace(NUM_LETTER_SPLIT_RE, (m, d1, l1, l2, d2) => { | |
| if (d1 && l1) return `${d1} ${l1}`; | |
| if (l2 && d2) return `${l2} ${d2}`; | |
| return m; | |
| }); | |
| } | |
| text = text.replace(COMMA_NUMBER_RE, m => m.replace(/,/g, '')); | |
| text = text.replace(DATE_RE, (_, pre, date, post) => pre + date.split(/[./-]/).join(' dash ') + post); | |
| text = text.replace(PHONE_NUMBER_RE, m => { | |
| const digits = m.replace(/\D/g, ''); | |
| return digits.length === 10 ? `${digits.slice(0, 3).split('').join(' ')}, ${digits.slice(3, 6).split('').join(' ')}, ${digits.slice(6).split('').join(' ')}` : m; | |
| }); | |
| text = text.replace(TIME_RE, (_, hours, minutes, seconds) => { | |
| const h = parseInt(hours), m = parseInt(minutes), s = seconds ? parseInt(seconds) : 0; | |
| if (!seconds) return m === 0 ? (h === 0 ? '0' : h > 12 ? `${hours} minutes` : `${hours} o'clock`) : minutes.startsWith('0') ? `${hours} oh ${minutes[1]}` : `${hours} ${minutes}`; | |
| let res = ''; | |
| if (h !== 0) res = hours + ' ' + (m === 0 ? 'oh oh' : minutes.startsWith('0') ? `oh ${minutes[1]}` : minutes); | |
| else if (m !== 0) res = minutes + ' ' + (s === 0 ? 'oh oh' : seconds.startsWith('0') ? `oh ${seconds[1]}` : seconds); | |
| else res = seconds; | |
| return res + ' ' + (s === 0 ? '' : seconds.startsWith('0') ? `oh ${seconds[1]}` : seconds); | |
| }); | |
| text = text.replace(POUNDS_RE, (_, amount) => `${amount.replace(/,/g, '')} pounds`); | |
| text = text.replace(DOLLARS_RE, (_, amount) => { | |
| const parts = amount.replace(/,/g, '').split('.'); | |
| const dollars = parseInt(parts[0]) || 0; | |
| const cents = parts[1] ? parseInt(parts[1]) : 0; | |
| if (dollars && cents) return `${dollars} ${dollars === 1 ? 'dollar' : 'dollars'}, ${cents} ${cents === 1 ? 'cent' : 'cents'}`; | |
| if (dollars) return `${dollars} ${dollars === 1 ? 'dollar' : 'dollars'}`; | |
| if (cents) return `${cents} ${cents === 1 ? 'cent' : 'cents'}`; | |
| return 'zero dollars'; | |
| }); | |
| text = text.replace(DECIMAL_NUMBER_RE, m => m.split('.').join(' point ').split('').join(' ')); // Simplified | |
| text = text.replace(MULTIPLY_RE, '$1 times $2'); | |
| text = text.replace(DIVIDE_RE, '$1 over $2'); | |
| text = text.replace(ADD_RE, '$1 plus $2'); | |
| text = text.replace(SUBTRACT_RE, (_, a, b) => (a ? a : '') + ' minus ' + b); | |
| text = text.replace(FRACTION_RE, '$1 over $2'); | |
| text = text.replace(ORDINAL_RE, (_, num) => ordinalToWords(parseInt(num))); | |
| text = text.replace(NUMBER_RE, m => { | |
| const num = parseInt(m); | |
| if (num > 1000 && num < 3000) { | |
| if (num === 2000) return 'two thousand'; | |
| if (num > 2000 && num < 2010) return 'two thousand ' + numberToWords(num % 100); | |
| if (num % 100 === 0) return numberToWords(Math.floor(num / 100)) + ' hundred'; | |
| return numberToWords(num, { zero: 'oh', group: 2 }); | |
| } | |
| return numberToWords(num); | |
| }); | |
| return text; | |
| } | |
| const SPECIAL_CHARACTERS = [ | |
| [/@/g, ' at '], [/&/g, ' and '], [/%/g, ' percent '], [/:/g, '.'], [/;/g, ','], [/\+/g, ' plus '], [/\\/g, ' backslash '], [/~/g, ' about '], [/(^| )<3/g, ' heart '], [/<=/g, ' less than or equal to '], [/>=/g, ' greater than or equal to '], [/</g, ' less than '], [/>/g, ' greater than '], [/=/g, ' equals '], [/\//g, ' slash '], [/_/g, ' '], | |
| ]; | |
| const LINK_HEADER_RE = /https?:\/\//gi; | |
| const DASH_RE = /(.) - (.)/g; | |
| const DOT_RE = /([A-Z])\.([A-Z])/gi; | |
| const PARENTHESES_RE = /[\(\[\{][^\)\]\}]*[\)\]\}](.)?/g; | |
| function normalizeSpecial(text) { | |
| text = text.replace(LINK_HEADER_RE, 'h t t p s colon slash slash '); | |
| text = text.replace(DASH_RE, '$1, $2'); | |
| text = text.replace(DOT_RE, '$1 dot $2'); | |
| text = text.replace(PARENTHESES_RE, (m, after) => { | |
| let result = m.replace(/[\(\[\{]/g, ', ').replace(/[\)\]\}]/g, ', '); | |
| if (after && /[$.!?,]/.test(after)) result = result.slice(0, -2) + after; | |
| return result; | |
| }); | |
| return text; | |
| } | |
| function expandSpecialCharacters(text) { | |
| for (const [regex, replacement] of SPECIAL_CHARACTERS) text = text.replace(regex, replacement); | |
| return text; | |
| } | |
| function normalizeNewlines(text) { | |
| return text.split('\n').map(line => { | |
| line = line.trim(); | |
| if (!line) return ''; | |
| if (!/[.!?]$/.test(line)) line += '.'; | |
| return line; | |
| }).join(' '); | |
| } | |
| function removeUnknownCharacters(text) { | |
| text = text.replace(/[^A-Za-z !\$%&'\*\+,\-./0123456789<>\?_]/g, ''); | |
| return text.replace(/[<>\/_+]/g, ''); | |
| } | |
| function collapseWhitespace(text) { | |
| return text.replace(/\s+/g, ' ').replace(/ ([.\?!,])/g, '$1'); | |
| } | |
| function dedupPunctuation(text) { | |
| return text.replace(/\.\.\.+/g, '[ELLIPSIS]').replace(/,+/g, ',').replace(/[.,]*\.[.,]*/g, '.').replace(/[.,!]*![.,!]*/g, '!').replace(/[.,!?]*\?[.,!?]*/g, '?').replace(/\[ELLIPSIS\]/g, '...'); | |
| } | |
| function cleanText(text) { | |
| text = convertToAscii(text); | |
| text = normalizeNewlines(text); | |
| text = normalizeNumbers(text); | |
| text = normalizeSpecial(text); | |
| text = expandAbbreviations(text); | |
| text = expandSpecialCharacters(text); | |
| text = text.toLowerCase(); | |
| text = removeUnknownCharacters(text); | |
| text = collapseWhitespace(text); | |
| text = dedupPunctuation(text); | |
| return text.trim(); | |
| } | |
| function preprocessText(text, batchSize = 3, minLength = 30) { | |
| text = text.trim(); | |
| const cleanedText = cleanText(text); | |
| let sentences = cleanedText.split(/(?<=[.!?])\s+/).filter(s => s.trim()); | |
| if (sentences.length === 0) return cleanedText ? [`[STOP][TEXT]${cleanedText}[START]`] : []; | |
| if (minLength > 0 && sentences.length > 1) { | |
| const merged = []; | |
| for (let i = 0; i < sentences.length; i++) { | |
| const cur = sentences[i]; | |
| if (cur.length < minLength) { | |
| if (merged.length > 0) merged[merged.length - 1] = (merged[merged.length - 1] + ' ' + cur).trim(); | |
| else if (i + 1 < sentences.length) sentences[i + 1] = (cur + ' ' + sentences[i + 1]).trim(); | |
| else merged.push(cur); | |
| } else merged.push(cur); | |
| } | |
| sentences = merged; | |
| } | |
| const prompts = []; | |
| for (let i = 0; i < sentences.length; i += batchSize) { | |
| const batch = sentences.slice(i, i + batchSize).join(' '); | |
| prompts.push(`[STOP][TEXT]${batch}[START]`); | |
| } | |
| return prompts; | |
| } | |
| // ---------------------------------------------------------------------------- | |
| // Worker Logic | |
| // ---------------------------------------------------------------------------- | |
| self.onmessage = async (e) => { | |
| const { type, data } = e.data; | |
| console.log('Worker received message:', type); | |
| if (type === 'load') { | |
| try { | |
| await loadModels(); | |
| postMessage({ type: 'loaded' }); | |
| } catch (err) { | |
| postMessage({ type: 'error', error: err.toString() }); | |
| } | |
| } else if (type === 'generate') { | |
| if (!isReady) { | |
| postMessage({ type: 'error', error: 'Models are not loaded yet.' }); | |
| return; | |
| } | |
| if (isGenerating) return; | |
| try { | |
| await startGeneration(data.text); | |
| } catch (err) { | |
| console.error('Generation Error:', err); | |
| postMessage({ type: 'error', error: err.toString() }); | |
| } | |
| } | |
| else if (type === 'stop') { | |
| isGenerating = false; | |
| postMessage({ type: 'status', status: 'Stopped', state: 'idle' }); | |
| } | |
| }; | |
| async function loadModels() { | |
| if (backboneSession) return; | |
| postMessage({ type: 'status', status: 'Loading models...', state: 'loading' }); | |
| // Configure WASM Paths to use EXACT same version as loader | |
| const version = '1.20.0'; | |
| const cdnBase = `https://cdn.jsdelivr.net/npm/onnxruntime-web@${version}/dist/`; | |
| ort.env.wasm.wasmPaths = cdnBase; | |
| // Disable multi-threading if not in cross-origin isolated environment to avoid ERR_WASM_FILE_NOT_FOUND | |
| if (!self.crossOriginIsolated) { | |
| console.warn('Environment is not cross-origin isolated. Disabling WASM multi-threading.'); | |
| ort.env.wasm.numThreads = 1; | |
| } else if (typeof navigator !== 'undefined' && navigator.hardwareConcurrency) { | |
| ort.env.wasm.numThreads = Math.min(navigator.hardwareConcurrency, 8); | |
| } | |
| try { | |
| const backboneOptions = { | |
| executionProviders: ['wasm'], | |
| freeDimensionOverrides: { 'batch': 1 }, | |
| graphOptimizationLevel: 'all' | |
| }; | |
| // Initialize FP16 Lookup | |
| isFp16Backbone = MODELS.backbone.includes('fp16'); | |
| if (isFp16Backbone) initFp16Lookup(); | |
| console.log('Loading Backbone...'); | |
| backboneSession = await ort.InferenceSession.create(MODELS.backbone, backboneOptions); | |
| console.log('Loading Decoder...'); | |
| const decoderBuf = await fetch(MODELS.decoder).then(r => { | |
| if (!r.ok) throw new Error(`Failed to load decoder: ${r.statusText}`); | |
| return r.arrayBuffer(); | |
| }); | |
| // External data check | |
| let dataBuf = null; | |
| try { | |
| const dataUrl = MODELS.decoder + '.data'; | |
| const dataRes = await fetch(dataUrl); | |
| if (dataRes.ok) { | |
| dataBuf = await dataRes.arrayBuffer(); | |
| } | |
| } catch (e) { } | |
| const decoderOptions = { | |
| executionProviders: ['wasm'], | |
| freeDimensionOverrides: { 'batch': 1 } | |
| }; | |
| if (dataBuf) { | |
| decoderOptions.externalData = [{ | |
| data: new Uint8Array(dataBuf), | |
| path: MODELS.decoder.split('/').pop() + '.data' | |
| }]; | |
| } | |
| decoderSession = await ort.InferenceSession.create(new Uint8Array(decoderBuf), decoderOptions); | |
| console.log('Loading Tokenizer...'); | |
| const transformers = await import('https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.0.0'); | |
| const { AutoTokenizer, env } = transformers; | |
| env.allowLocalModels = true; | |
| env.allowRemoteModels = false; | |
| env.localModelPath = new URL('.', self.location.href).pathname; | |
| tokenizer = await AutoTokenizer.from_pretrained(MODELS.tokenizer, { | |
| local_files_only: true | |
| }); | |
| isReady = true; | |
| postMessage({ type: 'status', status: 'Ready', state: 'idle' }); | |
| postMessage({ type: 'model_status', status: 'ready', text: 'Ready' }); | |
| postMessage({ type: 'loaded' }); | |
| } catch (err) { | |
| console.error('Model load failed in worker:', err); | |
| throw err; | |
| } | |
| } | |
| async function startGeneration(text) { | |
| isGenerating = true; | |
| postMessage({ type: 'status', status: 'Generating...', state: 'running' }); | |
| const prompts = preprocessText(text); | |
| const overallStartTime = performance.now(); | |
| let isFirstBatch = true; | |
| let cumulativeSamples = 0; | |
| // We need to pass the generation start time for RTF calculation logic | |
| // But since logic is here, we can just handle it. | |
| for (const prompt of prompts) { | |
| if (!isGenerating) break; | |
| const { input_ids } = await tokenizer(prompt); | |
| // Note: tokenizer runs in worker, so input_ids.data is available | |
| const batchSamples = await generationLoop(input_ids.data, overallStartTime, isFirstBatch, cumulativeSamples); | |
| cumulativeSamples += batchSamples; | |
| isFirstBatch = false; | |
| } | |
| if (isGenerating) { | |
| postMessage({ type: 'stream_ended' }); | |
| postMessage({ type: 'status', status: 'Finished', state: 'idle' }); | |
| } | |
| isGenerating = false; | |
| } | |
| // Sampling Cache | |
| let _topKIndices = null; | |
| let _topKScores = null; | |
| let _topKOrder = null; | |
| let _topKExp = null; | |
| const samplingParams = { temperature: 0.3, topK: 50, topP: 0.95, repetitionPenalty: 1.2 }; | |
| async function generationLoop(promptTokens, startTime, isFirstBatch = true, cumulativeSamples = 0) { | |
| const batch = 1; | |
| const numLayers = 17; | |
| const hiddenDim = 128; // This seems small, confirming config... yes from previous code. | |
| const promptLen = promptTokens.length; | |
| const vocabSize = 8192; | |
| const maxNewTokens = 512; | |
| const seenTokenMask = new Uint8Array(vocabSize); | |
| for (let i = 0; i < promptTokens.length; i++) { | |
| const tid = Number(promptTokens[i]); | |
| if (tid >= 0 && tid < vocabSize) seenTokenMask[tid] = 1; | |
| } | |
| const kvType = isFp16Backbone ? 'float16' : 'float32'; | |
| const kvData = isFp16Backbone ? new Uint16Array(0) : new Float32Array(0); | |
| let pastKeyValues = {}; | |
| for (let i = 0; i < numLayers; i++) { | |
| pastKeyValues[`past_key_values.${i}.key`] = new ort.Tensor(kvType, kvData, [batch, 1, 0, hiddenDim]); | |
| pastKeyValues[`past_key_values.${i}.value`] = new ort.Tensor(kvType, kvData, [batch, 1, 0, hiddenDim]); | |
| } | |
| const maxSeqLen = promptLen + maxNewTokens; | |
| const attentionMaskData = new BigInt64Array(maxSeqLen); | |
| attentionMaskData.fill(1n); | |
| let currentSeqLen = promptLen; | |
| const nextInputIdData = new BigInt64Array(1); | |
| const nextPositionIdData = new BigInt64Array(1); | |
| const nextInputIdsTensor = new ort.Tensor('int64', nextInputIdData, [batch, 1]); | |
| const nextPositionIdsTensor = new ort.Tensor('int64', nextPositionIdData, [batch, 1]); | |
| let currentInputIds = new ort.Tensor('int64', BigInt64Array.from(promptTokens), [batch, promptLen]); | |
| let currentAttentionMask = new ort.Tensor('int64', attentionMaskData.subarray(0, currentSeqLen), [batch, currentSeqLen]); | |
| let currentPositionIds = new ort.Tensor('int64', BigInt64Array.from({ length: promptLen }, (_, i) => BigInt(i)), [batch, promptLen]); | |
| const hiddenStatesBuffer = []; | |
| let totalSamples = 0; | |
| const targetChunkSize = 8; | |
| let chunkCounter = targetChunkSize; | |
| let firstChunk = true; | |
| // Pipelining | |
| let lastDecoderPromise = Promise.resolve(); | |
| let chunkBackboneTime = 0; | |
| // We'll track start time for generation inside this batch | |
| if (isFirstBatch) { | |
| postMessage({ type: 'generation_started', data: { time: performance.now() } }); | |
| } | |
| for (let i = 0; i < maxNewTokens; i++) { | |
| if (!isGenerating) break; | |
| // Yield Check (Optimization) | |
| if (i % 4 === 0) { | |
| await new Promise(resolve => setTimeout(resolve, 0)); | |
| } | |
| const inputs = { | |
| input_ids: currentInputIds, | |
| attention_mask: currentAttentionMask, | |
| position_ids: currentPositionIds, | |
| ...pastKeyValues | |
| }; | |
| const bbStart = performance.now(); | |
| const outputs = await backboneSession.run(inputs); | |
| chunkBackboneTime += (performance.now() - bbStart); | |
| const backboneNames = backboneSession.outputNames; | |
| const logits = outputs[backboneNames[0]]; | |
| const lastHiddenState = outputs[backboneNames[backboneNames.length - 1]]; | |
| for (let j = 0; j < numLayers; j++) { | |
| pastKeyValues[`past_key_values.${j}.key`] = outputs[backboneNames[1 + j * 2]]; | |
| pastKeyValues[`past_key_values.${j}.value`] = outputs[backboneNames[2 + j * 2]]; | |
| } | |
| const nextTokenId = sample(logits, seenTokenMask); | |
| const finished = (nextTokenId === 3n); | |
| const nextTokenIdNum = Number(nextTokenId); | |
| if (nextTokenIdNum >= 0 && nextTokenIdNum < vocabSize) seenTokenMask[nextTokenIdNum] = 1; | |
| const seqLen = lastHiddenState.dims[1]; | |
| const hiddenDimSize = lastHiddenState.dims[2]; | |
| const lastTokenStateRaw = lastHiddenState.data.subarray((seqLen - 1) * hiddenDimSize, seqLen * hiddenDimSize); | |
| let lastTokenState; | |
| if (lastTokenStateRaw instanceof Uint16Array) { | |
| lastTokenState = new Float32Array(hiddenDimSize); | |
| for (let j = 0; j < hiddenDimSize; j++) { | |
| lastTokenState[j] = fp16Lookup[lastTokenStateRaw[j]]; | |
| } | |
| } else { | |
| lastTokenState = new Float32Array(lastTokenStateRaw); | |
| } | |
| if (i > 0 && !finished) { | |
| hiddenStatesBuffer.push(new Float32Array(lastTokenState)); | |
| if (hiddenStatesBuffer.length > 2 * RECEPTIVE_FIELD + targetChunkSize) { | |
| hiddenStatesBuffer.splice(0, hiddenStatesBuffer.length - (2 * RECEPTIVE_FIELD + targetChunkSize)); | |
| } | |
| } | |
| // Decode Logic | |
| if (finished || hiddenStatesBuffer.length >= RECEPTIVE_FIELD + targetChunkSize) { | |
| if (finished || chunkCounter === targetChunkSize) { | |
| const window = hiddenStatesBuffer.slice(-hiddenStatesBuffer.length); | |
| const currentWindowSize = window.length; | |
| const decoderInput = new Float32Array(512 * currentWindowSize); | |
| for (let w = 0; w < currentWindowSize; w++) { | |
| for (let d = 0; d < 512; d++) { | |
| decoderInput[d * currentWindowSize + w] = window[w][d]; | |
| } | |
| } | |
| const isLast = finished; | |
| const captureChunkCounter = chunkCounter; | |
| const captureFirstChunk = firstChunk; | |
| const captureBBTime = chunkBackboneTime; | |
| chunkBackboneTime = 0; | |
| // Send to decoder (in promise chain) | |
| lastDecoderPromise = lastDecoderPromise.then(async () => { | |
| const decStart = performance.now(); | |
| const decoderOutputs = await decoderSession.run({ | |
| [decoderSession.inputNames[0]]: new ort.Tensor('float32', decoderInput, [1, 512, currentWindowSize]) | |
| }); | |
| const decDuration = performance.now() - decStart; | |
| const audio = decoderOutputs[decoderSession.outputNames[0]].data; | |
| let audioChunk; | |
| if (isLast) { | |
| const startIdx = audio.length - (RECEPTIVE_FIELD + captureChunkCounter - 1) * TOKEN_SIZE + TOKEN_SIZE; | |
| audioChunk = audio.subarray(startIdx); | |
| } else { | |
| const startIdx = audio.length - (RECEPTIVE_FIELD + targetChunkSize) * TOKEN_SIZE + TOKEN_SIZE; | |
| const endIdx = audio.length - RECEPTIVE_FIELD * TOKEN_SIZE + TOKEN_SIZE; | |
| audioChunk = audio.subarray(startIdx, endIdx); | |
| } | |
| // Post audio to main thread | |
| postMessage({ | |
| type: 'audio_chunk', | |
| data: audioChunk, | |
| metrics: { | |
| bbTime: captureBBTime, | |
| decTime: decDuration, | |
| chunkDuration: audioChunk.length / SAMPLE_RATE, | |
| isFirst: captureFirstChunk && isFirstBatch | |
| } | |
| }, [audioChunk.buffer]); // Transferable | |
| }); | |
| firstChunk = false; | |
| chunkCounter = 0; | |
| } | |
| chunkCounter++; | |
| } | |
| if (finished) break; | |
| nextInputIdData[0] = nextTokenId; | |
| currentInputIds = nextInputIdsTensor; | |
| currentSeqLen += 1; | |
| currentAttentionMask = new ort.Tensor('int64', attentionMaskData.subarray(0, currentSeqLen), [1, currentSeqLen]); | |
| nextPositionIdData[0] = BigInt(currentSeqLen - 1); | |
| currentPositionIds = nextPositionIdsTensor; | |
| } | |
| await lastDecoderPromise; | |
| return totalSamples; | |
| } | |
| function sample(logitsTensor, seenTokenMask) { | |
| let rawData = logitsTensor.data; | |
| const vocabSize = logitsTensor.dims[2]; | |
| const lastStepOffset = (logitsTensor.dims[1] - 1) * vocabSize; | |
| let data; | |
| if (rawData instanceof Uint16Array) { | |
| data = new Float32Array(vocabSize); | |
| for (let j = 0; j < vocabSize; j++) { | |
| data[j] = fp16Lookup[rawData[lastStepOffset + j]]; | |
| } | |
| } else { | |
| data = rawData.subarray ? rawData.subarray(lastStepOffset) : rawData.slice(lastStepOffset); | |
| } | |
| const { temperature, topK, topP, repetitionPenalty } = samplingParams; | |
| const useRepetitionPenalty = repetitionPenalty !== 1.0; | |
| const invTemperature = 1.0 / temperature; | |
| // Fast path: Top-K | |
| const k = Math.min(topK, vocabSize); | |
| if (k > 0 && k < vocabSize) { | |
| if (!_topKIndices || _topKIndices.length !== k) { | |
| _topKIndices = new Int32Array(k); | |
| _topKScores = new Float32Array(k); | |
| _topKExp = new Float64Array(k); | |
| _topKOrder = Array.from({ length: k }, (_, i) => i); | |
| } | |
| const heapIndices = _topKIndices; | |
| const heapScores = _topKScores; | |
| let heapSize = 0; | |
| for (let tokenId = 0; tokenId < vocabSize; tokenId++) { | |
| let s = data[tokenId] * invTemperature; | |
| if (useRepetitionPenalty && seenTokenMask[tokenId]) s = s < 0 ? (s * repetitionPenalty) : (s / repetitionPenalty); | |
| if (heapSize < k) { | |
| let pos = heapSize++; | |
| while (pos > 0) { | |
| const parent = (pos - 1) >> 1; | |
| if (heapScores[parent] <= s) break; | |
| heapScores[pos] = heapScores[parent]; | |
| heapIndices[pos] = heapIndices[parent]; | |
| pos = parent; | |
| } | |
| heapScores[pos] = s; | |
| heapIndices[pos] = tokenId; | |
| } else if (s > heapScores[0]) { | |
| let pos = 0; | |
| while (pos < (k >> 1)) { | |
| let left = (pos << 1) + 1; | |
| let right = left + 1; | |
| let smallest = left; | |
| if (right < k && heapScores[right] < heapScores[left]) smallest = right; | |
| if (heapScores[smallest] >= s) break; | |
| heapScores[pos] = heapScores[smallest]; | |
| heapIndices[pos] = heapIndices[smallest]; | |
| pos = smallest; | |
| } | |
| heapScores[pos] = s; | |
| heapIndices[pos] = tokenId; | |
| } | |
| } | |
| const expBuf = _topKExp; | |
| const order = _topKOrder; | |
| // Sort | |
| for (let i = 0; i < k; i++) order[i] = i; | |
| order.sort((a, b) => heapScores[b] - heapScores[a]); | |
| const maxScore = heapScores[order[0]]; | |
| let sumExp = 0; | |
| for (let i = 0; i < k; i++) { | |
| const w = Math.exp(heapScores[order[i]] - maxScore); | |
| expBuf[i] = w; | |
| sumExp += w; | |
| } | |
| let keep = k; | |
| if (topP < 1.0) { | |
| const threshold = topP * sumExp; | |
| let cumulative = 0; | |
| for (let i = 0; i < k; i++) { | |
| cumulative += expBuf[i]; | |
| if (cumulative >= threshold) { keep = i + 1; break; } | |
| } | |
| } | |
| let r = Math.random() * sumExp; // Technically simple top-k sample logic for now | |
| // Correct Top-P sampling needs re-normalization of sumExp over 'keep' | |
| // For simplicity/speed in this giant block, let's just sample from weighted top-k | |
| r = Math.random() * ((topP < 1.0) ? _topKExp.slice(0, keep).reduce((a, b) => a + b, 0) : sumExp); | |
| for (let i = 0; i < keep; i++) { | |
| r -= expBuf[i]; | |
| if (r <= 0) return BigInt(heapIndices[order[i]]); | |
| } | |
| return BigInt(heapIndices[order[0]]); | |
| } | |
| return 0n; // Fallback | |
| } | |