Spaces:
Sleeping
Sleeping
| # app/app.py | |
| # Bantrly TTS Evaluation Framework | |
| # Interactive UI for comparing TTS engines across grade bands. | |
| # | |
| # Run from app/ directory: | |
| # uv run gradio app.py | |
| # | |
| # Metrics: | |
| # WER β Word Error Rate (Radford et al. 2023, Whisper) | |
| # UTMOS β Automated MOS prediction (Saeki et al. 2022, VoiceMOS Challenge) | |
| # RTF β Real Time Factor (synthesis_time / audio_duration) | |
| # Cost β Equivalent Chirp 3 HD cost at $16/1M chars | |
| import sys | |
| import os | |
| import tempfile | |
| import pandas as pd | |
| import gradio as gr | |
| from storage import upload_audio_background, download_csv | |
| sys.path.insert(0, os.path.dirname(__file__)) | |
| from dotenv import load_dotenv | |
| # loads .env locally β on HF Spaces, secrets are injected as env vars directly | |
| load_dotenv(os.path.join(os.path.dirname(__file__), ".env"), override=False) | |
| from engines import ENGINES, ENGINE_MAP | |
| from engines.kokoro_engine import KOKORO_VOICES, KOKORO_DEFAULT_VOICE | |
| from evaluator import evaluate | |
| from storage import upload_audio_background | |
| from pathlib import Path | |
| # ββ constants βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| BANDS = ["K-2", "3-5", "6-8", "9-12"] | |
| ENGINE_CHOICES = [e.name for e in ENGINES] | |
| _EVAL_LOG_PATH = os.path.join(os.path.dirname(__file__), "results", "eval_log.csv") | |
| # recommended voice per band for Kokoro | |
| KOKORO_BAND_VOICE = { | |
| "K-2": "af_heart", | |
| "3-5": "af_heart", | |
| "6-8": "af_heart", | |
| "9-12": "am_echo", | |
| } | |
| # ββ state βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _session_results: list[dict] = [] | |
| _session_audio_urls: list[str] = [] | |
| # ββ helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def format_wer(wer): | |
| if wer is None: | |
| return "N/A" | |
| pct = round(wer * 100, 1) | |
| note = " β (short text)" if wer > 0.5 else "" | |
| return f"{pct}%{note}" | |
| def format_utmos(score): | |
| if score is None: | |
| return "N/A" | |
| return f"{score:.3f} / 5.0" | |
| def format_rtf(rtf): | |
| if rtf is None: | |
| return "N/A" | |
| flag = "β faster than real time" if rtf < 1.0 else "β slower than real time" | |
| return f"{rtf:.3f}x ({flag})" | |
| def format_cost(engine_cost, chirp_cost, engine_name=""): | |
| if "RunPod" in engine_name: | |
| return f"${engine_cost:.6f} (actual)" | |
| if engine_cost == 0.0: | |
| return f"$0.00 (Chirp equiv: ${chirp_cost:.6f})" | |
| return f"${engine_cost:.6f}" | |
| def build_comparison_table(results: list[dict]) -> pd.DataFrame: | |
| columns = [ | |
| "Engine", | |
| "Band", | |
| "Voice", | |
| "UTMOS β", | |
| "WER β", | |
| "RTF β", | |
| "Latency (s)", | |
| "Cost", | |
| ] | |
| if not results: | |
| return pd.DataFrame(columns=columns) | |
| rows = [] | |
| for r in results: | |
| rows.append({ | |
| "Engine": r["engine"], | |
| "Band": r["band"], | |
| "Voice": r.get("voice", "β"), | |
| "UTMOS β": format_utmos(r["utmos"]), | |
| "WER β": format_wer(r["wer"]), | |
| "RTF β": format_rtf(r["rtf"]), | |
| "Latency (s)": r["latency_s"], | |
| "Cost": format_cost(r["engine_cost_usd"], r["chirp_equiv_usd"], r["engine"]), | |
| }) | |
| return pd.DataFrame(rows) | |
| def build_business_chart(results: list[dict]): | |
| """ | |
| Bubble chart for business decision making. | |
| X = RTF (speed, lower = better) | |
| Y = UTMOS (quality, higher = better) | |
| Bubble size = fixed (cost removed from visual) | |
| Color = engine type | |
| Reads directly from results dicts β no dependency on display column names. | |
| """ | |
| import plotly.graph_objects as go | |
| if not results: | |
| fig = go.Figure() | |
| fig.update_layout( | |
| title="Run a synthesis to see the comparison chart", | |
| height=450, | |
| ) | |
| return fig | |
| def parse_rtf(rtf_str): | |
| if rtf_str is None or rtf_str == "N/A": | |
| return None | |
| try: | |
| return float(str(rtf_str).split("x")[0]) | |
| except Exception: | |
| return None | |
| def parse_utmos(utmos_str): | |
| if utmos_str is None or utmos_str == "N/A": | |
| return None | |
| try: | |
| return float(str(utmos_str).split(" ")[0]) | |
| except Exception: | |
| return None | |
| color_map = { | |
| "neural-local": "#2ecc71", | |
| "neural-cloud-free": "#3498db", | |
| "neural-cloud-paid": "#e74c3c", | |
| "rule-based-local": "#95a5a6", | |
| } | |
| traces = {} | |
| for r in results: | |
| rtf = parse_rtf(format_rtf(r.get("rtf"))) | |
| utmos = parse_utmos(format_utmos(r.get("utmos"))) | |
| if rtf is None or utmos is None: | |
| continue | |
| engine_name = r["engine"] | |
| engine_type = r.get("engine_type", "neural-local") | |
| voice = r.get("voice", "β") | |
| latency = r.get("latency_s", "β") | |
| wer_str = format_wer(r.get("wer")) | |
| production = "β" if r.get("production_ready") else "β" | |
| color = color_map.get(engine_type, "#bdc3c7") | |
| hover = ( | |
| f"<b>{engine_name}</b><br>" | |
| f"Voice: {voice}<br>" | |
| f"UTMOS: {utmos:.3f}<br>" | |
| f"RTF: {rtf:.3f}x<br>" | |
| f"WER: {wer_str}<br>" | |
| f"Latency: {latency}s<br>" | |
| f"Cost: {format_cost(r.get('engine_cost_usd', 0), r.get('chirp_equiv_usd', 0), engine_name)}<br>" | |
| f"Production: {production}" | |
| ) | |
| if engine_type not in traces: | |
| traces[engine_type] = { | |
| "x": [], "y": [], "sizes": [], | |
| "hovers": [], "labels": [], | |
| "color": color, | |
| } | |
| traces[engine_type]["x"].append(rtf) | |
| traces[engine_type]["y"].append(utmos) | |
| cost = r.get("engine_cost_usd", 0) or 0 | |
| size = 20 + min(cost * 2000, 25) | |
| traces[engine_type]["sizes"].append(size) | |
| traces[engine_type]["hovers"].append(hover) | |
| traces[engine_type]["labels"].append(f"{engine_name}<br>({voice})") | |
| fig = go.Figure() | |
| for engine_type, data in traces.items(): | |
| fig.add_trace(go.Scatter( | |
| x=data["x"], | |
| y=data["y"], | |
| mode="markers", | |
| name=engine_type, | |
| showlegend=True, | |
| marker=dict( | |
| size=data["sizes"], | |
| color=data["color"], | |
| opacity=0.85, | |
| line=dict(width=1.5, color="rgba(255,255,255,0.5)"), | |
| ), | |
| hovertext=data["hovers"], | |
| hoverinfo="text", | |
| )) | |
| fig.add_vline( | |
| x=1.0, line_dash="dash", line_color="rgba(255,255,255,0.4)", opacity=0.8, | |
| annotation_text="RTF = 1.0", | |
| annotation_font_color="rgba(255,255,255,0.7)", | |
| annotation_position="top right", | |
| ) | |
| fig.add_hline( | |
| y=4.0, line_dash="dash", line_color="rgba(255,255,255,0.4)", opacity=0.8, | |
| annotation_text="UTMOS = 4.0 threshold", | |
| annotation_font_color="rgba(255,255,255,0.7)", | |
| annotation_position="right", | |
| ) | |
| fig.add_annotation( | |
| x=0.1, y=4.9, | |
| text="β Ideal zone<br>(fast + high quality)", | |
| showarrow=False, | |
| font=dict(color="#2ecc71", size=11), | |
| bgcolor="rgba(46,204,113,0.15)", | |
| bordercolor="#2ecc71", | |
| borderwidth=1, | |
| ) | |
| all_rtf = [x for t in traces.values() for x in t["x"]] | |
| x_max = max(3.0, max(all_rtf) + 0.5) if all_rtf else 3.0 | |
| fig.update_layout( | |
| title=dict(text="TTS Engine Comparison β Business Decision Chart", font=dict(color="white")), | |
| xaxis_title="RTF β (lower = faster synthesis)", | |
| yaxis_title="UTMOS β (higher = more natural)", | |
| height=500, | |
| legend_title="Engine Type", | |
| xaxis=dict( | |
| range=[-0.1, x_max], | |
| color="white", | |
| gridcolor="rgba(255,255,255,0.15)", | |
| title_font=dict(color="white"), | |
| tickfont=dict(color="white"), | |
| ), | |
| yaxis=dict( | |
| range=[3.5, 5.0], | |
| color="white", | |
| gridcolor="rgba(255,255,255,0.15)", | |
| title_font=dict(color="white"), | |
| tickfont=dict(color="white"), | |
| ), | |
| legend=dict( | |
| title=dict(text="Engine Type", font=dict(color="white", size=12)), | |
| font=dict(color="white"), | |
| bgcolor="rgba(30,30,30,0.8)", | |
| bordercolor="rgba(255,255,255,0.3)", | |
| borderwidth=1, | |
| ), | |
| hovermode="closest", | |
| plot_bgcolor="rgba(0,0,0,0)", | |
| paper_bgcolor="rgba(0,0,0,0)", | |
| font=dict(color="white"), | |
| ) | |
| fig.update_xaxes(showgrid=True, gridcolor="rgba(128,128,128,0.2)") | |
| fig.update_yaxes(showgrid=True, gridcolor="rgba(128,128,128,0.2)") | |
| return fig | |
| def _make_audio_filename(engine_name: str, band: str, ext: str) -> str: | |
| """Generate a unique bucket filename for an audio file.""" | |
| from datetime import datetime | |
| ts = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| safe_engine = engine_name.replace(" ", "_").replace("(", "").replace(")", "") | |
| safe_band = band.replace("-", "") | |
| return f"{ts}_{safe_engine}_{safe_band}{ext}" | |
| # ββ event handlers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def on_row_select(evt: gr.SelectData) -> tuple: | |
| """ | |
| On row click: play audio and show metrics detail card. | |
| Uses _session_audio_urls indexed by row β URL never shown in table. | |
| Falls back to load_history URLs if session list is shorter (history mode). | |
| """ | |
| try: | |
| row_idx = evt.index[0] | |
| # get audio url | |
| url = None | |
| if row_idx < len(_session_audio_urls): | |
| url = _session_audio_urls[row_idx] | |
| # get result for detail card | |
| result = None | |
| if row_idx < len(_session_results): | |
| result = _session_results[row_idx] | |
| # build detail markdown | |
| if result: | |
| detail = ( | |
| f"**Engine:** {result['engine']} | " | |
| f"**Band:** {result['band']} | " | |
| f"**Voice:** {result.get('voice', 'β')}\n\n" | |
| f"**UTMOS:** {format_utmos(result['utmos'])} | " | |
| f"**WER:** {format_wer(result['wer'])} | " | |
| f"**RTF:** {format_rtf(result['rtf'])} | " | |
| f"**Latency:** {result['latency_s']}s | " | |
| f"**Cost:** {format_cost(result['engine_cost_usd'], result['chirp_equiv_usd'], result['engine'])}\n\n" | |
| f"**Text:** {result.get('input_text', 'β')}" | |
| ) | |
| else: | |
| detail = "" | |
| if url and str(url).startswith("http"): | |
| return gr.update(value=url, visible=True), gr.update(value=detail, visible=True) | |
| return gr.update(visible=False), gr.update(value=detail, visible=bool(detail)) | |
| except Exception as e: | |
| print(f"[Playback] Row select failed: {e}") | |
| return gr.update(visible=False), gr.update(visible=False) | |
| def on_engine_change(engine_name: str): | |
| """Show voice dropdown only for Kokoro.""" | |
| is_kokoro = engine_name == "Kokoro (tuned)" | |
| return gr.update(visible=is_kokoro) | |
| def on_band_change(band: str, engine_name: str): | |
| """Update voice dropdown to recommended voice when band changes (Kokoro only).""" | |
| if engine_name != "Kokoro (tuned)": | |
| return gr.update(visible=False, value=KOKORO_DEFAULT_VOICE) | |
| recommended = KOKORO_BAND_VOICE.get(band, KOKORO_DEFAULT_VOICE) | |
| return gr.update(visible=True, value=recommended) | |
| def run_synthesis(engine_name: str, band: str, text: str, voice: str): | |
| if not text.strip(): | |
| yield None, "β Please enter some text first.", build_comparison_table(_session_results), build_business_chart(_session_results) | |
| return | |
| engine = ENGINE_MAP.get(engine_name) | |
| if engine is None: | |
| yield None, f"β Engine '{engine_name}' not found.", build_comparison_table(_session_results), build_business_chart(_session_results) | |
| return | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: | |
| tmp_path = f.name.replace(".wav", "") | |
| yield None, f"Synthesizing with {engine_name}...", build_comparison_table(_session_results), build_business_chart(_session_results) | |
| try: | |
| # pass voice override only for Kokoro | |
| if engine_name == "Kokoro (tuned)": | |
| synth_result = engine.synthesize(text, band, tmp_path, voice_override=voice) | |
| else: | |
| synth_result = engine.synthesize(text, band, tmp_path) | |
| audio_path = synth_result["audio_path"] | |
| except NotImplementedError as e: | |
| yield None, f"β {e}", build_comparison_table(_session_results), build_business_chart(_session_results) | |
| return | |
| except Exception as e: | |
| yield None, f"β Synthesis failed: {e}", build_comparison_table(_session_results), build_business_chart(_session_results) | |
| return | |
| yield audio_path, "Running evals (WER, UTMOS, RTF)...", build_comparison_table(_session_results), build_business_chart(_session_results) | |
| try: | |
| eval_result = evaluate( | |
| reference_text=text, | |
| audio_path=audio_path, | |
| latency_seconds=synth_result["latency_seconds"], | |
| engine=engine, | |
| band=band, | |
| synth_voice=synth_result.get("voice", "unknown"), | |
| actual_cost_usd=synth_result.get("actual_cost_usd", None), | |
| ) | |
| except Exception as e: | |
| yield audio_path, f"β Eval failed: {e}", build_comparison_table(_session_results), build_business_chart(_session_results) | |
| return | |
| # upload audio to Supabase in background β non-blocking | |
| audio_ext = Path(audio_path).suffix | |
| bucket_filename = _make_audio_filename(engine_name, band, audio_ext) | |
| def _on_upload(url): | |
| if url: | |
| eval_result["audio_url"] = url | |
| print(f"[Storage] Uploaded: {url}") | |
| # update the CSV row with the real audio URL | |
| try: | |
| import pandas as pd | |
| if os.path.exists(_EVAL_LOG_PATH): | |
| df = pd.read_csv(_EVAL_LOG_PATH, dtype={"audio_url": str}) | |
| if "audio_url" not in df.columns: | |
| df["audio_url"] = "" | |
| # match by timestamp + engine + band β unique enough | |
| mask = ( | |
| (df["timestamp"] == eval_result["timestamp"]) & | |
| (df["engine"] == eval_result["engine"]) & | |
| (df["band"] == eval_result["band"]) | |
| ) | |
| df.loc[mask, "audio_url"] = url | |
| df.to_csv(_EVAL_LOG_PATH, index=False) | |
| # re-upload updated CSV to Supabase | |
| from storage import upload_csv_background | |
| upload_csv_background(_EVAL_LOG_PATH) | |
| except Exception as e: | |
| print(f"[Storage] CSV audio_url update failed: {e}") | |
| else: | |
| eval_result["audio_url"] = None | |
| upload_audio_background(audio_path, bucket_filename, callback=_on_upload) | |
| eval_result["audio_url"] = None # placeholder until upload completes | |
| _session_results.append(eval_result) | |
| _session_audio_urls.append(eval_result.get("audio_url") or "") | |
| status = ( | |
| f"β Done β " | |
| f"UTMOS: {format_utmos(eval_result['utmos'])} | " | |
| f"WER: {format_wer(eval_result['wer'])} | " | |
| f"RTF: {format_rtf(eval_result['rtf'])}" | |
| ) | |
| yield audio_path, status, build_comparison_table(_session_results), build_business_chart(_session_results) | |
| def clear_results(): | |
| _session_results.clear() | |
| _session_audio_urls.clear() | |
| return build_comparison_table(_session_results), build_business_chart(_session_results), "Results cleared." | |
| def export_session(): | |
| if not _session_results: | |
| return gr.update(visible=False), "β No session results to export." | |
| df = pd.DataFrame(_session_results) | |
| export_path = os.path.join(os.path.dirname(__file__), "session_export.csv") | |
| df.to_csv(export_path, index=False, encoding="utf-8-sig") | |
| return gr.update(value=export_path, visible=True), "β Session exported." | |
| def export_all(): | |
| if not os.path.exists(_EVAL_LOG_PATH): | |
| return gr.update(visible=False), "β No history log found." | |
| try: | |
| df = pd.read_csv(_EVAL_LOG_PATH, dtype={"audio_url": str}) | |
| export_path = os.path.join(os.path.dirname(__file__), "history_export.csv") | |
| df.to_csv(export_path, index=False, encoding="utf-8-sig") | |
| return gr.update(value=export_path, visible=True), "β Full history log ready to download." | |
| except Exception as e: | |
| return gr.update(visible=False), f"β Failed: {e}" | |
| def load_history(): | |
| global _session_results, _session_audio_urls | |
| # try Supabase first, fall back to local CSV | |
| try: | |
| from storage import download_csv | |
| download_csv(_EVAL_LOG_PATH) | |
| except Exception as e: | |
| print(f"[Storage] Supabase download skipped, using local: {e}") | |
| if not os.path.exists(_EVAL_LOG_PATH): | |
| return build_comparison_table([]), build_business_chart([]), "β No history found." | |
| try: | |
| df = pd.read_csv(_EVAL_LOG_PATH, dtype={"audio_url": str}) | |
| if "audio_url" not in df.columns: | |
| df["audio_url"] = "" | |
| records = df.to_dict(orient="records") | |
| # populate session state so row click works | |
| _session_results = records | |
| _session_audio_urls = [ | |
| str(r.get("audio_url", "")) if str(r.get("audio_url", "")) not in ("nan", "None", "") else "" | |
| for r in records | |
| ] | |
| return build_comparison_table(records), build_business_chart(records), f"β Loaded {len(records)} historical runs." | |
| except Exception as e: | |
| return build_comparison_table([]), build_business_chart([]), f"β Failed: {e}" | |
| def refresh_table(): | |
| """Rebuild comparison table from current session results β picks up audio URLs from completed uploads.""" | |
| return build_comparison_table(_session_results) | |
| # ββ UI ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def build_ui(): | |
| with gr.Blocks(title="Bantrly TTS Evaluation Framework") as demo: | |
| gr.Markdown(""" | |
| # π Bantrly TTS Evaluation Framework | |
| Compare TTS engines on coaching text across grade bands. | |
| **Metrics:** UTMOS (naturalness, β better) Β· WER (intelligibility, β better) Β· RTF (speed, β better) Β· Cost vs Chirp 3 HD | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| engine_selector = gr.Dropdown( | |
| choices=ENGINE_CHOICES, | |
| value=ENGINE_CHOICES[0], | |
| label="TTS Engine", | |
| ) | |
| band_selector = gr.Dropdown( | |
| choices=BANDS, | |
| value="K-2", | |
| label="Grade Band", | |
| ) | |
| voice_selector = gr.Dropdown( | |
| choices=KOKORO_VOICES, | |
| value=KOKORO_DEFAULT_VOICE, | |
| label="Voice (Kokoro only)", | |
| visible=True, # Kokoro is default engine | |
| info="Defaults to recommended voice for selected band. Override freely.", | |
| ) | |
| text_input = gr.Textbox( | |
| label="Coaching Text", | |
| placeholder="Type or paste any coaching text here...", | |
| lines=4, | |
| value="You did such a great job speaking today! I loved how loud and clear your voice was.", | |
| ) | |
| synthesize_btn = gr.Button("βΆ Synthesize + Eval", variant="primary") | |
| with gr.Column(scale=1): | |
| audio_output = gr.Audio(label="Output Audio", type="filepath") | |
| status_output = gr.Textbox(label="Status", interactive=False, lines=3) | |
| gr.Markdown("## Comparison Table") | |
| gr.Markdown( | |
| "**β higher is better Β· β lower is better** β " | |
| "WER may exceed 100% on short texts." | |
| ) | |
| comparison_table = gr.Dataframe( | |
| value=build_comparison_table([]), | |
| label="Eval Results β click a row to play audio", | |
| interactive=False, | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| row_audio_player = gr.Audio( | |
| label="βΆ Selected Row Audio", | |
| visible=False, | |
| type="filepath", | |
| ) | |
| with gr.Column(scale=2): | |
| row_detail = gr.Markdown( | |
| value="", | |
| visible=False, | |
| ) | |
| business_chart = gr.Plot( | |
| value=build_business_chart([]), | |
| label="Business Decision Chart", | |
| ) | |
| with gr.Row(): | |
| clear_btn = gr.Button("π Clear Session") | |
| refresh_btn = gr.Button("π Refresh Table") | |
| load_history_btn = gr.Button("π Load History") | |
| export_session_btn = gr.Button("β¬ Export Session") | |
| export_all_btn = gr.Button("β¬ Export Full History") | |
| with gr.Row(): | |
| export_file = gr.File(label="Download CSV", visible=False) | |
| export_status = gr.Textbox(label="", interactive=False, visible=True, value="") | |
| # ββ bindings ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| engine_selector.change( | |
| fn=on_engine_change, | |
| inputs=[engine_selector], | |
| outputs=[voice_selector], | |
| ) | |
| band_selector.change( | |
| fn=on_band_change, | |
| inputs=[band_selector, engine_selector], | |
| outputs=[voice_selector], | |
| ) | |
| synthesize_btn.click( | |
| fn=run_synthesis, | |
| inputs=[engine_selector, band_selector, text_input, voice_selector], | |
| outputs=[audio_output, status_output, comparison_table, business_chart], | |
| ) | |
| clear_btn.click( | |
| fn=clear_results, | |
| outputs=[comparison_table, business_chart, export_status], | |
| ) | |
| refresh_btn.click( | |
| fn=refresh_table, | |
| outputs=[comparison_table], | |
| ) | |
| comparison_table.select( | |
| fn=on_row_select, | |
| inputs=[], | |
| outputs=[row_audio_player, row_detail], | |
| ) | |
| load_history_btn.click( | |
| fn=load_history, | |
| outputs=[comparison_table, business_chart, export_status], | |
| ) | |
| export_session_btn.click( | |
| fn=export_session, | |
| outputs=[export_file, export_status], | |
| ) | |
| export_all_btn.click( | |
| fn=export_all, | |
| outputs=[export_file, export_status], | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = build_ui() | |
| demo.launch(share=False) |