| |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| from datasets import load_dataset |
| from transformers import AutoProcessor, AutoModel |
| from peft import PeftModel |
| from sklearn.metrics import ( |
| f1_score, |
| recall_score, |
| accuracy_score, |
| classification_report, |
| ) |
| import json |
|
|
| EMOTIONS = ["neutral", "happy", "sad", "angry", "fear", "surprise"] |
| MODEL_ID = "mistralai/Voxtral-Mini-4B-Realtime-2602" |
| DEVICE = torch.device("cuda") |
|
|
| processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) |
|
|
|
|
| class EmotionHead(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.Linear(1280, 512), |
| nn.LayerNorm(512), |
| nn.GELU(), |
| nn.Dropout(0.3), |
| nn.Linear(512, 256), |
| nn.LayerNorm(256), |
| nn.GELU(), |
| nn.Dropout(0.3), |
| nn.Linear(256, 6), |
| ) |
|
|
| def forward(self, x): |
| return self.net(x) |
|
|
|
|
| def compute_metrics(trues, preds): |
| if not trues: |
| return {"UA": 0, "WA": 0, "F1": 0, "WF1": 0} |
| ua = recall_score(trues, preds, average="macro", zero_division=0) * 100 |
| wa = accuracy_score(trues, preds) * 100 |
| f1 = f1_score(trues, preds, average="macro", zero_division=0) * 100 |
| wf1 = f1_score(trues, preds, average="weighted", zero_division=0) * 100 |
| return { |
| "UA": round(ua, 1), |
| "WA": round(wa, 1), |
| "F1": round(f1, 1), |
| "WF1": round(wf1, 1), |
| } |
|
|
|
|
| def predict_iemocap(dataset, model, head): |
| """Predict on IEMOCAP dataset - labels from answer field text.""" |
| preds, trues = [], [] |
| model.eval() |
| head.eval() |
|
|
| for i, sample in enumerate(dataset): |
| answer = sample.get("answer", "").lower() |
|
|
| if "happy" in answer or "excited" in answer: |
| mapped = "happy" |
| elif "sad" in answer: |
| mapped = "sad" |
| elif "angry" in answer: |
| mapped = "angry" |
| elif "neutral" in answer: |
| mapped = "neutral" |
| else: |
| continue |
|
|
| if mapped not in EMOTIONS: |
| continue |
| true_idx = EMOTIONS.index(mapped) |
|
|
| context = sample.get("context", {}) |
| audio_array = context.get("array") if context else None |
| if audio_array is None: |
| continue |
|
|
| audio = np.array(audio_array, dtype=np.float32) |
| sr = context.get("sampling_rate", 16000) |
|
|
| try: |
| inputs = processor(audio, sampling_rate=sr, return_tensors="pt") |
| feats = inputs["input_features"].to(DEVICE, dtype=torch.bfloat16) |
| with torch.no_grad(): |
| hidden = model.audio_tower(feats).last_hidden_state.mean(1).float() |
| pred = head(hidden).argmax(1).item() |
| preds.append(pred) |
| trues.append(true_idx) |
| except Exception as e: |
| print(f" error at {i}: {e}") |
|
|
| if i % 200 == 0 and i > 0: |
| print(f" processed {i}...") |
|
|
| return trues, preds |
|
|
|
|
| def predict_generic(dataset, model, head, label_map): |
| """Predict on generic dataset with label mapping.""" |
| preds, trues = [], [] |
| model.eval() |
| head.eval() |
|
|
| for i, sample in enumerate(dataset): |
| raw_label = sample.get("label") or sample.get("emotion", "") |
| mapped = label_map.get(str(raw_label).upper()) or label_map.get( |
| str(raw_label).lower() |
| ) |
| if mapped is None: |
| continue |
| if mapped not in EMOTIONS: |
| continue |
| true_idx = EMOTIONS.index(mapped) |
|
|
| audio = np.array(sample["audio"]["array"], dtype=np.float32) |
| sr = sample["audio"]["sampling_rate"] |
|
|
| try: |
| inputs = processor(audio, sampling_rate=sr, return_tensors="pt") |
| feats = inputs["input_features"].to(DEVICE, dtype=torch.bfloat16) |
| with torch.no_grad(): |
| hidden = model.audio_tower(feats).last_hidden_state.mean(1).float() |
| pred = head(hidden).argmax(1).item() |
| preds.append(pred) |
| trues.append(true_idx) |
| except Exception as e: |
| print(f" error at {i}: {e}") |
|
|
| if i % 200 == 0 and i > 0: |
| print(f" processed {i}...") |
|
|
| return trues, preds |
|
|
|
|
| |
| print("Loading Voxtral base (frozen)...") |
| voxtral_frozen = ( |
| AutoModel.from_pretrained(MODEL_ID, trust_remote_code=True, dtype=torch.bfloat16) |
| .to(DEVICE) |
| .eval() |
| ) |
|
|
| print("Loading Voxtral + LoRA...") |
| voxtral_lora_base = AutoModel.from_pretrained( |
| MODEL_ID, trust_remote_code=True, dtype=torch.bfloat16 |
| ) |
| voxtral_lora = ( |
| PeftModel.from_pretrained( |
| voxtral_lora_base, "MrlolDev/voxtral-emotion-speech", subfolder="lora_adapter" |
| ) |
| .to(DEVICE) |
| .eval() |
| ) |
|
|
| print("Loading heads...") |
| head_frozen = EmotionHead().to(DEVICE) |
| head_frozen.load_state_dict( |
| torch.load("emotion_head_best.pt", map_location=DEVICE, weights_only=False) |
| ) |
| head_frozen.eval() |
|
|
| head_lora = EmotionHead().to(DEVICE) |
| head_lora.load_state_dict( |
| torch.load("emotion_head_lora_best.pt", map_location=DEVICE, weights_only=False) |
| ) |
| head_lora.eval() |
|
|
| |
| CREMA_MAP = { |
| "HAP": "happy", |
| "SAD": "sad", |
| "ANG": "angry", |
| "NEU": "neutral", |
| "FEA": "fear", |
| "DIS": "angry", |
| } |
|
|
| SENSEVOICE_REF = { |
| "IEMOCAP": {"UA": 70.5, "WA": 65.7, "F1": 67.9, "WF1": 67.8}, |
| } |
|
|
| results = {} |
|
|
| |
| print("\n=== IEMOCAP ===") |
| ds = load_dataset("AudioLLMs/iemocap_emotion_recognition", trust_remote_code=True) |
| iemocap = ds["test"] |
| print(f"Total samples: {len(iemocap)}") |
|
|
| print("Predicting with Frozen encoder...") |
| trues_f, preds_f = predict_iemocap(iemocap, voxtral_frozen, head_frozen) |
| metrics_frozen = compute_metrics(trues_f, preds_f) |
| print( |
| f" Frozen: n={len(preds_f)} | UA={metrics_frozen['UA']} WA={metrics_frozen['WA']} F1={metrics_frozen['F1']}" |
| ) |
|
|
| print("Predicting with LoRA encoder...") |
| trues_l, preds_l = predict_iemocap(iemocap, voxtral_lora, head_lora) |
| metrics_lora = compute_metrics(trues_l, preds_l) |
| print( |
| f" LoRA: n={len(preds_l)} | UA={metrics_lora['UA']} WA={metrics_lora['WA']} F1={metrics_lora['F1']}" |
| ) |
|
|
| ref = SENSEVOICE_REF.get("IEMOCAP", {}) |
| print(f"\n {'Model':<25} {'UA':>6} {'WA':>6} {'F1':>6} {'WF1':>6}") |
| print(f" {'-' * 50}") |
| print( |
| f" {'Ours (Frozen)':<25} {metrics_frozen['UA']:>6} {metrics_frozen['WA']:>6} {metrics_frozen['F1']:>6} {metrics_frozen['WF1']:>6}" |
| ) |
| print( |
| f" {'Ours (LoRA)':<25} {metrics_lora['UA']:>6} {metrics_lora['WA']:>6} {metrics_lora['F1']:>6} {metrics_lora['WF1']:>6}" |
| ) |
| if ref: |
| print( |
| f" {'SenseVoice-S (ref)':<25} {ref['UA']:>6} {ref['WA']:>6} {ref['F1']:>6} {ref['WF1']:>6}" |
| ) |
|
|
| results["IEMOCAP"] = { |
| "frozen": metrics_frozen, |
| "lora": metrics_lora, |
| "sensevoice_s": ref, |
| } |
|
|
| |
| with open("benchmark_lora_results.json", "w") as f: |
| json.dump(results, f, indent=2) |
|
|
| print("\n✅ Saved: benchmark_lora_results.json") |
|
|