aankitdas's picture
added storage limit guard
229a3e3
# app/app.py
# Bantrly TTS Evaluation Framework
# Interactive UI for comparing TTS engines across grade bands.
#
# Run from app/ directory:
# uv run gradio app.py
#
# Metrics:
# WER β€” Word Error Rate (Radford et al. 2023, Whisper)
# UTMOS β€” Automated MOS prediction (Saeki et al. 2022, VoiceMOS Challenge)
# RTF β€” Real Time Factor (synthesis_time / audio_duration)
# Cost β€” Equivalent Chirp 3 HD cost at $16/1M chars
import sys
import os
import tempfile
import pandas as pd
import gradio as gr
from storage import upload_audio_background, download_csv
sys.path.insert(0, os.path.dirname(__file__))
from dotenv import load_dotenv
# loads .env locally β€” on HF Spaces, secrets are injected as env vars directly
load_dotenv(os.path.join(os.path.dirname(__file__), ".env"), override=False)
from engines import ENGINES, ENGINE_MAP
from engines.kokoro_engine import KOKORO_VOICES, KOKORO_DEFAULT_VOICE
from evaluator import evaluate
from storage import upload_audio_background
from pathlib import Path
# ── constants ─────────────────────────────────────────────────────────────────
BANDS = ["K-2", "3-5", "6-8", "9-12"]
ENGINE_CHOICES = [e.name for e in ENGINES]
_EVAL_LOG_PATH = os.path.join(os.path.dirname(__file__), "results", "eval_log.csv")
# recommended voice per band for Kokoro
KOKORO_BAND_VOICE = {
"K-2": "af_heart",
"3-5": "af_heart",
"6-8": "af_heart",
"9-12": "am_echo",
}
# ── state ─────────────────────────────────────────────────────────────────────
_session_results: list[dict] = []
_session_audio_urls: list[str] = []
# ── helpers ───────────────────────────────────────────────────────────────────
def format_wer(wer):
if wer is None:
return "N/A"
pct = round(wer * 100, 1)
note = " ⚠ (short text)" if wer > 0.5 else ""
return f"{pct}%{note}"
def format_utmos(score):
if score is None:
return "N/A"
return f"{score:.3f} / 5.0"
def format_rtf(rtf):
if rtf is None:
return "N/A"
flag = "βœ“ faster than real time" if rtf < 1.0 else "βœ— slower than real time"
return f"{rtf:.3f}x ({flag})"
def format_cost(engine_cost, chirp_cost, engine_name=""):
if "RunPod" in engine_name:
return f"${engine_cost:.6f} (actual)"
if engine_cost == 0.0:
return f"$0.00 (Chirp equiv: ${chirp_cost:.6f})"
return f"${engine_cost:.6f}"
def build_comparison_table(results: list[dict]) -> pd.DataFrame:
columns = [
"Engine",
"Band",
"Voice",
"UTMOS ↑",
"WER ↓",
"RTF ↓",
"Latency (s)",
"Cost",
]
if not results:
return pd.DataFrame(columns=columns)
rows = []
for r in results:
rows.append({
"Engine": r["engine"],
"Band": r["band"],
"Voice": r.get("voice", "β€”"),
"UTMOS ↑": format_utmos(r["utmos"]),
"WER ↓": format_wer(r["wer"]),
"RTF ↓": format_rtf(r["rtf"]),
"Latency (s)": r["latency_s"],
"Cost": format_cost(r["engine_cost_usd"], r["chirp_equiv_usd"], r["engine"]),
})
return pd.DataFrame(rows)
def build_business_chart(results: list[dict]):
"""
Bubble chart for business decision making.
X = RTF (speed, lower = better)
Y = UTMOS (quality, higher = better)
Bubble size = fixed (cost removed from visual)
Color = engine type
Reads directly from results dicts β€” no dependency on display column names.
"""
import plotly.graph_objects as go
if not results:
fig = go.Figure()
fig.update_layout(
title="Run a synthesis to see the comparison chart",
height=450,
)
return fig
def parse_rtf(rtf_str):
if rtf_str is None or rtf_str == "N/A":
return None
try:
return float(str(rtf_str).split("x")[0])
except Exception:
return None
def parse_utmos(utmos_str):
if utmos_str is None or utmos_str == "N/A":
return None
try:
return float(str(utmos_str).split(" ")[0])
except Exception:
return None
color_map = {
"neural-local": "#2ecc71",
"neural-cloud-free": "#3498db",
"neural-cloud-paid": "#e74c3c",
"rule-based-local": "#95a5a6",
}
traces = {}
for r in results:
rtf = parse_rtf(format_rtf(r.get("rtf")))
utmos = parse_utmos(format_utmos(r.get("utmos")))
if rtf is None or utmos is None:
continue
engine_name = r["engine"]
engine_type = r.get("engine_type", "neural-local")
voice = r.get("voice", "β€”")
latency = r.get("latency_s", "β€”")
wer_str = format_wer(r.get("wer"))
production = "βœ“" if r.get("production_ready") else "βœ—"
color = color_map.get(engine_type, "#bdc3c7")
hover = (
f"<b>{engine_name}</b><br>"
f"Voice: {voice}<br>"
f"UTMOS: {utmos:.3f}<br>"
f"RTF: {rtf:.3f}x<br>"
f"WER: {wer_str}<br>"
f"Latency: {latency}s<br>"
f"Cost: {format_cost(r.get('engine_cost_usd', 0), r.get('chirp_equiv_usd', 0), engine_name)}<br>"
f"Production: {production}"
)
if engine_type not in traces:
traces[engine_type] = {
"x": [], "y": [], "sizes": [],
"hovers": [], "labels": [],
"color": color,
}
traces[engine_type]["x"].append(rtf)
traces[engine_type]["y"].append(utmos)
cost = r.get("engine_cost_usd", 0) or 0
size = 20 + min(cost * 2000, 25)
traces[engine_type]["sizes"].append(size)
traces[engine_type]["hovers"].append(hover)
traces[engine_type]["labels"].append(f"{engine_name}<br>({voice})")
fig = go.Figure()
for engine_type, data in traces.items():
fig.add_trace(go.Scatter(
x=data["x"],
y=data["y"],
mode="markers",
name=engine_type,
showlegend=True,
marker=dict(
size=data["sizes"],
color=data["color"],
opacity=0.85,
line=dict(width=1.5, color="rgba(255,255,255,0.5)"),
),
hovertext=data["hovers"],
hoverinfo="text",
))
fig.add_vline(
x=1.0, line_dash="dash", line_color="rgba(255,255,255,0.4)", opacity=0.8,
annotation_text="RTF = 1.0",
annotation_font_color="rgba(255,255,255,0.7)",
annotation_position="top right",
)
fig.add_hline(
y=4.0, line_dash="dash", line_color="rgba(255,255,255,0.4)", opacity=0.8,
annotation_text="UTMOS = 4.0 threshold",
annotation_font_color="rgba(255,255,255,0.7)",
annotation_position="right",
)
fig.add_annotation(
x=0.1, y=4.9,
text="βœ“ Ideal zone<br>(fast + high quality)",
showarrow=False,
font=dict(color="#2ecc71", size=11),
bgcolor="rgba(46,204,113,0.15)",
bordercolor="#2ecc71",
borderwidth=1,
)
all_rtf = [x for t in traces.values() for x in t["x"]]
x_max = max(3.0, max(all_rtf) + 0.5) if all_rtf else 3.0
fig.update_layout(
title=dict(text="TTS Engine Comparison β€” Business Decision Chart", font=dict(color="white")),
xaxis_title="RTF ↓ (lower = faster synthesis)",
yaxis_title="UTMOS ↑ (higher = more natural)",
height=500,
legend_title="Engine Type",
xaxis=dict(
range=[-0.1, x_max],
color="white",
gridcolor="rgba(255,255,255,0.15)",
title_font=dict(color="white"),
tickfont=dict(color="white"),
),
yaxis=dict(
range=[3.5, 5.0],
color="white",
gridcolor="rgba(255,255,255,0.15)",
title_font=dict(color="white"),
tickfont=dict(color="white"),
),
legend=dict(
title=dict(text="Engine Type", font=dict(color="white", size=12)),
font=dict(color="white"),
bgcolor="rgba(30,30,30,0.8)",
bordercolor="rgba(255,255,255,0.3)",
borderwidth=1,
),
hovermode="closest",
plot_bgcolor="rgba(0,0,0,0)",
paper_bgcolor="rgba(0,0,0,0)",
font=dict(color="white"),
)
fig.update_xaxes(showgrid=True, gridcolor="rgba(128,128,128,0.2)")
fig.update_yaxes(showgrid=True, gridcolor="rgba(128,128,128,0.2)")
return fig
def _make_audio_filename(engine_name: str, band: str, ext: str) -> str:
"""Generate a unique bucket filename for an audio file."""
from datetime import datetime
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
safe_engine = engine_name.replace(" ", "_").replace("(", "").replace(")", "")
safe_band = band.replace("-", "")
return f"{ts}_{safe_engine}_{safe_band}{ext}"
# ── event handlers ────────────────────────────────────────────────────────────
def on_row_select(evt: gr.SelectData) -> tuple:
"""
On row click: play audio and show metrics detail card.
Uses _session_audio_urls indexed by row β€” URL never shown in table.
Falls back to load_history URLs if session list is shorter (history mode).
"""
try:
row_idx = evt.index[0]
# get audio url
url = None
if row_idx < len(_session_audio_urls):
url = _session_audio_urls[row_idx]
# get result for detail card
result = None
if row_idx < len(_session_results):
result = _session_results[row_idx]
# build detail markdown
if result:
detail = (
f"**Engine:** {result['engine']} | "
f"**Band:** {result['band']} | "
f"**Voice:** {result.get('voice', 'β€”')}\n\n"
f"**UTMOS:** {format_utmos(result['utmos'])} | "
f"**WER:** {format_wer(result['wer'])} | "
f"**RTF:** {format_rtf(result['rtf'])} | "
f"**Latency:** {result['latency_s']}s | "
f"**Cost:** {format_cost(result['engine_cost_usd'], result['chirp_equiv_usd'], result['engine'])}\n\n"
f"**Text:** {result.get('input_text', 'β€”')}"
)
else:
detail = ""
if url and str(url).startswith("http"):
return gr.update(value=url, visible=True), gr.update(value=detail, visible=True)
return gr.update(visible=False), gr.update(value=detail, visible=bool(detail))
except Exception as e:
print(f"[Playback] Row select failed: {e}")
return gr.update(visible=False), gr.update(visible=False)
def on_engine_change(engine_name: str):
"""Show voice dropdown only for Kokoro."""
is_kokoro = engine_name == "Kokoro (tuned)"
return gr.update(visible=is_kokoro)
def on_band_change(band: str, engine_name: str):
"""Update voice dropdown to recommended voice when band changes (Kokoro only)."""
if engine_name != "Kokoro (tuned)":
return gr.update(visible=False, value=KOKORO_DEFAULT_VOICE)
recommended = KOKORO_BAND_VOICE.get(band, KOKORO_DEFAULT_VOICE)
return gr.update(visible=True, value=recommended)
def run_synthesis(engine_name: str, band: str, text: str, voice: str):
if not text.strip():
yield None, "⚠ Please enter some text first.", build_comparison_table(_session_results), build_business_chart(_session_results)
return
engine = ENGINE_MAP.get(engine_name)
if engine is None:
yield None, f"⚠ Engine '{engine_name}' not found.", build_comparison_table(_session_results), build_business_chart(_session_results)
return
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
tmp_path = f.name.replace(".wav", "")
yield None, f"Synthesizing with {engine_name}...", build_comparison_table(_session_results), build_business_chart(_session_results)
try:
# pass voice override only for Kokoro
if engine_name == "Kokoro (tuned)":
synth_result = engine.synthesize(text, band, tmp_path, voice_override=voice)
else:
synth_result = engine.synthesize(text, band, tmp_path)
audio_path = synth_result["audio_path"]
except NotImplementedError as e:
yield None, f"⚠ {e}", build_comparison_table(_session_results), build_business_chart(_session_results)
return
except Exception as e:
yield None, f"βœ— Synthesis failed: {e}", build_comparison_table(_session_results), build_business_chart(_session_results)
return
yield audio_path, "Running evals (WER, UTMOS, RTF)...", build_comparison_table(_session_results), build_business_chart(_session_results)
try:
eval_result = evaluate(
reference_text=text,
audio_path=audio_path,
latency_seconds=synth_result["latency_seconds"],
engine=engine,
band=band,
synth_voice=synth_result.get("voice", "unknown"),
actual_cost_usd=synth_result.get("actual_cost_usd", None),
)
except Exception as e:
yield audio_path, f"βœ— Eval failed: {e}", build_comparison_table(_session_results), build_business_chart(_session_results)
return
# upload audio to Supabase in background β€” non-blocking
audio_ext = Path(audio_path).suffix
bucket_filename = _make_audio_filename(engine_name, band, audio_ext)
def _on_upload(url):
if url:
eval_result["audio_url"] = url
print(f"[Storage] Uploaded: {url}")
# update the CSV row with the real audio URL
try:
import pandas as pd
if os.path.exists(_EVAL_LOG_PATH):
df = pd.read_csv(_EVAL_LOG_PATH, dtype={"audio_url": str})
if "audio_url" not in df.columns:
df["audio_url"] = ""
# match by timestamp + engine + band β€” unique enough
mask = (
(df["timestamp"] == eval_result["timestamp"]) &
(df["engine"] == eval_result["engine"]) &
(df["band"] == eval_result["band"])
)
df.loc[mask, "audio_url"] = url
df.to_csv(_EVAL_LOG_PATH, index=False)
# re-upload updated CSV to Supabase
from storage import upload_csv_background
upload_csv_background(_EVAL_LOG_PATH)
except Exception as e:
print(f"[Storage] CSV audio_url update failed: {e}")
else:
eval_result["audio_url"] = None
upload_audio_background(audio_path, bucket_filename, callback=_on_upload)
eval_result["audio_url"] = None # placeholder until upload completes
_session_results.append(eval_result)
_session_audio_urls.append(eval_result.get("audio_url") or "")
status = (
f"βœ“ Done β€” "
f"UTMOS: {format_utmos(eval_result['utmos'])} | "
f"WER: {format_wer(eval_result['wer'])} | "
f"RTF: {format_rtf(eval_result['rtf'])}"
)
yield audio_path, status, build_comparison_table(_session_results), build_business_chart(_session_results)
def clear_results():
_session_results.clear()
_session_audio_urls.clear()
return build_comparison_table(_session_results), build_business_chart(_session_results), "Results cleared."
def export_session():
if not _session_results:
return gr.update(visible=False), "⚠ No session results to export."
df = pd.DataFrame(_session_results)
export_path = os.path.join(os.path.dirname(__file__), "session_export.csv")
df.to_csv(export_path, index=False, encoding="utf-8-sig")
return gr.update(value=export_path, visible=True), "βœ“ Session exported."
def export_all():
if not os.path.exists(_EVAL_LOG_PATH):
return gr.update(visible=False), "⚠ No history log found."
try:
df = pd.read_csv(_EVAL_LOG_PATH, dtype={"audio_url": str})
export_path = os.path.join(os.path.dirname(__file__), "history_export.csv")
df.to_csv(export_path, index=False, encoding="utf-8-sig")
return gr.update(value=export_path, visible=True), "βœ“ Full history log ready to download."
except Exception as e:
return gr.update(visible=False), f"βœ— Failed: {e}"
def load_history():
global _session_results, _session_audio_urls
# try Supabase first, fall back to local CSV
try:
from storage import download_csv
download_csv(_EVAL_LOG_PATH)
except Exception as e:
print(f"[Storage] Supabase download skipped, using local: {e}")
if not os.path.exists(_EVAL_LOG_PATH):
return build_comparison_table([]), build_business_chart([]), "⚠ No history found."
try:
df = pd.read_csv(_EVAL_LOG_PATH, dtype={"audio_url": str})
if "audio_url" not in df.columns:
df["audio_url"] = ""
records = df.to_dict(orient="records")
# populate session state so row click works
_session_results = records
_session_audio_urls = [
str(r.get("audio_url", "")) if str(r.get("audio_url", "")) not in ("nan", "None", "") else ""
for r in records
]
return build_comparison_table(records), build_business_chart(records), f"βœ“ Loaded {len(records)} historical runs."
except Exception as e:
return build_comparison_table([]), build_business_chart([]), f"βœ— Failed: {e}"
def refresh_table():
"""Rebuild comparison table from current session results β€” picks up audio URLs from completed uploads."""
return build_comparison_table(_session_results)
# ── UI ────────────────────────────────────────────────────────────────────────
def build_ui():
with gr.Blocks(title="Bantrly TTS Evaluation Framework") as demo:
gr.Markdown("""
# πŸŽ™ Bantrly TTS Evaluation Framework
Compare TTS engines on coaching text across grade bands.
**Metrics:** UTMOS (naturalness, ↑ better) Β· WER (intelligibility, ↓ better) Β· RTF (speed, ↓ better) Β· Cost vs Chirp 3 HD
""")
with gr.Row():
with gr.Column(scale=1):
engine_selector = gr.Dropdown(
choices=ENGINE_CHOICES,
value=ENGINE_CHOICES[0],
label="TTS Engine",
)
band_selector = gr.Dropdown(
choices=BANDS,
value="K-2",
label="Grade Band",
)
voice_selector = gr.Dropdown(
choices=KOKORO_VOICES,
value=KOKORO_DEFAULT_VOICE,
label="Voice (Kokoro only)",
visible=True, # Kokoro is default engine
info="Defaults to recommended voice for selected band. Override freely.",
)
text_input = gr.Textbox(
label="Coaching Text",
placeholder="Type or paste any coaching text here...",
lines=4,
value="You did such a great job speaking today! I loved how loud and clear your voice was.",
)
synthesize_btn = gr.Button("β–Ά Synthesize + Eval", variant="primary")
with gr.Column(scale=1):
audio_output = gr.Audio(label="Output Audio", type="filepath")
status_output = gr.Textbox(label="Status", interactive=False, lines=3)
gr.Markdown("## Comparison Table")
gr.Markdown(
"**↑ higher is better Β· ↓ lower is better** β€” "
"WER may exceed 100% on short texts."
)
comparison_table = gr.Dataframe(
value=build_comparison_table([]),
label="Eval Results β€” click a row to play audio",
interactive=False,
)
with gr.Row():
with gr.Column(scale=1):
row_audio_player = gr.Audio(
label="β–Ά Selected Row Audio",
visible=False,
type="filepath",
)
with gr.Column(scale=2):
row_detail = gr.Markdown(
value="",
visible=False,
)
business_chart = gr.Plot(
value=build_business_chart([]),
label="Business Decision Chart",
)
with gr.Row():
clear_btn = gr.Button("πŸ—‘ Clear Session")
refresh_btn = gr.Button("πŸ”„ Refresh Table")
load_history_btn = gr.Button("πŸ“‚ Load History")
export_session_btn = gr.Button("⬇ Export Session")
export_all_btn = gr.Button("⬇ Export Full History")
with gr.Row():
export_file = gr.File(label="Download CSV", visible=False)
export_status = gr.Textbox(label="", interactive=False, visible=True, value="")
# ── bindings ──────────────────────────────────────────────────────────
engine_selector.change(
fn=on_engine_change,
inputs=[engine_selector],
outputs=[voice_selector],
)
band_selector.change(
fn=on_band_change,
inputs=[band_selector, engine_selector],
outputs=[voice_selector],
)
synthesize_btn.click(
fn=run_synthesis,
inputs=[engine_selector, band_selector, text_input, voice_selector],
outputs=[audio_output, status_output, comparison_table, business_chart],
)
clear_btn.click(
fn=clear_results,
outputs=[comparison_table, business_chart, export_status],
)
refresh_btn.click(
fn=refresh_table,
outputs=[comparison_table],
)
comparison_table.select(
fn=on_row_select,
inputs=[],
outputs=[row_audio_player, row_detail],
)
load_history_btn.click(
fn=load_history,
outputs=[comparison_table, business_chart, export_status],
)
export_session_btn.click(
fn=export_session,
outputs=[export_file, export_status],
)
export_all_btn.click(
fn=export_all,
outputs=[export_file, export_status],
)
return demo
if __name__ == "__main__":
demo = build_ui()
demo.launch(share=False)