| |
| """emotion2vec base ๋ชจ๋ธ ์์ด ํ๊ฐ (RAVDESS ๋ฐ์ดํฐ์
). |
| |
| clean + phone 2๊ฐ ์กฐ๊ฑด์ผ๋ก ํ๊ฐํ์ฌ ์ค์ ํตํ ํ๊ฒฝ ์ฑ๋ฅ์ ์ถ์ ํ๋ค. |
| |
| Usage: |
| python scripts/evaluate_emotion2vec_english.py |
| python scripts/evaluate_emotion2vec_english.py --condition clean # clean๋ง |
| python scripts/evaluate_emotion2vec_english.py --condition phone # phone๋ง |
| python scripts/evaluate_emotion2vec_english.py --max-samples 100 # ๋น ๋ฅธ ํ
์คํธ |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import csv |
| import json |
| import logging |
| import sys |
| import time |
| from pathlib import Path |
|
|
| PROJECT_ROOT = Path(__file__).parent.parent |
| sys.path.insert(0, str(PROJECT_ROOT)) |
|
|
| logging.basicConfig( |
| level=logging.INFO, |
| format="%(asctime)s - %(levelname)s - %(message)s", |
| ) |
| logger = logging.getLogger("eval_emotion2vec_en") |
|
|
| MANIFEST_PATH = PROJECT_ROOT / "data" / "ravdess" / "manifest.csv" |
| OUTPUT_JSON = PROJECT_ROOT / "data" / "ravdess_eval_results.json" |
|
|
| PROJECT_LABELS = ["neutral", "joy", "sadness", "anger", "surprise", "fear", "disgust"] |
|
|
|
|
| def load_manifest(max_samples: int | None = None) -> list[dict]: |
| """manifest.csv ๋ก๋.""" |
| rows = [] |
| with open(MANIFEST_PATH) as f: |
| reader = csv.DictReader(f) |
| for row in reader: |
| rows.append(row) |
| if max_samples: |
| rows = rows[:max_samples] |
| logger.info(f"manifest ๋ก๋: {len(rows)}๊ฐ ์ํ") |
| return rows |
|
|
|
|
| def evaluate_condition( |
| samples: list[dict], |
| condition: str, |
| device: str, |
| ) -> dict: |
| """ํ ์กฐ๊ฑด(clean/phone)์ ๋ํด ํ๊ฐ ์คํ.""" |
| from sklearn.metrics import ( |
| accuracy_score, |
| classification_report, |
| confusion_matrix, |
| ) |
| from src.stage2.audio_emotion import predict as audio_predict |
|
|
| path_key = "clean_path" if condition == "clean" else "phone_path" |
|
|
| y_true = [] |
| y_pred = [] |
| latencies = [] |
| errors = 0 |
|
|
| total = len(samples) |
| for i, sample in enumerate(samples, 1): |
| audio_path = sample[path_key] |
| if not audio_path or not Path(audio_path).exists(): |
| errors += 1 |
| continue |
|
|
| ground_truth = sample["emotion"] |
|
|
| t0 = time.perf_counter() |
| result = audio_predict(audio_path, device=device) |
| latency = (time.perf_counter() - t0) * 1000 |
|
|
| y_true.append(ground_truth) |
| y_pred.append(result["emotion"]) |
| latencies.append(latency) |
|
|
| if i % 200 == 0 or i == total: |
| acc_so_far = sum(1 for t, p in zip(y_true, y_pred) if t == p) / len(y_true) |
| logger.info( |
| f" [{condition}] {i}/{total} โ " |
| f"acc={acc_so_far:.3f}, " |
| f"avg_latency={sum(latencies)/len(latencies):.0f}ms" |
| ) |
|
|
| accuracy = accuracy_score(y_true, y_pred) |
| report = classification_report( |
| y_true, y_pred, labels=PROJECT_LABELS, output_dict=True, zero_division=0, |
| ) |
| cm = confusion_matrix(y_true, y_pred, labels=PROJECT_LABELS) |
|
|
| |
| per_class = {} |
| for label in PROJECT_LABELS: |
| if label in report: |
| per_class[label] = { |
| "precision": round(report[label]["precision"], 4), |
| "recall": round(report[label]["recall"], 4), |
| "f1": round(report[label]["f1-score"], 4), |
| "support": report[label]["support"], |
| } |
|
|
| result = { |
| "condition": condition, |
| "total_samples": len(y_true), |
| "errors": errors, |
| "accuracy": round(accuracy, 4), |
| "macro_f1": round(report["macro avg"]["f1-score"], 4), |
| "weighted_f1": round(report["weighted avg"]["f1-score"], 4), |
| "per_class": per_class, |
| "confusion_matrix": cm.tolist(), |
| "confusion_labels": PROJECT_LABELS, |
| "avg_latency_ms": round(sum(latencies) / len(latencies), 1) if latencies else 0, |
| } |
|
|
| logger.info(f"\n{'='*60}") |
| logger.info(f"[{condition.upper()}] ๊ฒฐ๊ณผ:") |
| logger.info(f" Accuracy: {accuracy:.4f}") |
| logger.info(f" Macro F1: {report['macro avg']['f1-score']:.4f}") |
| logger.info(f" Weighted F1: {report['weighted avg']['f1-score']:.4f}") |
| logger.info(f" Avg Latency: {result['avg_latency_ms']:.0f}ms") |
| logger.info(f"\nPer-class F1:") |
| for label in PROJECT_LABELS: |
| if label in per_class: |
| logger.info(f" {label:10s}: F1={per_class[label]['f1']:.3f} " |
| f"(P={per_class[label]['precision']:.3f} R={per_class[label]['recall']:.3f}) " |
| f"n={per_class[label]['support']}") |
| logger.info(f"\nConfusion Matrix (rows=true, cols=pred):") |
| logger.info(f" {'':10s} " + " ".join(f"{l[:4]:>6s}" for l in PROJECT_LABELS)) |
| for i_row, label in enumerate(PROJECT_LABELS): |
| row_str = " ".join(f"{v:6d}" for v in cm[i_row]) |
| logger.info(f" {label:10s} {row_str}") |
| logger.info(f"{'='*60}\n") |
|
|
| return result |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="emotion2vec ์์ด ํ๊ฐ (RAVDESS)") |
| parser.add_argument("--condition", choices=["clean", "phone", "both"], default="both") |
| parser.add_argument("--device", default="cpu") |
| parser.add_argument("--max-samples", type=int, default=None, help="ํ๊ฐํ ์ต๋ ์ํ ์") |
| args = parser.parse_args() |
|
|
| if not MANIFEST_PATH.exists(): |
| logger.error(f"manifest.csv๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค. ๋จผ์ prepare_ravdess.py๋ฅผ ์คํํ์ธ์.") |
| sys.exit(1) |
|
|
| samples = load_manifest(args.max_samples) |
|
|
| conditions = [] |
| if args.condition in ("clean", "both"): |
| conditions.append("clean") |
| if args.condition in ("phone", "both"): |
| conditions.append("phone") |
|
|
| results = [] |
| for cond in conditions: |
| logger.info(f"\n{'#'*60}") |
| logger.info(f"ํ๊ฐ ์์: {cond.upper()} ์กฐ๊ฑด") |
| logger.info(f"{'#'*60}") |
| result = evaluate_condition(samples, cond, args.device) |
| results.append(result) |
|
|
| |
| with open(OUTPUT_JSON, "w") as f: |
| json.dump(results, f, indent=2, ensure_ascii=False) |
| logger.info(f"๊ฒฐ๊ณผ ์ ์ฅ: {OUTPUT_JSON}") |
|
|
| |
| if len(results) == 2: |
| clean_acc = results[0]["accuracy"] |
| phone_acc = results[1]["accuracy"] |
| degradation = clean_acc - phone_acc |
| logger.info(f"\n{'='*60}") |
| logger.info(f"Clean vs Phone ๋น๊ต:") |
| logger.info(f" Clean accuracy: {clean_acc:.4f}") |
| logger.info(f" Phone accuracy: {phone_acc:.4f}") |
| logger.info(f" Degradation: {degradation:+.4f}") |
| logger.info(f"{'='*60}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|