Voice_model / app.py
Pant0x's picture
Update app.py
76a3749 verified
import gradio as gr
import torch
import torchaudio
import numpy as np
from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor, AutoConfig
import matplotlib.pyplot as plt
# =========================
# CONFIG
# =========================
MODEL_NAME = "Hatman/audio-emotion-detection"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# =========================
# LOAD MODEL & FEATURE EXTRACTOR
# =========================
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_NAME)
model = Wav2Vec2ForSequenceClassification.from_pretrained(MODEL_NAME).to(DEVICE)
model.eval()
# Use the model’s label mapping directly
config = AutoConfig.from_pretrained(MODEL_NAME)
LABELS = [config.id2label[i] for i in range(len(config.id2label))]
# Map some emojis to each emotion for fun UI
EMOJIS = {
"Angry": "😑",
"Disgusted": "🀒",
"Fearful": "😨",
"Happy": "πŸ˜„",
"Neutral": "😐",
"Sad": "😒",
"Surprised": "😲"
}
# =========================
# PREDICTION FUNCTION
# =========================
def predict(audio):
try:
if audio is None:
return {"Error": "No audio provided"}, None
sr, data = audio
data = np.array(data, dtype=np.float32)
# Stereo -> Mono
if len(data.shape) > 1:
data = np.mean(data, axis=1)
# Resample to 16kHz
if sr != 16000:
data = torchaudio.functional.resample(torch.tensor(data), sr, 16000).numpy()
sr = 16000
# Normalize for Wav2Vec2
data = data / 32768.0
# Feature extraction
inputs = feature_extractor(
data,
sampling_rate=sr,
return_tensors="pt",
padding=True
)
# Move tensors to device
for k in inputs:
inputs[k] = inputs[k].to(DEVICE)
# Forward pass
with torch.no_grad():
logits = model(**inputs).logits
probs = torch.nn.functional.softmax(logits, dim=-1)[0].cpu().numpy()
# Format top 3 results with emojis
top_idx = np.argsort(probs)[::-1][:3]
result = {f"{LABELS[i]} {EMOJIS.get(LABELS[i], '')}": round(float(probs[i]), 4) for i in top_idx}
# Generate waveform plot
fig, ax = plt.subplots(figsize=(6,2))
ax.plot(data, color='purple')
ax.set_title("Audio Waveform")
ax.set_xlabel("Samples")
ax.set_ylabel("Amplitude")
ax.set_xticks([])
ax.set_yticks([])
plt.tight_layout()
return result, fig
except Exception as e:
return {"Error": str(e)}, None
# =========================
# GRADIO APP
# =========================
demo = gr.Interface(
fn=predict,
inputs=gr.Audio(sources=["upload", "microphone"], type="numpy", label="🎀 Upload or Record Audio"),
outputs=[gr.Label(num_top_classes=3), gr.Plot()],
title="Audio Emotion Detection 🎧",
description=(
"Fine-tuned Wav2Vec2 model (`Hatman/audio-emotion-detection`) "
"for emotion recognition from voice. "
"Detects: Angry, Disgusted, Fearful, Happy, Neutral, Sad, Surprised. "
"Audio is auto-resampled to 16kHz."
),
allow_flagging="never",
)
# =========================
# LAUNCH
# =========================
if __name__ == "__main__":
demo.launch()