ykhrustalev commited on
Commit
303ba09
·
unverified ·
1 Parent(s): baf104b

correct the render

Browse files
Files changed (3) hide show
  1. .gitignore +1 -0
  2. audio-model.js +110 -0
  3. main.js +74 -29
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ node_modules/
audio-model.js CHANGED
@@ -1289,6 +1289,7 @@ export class AudioModel {
1289
 
1290
  logReset();
1291
  log('=== Interleaved Generation ===');
 
1292
  log('Audio samples:', audioData.length, 'Sample rate:', sampleRate);
1293
 
1294
  if (!this.audioEncoderSession) {
@@ -1544,6 +1545,115 @@ export class AudioModel {
1544
  return { text, audioCodes };
1545
  }
1546
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1547
  /**
1548
  * Decode audio codes to waveform using audio detokenizer + ISTFT
1549
  * @param {number[][]} audioCodes - Array of [8] codebook values per frame
 
1289
 
1290
  logReset();
1291
  log('=== Interleaved Generation ===');
1292
+ log('Cache state:', this.cache ? `exists (seq_len=${this.cacheSeqLen})` : 'null (new conversation)');
1293
  log('Audio samples:', audioData.length, 'Sample rate:', sampleRate);
1294
 
1295
  if (!this.audioEncoderSession) {
 
1545
  return { text, audioCodes };
1546
  }
1547
 
1548
+ /**
1549
+ * Generate text-only response (for follow-up turns without audio).
1550
+ * Uses the stateful KV cache from previous interleaved turns.
1551
+ *
1552
+ * @param {string} userText - User's text input
1553
+ * @param {object} options - Generation options
1554
+ * @returns {object} - { text: string }
1555
+ */
1556
+ async generateTextOnly(userText, options = {}) {
1557
+ const {
1558
+ maxNewTokens = 256,
1559
+ temperature = 0.7,
1560
+ systemPrompt = 'You are a helpful assistant.',
1561
+ onToken,
1562
+ } = options;
1563
+
1564
+ logReset();
1565
+ log('=== Text-Only Generation ===');
1566
+ log('Cache state:', this.cache ? `exists (seq_len=${this.cacheSeqLen})` : 'null (new conversation)');
1567
+ log('User text:', userText);
1568
+
1569
+ if (!this.embedTokensWeight) {
1570
+ throw new Error('embed_tokens not loaded');
1571
+ }
1572
+
1573
+ const { hiddenSize } = this.embedTokensWeight;
1574
+
1575
+ // Build prompt based on whether we have existing cache
1576
+ let inputEmbeds;
1577
+ let newSeqLen;
1578
+
1579
+ if (this.cache === null) {
1580
+ // First turn: include system message
1581
+ log('First turn - initializing conversation');
1582
+ this.cache = this.initializeCache();
1583
+ this.cacheSeqLen = 0;
1584
+
1585
+ const promptText = `<|startoftext|><|im_start|>system\n${systemPrompt}<|im_end|>\n<|im_start|>user\n${userText}<|im_end|>\n<|im_start|>assistant\n`;
1586
+ const promptIds = Array.from(this.tokenizer.encode(promptText, { add_special_tokens: false }));
1587
+ inputEmbeds = this.getTextEmbeddings(promptIds);
1588
+ newSeqLen = promptIds.length;
1589
+ } else {
1590
+ // Continuation: just user turn
1591
+ log(`Continuing conversation (cache seq_len=${this.cacheSeqLen})`);
1592
+
1593
+ const turnText = `<|im_start|>user\n${userText}<|im_end|>\n<|im_start|>assistant\n`;
1594
+ const turnIds = Array.from(this.tokenizer.encode(turnText, { add_special_tokens: false }));
1595
+ inputEmbeds = this.getTextEmbeddings(turnIds);
1596
+ newSeqLen = turnIds.length;
1597
+ }
1598
+
1599
+ // Run prefill
1600
+ const totalLen = this.cacheSeqLen + newSeqLen;
1601
+ const attentionMask = new ort.Tensor('int64', new BigInt64Array(totalLen).fill(1n), [1, totalLen]);
1602
+
1603
+ let { logits, outputs } = await this.runDecoder(inputEmbeds, attentionMask, this.cache);
1604
+ this.updateCache(this.cache, outputs);
1605
+ this.cacheSeqLen = totalLen;
1606
+
1607
+ // Generate tokens
1608
+ const textTokens = [];
1609
+ let currentLen = totalLen;
1610
+
1611
+ for (let i = 0; i < maxNewTokens; i++) {
1612
+ const logitsData = logits.data;
1613
+ const seqLen = logits.dims[1];
1614
+ const lastLogits = new Float32Array(this.vocabSize);
1615
+ const offset = (seqLen - 1) * this.vocabSize;
1616
+ for (let j = 0; j < this.vocabSize; j++) {
1617
+ lastLogits[j] = logitsData[offset + j];
1618
+ }
1619
+ const nextToken = this.sampleToken(lastLogits, temperature);
1620
+
1621
+ // Check for stop tokens
1622
+ if (nextToken === this.tokenizer.eos_token_id || nextToken === SPECIAL_TOKENS.IM_END) {
1623
+ log('Stop token reached');
1624
+ break;
1625
+ }
1626
+
1627
+ textTokens.push(nextToken);
1628
+
1629
+ if (onToken) {
1630
+ const text = this.tokenizer.decode(textTokens, { skip_special_tokens: true });
1631
+ onToken(text, nextToken);
1632
+ }
1633
+
1634
+ // Get embedding for next token
1635
+ const nextEmbeds = this.getTextEmbeddings([nextToken]);
1636
+ currentLen++;
1637
+ const nextMask = new ort.Tensor('int64', new BigInt64Array(currentLen).fill(1n), [1, currentLen]);
1638
+ ({ logits, outputs } = await this.runDecoder(nextEmbeds, nextMask, this.cache));
1639
+ this.updateCache(this.cache, outputs);
1640
+ }
1641
+
1642
+ // Feed <|im_end|> to close turn
1643
+ const imEndEmbeds = this.getTextEmbeddings([SPECIAL_TOKENS.IM_END]);
1644
+ currentLen++;
1645
+ const finalMask = new ort.Tensor('int64', new BigInt64Array(currentLen).fill(1n), [1, currentLen]);
1646
+ ({ outputs } = await this.runDecoder(imEndEmbeds, finalMask, this.cache));
1647
+ this.updateCache(this.cache, outputs);
1648
+ this.cacheSeqLen = currentLen;
1649
+
1650
+ const text = this.tokenizer.decode(textTokens, { skip_special_tokens: true });
1651
+ log(`Generated ${textTokens.length} tokens: "${text}"`);
1652
+ log(`Cache seq_len: ${this.cacheSeqLen}`);
1653
+
1654
+ return { text };
1655
+ }
1656
+
1657
  /**
1658
  * Decode audio codes to waveform using audio detokenizer + ISTFT
1659
  * @param {number[][]} audioCodes - Array of [8] codebook values per frame
main.js CHANGED
@@ -62,6 +62,24 @@ let audioChunks = [];
62
  // ============================================================================
63
 
64
  function createWavBlob(samples, sampleRate) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  const numChannels = 1;
66
  const bitsPerSample = 16;
67
  const bytesPerSample = bitsPerSample / 8;
@@ -104,6 +122,16 @@ function createWavBlob(samples, sampleRate) {
104
  return new Blob([buffer], { type: 'audio/wav' });
105
  }
106
 
 
 
 
 
 
 
 
 
 
 
107
  // ============================================================================
108
  // UI Helpers
109
  // ============================================================================
@@ -441,22 +469,19 @@ async function generate(userMessage) {
441
  if (waveform.length > 0) {
442
  generatedText = result.textOutput || `Generated ${result.audioCodes.length} audio frames (${(waveform.length / 24000).toFixed(2)}s)`;
443
 
444
- // Create audio player
445
- const audioMsgEl = document.createElement('div');
446
- audioMsgEl.className = 'message assistant';
447
  const wavBlob = createWavBlob(waveform, 24000);
448
- console.log('TTS WAV blob created:', wavBlob.size, 'bytes');
449
  const audioUrl = URL.createObjectURL(wavBlob);
450
 
451
- const audioEl = document.createElement('audio');
452
- audioEl.controls = true;
453
- const sourceEl = document.createElement('source');
454
- sourceEl.src = audioUrl;
455
- sourceEl.type = 'audio/wav';
456
- audioEl.appendChild(sourceEl);
457
-
458
- audioMsgEl.appendChild(audioEl);
459
- chatContainer.appendChild(audioMsgEl);
460
  chatContainer.scrollTop = chatContainer.scrollHeight;
461
  } else {
462
  generatedText = '[Audio decoding failed - no waveform generated]';
@@ -503,34 +528,54 @@ async function generate(userMessage) {
503
  generatedText = `Generated ${result.audioCodes.length} audio frames`;
504
  }
505
 
506
- // Create audio player
507
- const audioMsgEl = document.createElement('div');
508
- audioMsgEl.className = 'message assistant';
509
  const wavBlob = createWavBlob(waveform, 24000);
510
- console.log('WAV blob created:', wavBlob.size, 'bytes');
511
  const audioUrl = URL.createObjectURL(wavBlob);
512
 
513
- const audioEl = document.createElement('audio');
514
- audioEl.controls = true;
515
- const sourceEl = document.createElement('source');
516
- sourceEl.src = audioUrl;
517
- sourceEl.type = 'audio/wav';
518
- audioEl.appendChild(sourceEl);
519
-
520
- audioMsgEl.appendChild(audioEl);
521
- chatContainer.appendChild(audioMsgEl);
522
  chatContainer.scrollTop = chatContainer.scrollHeight;
523
  } else {
524
  console.warn('Waveform decoding returned empty result');
525
  }
526
  }
527
 
528
- } else {
 
529
  showSpinner('Generating response...');
530
- generatedText = await audioModel.generate(messages, {
531
  maxNewTokens: 256,
532
- onToken: onTokenCallback,
 
 
 
 
 
 
 
533
  });
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
534
  }
535
 
536
  generatedText = generatedText.replace(/<\|im_end\|>$/g, '').trim();
 
62
  // ============================================================================
63
 
64
  function createWavBlob(samples, sampleRate) {
65
+ // Debug: check waveform statistics
66
+ let min = Infinity, max = -Infinity, sum = 0, nonZero = 0;
67
+ for (let i = 0; i < samples.length; i++) {
68
+ const v = samples[i];
69
+ if (v < min) min = v;
70
+ if (v > max) max = v;
71
+ sum += Math.abs(v);
72
+ if (Math.abs(v) > 0.001) nonZero++;
73
+ }
74
+ console.log('WAV input stats:', {
75
+ length: samples.length,
76
+ min: min.toFixed(6),
77
+ max: max.toFixed(6),
78
+ avgAbs: (sum / samples.length).toFixed(6),
79
+ nonZeroSamples: nonZero,
80
+ percentNonZero: ((nonZero / samples.length) * 100).toFixed(1) + '%'
81
+ });
82
+
83
  const numChannels = 1;
84
  const bitsPerSample = 16;
85
  const bytesPerSample = bitsPerSample / 8;
 
122
  return new Blob([buffer], { type: 'audio/wav' });
123
  }
124
 
125
+ // Test function to verify WAV creation works
126
+ function createTestToneBlob(durationSec = 1, frequency = 440, sampleRate = 24000) {
127
+ const numSamples = Math.floor(durationSec * sampleRate);
128
+ const samples = new Float32Array(numSamples);
129
+ for (let i = 0; i < numSamples; i++) {
130
+ samples[i] = 0.5 * Math.sin(2 * Math.PI * frequency * i / sampleRate);
131
+ }
132
+ return createWavBlob(samples, sampleRate);
133
+ }
134
+
135
  // ============================================================================
136
  // UI Helpers
137
  // ============================================================================
 
469
  if (waveform.length > 0) {
470
  generatedText = result.textOutput || `Generated ${result.audioCodes.length} audio frames (${(waveform.length / 24000).toFixed(2)}s)`;
471
 
472
+ // Create audio player inline with the message
 
 
473
  const wavBlob = createWavBlob(waveform, 24000);
474
+ console.log('TTS WAV blob created:', wavBlob.size, 'bytes, duration:', (waveform.length / 24000).toFixed(2), 's');
475
  const audioUrl = URL.createObjectURL(wavBlob);
476
 
477
+ // Add audio element to the existing message
478
+ const audioContainer = document.createElement('div');
479
+ audioContainer.style.marginTop = '0.75rem';
480
+ audioContainer.innerHTML = `
481
+ <audio controls preload="auto" src="${audioUrl}" style="width:100%;max-width:360px;display:block;"></audio>
482
+ <a href="${audioUrl}" download="generated_audio.wav" style="display:block;font-size:0.7rem;margin-top:0.25rem;color:#666;">Download WAV (${(waveform.length / 24000).toFixed(1)}s)</a>
483
+ `;
484
+ msgEl.appendChild(audioContainer);
 
485
  chatContainer.scrollTop = chatContainer.scrollHeight;
486
  } else {
487
  generatedText = '[Audio decoding failed - no waveform generated]';
 
528
  generatedText = `Generated ${result.audioCodes.length} audio frames`;
529
  }
530
 
531
+ // Create audio player inline with the message
 
 
532
  const wavBlob = createWavBlob(waveform, 24000);
533
+ console.log('WAV blob created:', wavBlob.size, 'bytes, duration:', (waveform.length / 24000).toFixed(2), 's');
534
  const audioUrl = URL.createObjectURL(wavBlob);
535
 
536
+ // Add audio element to the existing message
537
+ const audioContainer = document.createElement('div');
538
+ audioContainer.style.marginTop = '0.75rem';
539
+ audioContainer.innerHTML = `
540
+ <audio controls preload="auto" src="${audioUrl}" style="width:100%;max-width:360px;display:block;"></audio>
541
+ <a href="${audioUrl}" download="generated_audio.wav" style="display:block;font-size:0.7rem;margin-top:0.25rem;color:#666;">Download WAV (${(waveform.length / 24000).toFixed(1)}s)</a>
542
+ `;
543
+ msgEl.appendChild(audioContainer);
 
544
  chatContainer.scrollTop = chatContainer.scrollHeight;
545
  } else {
546
  console.warn('Waveform decoding returned empty result');
547
  }
548
  }
549
 
550
+ } else if (currentMode === 'interleaved' && userMessage) {
551
+ // Text-only follow-up in interleaved mode
552
  showSpinner('Generating response...');
553
+ const result = await audioModel.generateTextOnly(userMessage, {
554
  maxNewTokens: 256,
555
+ onToken: (text, tokenId) => {
556
+ generatedText = text;
557
+ tokenCount = text.length;
558
+ textEl.textContent = text;
559
+ chatContainer.scrollTop = chatContainer.scrollHeight;
560
+ const elapsed = ((performance.now() - startTime) / 1000).toFixed(1);
561
+ updateSpinner('Generating...', `${tokenCount} chars · ${elapsed}s`);
562
+ },
563
  });
564
+ generatedText = result.text || '';
565
+
566
+ } else if (userMessage) {
567
+ // Fallback text-only generation
568
+ showSpinner('Generating response...');
569
+ const result = await audioModel.generateTextOnly(userMessage, {
570
+ maxNewTokens: 256,
571
+ onToken: (text, tokenId) => {
572
+ generatedText = text;
573
+ tokenCount = text.length;
574
+ textEl.textContent = text;
575
+ chatContainer.scrollTop = chatContainer.scrollHeight;
576
+ },
577
+ });
578
+ generatedText = result.text || '';
579
  }
580
 
581
  generatedText = generatedText.replace(/<\|im_end\|>$/g, '').trim();