Pant0x commited on
Commit
d961cfc
·
verified ·
1 Parent(s): 1c88dc7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -24
app.py CHANGED
@@ -1,7 +1,8 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor
4
  import torchaudio
 
 
5
 
6
  # =========================
7
  # CONFIG
@@ -14,39 +15,65 @@ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
14
  # =========================
15
  feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_NAME)
16
  model = Wav2Vec2ForSequenceClassification.from_pretrained(MODEL_NAME).to(DEVICE)
 
17
 
18
- # Emotion labels in model's order
19
  LABELS = ["Angry", "Disgusted", "Fearful", "Happy", "Neutral", "Sad", "Surprised"]
20
 
21
  # =========================
22
- # PREDICTION PIPELINE
23
  # =========================
24
  def predict(audio):
25
- sr, data = audio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
- # Resample to 16kHz if needed
28
- if sr != 16000:
29
- data = torchaudio.functional.resample(torch.tensor(data), sr, 16000).numpy()
30
- sr = 16000
 
 
 
31
 
32
- # Extract features
33
- inputs = feature_extractor(
34
- data,
35
- sampling_rate=sr,
36
- return_tensors="pt",
37
- padding=True
38
- ).to(DEVICE)
39
 
40
- # Forward pass
41
- with torch.no_grad():
42
- logits = model(**inputs).logits
43
- probs = torch.nn.functional.softmax(logits, dim=-1)[0]
44
- pred_idx = torch.argmax(probs).item()
45
 
46
- return {LABELS[i]: float(probs[i]) for i in range(len(LABELS))}
 
47
 
48
  # =========================
49
- # GRADIO INTERFACE
50
  # =========================
51
  demo = gr.Interface(
52
  fn=predict,
@@ -57,13 +84,13 @@ demo = gr.Interface(
57
  "Fine-tuned Wav2Vec2 model (`Hatman/audio-emotion-detection`) "
58
  "for emotion recognition from voice. "
59
  "Detects: Angry, Disgusted, Fearful, Happy, Neutral, Sad, and Surprised. "
60
- "Audio should be 16kHz for best accuracy."
61
  ),
62
  allow_flagging="never",
63
  )
64
 
65
  # =========================
66
- # LAUNCH APP
67
  # =========================
68
  if __name__ == "__main__":
69
  demo.launch()
 
1
  import gradio as gr
2
  import torch
 
3
  import torchaudio
4
+ import numpy as np
5
+ from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor
6
 
7
  # =========================
8
  # CONFIG
 
15
  # =========================
16
  feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_NAME)
17
  model = Wav2Vec2ForSequenceClassification.from_pretrained(MODEL_NAME).to(DEVICE)
18
+ model.eval()
19
 
20
+ # Emotion labels in correct order
21
  LABELS = ["Angry", "Disgusted", "Fearful", "Happy", "Neutral", "Sad", "Surprised"]
22
 
23
  # =========================
24
+ # PREDICTION FUNCTION
25
  # =========================
26
  def predict(audio):
27
+ try:
28
+ if audio is None:
29
+ return {"Error": "No audio provided"}
30
+
31
+ sr, data = audio
32
+
33
+ # Ensure it's a NumPy array
34
+ if isinstance(data, list):
35
+ data = np.array(data)
36
+
37
+ # Convert stereo -> mono if needed
38
+ if len(data.shape) > 1:
39
+ data = np.mean(data, axis=1)
40
+
41
+ # Remove NaNs or infs
42
+ if np.isnan(data).any() or np.isinf(data).any():
43
+ data = np.nan_to_num(data)
44
+
45
+ # Resample to 16kHz if necessary
46
+ if sr != 16000:
47
+ data = torchaudio.functional.resample(
48
+ torch.tensor(data), sr, 16000
49
+ ).numpy()
50
+ sr = 16000
51
+
52
+ # Normalize
53
+ data = data / np.abs(data).max() if np.abs(data).max() > 0 else data
54
 
55
+ # Feature extraction
56
+ inputs = feature_extractor(
57
+ data,
58
+ sampling_rate=sr,
59
+ return_tensors="pt",
60
+ padding=True
61
+ ).to(DEVICE)
62
 
63
+ # Forward pass
64
+ with torch.no_grad():
65
+ logits = model(**inputs).logits
66
+ probs = torch.nn.functional.softmax(logits, dim=-1)[0].cpu().numpy()
 
 
 
67
 
68
+ # Format output
69
+ result = {LABELS[i]: float(probs[i]) for i in range(len(LABELS))}
70
+ return result
 
 
71
 
72
+ except Exception as e:
73
+ return {"Error": str(e)}
74
 
75
  # =========================
76
+ # GRADIO APP
77
  # =========================
78
  demo = gr.Interface(
79
  fn=predict,
 
84
  "Fine-tuned Wav2Vec2 model (`Hatman/audio-emotion-detection`) "
85
  "for emotion recognition from voice. "
86
  "Detects: Angry, Disgusted, Fearful, Happy, Neutral, Sad, and Surprised. "
87
+ "Audio must be 16kHz (auto-resampled)."
88
  ),
89
  allow_flagging="never",
90
  )
91
 
92
  # =========================
93
+ # LAUNCH
94
  # =========================
95
  if __name__ == "__main__":
96
  demo.launch()