RASMUS commited on
Commit
87e8f1b
·
verified ·
1 Parent(s): 6264e00

Add scripts/analyze_audio.py

Browse files
Files changed (1) hide show
  1. scripts/analyze_audio.py +279 -0
scripts/analyze_audio.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ analyze_audio.py
3
+
4
+ Comprehensive audio quality comparison between two WAV files.
5
+ Designed for comparing PyTorch TTS output vs ONNX/browser output.
6
+
7
+ Metrics:
8
+ 1. Objective (librosa): mel cosine similarity, MFCC similarity, duration, pitch contour
9
+ 2. Groq Whisper: transcription + WER
10
+ 3. Gemini Flash: MOS score (1-5) with reasoning
11
+
12
+ Usage:
13
+ conda run -n chatterbox-onnx python analyze_audio.py <file_a.wav> <file_b.wav> [--reference-text "..."]
14
+ conda run -n chatterbox-onnx python analyze_audio.py _cmp/pytorch_output.wav _cmp/onnx_output.wav
15
+
16
+ # Compare against the perfect baseline:
17
+ conda run -n chatterbox-onnx python analyze_audio.py \
18
+ Chatterbox-Finnish/output_finnish.wav \
19
+ _cmp/browser_sim_output.wav
20
+ """
21
+
22
+ import sys, os, base64, json, argparse
23
+ import numpy as np
24
+ import librosa
25
+ import soundfile as sf
26
+ import requests
27
+ from pathlib import Path
28
+
29
+ # Load from .env
30
+ def load_env():
31
+ env = {}
32
+ env_path = Path(__file__).parent / ".env"
33
+ if env_path.exists():
34
+ for line in env_path.read_text().splitlines():
35
+ if "=" in line and not line.startswith("#"):
36
+ k, v = line.split("=", 1)
37
+ env[k.strip()] = v.strip()
38
+ return env
39
+
40
+ ENV = load_env()
41
+ GROQ_KEY = os.environ.get("GROQ_API_KEY", ENV.get("QROQ_API_KEY", ""))
42
+ GEMINI_KEY = os.environ.get("GEMINI_API_KEY", ENV.get("GEMINI_API_KEY", ""))
43
+
44
+
45
+ # ── Objective metrics ─────────────────────────────────────────────────────────
46
+
47
+ def load_mono(path: str, target_sr: int = 22050) -> tuple[np.ndarray, int]:
48
+ y, sr = librosa.load(path, sr=target_sr, mono=True)
49
+ return y, sr
50
+
51
+
52
+ def cosine(a: np.ndarray, b: np.ndarray) -> float:
53
+ a, b = a.flatten(), b.flatten()
54
+ denom = np.linalg.norm(a) * np.linalg.norm(b)
55
+ return float(np.dot(a, b) / denom) if denom > 0 else 0.0
56
+
57
+
58
+ def mel_similarity(y_a, y_b, sr) -> float:
59
+ """Cosine similarity of mean mel spectrograms (overall timbre match)."""
60
+ mel_a = librosa.feature.melspectrogram(y=y_a, sr=sr, n_mels=128)
61
+ mel_b = librosa.feature.melspectrogram(y=y_b, sr=sr, n_mels=128)
62
+ # Mean over time
63
+ return cosine(mel_a.mean(axis=1), mel_b.mean(axis=1))
64
+
65
+
66
+ def mfcc_similarity(y_a, y_b, sr, n_mfcc=20) -> float:
67
+ """Cosine similarity of mean MFCCs (phonetic content match)."""
68
+ mfcc_a = librosa.feature.mfcc(y=y_a, sr=sr, n_mfcc=n_mfcc).mean(axis=1)
69
+ mfcc_b = librosa.feature.mfcc(y=y_b, sr=sr, n_mfcc=n_mfcc).mean(axis=1)
70
+ return cosine(mfcc_a, mfcc_b)
71
+
72
+
73
+ def pitch_correlation(y_a, y_b, sr) -> float:
74
+ """Correlation of F0 contours (prosody match). NaN frames excluded."""
75
+ f0_a = librosa.yin(y_a, fmin=60, fmax=400)
76
+ f0_b = librosa.yin(y_b, fmin=60, fmax=400)
77
+ # Resample to same length
78
+ length = min(len(f0_a), len(f0_b))
79
+ f0_a, f0_b = f0_a[:length], f0_b[:length]
80
+ voiced = (f0_a > 0) & (f0_b > 0)
81
+ if voiced.sum() < 10:
82
+ return float("nan")
83
+ a, b = f0_a[voiced], f0_b[voiced]
84
+ corr = np.corrcoef(a, b)[0, 1]
85
+ return float(corr)
86
+
87
+
88
+ def spectral_flux_similarity(y_a, y_b, sr) -> float:
89
+ """How similar the rhythm/energy flow is (pacing match)."""
90
+ flux_a = np.diff(librosa.feature.rms(y=y_a)[0])
91
+ flux_b = np.diff(librosa.feature.rms(y=y_b)[0])
92
+ length = min(len(flux_a), len(flux_b))
93
+ return cosine(flux_a[:length], flux_b[:length])
94
+
95
+
96
+ def objective_metrics(path_a: str, path_b: str) -> dict:
97
+ SR = 22050
98
+ y_a, _ = load_mono(path_a, SR)
99
+ y_b, _ = load_mono(path_b, SR)
100
+
101
+ dur_a = librosa.get_duration(y=y_a, sr=SR)
102
+ dur_b = librosa.get_duration(y=y_b, sr=SR)
103
+
104
+ return {
105
+ "duration_a_s": round(dur_a, 2),
106
+ "duration_b_s": round(dur_b, 2),
107
+ "duration_ratio": round(dur_b / dur_a if dur_a > 0 else 0, 3),
108
+ "mel_cosine": round(mel_similarity(y_a, y_b, SR), 4),
109
+ "mfcc_cosine": round(mfcc_similarity(y_a, y_b, SR), 4),
110
+ "pitch_correlation": round(pitch_correlation(y_a, y_b, SR), 4),
111
+ "energy_flux_cosine": round(spectral_flux_similarity(y_a, y_b, SR), 4),
112
+ }
113
+
114
+
115
+ # ── WER helper ────────────────────────────────────────────────────────────────
116
+
117
+ def simple_wer(ref: str, hyp: str) -> float:
118
+ """Token-level WER."""
119
+ ref_words = ref.lower().split()
120
+ hyp_words = hyp.lower().split()
121
+ n, m = len(ref_words), len(hyp_words)
122
+ dp = list(range(m + 1))
123
+ for i in range(1, n + 1):
124
+ prev = dp.copy()
125
+ dp[0] = i
126
+ for j in range(1, m + 1):
127
+ dp[j] = min(prev[j] + 1, dp[j - 1] + 1,
128
+ prev[j - 1] + (0 if ref_words[i-1] == hyp_words[j-1] else 1))
129
+ return dp[m] / max(n, 1)
130
+
131
+
132
+ # ── Groq transcription ────────────────────────���───────────────────────────────
133
+
134
+ def transcribe_groq(wav_path: str, lang: str = "fi") -> str:
135
+ if not GROQ_KEY:
136
+ return "(no GROQ_API_KEY)"
137
+ with open(wav_path, "rb") as f:
138
+ r = requests.post(
139
+ "https://api.groq.com/openai/v1/audio/transcriptions",
140
+ headers={"Authorization": f"Bearer {GROQ_KEY}"},
141
+ files={"file": (os.path.basename(wav_path), f, "audio/wav")},
142
+ data={"model": "whisper-large-v3", "language": lang, "response_format": "text"},
143
+ )
144
+ if r.ok:
145
+ return r.text.strip()
146
+ return f"(error {r.status_code})"
147
+
148
+
149
+ # ── Gemini MOS ────────────────────────────────────────────────────────────────
150
+
151
+ def gemini_mos(wav_path: str, label: str = "") -> dict:
152
+ """
153
+ Uses Gemini 2.0 Flash to give a MOS score + reasoning for a TTS audio file.
154
+ Matches methodology used in the Chatterbox Finnish fine-tuning evaluation.
155
+ """
156
+ if not GEMINI_KEY:
157
+ return {"score": None, "comment": "(no GEMINI_API_KEY)"}
158
+
159
+ audio_bytes = open(wav_path, "rb").read()
160
+ audio_b64 = base64.b64encode(audio_bytes).decode()
161
+
162
+ prompt = (
163
+ "You are an expert speech quality evaluator. "
164
+ "Listen to this Finnish text-to-speech audio sample and evaluate its naturalness.\n\n"
165
+ "Rate on MOS (Mean Opinion Score) scale 1-5:\n"
166
+ " 1.0 = Completely unintelligible or robotic\n"
167
+ " 2.0 = Very poor quality, hard to understand\n"
168
+ " 3.0 = Acceptable but clearly synthetic\n"
169
+ " 4.0 = Good quality, natural-sounding\n"
170
+ " 5.0 = Excellent, indistinguishable from human speech\n\n"
171
+ "Return ONLY valid JSON: {\"mos\": <float 1.0-5.0>, \"reason\": \"<one sentence>\"}"
172
+ )
173
+
174
+ body = {
175
+ "contents": [{
176
+ "parts": [
177
+ {"inline_data": {"mime_type": "audio/wav", "data": audio_b64}},
178
+ {"text": prompt},
179
+ ]
180
+ }],
181
+ "generationConfig": {"temperature": 0.1, "maxOutputTokens": 1024},
182
+ }
183
+
184
+ url = f"https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash:generateContent?key={GEMINI_KEY}"
185
+ r = requests.post(url, json=body, timeout=30)
186
+ if not r.ok:
187
+ return {"score": None, "comment": f"(Gemini error {r.status_code}: {r.text[:200]})"}
188
+
189
+ try:
190
+ text = r.json()["candidates"][0]["content"]["parts"][0]["text"]
191
+ # Strip markdown fences if present
192
+ text = text.strip().lstrip("```json").lstrip("```").rstrip("```").strip()
193
+ data = json.loads(text)
194
+ return {"score": data.get("mos"), "comment": data.get("reason", "")}
195
+ except Exception as e:
196
+ return {"score": None, "comment": f"(parse error: {e} | raw: {r.text[:200]})"}
197
+
198
+
199
+ # ── Main report ───────────────────────────────────────────────────────────────
200
+
201
+ def report(path_a: str, path_b: str, label_a: str = "A", label_b: str = "B",
202
+ ref_text: str = "", lang: str = "fi"):
203
+
204
+ BAR = "=" * 65
205
+ print(f"\n{BAR}")
206
+ print(f" AUDIO COMPARISON REPORT")
207
+ print(f" A: {path_a}")
208
+ print(f" B: {path_b}")
209
+ print(BAR)
210
+
211
+ # ── Objective metrics ──
212
+ print("\n── Objective metrics ──────────────────────────────────────────")
213
+ obj = objective_metrics(path_a, path_b)
214
+ print(f" Duration A={obj['duration_a_s']}s B={obj['duration_b_s']}s "
215
+ f"ratio(B/A)={obj['duration_ratio']}")
216
+ print(f" Mel cosine {obj['mel_cosine']:.4f} (timbre match, 1.0=identical)")
217
+ print(f" MFCC cosine {obj['mfcc_cosine']:.4f} (phonetic match, 1.0=identical)")
218
+ print(f" Pitch corr {obj['pitch_correlation']:.4f} (prosody match, 1.0=identical)")
219
+ print(f" Energy flux {obj['energy_flux_cosine']:.4f} (pacing match, 1.0=identical)")
220
+
221
+ mel = obj["mel_cosine"]
222
+ mfcc = obj["mfcc_cosine"]
223
+ quality = "excellent (near-identical)" if mel > 0.98 and mfcc > 0.98 \
224
+ else "good" if mel > 0.95 and mfcc > 0.95 \
225
+ else "fair" if mel > 0.90 \
226
+ else "poor — significant differences"
227
+ print(f"\n → Waveform match: {quality}")
228
+
229
+ # ── Transcription ──
230
+ print("\n── Groq Whisper transcription ─────────────────────────────────")
231
+ tx_a = transcribe_groq(path_a, lang)
232
+ tx_b = transcribe_groq(path_b, lang)
233
+ print(f" {label_a}: '{tx_a}'")
234
+ print(f" {label_b}: '{tx_b}'")
235
+ if ref_text:
236
+ wer_a = simple_wer(ref_text, tx_a)
237
+ wer_b = simple_wer(ref_text, tx_b)
238
+ print(f" Ref: '{ref_text}'")
239
+ print(f" WER {label_a}: {wer_a:.1%} {label_b}: {wer_b:.1%}")
240
+
241
+ # ── Gemini MOS ──
242
+ print("\n── Gemini 2.0 Flash MOS ───────────────────────────────────────")
243
+ mos_a = gemini_mos(path_a, label_a)
244
+ mos_b = gemini_mos(path_b, label_b)
245
+ print(f" {label_a}: MOS={mos_a['score']} — {mos_a['comment']}")
246
+ print(f" {label_b}: MOS={mos_b['score']} — {mos_b['comment']}")
247
+
248
+ # ── Summary ──
249
+ print(f"\n{BAR}")
250
+ print(" SUMMARY")
251
+ print(BAR)
252
+ print(f" Mel cosine: {obj['mel_cosine']:.4f} (target: >0.95 for 'good match')")
253
+ print(f" MFCC cosine: {obj['mfcc_cosine']:.4f} (target: >0.95)")
254
+ print(f" MOS {label_a}: {mos_a['score']} MOS {label_b}: {mos_b['score']}")
255
+ if ref_text:
256
+ wer_a = simple_wer(ref_text, tx_a)
257
+ wer_b = simple_wer(ref_text, tx_b)
258
+ print(f" WER {label_a}: {wer_a:.1%} WER {label_b}: {wer_b:.1%}")
259
+
260
+ return {
261
+ "objective": obj,
262
+ "transcription": {"a": tx_a, "b": tx_b},
263
+ "mos": {"a": mos_a, "b": mos_b},
264
+ }
265
+
266
+
267
+ if __name__ == "__main__":
268
+ p = argparse.ArgumentParser(description="Compare two TTS audio files")
269
+ p.add_argument("file_a", help="Reference/baseline WAV (e.g. pytorch output)")
270
+ p.add_argument("file_b", help="Target WAV to compare against (e.g. ONNX/browser output)")
271
+ p.add_argument("--label-a", default="PyTorch", help="Label for file A")
272
+ p.add_argument("--label-b", default="ONNX", help="Label for file B")
273
+ p.add_argument("--ref-text", default="", help="Reference transcript for WER")
274
+ p.add_argument("--lang", default="fi", help="Language code for transcription")
275
+ args = p.parse_args()
276
+
277
+ report(args.file_a, args.file_b,
278
+ label_a=args.label_a, label_b=args.label_b,
279
+ ref_text=args.ref_text, lang=args.lang)