import gradio as gr import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt from model_utils import load_models, predict, EMOTION_LABELS # ── Load once at startup ─────────────────────────────────────────────────────── print("Loading models...") load_models(model_dir=".") print("Ready.") EMOJI = {'neutral':'😐','happy':'😊','sad':'😢','angry':'😠','fear':'😨'} COLORS = {'neutral':'#95a5a6','happy':'#2ecc71','sad':'#3498db','angry':'#e74c3c','fear':'#e67e22'} # ── Shared inference logic ───────────────────────────────────────────────────── def _run(audio_path, language, mode): """Core inference — used by both the UI and the clean API endpoint.""" if audio_path is None: return None, "No audio provided." try: probs = predict(audio_path, language=language, mode=mode) except Exception as e: return None, f"Error: {e}" sorted_probs = sorted(probs.items(), key=lambda x: -x[1]) top, top_conf = sorted_probs[0] return probs, top # ── UI function (returns markdown + chart) ───────────────────────────────────── def run_inference(audio_path, language, mode): probs, top = _run(audio_path, language, mode) if probs is None: return top, None # top is the error string here sorted_probs = sorted(probs.items(), key=lambda x: -x[1]) top, top_conf = sorted_probs[0] result_md = ( f"## {EMOJI.get(top, '')} {top.upper()}\n\n" f"**Confidence:** {top_conf:.1%}\n\n" f"**Language:** {language} | **Mode:** {mode}" ) fig, ax = plt.subplots(figsize=(6, 3.2)) emos = [e for e, _ in sorted_probs] vals = [p for _, p in sorted_probs] cols = [COLORS.get(e, "#bdc3c7") for e in emos] bars = ax.barh(emos, vals, color=cols, height=0.5, edgecolor="none") for bar, val in zip(bars, vals): ax.text(val + 0.01, bar.get_y() + bar.get_height() / 2, f"{val:.1%}", va="center", fontsize=9) ax.set_xlim(0, 1.05) ax.set_xlabel("Probability") ax.set_title("Emotion Probabilities", fontweight="bold") ax.invert_yaxis() ax.spines[["top", "right", "left"]].set_visible(False) plt.tight_layout() return result_md, fig # ── Clean API function (used by your Vercel backend) ────────────────────────── # Returns a plain dict — no chart, no markdown. # gradio_client calls this as api_name="/predict_api" def predict_api(audio_path: str, language: str, mode: str) -> dict: """ Clean JSON endpoint for programmatic access. Returns: {"emotion": str, "confidence": float, "all_probs": dict} """ probs, top = _run(audio_path, language, mode) if probs is None: return {"emotion": "neutral", "confidence": 0.0, "all_probs": {}, "error": top} return { "emotion": top, "confidence": round(probs[top], 4), "all_probs": {k: round(v, 4) for k, v in probs.items()}, "error": None, } # ── Gradio UI ────────────────────────────────────────────────────────────────── with gr.Blocks(title="Multilingual SER") as demo: gr.Markdown(""" # 🎙️ Multilingual Speech Emotion Recognition Detects emotion in **Sinhala**, **Tamil**, and **English** speech. """) with gr.Row(): with gr.Column(): audio_in = gr.Audio( sources=["upload", "microphone"], type="filepath", label="Audio Input (WAV/MP3, max 6s)" ) language = gr.Radio( choices=["english", "tamil", "sinhala"], value="english", label="Language", info="Select the language spoken — affects normalization" ) mode = gr.Radio( choices=["fusion", "gemaps", "ensemble"], value="ensemble", label="Inference Mode", info="ensemble is most robust | gemaps is fastest | fusion is highest accuracy on English/Tamil" ) btn = gr.Button("Detect Emotion", variant="primary") with gr.Column(): out_text = gr.Markdown() out_plot = gr.Plot(label="Confidence") btn.click(run_inference, [audio_in, language, mode], [out_text, out_plot]) # ── Hidden API endpoint (Gradio 6 compatible) ────────────────────────────── # gr.Interface nested inside gr.Blocks crashes in Gradio 6. # Instead: hidden row wired to predict_api — registers as /predict_api with gr.Row(visible=False): _api_audio = gr.Audio(type="filepath") _api_lang = gr.Text(value="english") _api_mode = gr.Text(value="ensemble") _api_out = gr.JSON() _api_btn = gr.Button() _api_btn.click( fn=predict_api, inputs=[_api_audio, _api_lang, _api_mode], outputs=_api_out, api_name="predict_api", ) gr.Markdown(""" --- **Emotions:** Neutral · Happy · Sad · Angry · Fear **Modes:** - `fusion` — Whisper-tiny encoder + eGeMAPS (best on English & Tamil) - `gemaps` — 88 acoustic features only, language-agnostic, ~50ms - `ensemble` — 60% fusion + 40% gemaps (recommended for Sinhala) """) if __name__ == "__main__": demo.launch(theme=gr.themes.Soft())