Spaces:
Sleeping
Sleeping
File size: 6,555 Bytes
a3419b6 24a256c a3419b6 24a256c a3419b6 24a256c a3419b6 24a256c a3419b6 24a256c a3419b6 229a3e3 29ebcc1 229a3e3 29ebcc1 229a3e3 29ebcc1 229a3e3 29ebcc1 a3419b6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 | # app/evaluator.py
# TTS evaluation pipeline for the Bantrly eval framework.
#
# Metrics:
# WER β Word Error Rate via Whisper transcription (Radford et al. 2023)
# UTMOS β Automated MOS prediction (Saeki et al. 2022, VoiceMOS Challenge winner)
# RTF β Real Time Factor: synthesis_time / audio_duration
# Cost β Equivalent cost vs Chirp 3 HD ($16/1M chars)
#
# To enable persistent result saving, uncomment the save_results() call
# at the bottom of evaluate().
import time
import librosa
import torch
import soundfile as sf
import numpy as np
from jiwer import wer
from faster_whisper import WhisperModel
# --- Whisper setup ---
# "base" model: ~150MB, fast, good enough for WER on clean TTS output
# upgrade to "small" or "medium" if WER accuracy is insufficient
_whisper_model = None
def _get_whisper():
global _whisper_model
if _whisper_model is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
compute = "float16" if device == "cuda" else "int8"
_whisper_model = WhisperModel("base", device=device, compute_type=compute)
return _whisper_model
# --- UTMOS setup ---
# sarulab-speech/UTMOS22 β winner of VoiceMOS Challenge 2022
# predicts human MOS scores (1-5) without reference audio
_utmos_model = None
def _get_utmos():
global _utmos_model
if _utmos_model is None:
# tarepan/SpeechMOS is a maintained fork with a proper hubconf.py
# wraps the official UTMOS22 strong learner weights (MIT license)
_utmos_model = torch.hub.load(
"tarepan/SpeechMOS:v1.2.0",
"utmos22_strong",
trust_repo=True
)
_utmos_model.eval()
return _utmos_model
def compute_wer(reference_text: str, audio_path: str) -> float:
"""
Transcribe audio with Whisper and compute WER against reference text.
Args:
reference_text: the original input text (ground truth)
audio_path: path to synthesized audio file
Returns:
WER as a float 0.0β1.0 (multiply by 100 for percentage)
"""
model = _get_whisper()
segments, _ = model.transcribe(audio_path, beam_size=5)
hypothesis = " ".join(seg.text.strip() for seg in segments)
score = wer(reference_text.lower().strip(), hypothesis.lower().strip())
return round(score, 4)
def compute_utmos(audio_path: str) -> float:
"""
Predict MOS score using UTMOS (automated naturalness rating 1-5).
Uses librosa for all formats (WAV + MP3) to avoid soundfile
subprocess issues in Gradio's hot-reload worker.
Args:
audio_path: path to synthesized audio file
Returns:
predicted MOS score (float, higher = more natural)
"""
model = _get_utmos()
audio, sr = librosa.load(audio_path, sr=16000, mono=True)
wav_tensor = torch.FloatTensor(audio).unsqueeze(0)
with torch.no_grad():
score = model(wav_tensor, sr=16000)
return round(float(score), 3)
def compute_rtf(latency_seconds: float, audio_path: str) -> float:
"""
Compute Real Time Factor: synthesis_time / audio_duration.
RTF < 1.0 means faster than real time.
Uses librosa for MP3 (sf.read may fail on MP3 depending on libsndfile version).
Args:
latency_seconds: wall-clock synthesis time from engine
audio_path: path to synthesized audio file
Returns:
RTF as float
"""
if audio_path.endswith(".mp3"):
audio, sr = librosa.load(audio_path, sr=None)
else:
audio, sr = sf.read(audio_path)
audio_duration = len(audio) / sr
if audio_duration == 0:
return 0.0
return round(latency_seconds / audio_duration, 3)
def evaluate(
reference_text: str,
audio_path: str,
latency_seconds: float,
engine,
band: str = "unknown",
synth_voice: str = "unknown",
actual_cost_usd: float = None,
) -> dict:
"""
Run full eval suite on a synthesized audio file.
Args:
reference_text: original input text
audio_path: path to synthesized audio
latency_seconds: synthesis latency from engine.synthesize()
engine: TTSEngine instance (for cost + metadata)
Returns:
dict with all eval scores + metadata for comparison table
"""
# WER β skip for mp3 if Whisper has issues; wav is preferred
try:
wer_score = compute_wer(reference_text, audio_path)
except Exception as e:
wer_score = None
print(f"WER computation failed: {e}")
try:
utmos_score = compute_utmos(audio_path)
except Exception as e:
utmos_score = None
print(f"UTMOS computation failed: {e}")
# RTF
try:
rtf = compute_rtf(latency_seconds, audio_path)
except Exception as e:
rtf = None
print(f"RTF computation failed: {e}")
# cost estimate vs Chirp baseline
chirp_cost = round((len(reference_text) / 1_000_000) * 16.0, 6)
# use actual cost if provided by engine (e.g. RunPod returns it per request)
engine_cost = round(actual_cost_usd, 6) if actual_cost_usd is not None else round(engine.estimate_cost(reference_text), 6)
from datetime import datetime
result = {
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"engine": engine.name,
"engine_type": engine.engine_type,
"production_ready": engine.is_production_ready,
"band": band,
"input_text": reference_text,
"voice": synth_voice,
"wer": wer_score,
"utmos": utmos_score,
"rtf": rtf,
"latency_s": latency_seconds,
"engine_cost_usd": engine_cost,
"chirp_equiv_usd": chirp_cost,
"chars": len(reference_text),
}
# --- to enable persistent saving, uncomment these lines ---
import pandas as pd, os
results_path = os.path.join(os.path.dirname(__file__), "results", "eval_log.csv")
os.makedirs(os.path.dirname(results_path), exist_ok=True)
df = pd.DataFrame([result])
df.to_csv(results_path, mode="a", header=not os.path.exists(results_path), index=False)
# upload updated CSV and run cleanup check in background
try:
from storage import upload_csv_background, cleanup_bucket_background
upload_csv_background(results_path)
cleanup_bucket_background(results_path)
except Exception as e:
print(f"[Storage] Background tasks skipped: {e}")
return result |