ustwo-api / scripts /benchmark_ser_models.py
asdfasdfqrqwer's picture
Deploy from GitHub 2026-04-23T03:56:31Z
c857b85
Raw
History Blame Contribute Delete
30.6 kB
#!/usr/bin/env python3
"""3-Model SER Benchmark β€” emotion2vec vs SpeechBrain vs Whisper+Head.
AI Hub ν•œκ΅­μ–΄ 감정 데이터셋 ν…ŒμŠ€νŠΈ μ„œλΈŒμ…‹μ„ μ‚¬μš©ν•˜μ—¬ 3개 λͺ¨λΈμ˜
정확도, λ ˆμ΄ν„΄μ‹œ, λ©”λͺ¨λ¦¬ μ‚¬μš©λŸ‰μ„ κ°κ΄€μ μœΌλ‘œ λΉ„κ΅ν•œλ‹€.
Usage:
# 2개 λͺ¨λΈ λ¨Όμ € (Whisper head 없이)
python scripts/benchmark_ser_models.py \\
--test-dir data/evaluation/korean \\
--models emotion2vec speechbrain
# 전체 3개 λͺ¨λΈ
python scripts/benchmark_ser_models.py \\
--test-dir data/evaluation/korean \\
--models emotion2vec speechbrain whisper \\
--whisper-head-ckpt data/models/whisper_emotion_head.pt
# Quick smoke test
python scripts/benchmark_ser_models.py \\
--test-dir data/evaluation/korean \\
--models emotion2vec --max-samples 10
"""
from __future__ import annotations
import argparse
import csv
import gc
import json
import logging
import os
import statistics
import sys
import tempfile
import time
from abc import ABC, abstractmethod
from pathlib import Path
import numpy as np
import psutil
import soundfile as sf
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger(__name__)
# ──────────────────────────────────────────────
# Constants
# ──────────────────────────────────────────────
EVAL_LABELS = ["neutral", "joy", "sadness", "anger", "surprise", "fear"]
# Knockout criteria (from evaluation-framework.md)
KNOCKOUT_F1 = 0.70
KNOCKOUT_LATENCY_MS = 500
KNOCKOUT_RAM_MB = 2048
# ──────────────────────────────────────────────
# Model Adapter Interface
# ──────────────────────────────────────────────
class SERModelAdapter(ABC):
"""Abstract base for SER model adapters."""
name: str
model_id: str
params_m: int # millions
@abstractmethod
def load(self, device: str) -> None:
...
@abstractmethod
def predict(self, audio_path: str) -> dict[str, float]:
"""Return {emotion_label: score} in project taxonomy."""
...
@abstractmethod
def unload(self) -> None:
...
# ──────────────────────────────────────────────
# Adapter 1: emotion2vec_plus_base
# ──────────────────────────────────────────────
class Emotion2vecAdapter(SERModelAdapter):
name = "emotion2vec_plus_base"
model_id = "iic/emotion2vec_plus_base"
params_m = 90
# emotion2vec 9-class β†’ project 7-class (from src/stage2/audio_emotion.py)
LABEL_MAP = {
"angry": "anger", "disgusted": "disgust", "fearful": "fear",
"happy": "joy", "neutral": "neutral", "sad": "sadness",
"surprised": "surprise", "other": "neutral", "unknown": "neutral",
"η”Ÿζ°”/angry": "anger", "厌恢/disgusted": "disgust",
"恐惧/fearful": "fear", "εΌ€εΏƒ/happy": "joy",
"δΈ­η«‹/neutral": "neutral", "ιšΎθΏ‡/sad": "sadness",
"εƒζƒŠ/surprised": "surprise", "ε…Άδ»–/other": "neutral", "<unk>": "neutral",
}
def __init__(self):
self._model = None
def load(self, device: str) -> None:
from funasr import AutoModel
self._model = AutoModel(model=self.model_id, device=device, hub="hf")
def predict(self, audio_path: str) -> dict[str, float]:
output = self._model.generate(
audio_path, granularity="utterance", extract_embedding=False,
)
scores = {label: 0.0 for label in EVAL_LABELS}
if output and isinstance(output, list) and len(output) > 0:
rec = output[0]
for native_label, score in zip(rec.get("labels", []), rec.get("scores", [])):
mapped = self.LABEL_MAP.get(native_label, "neutral")
if mapped in scores:
scores[mapped] += float(score)
# Normalize
total = sum(scores.values())
if total > 0:
scores = {k: v / total for k, v in scores.items()}
return scores
def unload(self) -> None:
del self._model
self._model = None
# ──────────────────────────────────────────────
# Adapter 2: SpeechBrain wav2vec2-IEMOCAP
# ──────────────────────────────────────────────
class SpeechBrainAdapter(SERModelAdapter):
name = "speechbrain_wav2vec2"
model_id = "speechbrain/emotion-recognition-wav2vec2-IEMOCAP"
params_m = 314
# SpeechBrain 4-class β†’ project taxonomy
# NOTE: This model CANNOT predict fear or surprise
LABEL_MAP = {
"ang": "anger",
"hap": "joy",
"sad": "sadness",
"neu": "neutral",
}
def __init__(self):
self._classifier = None
self._label_order = None # populated from label_encoder
def load(self, device: str) -> None:
import torch
from speechbrain.inference.classifiers import EncoderClassifier
self._classifier = EncoderClassifier.from_hparams(
source=self.model_id,
run_opts={"device": device},
)
self._classifier = self._classifier.to(device)
# Get label order from label_encoder
try:
le = self._classifier.hparams.label_encoder
# lab2ind: {'neu': 0, 'ang': 1, 'hap': 2, 'sad': 3}
self._label_order = [None] * len(le.lab2ind)
for lab, idx in le.lab2ind.items():
self._label_order[idx] = lab
logger.info("SpeechBrain labels: %s", self._label_order)
except Exception:
self._label_order = ["neu", "ang", "hap", "sad"]
def predict(self, audio_path: str) -> dict[str, float]:
import torch
import torchaudio
signal, sr = torchaudio.load(audio_path)
if sr != 16000:
signal = torchaudio.functional.resample(signal, sr, 16000)
if signal.shape[0] > 1:
signal = signal.mean(dim=0, keepdim=True)
# Use modules directly (classify_batch broken in SpeechBrain 1.1.0)
with torch.no_grad():
feats = self._classifier.mods.wav2vec2(signal)
pooled = self._classifier.mods.avg_pool(feats)
logits = self._classifier.mods.output_mlp(pooled)
probs = torch.softmax(logits.squeeze(1), dim=-1).squeeze().tolist()
if isinstance(probs, float):
probs = [probs]
scores = {label: 0.0 for label in EVAL_LABELS}
for sb_label, prob in zip(self._label_order, probs):
mapped = self.LABEL_MAP.get(sb_label, "neutral")
if mapped in scores:
scores[mapped] += prob
return scores
def unload(self) -> None:
del self._classifier
self._classifier = None
# ──────────────────────────────────────────────
# Adapter 3: Whisper-Medium + Emotion Head
# ──────────────────────────────────────────────
class WhisperMediumAdapter(SERModelAdapter):
name = "whisper_medium_head"
model_id = "openai/whisper-medium"
params_m = 769
def __init__(self, head_ckpt: str | None = None):
self._encoder = None
self._head = None
self._processor = None
self._head_ckpt = head_ckpt
self._device = "cpu"
def load(self, device: str) -> None:
import torch
from transformers import WhisperModel, WhisperFeatureExtractor
self._device = device
self._processor = WhisperFeatureExtractor.from_pretrained(self.model_id)
self._encoder = WhisperModel.from_pretrained(self.model_id).to(device)
self._encoder.eval()
# Classifier head: hidden_dim β†’ 6 classes
hidden_dim = self._encoder.config.d_model # 1024 for medium
self._head = torch.nn.Linear(hidden_dim, len(EVAL_LABELS)).to(device)
if self._head_ckpt and Path(self._head_ckpt).exists():
logger.info("Loading Whisper emotion head from %s", self._head_ckpt)
state = torch.load(self._head_ckpt, map_location=device, weights_only=True)
self._head.load_state_dict(state)
else:
logger.warning("No trained Whisper head β€” using random weights (baseline)")
self._head.eval()
def predict(self, audio_path: str) -> dict[str, float]:
import torch
import librosa
# Load and preprocess
audio, sr = librosa.load(audio_path, sr=16000)
inputs = self._processor(
audio, sampling_rate=16000, return_tensors="pt",
)
input_features = inputs.input_features.to(self._device)
with torch.no_grad():
encoder_out = self._encoder.encoder(input_features)
hidden = encoder_out.last_hidden_state # (1, T, D)
pooled = hidden.mean(dim=1) # (1, D)
logits = self._head(pooled) # (1, 6)
probs = torch.softmax(logits, dim=-1).squeeze().cpu().tolist()
scores = {}
for label, prob in zip(EVAL_LABELS, probs):
scores[label] = prob
return scores
def unload(self) -> None:
del self._encoder, self._head, self._processor
self._encoder = self._head = self._processor = None
# ──────────────────────────────────────────────
# Phone Augmentation
# ──────────────────────────────────────────────
def apply_phone_augmentation(audio_path: str) -> str:
"""Apply phone-quality degradation, return path to temp WAV file."""
sys.path.insert(0, str(Path(__file__).resolve().parent.parent / "src"))
from common.phone_simulator import PhoneSimulator, CompandingType
audio, sr = sf.read(audio_path, dtype="float32")
if audio.ndim == 2:
audio = audio.mean(axis=1)
sim = PhoneSimulator(companding=CompandingType.ALAW)
degraded, new_sr = sim.process(audio, sr)
# Save to temp file
tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
sf.write(tmp.name, degraded, new_sr, subtype="PCM_16")
return tmp.name
# ──────────────────────────────────────────────
# Test Data Loading
# ──────────────────────────────────────────────
def load_test_data(test_dir: str, max_samples: int | None = None) -> list[dict]:
"""Load test samples from prepared subset."""
csv_path = Path(test_dir) / "test_labels.csv"
if not csv_path.exists():
logger.error("test_labels.csv not found in %s", test_dir)
sys.exit(1)
samples = []
with open(csv_path, encoding="utf-8") as f:
reader = csv.DictReader(f)
for row in reader:
audio_path = str(Path(test_dir) / row["file_path"])
if not Path(audio_path).exists():
logger.warning("Audio file not found: %s", audio_path)
continue
samples.append({
"audio_path": audio_path,
"emotion": row["emotion"],
"duration": float(row["duration"]),
"speaker_id": row.get("speaker_id", ""),
"intensity": row.get("intensity", ""),
})
if max_samples and len(samples) > max_samples:
import random
random.seed(42)
samples = random.sample(samples, max_samples)
logger.info("Loaded %d test samples from %s", len(samples), test_dir)
return samples
# ──────────────────────────────────────────────
# Benchmark Runner
# ──────────────────────────────────────────────
def get_process_rss_mb() -> float:
return psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)
def benchmark_model(
adapter: SERModelAdapter,
samples: list[dict],
device: str,
phone_augment: bool,
warmup: int = 5,
) -> dict:
"""Run full benchmark for one model on both clean and optionally phone conditions."""
logger.info("=" * 60)
logger.info("Benchmarking: %s (%dM params)", adapter.name, adapter.params_m)
logger.info("=" * 60)
result = {
"model": adapter.name,
"model_id": adapter.model_id,
"params_m": adapter.params_m,
"device": device,
}
# Baseline RAM
gc.collect()
baseline_rss = get_process_rss_mb()
# Load model
logger.info("Loading model...")
load_start = time.perf_counter()
try:
adapter.load(device)
except Exception as e:
logger.error("Failed to load %s: %s", adapter.name, e)
result["error"] = str(e)
return result
load_time = time.perf_counter() - load_start
result["load_time_sec"] = round(load_time, 2)
post_load_rss = get_process_rss_mb()
result["model_ram_mb"] = round(post_load_rss - baseline_rss, 1)
logger.info("Loaded in %.1fs, RAM: %.0fMB", load_time, result["model_ram_mb"])
# Run for each condition
conditions = ["clean"]
if phone_augment:
conditions.append("phone")
for condition in conditions:
logger.info("--- Condition: %s ---", condition)
# Warmup
warmup_samples = samples[:warmup] if len(samples) >= warmup else samples
for s in warmup_samples:
try:
audio_path = s["audio_path"]
if condition == "phone":
audio_path = apply_phone_augmentation(audio_path)
adapter.predict(audio_path)
if condition == "phone":
os.unlink(audio_path)
except Exception:
pass
# Inference
y_true = []
y_pred = []
latencies = []
errors = []
peak_rss = get_process_rss_mb()
for i, sample in enumerate(samples):
audio_path = sample["audio_path"]
tmp_path = None
try:
if condition == "phone":
tmp_path = apply_phone_augmentation(audio_path)
audio_path = tmp_path
t0 = time.perf_counter()
scores = adapter.predict(audio_path)
t1 = time.perf_counter()
latency_ms = (t1 - t0) * 1000
latencies.append(latency_ms)
pred_label = max(scores, key=scores.get)
y_true.append(sample["emotion"])
y_pred.append(pred_label)
except Exception as e:
errors.append({"index": i, "error": str(e)})
logger.warning("Error on sample %d: %s", i, e)
finally:
if tmp_path and os.path.exists(tmp_path):
os.unlink(tmp_path)
current_rss = get_process_rss_mb()
peak_rss = max(peak_rss, current_rss)
if (i + 1) % 50 == 0:
logger.info(" %d/%d done (mean lat: %.0fms)", i + 1, len(samples),
statistics.mean(latencies) if latencies else 0)
# Compute metrics
cond_result = compute_metrics(y_true, y_pred, latencies, peak_rss - baseline_rss, errors)
result[condition] = cond_result
logger.info(" %s: macro_f1=%.3f, accuracy=%.3f, mean_latency=%.0fms, peak_ram=%.0fMB",
condition,
cond_result["macro_f1"],
cond_result["accuracy"],
cond_result["latency"]["mean_ms"],
cond_result["peak_ram_mb"])
# Unload
adapter.unload()
gc.collect()
try:
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
except ImportError:
pass
return result
# ──────────────────────────────────────────────
# Metrics
# ──────────────────────────────────────────────
def compute_metrics(
y_true: list[str],
y_pred: list[str],
latencies: list[float],
peak_ram_mb: float,
errors: list[dict],
) -> dict:
"""Compute accuracy, F1, confusion matrix, latency stats."""
from sklearn.metrics import (
accuracy_score,
precision_recall_fscore_support,
confusion_matrix,
)
if not y_true or not y_pred:
return {
"accuracy": 0.0, "macro_f1": 0.0, "weighted_f1": 0.0,
"per_class": {l: {"precision": 0, "recall": 0, "f1": 0, "support": 0} for l in EVAL_LABELS},
"confusion_matrix": [[0] * len(EVAL_LABELS)] * len(EVAL_LABELS),
"confusion_labels": EVAL_LABELS,
"latency": {}, "peak_ram_mb": round(peak_ram_mb, 1),
"total_samples": 0, "errors": errors,
"note": "All samples failed β€” no predictions available",
}
accuracy = accuracy_score(y_true, y_pred)
precision, recall, f1, support = precision_recall_fscore_support(
y_true, y_pred, labels=EVAL_LABELS, average=None, zero_division=0,
)
macro_f1 = float(np.mean(f1))
weighted_f1 = float(np.average(f1, weights=support)) if sum(support) > 0 else 0.0
cm = confusion_matrix(y_true, y_pred, labels=EVAL_LABELS).tolist()
per_class = {}
for i, label in enumerate(EVAL_LABELS):
per_class[label] = {
"precision": round(float(precision[i]), 4),
"recall": round(float(recall[i]), 4),
"f1": round(float(f1[i]), 4),
"support": int(support[i]),
}
latency_stats = {}
if latencies:
latency_stats = {
"mean_ms": round(statistics.mean(latencies), 1),
"median_ms": round(statistics.median(latencies), 1),
"std_ms": round(statistics.stdev(latencies), 1) if len(latencies) > 1 else 0,
"p95_ms": round(sorted(latencies)[int(len(latencies) * 0.95)], 1),
"min_ms": round(min(latencies), 1),
"max_ms": round(max(latencies), 1),
}
return {
"accuracy": round(accuracy, 4),
"macro_f1": round(macro_f1, 4),
"weighted_f1": round(weighted_f1, 4),
"per_class": per_class,
"confusion_matrix": cm,
"confusion_labels": EVAL_LABELS,
"latency": latency_stats,
"peak_ram_mb": round(peak_ram_mb, 1),
"total_samples": len(y_true),
"errors": errors,
}
# ──────────────────────────────────────────────
# Knockout Check
# ──────────────────────────────────────────────
def knockout_check(result: dict) -> dict:
"""Check if model passes knockout criteria."""
checks = {}
for condition in ["clean", "phone"]:
if condition not in result:
continue
cond = result[condition]
f1_ok = cond["macro_f1"] >= KNOCKOUT_F1
lat_ok = cond["latency"].get("mean_ms", 999) <= KNOCKOUT_LATENCY_MS
ram_ok = cond["peak_ram_mb"] <= KNOCKOUT_RAM_MB
checks[condition] = {
"korean_f1": f"{'PASS' if f1_ok else 'FAIL'} ({cond['macro_f1']:.3f} {'β‰₯' if f1_ok else '<'} {KNOCKOUT_F1})",
"latency": f"{'PASS' if lat_ok else 'FAIL'} ({cond['latency'].get('mean_ms', 0):.0f}ms {'≀' if lat_ok else '>'} {KNOCKOUT_LATENCY_MS}ms)",
"ram": f"{'PASS' if ram_ok else 'FAIL'} ({cond['peak_ram_mb']:.0f}MB {'≀' if ram_ok else '>'} {KNOCKOUT_RAM_MB}MB)",
"overall": "PASS" if (f1_ok and lat_ok and ram_ok) else "FAIL",
}
return checks
# ──────────────────────────────────────────────
# Report Generation
# ──────────────────────────────────────────────
def generate_markdown_report(all_results: dict, output_path: str):
"""Generate a markdown comparison report."""
lines = [
"# 3-Model SER Benchmark Report",
"",
f"**Generated**: {time.strftime('%Y-%m-%d %H:%M:%S')}",
f"**Dataset**: AI Hub #71631 (감정이 νƒœκΉ…λœ μžμœ λŒ€ν™” - 성인)",
f"**Evaluation Classes**: {', '.join(EVAL_LABELS)} (6-class, no disgust)",
"",
"---",
"",
"## Summary Comparison",
"",
]
# Summary table
headers = ["Model", "Params", "Clean F1", "Phone F1", "Latency (mean)", "Latency (p95)", "RAM", "Knockout"]
rows = []
for name, res in all_results.items():
if "error" in res:
rows.append(f"| {name} | {res.get('params_m', '?')}M | LOAD FAILED | - | - | - | - | FAIL |")
continue
clean = res.get("clean", {})
phone = res.get("phone", {})
ko = knockout_check(res)
clean_ko = ko.get("clean", {}).get("overall", "N/A")
rows.append(
f"| {name} | {res['params_m']}M "
f"| {clean.get('macro_f1', 0):.3f} "
f"| {phone.get('macro_f1', 'N/A') if phone else 'N/A'} "
f"| {clean.get('latency', {}).get('mean_ms', 0):.0f}ms "
f"| {clean.get('latency', {}).get('p95_ms', 0):.0f}ms "
f"| {clean.get('peak_ram_mb', 0):.0f}MB "
f"| {clean_ko} |"
)
lines.append(f"| {' | '.join(headers)} |")
lines.append(f"| {'---|' * len(headers)}")
lines.extend(rows)
lines.append("")
# Knockout details
lines.extend(["", "## Knockout Check", ""])
for name, res in all_results.items():
if "error" in res:
continue
ko = knockout_check(res)
lines.append(f"### {name}")
for cond, checks in ko.items():
lines.append(f"**{cond}**: {checks['overall']}")
lines.append(f" - F1: {checks['korean_f1']}")
lines.append(f" - Latency: {checks['latency']}")
lines.append(f" - RAM: {checks['ram']}")
lines.append("")
# Per-model details with confusion matrix
lines.extend(["## Per-Model Details", ""])
for name, res in all_results.items():
if "error" in res:
continue
lines.append(f"### {name}")
for condition in ["clean", "phone"]:
if condition not in res:
continue
cond = res[condition]
lines.extend([
f"",
f"#### {condition.title()} Condition",
f"",
f"- Accuracy: {cond['accuracy']:.3f}",
f"- Macro F1: {cond['macro_f1']:.3f}",
f"- Weighted F1: {cond['weighted_f1']:.3f}",
f"",
"**Per-class F1:**",
"",
"| Emotion | Precision | Recall | F1 | Support |",
"|---|---|---|---|---|",
])
for label in EVAL_LABELS:
pc = cond["per_class"].get(label, {})
lines.append(
f"| {label} | {pc.get('precision', 0):.3f} "
f"| {pc.get('recall', 0):.3f} "
f"| {pc.get('f1', 0):.3f} "
f"| {pc.get('support', 0)} |"
)
# Confusion matrix
lines.extend(["", "**Confusion Matrix:**", ""])
cm = cond.get("confusion_matrix", [])
if cm:
lines.append("| | " + " | ".join(EVAL_LABELS) + " |")
lines.append("| --- | " + " | ".join(["---"] * len(EVAL_LABELS)) + " |")
for i, row in enumerate(cm):
lines.append(f"| **{EVAL_LABELS[i]}** | " + " | ".join(str(v) for v in row) + " |")
lines.append("")
# Limitations
lines.extend([
"## Known Limitations",
"",
"- **SpeechBrain wav2vec2-IEMOCAP**: Only outputs 4 classes (angry, happy, sad, neutral). "
"Cannot predict fear or surprise β†’ structurally penalized in 6-class macro F1.",
"- **Whisper-Medium + Head**: Requires a separately trained classifier head. "
"Without training, results reflect random baseline (~16.7%).",
"- **AI Hub dataset**: No 'disgust' class β†’ evaluated as 6-class instead of project's 7-class.",
"",
])
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w", encoding="utf-8") as f:
f.write("\n".join(lines))
logger.info("Markdown report saved to %s", output_path)
# ──────────────────────────────────────────────
# Main
# ──────────────────────────────────────────────
ADAPTER_MAP = {
"emotion2vec": Emotion2vecAdapter,
"speechbrain": SpeechBrainAdapter,
"whisper": WhisperMediumAdapter,
}
def main():
parser = argparse.ArgumentParser(description="3-Model SER Benchmark")
parser.add_argument("--test-dir", required=True, help="ν…ŒμŠ€νŠΈ μ„œλΈŒμ…‹ 디렉토리 (test_labels.csv 포함)")
parser.add_argument("--models", nargs="+", default=["emotion2vec", "speechbrain"],
choices=list(ADAPTER_MAP.keys()), help="λ²€μΉ˜λ§ˆν¬ν•  λͺ¨λΈ (default: emotion2vec speechbrain)")
parser.add_argument("--device", default="cpu", choices=["cpu", "cuda"], help="Compute device")
parser.add_argument("--whisper-head-ckpt", default=None, help="Whisper emotion head 체크포인트 경둜")
parser.add_argument("--phone-augment", action="store_true", default=False, help="Phone augmentation 평가 μΆ”κ°€")
parser.add_argument("--warmup", type=int, default=5, help="Warmup 횟수")
parser.add_argument("--max-samples", type=int, default=None, help="μ΅œλŒ€ μƒ˜ν”Œ 수 (smoke test용)")
parser.add_argument("--output-json", default="data/evaluation/benchmark_3model_results.json")
parser.add_argument("--output-md", default="docs/stage2/benchmark-3model-report.md")
args = parser.parse_args()
# Load test data
samples = load_test_data(args.test_dir, max_samples=args.max_samples)
if not samples:
logger.error("No test samples loaded")
sys.exit(1)
# Run benchmarks
all_results = {}
for model_name in args.models:
adapter_cls = ADAPTER_MAP[model_name]
if model_name == "whisper":
adapter = adapter_cls(head_ckpt=args.whisper_head_ckpt)
else:
adapter = adapter_cls()
result = benchmark_model(
adapter, samples, args.device,
phone_augment=args.phone_augment,
warmup=args.warmup,
)
result["knockout"] = knockout_check(result)
all_results[adapter.name] = result
# Save JSON
output_json_path = Path(args.output_json)
output_json_path.parent.mkdir(parents=True, exist_ok=True)
import platform
try:
import torch
torch_version = torch.__version__
cuda_available = torch.cuda.is_available()
except ImportError:
torch_version = "not installed"
cuda_available = False
output_data = {
"metadata": {
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"),
"device": args.device,
"test_samples": len(samples),
"eval_classes": EVAL_LABELS,
"conditions": ["clean"] + (["phone"] if args.phone_augment else []),
"system_info": {
"cpu": platform.processor() or "unknown",
"ram_total_gb": round(psutil.virtual_memory().total / (1024**3), 1),
"python": platform.python_version(),
"torch": torch_version,
"cuda": cuda_available,
},
},
"results": all_results,
}
with open(output_json_path, "w", encoding="utf-8") as f:
json.dump(output_data, f, indent=2, ensure_ascii=False, default=str)
logger.info("JSON results saved to %s", output_json_path)
# Generate markdown report
generate_markdown_report(all_results, args.output_md)
# Console summary
print("\n" + "=" * 60)
print("BENCHMARK COMPLETE")
print("=" * 60)
for name, res in all_results.items():
if "error" in res:
print(f"\n {name}: LOAD FAILED β€” {res['error']}")
continue
clean = res.get("clean", {})
ko = res.get("knockout", {}).get("clean", {})
print(f"\n {name} ({res['params_m']}M params):")
print(f" Clean F1: {clean.get('macro_f1', 0):.3f} Accuracy: {clean.get('accuracy', 0):.3f}")
print(f" Latency: {clean.get('latency', {}).get('mean_ms', 0):.0f}ms (mean), "
f"{clean.get('latency', {}).get('p95_ms', 0):.0f}ms (p95)")
print(f" RAM: {clean.get('peak_ram_mb', 0):.0f}MB")
print(f" Knockout: {ko.get('overall', 'N/A')}")
if args.phone_augment:
print("\n --- Phone Degradation ---")
for name, res in all_results.items():
if "error" in res or "phone" not in res:
continue
clean_f1 = res.get("clean", {}).get("macro_f1", 0)
phone_f1 = res["phone"]["macro_f1"]
drop = clean_f1 - phone_f1
print(f" {name}: {clean_f1:.3f} β†’ {phone_f1:.3f} (Ξ”={drop:+.3f})")
print(f"\n Results: {args.output_json}")
print(f" Report: {args.output_md}")
if __name__ == "__main__":
main()