Spaces:
Running
Running
File size: 5,868 Bytes
e437e52 77a3ebd e437e52 77a3ebd e437e52 77a3ebd e437e52 77a3ebd e437e52 77a3ebd e437e52 77a3ebd e437e52 77a3ebd c02d65d e437e52 77a3ebd e437e52 c02d65d 77a3ebd e437e52 c02d65d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 | import gradio as gr
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from model_utils import load_models, predict, EMOTION_LABELS
# ── Load once at startup ───────────────────────────────────────────────────────
print("Loading models...")
load_models(model_dir=".")
print("Ready.")
EMOJI = {'neutral':'😐','happy':'😊','sad':'😢','angry':'😠','fear':'😨'}
COLORS = {'neutral':'#95a5a6','happy':'#2ecc71','sad':'#3498db','angry':'#e74c3c','fear':'#e67e22'}
# ── Shared inference logic ─────────────────────────────────────────────────────
def _run(audio_path, language, mode):
"""Core inference — used by both the UI and the clean API endpoint."""
if audio_path is None:
return None, "No audio provided."
try:
probs = predict(audio_path, language=language, mode=mode)
except Exception as e:
return None, f"Error: {e}"
sorted_probs = sorted(probs.items(), key=lambda x: -x[1])
top, top_conf = sorted_probs[0]
return probs, top
# ── UI function (returns markdown + chart) ─────────────────────────────────────
def run_inference(audio_path, language, mode):
probs, top = _run(audio_path, language, mode)
if probs is None:
return top, None # top is the error string here
sorted_probs = sorted(probs.items(), key=lambda x: -x[1])
top, top_conf = sorted_probs[0]
result_md = (
f"## {EMOJI.get(top, '')} {top.upper()}\n\n"
f"**Confidence:** {top_conf:.1%}\n\n"
f"**Language:** {language} | **Mode:** {mode}"
)
fig, ax = plt.subplots(figsize=(6, 3.2))
emos = [e for e, _ in sorted_probs]
vals = [p for _, p in sorted_probs]
cols = [COLORS.get(e, "#bdc3c7") for e in emos]
bars = ax.barh(emos, vals, color=cols, height=0.5, edgecolor="none")
for bar, val in zip(bars, vals):
ax.text(val + 0.01, bar.get_y() + bar.get_height() / 2,
f"{val:.1%}", va="center", fontsize=9)
ax.set_xlim(0, 1.05)
ax.set_xlabel("Probability")
ax.set_title("Emotion Probabilities", fontweight="bold")
ax.invert_yaxis()
ax.spines[["top", "right", "left"]].set_visible(False)
plt.tight_layout()
return result_md, fig
# ── Clean API function (used by your Vercel backend) ──────────────────────────
# Returns a plain dict — no chart, no markdown.
# gradio_client calls this as api_name="/predict_api"
def predict_api(audio_path: str, language: str, mode: str) -> dict:
"""
Clean JSON endpoint for programmatic access.
Returns: {"emotion": str, "confidence": float, "all_probs": dict}
"""
probs, top = _run(audio_path, language, mode)
if probs is None:
return {"emotion": "neutral", "confidence": 0.0, "all_probs": {}, "error": top}
return {
"emotion": top,
"confidence": round(probs[top], 4),
"all_probs": {k: round(v, 4) for k, v in probs.items()},
"error": None,
}
# ── Gradio UI ──────────────────────────────────────────────────────────────────
with gr.Blocks(title="Multilingual SER") as demo:
gr.Markdown("""
# 🎙️ Multilingual Speech Emotion Recognition
Detects emotion in **Sinhala**, **Tamil**, and **English** speech.
""")
with gr.Row():
with gr.Column():
audio_in = gr.Audio(
sources=["upload", "microphone"],
type="filepath",
label="Audio Input (WAV/MP3, max 6s)"
)
language = gr.Radio(
choices=["english", "tamil", "sinhala"],
value="english",
label="Language",
info="Select the language spoken — affects normalization"
)
mode = gr.Radio(
choices=["fusion", "gemaps", "ensemble"],
value="ensemble",
label="Inference Mode",
info="ensemble is most robust | gemaps is fastest | fusion is highest accuracy on English/Tamil"
)
btn = gr.Button("Detect Emotion", variant="primary")
with gr.Column():
out_text = gr.Markdown()
out_plot = gr.Plot(label="Confidence")
btn.click(run_inference, [audio_in, language, mode], [out_text, out_plot])
# ── Hidden API endpoint (Gradio 6 compatible) ──────────────────────────────
# gr.Interface nested inside gr.Blocks crashes in Gradio 6.
# Instead: hidden row wired to predict_api — registers as /predict_api
with gr.Row(visible=False):
_api_audio = gr.Audio(type="filepath")
_api_lang = gr.Text(value="english")
_api_mode = gr.Text(value="ensemble")
_api_out = gr.JSON()
_api_btn = gr.Button()
_api_btn.click(
fn=predict_api,
inputs=[_api_audio, _api_lang, _api_mode],
outputs=_api_out,
api_name="predict_api",
)
gr.Markdown("""
---
**Emotions:** Neutral · Happy · Sad · Angry · Fear
**Modes:**
- `fusion` — Whisper-tiny encoder + eGeMAPS (best on English & Tamil)
- `gemaps` — 88 acoustic features only, language-agnostic, ~50ms
- `ensemble` — 60% fusion + 40% gemaps (recommended for Sinhala)
""")
if __name__ == "__main__":
demo.launch(theme=gr.themes.Soft()) |