Spaces:
Running
Running
| import gradio as gr | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| from model_utils import load_models, predict, EMOTION_LABELS | |
| # ── Load once at startup ─────────────────────────────────────────────────────── | |
| print("Loading models...") | |
| load_models(model_dir=".") | |
| print("Ready.") | |
| EMOJI = {'neutral':'😐','happy':'😊','sad':'😢','angry':'😠','fear':'😨'} | |
| COLORS = {'neutral':'#95a5a6','happy':'#2ecc71','sad':'#3498db','angry':'#e74c3c','fear':'#e67e22'} | |
| # ── Shared inference logic ───────────────────────────────────────────────────── | |
| def _run(audio_path, language, mode): | |
| """Core inference — used by both the UI and the clean API endpoint.""" | |
| if audio_path is None: | |
| return None, "No audio provided." | |
| try: | |
| probs = predict(audio_path, language=language, mode=mode) | |
| except Exception as e: | |
| return None, f"Error: {e}" | |
| sorted_probs = sorted(probs.items(), key=lambda x: -x[1]) | |
| top, top_conf = sorted_probs[0] | |
| return probs, top | |
| # ── UI function (returns markdown + chart) ───────────────────────────────────── | |
| def run_inference(audio_path, language, mode): | |
| probs, top = _run(audio_path, language, mode) | |
| if probs is None: | |
| return top, None # top is the error string here | |
| sorted_probs = sorted(probs.items(), key=lambda x: -x[1]) | |
| top, top_conf = sorted_probs[0] | |
| result_md = ( | |
| f"## {EMOJI.get(top, '')} {top.upper()}\n\n" | |
| f"**Confidence:** {top_conf:.1%}\n\n" | |
| f"**Language:** {language} | **Mode:** {mode}" | |
| ) | |
| fig, ax = plt.subplots(figsize=(6, 3.2)) | |
| emos = [e for e, _ in sorted_probs] | |
| vals = [p for _, p in sorted_probs] | |
| cols = [COLORS.get(e, "#bdc3c7") for e in emos] | |
| bars = ax.barh(emos, vals, color=cols, height=0.5, edgecolor="none") | |
| for bar, val in zip(bars, vals): | |
| ax.text(val + 0.01, bar.get_y() + bar.get_height() / 2, | |
| f"{val:.1%}", va="center", fontsize=9) | |
| ax.set_xlim(0, 1.05) | |
| ax.set_xlabel("Probability") | |
| ax.set_title("Emotion Probabilities", fontweight="bold") | |
| ax.invert_yaxis() | |
| ax.spines[["top", "right", "left"]].set_visible(False) | |
| plt.tight_layout() | |
| return result_md, fig | |
| # ── Clean API function (used by your Vercel backend) ────────────────────────── | |
| # Returns a plain dict — no chart, no markdown. | |
| # gradio_client calls this as api_name="/predict_api" | |
| def predict_api(audio_path: str, language: str, mode: str) -> dict: | |
| """ | |
| Clean JSON endpoint for programmatic access. | |
| Returns: {"emotion": str, "confidence": float, "all_probs": dict} | |
| """ | |
| probs, top = _run(audio_path, language, mode) | |
| if probs is None: | |
| return {"emotion": "neutral", "confidence": 0.0, "all_probs": {}, "error": top} | |
| return { | |
| "emotion": top, | |
| "confidence": round(probs[top], 4), | |
| "all_probs": {k: round(v, 4) for k, v in probs.items()}, | |
| "error": None, | |
| } | |
| # ── Gradio UI ────────────────────────────────────────────────────────────────── | |
| with gr.Blocks(title="Multilingual SER") as demo: | |
| gr.Markdown(""" | |
| # 🎙️ Multilingual Speech Emotion Recognition | |
| Detects emotion in **Sinhala**, **Tamil**, and **English** speech. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| audio_in = gr.Audio( | |
| sources=["upload", "microphone"], | |
| type="filepath", | |
| label="Audio Input (WAV/MP3, max 6s)" | |
| ) | |
| language = gr.Radio( | |
| choices=["english", "tamil", "sinhala"], | |
| value="english", | |
| label="Language", | |
| info="Select the language spoken — affects normalization" | |
| ) | |
| mode = gr.Radio( | |
| choices=["fusion", "gemaps", "ensemble"], | |
| value="ensemble", | |
| label="Inference Mode", | |
| info="ensemble is most robust | gemaps is fastest | fusion is highest accuracy on English/Tamil" | |
| ) | |
| btn = gr.Button("Detect Emotion", variant="primary") | |
| with gr.Column(): | |
| out_text = gr.Markdown() | |
| out_plot = gr.Plot(label="Confidence") | |
| btn.click(run_inference, [audio_in, language, mode], [out_text, out_plot]) | |
| # ── Hidden API endpoint (Gradio 6 compatible) ────────────────────────────── | |
| # gr.Interface nested inside gr.Blocks crashes in Gradio 6. | |
| # Instead: hidden row wired to predict_api — registers as /predict_api | |
| with gr.Row(visible=False): | |
| _api_audio = gr.Audio(type="filepath") | |
| _api_lang = gr.Text(value="english") | |
| _api_mode = gr.Text(value="ensemble") | |
| _api_out = gr.JSON() | |
| _api_btn = gr.Button() | |
| _api_btn.click( | |
| fn=predict_api, | |
| inputs=[_api_audio, _api_lang, _api_mode], | |
| outputs=_api_out, | |
| api_name="predict_api", | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| **Emotions:** Neutral · Happy · Sad · Angry · Fear | |
| **Modes:** | |
| - `fusion` — Whisper-tiny encoder + eGeMAPS (best on English & Tamil) | |
| - `gemaps` — 88 acoustic features only, language-agnostic, ~50ms | |
| - `ensemble` — 60% fusion + 40% gemaps (recommended for Sinhala) | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch(theme=gr.themes.Soft()) |