File size: 3,367 Bytes
e7edbfd
 
 
76a3749
 
 
e7edbfd
 
 
 
 
 
 
 
76a3749
e7edbfd
76a3749
e7edbfd
76a3749
e7edbfd
76a3749
 
 
 
 
 
 
 
 
 
 
 
 
 
e7edbfd
 
76a3749
e7edbfd
 
76a3749
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e7edbfd
 
76a3749
e7edbfd
 
 
76a3749
 
e7edbfd
 
76a3749
 
 
 
e7edbfd
 
 
 
 
76a3749
e7edbfd
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import gradio as gr
import torch
import torchaudio
import numpy as np
from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor, AutoConfig
import matplotlib.pyplot as plt

# =========================
# CONFIG
# =========================
MODEL_NAME = "Hatman/audio-emotion-detection"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# =========================
# LOAD MODEL & FEATURE EXTRACTOR
# =========================
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_NAME)
model = Wav2Vec2ForSequenceClassification.from_pretrained(MODEL_NAME).to(DEVICE)
model.eval()

# Use the model’s label mapping directly
config = AutoConfig.from_pretrained(MODEL_NAME)
LABELS = [config.id2label[i] for i in range(len(config.id2label))]

# Map some emojis to each emotion for fun UI
EMOJIS = {
    "Angry": "😑",
    "Disgusted": "🀒",
    "Fearful": "😨",
    "Happy": "πŸ˜„",
    "Neutral": "😐",
    "Sad": "😒",
    "Surprised": "😲"
}

# =========================
# PREDICTION FUNCTION
# =========================
def predict(audio):
    try:
        if audio is None:
            return {"Error": "No audio provided"}, None

        sr, data = audio
        data = np.array(data, dtype=np.float32)

        # Stereo -> Mono
        if len(data.shape) > 1:
            data = np.mean(data, axis=1)

        # Resample to 16kHz
        if sr != 16000:
            data = torchaudio.functional.resample(torch.tensor(data), sr, 16000).numpy()
            sr = 16000

        # Normalize for Wav2Vec2
        data = data / 32768.0

        # Feature extraction
        inputs = feature_extractor(
            data,
            sampling_rate=sr,
            return_tensors="pt",
            padding=True
        )

        # Move tensors to device
        for k in inputs:
            inputs[k] = inputs[k].to(DEVICE)

        # Forward pass
        with torch.no_grad():
            logits = model(**inputs).logits
            probs = torch.nn.functional.softmax(logits, dim=-1)[0].cpu().numpy()

        # Format top 3 results with emojis
        top_idx = np.argsort(probs)[::-1][:3]
        result = {f"{LABELS[i]} {EMOJIS.get(LABELS[i], '')}": round(float(probs[i]), 4) for i in top_idx}

        # Generate waveform plot
        fig, ax = plt.subplots(figsize=(6,2))
        ax.plot(data, color='purple')
        ax.set_title("Audio Waveform")
        ax.set_xlabel("Samples")
        ax.set_ylabel("Amplitude")
        ax.set_xticks([])
        ax.set_yticks([])
        plt.tight_layout()

        return result, fig

    except Exception as e:
        return {"Error": str(e)}, None

# =========================
# GRADIO APP
# =========================
demo = gr.Interface(
    fn=predict,
    inputs=gr.Audio(sources=["upload", "microphone"], type="numpy", label="🎀 Upload or Record Audio"),
    outputs=[gr.Label(num_top_classes=3), gr.Plot()],
    title="Audio Emotion Detection 🎧",
    description=(
        "Fine-tuned Wav2Vec2 model (`Hatman/audio-emotion-detection`) "
        "for emotion recognition from voice. "
        "Detects: Angry, Disgusted, Fearful, Happy, Neutral, Sad, Surprised. "
        "Audio is auto-resampled to 16kHz."
    ),
    allow_flagging="never",
)

# =========================
# LAUNCH
# =========================
if __name__ == "__main__":
    demo.launch()