Pant0x commited on
Commit
afb816c
·
verified ·
1 Parent(s): 63e9297

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -32
app.py CHANGED
@@ -18,7 +18,7 @@ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_NAME)
18
  model = Wav2Vec2ForSequenceClassification.from_pretrained(MODEL_NAME).to(DEVICE)
19
  model.eval()
20
 
21
- # Use the models label mapping directly
22
  config = AutoConfig.from_pretrained(MODEL_NAME)
23
  LABELS = [config.id2label[i] for i in range(len(config.id2label))]
24
 
@@ -40,55 +40,78 @@ def predict(audio):
40
  try:
41
  if audio is None:
42
  return {"Error": "No audio provided"}, None
43
-
44
  sr, data = audio
45
  data = np.array(data, dtype=np.float32)
46
-
47
  # Stereo -> Mono
48
  if len(data.shape) > 1:
49
  data = np.mean(data, axis=1)
50
-
51
  # Resample to 16kHz
52
  if sr != 16000:
53
  data = torchaudio.functional.resample(torch.tensor(data), sr, 16000).numpy()
54
  sr = 16000
55
-
56
- # Normalize for Wav2Vec2
57
- data = data / 32768.0
58
-
59
- # Feature extraction
 
 
 
 
 
 
 
 
 
 
60
  inputs = feature_extractor(
61
  data,
62
  sampling_rate=sr,
63
  return_tensors="pt",
64
- padding=True
 
 
65
  )
66
-
67
  # Move tensors to device
68
  for k in inputs:
69
  inputs[k] = inputs[k].to(DEVICE)
70
-
71
  # Forward pass
72
  with torch.no_grad():
73
  logits = model(**inputs).logits
 
 
 
 
 
 
74
  probs = torch.nn.functional.softmax(logits, dim=-1)[0].cpu().numpy()
75
-
76
- # Format top 3 results with emojis
77
- top_idx = np.argsort(probs)[::-1][:3]
78
- result = {f"{LABELS[i]} {EMOJIS.get(LABELS[i], '')}": round(float(probs[i]), 4) for i in top_idx}
79
-
 
 
 
 
 
80
  # Generate waveform plot
81
- fig, ax = plt.subplots(figsize=(6,2))
82
- ax.plot(data, color='purple')
83
- ax.set_title("Audio Waveform")
84
- ax.set_xlabel("Samples")
 
85
  ax.set_ylabel("Amplitude")
86
- ax.set_xticks([])
87
- ax.set_yticks([])
88
  plt.tight_layout()
89
-
90
  return result, fig
91
-
92
  except Exception as e:
93
  return {"Error": str(e)}, None
94
 
@@ -98,19 +121,28 @@ def predict(audio):
98
  demo = gr.Interface(
99
  fn=predict,
100
  inputs=gr.Audio(sources=["upload", "microphone"], type="numpy", label="🎤 Upload or Record Audio"),
101
- outputs=[gr.Label(num_top_classes=3), gr.Plot()],
102
- title="Audio Emotion Detection 🎧",
 
 
 
103
  description=(
104
- "Fine-tuned Wav2Vec2 model (`Hatman/audio-emotion-detection`) "
105
- "for emotion recognition from voice. "
106
- "Detects: Angry, Disgusted, Fearful, Happy, Neutral, Sad, Surprised. "
107
- "Audio is auto-resampled to 16kHz."
 
 
 
 
108
  ),
 
109
  allow_flagging="never",
 
110
  )
111
 
112
  # =========================
113
  # LAUNCH
114
  # =========================
115
  if __name__ == "__main__":
116
- demo.launch()
 
18
  model = Wav2Vec2ForSequenceClassification.from_pretrained(MODEL_NAME).to(DEVICE)
19
  model.eval()
20
 
21
+ # Use the model's label mapping directly
22
  config = AutoConfig.from_pretrained(MODEL_NAME)
23
  LABELS = [config.id2label[i] for i in range(len(config.id2label))]
24
 
 
40
  try:
41
  if audio is None:
42
  return {"Error": "No audio provided"}, None
43
+
44
  sr, data = audio
45
  data = np.array(data, dtype=np.float32)
46
+
47
  # Stereo -> Mono
48
  if len(data.shape) > 1:
49
  data = np.mean(data, axis=1)
50
+
51
  # Resample to 16kHz
52
  if sr != 16000:
53
  data = torchaudio.functional.resample(torch.tensor(data), sr, 16000).numpy()
54
  sr = 16000
55
+
56
+ # Improved normalization - normalize to [-1, 1] range
57
+ # Check if data is in int16 range or already normalized
58
+ if np.abs(data).max() > 1.0:
59
+ data = data / np.abs(data).max() # Normalize by max value
60
+
61
+ # Apply gentle audio preprocessing to improve feature extraction
62
+ # Remove DC offset
63
+ data = data - np.mean(data)
64
+
65
+ # Apply light pre-emphasis filter to balance frequencies
66
+ pre_emphasis = 0.97
67
+ data = np.append(data[0], data[1:] - pre_emphasis * data[:-1])
68
+
69
+ # Feature extraction with proper padding
70
  inputs = feature_extractor(
71
  data,
72
  sampling_rate=sr,
73
  return_tensors="pt",
74
+ padding=True,
75
+ max_length=16000 * 10, # Max 10 seconds
76
+ truncation=True
77
  )
78
+
79
  # Move tensors to device
80
  for k in inputs:
81
  inputs[k] = inputs[k].to(DEVICE)
82
+
83
  # Forward pass
84
  with torch.no_grad():
85
  logits = model(**inputs).logits
86
+
87
+ # Apply temperature scaling to reduce overconfidence
88
+ # Lower temperature = more uniform distribution
89
+ temperature = 1.5
90
+ logits = logits / temperature
91
+
92
  probs = torch.nn.functional.softmax(logits, dim=-1)[0].cpu().numpy()
93
+
94
+ # Show ALL emotions with their scores (not just top 3)
95
+ result = {}
96
+ for i, label in enumerate(LABELS):
97
+ emoji = EMOJIS.get(label, '')
98
+ result[f"{label} {emoji}"] = round(float(probs[i]), 4)
99
+
100
+ # Sort by probability
101
+ result = dict(sorted(result.items(), key=lambda x: x[1], reverse=True))
102
+
103
  # Generate waveform plot
104
+ fig, ax = plt.subplots(figsize=(8, 3))
105
+ time_axis = np.linspace(0, len(data) / sr, len(data))
106
+ ax.plot(time_axis, data, color='purple', linewidth=0.5)
107
+ ax.set_title("Audio Waveform", fontsize=12, fontweight='bold')
108
+ ax.set_xlabel("Time (seconds)")
109
  ax.set_ylabel("Amplitude")
110
+ ax.grid(True, alpha=0.3)
 
111
  plt.tight_layout()
112
+
113
  return result, fig
114
+
115
  except Exception as e:
116
  return {"Error": str(e)}, None
117
 
 
121
  demo = gr.Interface(
122
  fn=predict,
123
  inputs=gr.Audio(sources=["upload", "microphone"], type="numpy", label="🎤 Upload or Record Audio"),
124
+ outputs=[
125
+ gr.Label(num_top_classes=7, label="Emotion Probabilities"),
126
+ gr.Plot(label="Waveform Visualization")
127
+ ],
128
+ title="🎧 Audio Emotion Detection",
129
  description=(
130
+ "Fine-tuned Wav2Vec2 model for emotion recognition from voice. "
131
+ "Detects: **Angry, Disgusted, Fearful, Happy, Neutral, Sad, Surprised**.\n\n"
132
+ "**Tips for better results:**\n"
133
+ "- Speak clearly and naturally\n"
134
+ "- Record at least 2-3 seconds of audio\n"
135
+ "- Avoid background noise\n"
136
+ "- Try exaggerating emotions for testing\n\n"
137
+ "Audio is automatically resampled to 16kHz and normalized."
138
  ),
139
+ examples=[],
140
  allow_flagging="never",
141
+ theme=gr.themes.Soft()
142
  )
143
 
144
  # =========================
145
  # LAUNCH
146
  # =========================
147
  if __name__ == "__main__":
148
+ demo.launch()