Pant0x commited on
Commit
76a3749
·
verified ·
1 Parent(s): e7edbfd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -38
app.py CHANGED
@@ -1,8 +1,9 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2Processor
4
- import numpy as np
5
  import torchaudio
 
 
 
6
 
7
  # =========================
8
  # CONFIG
@@ -11,61 +12,105 @@ MODEL_NAME = "Hatman/audio-emotion-detection"
11
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
12
 
13
  # =========================
14
- # LOAD MODEL & PROCESSOR
15
  # =========================
16
- processor = Wav2Vec2Processor.from_pretrained(MODEL_NAME)
17
  model = Wav2Vec2ForSequenceClassification.from_pretrained(MODEL_NAME).to(DEVICE)
 
18
 
19
- # Emotion labels in same order used during training (matches the model card)
20
- LABELS = ["angry", "disgust", "fear", "happy", "neutral", "sad", "surprised"]
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  # =========================
23
- # PREDICTION PIPELINE
24
  # =========================
25
  def predict(audio):
26
- # audio: tuple (sample_rate, numpy array)
27
- sr, data = audio
28
-
29
- # Resample to 16k if needed
30
- if sr != 16000:
31
- data = torchaudio.functional.resample(torch.tensor(data), sr, 16000).numpy()
32
- sr = 16000
33
-
34
- # Process input
35
- inputs = processor(
36
- data,
37
- sampling_rate=sr,
38
- return_tensors="pt",
39
- padding=True,
40
- truncation=True
41
- ).to(DEVICE)
42
-
43
- # Forward pass
44
- with torch.no_grad():
45
- logits = model(**inputs).logits
46
- probs = torch.nn.functional.softmax(logits, dim=-1)[0]
47
- pred_idx = torch.argmax(probs).item()
48
-
49
- return {LABELS[i]: float(probs[i]) for i in range(len(LABELS))}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  # =========================
52
- # GRADIO INTERFACE
53
  # =========================
54
  demo = gr.Interface(
55
  fn=predict,
56
- inputs=gr.Audio(sources=["upload", "microphone"], type="numpy", label="Upload or Record Audio"),
57
- outputs=gr.Label(num_top_classes=3),
58
  title="Audio Emotion Detection 🎧",
59
  description=(
60
- "Wav2Vec2 emotion classification model. "
61
- "Supports 7 emotions: Angry, Disgust, Fear, Happy, Neutral, Sad, and Surprised. "
62
- "Upload audio or use your microphone."
 
63
  ),
64
  allow_flagging="never",
65
  )
66
 
67
  # =========================
68
- # LAUNCH APP
69
  # =========================
70
  if __name__ == "__main__":
71
  demo.launch()
 
1
  import gradio as gr
2
  import torch
 
 
3
  import torchaudio
4
+ import numpy as np
5
+ from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor, AutoConfig
6
+ import matplotlib.pyplot as plt
7
 
8
  # =========================
9
  # CONFIG
 
12
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
13
 
14
  # =========================
15
+ # LOAD MODEL & FEATURE EXTRACTOR
16
  # =========================
17
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_NAME)
18
  model = Wav2Vec2ForSequenceClassification.from_pretrained(MODEL_NAME).to(DEVICE)
19
+ model.eval()
20
 
21
+ # Use the model’s label mapping directly
22
+ config = AutoConfig.from_pretrained(MODEL_NAME)
23
+ LABELS = [config.id2label[i] for i in range(len(config.id2label))]
24
+
25
+ # Map some emojis to each emotion for fun UI
26
+ EMOJIS = {
27
+ "Angry": "😡",
28
+ "Disgusted": "🤢",
29
+ "Fearful": "😨",
30
+ "Happy": "😄",
31
+ "Neutral": "😐",
32
+ "Sad": "😢",
33
+ "Surprised": "😲"
34
+ }
35
 
36
  # =========================
37
+ # PREDICTION FUNCTION
38
  # =========================
39
  def predict(audio):
40
+ try:
41
+ if audio is None:
42
+ return {"Error": "No audio provided"}, None
43
+
44
+ sr, data = audio
45
+ data = np.array(data, dtype=np.float32)
46
+
47
+ # Stereo -> Mono
48
+ if len(data.shape) > 1:
49
+ data = np.mean(data, axis=1)
50
+
51
+ # Resample to 16kHz
52
+ if sr != 16000:
53
+ data = torchaudio.functional.resample(torch.tensor(data), sr, 16000).numpy()
54
+ sr = 16000
55
+
56
+ # Normalize for Wav2Vec2
57
+ data = data / 32768.0
58
+
59
+ # Feature extraction
60
+ inputs = feature_extractor(
61
+ data,
62
+ sampling_rate=sr,
63
+ return_tensors="pt",
64
+ padding=True
65
+ )
66
+
67
+ # Move tensors to device
68
+ for k in inputs:
69
+ inputs[k] = inputs[k].to(DEVICE)
70
+
71
+ # Forward pass
72
+ with torch.no_grad():
73
+ logits = model(**inputs).logits
74
+ probs = torch.nn.functional.softmax(logits, dim=-1)[0].cpu().numpy()
75
+
76
+ # Format top 3 results with emojis
77
+ top_idx = np.argsort(probs)[::-1][:3]
78
+ result = {f"{LABELS[i]} {EMOJIS.get(LABELS[i], '')}": round(float(probs[i]), 4) for i in top_idx}
79
+
80
+ # Generate waveform plot
81
+ fig, ax = plt.subplots(figsize=(6,2))
82
+ ax.plot(data, color='purple')
83
+ ax.set_title("Audio Waveform")
84
+ ax.set_xlabel("Samples")
85
+ ax.set_ylabel("Amplitude")
86
+ ax.set_xticks([])
87
+ ax.set_yticks([])
88
+ plt.tight_layout()
89
+
90
+ return result, fig
91
+
92
+ except Exception as e:
93
+ return {"Error": str(e)}, None
94
 
95
  # =========================
96
+ # GRADIO APP
97
  # =========================
98
  demo = gr.Interface(
99
  fn=predict,
100
+ inputs=gr.Audio(sources=["upload", "microphone"], type="numpy", label="🎤 Upload or Record Audio"),
101
+ outputs=[gr.Label(num_top_classes=3), gr.Plot()],
102
  title="Audio Emotion Detection 🎧",
103
  description=(
104
+ "Fine-tuned Wav2Vec2 model (`Hatman/audio-emotion-detection`) "
105
+ "for emotion recognition from voice. "
106
+ "Detects: Angry, Disgusted, Fearful, Happy, Neutral, Sad, Surprised. "
107
+ "Audio is auto-resampled to 16kHz."
108
  ),
109
  allow_flagging="never",
110
  )
111
 
112
  # =========================
113
+ # LAUNCH
114
  # =========================
115
  if __name__ == "__main__":
116
  demo.launch()