#!/usr/bin/env python3 """emotion2vec base 모델 영어 평가 (RAVDESS 데이터셋). clean + phone 2개 조건으로 평가하여 실전 통화 환경 성능을 추정한다. Usage: python scripts/evaluate_emotion2vec_english.py python scripts/evaluate_emotion2vec_english.py --condition clean # clean만 python scripts/evaluate_emotion2vec_english.py --condition phone # phone만 python scripts/evaluate_emotion2vec_english.py --max-samples 100 # 빠른 테스트 """ from __future__ import annotations import argparse import csv import json import logging import sys import time from pathlib import Path PROJECT_ROOT = Path(__file__).parent.parent sys.path.insert(0, str(PROJECT_ROOT)) logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", ) logger = logging.getLogger("eval_emotion2vec_en") MANIFEST_PATH = PROJECT_ROOT / "data" / "ravdess" / "manifest.csv" OUTPUT_JSON = PROJECT_ROOT / "data" / "ravdess_eval_results.json" PROJECT_LABELS = ["neutral", "joy", "sadness", "anger", "surprise", "fear", "disgust"] def load_manifest(max_samples: int | None = None) -> list[dict]: """manifest.csv 로드.""" rows = [] with open(MANIFEST_PATH) as f: reader = csv.DictReader(f) for row in reader: rows.append(row) if max_samples: rows = rows[:max_samples] logger.info(f"manifest 로드: {len(rows)}개 샘플") return rows def evaluate_condition( samples: list[dict], condition: str, device: str, ) -> dict: """한 조건(clean/phone)에 대해 평가 실행.""" from sklearn.metrics import ( accuracy_score, classification_report, confusion_matrix, ) from src.stage2.audio_emotion import predict as audio_predict path_key = "clean_path" if condition == "clean" else "phone_path" y_true = [] y_pred = [] latencies = [] errors = 0 total = len(samples) for i, sample in enumerate(samples, 1): audio_path = sample[path_key] if not audio_path or not Path(audio_path).exists(): errors += 1 continue ground_truth = sample["emotion"] t0 = time.perf_counter() result = audio_predict(audio_path, device=device) latency = (time.perf_counter() - t0) * 1000 # ms y_true.append(ground_truth) y_pred.append(result["emotion"]) latencies.append(latency) if i % 200 == 0 or i == total: acc_so_far = sum(1 for t, p in zip(y_true, y_pred) if t == p) / len(y_true) logger.info( f" [{condition}] {i}/{total} — " f"acc={acc_so_far:.3f}, " f"avg_latency={sum(latencies)/len(latencies):.0f}ms" ) accuracy = accuracy_score(y_true, y_pred) report = classification_report( y_true, y_pred, labels=PROJECT_LABELS, output_dict=True, zero_division=0, ) cm = confusion_matrix(y_true, y_pred, labels=PROJECT_LABELS) # per-class metrics 정리 per_class = {} for label in PROJECT_LABELS: if label in report: per_class[label] = { "precision": round(report[label]["precision"], 4), "recall": round(report[label]["recall"], 4), "f1": round(report[label]["f1-score"], 4), "support": report[label]["support"], } result = { "condition": condition, "total_samples": len(y_true), "errors": errors, "accuracy": round(accuracy, 4), "macro_f1": round(report["macro avg"]["f1-score"], 4), "weighted_f1": round(report["weighted avg"]["f1-score"], 4), "per_class": per_class, "confusion_matrix": cm.tolist(), "confusion_labels": PROJECT_LABELS, "avg_latency_ms": round(sum(latencies) / len(latencies), 1) if latencies else 0, } logger.info(f"\n{'='*60}") logger.info(f"[{condition.upper()}] 결과:") logger.info(f" Accuracy: {accuracy:.4f}") logger.info(f" Macro F1: {report['macro avg']['f1-score']:.4f}") logger.info(f" Weighted F1: {report['weighted avg']['f1-score']:.4f}") logger.info(f" Avg Latency: {result['avg_latency_ms']:.0f}ms") logger.info(f"\nPer-class F1:") for label in PROJECT_LABELS: if label in per_class: logger.info(f" {label:10s}: F1={per_class[label]['f1']:.3f} " f"(P={per_class[label]['precision']:.3f} R={per_class[label]['recall']:.3f}) " f"n={per_class[label]['support']}") logger.info(f"\nConfusion Matrix (rows=true, cols=pred):") logger.info(f" {'':10s} " + " ".join(f"{l[:4]:>6s}" for l in PROJECT_LABELS)) for i_row, label in enumerate(PROJECT_LABELS): row_str = " ".join(f"{v:6d}" for v in cm[i_row]) logger.info(f" {label:10s} {row_str}") logger.info(f"{'='*60}\n") return result def main(): parser = argparse.ArgumentParser(description="emotion2vec 영어 평가 (RAVDESS)") parser.add_argument("--condition", choices=["clean", "phone", "both"], default="both") parser.add_argument("--device", default="cpu") parser.add_argument("--max-samples", type=int, default=None, help="평가할 최대 샘플 수") args = parser.parse_args() if not MANIFEST_PATH.exists(): logger.error(f"manifest.csv를 찾을 수 없습니다. 먼저 prepare_ravdess.py를 실행하세요.") sys.exit(1) samples = load_manifest(args.max_samples) conditions = [] if args.condition in ("clean", "both"): conditions.append("clean") if args.condition in ("phone", "both"): conditions.append("phone") results = [] for cond in conditions: logger.info(f"\n{'#'*60}") logger.info(f"평가 시작: {cond.upper()} 조건") logger.info(f"{'#'*60}") result = evaluate_condition(samples, cond, args.device) results.append(result) # 결과 저장 with open(OUTPUT_JSON, "w") as f: json.dump(results, f, indent=2, ensure_ascii=False) logger.info(f"결과 저장: {OUTPUT_JSON}") # clean vs phone 비교 (both일 때) if len(results) == 2: clean_acc = results[0]["accuracy"] phone_acc = results[1]["accuracy"] degradation = clean_acc - phone_acc logger.info(f"\n{'='*60}") logger.info(f"Clean vs Phone 비교:") logger.info(f" Clean accuracy: {clean_acc:.4f}") logger.info(f" Phone accuracy: {phone_acc:.4f}") logger.info(f" Degradation: {degradation:+.4f}") logger.info(f"{'='*60}") if __name__ == "__main__": main()