Pant0x commited on
Commit
4f15c75
Β·
verified Β·
1 Parent(s): 574ac6a

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -148
app.py CHANGED
@@ -1,148 +1,71 @@
1
- import gradio as gr
2
- import torch
3
- import torchaudio
4
- import numpy as np
5
- from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor, AutoConfig
6
- import matplotlib.pyplot as plt
7
-
8
- # =========================
9
- # CONFIG
10
- # =========================
11
- MODEL_NAME = "Hatman/audio-emotion-detection"
12
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
13
-
14
- # =========================
15
- # LOAD MODEL & FEATURE EXTRACTOR
16
- # =========================
17
- feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_NAME)
18
- model = Wav2Vec2ForSequenceClassification.from_pretrained(MODEL_NAME).to(DEVICE)
19
- model.eval()
20
-
21
- # Use the model's label mapping directly
22
- config = AutoConfig.from_pretrained(MODEL_NAME)
23
- LABELS = [config.id2label[i] for i in range(len(config.id2label))]
24
-
25
- # Map some emojis to each emotion for fun UI
26
- EMOJIS = {
27
- "Angry": "😑",
28
- "Disgusted": "🀒",
29
- "Fearful": "😨",
30
- "Happy": "πŸ˜„",
31
- "Neutral": "😐",
32
- "Sad": "😒",
33
- "Surprised": "😲"
34
- }
35
-
36
- # =========================
37
- # PREDICTION FUNCTION
38
- # =========================
39
- def predict(audio):
40
- try:
41
- if audio is None:
42
- return {"Error": "No audio provided"}, None
43
-
44
- sr, data = audio
45
- data = np.array(data, dtype=np.float32)
46
-
47
- # Stereo -> Mono
48
- if len(data.shape) > 1:
49
- data = np.mean(data, axis=1)
50
-
51
- # Resample to 16kHz
52
- if sr != 16000:
53
- data = torchaudio.functional.resample(torch.tensor(data), sr, 16000).numpy()
54
- sr = 16000
55
-
56
- # Improved normalization - normalize to [-1, 1] range
57
- # Check if data is in int16 range or already normalized
58
- if np.abs(data).max() > 1.0:
59
- data = data / np.abs(data).max() # Normalize by max value
60
-
61
- # Apply gentle audio preprocessing to improve feature extraction
62
- # Remove DC offset
63
- data = data - np.mean(data)
64
-
65
- # Apply light pre-emphasis filter to balance frequencies
66
- pre_emphasis = 0.97
67
- data = np.append(data[0], data[1:] - pre_emphasis * data[:-1])
68
-
69
- # Feature extraction with proper padding
70
- inputs = feature_extractor(
71
- data,
72
- sampling_rate=sr,
73
- return_tensors="pt",
74
- padding=True,
75
- max_length=16000 * 10, # Max 10 seconds
76
- truncation=True
77
- )
78
-
79
- # Move tensors to device
80
- for k in inputs:
81
- inputs[k] = inputs[k].to(DEVICE)
82
-
83
- # Forward pass
84
- with torch.no_grad():
85
- logits = model(**inputs).logits
86
-
87
- # Apply temperature scaling to reduce overconfidence
88
- # Lower temperature = more uniform distribution
89
- temperature = 1.5
90
- logits = logits / temperature
91
-
92
- probs = torch.nn.functional.softmax(logits, dim=-1)[0].cpu().numpy()
93
-
94
- # Show ALL emotions with their scores (not just top 3)
95
- result = {}
96
- for i, label in enumerate(LABELS):
97
- emoji = EMOJIS.get(label, '')
98
- result[f"{label} {emoji}"] = round(float(probs[i]), 4)
99
-
100
- # Sort by probability
101
- result = dict(sorted(result.items(), key=lambda x: x[1], reverse=True))
102
-
103
- # Generate waveform plot
104
- fig, ax = plt.subplots(figsize=(8, 3))
105
- time_axis = np.linspace(0, len(data) / sr, len(data))
106
- ax.plot(time_axis, data, color='purple', linewidth=0.5)
107
- ax.set_title("Audio Waveform", fontsize=12, fontweight='bold')
108
- ax.set_xlabel("Time (seconds)")
109
- ax.set_ylabel("Amplitude")
110
- ax.grid(True, alpha=0.3)
111
- plt.tight_layout()
112
-
113
- return result, fig
114
-
115
- except Exception as e:
116
- return {"Error": str(e)}, None
117
-
118
- # =========================
119
- # GRADIO APP
120
- # =========================
121
- demo = gr.Interface(
122
- fn=predict,
123
- inputs=gr.Audio(sources=["upload", "microphone"], type="numpy", label="🎀 Upload or Record Audio"),
124
- outputs=[
125
- gr.Label(num_top_classes=7, label="Emotion Probabilities"),
126
- gr.Plot(label="Waveform Visualization")
127
- ],
128
- title="🎧 Audio Emotion Detection",
129
- description=(
130
- "Fine-tuned Wav2Vec2 model for emotion recognition from voice. "
131
- "Detects: **Angry, Disgusted, Fearful, Happy, Neutral, Sad, Surprised**.\n\n"
132
- "**Tips for better results:**\n"
133
- "- Speak clearly and naturally\n"
134
- "- Record at least 2-3 seconds of audio\n"
135
- "- Avoid background noise\n"
136
- "- Try exaggerating emotions for testing\n\n"
137
- "Audio is automatically resampled to 16kHz and normalized."
138
- ),
139
- examples=[],
140
- allow_flagging="never",
141
- theme=gr.themes.Soft()
142
- )
143
-
144
- # =========================
145
- # LAUNCH
146
- # =========================
147
- if __name__ == "__main__":
148
- demo.launch()
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2Processor
4
+ import numpy as np
5
+ import torchaudio
6
+
7
+ # =========================
8
+ # CONFIG
9
+ # =========================
10
+ MODEL_NAME = "your-username/Audio-Emotion-Detection" # <- replace with your repo name
11
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
12
+
13
+ # =========================
14
+ # LOAD MODEL & PROCESSOR
15
+ # =========================
16
+ processor = Wav2Vec2Processor.from_pretrained(MODEL_NAME)
17
+ model = Wav2Vec2ForSequenceClassification.from_pretrained(MODEL_NAME).to(DEVICE)
18
+
19
+ # Emotion labels in same order used during training
20
+ LABELS = ["Angry", "Disgusted", "Fearful", "Happy", "Neutral", "Sad", "Surprised"]
21
+
22
+ # =========================
23
+ # PREDICTION PIPELINE
24
+ # =========================
25
+ def predict(audio):
26
+ # audio: tuple (sample_rate, numpy array)
27
+ sr, data = audio
28
+
29
+ # Resample to 16k if necessary
30
+ if sr != 16000:
31
+ data = torchaudio.functional.resample(torch.tensor(data), sr, 16000).numpy()
32
+ sr = 16000
33
+
34
+ # Process input
35
+ inputs = processor(
36
+ data,
37
+ sampling_rate=sr,
38
+ return_tensors="pt",
39
+ padding=True,
40
+ truncation=True
41
+ ).to(DEVICE)
42
+
43
+ # Forward pass
44
+ with torch.no_grad():
45
+ logits = model(**inputs).logits
46
+ probs = torch.nn.functional.softmax(logits, dim=-1)[0]
47
+ pred_idx = torch.argmax(probs).item()
48
+
49
+ return {LABELS[i]: float(probs[i]) for i in range(len(LABELS))}
50
+
51
+ # =========================
52
+ # GRADIO INTERFACE
53
+ # =========================
54
+ demo = gr.Interface(
55
+ fn=predict,
56
+ inputs=gr.Audio(sources=["upload", "microphone"], type="numpy", label="Upload or Record Audio"),
57
+ outputs=gr.Label(num_top_classes=3),
58
+ title="Audio Emotion Detection 🎧",
59
+ description=(
60
+ "Fine-tuned Wav2Vec2 model for detecting emotions from voice. "
61
+ "Supports 7 emotions: Angry, Disgusted, Fearful, Happy, Neutral, Sad, and Surprised. "
62
+ "All audio should be 16kHz."
63
+ ),
64
+ allow_flagging="never",
65
+ )
66
+
67
+ # =========================
68
+ # LAUNCH APP
69
+ # =========================
70
+ if __name__ == "__main__":
71
+ demo.launch()