Pant0x commited on
Commit
63e9297
Β·
verified Β·
1 Parent(s): e351a45

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -29
app.py CHANGED
@@ -2,7 +2,8 @@ import gradio as gr
2
  import torch
3
  import torchaudio
4
  import numpy as np
5
- from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor
 
6
 
7
  # =========================
8
  # CONFIG
@@ -13,14 +14,24 @@ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
13
  # =========================
14
  # LOAD MODEL & FEATURE EXTRACTOR
15
  # =========================
16
- print(f"Loading model: {MODEL_NAME}")
17
  feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_NAME)
18
  model = Wav2Vec2ForSequenceClassification.from_pretrained(MODEL_NAME).to(DEVICE)
19
  model.eval()
20
- print("Model loaded successfully.")
21
 
22
- # Emotion labels in correct order
23
- LABELS = ["Angry", "Disgusted", "Fearful", "Happy", "Neutral", "Sad", "Surprised"]
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  # =========================
26
  # PREDICTION FUNCTION
@@ -28,31 +39,22 @@ LABELS = ["Angry", "Disgusted", "Fearful", "Happy", "Neutral", "Sad", "Surprised
28
  def predict(audio):
29
  try:
30
  if audio is None:
31
- return {"Error": "No audio provided"}
32
 
33
  sr, data = audio
34
- if isinstance(data, list):
35
- data = np.array(data, dtype=np.float32)
36
 
37
- # Convert stereo -> mono
38
  if len(data.shape) > 1:
39
  data = np.mean(data, axis=1)
40
 
41
- # Resample if needed
42
  if sr != 16000:
43
- waveform = torch.tensor(data, dtype=torch.float32)
44
- data = torchaudio.functional.resample(waveform, sr, 16000).numpy()
45
  sr = 16000
46
 
47
- # Normalize
48
- if np.abs(data).max() > 0:
49
- data = data / np.abs(data).max()
50
-
51
- # Make sure dtype and shape are clean
52
- data = np.array(data, dtype=np.float32).flatten()
53
-
54
- # Debug info
55
- print(f"Sample rate: {sr}, Data shape: {data.shape}, Device: {DEVICE}")
56
 
57
  # Feature extraction
58
  inputs = feature_extractor(
@@ -62,7 +64,7 @@ def predict(audio):
62
  padding=True
63
  )
64
 
65
- # Move to device
66
  for k in inputs:
67
  inputs[k] = inputs[k].to(DEVICE)
68
 
@@ -71,13 +73,24 @@ def predict(audio):
71
  logits = model(**inputs).logits
72
  probs = torch.nn.functional.softmax(logits, dim=-1)[0].cpu().numpy()
73
 
74
- result = {LABELS[i]: round(float(probs[i]), 4) for i in range(len(LABELS))}
75
- print(f"Predicted: {result}")
76
- return result
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  except Exception as e:
79
- print(f"ERROR: {str(e)}")
80
- return {"Error": str(e)}
81
 
82
  # =========================
83
  # GRADIO APP
@@ -85,12 +98,12 @@ def predict(audio):
85
  demo = gr.Interface(
86
  fn=predict,
87
  inputs=gr.Audio(sources=["upload", "microphone"], type="numpy", label="🎀 Upload or Record Audio"),
88
- outputs=gr.Label(num_top_classes=3),
89
  title="Audio Emotion Detection 🎧",
90
  description=(
91
  "Fine-tuned Wav2Vec2 model (`Hatman/audio-emotion-detection`) "
92
  "for emotion recognition from voice. "
93
- "Detects: Angry, Disgusted, Fearful, Happy, Neutral, Sad, and Surprised. "
94
  "Audio is auto-resampled to 16kHz."
95
  ),
96
  allow_flagging="never",
 
2
  import torch
3
  import torchaudio
4
  import numpy as np
5
+ from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor, AutoConfig
6
+ import matplotlib.pyplot as plt
7
 
8
  # =========================
9
  # CONFIG
 
14
  # =========================
15
  # LOAD MODEL & FEATURE EXTRACTOR
16
  # =========================
 
17
  feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_NAME)
18
  model = Wav2Vec2ForSequenceClassification.from_pretrained(MODEL_NAME).to(DEVICE)
19
  model.eval()
 
20
 
21
+ # Use the model’s label mapping directly
22
+ config = AutoConfig.from_pretrained(MODEL_NAME)
23
+ LABELS = [config.id2label[i] for i in range(len(config.id2label))]
24
+
25
+ # Map some emojis to each emotion for fun UI
26
+ EMOJIS = {
27
+ "Angry": "😑",
28
+ "Disgusted": "🀒",
29
+ "Fearful": "😨",
30
+ "Happy": "πŸ˜„",
31
+ "Neutral": "😐",
32
+ "Sad": "😒",
33
+ "Surprised": "😲"
34
+ }
35
 
36
  # =========================
37
  # PREDICTION FUNCTION
 
39
  def predict(audio):
40
  try:
41
  if audio is None:
42
+ return {"Error": "No audio provided"}, None
43
 
44
  sr, data = audio
45
+ data = np.array(data, dtype=np.float32)
 
46
 
47
+ # Stereo -> Mono
48
  if len(data.shape) > 1:
49
  data = np.mean(data, axis=1)
50
 
51
+ # Resample to 16kHz
52
  if sr != 16000:
53
+ data = torchaudio.functional.resample(torch.tensor(data), sr, 16000).numpy()
 
54
  sr = 16000
55
 
56
+ # Normalize for Wav2Vec2
57
+ data = data / 32768.0
 
 
 
 
 
 
 
58
 
59
  # Feature extraction
60
  inputs = feature_extractor(
 
64
  padding=True
65
  )
66
 
67
+ # Move tensors to device
68
  for k in inputs:
69
  inputs[k] = inputs[k].to(DEVICE)
70
 
 
73
  logits = model(**inputs).logits
74
  probs = torch.nn.functional.softmax(logits, dim=-1)[0].cpu().numpy()
75
 
76
+ # Format top 3 results with emojis
77
+ top_idx = np.argsort(probs)[::-1][:3]
78
+ result = {f"{LABELS[i]} {EMOJIS.get(LABELS[i], '')}": round(float(probs[i]), 4) for i in top_idx}
79
+
80
+ # Generate waveform plot
81
+ fig, ax = plt.subplots(figsize=(6,2))
82
+ ax.plot(data, color='purple')
83
+ ax.set_title("Audio Waveform")
84
+ ax.set_xlabel("Samples")
85
+ ax.set_ylabel("Amplitude")
86
+ ax.set_xticks([])
87
+ ax.set_yticks([])
88
+ plt.tight_layout()
89
+
90
+ return result, fig
91
 
92
  except Exception as e:
93
+ return {"Error": str(e)}, None
 
94
 
95
  # =========================
96
  # GRADIO APP
 
98
  demo = gr.Interface(
99
  fn=predict,
100
  inputs=gr.Audio(sources=["upload", "microphone"], type="numpy", label="🎀 Upload or Record Audio"),
101
+ outputs=[gr.Label(num_top_classes=3), gr.Plot()],
102
  title="Audio Emotion Detection 🎧",
103
  description=(
104
  "Fine-tuned Wav2Vec2 model (`Hatman/audio-emotion-detection`) "
105
  "for emotion recognition from voice. "
106
+ "Detects: Angry, Disgusted, Fearful, Happy, Neutral, Sad, Surprised. "
107
  "Audio is auto-resampled to 16kHz."
108
  ),
109
  allow_flagging="never",