Spaces:
Running
Running
| /** | |
| * End-to-End mel spectrogram tests using real audio and ONNX reference data. | |
| * | |
| * Tests: | |
| * 1. Cross-validation against ONNX reference (mel_reference.json from parakeet.js) | |
| * 2. Real WAV file processing (life_Jim.wav from parakeet.js demo) | |
| * 3. Mel filterbank accuracy against ONNX reference | |
| * | |
| * These tests catch regressions in mel computation that unit tests might miss, | |
| * such as incorrect normalization, wrong filterbank values, or precision issues. | |
| * | |
| * The ONNX reference is generated by parakeet.js's tests/generate_mel_reference.py | |
| * using the official NeMo ONNX preprocessor as ground truth. | |
| * | |
| * Run: npm test | |
| */ | |
| import { readFileSync, existsSync } from 'fs'; | |
| import { join } from 'path'; | |
| import https from 'https'; | |
| import { describe, it, expect, beforeAll } from 'vitest'; | |
| import { | |
| MEL_CONSTANTS, | |
| hzToMel, | |
| melToHz, | |
| createMelFilterbank, | |
| createPaddedHannWindow, | |
| precomputeTwiddles, | |
| fft, | |
| preemphasize, | |
| computeMelFrame, | |
| normalizeMelFeatures, | |
| sampleToFrame, | |
| } from './mel-math'; | |
| import { resampleLinear } from './utils'; | |
| // βββ Helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| /** Decode base64 to Float32Array (matching parakeet.js test format) */ | |
| function base64ToFloat32(b64: string): Float32Array { | |
| const buf = Buffer.from(b64, 'base64'); | |
| return new Float32Array(buf.buffer, buf.byteOffset, buf.byteLength / Float32Array.BYTES_PER_ELEMENT); | |
| } | |
| /** Compute error metrics between two arrays */ | |
| function computeError(actual: Float32Array, expected: Float32Array, validCount?: number) { | |
| const n = validCount || Math.min(actual.length, expected.length); | |
| let maxErr = 0; | |
| let sumErr = 0; | |
| for (let i = 0; i < n; i++) { | |
| const err = Math.abs(actual[i] - expected[i]); | |
| sumErr += err; | |
| if (err > maxErr) maxErr = err; | |
| } | |
| return { | |
| maxAbsError: maxErr, | |
| meanAbsError: sumErr / n, | |
| n, | |
| }; | |
| } | |
| /** Parse a 16-bit PCM WAV file into Float32Array at the native sample rate */ | |
| function parseWav(buffer: ArrayBuffer): { audio: Float32Array; sampleRate: number; channels: number } { | |
| const view = new DataView(buffer); | |
| // RIFF header | |
| const riff = String.fromCharCode(view.getUint8(0), view.getUint8(1), view.getUint8(2), view.getUint8(3)); | |
| if (riff !== 'RIFF') throw new Error('Not a valid WAV file: missing RIFF header'); | |
| const wave = String.fromCharCode(view.getUint8(8), view.getUint8(9), view.getUint8(10), view.getUint8(11)); | |
| if (wave !== 'WAVE') throw new Error('Not a valid WAV file: missing WAVE format'); | |
| // Find fmt and data chunks | |
| let offset = 12; | |
| let sampleRate = 0; | |
| let channels = 0; | |
| let bitsPerSample = 0; | |
| let dataOffset = 0; | |
| let dataSize = 0; | |
| while (offset < buffer.byteLength - 8) { | |
| const chunkId = String.fromCharCode( | |
| view.getUint8(offset), view.getUint8(offset + 1), | |
| view.getUint8(offset + 2), view.getUint8(offset + 3), | |
| ); | |
| const chunkSize = view.getUint32(offset + 4, true); | |
| if (chunkId === 'fmt ') { | |
| channels = view.getUint16(offset + 10, true); | |
| sampleRate = view.getUint32(offset + 12, true); | |
| bitsPerSample = view.getUint16(offset + 22, true); | |
| } else if (chunkId === 'data') { | |
| dataOffset = offset + 8; | |
| dataSize = chunkSize; | |
| break; | |
| } | |
| offset += 8 + chunkSize; | |
| // Align to even byte boundary | |
| if (chunkSize % 2 !== 0) offset++; | |
| } | |
| if (dataOffset === 0) throw new Error('No data chunk found in WAV file'); | |
| if (bitsPerSample !== 16) throw new Error(`Unsupported bit depth: ${bitsPerSample} (expected 16)`); | |
| // Extract PCM samples and convert to Float32 [-1, 1] | |
| const numSamples = dataSize / (bitsPerSample / 8) / channels; | |
| const audio = new Float32Array(numSamples); | |
| for (let i = 0; i < numSamples; i++) { | |
| // Read first channel (mono or left channel) | |
| const sampleOffset = dataOffset + i * channels * 2; | |
| const sample = view.getInt16(sampleOffset, true); | |
| audio[i] = sample / 32768.0; | |
| } | |
| return { audio, sampleRate, channels }; | |
| } | |
| /** | |
| * Run our full mel pipeline on raw PCM audio. | |
| * Matches the JsPreprocessor.process() pipeline in parakeet.js/src/mel.js. | |
| */ | |
| function fullMelPipeline(audio: Float32Array, nMels: number = 128) { | |
| const { N_FFT, HOP_LENGTH, PREEMPH } = MEL_CONSTANTS; | |
| // 1. Pre-emphasize | |
| const preemph = preemphasize(audio, 0, PREEMPH); | |
| // 2. Compute mel frames | |
| const numFrames = sampleToFrame(audio.length); | |
| if (numFrames === 0) return { features: new Float32Array(0), T: 0 }; | |
| const hannWindow = createPaddedHannWindow(); | |
| const twiddles = precomputeTwiddles(N_FFT); | |
| const melFilterbank = createMelFilterbank(nMels); | |
| // Raw mel buffer [nMels Γ numFrames], mel-major layout | |
| const rawMel = new Float32Array(nMels * numFrames); | |
| for (let t = 0; t < numFrames; t++) { | |
| const frame = computeMelFrame(preemph, t, hannWindow, twiddles, melFilterbank, nMels); | |
| for (let m = 0; m < nMels; m++) { | |
| rawMel[m * numFrames + t] = frame[m]; | |
| } | |
| } | |
| // 3. Normalize | |
| const features = normalizeMelFeatures(rawMel, nMels, numFrames); | |
| return { features, T: numFrames }; | |
| } | |
| // βββ Paths ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| // parakeet.js is sibling to keet: __dirname = src/lib/audio, 4 levels up = N:\github\ysdede | |
| const PARAKEET_ROOT = join(__dirname, '..', '..', '..', '..', 'parakeet.js'); | |
| const MEL_REFERENCE_PATH = join(PARAKEET_ROOT, 'tests', 'mel_reference.json'); | |
| const WAV_LOCAL_PATH = join(PARAKEET_ROOT, 'examples', 'demo', 'public', 'assets', 'life_Jim.wav'); | |
| const WAV_GITHUB_URL = 'https://github.com/ysdede/parakeet.js/raw/refs/heads/master/examples/demo/public/assets/life_Jim.wav'; | |
| // βββ ONNX Reference Cross-Validation βββββββββββββββββββββββββββββββββββββ | |
| describe('Cross-validation against ONNX reference', () => { | |
| let reference: any; | |
| let hasReference = false; | |
| beforeAll(() => { | |
| try { | |
| if (existsSync(MEL_REFERENCE_PATH)) { | |
| const content = readFileSync(MEL_REFERENCE_PATH, 'utf-8'); | |
| reference = JSON.parse(content); | |
| hasReference = true; | |
| } | |
| } catch { | |
| // Reference not available β tests will be skipped | |
| } | |
| }); | |
| it('should load mel_reference.json from parakeet.js', () => { | |
| if (!hasReference) { | |
| console.log(`SKIP: mel_reference.json not found at ${MEL_REFERENCE_PATH}`); | |
| console.log('Run: cd ../parakeet.js && python tests/generate_mel_reference.py'); | |
| return; | |
| } | |
| expect(reference).toBeDefined(); | |
| expect(reference.nMels).toBe(128); | |
| expect(reference.tests).toBeDefined(); | |
| }); | |
| it('should match ONNX mel filterbank within 1e-5', () => { | |
| if (!hasReference || !reference.melFilterbank) { | |
| console.log('SKIP: No filterbank reference'); | |
| return; | |
| } | |
| const refFb = base64ToFloat32(reference.melFilterbank.data); | |
| const refShape = reference.melFilterbank.shape; // [257, 128] | |
| const jsFb = createMelFilterbank(128); | |
| // Compare (ref is [257,128] row-major, ours is [128,257] row-major) | |
| let maxErr = 0; | |
| for (let freq = 0; freq < 257; freq++) { | |
| for (let mel = 0; mel < 128; mel++) { | |
| const refVal = refFb[freq * 128 + mel]; | |
| const jsVal = jsFb[mel * 257 + freq]; | |
| const err = Math.abs(refVal - jsVal); | |
| if (err > maxErr) maxErr = err; | |
| } | |
| } | |
| console.log(`Filterbank max error vs ONNX: ${maxErr.toExponential(3)}`); | |
| expect(maxErr).toBeLessThan(1e-5); | |
| }); | |
| it('should match ONNX full pipeline for each test signal (max<0.05, mean<0.005)', () => { | |
| if (!hasReference) { | |
| console.log('SKIP: No reference data'); | |
| return; | |
| } | |
| const nMels = reference.nMels; | |
| for (const [name, test] of Object.entries(reference.tests) as [string, any][]) { | |
| const audio = base64ToFloat32(test.audio); | |
| const refFeatures = base64ToFloat32(test.features); | |
| const refLen = test.featuresLen; | |
| // Run our pipeline | |
| const { features: ourFeatures, T: ourLen } = fullMelPipeline(audio, nMels); | |
| console.log(`Signal "${name}": ${audio.length} samples (${(audio.length / 16000).toFixed(2)}s), ` + | |
| `frames: ours=${ourLen}, ref=${refLen}`); | |
| // Frame count should match | |
| expect(ourLen).toBe(refLen); | |
| // Compare valid frames (mel-by-mel) | |
| const nFramesOurs = ourFeatures.length / nMels; | |
| const nFramesRef = refFeatures.length / nMels; | |
| let maxErr = 0; | |
| let sumErr = 0; | |
| let n = 0; | |
| for (let m = 0; m < nMels; m++) { | |
| for (let t = 0; t < ourLen; t++) { | |
| const ourVal = ourFeatures[m * nFramesOurs + t]; | |
| const refVal = refFeatures[m * nFramesRef + t]; | |
| const err = Math.abs(ourVal - refVal); | |
| sumErr += err; | |
| if (err > maxErr) maxErr = err; | |
| n++; | |
| } | |
| } | |
| const meanErr = sumErr / n; | |
| console.log(` Max error: ${maxErr.toExponential(3)}, Mean error: ${meanErr.toExponential(3)}`); | |
| // Same thresholds as parakeet.js test_mel.mjs | |
| expect(maxErr).toBeLessThan(0.05); | |
| expect(meanErr).toBeLessThan(0.005); | |
| } | |
| }); | |
| }); | |
| // βββ Real WAV File Tests ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| describe('Real audio: life_Jim.wav', () => { | |
| let audioData: Float32Array; | |
| let audioDuration: number; | |
| const EXPECTED_TRANSCRIPT = 'it is not life as we know or understand it'; | |
| beforeAll(async () => { | |
| let wavBuffer: ArrayBuffer; | |
| if (existsSync(WAV_LOCAL_PATH)) { | |
| // Read local file (fast, no network dependency) | |
| const fileBuffer = readFileSync(WAV_LOCAL_PATH); | |
| wavBuffer = fileBuffer.buffer.slice( | |
| fileBuffer.byteOffset, | |
| fileBuffer.byteOffset + fileBuffer.byteLength, | |
| ); | |
| console.log(`Loaded local WAV: ${WAV_LOCAL_PATH} (${fileBuffer.length} bytes)`); | |
| } else { | |
| // Download from GitHub using Node.js https (happy-dom blocks CORS fetch) | |
| console.log(`Local WAV not found, downloading from ${WAV_GITHUB_URL}`); | |
| wavBuffer = await new Promise<ArrayBuffer>((resolve, reject) => { | |
| const download = (url: string, redirects = 0) => { | |
| if (redirects > 5) return reject(new Error('Too many redirects')); | |
| https.get(url, (res) => { | |
| // Follow redirects (GitHub sends 301/302) | |
| if (res.statusCode && res.statusCode >= 300 && res.statusCode < 400 && res.headers.location) { | |
| return download(res.headers.location, redirects + 1); | |
| } | |
| if (res.statusCode !== 200) return reject(new Error(`HTTP ${res.statusCode}`)); | |
| const chunks: Buffer[] = []; | |
| res.on('data', (chunk: Buffer) => chunks.push(chunk)); | |
| res.on('end', () => { | |
| const buf = Buffer.concat(chunks); | |
| resolve(buf.buffer.slice(buf.byteOffset, buf.byteOffset + buf.byteLength)); | |
| }); | |
| res.on('error', reject); | |
| }).on('error', reject); | |
| }; | |
| download(WAV_GITHUB_URL); | |
| }); | |
| console.log(`Downloaded WAV: ${wavBuffer.byteLength} bytes`); | |
| } | |
| // Parse WAV | |
| const { audio, sampleRate, channels } = parseWav(wavBuffer); | |
| console.log(`Parsed WAV: ${audio.length} samples, ${sampleRate} Hz, ${channels} ch`); | |
| // Resample to 16kHz if needed | |
| if (sampleRate !== 16000) { | |
| audioData = resampleLinear(audio, sampleRate, 16000); | |
| console.log(`Resampled: ${audio.length} β ${audioData.length} samples (${sampleRate} β 16000 Hz)`); | |
| } else { | |
| audioData = audio; | |
| } | |
| audioDuration = audioData.length / 16000; | |
| console.log(`Audio duration: ${audioDuration.toFixed(2)}s`); | |
| }); | |
| it('should parse the WAV file correctly', () => { | |
| expect(audioData).toBeInstanceOf(Float32Array); | |
| expect(audioData.length).toBeGreaterThan(0); | |
| // life_Jim.wav is about 1.4 seconds of speech | |
| expect(audioDuration).toBeGreaterThan(0.5); | |
| expect(audioDuration).toBeLessThan(10); | |
| }); | |
| it('should have valid PCM values in [-1, 1] range', () => { | |
| let min = Infinity, max = -Infinity; | |
| for (let i = 0; i < audioData.length; i++) { | |
| if (audioData[i] < min) min = audioData[i]; | |
| if (audioData[i] > max) max = audioData[i]; | |
| expect(isFinite(audioData[i])).toBe(true); | |
| } | |
| expect(min).toBeGreaterThanOrEqual(-1.0); | |
| expect(max).toBeLessThanOrEqual(1.0); | |
| // Should have actual audio content (not silence) | |
| expect(max - min).toBeGreaterThan(0.01); | |
| console.log(`Audio range: [${min.toFixed(4)}, ${max.toFixed(4)}]`); | |
| }); | |
| it('should produce correct number of mel frames', () => { | |
| const expectedFrames = sampleToFrame(audioData.length); | |
| expect(expectedFrames).toBeGreaterThan(0); | |
| console.log(`Expected frames: ${expectedFrames} (${audioDuration.toFixed(2)}s Γ 100 fps)`); | |
| }); | |
| it('should produce finite, normalized mel features', () => { | |
| const { features, T } = fullMelPipeline(audioData, 128); | |
| expect(T).toBeGreaterThan(0); | |
| expect(features.length).toBe(128 * T); | |
| // All values should be finite | |
| for (let i = 0; i < features.length; i++) { | |
| expect(isFinite(features[i])).toBe(true); | |
| } | |
| // Per-mel-bin: should have ~zero mean (normalized) | |
| for (let m = 0; m < 128; m++) { | |
| let sum = 0; | |
| for (let t = 0; t < T; t++) { | |
| sum += features[m * T + t]; | |
| } | |
| const mean = sum / T; | |
| expect(Math.abs(mean)).toBeLessThan(0.01); | |
| } | |
| }); | |
| it('should produce deterministic results', () => { | |
| const result1 = fullMelPipeline(audioData, 128); | |
| const result2 = fullMelPipeline(audioData, 128); | |
| expect(result1.T).toBe(result2.T); | |
| expect(result1.features.length).toBe(result2.features.length); | |
| for (let i = 0; i < result1.features.length; i++) { | |
| expect(result1.features[i]).toBe(result2.features[i]); | |
| } | |
| }); | |
| it('should produce different features for different time windows', () => { | |
| const { features, T } = fullMelPipeline(audioData, 128); | |
| // Compare first and second halves β they should differ (it's speech, not silence) | |
| const halfT = Math.floor(T / 2); | |
| if (halfT < 2) return; // too short | |
| let diffCount = 0; | |
| for (let m = 0; m < 128; m++) { | |
| const v1 = features[m * T + 0]; // first frame | |
| const v2 = features[m * T + halfT]; // middle frame | |
| if (Math.abs(v1 - v2) > 0.01) diffCount++; | |
| } | |
| // At least some mel bins should differ between speech regions | |
| expect(diffCount).toBeGreaterThan(10); | |
| }); | |
| it('should match mel-worker output for the same audio', async () => { | |
| // This test validates that our mel-math (used by mel.worker.ts) produces | |
| // the same features as the full pipeline, ensuring the worker's incremental | |
| // computation matches batch processing. | |
| const nMels = 128; | |
| const { features: batchFeatures, T } = fullMelPipeline(audioData, nMels); | |
| // Simulate incremental processing (like mel.worker does): | |
| // Push all audio at once, then extract all frames | |
| const hannWindow = createPaddedHannWindow(); | |
| const twiddles = precomputeTwiddles(MEL_CONSTANTS.N_FFT); | |
| const melFilterbank = createMelFilterbank(nMels); | |
| // Pre-emphasize the full audio | |
| const preemph = preemphasize(audioData); | |
| // Compute frames one by one (like worker does incrementally) | |
| const rawMel = new Float32Array(nMels * T); | |
| for (let t = 0; t < T; t++) { | |
| const frame = computeMelFrame(preemph, t, hannWindow, twiddles, melFilterbank, nMels); | |
| for (let m = 0; m < nMels; m++) { | |
| rawMel[m * T + t] = frame[m]; | |
| } | |
| } | |
| // Normalize (same as getFeatures in worker) | |
| const incrementalFeatures = normalizeMelFeatures(rawMel, nMels, T); | |
| // Should be bit-for-bit identical since same code path | |
| expect(incrementalFeatures.length).toBe(batchFeatures.length); | |
| for (let i = 0; i < incrementalFeatures.length; i++) { | |
| expect(incrementalFeatures[i]).toBe(batchFeatures[i]); | |
| } | |
| }); | |
| it('should complete mel processing under 100ms for this audio', () => { | |
| const t0 = performance.now(); | |
| const { features, T } = fullMelPipeline(audioData, 128); | |
| const elapsed = performance.now() - t0; | |
| console.log(`Mel pipeline: ${T} frames in ${elapsed.toFixed(1)}ms ` + | |
| `(${(audioDuration / (elapsed / 1000)).toFixed(1)}x realtime)`); | |
| // Should be fast enough for real-time use | |
| expect(elapsed).toBeLessThan(100); | |
| }); | |
| }); | |
| // βββ WAV Parser Tests βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| describe('WAV parser', () => { | |
| it('should parse a known WAV file correctly', () => { | |
| if (!existsSync(WAV_LOCAL_PATH)) { | |
| console.log('SKIP: WAV file not available locally'); | |
| return; | |
| } | |
| const buffer = readFileSync(WAV_LOCAL_PATH); | |
| const wavBuffer = buffer.buffer.slice(buffer.byteOffset, buffer.byteOffset + buffer.byteLength); | |
| const { audio, sampleRate, channels } = parseWav(wavBuffer); | |
| expect(audio).toBeInstanceOf(Float32Array); | |
| expect(audio.length).toBeGreaterThan(0); | |
| expect(sampleRate).toBeGreaterThan(0); | |
| expect(channels).toBeGreaterThanOrEqual(1); | |
| console.log(`WAV: ${audio.length} samples, ${sampleRate} Hz, ${channels} ch, ` + | |
| `${(audio.length / sampleRate).toFixed(2)}s`); | |
| }); | |
| it('should reject non-WAV data', () => { | |
| const notWav = new ArrayBuffer(44); | |
| new Uint8Array(notWav).fill(0); | |
| expect(() => parseWav(notWav)).toThrow(); | |
| }); | |
| }); | |