Pant0x commited on
Commit
e7edbfd
·
verified ·
1 Parent(s): 4f15c75

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -71
app.py CHANGED
@@ -1,71 +1,71 @@
1
- import gradio as gr
2
- import torch
3
- from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2Processor
4
- import numpy as np
5
- import torchaudio
6
-
7
- # =========================
8
- # CONFIG
9
- # =========================
10
- MODEL_NAME = "your-username/Audio-Emotion-Detection" # <- replace with your repo name
11
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
12
-
13
- # =========================
14
- # LOAD MODEL & PROCESSOR
15
- # =========================
16
- processor = Wav2Vec2Processor.from_pretrained(MODEL_NAME)
17
- model = Wav2Vec2ForSequenceClassification.from_pretrained(MODEL_NAME).to(DEVICE)
18
-
19
- # Emotion labels in same order used during training
20
- LABELS = ["Angry", "Disgusted", "Fearful", "Happy", "Neutral", "Sad", "Surprised"]
21
-
22
- # =========================
23
- # PREDICTION PIPELINE
24
- # =========================
25
- def predict(audio):
26
- # audio: tuple (sample_rate, numpy array)
27
- sr, data = audio
28
-
29
- # Resample to 16k if necessary
30
- if sr != 16000:
31
- data = torchaudio.functional.resample(torch.tensor(data), sr, 16000).numpy()
32
- sr = 16000
33
-
34
- # Process input
35
- inputs = processor(
36
- data,
37
- sampling_rate=sr,
38
- return_tensors="pt",
39
- padding=True,
40
- truncation=True
41
- ).to(DEVICE)
42
-
43
- # Forward pass
44
- with torch.no_grad():
45
- logits = model(**inputs).logits
46
- probs = torch.nn.functional.softmax(logits, dim=-1)[0]
47
- pred_idx = torch.argmax(probs).item()
48
-
49
- return {LABELS[i]: float(probs[i]) for i in range(len(LABELS))}
50
-
51
- # =========================
52
- # GRADIO INTERFACE
53
- # =========================
54
- demo = gr.Interface(
55
- fn=predict,
56
- inputs=gr.Audio(sources=["upload", "microphone"], type="numpy", label="Upload or Record Audio"),
57
- outputs=gr.Label(num_top_classes=3),
58
- title="Audio Emotion Detection 🎧",
59
- description=(
60
- "Fine-tuned Wav2Vec2 model for detecting emotions from voice. "
61
- "Supports 7 emotions: Angry, Disgusted, Fearful, Happy, Neutral, Sad, and Surprised. "
62
- "All audio should be 16kHz."
63
- ),
64
- allow_flagging="never",
65
- )
66
-
67
- # =========================
68
- # LAUNCH APP
69
- # =========================
70
- if __name__ == "__main__":
71
- demo.launch()
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2Processor
4
+ import numpy as np
5
+ import torchaudio
6
+
7
+ # =========================
8
+ # CONFIG
9
+ # =========================
10
+ MODEL_NAME = "Hatman/audio-emotion-detection"
11
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
12
+
13
+ # =========================
14
+ # LOAD MODEL & PROCESSOR
15
+ # =========================
16
+ processor = Wav2Vec2Processor.from_pretrained(MODEL_NAME)
17
+ model = Wav2Vec2ForSequenceClassification.from_pretrained(MODEL_NAME).to(DEVICE)
18
+
19
+ # Emotion labels in same order used during training (matches the model card)
20
+ LABELS = ["angry", "disgust", "fear", "happy", "neutral", "sad", "surprised"]
21
+
22
+ # =========================
23
+ # PREDICTION PIPELINE
24
+ # =========================
25
+ def predict(audio):
26
+ # audio: tuple (sample_rate, numpy array)
27
+ sr, data = audio
28
+
29
+ # Resample to 16k if needed
30
+ if sr != 16000:
31
+ data = torchaudio.functional.resample(torch.tensor(data), sr, 16000).numpy()
32
+ sr = 16000
33
+
34
+ # Process input
35
+ inputs = processor(
36
+ data,
37
+ sampling_rate=sr,
38
+ return_tensors="pt",
39
+ padding=True,
40
+ truncation=True
41
+ ).to(DEVICE)
42
+
43
+ # Forward pass
44
+ with torch.no_grad():
45
+ logits = model(**inputs).logits
46
+ probs = torch.nn.functional.softmax(logits, dim=-1)[0]
47
+ pred_idx = torch.argmax(probs).item()
48
+
49
+ return {LABELS[i]: float(probs[i]) for i in range(len(LABELS))}
50
+
51
+ # =========================
52
+ # GRADIO INTERFACE
53
+ # =========================
54
+ demo = gr.Interface(
55
+ fn=predict,
56
+ inputs=gr.Audio(sources=["upload", "microphone"], type="numpy", label="Upload or Record Audio"),
57
+ outputs=gr.Label(num_top_classes=3),
58
+ title="Audio Emotion Detection 🎧",
59
+ description=(
60
+ "Wav2Vec2 emotion classification model. "
61
+ "Supports 7 emotions: Angry, Disgust, Fear, Happy, Neutral, Sad, and Surprised. "
62
+ "Upload audio or use your microphone."
63
+ ),
64
+ allow_flagging="never",
65
+ )
66
+
67
+ # =========================
68
+ # LAUNCH APP
69
+ # =========================
70
+ if __name__ == "__main__":
71
+ demo.launch()