keet-streaming / src /lib /audio /mel-e2e.test.ts
ysdede's picture
feat(space): migrate Hugging Face Space to keet SolidJS app
b8cc2bf
/**
* End-to-End mel spectrogram tests using real audio and ONNX reference data.
*
* Tests:
* 1. Cross-validation against ONNX reference (mel_reference.json from parakeet.js)
* 2. Real WAV file processing (life_Jim.wav from parakeet.js demo)
* 3. Mel filterbank accuracy against ONNX reference
*
* These tests catch regressions in mel computation that unit tests might miss,
* such as incorrect normalization, wrong filterbank values, or precision issues.
*
* The ONNX reference is generated by parakeet.js's tests/generate_mel_reference.py
* using the official NeMo ONNX preprocessor as ground truth.
*
* Run: npm test
*/
import { readFileSync, existsSync } from 'fs';
import { join } from 'path';
import https from 'https';
import { describe, it, expect, beforeAll } from 'vitest';
import {
MEL_CONSTANTS,
hzToMel,
melToHz,
createMelFilterbank,
createPaddedHannWindow,
precomputeTwiddles,
fft,
preemphasize,
computeMelFrame,
normalizeMelFeatures,
sampleToFrame,
} from './mel-math';
import { resampleLinear } from './utils';
// ─── Helpers ──────────────────────────────────────────────────────────────
/** Decode base64 to Float32Array (matching parakeet.js test format) */
function base64ToFloat32(b64: string): Float32Array {
const buf = Buffer.from(b64, 'base64');
return new Float32Array(buf.buffer, buf.byteOffset, buf.byteLength / Float32Array.BYTES_PER_ELEMENT);
}
/** Compute error metrics between two arrays */
function computeError(actual: Float32Array, expected: Float32Array, validCount?: number) {
const n = validCount || Math.min(actual.length, expected.length);
let maxErr = 0;
let sumErr = 0;
for (let i = 0; i < n; i++) {
const err = Math.abs(actual[i] - expected[i]);
sumErr += err;
if (err > maxErr) maxErr = err;
}
return {
maxAbsError: maxErr,
meanAbsError: sumErr / n,
n,
};
}
/** Parse a 16-bit PCM WAV file into Float32Array at the native sample rate */
function parseWav(buffer: ArrayBuffer): { audio: Float32Array; sampleRate: number; channels: number } {
const view = new DataView(buffer);
// RIFF header
const riff = String.fromCharCode(view.getUint8(0), view.getUint8(1), view.getUint8(2), view.getUint8(3));
if (riff !== 'RIFF') throw new Error('Not a valid WAV file: missing RIFF header');
const wave = String.fromCharCode(view.getUint8(8), view.getUint8(9), view.getUint8(10), view.getUint8(11));
if (wave !== 'WAVE') throw new Error('Not a valid WAV file: missing WAVE format');
// Find fmt and data chunks
let offset = 12;
let sampleRate = 0;
let channels = 0;
let bitsPerSample = 0;
let dataOffset = 0;
let dataSize = 0;
while (offset < buffer.byteLength - 8) {
const chunkId = String.fromCharCode(
view.getUint8(offset), view.getUint8(offset + 1),
view.getUint8(offset + 2), view.getUint8(offset + 3),
);
const chunkSize = view.getUint32(offset + 4, true);
if (chunkId === 'fmt ') {
channels = view.getUint16(offset + 10, true);
sampleRate = view.getUint32(offset + 12, true);
bitsPerSample = view.getUint16(offset + 22, true);
} else if (chunkId === 'data') {
dataOffset = offset + 8;
dataSize = chunkSize;
break;
}
offset += 8 + chunkSize;
// Align to even byte boundary
if (chunkSize % 2 !== 0) offset++;
}
if (dataOffset === 0) throw new Error('No data chunk found in WAV file');
if (bitsPerSample !== 16) throw new Error(`Unsupported bit depth: ${bitsPerSample} (expected 16)`);
// Extract PCM samples and convert to Float32 [-1, 1]
const numSamples = dataSize / (bitsPerSample / 8) / channels;
const audio = new Float32Array(numSamples);
for (let i = 0; i < numSamples; i++) {
// Read first channel (mono or left channel)
const sampleOffset = dataOffset + i * channels * 2;
const sample = view.getInt16(sampleOffset, true);
audio[i] = sample / 32768.0;
}
return { audio, sampleRate, channels };
}
/**
* Run our full mel pipeline on raw PCM audio.
* Matches the JsPreprocessor.process() pipeline in parakeet.js/src/mel.js.
*/
function fullMelPipeline(audio: Float32Array, nMels: number = 128) {
const { N_FFT, HOP_LENGTH, PREEMPH } = MEL_CONSTANTS;
// 1. Pre-emphasize
const preemph = preemphasize(audio, 0, PREEMPH);
// 2. Compute mel frames
const numFrames = sampleToFrame(audio.length);
if (numFrames === 0) return { features: new Float32Array(0), T: 0 };
const hannWindow = createPaddedHannWindow();
const twiddles = precomputeTwiddles(N_FFT);
const melFilterbank = createMelFilterbank(nMels);
// Raw mel buffer [nMels Γ— numFrames], mel-major layout
const rawMel = new Float32Array(nMels * numFrames);
for (let t = 0; t < numFrames; t++) {
const frame = computeMelFrame(preemph, t, hannWindow, twiddles, melFilterbank, nMels);
for (let m = 0; m < nMels; m++) {
rawMel[m * numFrames + t] = frame[m];
}
}
// 3. Normalize
const features = normalizeMelFeatures(rawMel, nMels, numFrames);
return { features, T: numFrames };
}
// ─── Paths ────────────────────────────────────────────────────────────────
// parakeet.js is sibling to keet: __dirname = src/lib/audio, 4 levels up = N:\github\ysdede
const PARAKEET_ROOT = join(__dirname, '..', '..', '..', '..', 'parakeet.js');
const MEL_REFERENCE_PATH = join(PARAKEET_ROOT, 'tests', 'mel_reference.json');
const WAV_LOCAL_PATH = join(PARAKEET_ROOT, 'examples', 'demo', 'public', 'assets', 'life_Jim.wav');
const WAV_GITHUB_URL = 'https://github.com/ysdede/parakeet.js/raw/refs/heads/master/examples/demo/public/assets/life_Jim.wav';
// ─── ONNX Reference Cross-Validation ─────────────────────────────────────
describe('Cross-validation against ONNX reference', () => {
let reference: any;
let hasReference = false;
beforeAll(() => {
try {
if (existsSync(MEL_REFERENCE_PATH)) {
const content = readFileSync(MEL_REFERENCE_PATH, 'utf-8');
reference = JSON.parse(content);
hasReference = true;
}
} catch {
// Reference not available β€” tests will be skipped
}
});
it('should load mel_reference.json from parakeet.js', () => {
if (!hasReference) {
console.log(`SKIP: mel_reference.json not found at ${MEL_REFERENCE_PATH}`);
console.log('Run: cd ../parakeet.js && python tests/generate_mel_reference.py');
return;
}
expect(reference).toBeDefined();
expect(reference.nMels).toBe(128);
expect(reference.tests).toBeDefined();
});
it('should match ONNX mel filterbank within 1e-5', () => {
if (!hasReference || !reference.melFilterbank) {
console.log('SKIP: No filterbank reference');
return;
}
const refFb = base64ToFloat32(reference.melFilterbank.data);
const refShape = reference.melFilterbank.shape; // [257, 128]
const jsFb = createMelFilterbank(128);
// Compare (ref is [257,128] row-major, ours is [128,257] row-major)
let maxErr = 0;
for (let freq = 0; freq < 257; freq++) {
for (let mel = 0; mel < 128; mel++) {
const refVal = refFb[freq * 128 + mel];
const jsVal = jsFb[mel * 257 + freq];
const err = Math.abs(refVal - jsVal);
if (err > maxErr) maxErr = err;
}
}
console.log(`Filterbank max error vs ONNX: ${maxErr.toExponential(3)}`);
expect(maxErr).toBeLessThan(1e-5);
});
it('should match ONNX full pipeline for each test signal (max<0.05, mean<0.005)', () => {
if (!hasReference) {
console.log('SKIP: No reference data');
return;
}
const nMels = reference.nMels;
for (const [name, test] of Object.entries(reference.tests) as [string, any][]) {
const audio = base64ToFloat32(test.audio);
const refFeatures = base64ToFloat32(test.features);
const refLen = test.featuresLen;
// Run our pipeline
const { features: ourFeatures, T: ourLen } = fullMelPipeline(audio, nMels);
console.log(`Signal "${name}": ${audio.length} samples (${(audio.length / 16000).toFixed(2)}s), ` +
`frames: ours=${ourLen}, ref=${refLen}`);
// Frame count should match
expect(ourLen).toBe(refLen);
// Compare valid frames (mel-by-mel)
const nFramesOurs = ourFeatures.length / nMels;
const nFramesRef = refFeatures.length / nMels;
let maxErr = 0;
let sumErr = 0;
let n = 0;
for (let m = 0; m < nMels; m++) {
for (let t = 0; t < ourLen; t++) {
const ourVal = ourFeatures[m * nFramesOurs + t];
const refVal = refFeatures[m * nFramesRef + t];
const err = Math.abs(ourVal - refVal);
sumErr += err;
if (err > maxErr) maxErr = err;
n++;
}
}
const meanErr = sumErr / n;
console.log(` Max error: ${maxErr.toExponential(3)}, Mean error: ${meanErr.toExponential(3)}`);
// Same thresholds as parakeet.js test_mel.mjs
expect(maxErr).toBeLessThan(0.05);
expect(meanErr).toBeLessThan(0.005);
}
});
});
// ─── Real WAV File Tests ──────────────────────────────────────────────────
describe('Real audio: life_Jim.wav', () => {
let audioData: Float32Array;
let audioDuration: number;
const EXPECTED_TRANSCRIPT = 'it is not life as we know or understand it';
beforeAll(async () => {
let wavBuffer: ArrayBuffer;
if (existsSync(WAV_LOCAL_PATH)) {
// Read local file (fast, no network dependency)
const fileBuffer = readFileSync(WAV_LOCAL_PATH);
wavBuffer = fileBuffer.buffer.slice(
fileBuffer.byteOffset,
fileBuffer.byteOffset + fileBuffer.byteLength,
);
console.log(`Loaded local WAV: ${WAV_LOCAL_PATH} (${fileBuffer.length} bytes)`);
} else {
// Download from GitHub using Node.js https (happy-dom blocks CORS fetch)
console.log(`Local WAV not found, downloading from ${WAV_GITHUB_URL}`);
wavBuffer = await new Promise<ArrayBuffer>((resolve, reject) => {
const download = (url: string, redirects = 0) => {
if (redirects > 5) return reject(new Error('Too many redirects'));
https.get(url, (res) => {
// Follow redirects (GitHub sends 301/302)
if (res.statusCode && res.statusCode >= 300 && res.statusCode < 400 && res.headers.location) {
return download(res.headers.location, redirects + 1);
}
if (res.statusCode !== 200) return reject(new Error(`HTTP ${res.statusCode}`));
const chunks: Buffer[] = [];
res.on('data', (chunk: Buffer) => chunks.push(chunk));
res.on('end', () => {
const buf = Buffer.concat(chunks);
resolve(buf.buffer.slice(buf.byteOffset, buf.byteOffset + buf.byteLength));
});
res.on('error', reject);
}).on('error', reject);
};
download(WAV_GITHUB_URL);
});
console.log(`Downloaded WAV: ${wavBuffer.byteLength} bytes`);
}
// Parse WAV
const { audio, sampleRate, channels } = parseWav(wavBuffer);
console.log(`Parsed WAV: ${audio.length} samples, ${sampleRate} Hz, ${channels} ch`);
// Resample to 16kHz if needed
if (sampleRate !== 16000) {
audioData = resampleLinear(audio, sampleRate, 16000);
console.log(`Resampled: ${audio.length} β†’ ${audioData.length} samples (${sampleRate} β†’ 16000 Hz)`);
} else {
audioData = audio;
}
audioDuration = audioData.length / 16000;
console.log(`Audio duration: ${audioDuration.toFixed(2)}s`);
});
it('should parse the WAV file correctly', () => {
expect(audioData).toBeInstanceOf(Float32Array);
expect(audioData.length).toBeGreaterThan(0);
// life_Jim.wav is about 1.4 seconds of speech
expect(audioDuration).toBeGreaterThan(0.5);
expect(audioDuration).toBeLessThan(10);
});
it('should have valid PCM values in [-1, 1] range', () => {
let min = Infinity, max = -Infinity;
for (let i = 0; i < audioData.length; i++) {
if (audioData[i] < min) min = audioData[i];
if (audioData[i] > max) max = audioData[i];
expect(isFinite(audioData[i])).toBe(true);
}
expect(min).toBeGreaterThanOrEqual(-1.0);
expect(max).toBeLessThanOrEqual(1.0);
// Should have actual audio content (not silence)
expect(max - min).toBeGreaterThan(0.01);
console.log(`Audio range: [${min.toFixed(4)}, ${max.toFixed(4)}]`);
});
it('should produce correct number of mel frames', () => {
const expectedFrames = sampleToFrame(audioData.length);
expect(expectedFrames).toBeGreaterThan(0);
console.log(`Expected frames: ${expectedFrames} (${audioDuration.toFixed(2)}s Γ— 100 fps)`);
});
it('should produce finite, normalized mel features', () => {
const { features, T } = fullMelPipeline(audioData, 128);
expect(T).toBeGreaterThan(0);
expect(features.length).toBe(128 * T);
// All values should be finite
for (let i = 0; i < features.length; i++) {
expect(isFinite(features[i])).toBe(true);
}
// Per-mel-bin: should have ~zero mean (normalized)
for (let m = 0; m < 128; m++) {
let sum = 0;
for (let t = 0; t < T; t++) {
sum += features[m * T + t];
}
const mean = sum / T;
expect(Math.abs(mean)).toBeLessThan(0.01);
}
});
it('should produce deterministic results', () => {
const result1 = fullMelPipeline(audioData, 128);
const result2 = fullMelPipeline(audioData, 128);
expect(result1.T).toBe(result2.T);
expect(result1.features.length).toBe(result2.features.length);
for (let i = 0; i < result1.features.length; i++) {
expect(result1.features[i]).toBe(result2.features[i]);
}
});
it('should produce different features for different time windows', () => {
const { features, T } = fullMelPipeline(audioData, 128);
// Compare first and second halves β€” they should differ (it's speech, not silence)
const halfT = Math.floor(T / 2);
if (halfT < 2) return; // too short
let diffCount = 0;
for (let m = 0; m < 128; m++) {
const v1 = features[m * T + 0]; // first frame
const v2 = features[m * T + halfT]; // middle frame
if (Math.abs(v1 - v2) > 0.01) diffCount++;
}
// At least some mel bins should differ between speech regions
expect(diffCount).toBeGreaterThan(10);
});
it('should match mel-worker output for the same audio', async () => {
// This test validates that our mel-math (used by mel.worker.ts) produces
// the same features as the full pipeline, ensuring the worker's incremental
// computation matches batch processing.
const nMels = 128;
const { features: batchFeatures, T } = fullMelPipeline(audioData, nMels);
// Simulate incremental processing (like mel.worker does):
// Push all audio at once, then extract all frames
const hannWindow = createPaddedHannWindow();
const twiddles = precomputeTwiddles(MEL_CONSTANTS.N_FFT);
const melFilterbank = createMelFilterbank(nMels);
// Pre-emphasize the full audio
const preemph = preemphasize(audioData);
// Compute frames one by one (like worker does incrementally)
const rawMel = new Float32Array(nMels * T);
for (let t = 0; t < T; t++) {
const frame = computeMelFrame(preemph, t, hannWindow, twiddles, melFilterbank, nMels);
for (let m = 0; m < nMels; m++) {
rawMel[m * T + t] = frame[m];
}
}
// Normalize (same as getFeatures in worker)
const incrementalFeatures = normalizeMelFeatures(rawMel, nMels, T);
// Should be bit-for-bit identical since same code path
expect(incrementalFeatures.length).toBe(batchFeatures.length);
for (let i = 0; i < incrementalFeatures.length; i++) {
expect(incrementalFeatures[i]).toBe(batchFeatures[i]);
}
});
it('should complete mel processing under 100ms for this audio', () => {
const t0 = performance.now();
const { features, T } = fullMelPipeline(audioData, 128);
const elapsed = performance.now() - t0;
console.log(`Mel pipeline: ${T} frames in ${elapsed.toFixed(1)}ms ` +
`(${(audioDuration / (elapsed / 1000)).toFixed(1)}x realtime)`);
// Should be fast enough for real-time use
expect(elapsed).toBeLessThan(100);
});
});
// ─── WAV Parser Tests ─────────────────────────────────────────────────────
describe('WAV parser', () => {
it('should parse a known WAV file correctly', () => {
if (!existsSync(WAV_LOCAL_PATH)) {
console.log('SKIP: WAV file not available locally');
return;
}
const buffer = readFileSync(WAV_LOCAL_PATH);
const wavBuffer = buffer.buffer.slice(buffer.byteOffset, buffer.byteOffset + buffer.byteLength);
const { audio, sampleRate, channels } = parseWav(wavBuffer);
expect(audio).toBeInstanceOf(Float32Array);
expect(audio.length).toBeGreaterThan(0);
expect(sampleRate).toBeGreaterThan(0);
expect(channels).toBeGreaterThanOrEqual(1);
console.log(`WAV: ${audio.length} samples, ${sampleRate} Hz, ${channels} ch, ` +
`${(audio.length / sampleRate).toFixed(2)}s`);
});
it('should reject non-WAV data', () => {
const notWav = new ArrayBuffer(44);
new Uint8Array(notWav).fill(0);
expect(() => parseWav(notWav)).toThrow();
});
});