Spaces:
Sleeping
Sleeping
| # app/evaluator.py | |
| # TTS evaluation pipeline for the Bantrly eval framework. | |
| # | |
| # Metrics: | |
| # WER — Word Error Rate via Whisper transcription (Radford et al. 2023) | |
| # UTMOS — Automated MOS prediction (Saeki et al. 2022, VoiceMOS Challenge winner) | |
| # RTF — Real Time Factor: synthesis_time / audio_duration | |
| # Cost — Equivalent cost vs Chirp 3 HD ($16/1M chars) | |
| # | |
| # To enable persistent result saving, uncomment the save_results() call | |
| # at the bottom of evaluate(). | |
| import time | |
| import librosa | |
| import torch | |
| import soundfile as sf | |
| import numpy as np | |
| from jiwer import wer | |
| from faster_whisper import WhisperModel | |
| # --- Whisper setup --- | |
| # "base" model: ~150MB, fast, good enough for WER on clean TTS output | |
| # upgrade to "small" or "medium" if WER accuracy is insufficient | |
| _whisper_model = None | |
| def _get_whisper(): | |
| global _whisper_model | |
| if _whisper_model is None: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| compute = "float16" if device == "cuda" else "int8" | |
| _whisper_model = WhisperModel("base", device=device, compute_type=compute) | |
| return _whisper_model | |
| # --- UTMOS setup --- | |
| # sarulab-speech/UTMOS22 — winner of VoiceMOS Challenge 2022 | |
| # predicts human MOS scores (1-5) without reference audio | |
| _utmos_model = None | |
| def _get_utmos(): | |
| global _utmos_model | |
| if _utmos_model is None: | |
| # tarepan/SpeechMOS is a maintained fork with a proper hubconf.py | |
| # wraps the official UTMOS22 strong learner weights (MIT license) | |
| _utmos_model = torch.hub.load( | |
| "tarepan/SpeechMOS:v1.2.0", | |
| "utmos22_strong", | |
| trust_repo=True | |
| ) | |
| _utmos_model.eval() | |
| return _utmos_model | |
| def compute_wer(reference_text: str, audio_path: str) -> float: | |
| """ | |
| Transcribe audio with Whisper and compute WER against reference text. | |
| Args: | |
| reference_text: the original input text (ground truth) | |
| audio_path: path to synthesized audio file | |
| Returns: | |
| WER as a float 0.0–1.0 (multiply by 100 for percentage) | |
| """ | |
| model = _get_whisper() | |
| segments, _ = model.transcribe(audio_path, beam_size=5) | |
| hypothesis = " ".join(seg.text.strip() for seg in segments) | |
| score = wer(reference_text.lower().strip(), hypothesis.lower().strip()) | |
| return round(score, 4) | |
| def compute_utmos(audio_path: str) -> float: | |
| """ | |
| Predict MOS score using UTMOS (automated naturalness rating 1-5). | |
| Uses librosa for all formats (WAV + MP3) to avoid soundfile | |
| subprocess issues in Gradio's hot-reload worker. | |
| Args: | |
| audio_path: path to synthesized audio file | |
| Returns: | |
| predicted MOS score (float, higher = more natural) | |
| """ | |
| model = _get_utmos() | |
| audio, sr = librosa.load(audio_path, sr=16000, mono=True) | |
| wav_tensor = torch.FloatTensor(audio).unsqueeze(0) | |
| with torch.no_grad(): | |
| score = model(wav_tensor, sr=16000) | |
| return round(float(score), 3) | |
| def compute_rtf(latency_seconds: float, audio_path: str) -> float: | |
| """ | |
| Compute Real Time Factor: synthesis_time / audio_duration. | |
| RTF < 1.0 means faster than real time. | |
| Uses librosa for MP3 (sf.read may fail on MP3 depending on libsndfile version). | |
| Args: | |
| latency_seconds: wall-clock synthesis time from engine | |
| audio_path: path to synthesized audio file | |
| Returns: | |
| RTF as float | |
| """ | |
| if audio_path.endswith(".mp3"): | |
| audio, sr = librosa.load(audio_path, sr=None) | |
| else: | |
| audio, sr = sf.read(audio_path) | |
| audio_duration = len(audio) / sr | |
| if audio_duration == 0: | |
| return 0.0 | |
| return round(latency_seconds / audio_duration, 3) | |
| def evaluate( | |
| reference_text: str, | |
| audio_path: str, | |
| latency_seconds: float, | |
| engine, | |
| band: str = "unknown", | |
| synth_voice: str = "unknown", | |
| actual_cost_usd: float = None, | |
| ) -> dict: | |
| """ | |
| Run full eval suite on a synthesized audio file. | |
| Args: | |
| reference_text: original input text | |
| audio_path: path to synthesized audio | |
| latency_seconds: synthesis latency from engine.synthesize() | |
| engine: TTSEngine instance (for cost + metadata) | |
| Returns: | |
| dict with all eval scores + metadata for comparison table | |
| """ | |
| # WER — skip for mp3 if Whisper has issues; wav is preferred | |
| try: | |
| wer_score = compute_wer(reference_text, audio_path) | |
| except Exception as e: | |
| wer_score = None | |
| print(f"WER computation failed: {e}") | |
| try: | |
| utmos_score = compute_utmos(audio_path) | |
| except Exception as e: | |
| utmos_score = None | |
| print(f"UTMOS computation failed: {e}") | |
| # RTF | |
| try: | |
| rtf = compute_rtf(latency_seconds, audio_path) | |
| except Exception as e: | |
| rtf = None | |
| print(f"RTF computation failed: {e}") | |
| # cost estimate vs Chirp baseline | |
| chirp_cost = round((len(reference_text) / 1_000_000) * 16.0, 6) | |
| # use actual cost if provided by engine (e.g. RunPod returns it per request) | |
| engine_cost = round(actual_cost_usd, 6) if actual_cost_usd is not None else round(engine.estimate_cost(reference_text), 6) | |
| from datetime import datetime | |
| result = { | |
| "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), | |
| "engine": engine.name, | |
| "engine_type": engine.engine_type, | |
| "production_ready": engine.is_production_ready, | |
| "band": band, | |
| "input_text": reference_text, | |
| "voice": synth_voice, | |
| "wer": wer_score, | |
| "utmos": utmos_score, | |
| "rtf": rtf, | |
| "latency_s": latency_seconds, | |
| "engine_cost_usd": engine_cost, | |
| "chirp_equiv_usd": chirp_cost, | |
| "chars": len(reference_text), | |
| } | |
| # --- to enable persistent saving, uncomment these lines --- | |
| import pandas as pd, os | |
| results_path = os.path.join(os.path.dirname(__file__), "results", "eval_log.csv") | |
| os.makedirs(os.path.dirname(results_path), exist_ok=True) | |
| df = pd.DataFrame([result]) | |
| df.to_csv(results_path, mode="a", header=not os.path.exists(results_path), index=False) | |
| # upload updated CSV and run cleanup check in background | |
| try: | |
| from storage import upload_csv_background, cleanup_bucket_background | |
| upload_csv_background(results_path) | |
| cleanup_bucket_background(results_path) | |
| except Exception as e: | |
| print(f"[Storage] Background tasks skipped: {e}") | |
| return result |