Raemih commited on
Commit
cf64064
·
verified ·
1 Parent(s): ec48379

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -57
app.py CHANGED
@@ -1,68 +1,48 @@
1
  import gradio as gr
2
  import torch
3
  import librosa
4
- from transformers import AutoFeatureExtractor
5
- from model import MMSForMultilingualSER
6
 
7
- MODEL_ID = "E-motionAssistant/mms-300m-multilingual-ser"
 
 
 
8
 
9
- emotion_labels = [
10
- "neutral",
11
- "happy",
12
- "sad",
13
- "anger",
14
- "fear"
15
- ]
16
-
17
- device = "cpu"
18
-
19
- print("Loading model...")
20
-
21
- feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_ID)
22
-
23
- emotion_model = MMSForMultilingualSER.from_pretrained(
24
- MODEL_ID,
25
- ignore_mismatched_sizes=True
26
- )
27
-
28
- emotion_model.eval()
29
-
30
- print("Model loaded")
31
-
32
-
33
- def detect_emotion(audio):
34
 
 
 
35
  speech, sr = librosa.load(audio, sr=16000)
36
 
37
- inputs = feature_extractor(
38
- speech,
39
- sampling_rate=16000,
40
- return_tensors="pt"
41
- )
42
 
 
43
  with torch.no_grad():
44
- logits = emotion_model(**inputs)
45
-
46
- pred = torch.argmax(logits, dim=-1).item()
47
-
48
- return emotion_labels[pred]
49
-
50
-
51
- with gr.Blocks() as demo:
52
-
53
- gr.Markdown("# Emotion Regulation Assistant")
54
-
55
- with gr.Tab("Emotion Detection"):
56
-
57
- audio_input = gr.Audio(type="filepath")
58
- output = gr.Textbox(label="Detected Emotion")
59
-
60
- btn = gr.Button("Detect Emotion")
61
-
62
- btn.click(
63
- fn=detect_emotion,
64
- inputs=audio_input,
65
- outputs=output
66
- )
67
 
68
- demo.launch()
 
 
1
  import gradio as gr
2
  import torch
3
  import librosa
4
+ from transformers import Wav2Vec2FeatureExtractor, HubertForSequenceClassification
 
5
 
6
+ # Load model and processor
7
+ model_id = "superb/hubert-base-superb-er"
8
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_id)
9
+ model = HubertForSequenceClassification.from_pretrained(model_id)
10
 
11
+ def predict_emotion(audio):
12
+ if audio is None:
13
+ return "Please upload an audio file."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
+ # Load and resample audio to 16kHz
16
+ # Gradio provides the path to the temporary file
17
  speech, sr = librosa.load(audio, sr=16000)
18
 
19
+ # Preprocess
20
+ inputs = feature_extractor(speech, sampling_rate=16000, return_tensors="pt", padding=True)
 
 
 
21
 
22
+ # Inference
23
  with torch.no_grad():
24
+ logits = model(**inputs).logits
25
+
26
+ # Get probabilities via Softmax
27
+ probs = torch.nn.functional.softmax(logits, dim=-1)
28
+
29
+ # Map to labels
30
+ # Model labels: 0: neu, 1: hap, 2: ang, 3: sad
31
+ labels = ["Neutral", "Happy", "Angry", "Sad"]
32
+ results = {labels[i]: float(probs[0][i]) for i in range(len(labels))}
33
+
34
+ return results
35
+
36
+ # Define the Gradio Interface
37
+ demo = gr.Interface(
38
+ fn=predict_emotion,
39
+ inputs=gr.Audio(type="filepath", label="Upload Audio or Record"),
40
+ outputs=gr.Label(label="Detected Emotion"),
41
+ title="HuBERT Emotion Recognition",
42
+ description="Upload an audio clip to detect the primary emotion. This model (hubert-base-superb-er) is fine-tuned for Neutral, Happy, Angry, and Sad classifications.",
43
+ examples=[], # You can add paths to example .wav files here
44
+ theme="soft"
45
+ )
 
46
 
47
+ if __name__ == "__main__":
48
+ demo.launch()