Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torchaudio | |
| import numpy as np | |
| from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor, AutoConfig | |
| import matplotlib.pyplot as plt | |
| # ========================= | |
| # CONFIG | |
| # ========================= | |
| MODEL_NAME = "Hatman/audio-emotion-detection" | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # ========================= | |
| # LOAD MODEL & FEATURE EXTRACTOR | |
| # ========================= | |
| feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_NAME) | |
| model = Wav2Vec2ForSequenceClassification.from_pretrained(MODEL_NAME).to(DEVICE) | |
| model.eval() | |
| # Use the modelβs label mapping directly | |
| config = AutoConfig.from_pretrained(MODEL_NAME) | |
| LABELS = [config.id2label[i] for i in range(len(config.id2label))] | |
| # Map some emojis to each emotion for fun UI | |
| EMOJIS = { | |
| "Angry": "π‘", | |
| "Disgusted": "π€’", | |
| "Fearful": "π¨", | |
| "Happy": "π", | |
| "Neutral": "π", | |
| "Sad": "π’", | |
| "Surprised": "π²" | |
| } | |
| # ========================= | |
| # PREDICTION FUNCTION | |
| # ========================= | |
| def predict(audio): | |
| try: | |
| if audio is None: | |
| return {"Error": "No audio provided"}, None | |
| sr, data = audio | |
| data = np.array(data, dtype=np.float32) | |
| # Stereo -> Mono | |
| if len(data.shape) > 1: | |
| data = np.mean(data, axis=1) | |
| # Resample to 16kHz | |
| if sr != 16000: | |
| data = torchaudio.functional.resample(torch.tensor(data), sr, 16000).numpy() | |
| sr = 16000 | |
| # Normalize for Wav2Vec2 | |
| data = data / 32768.0 | |
| # Feature extraction | |
| inputs = feature_extractor( | |
| data, | |
| sampling_rate=sr, | |
| return_tensors="pt", | |
| padding=True | |
| ) | |
| # Move tensors to device | |
| for k in inputs: | |
| inputs[k] = inputs[k].to(DEVICE) | |
| # Forward pass | |
| with torch.no_grad(): | |
| logits = model(**inputs).logits | |
| probs = torch.nn.functional.softmax(logits, dim=-1)[0].cpu().numpy() | |
| # Format top 3 results with emojis | |
| top_idx = np.argsort(probs)[::-1][:3] | |
| result = {f"{LABELS[i]} {EMOJIS.get(LABELS[i], '')}": round(float(probs[i]), 4) for i in top_idx} | |
| # Generate waveform plot | |
| fig, ax = plt.subplots(figsize=(6,2)) | |
| ax.plot(data, color='purple') | |
| ax.set_title("Audio Waveform") | |
| ax.set_xlabel("Samples") | |
| ax.set_ylabel("Amplitude") | |
| ax.set_xticks([]) | |
| ax.set_yticks([]) | |
| plt.tight_layout() | |
| return result, fig | |
| except Exception as e: | |
| return {"Error": str(e)}, None | |
| # ========================= | |
| # GRADIO APP | |
| # ========================= | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Audio(sources=["upload", "microphone"], type="numpy", label="π€ Upload or Record Audio"), | |
| outputs=[gr.Label(num_top_classes=3), gr.Plot()], | |
| title="Audio Emotion Detection π§", | |
| description=( | |
| "Fine-tuned Wav2Vec2 model (`Hatman/audio-emotion-detection`) " | |
| "for emotion recognition from voice. " | |
| "Detects: Angry, Disgusted, Fearful, Happy, Neutral, Sad, Surprised. " | |
| "Audio is auto-resampled to 16kHz." | |
| ), | |
| allow_flagging="never", | |
| ) | |
| # ========================= | |
| # LAUNCH | |
| # ========================= | |
| if __name__ == "__main__": | |
| demo.launch() | |