ykhrustalev's picture
Fix TTS bug - Error: input 'depth_slices_in' is missing in 'feeds'. (#3)
c72d46d
/**
* LFM2-Audio Model Runner for ONNX Runtime Web
*
* Runs audio model inference using ONNX models:
* 1. decoder.onnx - LFM2 backbone (shared with text)
* 2. audio_encoder.onnx - Conformer encoder for ASR (mel → embeddings)
* 3. audio_embedding.onnx - Audio code embeddings for TTS
* 4. audio_detokenizer.onnx - Audio codes → STFT features
* 5. vocoder_depthformer.onnx - Autoregressive codebook prediction
*
* Supports ASR mode for the webapp (transcription).
*/
import * as ort from 'onnxruntime-web';
import { AutoTokenizer, env } from '@huggingface/transformers';
import { loadMelConfig, computeMelSpectrogram, loadAudioFile } from './audio-processor.js';
// Cache configuration
const CACHE_NAME = 'onnx-models-v1';
const IDB_NAME = 'onnx-model-cache';
const IDB_STORE = 'models';
// IndexedDB helpers for fallback caching
let idbPromise = null;
function openIDB() {
if (idbPromise) return idbPromise;
idbPromise = new Promise((resolve, reject) => {
const request = indexedDB.open(IDB_NAME, 1);
request.onerror = () => reject(request.error);
request.onsuccess = () => resolve(request.result);
request.onupgradeneeded = (event) => {
const db = event.target.result;
if (!db.objectStoreNames.contains(IDB_STORE)) {
db.createObjectStore(IDB_STORE);
}
};
});
return idbPromise;
}
async function idbGet(key) {
try {
const db = await openIDB();
return new Promise((resolve, reject) => {
const tx = db.transaction(IDB_STORE, 'readonly');
const store = tx.objectStore(IDB_STORE);
const request = store.get(key);
request.onerror = () => reject(request.error);
request.onsuccess = () => resolve(request.result);
});
} catch (e) {
return null;
}
}
async function idbSet(key, value) {
try {
const db = await openIDB();
return new Promise((resolve, reject) => {
const tx = db.transaction(IDB_STORE, 'readwrite');
const store = tx.objectStore(IDB_STORE);
const request = store.put(value, key);
request.onerror = () => reject(request.error);
request.onsuccess = () => resolve();
});
} catch (e) {
// Ignore cache write failures
}
}
// Special tokens for audio model
const SPECIAL_TOKENS = {
AUDIO_START: 128, // <|audio_start|>
TEXT_START: 129, // <|text_start|>
TEXT_END: 130, // <|text_end|>
MIXED_START: 131, // <|mixed_start|>
MIXED_END: 132, // <|mixed_end|>
IM_END: 7, // <|im_end|>
};
// Audio codebook constants
const NUM_CODEBOOKS = 8;
const CODEBOOK_VOCAB = 2049;
const END_OF_AUDIO_TOKEN = 2048;
// Default system prompts (matching Python lfm2-audio-infer)
const DEFAULT_SYSTEM_PROMPT_ASR = 'Perform ASR.';
const DEFAULT_SYSTEM_PROMPT_TTS = 'Perform TTS. Use the UK female voice.';
const DEFAULT_SYSTEM_PROMPT_INTERLEAVED = 'Respond with interleaved text and audio.';
// Max tokens defaults (matching liquid-audio)
// Each audio frame = 80ms (6x upsampling in detokenizer, 320 hop, 24kHz)
// 1024 frames ≈ 82 seconds of audio
const DEFAULT_MAX_TOKENS_AUDIO = 1024; // TTS and interleaved modes
const DEFAULT_MAX_TOKENS_TEXT = 100; // ASR mode
// Timestamped logging helper
let _logStartTime = null;
function log(...args) {
if (_logStartTime === null) {
_logStartTime = performance.now();
}
const elapsed = ((performance.now() - _logStartTime) / 1000).toFixed(2);
console.log(`[${elapsed}s]`, ...args);
}
function logReset() {
_logStartTime = performance.now();
}
/**
* Fetch with caching support using Cache API or IndexedDB fallback
*/
async function fetchWithCache(url, options = {}) {
if (!url.startsWith('http://') && !url.startsWith('https://')) {
return fetch(url, options);
}
const fileName = url.split('/').pop();
// Try Cache API first
if (typeof caches !== 'undefined') {
try {
const cache = await caches.open(CACHE_NAME);
const cached = await cache.match(url);
if (cached) {
console.log(`[Cache HIT] ${fileName}`);
return cached;
}
console.log(`[Cache MISS] Fetching ${fileName}...`);
const response = await fetch(url, options);
if (response.ok) {
cache.put(url, response.clone());
}
return response;
} catch (e) {
// Fall through to IndexedDB
}
}
// Try IndexedDB fallback
if (typeof indexedDB !== 'undefined') {
try {
const cached = await idbGet(url);
if (cached) {
console.log(`[IDB Cache HIT] ${fileName}`);
return new Response(cached.data, {
status: 200,
headers: { 'Content-Type': cached.contentType || 'application/octet-stream' },
});
}
console.log(`[IDB Cache MISS] Fetching ${fileName}...`);
const response = await fetch(url, options);
if (response.ok) {
const clone = response.clone();
const data = await clone.arrayBuffer();
const contentType = response.headers.get('Content-Type') || 'application/octet-stream';
await idbSet(url, { data, contentType });
}
return response;
} catch (e) {
console.warn('IndexedDB cache failed:', e);
}
}
// Direct fetch as last resort
console.log(`[No Cache] Fetching ${fileName}...`);
return fetch(url, options);
}
/**
* Clear the model cache (both Cache API and IndexedDB)
*/
export async function clearModelCache() {
let deleted = false;
// Clear Cache API
if (typeof caches !== 'undefined') {
try {
deleted = await caches.delete(CACHE_NAME);
} catch (e) {
// Ignore
}
}
// Clear IndexedDB
if (typeof indexedDB !== 'undefined') {
try {
await new Promise((resolve, reject) => {
const request = indexedDB.deleteDatabase(IDB_NAME);
request.onerror = () => reject(request.error);
request.onsuccess = () => resolve();
});
idbPromise = null; // Reset the cached promise
deleted = true;
} catch (e) {
// Ignore
}
}
console.log(deleted ? 'Model cache cleared' : 'No cache to clear');
return deleted;
}
/**
* Get cache storage usage info
*/
export async function getCacheInfo() {
if ('storage' in navigator && 'estimate' in navigator.storage) {
const estimate = await navigator.storage.estimate();
return {
used: estimate.usage || 0,
available: estimate.quota || 0,
};
}
return null;
}
/**
* Load tokenizer from model path
*/
async function loadTokenizerFromPath(modelPath) {
const isRemote = modelPath.startsWith('http://') || modelPath.startsWith('https://');
console.log(`Loading tokenizer from ${isRemote ? 'remote' : 'local'}: ${modelPath}`);
const fetchOptions = isRemote ? { mode: 'cors', credentials: 'omit' } : {};
const [tokenizerResponse, configResponse] = await Promise.all([
fetchWithCache(`${modelPath}/tokenizer.json`, fetchOptions),
fetchWithCache(`${modelPath}/tokenizer_config.json`, fetchOptions),
]);
if (!tokenizerResponse.ok) {
throw new Error(`Failed to fetch tokenizer.json: ${tokenizerResponse.status}`);
}
if (!configResponse.ok) {
throw new Error(`Failed to fetch tokenizer_config.json: ${configResponse.status}`);
}
const tokenizerJSON = await tokenizerResponse.text();
const configJSON = await configResponse.text();
// Parse tokenizer.json to extract special token IDs
const tokenizerData = JSON.parse(tokenizerJSON);
const specialTokens = {};
if (tokenizerData.added_tokens) {
for (const token of tokenizerData.added_tokens) {
specialTokens[token.content] = token.id;
}
console.log('Found special tokens:', Object.keys(specialTokens).length);
}
// Create tokenizer using transformers.js
const fakeModelId = `tokenizer-${Date.now()}`;
const fileCache = {
'tokenizer.json': tokenizerJSON,
'tokenizer_config.json': configJSON,
};
const originalFetch = globalThis.fetch;
globalThis.fetch = async (input, init) => {
const url = typeof input === 'string' ? input : input.url;
if (url.includes(fakeModelId)) {
for (const [filename, content] of Object.entries(fileCache)) {
if (url.includes(filename)) {
return new Response(content, {
status: 200,
headers: { 'Content-Type': 'application/json' },
});
}
}
return new Response('Not found', { status: 404 });
}
return originalFetch(input, init);
};
const originalAllowLocal = env.allowLocalModels;
env.allowLocalModels = false;
try {
const tokenizer = await AutoTokenizer.from_pretrained(fakeModelId);
console.log('Tokenizer created successfully');
return { tokenizer, specialTokens };
} finally {
globalThis.fetch = originalFetch;
env.allowLocalModels = originalAllowLocal;
}
}
export class AudioModel {
constructor() {
this.tokenizer = null;
this.decoderSession = null;
this.audioEncoderSession = null;
this.audioEmbeddingSession = null;
this.audioEmbeddingWeight = null; // Direct lookup (faster than ONNX)
this.audioDetokenizerSession = null;
this.vocoderSession = null;
this.config = null;
this.embedTokensWeight = null;
// Model config
this.hiddenSize = 2048;
this.numLayers = 16;
this.numKVHeads = 8;
this.headDim = 64;
this.convL = 3;
this.layerTypes = [];
this.vocabSize = 65536;
// === Stateful cache for multi-turn conversation ===
this.cache = null;
this.cacheSeqLen = 0;
}
/**
* Reset conversation state (KV cache).
* Call this to start a new conversation.
*/
reset() {
this.cache = null;
this.cacheSeqLen = 0;
log('Conversation state reset');
}
/**
* Load the audio model from a directory
* @param {string} modelPath - Path to model directory
* @param {object} options - Loading options
*/
async load(modelPath, options = {}) {
const { progressCallback, device = 'webgpu', quantization = null } = options;
const report = (status, progress = 0, file = '') => {
if (progressCallback) {
progressCallback({ status, progress, file });
}
};
const executionProviders = device === 'webgpu'
? ['webgpu', 'wasm']
: ['wasm'];
try {
// Load mel config for audio processing
await loadMelConfig(modelPath);
// Load tokenizer
report('loading', 0, 'tokenizer');
const { tokenizer } = await loadTokenizerFromPath(modelPath);
this.tokenizer = tokenizer;
// Load config
report('loading', 5, 'config');
const configResponse = await fetch(`${modelPath}/config.json`, {
mode: 'cors',
credentials: 'omit',
});
this.config = await configResponse.json();
// Extract model dimensions from config
const lfmConfig = this.config.lfm || {};
this.hiddenSize = lfmConfig.hidden_size || 2048;
this.numLayers = lfmConfig.num_hidden_layers || 16;
this.numKVHeads = lfmConfig.num_key_value_heads || 8;
this.headDim = Math.floor(this.hiddenSize / (lfmConfig.num_attention_heads || 32));
this.convL = lfmConfig.conv_L_cache || 3;
this.layerTypes = lfmConfig.layer_types || [];
this.vocabSize = lfmConfig.vocab_size || 65536;
console.log('Model config:', {
hiddenSize: this.hiddenSize,
numLayers: this.numLayers,
numKVHeads: this.numKVHeads,
headDim: this.headDim,
});
// Parse quantization config
const quantConfig = typeof quantization === 'object' ? quantization : {
decoder: quantization,
audioEncoder: quantization,
audioEmbedding: quantization,
audioDetokenizer: quantization,
vocoder: quantization,
};
// Helper to load ONNX model with external data
const loadOnnxWithExternalData = async (name, progress, quantSuffix = null, extraOptions = {}) => {
const suffix = quantSuffix ? `_${quantSuffix}` : '';
const fileName = `${name}${suffix}`;
report('loading', progress, `${fileName}.onnx`);
const onnxPath = `${modelPath}/onnx/${fileName}.onnx`;
const fetchOptions = { mode: 'cors', credentials: 'omit' };
console.log(`Loading ${fileName}...`);
const sessionOptions = { executionProviders, ...extraOptions };
const onnxResponse = await fetchWithCache(onnxPath, fetchOptions);
if (!onnxResponse.ok) {
throw new Error(`Failed to fetch ${fileName}.onnx: ${onnxResponse.status}`);
}
const onnxBuffer = await onnxResponse.arrayBuffer();
console.log(`Loaded ${fileName}.onnx: ${(onnxBuffer.byteLength / 1024 / 1024).toFixed(1)} MB`);
// Load external data files
sessionOptions.externalData = [];
// Try single .onnx_data file
const singleDataPath = `${modelPath}/onnx/${fileName}.onnx_data`;
try {
const dataResponse = await fetchWithCache(singleDataPath, fetchOptions);
const contentType = dataResponse.headers.get('content-type') || '';
if (dataResponse.ok && !contentType.includes('text/html')) {
const dataBuffer = await dataResponse.arrayBuffer();
if (dataBuffer.byteLength > 1000) { // Sanity check
console.log(`Loaded ${fileName}.onnx_data: ${(dataBuffer.byteLength / 1024 / 1024).toFixed(1)} MB`);
sessionOptions.externalData.push({
path: `${fileName}.onnx_data`,
data: new Uint8Array(dataBuffer),
});
}
}
} catch (e) {
// File doesn't exist
}
// Try numbered files - stop on first 404
for (let i = 1; ; i++) {
const numberedDataPath = `${modelPath}/onnx/${fileName}.onnx_data_${i}`;
const dataResponse = await fetch(numberedDataPath, fetchOptions);
if (dataResponse.status === 404 || !dataResponse.ok) break;
const contentType = dataResponse.headers.get('content-type') || '';
if (contentType.includes('text/html')) break;
const dataBuffer = await dataResponse.arrayBuffer();
if (dataBuffer.byteLength < 1000) break;
console.log(`Loaded ${fileName}.onnx_data_${i}: ${(dataBuffer.byteLength / 1024 / 1024).toFixed(1)} MB`);
sessionOptions.externalData.push({
path: `${fileName}.onnx_data_${i}`,
data: new Uint8Array(dataBuffer),
});
}
if (sessionOptions.externalData.length === 0) {
delete sessionOptions.externalData;
}
const session = await ort.InferenceSession.create(new Uint8Array(onnxBuffer), sessionOptions);
console.log(`Session created for ${fileName}`);
return session;
};
// Load decoder
// On WebGPU: keep KV cache outputs on GPU to avoid GPU→CPU→GPU roundtrips between steps
const decoderOpts = device === 'webgpu' ? (() => {
const loc = {};
for (let i = 0; i < this.layerTypes.length; i++) {
if (this.layerTypes[i] === 'conv') {
loc[`present_conv.${i}`] = 'gpu-buffer';
} else {
loc[`present.${i}.key`] = 'gpu-buffer';
loc[`present.${i}.value`] = 'gpu-buffer';
}
}
return { preferredOutputLocation: loc };
})() : {};
this.decoderSession = await loadOnnxWithExternalData('decoder', 10, quantConfig.decoder, decoderOpts);
// Load embed_tokens weight for text embedding lookup
report('loading', 30, 'embed_tokens');
this.embedTokensWeight = await this.loadEmbedTokensWeight(modelPath);
// Load audio encoder (for ASR)
this.audioEncoderSession = await loadOnnxWithExternalData('audio_encoder', 50, quantConfig.audioEncoder);
// Load audio embedding (for TTS) - try binary first, fallback to ONNX
report('loading', 65, 'audio_embedding');
this.audioEmbeddingWeight = await this.loadAudioEmbeddingWeight(modelPath);
if (!this.audioEmbeddingWeight) {
// Fallback to ONNX model
this.audioEmbeddingSession = await loadOnnxWithExternalData('audio_embedding', 70, quantConfig.audioEmbedding);
} else {
console.log('Using direct audio embedding lookup (binary)');
}
// Load audio detokenizer (for TTS output)
try {
this.audioDetokenizerSession = await loadOnnxWithExternalData('audio_detokenizer', 85, quantConfig.audioDetokenizer);
} catch (e) {
console.warn('Audio detokenizer not available:', e);
}
// Load vocoder (for TTS)
// On WebGPU: keep KV cache on GPU to avoid GPU→CPU→GPU roundtrips between steps
try {
const vocoderOpts = device === 'webgpu'
? { preferredOutputLocation: { new_keys: 'gpu-buffer', new_values: 'gpu-buffer', depth_slices: 'gpu-buffer' } }
: {};
this.vocoderSession = await loadOnnxWithExternalData('vocoder_depthformer', 95, quantConfig.vocoder, vocoderOpts);
} catch (e) {
console.warn('Vocoder not available:', e);
}
report('done', 100, '');
return true;
} catch (error) {
let errorMessage = error;
if (typeof error === 'number') {
errorMessage = `ONNX Runtime error code: ${error}`;
} else if (error instanceof Error) {
errorMessage = error.message;
}
console.error('Failed to load audio model:', errorMessage);
throw new Error(errorMessage);
}
}
/**
* Get audio embeddings for given token indices and sum across codebooks
*
* Uses direct array indexing if binary weight available (fast),
* falls back to ONNX session otherwise.
*
* @param {number[]} audioTokens - Array of 8 token indices (one per codebook)
* @returns {Float32Array} Summed embedding [hiddenSize]
*/
async getAudioEmbedding(audioTokens) {
const NUM_CODEBOOKS = 8;
const hiddenSize = this.hiddenSize;
const summedEmbeds = new Float32Array(hiddenSize);
if (this.audioEmbeddingWeight) {
// Direct lookup (much faster - no ONNX call)
const weight = this.audioEmbeddingWeight.weight;
for (let cb = 0; cb < NUM_CODEBOOKS; cb++) {
const tokenIdx = audioTokens[cb];
const offset = tokenIdx * hiddenSize;
for (let h = 0; h < hiddenSize; h++) {
summedEmbeds[h] += weight[offset + h];
}
}
} else {
// Fallback to ONNX session
const audioTokensTensor = new ort.Tensor('int64', new BigInt64Array(audioTokens.map(BigInt)), [1, NUM_CODEBOOKS]);
const result = await this.audioEmbeddingSession.run({ audio_codes: audioTokensTensor });
const embeddings = result.audio_embeds.data;
for (let cb = 0; cb < NUM_CODEBOOKS; cb++) {
for (let h = 0; h < hiddenSize; h++) {
summedEmbeds[h] += embeddings[cb * hiddenSize + h];
}
}
}
return summedEmbeds;
}
/**
* Load audio_embedding.weight from raw binary file for direct lookup
*
* This eliminates ONNX model calls (352 per generation → 0).
* Falls back to ONNX session if binary not available.
*/
async loadAudioEmbeddingWeight(modelPath) {
const fetchOptions = { mode: 'cors', credentials: 'omit' };
try {
// Load metadata
const metaResponse = await fetchWithCache(`${modelPath}/onnx/audio_embedding.json`, fetchOptions);
if (!metaResponse.ok) {
console.log('audio_embedding.json not found, will use ONNX model');
return null;
}
const meta = await metaResponse.json();
console.log('audio_embedding metadata:', meta);
// Load binary weight
const binResponse = await fetchWithCache(`${modelPath}/onnx/audio_embedding.bin`, fetchOptions);
if (!binResponse.ok) {
console.log('audio_embedding.bin not found, will use ONNX model');
return null;
}
const buffer = await binResponse.arrayBuffer();
const weight = new Float32Array(buffer);
if (weight.length !== meta.vocab_size * meta.hidden_size) {
console.error('audio_embedding size mismatch:', weight.length, 'expected:', meta.vocab_size * meta.hidden_size);
return null;
}
console.log(`Loaded audio_embedding: [${meta.vocab_size}, ${meta.hidden_size}] (${(buffer.byteLength / 1e6).toFixed(1)} MB)`);
return { weight, vocabSize: meta.vocab_size, hiddenSize: meta.hidden_size };
} catch (e) {
console.log('Failed to load audio_embedding.bin:', e);
return null;
}
}
/**
* Load embed_tokens.weight from raw binary file for text embedding lookup
*
* The Python export saves embed_tokens.weight as:
* - embed_tokens.bin: raw float32 binary (vocab_size * hidden_size * 4 bytes)
* - embed_tokens.json: metadata (vocab_size, hidden_size)
*/
async loadEmbedTokensWeight(modelPath) {
const fetchOptions = { mode: 'cors', credentials: 'omit' };
// Load metadata
const metaResponse = await fetchWithCache(`${modelPath}/onnx/embed_tokens.json`, fetchOptions);
if (!metaResponse.ok) {
console.warn('embed_tokens.json not found, TTS will be unavailable');
return null;
}
const meta = await metaResponse.json();
console.log('embed_tokens metadata:', meta);
// Load binary weight
const binResponse = await fetchWithCache(`${modelPath}/onnx/embed_tokens.bin`, fetchOptions);
if (!binResponse.ok) {
console.warn('embed_tokens.bin not found, TTS will be unavailable');
return null;
}
const buffer = await binResponse.arrayBuffer();
const weight = new Float32Array(buffer);
if (weight.length !== meta.vocab_size * meta.hidden_size) {
console.error('embed_tokens size mismatch:', weight.length, 'expected:', meta.vocab_size * meta.hidden_size);
return null;
}
console.log(`Loaded embed_tokens: [${meta.vocab_size}, ${meta.hidden_size}] (${(buffer.byteLength / 1e6).toFixed(1)} MB)`);
return { weight, vocabSize: meta.vocab_size, hiddenSize: meta.hidden_size };
}
/**
* Get text embeddings for token IDs using pre-loaded weight
* @param {number[]} tokenIds - Array of token IDs
* @returns {ort.Tensor} - Embeddings tensor [1, seq_len, hidden_size]
*/
getTextEmbeddings(tokenIds) {
if (!this.embedTokensWeight) {
throw new Error('embed_tokens weight not loaded');
}
const { weight, hiddenSize } = this.embedTokensWeight;
const seqLen = tokenIds.length;
const embeddings = new Float32Array(seqLen * hiddenSize);
for (let i = 0; i < seqLen; i++) {
const tokenId = tokenIds[i];
const srcOffset = tokenId * hiddenSize;
const dstOffset = i * hiddenSize;
embeddings.set(weight.subarray(srcOffset, srcOffset + hiddenSize), dstOffset);
}
return new ort.Tensor('float32', embeddings, [1, seqLen, hiddenSize]);
}
/**
* Initialize KV cache for generation
*/
initializeCache() {
const cache = {};
for (let idx = 0; idx < this.layerTypes.length; idx++) {
const layerType = this.layerTypes[idx];
if (layerType === 'conv') {
cache[`past_conv.${idx}`] = new ort.Tensor(
'float32',
new Float32Array(1 * this.hiddenSize * this.convL),
[1, this.hiddenSize, this.convL]
);
} else {
cache[`past_key_values.${idx}.key`] = new ort.Tensor(
'float32',
new Float32Array(0),
[1, this.numKVHeads, 0, this.headDim]
);
cache[`past_key_values.${idx}.value`] = new ort.Tensor(
'float32',
new Float32Array(0),
[1, this.numKVHeads, 0, this.headDim]
);
}
}
return cache;
}
/**
* Update cache from decoder outputs
*/
updateCache(cache, outputs) {
for (const name of Object.keys(outputs)) {
if (name.startsWith('present_conv.')) {
const cacheName = name.replace('present_conv', 'past_conv');
if (cacheName in cache) {
cache[cacheName] = outputs[name];
}
} else if (name.startsWith('present.')) {
const cacheName = name.replace('present.', 'past_key_values.');
if (cacheName in cache) {
cache[cacheName] = outputs[name];
}
}
}
}
/**
* Run decoder with embeddings
*/
async runDecoder(embeds, attentionMask, cache) {
const feeds = {
inputs_embeds: embeds,
attention_mask: attentionMask,
...cache,
};
const outputs = await this.decoderSession.run(feeds);
return {
logits: outputs.logits,
hiddenStates: outputs.hidden_states,
outputs,
};
}
/**
* Sample next token
*/
sampleToken(logits, temperature = 0.7) {
if (temperature === 0) {
let maxIdx = 0;
let maxVal = logits[0];
for (let i = 1; i < logits.length; i++) {
if (logits[i] > maxVal) {
maxVal = logits[i];
maxIdx = i;
}
}
return maxIdx;
}
// Temperature sampling
const scaledLogits = new Float32Array(logits.length);
let maxLogit = -Infinity;
for (let i = 0; i < logits.length; i++) {
scaledLogits[i] = logits[i] / temperature;
maxLogit = Math.max(maxLogit, scaledLogits[i]);
}
let sumExp = 0;
for (let i = 0; i < scaledLogits.length; i++) {
scaledLogits[i] = Math.exp(scaledLogits[i] - maxLogit);
sumExp += scaledLogits[i];
}
const probs = new Float32Array(scaledLogits.length);
for (let i = 0; i < probs.length; i++) {
probs[i] = scaledLogits[i] / sumExp;
}
// Sample from distribution
const r = Math.random();
let cumsum = 0;
for (let i = 0; i < probs.length; i++) {
cumsum += probs[i];
if (r < cumsum) return i;
}
return probs.length - 1;
}
/**
* Transcribe audio to text (ASR mode)
* @param {Float32Array} audioData - Audio samples in [-1, 1]
* @param {number} sampleRate - Audio sample rate
* @param {object} options - Generation options
*/
async transcribe(audioData, sampleRate, options = {}) {
const {
maxNewTokens = DEFAULT_MAX_TOKENS_TEXT,
temperature = 0,
systemPrompt = DEFAULT_SYSTEM_PROMPT_ASR,
onToken,
} = options;
if (!this.audioEncoderSession) {
throw new Error('Audio encoder not loaded');
}
if (!this.embedTokensWeight) {
throw new Error('embed_tokens not loaded - required for ASR');
}
logReset();
log('=== ASR Transcription ===');
log('Audio samples:', audioData.length, 'Sample rate:', sampleRate);
// 1. Compute mel spectrogram
const { melFeatures, numFrames } = computeMelSpectrogram(audioData, sampleRate);
log('Mel spectrogram frames:', numFrames);
// 2. Run audio encoder
const melTensor = new ort.Tensor(
'float32',
melFeatures,
[1, numFrames, 128] // [batch, time, n_mels]
);
const melLengths = new ort.Tensor(
'int64',
new BigInt64Array([BigInt(numFrames)]),
[1]
);
const encoderOutputs = await this.audioEncoderSession.run({
mel_spectrogram: melTensor,
mel_lengths: melLengths,
});
const audioEmbeds = encoderOutputs.audio_embeddings;
log('Audio embeddings shape:', audioEmbeds.dims);
// 3. Build prompt: prefix + audio + suffix
const prefixText = `<|startoftext|><|im_start|>system\n${systemPrompt}<|im_end|>\n<|im_start|>user\n`;
const suffixText = '<|im_end|>\n<|im_start|>assistant\n';
// Use add_special_tokens: false to match Python behavior (prompt already has special tokens)
const prefixIds = Array.from(this.tokenizer.encode(prefixText, { add_special_tokens: false }));
const suffixIds = Array.from(this.tokenizer.encode(suffixText, { add_special_tokens: false }));
log('Prefix tokens:', prefixIds.length, 'Suffix tokens:', suffixIds.length);
// Get text embeddings
const prefixEmbeds = this.getTextEmbeddings(prefixIds);
const suffixEmbeds = this.getTextEmbeddings(suffixIds);
// 4. Concatenate embeddings: prefix + audio + suffix
const prefixLen = prefixIds.length;
const audioLen = audioEmbeds.dims[1];
const suffixLen = suffixIds.length;
const totalLen = prefixLen + audioLen + suffixLen;
const { hiddenSize } = this.embedTokensWeight;
const allEmbeds = new Float32Array(totalLen * hiddenSize);
// Copy prefix embeddings
allEmbeds.set(prefixEmbeds.data, 0);
// Copy audio embeddings
allEmbeds.set(new Float32Array(audioEmbeds.data.buffer, audioEmbeds.data.byteOffset, audioLen * hiddenSize), prefixLen * hiddenSize);
// Copy suffix embeddings
allEmbeds.set(suffixEmbeds.data, (prefixLen + audioLen) * hiddenSize);
const inputEmbeds = new ort.Tensor('float32', allEmbeds, [1, totalLen, hiddenSize]);
const attentionMask = new ort.Tensor('int64', new BigInt64Array(totalLen).fill(1n), [1, totalLen]);
// 5. Initialize cache and run prefill
const cache = this.initializeCache();
let { logits, hiddenStates, outputs } = await this.runDecoder(inputEmbeds, attentionMask, cache);
this.updateCache(cache, outputs);
// 6. Generate tokens
const generatedTokens = [];
let currentLen = totalLen;
for (let i = 0; i < maxNewTokens; i++) {
// Get logits for last position - shape is [1, seq_len, vocab_size]
const logitsData = logits.data;
const seqLen = logits.dims[1];
const lastLogits = new Float32Array(this.vocabSize);
const offset = (seqLen - 1) * this.vocabSize;
for (let j = 0; j < this.vocabSize; j++) {
lastLogits[j] = logitsData[offset + j];
}
const nextToken = this.sampleToken(lastLogits, temperature);
// Check for stop tokens
if (nextToken === this.tokenizer.eos_token_id || nextToken === SPECIAL_TOKENS.IM_END) {
log('Stop token reached');
break;
}
generatedTokens.push(nextToken);
if (onToken) {
const text = this.tokenizer.decode(generatedTokens);
onToken(text, nextToken);
}
// Get embedding for next token
const nextEmbeds = this.getTextEmbeddings([nextToken]);
currentLen++;
const nextMask = new ort.Tensor('int64', new BigInt64Array(currentLen).fill(1n), [1, currentLen]);
// Run decoder with single token
({ logits, hiddenStates, outputs } = await this.runDecoder(nextEmbeds, nextMask, cache));
this.updateCache(cache, outputs);
}
const result = this.tokenizer.decode(generatedTokens);
log(`Generated ${generatedTokens.length} tokens: "${result}"`);
return result;
}
/**
* Generate response from messages
* @param {Array} messages - Chat messages
* @param {object} options - Generation options
*/
async generate(messages, options = {}) {
const { maxNewTokens = 256, onToken, audioData = null, sampleRate = null } = options;
// If audio data provided, do ASR
if (audioData && sampleRate) {
return this.transcribe(audioData, sampleRate, {
maxNewTokens,
onToken,
});
}
// Text-only generation (simplified)
const prompt = this.tokenizer.apply_chat_template(messages, {
add_generation_prompt: true,
tokenize: false,
});
const inputIds = this.tokenizer.encode(prompt);
console.log('Input tokens:', inputIds.length);
// Initialize cache
const cache = this.initializeCache();
const generatedTokens = [];
// Note: Full implementation needs proper text embedding support
// This is a placeholder that shows the model is loaded
return '[Text generation requires full embedding support - model loaded successfully]';
}
/**
* Initialize reusable vocoder tensors to reduce allocation overhead
*/
_initVocoderCache() {
if (this._vocoderCache) return;
const numLayers = 6;
const numKvHeads = 8;
const headDim = 32;
// Pre-allocate data arrays
const stepIdxData = new BigInt64Array(1);
const prevTokenData = new BigInt64Array(1);
const seqlensKData = new Int32Array(1);
const totalSeqLenData = new Int32Array(1);
// Pre-allocate tensors that can be reused
this._vocoderCache = {
hiddenTensor: null, // Created per-call since hiddenState changes
stepIdxData,
prevTokenData,
seqlensKData,
totalSeqLenData,
// Pre-create reusable tensors (ONNX Runtime reads from the data array)
stepIdxTensor: new ort.Tensor('int64', stepIdxData, []),
prevTokenTensor: new ort.Tensor('int64', prevTokenData, [1]),
seqlensKTensor: new ort.Tensor('int32', seqlensKData, [1]),
totalSeqLenTensor: new ort.Tensor('int32', totalSeqLenData, []),
emptyKeysData: new Float32Array(0),
emptyValuesData: new Float32Array(0),
emptyDepthSlicesData: new Float32Array(8 * 1024), // zeros for step 0
// Reusable sampling arrays
scaledLogits: new Float32Array(2049), // codebook vocab size
indices: new Uint16Array(2049), // Use typed array for faster reset
probs: new Float32Array(64), // top-k size
};
// Initialize indices
for (let i = 0; i < 2049; i++) {
this._vocoderCache.indices[i] = i;
}
}
/**
* Sample audio codes using vocoder depthformer
* Optimized to reduce tensor creation overhead
* @param {Float32Array} hiddenState - [hidden_size] hidden state
* @param {number} temperature - Sampling temperature
* @param {number} topK - Top-k sampling
* @returns {number[]} - 8 codebook values
*/
async sampleAudioCodes(hiddenState, temperature = 0.8, topK = 64) {
if (!this.vocoderSession) {
throw new Error('Vocoder not loaded');
}
// Initialize cache on first call
this._initVocoderCache();
const cache = this._vocoderCache;
const numCodebooks = 8;
const numLayers = 6;
const numKvHeads = 8;
const headDim = 32;
const codes = [];
let prevToken = 0;
// Create hidden state tensor (must be new since data changes)
const hiddenTensor = new ort.Tensor('float32', hiddenState, [1, this.hiddenSize]);
// Initialize empty KV cache
let pastKeys = new ort.Tensor(
'float32',
cache.emptyKeysData,
[numLayers, 1, numKvHeads, 0, headDim]
);
let pastValues = new ort.Tensor(
'float32',
cache.emptyValuesData,
[numLayers, 1, numKvHeads, 0, headDim]
);
// Reuse step_idx and prev_token tensors by updating their data
cache.stepIdxData[0] = 0n;
cache.prevTokenData[0] = 0n;
// depth_slices_in: zeros at step 0 (model ignores it), then fed back from output
let depthSlicesIn = new ort.Tensor('float32', cache.emptyDepthSlicesData, [1, 8, 1024]);
for (let i = 0; i < numCodebooks; i++) {
// Update mutable tensor data (tensor objects reuse the underlying data arrays)
cache.stepIdxData[0] = BigInt(i);
cache.prevTokenData[0] = BigInt(prevToken);
cache.seqlensKData[0] = i;
cache.totalSeqLenData[0] = i + 1;
const feeds = {
hidden_states: hiddenTensor,
depth_slices_in: depthSlicesIn,
step_idx: cache.stepIdxTensor,
prev_token: cache.prevTokenTensor,
past_keys: pastKeys,
past_values: pastValues,
seqlens_k: cache.seqlensKTensor,
total_seq_len: cache.totalSeqLenTensor,
};
const outputs = await this.vocoderSession.run(feeds);
depthSlicesIn = outputs.depth_slices;
const logits = outputs.logits.data;
const vocabSize = logits.length;
// Sample with temperature and top-k (reusing cached arrays)
let token;
if (temperature <= 0) {
// Greedy
token = 0;
let maxVal = logits[0];
for (let j = 1; j < vocabSize; j++) {
if (logits[j] > maxVal) {
maxVal = logits[j];
token = j;
}
}
} else {
// Top-k sampling with reused arrays
const scaledLogits = cache.scaledLogits;
const indices = cache.indices;
const probs = cache.probs;
// Scale logits by temperature and find top-k in single pass
// Use partial selection sort (O(k*n) which is fast for small k)
for (let j = 0; j < vocabSize; j++) {
scaledLogits[j] = logits[j] / temperature;
indices[j] = j;
}
// Partial sort to get top-k
for (let j = 0; j < topK; j++) {
let maxIdx = j;
for (let k = j + 1; k < vocabSize; k++) {
if (scaledLogits[indices[k]] > scaledLogits[indices[maxIdx]]) {
maxIdx = k;
}
}
// Swap
const tmp = indices[j];
indices[j] = indices[maxIdx];
indices[maxIdx] = tmp;
}
// Softmax over top-k
const maxLogit = scaledLogits[indices[0]];
let sumExp = 0;
for (let j = 0; j < topK; j++) {
probs[j] = Math.exp(scaledLogits[indices[j]] - maxLogit);
sumExp += probs[j];
}
for (let j = 0; j < topK; j++) {
probs[j] /= sumExp;
}
// Sample
const r = Math.random();
let cumsum = 0;
token = indices[topK - 1]; // Default to last
for (let j = 0; j < topK; j++) {
cumsum += probs[j];
if (r < cumsum) {
token = indices[j];
break;
}
}
}
codes.push(token);
prevToken = token;
// Update KV cache
pastKeys = outputs.new_keys;
pastValues = outputs.new_values;
}
return codes;
}
/**
* Generate speech from text (TTS mode)
* @param {string} text - Text to convert to speech
* @param {object} options - Generation options
* @returns {object} - { audioCodes: number[][], textTokens: number[] }
*/
async generateSpeech(text, options = {}) {
const {
maxNewTokens = DEFAULT_MAX_TOKENS_AUDIO,
textTemperature = 0.7,
audioTemperature = 0.8,
audioTopK = 64,
systemPrompt = DEFAULT_SYSTEM_PROMPT_TTS,
onToken,
onAudioFrame,
} = options;
logReset();
log('=== TTS Generation ===');
log('Text:', text);
if (!this.embedTokensWeight) {
throw new Error('embed_tokens not loaded - required for TTS');
}
if (!this.vocoderSession) {
throw new Error('Vocoder not loaded - required for TTS');
}
// Build TTS prompt
const prompt = `<|startoftext|><|im_start|>system\n${systemPrompt}<|im_end|>\n<|im_start|>user\n${text}<|im_end|>\n<|im_start|>assistant\n`;
// Use add_special_tokens: false to match Python behavior (prompt already has special tokens)
const inputIds = Array.from(this.tokenizer.encode(prompt, { add_special_tokens: false }));
log('TTS prompt tokens:', inputIds.length);
// Get embeddings and run prefill
const inputEmbeds = this.getTextEmbeddings(inputIds);
const cache = this.initializeCache();
const attentionMask = new ort.Tensor('int64', new BigInt64Array(inputIds.length).fill(1n), [1, inputIds.length]);
let { logits, hiddenStates, outputs } = await this.runDecoder(inputEmbeds, attentionMask, cache);
this.updateCache(cache, outputs);
// Phase 1: Generate text until <|audio_start|> token
const textTokens = [];
let currentLen = inputIds.length;
let inAudioMode = false;
let tokensGenerated = 0;
while (tokensGenerated < maxNewTokens && !inAudioMode) {
const logitsData = logits.data;
const seqLen = logits.dims[1];
// Get logits for last position - shape is [1, seq_len, vocab_size]
const lastLogits = new Float32Array(this.vocabSize);
const offset = (seqLen - 1) * this.vocabSize;
for (let i = 0; i < this.vocabSize; i++) {
lastLogits[i] = logitsData[offset + i];
}
const nextToken = this.sampleToken(lastLogits, textTemperature);
tokensGenerated++;
if (nextToken === this.tokenizer.eos_token_id) {
console.warn('Model produced EOS before audio, TTS may not work');
break;
}
if (nextToken === SPECIAL_TOKENS.AUDIO_START) {
log('Model entered audio mode');
inAudioMode = true;
// Feed audio_start token to get hidden states for first audio frame
const nextEmbeds = this.getTextEmbeddings([SPECIAL_TOKENS.AUDIO_START]);
currentLen++;
const nextMask = new ort.Tensor('int64', new BigInt64Array(currentLen).fill(1n), [1, currentLen]);
({ logits, hiddenStates, outputs } = await this.runDecoder(nextEmbeds, nextMask, cache));
this.updateCache(cache, outputs);
break;
}
textTokens.push(nextToken);
// Continue text generation
const nextEmbeds = this.getTextEmbeddings([nextToken]);
currentLen++;
const nextMask = new ort.Tensor('int64', new BigInt64Array(currentLen).fill(1n), [1, currentLen]);
({ logits, hiddenStates, outputs } = await this.runDecoder(nextEmbeds, nextMask, cache));
this.updateCache(cache, outputs);
}
if (!inAudioMode) {
console.warn('Model did not enter audio mode, forcing audio generation');
// Force audio start token
const nextEmbeds = this.getTextEmbeddings([SPECIAL_TOKENS.AUDIO_START]);
currentLen++;
const nextMask = new ort.Tensor('int64', new BigInt64Array(currentLen).fill(1n), [1, currentLen]);
({ logits, hiddenStates, outputs } = await this.runDecoder(nextEmbeds, nextMask, cache));
this.updateCache(cache, outputs);
tokensGenerated++;
}
// Phase 2: Generate audio frames using depthformer
const audioCodes = [];
const startTime = performance.now();
while (tokensGenerated < maxNewTokens) {
// Get hidden state for last position
const hiddenData = hiddenStates.data;
const seqLen = hiddenStates.dims[1];
const lastHidden = hiddenData.slice((seqLen - 1) * this.hiddenSize, seqLen * this.hiddenSize);
// Sample audio codes
const frameCodes = await this.sampleAudioCodes(lastHidden, audioTemperature, audioTopK);
// Check for end-of-audio
// Only check first codebook (matching liquid-audio TTS behavior)
if (frameCodes[0] >= END_OF_AUDIO_TOKEN) {
log(`End of audio at frame ${audioCodes.length}`);
break;
}
// Log progress periodically
if (audioCodes.length % 50 === 0) {
log(`Generated ${audioCodes.length} audio frames`);
}
audioCodes.push(frameCodes);
tokensGenerated++;
if (onAudioFrame) {
onAudioFrame(frameCodes, audioCodes.length);
}
// Feed back audio codes to continue generation
// Audio embedding expects tokens in range [0, 16392) where:
// token = codebook_idx * 2049 + code_value
const clampedCodes = frameCodes.map(c => Math.min(c, 2047));
const audioTokens = clampedCodes.map((code, idx) => idx * CODEBOOK_VOCAB + code);
// Get summed embeddings for all 8 codebooks
const summedEmbeds = await this.getAudioEmbedding(audioTokens);
const nextEmbeds = new ort.Tensor('float32', summedEmbeds, [1, 1, this.hiddenSize]);
currentLen++;
const nextMask = new ort.Tensor('int64', new BigInt64Array(currentLen).fill(1n), [1, currentLen]);
({ logits, hiddenStates, outputs } = await this.runDecoder(nextEmbeds, nextMask, cache));
this.updateCache(cache, outputs);
}
const elapsed = (performance.now() - startTime) / 1000;
const framesPerSec = audioCodes.length / elapsed;
log(`Generated ${audioCodes.length} audio frames in ${elapsed.toFixed(2)}s (${framesPerSec.toFixed(1)} frames/s)`);
const textOutput = textTokens.length > 0 ? this.tokenizer.decode(textTokens) : '';
return { audioCodes, textTokens, textOutput };
}
/**
* Generate interleaved response (mixed text/audio) with stateful KV cache.
*
* The cache is preserved between calls for multi-turn conversation.
* Call reset() to start a new conversation.
*
* @param {Float32Array} audioData - Input audio samples
* @param {number} sampleRate - Audio sample rate
* @param {string} textPrompt - Optional text prompt (unused, for API compatibility)
* @param {object} options - Generation options
* @returns {object} - { text: string, audioCodes: number[][] }
*/
async generateInterleaved(audioData, sampleRate, textPrompt = '', options = {}) {
const {
maxNewTokens = DEFAULT_MAX_TOKENS_AUDIO,
textTemperature = 1.0,
audioTemperature = 1.0,
audioTopK = 4,
systemPrompt = DEFAULT_SYSTEM_PROMPT_INTERLEAVED,
onToken,
onAudioFrame,
} = options;
// Counter-based mode switching (matching liquid-audio)
const INTERLEAVED_N_TEXT = 6;
const INTERLEAVED_N_AUDIO = 12;
logReset();
log('=== Interleaved Generation ===');
log('Cache state:', this.cache ? `exists (seq_len=${this.cacheSeqLen})` : 'null (new conversation)');
log('Audio samples:', audioData.length, 'Sample rate:', sampleRate);
if (!this.audioEncoderSession) {
throw new Error('Audio encoder not loaded - required for interleaved mode');
}
if (!this.embedTokensWeight) {
throw new Error('embed_tokens not loaded - required for interleaved mode');
}
if (!this.vocoderSession) {
throw new Error('Vocoder not loaded - required for interleaved mode');
}
// Timing accumulators
let timeAudioEncode = 0;
let timePrefill = 0;
let timeTextDecode = 0;
let timeAudioDecode = 0;
let timeVocoder = 0;
let timeAudioEmbed = 0;
// 1. Compute mel spectrogram and encode audio
let tStep = performance.now();
const { melFeatures, numFrames } = computeMelSpectrogram(audioData, sampleRate);
const timeMel = performance.now() - tStep;
const melTensor = new ort.Tensor('float32', melFeatures, [1, numFrames, 128]);
const melLengths = new ort.Tensor('int64', new BigInt64Array([BigInt(numFrames)]), [1]);
tStep = performance.now();
const encoderOutputs = await this.audioEncoderSession.run({
mel_spectrogram: melTensor,
mel_lengths: melLengths,
});
timeAudioEncode = performance.now() - tStep;
const audioEmbeds = encoderOutputs.audio_embeddings;
log(`Mel: ${timeMel.toFixed(0)}ms, AudioEnc: ${timeAudioEncode.toFixed(0)}ms, frames: ${numFrames}`);
const { hiddenSize } = this.embedTokensWeight;
// 2. Build prompt based on whether this is first turn or continuation
let inputEmbeds;
let newSeqLen;
if (this.cache === null) {
// === First turn: full prompt with system message ===
log('First turn - initializing conversation');
this.cache = this.initializeCache();
this.cacheSeqLen = 0;
const prefixText = `<|startoftext|><|im_start|>system\n${systemPrompt}<|im_end|>\n<|im_start|>user\n`;
const suffixText = '<|im_end|>\n<|im_start|>assistant\n';
const prefixIds = Array.from(this.tokenizer.encode(prefixText, { add_special_tokens: false }));
const suffixIds = Array.from(this.tokenizer.encode(suffixText, { add_special_tokens: false }));
const prefixEmbeds = this.getTextEmbeddings(prefixIds);
const suffixEmbeds = this.getTextEmbeddings(suffixIds);
const prefixLen = prefixIds.length;
const audioLen = audioEmbeds.dims[1];
const suffixLen = suffixIds.length;
newSeqLen = prefixLen + audioLen + suffixLen;
const allEmbeds = new Float32Array(newSeqLen * hiddenSize);
allEmbeds.set(prefixEmbeds.data, 0);
allEmbeds.set(
new Float32Array(audioEmbeds.data.buffer, audioEmbeds.data.byteOffset, audioLen * hiddenSize),
prefixLen * hiddenSize
);
allEmbeds.set(suffixEmbeds.data, (prefixLen + audioLen) * hiddenSize);
inputEmbeds = new ort.Tensor('float32', allEmbeds, [1, newSeqLen, hiddenSize]);
} else {
// === Continuation: user turn only ===
log(`Continuing conversation (cache seq_len=${this.cacheSeqLen})`);
const userPrefixText = '<|im_start|>user\n';
const suffixText = '<|im_end|>\n<|im_start|>assistant\n';
const userPrefixIds = Array.from(this.tokenizer.encode(userPrefixText, { add_special_tokens: false }));
const suffixIds = Array.from(this.tokenizer.encode(suffixText, { add_special_tokens: false }));
const userPrefixEmbeds = this.getTextEmbeddings(userPrefixIds);
const suffixEmbeds = this.getTextEmbeddings(suffixIds);
const userPrefixLen = userPrefixIds.length;
const audioLen = audioEmbeds.dims[1];
const suffixLen = suffixIds.length;
newSeqLen = userPrefixLen + audioLen + suffixLen;
const allEmbeds = new Float32Array(newSeqLen * hiddenSize);
allEmbeds.set(userPrefixEmbeds.data, 0);
allEmbeds.set(
new Float32Array(audioEmbeds.data.buffer, audioEmbeds.data.byteOffset, audioLen * hiddenSize),
userPrefixLen * hiddenSize
);
allEmbeds.set(suffixEmbeds.data, (userPrefixLen + audioLen) * hiddenSize);
inputEmbeds = new ort.Tensor('float32', allEmbeds, [1, newSeqLen, hiddenSize]);
}
// 3. Run prefill with attention mask covering full sequence
const totalLen = this.cacheSeqLen + newSeqLen;
const attentionMask = new ort.Tensor('int64', new BigInt64Array(totalLen).fill(1n), [1, totalLen]);
tStep = performance.now();
let { logits, hiddenStates, outputs } = await this.runDecoder(inputEmbeds, attentionMask, this.cache);
timePrefill = performance.now() - tStep;
this.updateCache(this.cache, outputs);
this.cacheSeqLen = totalLen;
log(`Prefill: ${timePrefill.toFixed(0)}ms, new tokens: ${newSeqLen}, total: ${totalLen}`);
// 4. Generate with counter-based mode switching
const textTokens = [];
const audioCodes = [];
let currentLen = totalLen;
let inAudioMode = false;
let modalityLeft = INTERLEAVED_N_TEXT;
let textDone = false;
const startTime = performance.now();
for (let step = 0; step < maxNewTokens; step++) {
modalityLeft--;
if (inAudioMode) {
// Generate audio frame using depthformer
const hiddenData = hiddenStates.data;
const seqLen = hiddenStates.dims[1];
const lastHidden = hiddenData.slice((seqLen - 1) * hiddenSize, seqLen * hiddenSize);
tStep = performance.now();
const frameCodes = await this.sampleAudioCodes(lastHidden, audioTemperature, audioTopK);
timeVocoder += performance.now() - tStep;
// Switch back to text after N audio frames (if text not done)
if (modalityLeft <= 0 && !textDone) {
inAudioMode = false;
modalityLeft = INTERLEAVED_N_TEXT;
}
// Check for end of audio - first codebook == 2048 (matching liquid-audio)
if (frameCodes[0] === END_OF_AUDIO_TOKEN) {
log(`End of audio at step ${step}`);
// Set all codes to 2048 (matching liquid-audio)
for (let i = 0; i < NUM_CODEBOOKS; i++) {
frameCodes[i] = END_OF_AUDIO_TOKEN;
}
inAudioMode = false;
// Don't save this frame, but still feed it back
} else {
// Save valid frame (clamped to 0-2047)
const clampedFrame = frameCodes.map(c => Math.min(c, 2047));
audioCodes.push(clampedFrame);
if (onAudioFrame) {
onAudioFrame(clampedFrame, audioCodes.length);
}
if (audioCodes.length % 50 === 0) {
log(`Generated ${audioCodes.length} audio frames`);
}
}
// Get embeddings for next step (always feed back, even for 2048 frames)
tStep = performance.now();
const feedCodes = frameCodes.map(c => c === END_OF_AUDIO_TOKEN ? END_OF_AUDIO_TOKEN : Math.min(c, 2047));
const audioTokens = feedCodes.map((code, idx) => idx * CODEBOOK_VOCAB + code);
// Get summed embeddings for all 8 codebooks
const summedEmbeds = await this.getAudioEmbedding(audioTokens);
timeAudioEmbed += performance.now() - tStep;
const nextEmbeds = new ort.Tensor('float32', summedEmbeds, [1, 1, hiddenSize]);
currentLen++;
const nextMask = new ort.Tensor('int64', new BigInt64Array(currentLen).fill(1n), [1, currentLen]);
tStep = performance.now();
({ logits, hiddenStates, outputs } = await this.runDecoder(nextEmbeds, nextMask, this.cache));
timeAudioDecode += performance.now() - tStep;
this.updateCache(this.cache, outputs);
} else {
// Generate text token
const logitsData = logits.data;
const seqLen = logits.dims[1];
// Get logits for last position - shape is [1, seq_len, vocab_size]
const lastLogits = new Float32Array(this.vocabSize);
const offset = (seqLen - 1) * this.vocabSize;
for (let i = 0; i < this.vocabSize; i++) {
lastLogits[i] = logitsData[offset + i];
}
const token = this.sampleToken(lastLogits, textTemperature);
// Check for end of turn
if (token === this.tokenizer.eos_token_id || token === SPECIAL_TOKENS.IM_END) {
log(`End of turn at step ${step}`);
break;
}
// Check for <|text_end|> token (130)
if (token === SPECIAL_TOKENS.TEXT_END) {
log(`Text end at step ${step}`);
textDone = true;
}
// Switch to audio after N text tokens OR text_end
if (modalityLeft <= 0 || textDone) {
inAudioMode = true;
modalityLeft = INTERLEAVED_N_AUDIO;
}
textTokens.push(token);
if (onToken) {
const decodedText = this.tokenizer.decode(textTokens, { skip_special_tokens: true });
onToken(decodedText, token);
}
// Get embedding for next step
const nextEmbeds = this.getTextEmbeddings([token]);
currentLen++;
const nextMask = new ort.Tensor('int64', new BigInt64Array(currentLen).fill(1n), [1, currentLen]);
tStep = performance.now();
({ logits, hiddenStates, outputs } = await this.runDecoder(nextEmbeds, nextMask, this.cache));
timeTextDecode += performance.now() - tStep;
this.updateCache(this.cache, outputs);
}
}
// 5. Feed <|im_end|> token to close assistant turn in cache
const imEndEmbeds = this.getTextEmbeddings([SPECIAL_TOKENS.IM_END]);
currentLen++;
const finalMask = new ort.Tensor('int64', new BigInt64Array(currentLen).fill(1n), [1, currentLen]);
({ outputs } = await this.runDecoder(imEndEmbeds, finalMask, this.cache));
this.updateCache(this.cache, outputs);
this.cacheSeqLen = currentLen;
// Decode with skip_special_tokens to clean up special tokens like <|text_end|>
const text = this.tokenizer.decode(textTokens, { skip_special_tokens: true });
// Print timing summary
log(`=== Summary ===`);
log(` Mel: ${timeMel.toFixed(0)}ms, AudioEnc: ${timeAudioEncode.toFixed(0)}ms, Prefill: ${timePrefill.toFixed(0)}ms`);
log(` TextDec: ${timeTextDecode.toFixed(0)}ms (${textTokens.length} tok), AudioDec: ${timeAudioDecode.toFixed(0)}ms`);
log(` Vocoder: ${timeVocoder.toFixed(0)}ms, AudioEmbed: ${timeAudioEmbed.toFixed(0)}ms`);
log(`Output: ${textTokens.length} text tokens, ${audioCodes.length} audio frames`);
log(`Text: "${text}"`);
log(`Cache seq_len: ${this.cacheSeqLen}`);
return { text, audioCodes };
}
/**
* Generate interleaved response from text-only input (continuation turn).
* Uses the stateful KV cache from previous turns. Produces both text AND audio.
*
* @param {string} userText - User's text message
* @param {object} options - Generation options
* @returns {object} - { text: string, audioCodes: number[][] }
*/
async generateInterleavedFromText(userText, options = {}) {
const {
maxNewTokens = DEFAULT_MAX_TOKENS_AUDIO,
textTemperature = 1.0,
audioTemperature = 1.0,
audioTopK = 4,
systemPrompt = DEFAULT_SYSTEM_PROMPT_INTERLEAVED,
onToken,
onAudioFrame,
} = options;
// Counter-based mode switching (matching liquid-audio)
const INTERLEAVED_N_TEXT = 6;
const INTERLEAVED_N_AUDIO = 12;
logReset();
log('=== Text-Only Interleaved Generation ===');
log(`Cache state: ${this.cache ? `exists (seq_len=${this.cacheSeqLen})` : 'null (new conversation)'}`);
log(`User text: ${userText}`);
if (!this.embedTokensWeight) {
throw new Error('embed_tokens not loaded - required for text generation');
}
if (!this.vocoderSession) {
throw new Error('Vocoder not loaded - required for interleaved mode');
}
// Timing accumulators
let timePrefill = 0;
let timeTextDecode = 0;
let timeAudioDecode = 0;
let timeVocoder = 0;
let timeAudioEmbed = 0;
let tStep;
const { hiddenSize } = this.embedTokensWeight;
// Build prompt based on whether this is first turn or continuation
let inputEmbeds;
let newSeqLen;
if (this.cache === null) {
// === First turn: full prompt with system message ===
log('First turn - initializing conversation');
this.cache = this.initializeCache();
this.cacheSeqLen = 0;
const prefixText = `<|startoftext|><|im_start|>system\n${systemPrompt}<|im_end|>\n<|im_start|>user\n${userText}<|im_end|>\n<|im_start|>assistant\n`;
const prefixIds = Array.from(this.tokenizer.encode(prefixText, { add_special_tokens: false }));
const prefixEmbeds = this.getTextEmbeddings(prefixIds);
newSeqLen = prefixIds.length;
inputEmbeds = new ort.Tensor('float32', prefixEmbeds.data, [1, newSeqLen, hiddenSize]);
} else {
// === Continuation: user turn only ===
log(`Continuing conversation (cache seq_len=${this.cacheSeqLen})`);
const userTurnText = `<|im_start|>user\n${userText}<|im_end|>\n<|im_start|>assistant\n`;
const userTurnIds = Array.from(this.tokenizer.encode(userTurnText, { add_special_tokens: false }));
const userTurnEmbeds = this.getTextEmbeddings(userTurnIds);
newSeqLen = userTurnIds.length;
inputEmbeds = new ort.Tensor('float32', userTurnEmbeds.data, [1, newSeqLen, hiddenSize]);
}
// Run prefill with attention mask covering full sequence
const totalLen = this.cacheSeqLen + newSeqLen;
const attentionMask = new ort.Tensor('int64', new BigInt64Array(totalLen).fill(1n), [1, totalLen]);
tStep = performance.now();
let { logits, hiddenStates, outputs } = await this.runDecoder(inputEmbeds, attentionMask, this.cache);
timePrefill = performance.now() - tStep;
this.updateCache(this.cache, outputs);
this.cacheSeqLen = totalLen;
log(`Prefill: ${timePrefill.toFixed(0)}ms, new tokens: ${newSeqLen}, total: ${totalLen}`);
// Generate with counter-based mode switching
const textTokens = [];
const audioCodes = [];
let currentLen = totalLen;
let inAudioMode = false;
let modalityLeft = INTERLEAVED_N_TEXT;
let textDone = false;
for (let step = 0; step < maxNewTokens; step++) {
modalityLeft--;
if (inAudioMode) {
// Generate audio frame using depthformer
const hiddenData = hiddenStates.data;
const seqLen = hiddenStates.dims[1];
const lastHidden = hiddenData.slice((seqLen - 1) * hiddenSize, seqLen * hiddenSize);
tStep = performance.now();
const frameCodes = await this.sampleAudioCodes(lastHidden, audioTemperature, audioTopK);
timeVocoder += performance.now() - tStep;
// Switch back to text after N audio frames (if text not done)
if (modalityLeft <= 0 && !textDone) {
inAudioMode = false;
modalityLeft = INTERLEAVED_N_TEXT;
}
// Check for end of audio
if (frameCodes[0] === END_OF_AUDIO_TOKEN) {
log(`End of audio at step ${step}`);
for (let i = 0; i < NUM_CODEBOOKS; i++) {
frameCodes[i] = END_OF_AUDIO_TOKEN;
}
inAudioMode = false;
} else {
const clampedFrame = frameCodes.map(c => Math.min(c, 2047));
audioCodes.push(clampedFrame);
if (onAudioFrame) {
onAudioFrame(clampedFrame, audioCodes.length);
}
if (audioCodes.length % 50 === 0) {
log(`Generated ${audioCodes.length} audio frames`);
}
}
// Get embeddings for next step
tStep = performance.now();
const feedCodes = frameCodes.map(c => c === END_OF_AUDIO_TOKEN ? END_OF_AUDIO_TOKEN : Math.min(c, 2047));
const audioTokens = feedCodes.map((code, idx) => idx * CODEBOOK_VOCAB + code);
const summedEmbeds = await this.getAudioEmbedding(audioTokens);
timeAudioEmbed += performance.now() - tStep;
const nextEmbeds = new ort.Tensor('float32', summedEmbeds, [1, 1, hiddenSize]);
currentLen++;
const nextMask = new ort.Tensor('int64', new BigInt64Array(currentLen).fill(1n), [1, currentLen]);
tStep = performance.now();
({ logits, hiddenStates, outputs } = await this.runDecoder(nextEmbeds, nextMask, this.cache));
timeAudioDecode += performance.now() - tStep;
this.updateCache(this.cache, outputs);
} else {
// Generate text token
const logitsData = logits.data;
const seqLen = logits.dims[1];
const lastLogits = new Float32Array(this.vocabSize);
const offset = (seqLen - 1) * this.vocabSize;
for (let i = 0; i < this.vocabSize; i++) {
lastLogits[i] = logitsData[offset + i];
}
const token = this.sampleToken(lastLogits, textTemperature);
// Check for end of turn
if (token === this.tokenizer.eos_token_id || token === SPECIAL_TOKENS.IM_END) {
log(`End of turn at step ${step}`);
break;
}
// Check for <|text_end|> token
if (token === SPECIAL_TOKENS.TEXT_END) {
log(`Text end at step ${step}`);
textDone = true;
}
// Switch to audio after N text tokens OR text_end
if (modalityLeft <= 0 || textDone) {
inAudioMode = true;
modalityLeft = INTERLEAVED_N_AUDIO;
}
textTokens.push(token);
if (onToken) {
const decodedText = this.tokenizer.decode(textTokens, { skip_special_tokens: true });
onToken(decodedText, token);
}
// Get embedding for next step
const nextEmbeds = this.getTextEmbeddings([token]);
currentLen++;
const nextMask = new ort.Tensor('int64', new BigInt64Array(currentLen).fill(1n), [1, currentLen]);
tStep = performance.now();
({ logits, hiddenStates, outputs } = await this.runDecoder(nextEmbeds, nextMask, this.cache));
timeTextDecode += performance.now() - tStep;
this.updateCache(this.cache, outputs);
}
}
// Feed <|im_end|> token to close assistant turn in cache
const imEndEmbeds = this.getTextEmbeddings([SPECIAL_TOKENS.IM_END]);
currentLen++;
const finalMask = new ort.Tensor('int64', new BigInt64Array(currentLen).fill(1n), [1, currentLen]);
({ outputs } = await this.runDecoder(imEndEmbeds, finalMask, this.cache));
this.updateCache(this.cache, outputs);
this.cacheSeqLen = currentLen;
const text = this.tokenizer.decode(textTokens, { skip_special_tokens: true });
log(`=== Summary ===`);
log(` Prefill: ${timePrefill.toFixed(0)}ms`);
log(` TextDec: ${timeTextDecode.toFixed(0)}ms (${textTokens.length} tok), AudioDec: ${timeAudioDecode.toFixed(0)}ms`);
log(` Vocoder: ${timeVocoder.toFixed(0)}ms, AudioEmbed: ${timeAudioEmbed.toFixed(0)}ms`);
log(`Output: ${textTokens.length} text tokens, ${audioCodes.length} audio frames`);
log(`Text: "${text}"`);
log(`Cache seq_len: ${this.cacheSeqLen}`);
return { text, audioCodes };
}
/**
* Generate text-only response (for follow-up turns without audio).
* Uses the stateful KV cache from previous interleaved turns.
*
* @param {string} userText - User's text input
* @param {object} options - Generation options
* @returns {object} - { text: string }
*/
async generateTextOnly(userText, options = {}) {
const {
maxNewTokens = 256,
temperature = 0.7,
systemPrompt = 'You are a helpful assistant.',
onToken,
} = options;
logReset();
log('=== Text-Only Generation ===');
log('Cache state:', this.cache ? `exists (seq_len=${this.cacheSeqLen})` : 'null (new conversation)');
log('User text:', userText);
if (!this.embedTokensWeight) {
throw new Error('embed_tokens not loaded');
}
const { hiddenSize } = this.embedTokensWeight;
// Build prompt based on whether we have existing cache
let inputEmbeds;
let newSeqLen;
if (this.cache === null) {
// First turn: include system message
log('First turn - initializing conversation');
this.cache = this.initializeCache();
this.cacheSeqLen = 0;
const promptText = `<|startoftext|><|im_start|>system\n${systemPrompt}<|im_end|>\n<|im_start|>user\n${userText}<|im_end|>\n<|im_start|>assistant\n`;
const promptIds = Array.from(this.tokenizer.encode(promptText, { add_special_tokens: false }));
inputEmbeds = this.getTextEmbeddings(promptIds);
newSeqLen = promptIds.length;
} else {
// Continuation: just user turn
log(`Continuing conversation (cache seq_len=${this.cacheSeqLen})`);
const turnText = `<|im_start|>user\n${userText}<|im_end|>\n<|im_start|>assistant\n`;
const turnIds = Array.from(this.tokenizer.encode(turnText, { add_special_tokens: false }));
inputEmbeds = this.getTextEmbeddings(turnIds);
newSeqLen = turnIds.length;
}
// Run prefill
const totalLen = this.cacheSeqLen + newSeqLen;
const attentionMask = new ort.Tensor('int64', new BigInt64Array(totalLen).fill(1n), [1, totalLen]);
let { logits, outputs } = await this.runDecoder(inputEmbeds, attentionMask, this.cache);
this.updateCache(this.cache, outputs);
this.cacheSeqLen = totalLen;
// Generate tokens
const textTokens = [];
let currentLen = totalLen;
for (let i = 0; i < maxNewTokens; i++) {
const logitsData = logits.data;
const seqLen = logits.dims[1];
const lastLogits = new Float32Array(this.vocabSize);
const offset = (seqLen - 1) * this.vocabSize;
for (let j = 0; j < this.vocabSize; j++) {
lastLogits[j] = logitsData[offset + j];
}
const nextToken = this.sampleToken(lastLogits, temperature);
// Check for stop tokens
if (nextToken === this.tokenizer.eos_token_id || nextToken === SPECIAL_TOKENS.IM_END) {
log('Stop token reached');
break;
}
textTokens.push(nextToken);
if (onToken) {
const text = this.tokenizer.decode(textTokens, { skip_special_tokens: true });
onToken(text, nextToken);
}
// Get embedding for next token
const nextEmbeds = this.getTextEmbeddings([nextToken]);
currentLen++;
const nextMask = new ort.Tensor('int64', new BigInt64Array(currentLen).fill(1n), [1, currentLen]);
({ logits, outputs } = await this.runDecoder(nextEmbeds, nextMask, this.cache));
this.updateCache(this.cache, outputs);
}
// Feed <|im_end|> to close turn
const imEndEmbeds = this.getTextEmbeddings([SPECIAL_TOKENS.IM_END]);
currentLen++;
const finalMask = new ort.Tensor('int64', new BigInt64Array(currentLen).fill(1n), [1, currentLen]);
({ outputs } = await this.runDecoder(imEndEmbeds, finalMask, this.cache));
this.updateCache(this.cache, outputs);
this.cacheSeqLen = currentLen;
const text = this.tokenizer.decode(textTokens, { skip_special_tokens: true });
log(`Generated ${textTokens.length} tokens: "${text}"`);
log(`Cache seq_len: ${this.cacheSeqLen}`);
return { text };
}
/**
* Decode audio codes to waveform using audio detokenizer + ISTFT
* @param {number[][]} audioCodes - Array of [8] codebook values per frame
* @returns {Float32Array} - Audio waveform samples in [-1, 1]
*/
async decodeAudioCodes(audioCodes) {
if (!this.audioDetokenizerSession) {
throw new Error('Audio detokenizer not loaded');
}
if (audioCodes.length < 2) {
console.warn('Not enough audio codes to decode');
return new Float32Array(0);
}
const decodeStart = performance.now();
log(`Decoding ${audioCodes.length} audio frames...`);
// ISTFT parameters (fixed for this model)
const nFft = 1280;
const hopLength = 320;
const winLength = 1280;
const nFftBins = nFft / 2 + 1;
// Stack codes: [T, 8] -> [8, T] and add batch -> [1, 8, T]
const T = audioCodes.length;
const codesTransposed = new BigInt64Array(8 * T);
for (let t = 0; t < T; t++) {
for (let cb = 0; cb < 8; cb++) {
codesTransposed[cb * T + t] = BigInt(Math.min(audioCodes[t][cb], 2047));
}
}
// Run detokenizer: [1, 8, T] -> [1, T, 1282]
const codesTensor = new ort.Tensor('int64', codesTransposed, [1, 8, T]);
const detokStart = performance.now();
const detokOutputs = await this.audioDetokenizerSession.run({ audio_codes: codesTensor });
const stftFeatures = detokOutputs.stft_features;
log(`Detokenizer: ${(performance.now() - detokStart).toFixed(0)}ms, STFT frames: ${stftFeatures.dims[1]}`);
// Get raw data - shape is [1, T, 1282], we need to skip batch dimension
const stftData = stftFeatures.data;
const actualT = stftFeatures.dims[1];
// Convert to complex STFT: [log_magnitude | angle] -> complex
const complexStft = new Array(nFftBins);
for (let f = 0; f < nFftBins; f++) {
complexStft[f] = new Array(actualT);
for (let t = 0; t < actualT; t++) {
const logMag = stftData[t * 1282 + f];
const angle = stftData[t * 1282 + nFftBins + f];
const mag = Math.exp(logMag);
// Store as [real, imag]
complexStft[f][t] = [mag * Math.cos(angle), mag * Math.sin(angle)];
}
}
// ISTFT with 'same' padding
const istftStart = performance.now();
const waveform = this.istftSamePadding(complexStft, nFft, hopLength, winLength, actualT);
log(`ISTFT: ${(performance.now() - istftStart).toFixed(0)}ms`);
// Find max/min without spread operator (avoid stack overflow on large arrays)
let waveMax = -Infinity, waveMin = Infinity;
for (let i = 0; i < waveform.length; i++) {
if (waveform[i] > waveMax) waveMax = waveform[i];
if (waveform[i] < waveMin) waveMin = waveform[i];
}
log('ISTFT output - length:', waveform.length, 'max:', waveMax.toFixed(4), 'min:', waveMin.toFixed(4));
// Check for invalid values
if (isNaN(waveMax) || isNaN(waveMin) || !isFinite(waveMax) || !isFinite(waveMin)) {
console.error('ISTFT produced invalid values (NaN/Inf)');
return new Float32Array(0);
}
// Normalize to [-1, 1]
let maxVal = Math.max(Math.abs(waveMax), Math.abs(waveMin));
if (maxVal > 0) {
for (let i = 0; i < waveform.length; i++) {
waveform[i] = (waveform[i] / maxVal) * 0.9;
}
} else {
console.warn('ISTFT produced all-zero waveform');
}
log(`Decoded audio: ${waveform.length} samples (${(waveform.length / 24000).toFixed(2)}s)`);
return waveform;
}
/**
* ISTFT with 'same' padding matching liquid_audio.
* Uses Bluestein FFT for O(N log N) IRFFT on any size.
*
* Matches Python: np.fft.irfft(spec, n_fft, axis=0, norm="backward")
*/
istftSamePadding(complexStft, nFft, hopLength, winLength, T) {
const N = complexStft.length; // nFftBins = nFft/2 + 1 = 641
const pad = Math.floor((winLength - hopLength) / 2);
// Generate Hann window
const window = new Float32Array(winLength);
for (let i = 0; i < winLength; i++) {
window[i] = 0.5 * (1 - Math.cos(2 * Math.PI * i / (winLength - 1)));
}
// Initialize Bluestein FFT for size nFft (cached for reuse)
if (!this._bluesteinCache || this._bluesteinCache.n !== nFft) {
this._bluesteinCache = this._initBluestein(nFft);
}
const bluestein = this._bluesteinCache;
// Pre-allocate buffers for IFFT
const fullRe = new Float32Array(nFft);
const fullIm = new Float32Array(nFft);
// Process all frames
const ifftFrames = new Array(T);
for (let t = 0; t < T; t++) {
// Build full spectrum from one-sided (conjugate symmetry)
fullRe.fill(0);
fullIm.fill(0);
// Copy positive frequencies
for (let k = 0; k < N; k++) {
fullRe[k] = complexStft[k][t][0];
fullIm[k] = complexStft[k][t][1];
}
// Mirror negative frequencies (conjugate symmetry for real signal)
for (let k = 1; k < N - 1; k++) {
fullRe[nFft - k] = fullRe[k];
fullIm[nFft - k] = -fullIm[k];
}
// IFFT using Bluestein: IFFT(X) = conj(FFT(conj(X))) / N
// Conjugate input
for (let i = 0; i < nFft; i++) fullIm[i] = -fullIm[i];
// FFT
this._bluesteinFFT(fullRe, fullIm, bluestein);
// Conjugate and scale
for (let i = 0; i < nFft; i++) {
fullRe[i] /= nFft;
fullIm[i] = -fullIm[i] / nFft;
}
// Apply window (take first winLength samples)
const windowedFrame = new Float32Array(winLength);
for (let n = 0; n < winLength; n++) {
windowedFrame[n] = fullRe[n] * window[n];
}
ifftFrames[t] = windowedFrame;
// Debug first frame
if (t === 0) {
let maxVal = 0;
let hasNaN = false;
for (let n = 0; n < winLength; n++) {
if (isNaN(windowedFrame[n]) || !isFinite(windowedFrame[n])) {
hasNaN = true;
break;
}
const absVal = Math.abs(windowedFrame[n] / (window[n] + 1e-10));
if (absVal > maxVal) maxVal = absVal;
}
if (hasNaN) {
console.error('IRFFT frame 0 contains NaN/Inf values!');
}
}
}
// Overlap-add
const outputSize = (T - 1) * hopLength + winLength;
const audio = new Float32Array(outputSize);
const windowEnvelope = new Float32Array(outputSize);
const windowSq = new Float32Array(winLength);
for (let i = 0; i < winLength; i++) {
windowSq[i] = window[i] * window[i];
}
for (let t = 0; t < T; t++) {
const start = t * hopLength;
for (let n = 0; n < winLength; n++) {
audio[start + n] += ifftFrames[t][n];
windowEnvelope[start + n] += windowSq[n];
}
}
// Normalize and trim padding
const trimmedLength = outputSize - 2 * pad;
const trimmed = new Float32Array(trimmedLength);
for (let i = 0; i < trimmedLength; i++) {
const srcIdx = i + pad;
if (windowEnvelope[srcIdx] > 1e-8) {
trimmed[i] = audio[srcIdx] / windowEnvelope[srcIdx];
} else {
trimmed[i] = audio[srcIdx];
}
}
return trimmed;
}
/**
* Initialize Bluestein FFT for size n (any size, not just power of 2)
*/
_initBluestein(n) {
// Bluestein's algorithm: converts any-size FFT to power-of-2 FFT via convolution
// FFT size for convolution: next power of 2 >= 2n - 1
let m = 1;
while (m < 2 * n - 1) m <<= 1;
// Chirp sequence: W_n^(k^2/2) = exp(-πi * k² / n)
const chirpRe = new Float32Array(n);
const chirpIm = new Float32Array(n);
for (let k = 0; k < n; k++) {
const angle = Math.PI * k * k / n;
chirpRe[k] = Math.cos(angle);
chirpIm[k] = -Math.sin(angle); // exp(-i*angle)
}
// Precompute FFT of chirp filter (b sequence)
// b[k] = conj(chirp[k]) for k in [0, n-1]
// b[m-k] = conj(chirp[k]) for k in [1, n-1]
// conj(chirp[k]) = chirpRe[k] - i*chirpIm[k]
const bRe = new Float32Array(m);
const bIm = new Float32Array(m);
bRe[0] = chirpRe[0];
bIm[0] = -chirpIm[0]; // conjugate
for (let k = 1; k < n; k++) {
bRe[k] = chirpRe[k];
bIm[k] = -chirpIm[k]; // conjugate
bRe[m - k] = chirpRe[k];
bIm[m - k] = -chirpIm[k]; // conjugate
}
// FFT of b (in-place)
this._fftRadix2InPlace(bRe, bIm, m, false);
// Precompute twiddle factors for radix-2 FFT of size m
const twiddleRe = new Float32Array(m / 2);
const twiddleIm = new Float32Array(m / 2);
for (let i = 0; i < m / 2; i++) {
const angle = -2 * Math.PI * i / m;
twiddleRe[i] = Math.cos(angle);
twiddleIm[i] = Math.sin(angle);
}
return { n, m, chirpRe, chirpIm, bRe, bIm, twiddleRe, twiddleIm };
}
/**
* Bluestein FFT for any size
*/
_bluesteinFFT(re, im, cache) {
const { n, m, chirpRe, chirpIm, bRe, bIm, twiddleRe, twiddleIm } = cache;
// a[k] = x[k] * chirp[k] for k in [0, n-1], zero-padded to m
// chirp[k] = chirpRe[k] + i*chirpIm[k]
// (re + i*im) * (chirpRe + i*chirpIm) = (re*chirpRe - im*chirpIm) + i*(im*chirpRe + re*chirpIm)
const aRe = new Float32Array(m);
const aIm = new Float32Array(m);
for (let k = 0; k < n; k++) {
aRe[k] = re[k] * chirpRe[k] - im[k] * chirpIm[k];
aIm[k] = im[k] * chirpRe[k] + re[k] * chirpIm[k];
}
// FFT of a
this._fftRadix2(aRe, aIm, twiddleRe, twiddleIm);
// Pointwise multiply: a = a * b
for (let k = 0; k < m; k++) {
const tmpRe = aRe[k] * bRe[k] - aIm[k] * bIm[k];
const tmpIm = aRe[k] * bIm[k] + aIm[k] * bRe[k];
aRe[k] = tmpRe;
aIm[k] = tmpIm;
}
// IFFT of a (using FFT with conjugate trick)
for (let k = 0; k < m; k++) aIm[k] = -aIm[k];
this._fftRadix2(aRe, aIm, twiddleRe, twiddleIm);
for (let k = 0; k < m; k++) {
aRe[k] /= m;
aIm[k] = -aIm[k] / m;
}
// X[k] = chirp[k] * y[k]
// Same multiplication as for a: (aRe + i*aIm) * (chirpRe + i*chirpIm)
for (let k = 0; k < n; k++) {
re[k] = aRe[k] * chirpRe[k] - aIm[k] * chirpIm[k];
im[k] = aIm[k] * chirpRe[k] + aRe[k] * chirpIm[k];
}
}
/**
* In-place radix-2 FFT (Cooley-Tukey) with precomputed twiddles
*/
_fftRadix2(re, im, twiddleRe, twiddleIm) {
const n = re.length;
// Bit-reversal permutation
for (let i = 0, j = 0; i < n; i++) {
if (i < j) {
let tmp = re[i]; re[i] = re[j]; re[j] = tmp;
tmp = im[i]; im[i] = im[j]; im[j] = tmp;
}
let k = n >> 1;
while (k > 0 && k <= j) { j -= k; k >>= 1; }
j += k;
}
// Cooley-Tukey butterflies
for (let len = 2; len <= n; len <<= 1) {
const halfLen = len >> 1;
const step = n / len;
for (let i = 0; i < n; i += len) {
for (let j = 0; j < halfLen; j++) {
const twIdx = j * step;
const wRe = twiddleRe[twIdx];
const wIm = twiddleIm[twIdx];
const u = i + j;
const v = u + halfLen;
const tRe = wRe * re[v] - wIm * im[v];
const tIm = wRe * im[v] + wIm * re[v];
re[v] = re[u] - tRe;
im[v] = im[u] - tIm;
re[u] += tRe;
im[u] += tIm;
}
}
}
}
/**
* In-place radix-2 FFT without precomputed twiddles (for initialization)
*/
_fftRadix2InPlace(re, im, n, inverse = false) {
// Bit-reversal
for (let i = 0, j = 0; i < n; i++) {
if (i < j) {
let tmp = re[i]; re[i] = re[j]; re[j] = tmp;
tmp = im[i]; im[i] = im[j]; im[j] = tmp;
}
let k = n >> 1;
while (k > 0 && k <= j) { j -= k; k >>= 1; }
j += k;
}
// Butterflies
const sign = inverse ? 1 : -1;
for (let len = 2; len <= n; len <<= 1) {
const halfLen = len >> 1;
const angle = sign * 2 * Math.PI / len;
const wRe = Math.cos(angle);
const wIm = Math.sin(angle);
for (let i = 0; i < n; i += len) {
let curRe = 1, curIm = 0;
for (let j = 0; j < halfLen; j++) {
const u = i + j;
const v = u + halfLen;
const tRe = curRe * re[v] - curIm * im[v];
const tIm = curRe * im[v] + curIm * re[v];
re[v] = re[u] - tRe;
im[v] = im[u] - tIm;
re[u] += tRe;
im[u] += tIm;
const newRe = curRe * wRe - curIm * wIm;
curIm = curRe * wIm + curIm * wRe;
curRe = newRe;
}
}
}
if (inverse) {
for (let i = 0; i < n; i++) {
re[i] /= n;
im[i] /= n;
}
}
}
/**
* Free resources
*/
dispose() {
this.tokenizer = null;
this.decoderSession = null;
this.audioEncoderSession = null;
this.audioEmbeddingSession = null;
this.audioEmbeddingWeight = null;
this.audioDetokenizerSession = null;
this.vocoderSession = null;
this.embedTokensWeight = null;
}
}
// Re-export audio utilities
export { loadAudioFile };
export default AudioModel;