KevinAHM commited on
Commit
ea8d726
·
verified ·
1 Parent(s): 59b2bdd

Fix deterministic audio preprocessing on Windows

Browse files
Files changed (1) hide show
  1. index.html +156 -22
index.html CHANGED
@@ -367,6 +367,10 @@ function handleFile(file) {
367
 
368
  function updateBtn() { processBtn.disabled = !(fileBuffer && session); }
369
 
 
 
 
 
370
  // -- Detect backend & load model --
371
  async function init() {
372
  // Detect WebGPU and patch device creation to raise storage buffer limits
@@ -411,7 +415,12 @@ async function init() {
411
 
412
  const ep = backend === 'webgpu' ? 'webgpu' : 'wasm';
413
  const opts = { executionProviders: [ep] };
414
- if (ep === 'wasm') {
 
 
 
 
 
415
  opts.executionProviders = [{ name: 'wasm', options: { numThreads: navigator.hardwareConcurrency || 4 } }];
416
  }
417
 
@@ -445,25 +454,148 @@ async function init() {
445
  updateBtn();
446
  }
447
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
448
  // -- Decode audio to mono 16kHz Float32 --
449
  async function decodeToMono16k(arrayBuffer) {
450
- const audioCtx = new OfflineAudioContext(1, 1, INPUT_SR);
451
- const decoded = await audioCtx.decodeAudioData(arrayBuffer.slice(0));
452
- const origSr = decoded.sampleRate;
453
- const origData = decoded.getChannelData(0);
454
-
455
- // Resample to 16kHz
456
- const ratio = INPUT_SR / origSr;
457
- const outLen = Math.round(origData.length * ratio);
458
- const ctx2 = new OfflineAudioContext(1, outLen, INPUT_SR);
459
- const src = ctx2.createBufferSource();
460
- const buf = ctx2.createBuffer(1, origData.length, origSr);
461
- buf.getChannelData(0).set(origData);
462
- src.buffer = buf;
463
- src.connect(ctx2.destination);
464
- src.start();
465
- const rendered = await ctx2.startRendering();
466
- return rendered.getChannelData(0);
467
  }
468
 
469
  // -- Process --
@@ -480,7 +612,7 @@ processBtn.addEventListener('click', async () => {
480
  const totalSamples = audio16k.length;
481
  const audioDuration = totalSamples / INPUT_SR;
482
 
483
- // Chunk sizing: CPU=1000ms, GPU=30s
484
  const chunkMs = backend === 'webgpu' ? 5000 : 1000;
485
  const chunkHops = Math.max(1, Math.floor(chunkMs / 1000 * INPUT_SR / HOP));
486
  const chunkSamples = chunkHops * HOP;
@@ -522,8 +654,10 @@ processBtn.addEventListener('click', async () => {
522
  state_in: stateTensor,
523
  });
524
 
525
- outputs.push(new Float32Array(result.audio_out.data));
526
- state = new Float32Array(result.state_out.data);
 
 
527
 
528
  chunkIdx++;
529
  const pct = Math.round(chunkIdx / numChunks * 100);
@@ -595,7 +729,7 @@ function encodeWav(samples, sr) {
595
 
596
  // Load ORT and init
597
  const script = document.createElement('script');
598
- script.src = 'https://cdn.jsdelivr.net/npm/onnxruntime-web@1.22.0/dist/ort.min.js';
599
  script.crossOrigin = 'anonymous';
600
  script.onload = () => {
601
  ort.env.wasm.numThreads = navigator.hardwareConcurrency || 4;
 
367
 
368
  function updateBtn() { processBtn.disabled = !(fileBuffer && session); }
369
 
370
+ async function readTensorData(tensor) {
371
+ return typeof tensor.getData === 'function' ? await tensor.getData() : tensor.data;
372
+ }
373
+
374
  // -- Detect backend & load model --
375
  async function init() {
376
  // Detect WebGPU and patch device creation to raise storage buffer limits
 
415
 
416
  const ep = backend === 'webgpu' ? 'webgpu' : 'wasm';
417
  const opts = { executionProviders: [ep] };
418
+ if (ep === 'webgpu') {
419
+ opts.preferredOutputLocation = {
420
+ audio_out: 'cpu',
421
+ state_out: 'cpu',
422
+ };
423
+ } else {
424
  opts.executionProviders = [{ name: 'wasm', options: { numThreads: navigator.hardwareConcurrency || 4 } }];
425
  }
426
 
 
454
  updateBtn();
455
  }
456
 
457
+ function mixToMono(audioBuffer) {
458
+ const len = audioBuffer.length;
459
+ const mono = new Float32Array(len);
460
+ for (let ch = 0; ch < audioBuffer.numberOfChannels; ch++) {
461
+ const data = audioBuffer.getChannelData(ch);
462
+ for (let i = 0; i < len; i++) mono[i] += data[i];
463
+ }
464
+ const gain = 1 / audioBuffer.numberOfChannels;
465
+ for (let i = 0; i < len; i++) mono[i] *= gain;
466
+ return mono;
467
+ }
468
+
469
+ function readFourCc(view, offset) {
470
+ return String.fromCharCode(
471
+ view.getUint8(offset),
472
+ view.getUint8(offset + 1),
473
+ view.getUint8(offset + 2),
474
+ view.getUint8(offset + 3)
475
+ );
476
+ }
477
+
478
+ function decodeWavToMono(arrayBuffer) {
479
+ if (arrayBuffer.byteLength < 44) return null;
480
+ const view = new DataView(arrayBuffer);
481
+ if (readFourCc(view, 0) !== 'RIFF' || readFourCc(view, 8) !== 'WAVE') return null;
482
+
483
+ let offset = 12;
484
+ let fmt = null;
485
+ let dataOffset = 0;
486
+ let dataSize = 0;
487
+
488
+ while (offset + 8 <= view.byteLength) {
489
+ const id = readFourCc(view, offset);
490
+ const size = view.getUint32(offset + 4, true);
491
+ const chunkStart = offset + 8;
492
+ if (id === 'fmt ') {
493
+ const format = view.getUint16(chunkStart, true);
494
+ fmt = {
495
+ format: format === 0xfffe && size >= 40 ? view.getUint16(chunkStart + 24, true) : format,
496
+ channels: view.getUint16(chunkStart + 2, true),
497
+ sampleRate: view.getUint32(chunkStart + 4, true),
498
+ blockAlign: view.getUint16(chunkStart + 12, true),
499
+ bitsPerSample: view.getUint16(chunkStart + 14, true),
500
+ };
501
+ } else if (id === 'data') {
502
+ dataOffset = chunkStart;
503
+ dataSize = size;
504
+ break;
505
+ }
506
+ offset = chunkStart + size + (size % 2);
507
+ }
508
+
509
+ if (!fmt || !dataOffset || !dataSize) return null;
510
+ if (fmt.format !== 1 && fmt.format !== 3) return null;
511
+ const bytesPerSample = fmt.bitsPerSample / 8;
512
+ if (!Number.isInteger(bytesPerSample) || bytesPerSample < 1) return null;
513
+ const frames = Math.floor(dataSize / fmt.blockAlign);
514
+ const mono = new Float32Array(frames);
515
+
516
+ const readSample = (pos) => {
517
+ if (fmt.format === 3 && fmt.bitsPerSample === 32) return view.getFloat32(pos, true);
518
+ if (fmt.format !== 1) return 0;
519
+ if (fmt.bitsPerSample === 8) return (view.getUint8(pos) - 128) / 128;
520
+ if (fmt.bitsPerSample === 16) return view.getInt16(pos, true) / 32768;
521
+ if (fmt.bitsPerSample === 24) {
522
+ let v = view.getUint8(pos) | (view.getUint8(pos + 1) << 8) | (view.getUint8(pos + 2) << 16);
523
+ if (v & 0x800000) v |= 0xff000000;
524
+ return v / 8388608;
525
+ }
526
+ if (fmt.bitsPerSample === 32) return view.getInt32(pos, true) / 2147483648;
527
+ return 0;
528
+ };
529
+
530
+ for (let frame = 0; frame < frames; frame++) {
531
+ const frameOffset = dataOffset + frame * fmt.blockAlign;
532
+ let sum = 0;
533
+ for (let ch = 0; ch < fmt.channels; ch++) {
534
+ sum += readSample(frameOffset + ch * bytesPerSample);
535
+ }
536
+ mono[frame] = sum / fmt.channels;
537
+ }
538
+
539
+ return {
540
+ mono,
541
+ sampleRate: fmt.sampleRate,
542
+ channels: fmt.channels,
543
+ source: 'wav',
544
+ };
545
+ }
546
+
547
+ function sinc(x) {
548
+ if (Math.abs(x) < 1e-8) return 1;
549
+ const pix = Math.PI * x;
550
+ return Math.sin(pix) / pix;
551
+ }
552
+
553
+ function resampleSinc(input, inSr, outSr) {
554
+ if (inSr === outSr) return new Float32Array(input);
555
+ const outLen = Math.round(input.length * outSr / inSr);
556
+ const output = new Float32Array(outLen);
557
+ const ratio = inSr / outSr;
558
+ const cutoff = Math.min(1, outSr / inSr) * 0.95;
559
+ const radius = 12;
560
+ const support = radius / cutoff;
561
+
562
+ for (let i = 0; i < outLen; i++) {
563
+ const center = i * ratio;
564
+ const left = Math.max(0, Math.ceil(center - support));
565
+ const right = Math.min(input.length - 1, Math.floor(center + support));
566
+ let sum = 0;
567
+ let weightSum = 0;
568
+
569
+ for (let j = left; j <= right; j++) {
570
+ const x = (center - j) * cutoff;
571
+ const weight = sinc(x) * sinc(x / radius);
572
+ sum += input[j] * weight;
573
+ weightSum += weight;
574
+ }
575
+ output[i] = weightSum ? sum / weightSum : 0;
576
+ }
577
+ return output;
578
+ }
579
+
580
  // -- Decode audio to mono 16kHz Float32 --
581
  async function decodeToMono16k(arrayBuffer) {
582
+ let decodedAudio = decodeWavToMono(arrayBuffer);
583
+ if (!decodedAudio) {
584
+ const AudioCtx = window.AudioContext || window.webkitAudioContext;
585
+ const audioCtx = new AudioCtx();
586
+ const decoded = await audioCtx.decodeAudioData(arrayBuffer.slice(0));
587
+ await audioCtx.close();
588
+ decodedAudio = {
589
+ mono: mixToMono(decoded),
590
+ sampleRate: decoded.sampleRate,
591
+ channels: decoded.numberOfChannels,
592
+ source: 'webaudio',
593
+ };
594
+ }
595
+ const origSr = decodedAudio.sampleRate;
596
+ const mono = decodedAudio.mono;
597
+ const audio16k = resampleSinc(mono, origSr, INPUT_SR);
598
+ return audio16k;
599
  }
600
 
601
  // -- Process --
 
612
  const totalSamples = audio16k.length;
613
  const audioDuration = totalSamples / INPUT_SR;
614
 
615
+ // Chunk sizing: CPU=1000ms, GPU=5000ms
616
  const chunkMs = backend === 'webgpu' ? 5000 : 1000;
617
  const chunkHops = Math.max(1, Math.floor(chunkMs / 1000 * INPUT_SR / HOP));
618
  const chunkSamples = chunkHops * HOP;
 
654
  state_in: stateTensor,
655
  });
656
 
657
+ const audioOut = await readTensorData(result.audio_out);
658
+ const stateOut = await readTensorData(result.state_out);
659
+ outputs.push(new Float32Array(audioOut));
660
+ state = new Float32Array(stateOut);
661
 
662
  chunkIdx++;
663
  const pct = Math.round(chunkIdx / numChunks * 100);
 
729
 
730
  // Load ORT and init
731
  const script = document.createElement('script');
732
+ script.src = 'https://cdn.jsdelivr.net/npm/onnxruntime-web@1.22.0/dist/ort.webgpu.min.js';
733
  script.crossOrigin = 'anonymous';
734
  script.onload = () => {
735
  ort.env.wasm.numThreads = navigator.hardwareConcurrency || 4;