ser-wav2vec / app.py
Raemih's picture
Update app.py
c02d65d verified
import gradio as gr
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from model_utils import load_models, predict, EMOTION_LABELS
# ── Load once at startup ───────────────────────────────────────────────────────
print("Loading models...")
load_models(model_dir=".")
print("Ready.")
EMOJI = {'neutral':'😐','happy':'😊','sad':'😢','angry':'😠','fear':'😨'}
COLORS = {'neutral':'#95a5a6','happy':'#2ecc71','sad':'#3498db','angry':'#e74c3c','fear':'#e67e22'}
# ── Shared inference logic ─────────────────────────────────────────────────────
def _run(audio_path, language, mode):
"""Core inference — used by both the UI and the clean API endpoint."""
if audio_path is None:
return None, "No audio provided."
try:
probs = predict(audio_path, language=language, mode=mode)
except Exception as e:
return None, f"Error: {e}"
sorted_probs = sorted(probs.items(), key=lambda x: -x[1])
top, top_conf = sorted_probs[0]
return probs, top
# ── UI function (returns markdown + chart) ─────────────────────────────────────
def run_inference(audio_path, language, mode):
probs, top = _run(audio_path, language, mode)
if probs is None:
return top, None # top is the error string here
sorted_probs = sorted(probs.items(), key=lambda x: -x[1])
top, top_conf = sorted_probs[0]
result_md = (
f"## {EMOJI.get(top, '')} {top.upper()}\n\n"
f"**Confidence:** {top_conf:.1%}\n\n"
f"**Language:** {language} | **Mode:** {mode}"
)
fig, ax = plt.subplots(figsize=(6, 3.2))
emos = [e for e, _ in sorted_probs]
vals = [p for _, p in sorted_probs]
cols = [COLORS.get(e, "#bdc3c7") for e in emos]
bars = ax.barh(emos, vals, color=cols, height=0.5, edgecolor="none")
for bar, val in zip(bars, vals):
ax.text(val + 0.01, bar.get_y() + bar.get_height() / 2,
f"{val:.1%}", va="center", fontsize=9)
ax.set_xlim(0, 1.05)
ax.set_xlabel("Probability")
ax.set_title("Emotion Probabilities", fontweight="bold")
ax.invert_yaxis()
ax.spines[["top", "right", "left"]].set_visible(False)
plt.tight_layout()
return result_md, fig
# ── Clean API function (used by your Vercel backend) ──────────────────────────
# Returns a plain dict — no chart, no markdown.
# gradio_client calls this as api_name="/predict_api"
def predict_api(audio_path: str, language: str, mode: str) -> dict:
"""
Clean JSON endpoint for programmatic access.
Returns: {"emotion": str, "confidence": float, "all_probs": dict}
"""
probs, top = _run(audio_path, language, mode)
if probs is None:
return {"emotion": "neutral", "confidence": 0.0, "all_probs": {}, "error": top}
return {
"emotion": top,
"confidence": round(probs[top], 4),
"all_probs": {k: round(v, 4) for k, v in probs.items()},
"error": None,
}
# ── Gradio UI ──────────────────────────────────────────────────────────────────
with gr.Blocks(title="Multilingual SER") as demo:
gr.Markdown("""
# 🎙️ Multilingual Speech Emotion Recognition
Detects emotion in **Sinhala**, **Tamil**, and **English** speech.
""")
with gr.Row():
with gr.Column():
audio_in = gr.Audio(
sources=["upload", "microphone"],
type="filepath",
label="Audio Input (WAV/MP3, max 6s)"
)
language = gr.Radio(
choices=["english", "tamil", "sinhala"],
value="english",
label="Language",
info="Select the language spoken — affects normalization"
)
mode = gr.Radio(
choices=["fusion", "gemaps", "ensemble"],
value="ensemble",
label="Inference Mode",
info="ensemble is most robust | gemaps is fastest | fusion is highest accuracy on English/Tamil"
)
btn = gr.Button("Detect Emotion", variant="primary")
with gr.Column():
out_text = gr.Markdown()
out_plot = gr.Plot(label="Confidence")
btn.click(run_inference, [audio_in, language, mode], [out_text, out_plot])
# ── Hidden API endpoint (Gradio 6 compatible) ──────────────────────────────
# gr.Interface nested inside gr.Blocks crashes in Gradio 6.
# Instead: hidden row wired to predict_api — registers as /predict_api
with gr.Row(visible=False):
_api_audio = gr.Audio(type="filepath")
_api_lang = gr.Text(value="english")
_api_mode = gr.Text(value="ensemble")
_api_out = gr.JSON()
_api_btn = gr.Button()
_api_btn.click(
fn=predict_api,
inputs=[_api_audio, _api_lang, _api_mode],
outputs=_api_out,
api_name="predict_api",
)
gr.Markdown("""
---
**Emotions:** Neutral · Happy · Sad · Angry · Fear
**Modes:**
- `fusion` — Whisper-tiny encoder + eGeMAPS (best on English & Tamil)
- `gemaps` — 88 acoustic features only, language-agnostic, ~50ms
- `ensemble` — 60% fusion + 40% gemaps (recommended for Sinhala)
""")
if __name__ == "__main__":
demo.launch(theme=gr.themes.Soft())