ustwo-api / scripts /evaluate_emotion2vec_english.py
asdfasdfqrqwer's picture
Deploy from GitHub 2026-04-23T03:56:31Z
c857b85
Raw
History Blame Contribute Delete
6.7 kB
#!/usr/bin/env python3
"""emotion2vec base ๋ชจ๋ธ ์˜์–ด ํ‰๊ฐ€ (RAVDESS ๋ฐ์ดํ„ฐ์…‹).
clean + phone 2๊ฐœ ์กฐ๊ฑด์œผ๋กœ ํ‰๊ฐ€ํ•˜์—ฌ ์‹ค์ „ ํ†ตํ™” ํ™˜๊ฒฝ ์„ฑ๋Šฅ์„ ์ถ”์ •ํ•œ๋‹ค.
Usage:
python scripts/evaluate_emotion2vec_english.py
python scripts/evaluate_emotion2vec_english.py --condition clean # clean๋งŒ
python scripts/evaluate_emotion2vec_english.py --condition phone # phone๋งŒ
python scripts/evaluate_emotion2vec_english.py --max-samples 100 # ๋น ๋ฅธ ํ…Œ์ŠคํŠธ
"""
from __future__ import annotations
import argparse
import csv
import json
import logging
import sys
import time
from pathlib import Path
PROJECT_ROOT = Path(__file__).parent.parent
sys.path.insert(0, str(PROJECT_ROOT))
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger("eval_emotion2vec_en")
MANIFEST_PATH = PROJECT_ROOT / "data" / "ravdess" / "manifest.csv"
OUTPUT_JSON = PROJECT_ROOT / "data" / "ravdess_eval_results.json"
PROJECT_LABELS = ["neutral", "joy", "sadness", "anger", "surprise", "fear", "disgust"]
def load_manifest(max_samples: int | None = None) -> list[dict]:
"""manifest.csv ๋กœ๋“œ."""
rows = []
with open(MANIFEST_PATH) as f:
reader = csv.DictReader(f)
for row in reader:
rows.append(row)
if max_samples:
rows = rows[:max_samples]
logger.info(f"manifest ๋กœ๋“œ: {len(rows)}๊ฐœ ์ƒ˜ํ”Œ")
return rows
def evaluate_condition(
samples: list[dict],
condition: str,
device: str,
) -> dict:
"""ํ•œ ์กฐ๊ฑด(clean/phone)์— ๋Œ€ํ•ด ํ‰๊ฐ€ ์‹คํ–‰."""
from sklearn.metrics import (
accuracy_score,
classification_report,
confusion_matrix,
)
from src.stage2.audio_emotion import predict as audio_predict
path_key = "clean_path" if condition == "clean" else "phone_path"
y_true = []
y_pred = []
latencies = []
errors = 0
total = len(samples)
for i, sample in enumerate(samples, 1):
audio_path = sample[path_key]
if not audio_path or not Path(audio_path).exists():
errors += 1
continue
ground_truth = sample["emotion"]
t0 = time.perf_counter()
result = audio_predict(audio_path, device=device)
latency = (time.perf_counter() - t0) * 1000 # ms
y_true.append(ground_truth)
y_pred.append(result["emotion"])
latencies.append(latency)
if i % 200 == 0 or i == total:
acc_so_far = sum(1 for t, p in zip(y_true, y_pred) if t == p) / len(y_true)
logger.info(
f" [{condition}] {i}/{total} โ€” "
f"acc={acc_so_far:.3f}, "
f"avg_latency={sum(latencies)/len(latencies):.0f}ms"
)
accuracy = accuracy_score(y_true, y_pred)
report = classification_report(
y_true, y_pred, labels=PROJECT_LABELS, output_dict=True, zero_division=0,
)
cm = confusion_matrix(y_true, y_pred, labels=PROJECT_LABELS)
# per-class metrics ์ •๋ฆฌ
per_class = {}
for label in PROJECT_LABELS:
if label in report:
per_class[label] = {
"precision": round(report[label]["precision"], 4),
"recall": round(report[label]["recall"], 4),
"f1": round(report[label]["f1-score"], 4),
"support": report[label]["support"],
}
result = {
"condition": condition,
"total_samples": len(y_true),
"errors": errors,
"accuracy": round(accuracy, 4),
"macro_f1": round(report["macro avg"]["f1-score"], 4),
"weighted_f1": round(report["weighted avg"]["f1-score"], 4),
"per_class": per_class,
"confusion_matrix": cm.tolist(),
"confusion_labels": PROJECT_LABELS,
"avg_latency_ms": round(sum(latencies) / len(latencies), 1) if latencies else 0,
}
logger.info(f"\n{'='*60}")
logger.info(f"[{condition.upper()}] ๊ฒฐ๊ณผ:")
logger.info(f" Accuracy: {accuracy:.4f}")
logger.info(f" Macro F1: {report['macro avg']['f1-score']:.4f}")
logger.info(f" Weighted F1: {report['weighted avg']['f1-score']:.4f}")
logger.info(f" Avg Latency: {result['avg_latency_ms']:.0f}ms")
logger.info(f"\nPer-class F1:")
for label in PROJECT_LABELS:
if label in per_class:
logger.info(f" {label:10s}: F1={per_class[label]['f1']:.3f} "
f"(P={per_class[label]['precision']:.3f} R={per_class[label]['recall']:.3f}) "
f"n={per_class[label]['support']}")
logger.info(f"\nConfusion Matrix (rows=true, cols=pred):")
logger.info(f" {'':10s} " + " ".join(f"{l[:4]:>6s}" for l in PROJECT_LABELS))
for i_row, label in enumerate(PROJECT_LABELS):
row_str = " ".join(f"{v:6d}" for v in cm[i_row])
logger.info(f" {label:10s} {row_str}")
logger.info(f"{'='*60}\n")
return result
def main():
parser = argparse.ArgumentParser(description="emotion2vec ์˜์–ด ํ‰๊ฐ€ (RAVDESS)")
parser.add_argument("--condition", choices=["clean", "phone", "both"], default="both")
parser.add_argument("--device", default="cpu")
parser.add_argument("--max-samples", type=int, default=None, help="ํ‰๊ฐ€ํ•  ์ตœ๋Œ€ ์ƒ˜ํ”Œ ์ˆ˜")
args = parser.parse_args()
if not MANIFEST_PATH.exists():
logger.error(f"manifest.csv๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. ๋จผ์ € prepare_ravdess.py๋ฅผ ์‹คํ–‰ํ•˜์„ธ์š”.")
sys.exit(1)
samples = load_manifest(args.max_samples)
conditions = []
if args.condition in ("clean", "both"):
conditions.append("clean")
if args.condition in ("phone", "both"):
conditions.append("phone")
results = []
for cond in conditions:
logger.info(f"\n{'#'*60}")
logger.info(f"ํ‰๊ฐ€ ์‹œ์ž‘: {cond.upper()} ์กฐ๊ฑด")
logger.info(f"{'#'*60}")
result = evaluate_condition(samples, cond, args.device)
results.append(result)
# ๊ฒฐ๊ณผ ์ €์žฅ
with open(OUTPUT_JSON, "w") as f:
json.dump(results, f, indent=2, ensure_ascii=False)
logger.info(f"๊ฒฐ๊ณผ ์ €์žฅ: {OUTPUT_JSON}")
# clean vs phone ๋น„๊ต (both์ผ ๋•Œ)
if len(results) == 2:
clean_acc = results[0]["accuracy"]
phone_acc = results[1]["accuracy"]
degradation = clean_acc - phone_acc
logger.info(f"\n{'='*60}")
logger.info(f"Clean vs Phone ๋น„๊ต:")
logger.info(f" Clean accuracy: {clean_acc:.4f}")
logger.info(f" Phone accuracy: {phone_acc:.4f}")
logger.info(f" Degradation: {degradation:+.4f}")
logger.info(f"{'='*60}")
if __name__ == "__main__":
main()