#!/usr/bin/env python3 """3-Model SER Benchmark — emotion2vec vs SpeechBrain vs Whisper+Head. AI Hub 한국어 감정 데이터셋 테스트 서브셋을 사용하여 3개 모델의 정확도, 레이턴시, 메모리 사용량을 객관적으로 비교한다. Usage: # 2개 모델 먼저 (Whisper head 없이) python scripts/benchmark_ser_models.py \\ --test-dir data/evaluation/korean \\ --models emotion2vec speechbrain # 전체 3개 모델 python scripts/benchmark_ser_models.py \\ --test-dir data/evaluation/korean \\ --models emotion2vec speechbrain whisper \\ --whisper-head-ckpt data/models/whisper_emotion_head.pt # Quick smoke test python scripts/benchmark_ser_models.py \\ --test-dir data/evaluation/korean \\ --models emotion2vec --max-samples 10 """ from __future__ import annotations import argparse import csv import gc import json import logging import os import statistics import sys import tempfile import time from abc import ABC, abstractmethod from pathlib import Path import numpy as np import psutil import soundfile as sf logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") logger = logging.getLogger(__name__) # ────────────────────────────────────────────── # Constants # ────────────────────────────────────────────── EVAL_LABELS = ["neutral", "joy", "sadness", "anger", "surprise", "fear"] # Knockout criteria (from evaluation-framework.md) KNOCKOUT_F1 = 0.70 KNOCKOUT_LATENCY_MS = 500 KNOCKOUT_RAM_MB = 2048 # ────────────────────────────────────────────── # Model Adapter Interface # ────────────────────────────────────────────── class SERModelAdapter(ABC): """Abstract base for SER model adapters.""" name: str model_id: str params_m: int # millions @abstractmethod def load(self, device: str) -> None: ... @abstractmethod def predict(self, audio_path: str) -> dict[str, float]: """Return {emotion_label: score} in project taxonomy.""" ... @abstractmethod def unload(self) -> None: ... # ────────────────────────────────────────────── # Adapter 1: emotion2vec_plus_base # ────────────────────────────────────────────── class Emotion2vecAdapter(SERModelAdapter): name = "emotion2vec_plus_base" model_id = "iic/emotion2vec_plus_base" params_m = 90 # emotion2vec 9-class → project 7-class (from src/stage2/audio_emotion.py) LABEL_MAP = { "angry": "anger", "disgusted": "disgust", "fearful": "fear", "happy": "joy", "neutral": "neutral", "sad": "sadness", "surprised": "surprise", "other": "neutral", "unknown": "neutral", "生气/angry": "anger", "厌恶/disgusted": "disgust", "恐惧/fearful": "fear", "开心/happy": "joy", "中立/neutral": "neutral", "难过/sad": "sadness", "吃惊/surprised": "surprise", "其他/other": "neutral", "": "neutral", } def __init__(self): self._model = None def load(self, device: str) -> None: from funasr import AutoModel self._model = AutoModel(model=self.model_id, device=device, hub="hf") def predict(self, audio_path: str) -> dict[str, float]: output = self._model.generate( audio_path, granularity="utterance", extract_embedding=False, ) scores = {label: 0.0 for label in EVAL_LABELS} if output and isinstance(output, list) and len(output) > 0: rec = output[0] for native_label, score in zip(rec.get("labels", []), rec.get("scores", [])): mapped = self.LABEL_MAP.get(native_label, "neutral") if mapped in scores: scores[mapped] += float(score) # Normalize total = sum(scores.values()) if total > 0: scores = {k: v / total for k, v in scores.items()} return scores def unload(self) -> None: del self._model self._model = None # ────────────────────────────────────────────── # Adapter 2: SpeechBrain wav2vec2-IEMOCAP # ────────────────────────────────────────────── class SpeechBrainAdapter(SERModelAdapter): name = "speechbrain_wav2vec2" model_id = "speechbrain/emotion-recognition-wav2vec2-IEMOCAP" params_m = 314 # SpeechBrain 4-class → project taxonomy # NOTE: This model CANNOT predict fear or surprise LABEL_MAP = { "ang": "anger", "hap": "joy", "sad": "sadness", "neu": "neutral", } def __init__(self): self._classifier = None self._label_order = None # populated from label_encoder def load(self, device: str) -> None: import torch from speechbrain.inference.classifiers import EncoderClassifier self._classifier = EncoderClassifier.from_hparams( source=self.model_id, run_opts={"device": device}, ) self._classifier = self._classifier.to(device) # Get label order from label_encoder try: le = self._classifier.hparams.label_encoder # lab2ind: {'neu': 0, 'ang': 1, 'hap': 2, 'sad': 3} self._label_order = [None] * len(le.lab2ind) for lab, idx in le.lab2ind.items(): self._label_order[idx] = lab logger.info("SpeechBrain labels: %s", self._label_order) except Exception: self._label_order = ["neu", "ang", "hap", "sad"] def predict(self, audio_path: str) -> dict[str, float]: import torch import torchaudio signal, sr = torchaudio.load(audio_path) if sr != 16000: signal = torchaudio.functional.resample(signal, sr, 16000) if signal.shape[0] > 1: signal = signal.mean(dim=0, keepdim=True) # Use modules directly (classify_batch broken in SpeechBrain 1.1.0) with torch.no_grad(): feats = self._classifier.mods.wav2vec2(signal) pooled = self._classifier.mods.avg_pool(feats) logits = self._classifier.mods.output_mlp(pooled) probs = torch.softmax(logits.squeeze(1), dim=-1).squeeze().tolist() if isinstance(probs, float): probs = [probs] scores = {label: 0.0 for label in EVAL_LABELS} for sb_label, prob in zip(self._label_order, probs): mapped = self.LABEL_MAP.get(sb_label, "neutral") if mapped in scores: scores[mapped] += prob return scores def unload(self) -> None: del self._classifier self._classifier = None # ────────────────────────────────────────────── # Adapter 3: Whisper-Medium + Emotion Head # ────────────────────────────────────────────── class WhisperMediumAdapter(SERModelAdapter): name = "whisper_medium_head" model_id = "openai/whisper-medium" params_m = 769 def __init__(self, head_ckpt: str | None = None): self._encoder = None self._head = None self._processor = None self._head_ckpt = head_ckpt self._device = "cpu" def load(self, device: str) -> None: import torch from transformers import WhisperModel, WhisperFeatureExtractor self._device = device self._processor = WhisperFeatureExtractor.from_pretrained(self.model_id) self._encoder = WhisperModel.from_pretrained(self.model_id).to(device) self._encoder.eval() # Classifier head: hidden_dim → 6 classes hidden_dim = self._encoder.config.d_model # 1024 for medium self._head = torch.nn.Linear(hidden_dim, len(EVAL_LABELS)).to(device) if self._head_ckpt and Path(self._head_ckpt).exists(): logger.info("Loading Whisper emotion head from %s", self._head_ckpt) state = torch.load(self._head_ckpt, map_location=device, weights_only=True) self._head.load_state_dict(state) else: logger.warning("No trained Whisper head — using random weights (baseline)") self._head.eval() def predict(self, audio_path: str) -> dict[str, float]: import torch import librosa # Load and preprocess audio, sr = librosa.load(audio_path, sr=16000) inputs = self._processor( audio, sampling_rate=16000, return_tensors="pt", ) input_features = inputs.input_features.to(self._device) with torch.no_grad(): encoder_out = self._encoder.encoder(input_features) hidden = encoder_out.last_hidden_state # (1, T, D) pooled = hidden.mean(dim=1) # (1, D) logits = self._head(pooled) # (1, 6) probs = torch.softmax(logits, dim=-1).squeeze().cpu().tolist() scores = {} for label, prob in zip(EVAL_LABELS, probs): scores[label] = prob return scores def unload(self) -> None: del self._encoder, self._head, self._processor self._encoder = self._head = self._processor = None # ────────────────────────────────────────────── # Phone Augmentation # ────────────────────────────────────────────── def apply_phone_augmentation(audio_path: str) -> str: """Apply phone-quality degradation, return path to temp WAV file.""" sys.path.insert(0, str(Path(__file__).resolve().parent.parent / "src")) from common.phone_simulator import PhoneSimulator, CompandingType audio, sr = sf.read(audio_path, dtype="float32") if audio.ndim == 2: audio = audio.mean(axis=1) sim = PhoneSimulator(companding=CompandingType.ALAW) degraded, new_sr = sim.process(audio, sr) # Save to temp file tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) sf.write(tmp.name, degraded, new_sr, subtype="PCM_16") return tmp.name # ────────────────────────────────────────────── # Test Data Loading # ────────────────────────────────────────────── def load_test_data(test_dir: str, max_samples: int | None = None) -> list[dict]: """Load test samples from prepared subset.""" csv_path = Path(test_dir) / "test_labels.csv" if not csv_path.exists(): logger.error("test_labels.csv not found in %s", test_dir) sys.exit(1) samples = [] with open(csv_path, encoding="utf-8") as f: reader = csv.DictReader(f) for row in reader: audio_path = str(Path(test_dir) / row["file_path"]) if not Path(audio_path).exists(): logger.warning("Audio file not found: %s", audio_path) continue samples.append({ "audio_path": audio_path, "emotion": row["emotion"], "duration": float(row["duration"]), "speaker_id": row.get("speaker_id", ""), "intensity": row.get("intensity", ""), }) if max_samples and len(samples) > max_samples: import random random.seed(42) samples = random.sample(samples, max_samples) logger.info("Loaded %d test samples from %s", len(samples), test_dir) return samples # ────────────────────────────────────────────── # Benchmark Runner # ────────────────────────────────────────────── def get_process_rss_mb() -> float: return psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024) def benchmark_model( adapter: SERModelAdapter, samples: list[dict], device: str, phone_augment: bool, warmup: int = 5, ) -> dict: """Run full benchmark for one model on both clean and optionally phone conditions.""" logger.info("=" * 60) logger.info("Benchmarking: %s (%dM params)", adapter.name, adapter.params_m) logger.info("=" * 60) result = { "model": adapter.name, "model_id": adapter.model_id, "params_m": adapter.params_m, "device": device, } # Baseline RAM gc.collect() baseline_rss = get_process_rss_mb() # Load model logger.info("Loading model...") load_start = time.perf_counter() try: adapter.load(device) except Exception as e: logger.error("Failed to load %s: %s", adapter.name, e) result["error"] = str(e) return result load_time = time.perf_counter() - load_start result["load_time_sec"] = round(load_time, 2) post_load_rss = get_process_rss_mb() result["model_ram_mb"] = round(post_load_rss - baseline_rss, 1) logger.info("Loaded in %.1fs, RAM: %.0fMB", load_time, result["model_ram_mb"]) # Run for each condition conditions = ["clean"] if phone_augment: conditions.append("phone") for condition in conditions: logger.info("--- Condition: %s ---", condition) # Warmup warmup_samples = samples[:warmup] if len(samples) >= warmup else samples for s in warmup_samples: try: audio_path = s["audio_path"] if condition == "phone": audio_path = apply_phone_augmentation(audio_path) adapter.predict(audio_path) if condition == "phone": os.unlink(audio_path) except Exception: pass # Inference y_true = [] y_pred = [] latencies = [] errors = [] peak_rss = get_process_rss_mb() for i, sample in enumerate(samples): audio_path = sample["audio_path"] tmp_path = None try: if condition == "phone": tmp_path = apply_phone_augmentation(audio_path) audio_path = tmp_path t0 = time.perf_counter() scores = adapter.predict(audio_path) t1 = time.perf_counter() latency_ms = (t1 - t0) * 1000 latencies.append(latency_ms) pred_label = max(scores, key=scores.get) y_true.append(sample["emotion"]) y_pred.append(pred_label) except Exception as e: errors.append({"index": i, "error": str(e)}) logger.warning("Error on sample %d: %s", i, e) finally: if tmp_path and os.path.exists(tmp_path): os.unlink(tmp_path) current_rss = get_process_rss_mb() peak_rss = max(peak_rss, current_rss) if (i + 1) % 50 == 0: logger.info(" %d/%d done (mean lat: %.0fms)", i + 1, len(samples), statistics.mean(latencies) if latencies else 0) # Compute metrics cond_result = compute_metrics(y_true, y_pred, latencies, peak_rss - baseline_rss, errors) result[condition] = cond_result logger.info(" %s: macro_f1=%.3f, accuracy=%.3f, mean_latency=%.0fms, peak_ram=%.0fMB", condition, cond_result["macro_f1"], cond_result["accuracy"], cond_result["latency"]["mean_ms"], cond_result["peak_ram_mb"]) # Unload adapter.unload() gc.collect() try: import torch if torch.cuda.is_available(): torch.cuda.empty_cache() except ImportError: pass return result # ────────────────────────────────────────────── # Metrics # ────────────────────────────────────────────── def compute_metrics( y_true: list[str], y_pred: list[str], latencies: list[float], peak_ram_mb: float, errors: list[dict], ) -> dict: """Compute accuracy, F1, confusion matrix, latency stats.""" from sklearn.metrics import ( accuracy_score, precision_recall_fscore_support, confusion_matrix, ) if not y_true or not y_pred: return { "accuracy": 0.0, "macro_f1": 0.0, "weighted_f1": 0.0, "per_class": {l: {"precision": 0, "recall": 0, "f1": 0, "support": 0} for l in EVAL_LABELS}, "confusion_matrix": [[0] * len(EVAL_LABELS)] * len(EVAL_LABELS), "confusion_labels": EVAL_LABELS, "latency": {}, "peak_ram_mb": round(peak_ram_mb, 1), "total_samples": 0, "errors": errors, "note": "All samples failed — no predictions available", } accuracy = accuracy_score(y_true, y_pred) precision, recall, f1, support = precision_recall_fscore_support( y_true, y_pred, labels=EVAL_LABELS, average=None, zero_division=0, ) macro_f1 = float(np.mean(f1)) weighted_f1 = float(np.average(f1, weights=support)) if sum(support) > 0 else 0.0 cm = confusion_matrix(y_true, y_pred, labels=EVAL_LABELS).tolist() per_class = {} for i, label in enumerate(EVAL_LABELS): per_class[label] = { "precision": round(float(precision[i]), 4), "recall": round(float(recall[i]), 4), "f1": round(float(f1[i]), 4), "support": int(support[i]), } latency_stats = {} if latencies: latency_stats = { "mean_ms": round(statistics.mean(latencies), 1), "median_ms": round(statistics.median(latencies), 1), "std_ms": round(statistics.stdev(latencies), 1) if len(latencies) > 1 else 0, "p95_ms": round(sorted(latencies)[int(len(latencies) * 0.95)], 1), "min_ms": round(min(latencies), 1), "max_ms": round(max(latencies), 1), } return { "accuracy": round(accuracy, 4), "macro_f1": round(macro_f1, 4), "weighted_f1": round(weighted_f1, 4), "per_class": per_class, "confusion_matrix": cm, "confusion_labels": EVAL_LABELS, "latency": latency_stats, "peak_ram_mb": round(peak_ram_mb, 1), "total_samples": len(y_true), "errors": errors, } # ────────────────────────────────────────────── # Knockout Check # ────────────────────────────────────────────── def knockout_check(result: dict) -> dict: """Check if model passes knockout criteria.""" checks = {} for condition in ["clean", "phone"]: if condition not in result: continue cond = result[condition] f1_ok = cond["macro_f1"] >= KNOCKOUT_F1 lat_ok = cond["latency"].get("mean_ms", 999) <= KNOCKOUT_LATENCY_MS ram_ok = cond["peak_ram_mb"] <= KNOCKOUT_RAM_MB checks[condition] = { "korean_f1": f"{'PASS' if f1_ok else 'FAIL'} ({cond['macro_f1']:.3f} {'≥' if f1_ok else '<'} {KNOCKOUT_F1})", "latency": f"{'PASS' if lat_ok else 'FAIL'} ({cond['latency'].get('mean_ms', 0):.0f}ms {'≤' if lat_ok else '>'} {KNOCKOUT_LATENCY_MS}ms)", "ram": f"{'PASS' if ram_ok else 'FAIL'} ({cond['peak_ram_mb']:.0f}MB {'≤' if ram_ok else '>'} {KNOCKOUT_RAM_MB}MB)", "overall": "PASS" if (f1_ok and lat_ok and ram_ok) else "FAIL", } return checks # ────────────────────────────────────────────── # Report Generation # ────────────────────────────────────────────── def generate_markdown_report(all_results: dict, output_path: str): """Generate a markdown comparison report.""" lines = [ "# 3-Model SER Benchmark Report", "", f"**Generated**: {time.strftime('%Y-%m-%d %H:%M:%S')}", f"**Dataset**: AI Hub #71631 (감정이 태깅된 자유대화 - 성인)", f"**Evaluation Classes**: {', '.join(EVAL_LABELS)} (6-class, no disgust)", "", "---", "", "## Summary Comparison", "", ] # Summary table headers = ["Model", "Params", "Clean F1", "Phone F1", "Latency (mean)", "Latency (p95)", "RAM", "Knockout"] rows = [] for name, res in all_results.items(): if "error" in res: rows.append(f"| {name} | {res.get('params_m', '?')}M | LOAD FAILED | - | - | - | - | FAIL |") continue clean = res.get("clean", {}) phone = res.get("phone", {}) ko = knockout_check(res) clean_ko = ko.get("clean", {}).get("overall", "N/A") rows.append( f"| {name} | {res['params_m']}M " f"| {clean.get('macro_f1', 0):.3f} " f"| {phone.get('macro_f1', 'N/A') if phone else 'N/A'} " f"| {clean.get('latency', {}).get('mean_ms', 0):.0f}ms " f"| {clean.get('latency', {}).get('p95_ms', 0):.0f}ms " f"| {clean.get('peak_ram_mb', 0):.0f}MB " f"| {clean_ko} |" ) lines.append(f"| {' | '.join(headers)} |") lines.append(f"| {'---|' * len(headers)}") lines.extend(rows) lines.append("") # Knockout details lines.extend(["", "## Knockout Check", ""]) for name, res in all_results.items(): if "error" in res: continue ko = knockout_check(res) lines.append(f"### {name}") for cond, checks in ko.items(): lines.append(f"**{cond}**: {checks['overall']}") lines.append(f" - F1: {checks['korean_f1']}") lines.append(f" - Latency: {checks['latency']}") lines.append(f" - RAM: {checks['ram']}") lines.append("") # Per-model details with confusion matrix lines.extend(["## Per-Model Details", ""]) for name, res in all_results.items(): if "error" in res: continue lines.append(f"### {name}") for condition in ["clean", "phone"]: if condition not in res: continue cond = res[condition] lines.extend([ f"", f"#### {condition.title()} Condition", f"", f"- Accuracy: {cond['accuracy']:.3f}", f"- Macro F1: {cond['macro_f1']:.3f}", f"- Weighted F1: {cond['weighted_f1']:.3f}", f"", "**Per-class F1:**", "", "| Emotion | Precision | Recall | F1 | Support |", "|---|---|---|---|---|", ]) for label in EVAL_LABELS: pc = cond["per_class"].get(label, {}) lines.append( f"| {label} | {pc.get('precision', 0):.3f} " f"| {pc.get('recall', 0):.3f} " f"| {pc.get('f1', 0):.3f} " f"| {pc.get('support', 0)} |" ) # Confusion matrix lines.extend(["", "**Confusion Matrix:**", ""]) cm = cond.get("confusion_matrix", []) if cm: lines.append("| | " + " | ".join(EVAL_LABELS) + " |") lines.append("| --- | " + " | ".join(["---"] * len(EVAL_LABELS)) + " |") for i, row in enumerate(cm): lines.append(f"| **{EVAL_LABELS[i]}** | " + " | ".join(str(v) for v in row) + " |") lines.append("") # Limitations lines.extend([ "## Known Limitations", "", "- **SpeechBrain wav2vec2-IEMOCAP**: Only outputs 4 classes (angry, happy, sad, neutral). " "Cannot predict fear or surprise → structurally penalized in 6-class macro F1.", "- **Whisper-Medium + Head**: Requires a separately trained classifier head. " "Without training, results reflect random baseline (~16.7%).", "- **AI Hub dataset**: No 'disgust' class → evaluated as 6-class instead of project's 7-class.", "", ]) Path(output_path).parent.mkdir(parents=True, exist_ok=True) with open(output_path, "w", encoding="utf-8") as f: f.write("\n".join(lines)) logger.info("Markdown report saved to %s", output_path) # ────────────────────────────────────────────── # Main # ────────────────────────────────────────────── ADAPTER_MAP = { "emotion2vec": Emotion2vecAdapter, "speechbrain": SpeechBrainAdapter, "whisper": WhisperMediumAdapter, } def main(): parser = argparse.ArgumentParser(description="3-Model SER Benchmark") parser.add_argument("--test-dir", required=True, help="테스트 서브셋 디렉토리 (test_labels.csv 포함)") parser.add_argument("--models", nargs="+", default=["emotion2vec", "speechbrain"], choices=list(ADAPTER_MAP.keys()), help="벤치마크할 모델 (default: emotion2vec speechbrain)") parser.add_argument("--device", default="cpu", choices=["cpu", "cuda"], help="Compute device") parser.add_argument("--whisper-head-ckpt", default=None, help="Whisper emotion head 체크포인트 경로") parser.add_argument("--phone-augment", action="store_true", default=False, help="Phone augmentation 평가 추가") parser.add_argument("--warmup", type=int, default=5, help="Warmup 횟수") parser.add_argument("--max-samples", type=int, default=None, help="최대 샘플 수 (smoke test용)") parser.add_argument("--output-json", default="data/evaluation/benchmark_3model_results.json") parser.add_argument("--output-md", default="docs/stage2/benchmark-3model-report.md") args = parser.parse_args() # Load test data samples = load_test_data(args.test_dir, max_samples=args.max_samples) if not samples: logger.error("No test samples loaded") sys.exit(1) # Run benchmarks all_results = {} for model_name in args.models: adapter_cls = ADAPTER_MAP[model_name] if model_name == "whisper": adapter = adapter_cls(head_ckpt=args.whisper_head_ckpt) else: adapter = adapter_cls() result = benchmark_model( adapter, samples, args.device, phone_augment=args.phone_augment, warmup=args.warmup, ) result["knockout"] = knockout_check(result) all_results[adapter.name] = result # Save JSON output_json_path = Path(args.output_json) output_json_path.parent.mkdir(parents=True, exist_ok=True) import platform try: import torch torch_version = torch.__version__ cuda_available = torch.cuda.is_available() except ImportError: torch_version = "not installed" cuda_available = False output_data = { "metadata": { "timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"), "device": args.device, "test_samples": len(samples), "eval_classes": EVAL_LABELS, "conditions": ["clean"] + (["phone"] if args.phone_augment else []), "system_info": { "cpu": platform.processor() or "unknown", "ram_total_gb": round(psutil.virtual_memory().total / (1024**3), 1), "python": platform.python_version(), "torch": torch_version, "cuda": cuda_available, }, }, "results": all_results, } with open(output_json_path, "w", encoding="utf-8") as f: json.dump(output_data, f, indent=2, ensure_ascii=False, default=str) logger.info("JSON results saved to %s", output_json_path) # Generate markdown report generate_markdown_report(all_results, args.output_md) # Console summary print("\n" + "=" * 60) print("BENCHMARK COMPLETE") print("=" * 60) for name, res in all_results.items(): if "error" in res: print(f"\n {name}: LOAD FAILED — {res['error']}") continue clean = res.get("clean", {}) ko = res.get("knockout", {}).get("clean", {}) print(f"\n {name} ({res['params_m']}M params):") print(f" Clean F1: {clean.get('macro_f1', 0):.3f} Accuracy: {clean.get('accuracy', 0):.3f}") print(f" Latency: {clean.get('latency', {}).get('mean_ms', 0):.0f}ms (mean), " f"{clean.get('latency', {}).get('p95_ms', 0):.0f}ms (p95)") print(f" RAM: {clean.get('peak_ram_mb', 0):.0f}MB") print(f" Knockout: {ko.get('overall', 'N/A')}") if args.phone_augment: print("\n --- Phone Degradation ---") for name, res in all_results.items(): if "error" in res or "phone" not in res: continue clean_f1 = res.get("clean", {}).get("macro_f1", 0) phone_f1 = res["phone"]["macro_f1"] drop = clean_f1 - phone_f1 print(f" {name}: {clean_f1:.3f} → {phone_f1:.3f} (Δ={drop:+.3f})") print(f"\n Results: {args.output_json}") print(f" Report: {args.output_md}") if __name__ == "__main__": main()