codetocare commited on
Commit
0ae8940
Β·
verified Β·
1 Parent(s): f80b262

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -26
app.py CHANGED
@@ -1,40 +1,21 @@
1
- import gradio as gr
2
- from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2Processor
3
- import torch
4
- import torchaudio
5
 
6
- # Load the pre-trained model and processor
7
- model_name = "bhadresh-savani/wav2vec2-large-robust-english-emotion"
8
- processor = Wav2Vec2Processor.from_pretrained(model_name)
9
- model = Wav2Vec2ForSequenceClassification.from_pretrained(model_name)
10
-
11
- # Emotion labels for this specific model
12
- labels = ['angry', 'calm', 'happy', 'sad']
13
 
14
  def predict_emotion(audio):
15
- # audio: tuple -> (sample_rate, numpy array)
16
  speech, sr = audio
17
  if sr != 16000:
18
  resampler = torchaudio.transforms.Resample(sr, 16000)
19
- speech = resampler(torch.tensor(speech))
20
  else:
21
- speech = torch.tensor(speech)
22
 
23
  input_values = processor(speech, sampling_rate=16000, return_tensors="pt").input_values
24
  with torch.no_grad():
25
  logits = model(input_values).logits
26
 
27
  predicted_id = torch.argmax(logits, dim=-1).item()
28
- emotion = labels[predicted_id]
29
  return f"Predicted Emotion: **{emotion}**"
30
-
31
- # Gradio interface
32
- interface = gr.Interface(
33
- fn=predict_emotion,
34
- inputs=gr.Audio(source="microphone", type="numpy", label="Record or Upload Speech"),
35
- outputs=gr.Markdown(label="Emotion"),
36
- title="Voice Emotion Recognition",
37
- description="Speak or upload a WAV file to detect the emotion using a fine-tuned Wav2Vec2 model."
38
- )
39
-
40
- interface.launch()
 
1
+ model_name = "Dpngtm/wav2vec2-emotion-recognition"
2
+ processor = Wav2Vec2Processor.from_pretrained(model_name)
3
+ model = Wav2Vec2ForSequenceClassification.from_pretrained(model_name)
 
4
 
5
+ labels = ["angry", "calm", "disgust", "fearful", "happy", "neutral", "sad", "surprised"]
 
 
 
 
 
 
6
 
7
  def predict_emotion(audio):
 
8
  speech, sr = audio
9
  if sr != 16000:
10
  resampler = torchaudio.transforms.Resample(sr, 16000)
11
+ speech = resampler(torch.tensor(speech))
12
  else:
13
+ speech = torch.tensor(speech)
14
 
15
  input_values = processor(speech, sampling_rate=16000, return_tensors="pt").input_values
16
  with torch.no_grad():
17
  logits = model(input_values).logits
18
 
19
  predicted_id = torch.argmax(logits, dim=-1).item()
20
+ emotion = labels[predicted_id]
21
  return f"Predicted Emotion: **{emotion}**"