MrlolDev's picture
Upload preview.py with huggingface_hub
ff42eed verified
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from transformers import AutoProcessor, AutoModel
EMOTIONS = ["neutral", "happy", "sad", "angry", "fear", "surprise"]
EMOJIS = {
"neutral": "😐",
"happy": "😄",
"sad": "😢",
"angry": "😡",
"fear": "😨",
"surprise": "😲",
}
class EmotionHead(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Linear(1280, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(256, 6),
)
def forward(self, x):
return self.net(x)
device = torch.device("cuda")
MODEL_ID = "mistralai/Voxtral-Mini-4B-Realtime-2602"
print("Loading models...")
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
voxtral = (
AutoModel.from_pretrained(MODEL_ID, trust_remote_code=True, dtype=torch.bfloat16)
.to(device)
.eval()
)
emotion_model = EmotionHead().to(device)
emotion_model.load_state_dict(
torch.load("emotion_head_best.pt", map_location=device, weights_only=False)
)
emotion_model.eval()
print("Ready")
# Buffer acumulado de audio
audio_buffer = []
last_transcript = ""
def process_stream(audio, state):
if audio is None:
return state["emotion"], state["transcript"], state
sr, chunk = audio
chunk = chunk.astype(np.float32)
if chunk.max() > 1.0:
chunk = chunk / 32768.0
if chunk.ndim > 1:
chunk = chunk.mean(axis=1)
# Acumula chunks
state["buffer"].extend(chunk.tolist())
# Procesa cada ~2 segundos acumulados
min_samples = sr * 2
if len(state["buffer"]) < min_samples:
return state["emotion"], state["transcript"], state
data = np.array(state["buffer"], dtype=np.float32)
# Resample to 16000 Hz if needed
if sr != 16000:
import librosa
data = librosa.resample(data, orig_sr=sr, target_sr=16000)
sr = 16000
try:
inputs = processor(data, return_tensors="pt")
inputs = {
k: v.to(device=device, dtype=torch.bfloat16)
if "feature" in k
else v.to(device)
for k, v in inputs.items()
}
feats = inputs["input_features"]
with torch.no_grad():
# Transcripción
out = voxtral.generate(**inputs, max_new_tokens=200)
text = processor.decode(out[0], skip_special_tokens=True)
# Emoción
hidden = voxtral.audio_tower(feats).last_hidden_state.mean(1).float()
probs = F.softmax(emotion_model(hidden), dim=1).squeeze(0)
idx = probs.argmax().item()
emotion = EMOTIONS[idx]
confidence = probs[idx].item()
emoji = EMOJIS[emotion]
bars = "\n".join(
[f"{EMOJIS[e]} {e}: {probs[i].item():.0%}" for i, e in enumerate(EMOTIONS)]
)
emotion_str = f"{emoji} {emotion} ({confidence:.0%})\n\n{bars}"
# Solo guarda los ultimos 2s para el siguiente ciclo (sliding window)
state["buffer"] = state["buffer"][-sr * 2 :]
state["emotion"] = emotion_str
state["transcript"] = text
except Exception as e:
print(f"Error: {e}")
return state["emotion"], state["transcript"], state
def reset_state():
return {"buffer": [], "emotion": "", "transcript": ""}
with gr.Blocks(title="Voxtral + Emotion") as demo:
gr.Markdown("## 🎙️ Voxtral Realtime + Emotion Detection")
gr.Markdown("Habla — la transcripción y emoción se actualizan cada ~2 segundos")
state = gr.State(reset_state())
with gr.Row():
audio = gr.Audio(
sources=["microphone"], streaming=True, type="numpy", label="Micrófono"
)
with gr.Row():
emotion_out = gr.Textbox(label="Emoción", lines=8)
transcript_out = gr.Textbox(label="Transcripción", lines=8)
audio.stream(
fn=process_stream,
inputs=[audio, state],
outputs=[emotion_out, transcript_out, state],
)
gr.Button("Reset").click(
fn=lambda: ("", "", reset_state()), outputs=[emotion_out, transcript_out, state]
)
demo.launch(share=True)