tts-eval-framework / app /evaluator.py
aankitdas's picture
added storage limit guard
229a3e3
# app/evaluator.py
# TTS evaluation pipeline for the Bantrly eval framework.
#
# Metrics:
# WER — Word Error Rate via Whisper transcription (Radford et al. 2023)
# UTMOS — Automated MOS prediction (Saeki et al. 2022, VoiceMOS Challenge winner)
# RTF — Real Time Factor: synthesis_time / audio_duration
# Cost — Equivalent cost vs Chirp 3 HD ($16/1M chars)
#
# To enable persistent result saving, uncomment the save_results() call
# at the bottom of evaluate().
import time
import librosa
import torch
import soundfile as sf
import numpy as np
from jiwer import wer
from faster_whisper import WhisperModel
# --- Whisper setup ---
# "base" model: ~150MB, fast, good enough for WER on clean TTS output
# upgrade to "small" or "medium" if WER accuracy is insufficient
_whisper_model = None
def _get_whisper():
global _whisper_model
if _whisper_model is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
compute = "float16" if device == "cuda" else "int8"
_whisper_model = WhisperModel("base", device=device, compute_type=compute)
return _whisper_model
# --- UTMOS setup ---
# sarulab-speech/UTMOS22 — winner of VoiceMOS Challenge 2022
# predicts human MOS scores (1-5) without reference audio
_utmos_model = None
def _get_utmos():
global _utmos_model
if _utmos_model is None:
# tarepan/SpeechMOS is a maintained fork with a proper hubconf.py
# wraps the official UTMOS22 strong learner weights (MIT license)
_utmos_model = torch.hub.load(
"tarepan/SpeechMOS:v1.2.0",
"utmos22_strong",
trust_repo=True
)
_utmos_model.eval()
return _utmos_model
def compute_wer(reference_text: str, audio_path: str) -> float:
"""
Transcribe audio with Whisper and compute WER against reference text.
Args:
reference_text: the original input text (ground truth)
audio_path: path to synthesized audio file
Returns:
WER as a float 0.0–1.0 (multiply by 100 for percentage)
"""
model = _get_whisper()
segments, _ = model.transcribe(audio_path, beam_size=5)
hypothesis = " ".join(seg.text.strip() for seg in segments)
score = wer(reference_text.lower().strip(), hypothesis.lower().strip())
return round(score, 4)
def compute_utmos(audio_path: str) -> float:
"""
Predict MOS score using UTMOS (automated naturalness rating 1-5).
Uses librosa for all formats (WAV + MP3) to avoid soundfile
subprocess issues in Gradio's hot-reload worker.
Args:
audio_path: path to synthesized audio file
Returns:
predicted MOS score (float, higher = more natural)
"""
model = _get_utmos()
audio, sr = librosa.load(audio_path, sr=16000, mono=True)
wav_tensor = torch.FloatTensor(audio).unsqueeze(0)
with torch.no_grad():
score = model(wav_tensor, sr=16000)
return round(float(score), 3)
def compute_rtf(latency_seconds: float, audio_path: str) -> float:
"""
Compute Real Time Factor: synthesis_time / audio_duration.
RTF < 1.0 means faster than real time.
Uses librosa for MP3 (sf.read may fail on MP3 depending on libsndfile version).
Args:
latency_seconds: wall-clock synthesis time from engine
audio_path: path to synthesized audio file
Returns:
RTF as float
"""
if audio_path.endswith(".mp3"):
audio, sr = librosa.load(audio_path, sr=None)
else:
audio, sr = sf.read(audio_path)
audio_duration = len(audio) / sr
if audio_duration == 0:
return 0.0
return round(latency_seconds / audio_duration, 3)
def evaluate(
reference_text: str,
audio_path: str,
latency_seconds: float,
engine,
band: str = "unknown",
synth_voice: str = "unknown",
actual_cost_usd: float = None,
) -> dict:
"""
Run full eval suite on a synthesized audio file.
Args:
reference_text: original input text
audio_path: path to synthesized audio
latency_seconds: synthesis latency from engine.synthesize()
engine: TTSEngine instance (for cost + metadata)
Returns:
dict with all eval scores + metadata for comparison table
"""
# WER — skip for mp3 if Whisper has issues; wav is preferred
try:
wer_score = compute_wer(reference_text, audio_path)
except Exception as e:
wer_score = None
print(f"WER computation failed: {e}")
try:
utmos_score = compute_utmos(audio_path)
except Exception as e:
utmos_score = None
print(f"UTMOS computation failed: {e}")
# RTF
try:
rtf = compute_rtf(latency_seconds, audio_path)
except Exception as e:
rtf = None
print(f"RTF computation failed: {e}")
# cost estimate vs Chirp baseline
chirp_cost = round((len(reference_text) / 1_000_000) * 16.0, 6)
# use actual cost if provided by engine (e.g. RunPod returns it per request)
engine_cost = round(actual_cost_usd, 6) if actual_cost_usd is not None else round(engine.estimate_cost(reference_text), 6)
from datetime import datetime
result = {
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"engine": engine.name,
"engine_type": engine.engine_type,
"production_ready": engine.is_production_ready,
"band": band,
"input_text": reference_text,
"voice": synth_voice,
"wer": wer_score,
"utmos": utmos_score,
"rtf": rtf,
"latency_s": latency_seconds,
"engine_cost_usd": engine_cost,
"chirp_equiv_usd": chirp_cost,
"chars": len(reference_text),
}
# --- to enable persistent saving, uncomment these lines ---
import pandas as pd, os
results_path = os.path.join(os.path.dirname(__file__), "results", "eval_log.csv")
os.makedirs(os.path.dirname(results_path), exist_ok=True)
df = pd.DataFrame([result])
df.to_csv(results_path, mode="a", header=not os.path.exists(results_path), index=False)
# upload updated CSV and run cleanup check in background
try:
from storage import upload_csv_background, cleanup_bucket_background
upload_csv_background(results_path)
cleanup_bucket_background(results_path)
except Exception as e:
print(f"[Storage] Background tasks skipped: {e}")
return result