Chatterbox-Finnish-ONNX / scripts /analyze_audio.py
RASMUS's picture
Add scripts/analyze_audio.py
87e8f1b verified
"""
analyze_audio.py
Comprehensive audio quality comparison between two WAV files.
Designed for comparing PyTorch TTS output vs ONNX/browser output.
Metrics:
1. Objective (librosa): mel cosine similarity, MFCC similarity, duration, pitch contour
2. Groq Whisper: transcription + WER
3. Gemini Flash: MOS score (1-5) with reasoning
Usage:
conda run -n chatterbox-onnx python analyze_audio.py <file_a.wav> <file_b.wav> [--reference-text "..."]
conda run -n chatterbox-onnx python analyze_audio.py _cmp/pytorch_output.wav _cmp/onnx_output.wav
# Compare against the perfect baseline:
conda run -n chatterbox-onnx python analyze_audio.py \
Chatterbox-Finnish/output_finnish.wav \
_cmp/browser_sim_output.wav
"""
import sys, os, base64, json, argparse
import numpy as np
import librosa
import soundfile as sf
import requests
from pathlib import Path
# Load from .env
def load_env():
env = {}
env_path = Path(__file__).parent / ".env"
if env_path.exists():
for line in env_path.read_text().splitlines():
if "=" in line and not line.startswith("#"):
k, v = line.split("=", 1)
env[k.strip()] = v.strip()
return env
ENV = load_env()
GROQ_KEY = os.environ.get("GROQ_API_KEY", ENV.get("QROQ_API_KEY", ""))
GEMINI_KEY = os.environ.get("GEMINI_API_KEY", ENV.get("GEMINI_API_KEY", ""))
# ── Objective metrics ─────────────────────────────────────────────────────────
def load_mono(path: str, target_sr: int = 22050) -> tuple[np.ndarray, int]:
y, sr = librosa.load(path, sr=target_sr, mono=True)
return y, sr
def cosine(a: np.ndarray, b: np.ndarray) -> float:
a, b = a.flatten(), b.flatten()
denom = np.linalg.norm(a) * np.linalg.norm(b)
return float(np.dot(a, b) / denom) if denom > 0 else 0.0
def mel_similarity(y_a, y_b, sr) -> float:
"""Cosine similarity of mean mel spectrograms (overall timbre match)."""
mel_a = librosa.feature.melspectrogram(y=y_a, sr=sr, n_mels=128)
mel_b = librosa.feature.melspectrogram(y=y_b, sr=sr, n_mels=128)
# Mean over time
return cosine(mel_a.mean(axis=1), mel_b.mean(axis=1))
def mfcc_similarity(y_a, y_b, sr, n_mfcc=20) -> float:
"""Cosine similarity of mean MFCCs (phonetic content match)."""
mfcc_a = librosa.feature.mfcc(y=y_a, sr=sr, n_mfcc=n_mfcc).mean(axis=1)
mfcc_b = librosa.feature.mfcc(y=y_b, sr=sr, n_mfcc=n_mfcc).mean(axis=1)
return cosine(mfcc_a, mfcc_b)
def pitch_correlation(y_a, y_b, sr) -> float:
"""Correlation of F0 contours (prosody match). NaN frames excluded."""
f0_a = librosa.yin(y_a, fmin=60, fmax=400)
f0_b = librosa.yin(y_b, fmin=60, fmax=400)
# Resample to same length
length = min(len(f0_a), len(f0_b))
f0_a, f0_b = f0_a[:length], f0_b[:length]
voiced = (f0_a > 0) & (f0_b > 0)
if voiced.sum() < 10:
return float("nan")
a, b = f0_a[voiced], f0_b[voiced]
corr = np.corrcoef(a, b)[0, 1]
return float(corr)
def spectral_flux_similarity(y_a, y_b, sr) -> float:
"""How similar the rhythm/energy flow is (pacing match)."""
flux_a = np.diff(librosa.feature.rms(y=y_a)[0])
flux_b = np.diff(librosa.feature.rms(y=y_b)[0])
length = min(len(flux_a), len(flux_b))
return cosine(flux_a[:length], flux_b[:length])
def objective_metrics(path_a: str, path_b: str) -> dict:
SR = 22050
y_a, _ = load_mono(path_a, SR)
y_b, _ = load_mono(path_b, SR)
dur_a = librosa.get_duration(y=y_a, sr=SR)
dur_b = librosa.get_duration(y=y_b, sr=SR)
return {
"duration_a_s": round(dur_a, 2),
"duration_b_s": round(dur_b, 2),
"duration_ratio": round(dur_b / dur_a if dur_a > 0 else 0, 3),
"mel_cosine": round(mel_similarity(y_a, y_b, SR), 4),
"mfcc_cosine": round(mfcc_similarity(y_a, y_b, SR), 4),
"pitch_correlation": round(pitch_correlation(y_a, y_b, SR), 4),
"energy_flux_cosine": round(spectral_flux_similarity(y_a, y_b, SR), 4),
}
# ── WER helper ────────────────────────────────────────────────────────────────
def simple_wer(ref: str, hyp: str) -> float:
"""Token-level WER."""
ref_words = ref.lower().split()
hyp_words = hyp.lower().split()
n, m = len(ref_words), len(hyp_words)
dp = list(range(m + 1))
for i in range(1, n + 1):
prev = dp.copy()
dp[0] = i
for j in range(1, m + 1):
dp[j] = min(prev[j] + 1, dp[j - 1] + 1,
prev[j - 1] + (0 if ref_words[i-1] == hyp_words[j-1] else 1))
return dp[m] / max(n, 1)
# ── Groq transcription ────────────────────────────────────────────────────────
def transcribe_groq(wav_path: str, lang: str = "fi") -> str:
if not GROQ_KEY:
return "(no GROQ_API_KEY)"
with open(wav_path, "rb") as f:
r = requests.post(
"https://api.groq.com/openai/v1/audio/transcriptions",
headers={"Authorization": f"Bearer {GROQ_KEY}"},
files={"file": (os.path.basename(wav_path), f, "audio/wav")},
data={"model": "whisper-large-v3", "language": lang, "response_format": "text"},
)
if r.ok:
return r.text.strip()
return f"(error {r.status_code})"
# ── Gemini MOS ────────────────────────────────────────────────────────────────
def gemini_mos(wav_path: str, label: str = "") -> dict:
"""
Uses Gemini 2.0 Flash to give a MOS score + reasoning for a TTS audio file.
Matches methodology used in the Chatterbox Finnish fine-tuning evaluation.
"""
if not GEMINI_KEY:
return {"score": None, "comment": "(no GEMINI_API_KEY)"}
audio_bytes = open(wav_path, "rb").read()
audio_b64 = base64.b64encode(audio_bytes).decode()
prompt = (
"You are an expert speech quality evaluator. "
"Listen to this Finnish text-to-speech audio sample and evaluate its naturalness.\n\n"
"Rate on MOS (Mean Opinion Score) scale 1-5:\n"
" 1.0 = Completely unintelligible or robotic\n"
" 2.0 = Very poor quality, hard to understand\n"
" 3.0 = Acceptable but clearly synthetic\n"
" 4.0 = Good quality, natural-sounding\n"
" 5.0 = Excellent, indistinguishable from human speech\n\n"
"Return ONLY valid JSON: {\"mos\": <float 1.0-5.0>, \"reason\": \"<one sentence>\"}"
)
body = {
"contents": [{
"parts": [
{"inline_data": {"mime_type": "audio/wav", "data": audio_b64}},
{"text": prompt},
]
}],
"generationConfig": {"temperature": 0.1, "maxOutputTokens": 1024},
}
url = f"https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash:generateContent?key={GEMINI_KEY}"
r = requests.post(url, json=body, timeout=30)
if not r.ok:
return {"score": None, "comment": f"(Gemini error {r.status_code}: {r.text[:200]})"}
try:
text = r.json()["candidates"][0]["content"]["parts"][0]["text"]
# Strip markdown fences if present
text = text.strip().lstrip("```json").lstrip("```").rstrip("```").strip()
data = json.loads(text)
return {"score": data.get("mos"), "comment": data.get("reason", "")}
except Exception as e:
return {"score": None, "comment": f"(parse error: {e} | raw: {r.text[:200]})"}
# ── Main report ───────────────────────────────────────────────────────────────
def report(path_a: str, path_b: str, label_a: str = "A", label_b: str = "B",
ref_text: str = "", lang: str = "fi"):
BAR = "=" * 65
print(f"\n{BAR}")
print(f" AUDIO COMPARISON REPORT")
print(f" A: {path_a}")
print(f" B: {path_b}")
print(BAR)
# ── Objective metrics ──
print("\n── Objective metrics ──────────────────────────────────────────")
obj = objective_metrics(path_a, path_b)
print(f" Duration A={obj['duration_a_s']}s B={obj['duration_b_s']}s "
f"ratio(B/A)={obj['duration_ratio']}")
print(f" Mel cosine {obj['mel_cosine']:.4f} (timbre match, 1.0=identical)")
print(f" MFCC cosine {obj['mfcc_cosine']:.4f} (phonetic match, 1.0=identical)")
print(f" Pitch corr {obj['pitch_correlation']:.4f} (prosody match, 1.0=identical)")
print(f" Energy flux {obj['energy_flux_cosine']:.4f} (pacing match, 1.0=identical)")
mel = obj["mel_cosine"]
mfcc = obj["mfcc_cosine"]
quality = "excellent (near-identical)" if mel > 0.98 and mfcc > 0.98 \
else "good" if mel > 0.95 and mfcc > 0.95 \
else "fair" if mel > 0.90 \
else "poor β€” significant differences"
print(f"\n β†’ Waveform match: {quality}")
# ── Transcription ──
print("\n── Groq Whisper transcription ─────────────────────────────────")
tx_a = transcribe_groq(path_a, lang)
tx_b = transcribe_groq(path_b, lang)
print(f" {label_a}: '{tx_a}'")
print(f" {label_b}: '{tx_b}'")
if ref_text:
wer_a = simple_wer(ref_text, tx_a)
wer_b = simple_wer(ref_text, tx_b)
print(f" Ref: '{ref_text}'")
print(f" WER {label_a}: {wer_a:.1%} {label_b}: {wer_b:.1%}")
# ── Gemini MOS ──
print("\n── Gemini 2.0 Flash MOS ───────────────────────────────────────")
mos_a = gemini_mos(path_a, label_a)
mos_b = gemini_mos(path_b, label_b)
print(f" {label_a}: MOS={mos_a['score']} β€” {mos_a['comment']}")
print(f" {label_b}: MOS={mos_b['score']} β€” {mos_b['comment']}")
# ── Summary ──
print(f"\n{BAR}")
print(" SUMMARY")
print(BAR)
print(f" Mel cosine: {obj['mel_cosine']:.4f} (target: >0.95 for 'good match')")
print(f" MFCC cosine: {obj['mfcc_cosine']:.4f} (target: >0.95)")
print(f" MOS {label_a}: {mos_a['score']} MOS {label_b}: {mos_b['score']}")
if ref_text:
wer_a = simple_wer(ref_text, tx_a)
wer_b = simple_wer(ref_text, tx_b)
print(f" WER {label_a}: {wer_a:.1%} WER {label_b}: {wer_b:.1%}")
return {
"objective": obj,
"transcription": {"a": tx_a, "b": tx_b},
"mos": {"a": mos_a, "b": mos_b},
}
if __name__ == "__main__":
p = argparse.ArgumentParser(description="Compare two TTS audio files")
p.add_argument("file_a", help="Reference/baseline WAV (e.g. pytorch output)")
p.add_argument("file_b", help="Target WAV to compare against (e.g. ONNX/browser output)")
p.add_argument("--label-a", default="PyTorch", help="Label for file A")
p.add_argument("--label-b", default="ONNX", help="Label for file B")
p.add_argument("--ref-text", default="", help="Reference transcript for WER")
p.add_argument("--lang", default="fi", help="Language code for transcription")
args = p.parse_args()
report(args.file_a, args.file_b,
label_a=args.label_a, label_b=args.label_b,
ref_text=args.ref_text, lang=args.lang)