Pant0x commited on
Commit
e351a45
·
verified ·
1 Parent(s): 9ee8bab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -24
app.py CHANGED
@@ -1,7 +1,8 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor
4
  import torchaudio
 
 
5
 
6
  # =========================
7
  # CONFIG
@@ -12,41 +13,74 @@ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
12
  # =========================
13
  # LOAD MODEL & FEATURE EXTRACTOR
14
  # =========================
 
15
  feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_NAME)
16
  model = Wav2Vec2ForSequenceClassification.from_pretrained(MODEL_NAME).to(DEVICE)
 
 
17
 
18
- # Emotion labels in model's order
19
  LABELS = ["Angry", "Disgusted", "Fearful", "Happy", "Neutral", "Sad", "Surprised"]
20
 
21
  # =========================
22
- # PREDICTION PIPELINE
23
  # =========================
24
  def predict(audio):
25
- sr, data = audio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
- # Resample to 16kHz if needed
28
- if sr != 16000:
29
- data = torchaudio.functional.resample(torch.tensor(data), sr, 16000).numpy()
30
- sr = 16000
31
 
32
- # Extract features
33
- inputs = feature_extractor(
34
- data,
35
- sampling_rate=sr,
36
- return_tensors="pt",
37
- padding=True
38
- ).to(DEVICE)
39
 
40
- # Forward pass
41
- with torch.no_grad():
42
- logits = model(**inputs).logits
43
- probs = torch.nn.functional.softmax(logits, dim=-1)[0]
44
- pred_idx = torch.argmax(probs).item()
45
 
46
- return {LABELS[i]: float(probs[i]) for i in range(len(LABELS))}
 
 
47
 
48
  # =========================
49
- # GRADIO INTERFACE
50
  # =========================
51
  demo = gr.Interface(
52
  fn=predict,
@@ -57,13 +91,13 @@ demo = gr.Interface(
57
  "Fine-tuned Wav2Vec2 model (`Hatman/audio-emotion-detection`) "
58
  "for emotion recognition from voice. "
59
  "Detects: Angry, Disgusted, Fearful, Happy, Neutral, Sad, and Surprised. "
60
- "Audio should be 16kHz for best accuracy."
61
  ),
62
  allow_flagging="never",
63
  )
64
 
65
  # =========================
66
- # LAUNCH APP
67
  # =========================
68
  if __name__ == "__main__":
69
  demo.launch()
 
1
  import gradio as gr
2
  import torch
 
3
  import torchaudio
4
+ import numpy as np
5
+ from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor
6
 
7
  # =========================
8
  # CONFIG
 
13
  # =========================
14
  # LOAD MODEL & FEATURE EXTRACTOR
15
  # =========================
16
+ print(f"Loading model: {MODEL_NAME}")
17
  feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_NAME)
18
  model = Wav2Vec2ForSequenceClassification.from_pretrained(MODEL_NAME).to(DEVICE)
19
+ model.eval()
20
+ print("Model loaded successfully.")
21
 
22
+ # Emotion labels in correct order
23
  LABELS = ["Angry", "Disgusted", "Fearful", "Happy", "Neutral", "Sad", "Surprised"]
24
 
25
  # =========================
26
+ # PREDICTION FUNCTION
27
  # =========================
28
  def predict(audio):
29
+ try:
30
+ if audio is None:
31
+ return {"Error": "No audio provided"}
32
+
33
+ sr, data = audio
34
+ if isinstance(data, list):
35
+ data = np.array(data, dtype=np.float32)
36
+
37
+ # Convert stereo -> mono
38
+ if len(data.shape) > 1:
39
+ data = np.mean(data, axis=1)
40
+
41
+ # Resample if needed
42
+ if sr != 16000:
43
+ waveform = torch.tensor(data, dtype=torch.float32)
44
+ data = torchaudio.functional.resample(waveform, sr, 16000).numpy()
45
+ sr = 16000
46
+
47
+ # Normalize
48
+ if np.abs(data).max() > 0:
49
+ data = data / np.abs(data).max()
50
+
51
+ # Make sure dtype and shape are clean
52
+ data = np.array(data, dtype=np.float32).flatten()
53
+
54
+ # Debug info
55
+ print(f"Sample rate: {sr}, Data shape: {data.shape}, Device: {DEVICE}")
56
+
57
+ # Feature extraction
58
+ inputs = feature_extractor(
59
+ data,
60
+ sampling_rate=sr,
61
+ return_tensors="pt",
62
+ padding=True
63
+ )
64
 
65
+ # Move to device
66
+ for k in inputs:
67
+ inputs[k] = inputs[k].to(DEVICE)
 
68
 
69
+ # Forward pass
70
+ with torch.no_grad():
71
+ logits = model(**inputs).logits
72
+ probs = torch.nn.functional.softmax(logits, dim=-1)[0].cpu().numpy()
 
 
 
73
 
74
+ result = {LABELS[i]: round(float(probs[i]), 4) for i in range(len(LABELS))}
75
+ print(f"Predicted: {result}")
76
+ return result
 
 
77
 
78
+ except Exception as e:
79
+ print(f"ERROR: {str(e)}")
80
+ return {"Error": str(e)}
81
 
82
  # =========================
83
+ # GRADIO APP
84
  # =========================
85
  demo = gr.Interface(
86
  fn=predict,
 
91
  "Fine-tuned Wav2Vec2 model (`Hatman/audio-emotion-detection`) "
92
  "for emotion recognition from voice. "
93
  "Detects: Angry, Disgusted, Fearful, Happy, Neutral, Sad, and Surprised. "
94
+ "Audio is auto-resampled to 16kHz."
95
  ),
96
  allow_flagging="never",
97
  )
98
 
99
  # =========================
100
+ # LAUNCH
101
  # =========================
102
  if __name__ == "__main__":
103
  demo.launch()