Pant0x commited on
Commit
1c88dc7
·
verified ·
1 Parent(s): 257f2a0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -12
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2Processor
4
  import torchaudio
5
 
6
  # =========================
@@ -10,12 +10,12 @@ MODEL_NAME = "Hatman/audio-emotion-detection"
10
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
11
 
12
  # =========================
13
- # LOAD MODEL & PROCESSOR
14
  # =========================
15
- processor = Wav2Vec2Processor.from_pretrained(MODEL_NAME)
16
  model = Wav2Vec2ForSequenceClassification.from_pretrained(MODEL_NAME).to(DEVICE)
17
 
18
- # Emotion labels (must match model training order)
19
  LABELS = ["Angry", "Disgusted", "Fearful", "Happy", "Neutral", "Sad", "Surprised"]
20
 
21
  # =========================
@@ -24,18 +24,17 @@ LABELS = ["Angry", "Disgusted", "Fearful", "Happy", "Neutral", "Sad", "Surprised
24
  def predict(audio):
25
  sr, data = audio
26
 
27
- # Resample to 16kHz if necessary
28
  if sr != 16000:
29
  data = torchaudio.functional.resample(torch.tensor(data), sr, 16000).numpy()
30
  sr = 16000
31
 
32
- # Prepare input
33
- inputs = processor(
34
  data,
35
  sampling_rate=sr,
36
  return_tensors="pt",
37
- padding=True,
38
- truncation=True
39
  ).to(DEVICE)
40
 
41
  # Forward pass
@@ -56,9 +55,9 @@ demo = gr.Interface(
56
  title="Audio Emotion Detection 🎧",
57
  description=(
58
  "Fine-tuned Wav2Vec2 model (`Hatman/audio-emotion-detection`) "
59
- "for emotion recognition in voice. "
60
- "Predicts: Angry, Disgusted, Fearful, Happy, Neutral, Sad, and Surprised. "
61
- "Audio must be 16kHz."
62
  ),
63
  allow_flagging="never",
64
  )
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor
4
  import torchaudio
5
 
6
  # =========================
 
10
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
11
 
12
  # =========================
13
+ # LOAD MODEL & FEATURE EXTRACTOR
14
  # =========================
15
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_NAME)
16
  model = Wav2Vec2ForSequenceClassification.from_pretrained(MODEL_NAME).to(DEVICE)
17
 
18
+ # Emotion labels in model's order
19
  LABELS = ["Angry", "Disgusted", "Fearful", "Happy", "Neutral", "Sad", "Surprised"]
20
 
21
  # =========================
 
24
  def predict(audio):
25
  sr, data = audio
26
 
27
+ # Resample to 16kHz if needed
28
  if sr != 16000:
29
  data = torchaudio.functional.resample(torch.tensor(data), sr, 16000).numpy()
30
  sr = 16000
31
 
32
+ # Extract features
33
+ inputs = feature_extractor(
34
  data,
35
  sampling_rate=sr,
36
  return_tensors="pt",
37
+ padding=True
 
38
  ).to(DEVICE)
39
 
40
  # Forward pass
 
55
  title="Audio Emotion Detection 🎧",
56
  description=(
57
  "Fine-tuned Wav2Vec2 model (`Hatman/audio-emotion-detection`) "
58
+ "for emotion recognition from voice. "
59
+ "Detects: Angry, Disgusted, Fearful, Happy, Neutral, Sad, and Surprised. "
60
+ "Audio should be 16kHz for best accuracy."
61
  ),
62
  allow_flagging="never",
63
  )