File size: 6,555 Bytes
a3419b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24a256c
 
a3419b6
 
24a256c
a3419b6
 
 
 
 
24a256c
a3419b6
 
 
 
 
 
 
 
 
 
 
24a256c
a3419b6
 
 
 
 
 
 
 
24a256c
 
 
 
 
a3419b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229a3e3
29ebcc1
229a3e3
29ebcc1
229a3e3
29ebcc1
229a3e3
29ebcc1
a3419b6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
# app/evaluator.py
# TTS evaluation pipeline for the Bantrly eval framework.
#
# Metrics:
#   WER   β€” Word Error Rate via Whisper transcription (Radford et al. 2023)
#   UTMOS β€” Automated MOS prediction (Saeki et al. 2022, VoiceMOS Challenge winner)
#   RTF   β€” Real Time Factor: synthesis_time / audio_duration
#   Cost  β€” Equivalent cost vs Chirp 3 HD ($16/1M chars)
#
# To enable persistent result saving, uncomment the save_results() call
# at the bottom of evaluate().

import time
import librosa
import torch
import soundfile as sf
import numpy as np
from jiwer import wer
from faster_whisper import WhisperModel

# --- Whisper setup ---
# "base" model: ~150MB, fast, good enough for WER on clean TTS output
# upgrade to "small" or "medium" if WER accuracy is insufficient
_whisper_model = None

def _get_whisper():
    global _whisper_model
    if _whisper_model is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"
        compute = "float16" if device == "cuda" else "int8"
        _whisper_model = WhisperModel("base", device=device, compute_type=compute)
    return _whisper_model


# --- UTMOS setup ---
# sarulab-speech/UTMOS22 β€” winner of VoiceMOS Challenge 2022
# predicts human MOS scores (1-5) without reference audio
_utmos_model = None

def _get_utmos():
    global _utmos_model
    if _utmos_model is None:
        # tarepan/SpeechMOS is a maintained fork with a proper hubconf.py
        # wraps the official UTMOS22 strong learner weights (MIT license)
        _utmos_model = torch.hub.load(
            "tarepan/SpeechMOS:v1.2.0",
            "utmos22_strong",
            trust_repo=True
        )
        _utmos_model.eval()
    return _utmos_model


def compute_wer(reference_text: str, audio_path: str) -> float:
    """
    Transcribe audio with Whisper and compute WER against reference text.

    Args:
        reference_text: the original input text (ground truth)
        audio_path:     path to synthesized audio file

    Returns:
        WER as a float 0.0–1.0 (multiply by 100 for percentage)
    """
    model = _get_whisper()
    segments, _ = model.transcribe(audio_path, beam_size=5)
    hypothesis = " ".join(seg.text.strip() for seg in segments)
    score = wer(reference_text.lower().strip(), hypothesis.lower().strip())
    return round(score, 4)


def compute_utmos(audio_path: str) -> float:
    """
    Predict MOS score using UTMOS (automated naturalness rating 1-5).
    Uses librosa for all formats (WAV + MP3) to avoid soundfile
    subprocess issues in Gradio's hot-reload worker.

    Args:
        audio_path: path to synthesized audio file

    Returns:
        predicted MOS score (float, higher = more natural)
    """
    model = _get_utmos()
    audio, sr = librosa.load(audio_path, sr=16000, mono=True)
    wav_tensor = torch.FloatTensor(audio).unsqueeze(0)

    with torch.no_grad():
        score = model(wav_tensor, sr=16000)

    return round(float(score), 3)

def compute_rtf(latency_seconds: float, audio_path: str) -> float:
    """
    Compute Real Time Factor: synthesis_time / audio_duration.
    RTF < 1.0 means faster than real time.
    Uses librosa for MP3 (sf.read may fail on MP3 depending on libsndfile version).

    Args:
        latency_seconds: wall-clock synthesis time from engine
        audio_path:      path to synthesized audio file

    Returns:
        RTF as float
    """
    if audio_path.endswith(".mp3"):
        audio, sr = librosa.load(audio_path, sr=None)
    else:
        audio, sr = sf.read(audio_path)

    audio_duration = len(audio) / sr
    if audio_duration == 0:
        return 0.0
    return round(latency_seconds / audio_duration, 3)


def evaluate(
    reference_text: str,
    audio_path: str,
    latency_seconds: float,
    engine,
    band: str = "unknown",
    synth_voice: str = "unknown",
    actual_cost_usd: float = None,
) -> dict:
    """
    Run full eval suite on a synthesized audio file.

    Args:
        reference_text:  original input text
        audio_path:      path to synthesized audio
        latency_seconds: synthesis latency from engine.synthesize()
        engine:          TTSEngine instance (for cost + metadata)

    Returns:
        dict with all eval scores + metadata for comparison table
    """
    # WER β€” skip for mp3 if Whisper has issues; wav is preferred
    try:
        wer_score = compute_wer(reference_text, audio_path)
    except Exception as e:
        wer_score = None
        print(f"WER computation failed: {e}")

    try:
        utmos_score = compute_utmos(audio_path)
    except Exception as e:
        utmos_score = None
        print(f"UTMOS computation failed: {e}")

    # RTF
    try:
        rtf = compute_rtf(latency_seconds, audio_path)
    except Exception as e:
        rtf = None
        print(f"RTF computation failed: {e}")

    # cost estimate vs Chirp baseline
    chirp_cost = round((len(reference_text) / 1_000_000) * 16.0, 6)
    # use actual cost if provided by engine (e.g. RunPod returns it per request)
    engine_cost = round(actual_cost_usd, 6) if actual_cost_usd is not None else round(engine.estimate_cost(reference_text), 6)

    from datetime import datetime

    result = {
        "timestamp":         datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
        "engine":            engine.name,
        "engine_type":       engine.engine_type,
        "production_ready":  engine.is_production_ready,
        "band":              band,
        "input_text":        reference_text,
        "voice":             synth_voice,
        "wer":               wer_score,
        "utmos":             utmos_score,
        "rtf":               rtf,
        "latency_s":         latency_seconds,
        "engine_cost_usd":   engine_cost,
        "chirp_equiv_usd":   chirp_cost,
        "chars":             len(reference_text),
    }

    # --- to enable persistent saving, uncomment these lines ---
    import pandas as pd, os
    results_path = os.path.join(os.path.dirname(__file__), "results", "eval_log.csv")
    os.makedirs(os.path.dirname(results_path), exist_ok=True)
    df = pd.DataFrame([result])
    df.to_csv(results_path, mode="a", header=not os.path.exists(results_path), index=False)

    # upload updated CSV and run cleanup check in background
    try:
        from storage import upload_csv_background, cleanup_bucket_background
        upload_csv_background(results_path)
        cleanup_bucket_background(results_path)
    except Exception as e:
        print(f"[Storage] Background tasks skipped: {e}")

    return result