import gradio as gr import torch import torchaudio import numpy as np from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor, AutoConfig import matplotlib.pyplot as plt # ========================= # CONFIG # ========================= MODEL_NAME = "Hatman/audio-emotion-detection" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # ========================= # LOAD MODEL & FEATURE EXTRACTOR # ========================= feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_NAME) model = Wav2Vec2ForSequenceClassification.from_pretrained(MODEL_NAME).to(DEVICE) model.eval() # Use the model’s label mapping directly config = AutoConfig.from_pretrained(MODEL_NAME) LABELS = [config.id2label[i] for i in range(len(config.id2label))] # Map some emojis to each emotion for fun UI EMOJIS = { "Angry": "😑", "Disgusted": "🀒", "Fearful": "😨", "Happy": "πŸ˜„", "Neutral": "😐", "Sad": "😒", "Surprised": "😲" } # ========================= # PREDICTION FUNCTION # ========================= def predict(audio): try: if audio is None: return {"Error": "No audio provided"}, None sr, data = audio data = np.array(data, dtype=np.float32) # Stereo -> Mono if len(data.shape) > 1: data = np.mean(data, axis=1) # Resample to 16kHz if sr != 16000: data = torchaudio.functional.resample(torch.tensor(data), sr, 16000).numpy() sr = 16000 # Normalize for Wav2Vec2 data = data / 32768.0 # Feature extraction inputs = feature_extractor( data, sampling_rate=sr, return_tensors="pt", padding=True ) # Move tensors to device for k in inputs: inputs[k] = inputs[k].to(DEVICE) # Forward pass with torch.no_grad(): logits = model(**inputs).logits probs = torch.nn.functional.softmax(logits, dim=-1)[0].cpu().numpy() # Format top 3 results with emojis top_idx = np.argsort(probs)[::-1][:3] result = {f"{LABELS[i]} {EMOJIS.get(LABELS[i], '')}": round(float(probs[i]), 4) for i in top_idx} # Generate waveform plot fig, ax = plt.subplots(figsize=(6,2)) ax.plot(data, color='purple') ax.set_title("Audio Waveform") ax.set_xlabel("Samples") ax.set_ylabel("Amplitude") ax.set_xticks([]) ax.set_yticks([]) plt.tight_layout() return result, fig except Exception as e: return {"Error": str(e)}, None # ========================= # GRADIO APP # ========================= demo = gr.Interface( fn=predict, inputs=gr.Audio(sources=["upload", "microphone"], type="numpy", label="🎀 Upload or Record Audio"), outputs=[gr.Label(num_top_classes=3), gr.Plot()], title="Audio Emotion Detection 🎧", description=( "Fine-tuned Wav2Vec2 model (`Hatman/audio-emotion-detection`) " "for emotion recognition from voice. " "Detects: Angry, Disgusted, Fearful, Happy, Neutral, Sad, Surprised. " "Audio is auto-resampled to 16kHz." ), allow_flagging="never", ) # ========================= # LAUNCH # ========================= if __name__ == "__main__": demo.launch()