| import gradio as gr |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| from transformers import AutoProcessor, AutoModel |
|
|
| EMOTIONS = ["neutral", "happy", "sad", "angry", "fear", "surprise"] |
| EMOJIS = { |
| "neutral": "😐", |
| "happy": "😄", |
| "sad": "😢", |
| "angry": "😡", |
| "fear": "😨", |
| "surprise": "😲", |
| } |
|
|
|
|
| class EmotionHead(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.Linear(1280, 512), |
| nn.BatchNorm1d(512), |
| nn.ReLU(), |
| nn.Dropout(0.3), |
| nn.Linear(512, 256), |
| nn.BatchNorm1d(256), |
| nn.ReLU(), |
| nn.Dropout(0.3), |
| nn.Linear(256, 6), |
| ) |
|
|
| def forward(self, x): |
| return self.net(x) |
|
|
|
|
| device = torch.device("cuda") |
| MODEL_ID = "mistralai/Voxtral-Mini-4B-Realtime-2602" |
| print("Loading models...") |
| processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) |
| voxtral = ( |
| AutoModel.from_pretrained(MODEL_ID, trust_remote_code=True, dtype=torch.bfloat16) |
| .to(device) |
| .eval() |
| ) |
| emotion_model = EmotionHead().to(device) |
| emotion_model.load_state_dict( |
| torch.load("emotion_head_best.pt", map_location=device, weights_only=False) |
| ) |
| emotion_model.eval() |
| print("Ready") |
|
|
| |
| audio_buffer = [] |
| last_transcript = "" |
|
|
|
|
| def process_stream(audio, state): |
| if audio is None: |
| return state["emotion"], state["transcript"], state |
|
|
| sr, chunk = audio |
| chunk = chunk.astype(np.float32) |
| if chunk.max() > 1.0: |
| chunk = chunk / 32768.0 |
| if chunk.ndim > 1: |
| chunk = chunk.mean(axis=1) |
|
|
| |
| state["buffer"].extend(chunk.tolist()) |
|
|
| |
| min_samples = sr * 2 |
| if len(state["buffer"]) < min_samples: |
| return state["emotion"], state["transcript"], state |
|
|
| data = np.array(state["buffer"], dtype=np.float32) |
|
|
| |
| if sr != 16000: |
| import librosa |
|
|
| data = librosa.resample(data, orig_sr=sr, target_sr=16000) |
| sr = 16000 |
|
|
| try: |
| inputs = processor(data, return_tensors="pt") |
| inputs = { |
| k: v.to(device=device, dtype=torch.bfloat16) |
| if "feature" in k |
| else v.to(device) |
| for k, v in inputs.items() |
| } |
| feats = inputs["input_features"] |
|
|
| with torch.no_grad(): |
| |
| out = voxtral.generate(**inputs, max_new_tokens=200) |
| text = processor.decode(out[0], skip_special_tokens=True) |
|
|
| |
| hidden = voxtral.audio_tower(feats).last_hidden_state.mean(1).float() |
| probs = F.softmax(emotion_model(hidden), dim=1).squeeze(0) |
|
|
| idx = probs.argmax().item() |
| emotion = EMOTIONS[idx] |
| confidence = probs[idx].item() |
| emoji = EMOJIS[emotion] |
|
|
| bars = "\n".join( |
| [f"{EMOJIS[e]} {e}: {probs[i].item():.0%}" for i, e in enumerate(EMOTIONS)] |
| ) |
|
|
| emotion_str = f"{emoji} {emotion} ({confidence:.0%})\n\n{bars}" |
|
|
| |
| state["buffer"] = state["buffer"][-sr * 2 :] |
| state["emotion"] = emotion_str |
| state["transcript"] = text |
|
|
| except Exception as e: |
| print(f"Error: {e}") |
|
|
| return state["emotion"], state["transcript"], state |
|
|
|
|
| def reset_state(): |
| return {"buffer": [], "emotion": "", "transcript": ""} |
|
|
|
|
| with gr.Blocks(title="Voxtral + Emotion") as demo: |
| gr.Markdown("## 🎙️ Voxtral Realtime + Emotion Detection") |
| gr.Markdown("Habla — la transcripción y emoción se actualizan cada ~2 segundos") |
|
|
| state = gr.State(reset_state()) |
|
|
| with gr.Row(): |
| audio = gr.Audio( |
| sources=["microphone"], streaming=True, type="numpy", label="Micrófono" |
| ) |
|
|
| with gr.Row(): |
| emotion_out = gr.Textbox(label="Emoción", lines=8) |
| transcript_out = gr.Textbox(label="Transcripción", lines=8) |
|
|
| audio.stream( |
| fn=process_stream, |
| inputs=[audio, state], |
| outputs=[emotion_out, transcript_out, state], |
| ) |
|
|
| gr.Button("Reset").click( |
| fn=lambda: ("", "", reset_state()), outputs=[emotion_out, transcript_out, state] |
| ) |
|
|
| demo.launch(share=True) |
|
|