Pant0x commited on
Commit
9ee8bab
·
verified ·
1 Parent(s): d961cfc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -51
app.py CHANGED
@@ -1,8 +1,7 @@
1
  import gradio as gr
2
  import torch
3
- import torchaudio
4
- import numpy as np
5
  from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor
 
6
 
7
  # =========================
8
  # CONFIG
@@ -15,65 +14,39 @@ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
15
  # =========================
16
  feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_NAME)
17
  model = Wav2Vec2ForSequenceClassification.from_pretrained(MODEL_NAME).to(DEVICE)
18
- model.eval()
19
 
20
- # Emotion labels in correct order
21
  LABELS = ["Angry", "Disgusted", "Fearful", "Happy", "Neutral", "Sad", "Surprised"]
22
 
23
  # =========================
24
- # PREDICTION FUNCTION
25
  # =========================
26
  def predict(audio):
27
- try:
28
- if audio is None:
29
- return {"Error": "No audio provided"}
30
-
31
- sr, data = audio
32
-
33
- # Ensure it's a NumPy array
34
- if isinstance(data, list):
35
- data = np.array(data)
36
-
37
- # Convert stereo -> mono if needed
38
- if len(data.shape) > 1:
39
- data = np.mean(data, axis=1)
40
-
41
- # Remove NaNs or infs
42
- if np.isnan(data).any() or np.isinf(data).any():
43
- data = np.nan_to_num(data)
44
-
45
- # Resample to 16kHz if necessary
46
- if sr != 16000:
47
- data = torchaudio.functional.resample(
48
- torch.tensor(data), sr, 16000
49
- ).numpy()
50
- sr = 16000
51
-
52
- # Normalize
53
- data = data / np.abs(data).max() if np.abs(data).max() > 0 else data
54
 
55
- # Feature extraction
56
- inputs = feature_extractor(
57
- data,
58
- sampling_rate=sr,
59
- return_tensors="pt",
60
- padding=True
61
- ).to(DEVICE)
62
 
63
- # Forward pass
64
- with torch.no_grad():
65
- logits = model(**inputs).logits
66
- probs = torch.nn.functional.softmax(logits, dim=-1)[0].cpu().numpy()
 
 
 
67
 
68
- # Format output
69
- result = {LABELS[i]: float(probs[i]) for i in range(len(LABELS))}
70
- return result
 
 
71
 
72
- except Exception as e:
73
- return {"Error": str(e)}
74
 
75
  # =========================
76
- # GRADIO APP
77
  # =========================
78
  demo = gr.Interface(
79
  fn=predict,
@@ -84,13 +57,13 @@ demo = gr.Interface(
84
  "Fine-tuned Wav2Vec2 model (`Hatman/audio-emotion-detection`) "
85
  "for emotion recognition from voice. "
86
  "Detects: Angry, Disgusted, Fearful, Happy, Neutral, Sad, and Surprised. "
87
- "Audio must be 16kHz (auto-resampled)."
88
  ),
89
  allow_flagging="never",
90
  )
91
 
92
  # =========================
93
- # LAUNCH
94
  # =========================
95
  if __name__ == "__main__":
96
  demo.launch()
 
1
  import gradio as gr
2
  import torch
 
 
3
  from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor
4
+ import torchaudio
5
 
6
  # =========================
7
  # CONFIG
 
14
  # =========================
15
  feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_NAME)
16
  model = Wav2Vec2ForSequenceClassification.from_pretrained(MODEL_NAME).to(DEVICE)
 
17
 
18
+ # Emotion labels in model's order
19
  LABELS = ["Angry", "Disgusted", "Fearful", "Happy", "Neutral", "Sad", "Surprised"]
20
 
21
  # =========================
22
+ # PREDICTION PIPELINE
23
  # =========================
24
  def predict(audio):
25
+ sr, data = audio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ # Resample to 16kHz if needed
28
+ if sr != 16000:
29
+ data = torchaudio.functional.resample(torch.tensor(data), sr, 16000).numpy()
30
+ sr = 16000
 
 
 
31
 
32
+ # Extract features
33
+ inputs = feature_extractor(
34
+ data,
35
+ sampling_rate=sr,
36
+ return_tensors="pt",
37
+ padding=True
38
+ ).to(DEVICE)
39
 
40
+ # Forward pass
41
+ with torch.no_grad():
42
+ logits = model(**inputs).logits
43
+ probs = torch.nn.functional.softmax(logits, dim=-1)[0]
44
+ pred_idx = torch.argmax(probs).item()
45
 
46
+ return {LABELS[i]: float(probs[i]) for i in range(len(LABELS))}
 
47
 
48
  # =========================
49
+ # GRADIO INTERFACE
50
  # =========================
51
  demo = gr.Interface(
52
  fn=predict,
 
57
  "Fine-tuned Wav2Vec2 model (`Hatman/audio-emotion-detection`) "
58
  "for emotion recognition from voice. "
59
  "Detects: Angry, Disgusted, Fearful, Happy, Neutral, Sad, and Surprised. "
60
+ "Audio should be 16kHz for best accuracy."
61
  ),
62
  allow_flagging="never",
63
  )
64
 
65
  # =========================
66
+ # LAUNCH APP
67
  # =========================
68
  if __name__ == "__main__":
69
  demo.launch()