ustwo-api / scripts /benchmark_emotion2vec.py
asdfasdfqrqwer's picture
Deploy from GitHub 2026-04-23T03:56:31Z
c857b85
Raw
History Blame Contribute Delete
18.9 kB
#!/usr/bin/env python3
"""emotion2vec Variant Smoke Test β€” base vs plus_base vs plus_large μ‹€μΈ‘ 비ꡐ.
κΈ°μ‘΄ Stage 1 좜λ ₯ μ„Έκ·Έλ¨ΌνŠΈ(88개)λ₯Ό μ‚¬μš©ν•˜μ—¬ 3개 emotion2vec variant의
latency, RAM, 예츑 ν’ˆμ§ˆμ„ μ‹€μΈ‘ λΉ„κ΅ν•©λ‹ˆλ‹€.
Usage:
python scripts/benchmark_emotion2vec.py # 3개 μ „λΆ€
python scripts/benchmark_emotion2vec.py --variants plus_base # 단일
python scripts/benchmark_emotion2vec.py --device cuda # GPU
"""
from __future__ import annotations
import argparse
import gc
import glob
import json
import logging
import os
import statistics
import sys
import time
from pathlib import Path
import numpy as np
import psutil
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger(__name__)
# ──────────────────────────────────────────────
# Constants
# ──────────────────────────────────────────────
VARIANT_CONFIGS = {
"base": "iic/emotion2vec_base",
"plus_base": "iic/emotion2vec_plus_base",
"plus_large": "iic/emotion2vec_plus_large",
}
# emotion2vec 9-class β†’ project 7-class mapping
LABEL_MAP = {
"angry": "anger",
"disgusted": "disgust",
"fearful": "fear",
"happy": "joy",
"neutral": "neutral",
"sad": "sadness",
"surprised": "surprise",
"other": "neutral",
"unknown": "neutral",
}
PROJECT_LABELS = ["neutral", "joy", "sadness", "anger", "surprise", "fear", "disgust"]
# Representative segments for Korean sanity check (indices into sorted segment list)
# Will be selected dynamically: shortest, longest, and 3 evenly spaced
SANITY_CHECK_COUNT = 5
# ──────────────────────────────────────────────
# Segment Discovery
# ──────────────────────────────────────────────
def discover_segments(segments_dir: str) -> list[dict]:
"""Find all segment WAV files and load metadata from stage1_output.json."""
pattern = os.path.join(segments_dir, "call_*", "seg_*.wav")
paths = sorted(glob.glob(pattern))
if not paths:
logger.error("No segment files found in %s", segments_dir)
sys.exit(1)
# Try to load metadata from stage1_output.json for text context
metadata = {}
stage1_path = Path(segments_dir).parent / "stage1_output.json"
if stage1_path.exists():
with open(stage1_path) as f:
data = json.load(f)
for seg in data.get("segments", []):
metadata[seg["audio_path"]] = {
"text": seg.get("text", ""),
"speaker_id": seg.get("speaker_id", ""),
"start": seg.get("start", 0),
"end": seg.get("end", 0),
}
segments = []
for p in paths:
call_id = Path(p).parent.name
seg_name = Path(p).stem
meta = metadata.get(p, {})
segments.append({
"path": p,
"call_id": call_id,
"seg_name": seg_name,
"text": meta.get("text", ""),
"speaker_id": meta.get("speaker_id", ""),
"duration_sec": meta.get("end", 0) - meta.get("start", 0),
})
logger.info("Discovered %d segments across %d calls",
len(segments), len(set(s["call_id"] for s in segments)))
return segments
def select_sanity_segments(segments: list[dict], count: int = SANITY_CHECK_COUNT) -> list[dict]:
"""Select representative segments for sanity check: shortest, longest, + evenly spaced."""
if len(segments) <= count:
return segments
# Sort by duration for selection
by_dur = sorted(segments, key=lambda s: s["duration_sec"])
# Filter to segments that have text (from stage1_output.json call)
with_text = [s for s in by_dur if s["text"]]
if len(with_text) < count:
with_text = by_dur
selected = [with_text[0], with_text[-1]] # shortest, longest
remaining = count - 2
step = max(1, len(with_text) // (remaining + 1))
for i in range(1, remaining + 1):
idx = min(i * step, len(with_text) - 1)
candidate = with_text[idx]
if candidate not in selected:
selected.append(candidate)
return selected[:count]
# ──────────────────────────────────────────────
# Benchmarking
# ──────────────────────────────────────────────
def get_process_rss_mb() -> float:
"""Current process RSS in MB."""
return psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)
def map_predictions(raw_scores: dict[str, float]) -> dict:
"""Map emotion2vec native labels to project 7-class taxonomy."""
mapped = {label: 0.0 for label in PROJECT_LABELS}
for native_label, score in raw_scores.items():
project_label = LABEL_MAP.get(native_label, "neutral")
mapped[project_label] += score
top_label = max(mapped, key=mapped.get)
return {
"label": top_label,
"confidence": mapped[top_label],
"scores": mapped,
"raw_scores": raw_scores,
}
def benchmark_variant(
variant_name: str,
model_id: str,
segments: list[dict],
device: str = "cpu",
warmup: int = 3,
) -> dict:
"""Run full benchmark for one emotion2vec variant."""
from funasr import AutoModel
logger.info("=" * 60)
logger.info("Benchmarking: %s (%s)", variant_name, model_id)
logger.info("=" * 60)
result = {
"variant": variant_name,
"model_id": model_id,
"device": device,
}
# 1. Baseline RAM
gc.collect()
baseline_rss = get_process_rss_mb()
# 2. Model load + load time
logger.info("Loading model...")
load_start = time.perf_counter()
try:
model = AutoModel(model=model_id, device=device)
except Exception as e:
logger.error("Failed to load %s: %s", model_id, e)
result["error"] = str(e)
return result
load_time = time.perf_counter() - load_start
result["load_time_sec"] = round(load_time, 2)
logger.info("Model loaded in %.2fs", load_time)
# 3. Peak RAM after load
post_load_rss = get_process_rss_mb()
result["model_ram_mb"] = round(post_load_rss - baseline_rss, 1)
logger.info("Model RAM: %.1f MB", result["model_ram_mb"])
# 4. Warmup
logger.info("Warmup (%d runs)...", warmup)
warmup_segs = segments[:warmup] if len(segments) >= warmup else segments
for seg in warmup_segs:
try:
model.generate(seg["path"], granularity="utterance", extract_embedding=False)
except Exception as e:
logger.warning("Warmup failed on %s: %s", seg["path"], e)
# 5. Timed inference on all segments
logger.info("Running inference on %d segments...", len(segments))
predictions = []
latencies = []
errors = []
peak_rss = post_load_rss
for i, seg in enumerate(segments):
try:
t0 = time.perf_counter()
output = model.generate(
seg["path"], granularity="utterance", extract_embedding=False,
)
t1 = time.perf_counter()
latency_ms = (t1 - t0) * 1000
latencies.append(latency_ms)
# Parse emotion2vec output
raw_scores = {}
if output and isinstance(output, list) and len(output) > 0:
rec = output[0]
labels = rec.get("labels", [])
scores = rec.get("scores", [])
for label, score in zip(labels, scores):
raw_scores[label] = float(score)
mapped = map_predictions(raw_scores)
predictions.append({
"seg_name": seg["seg_name"],
"call_id": seg["call_id"],
"text": seg["text"],
"speaker_id": seg["speaker_id"],
"duration_sec": seg["duration_sec"],
"latency_ms": round(latency_ms, 1),
**mapped,
})
except Exception as e:
errors.append({"seg_name": seg["seg_name"], "error": str(e)})
logger.warning("Inference error on %s: %s", seg["seg_name"], e)
# Track peak RAM
current_rss = get_process_rss_mb()
peak_rss = max(peak_rss, current_rss)
if (i + 1) % 20 == 0:
logger.info(" %d/%d segments done", i + 1, len(segments))
# 6. Aggregate results
result["peak_ram_mb"] = round(peak_rss - baseline_rss, 1)
result["total_segments"] = len(segments)
result["successful"] = len(predictions)
result["errors"] = errors
if latencies:
result["latency"] = {
"mean_ms": round(statistics.mean(latencies), 1),
"median_ms": round(statistics.median(latencies), 1),
"std_ms": round(statistics.stdev(latencies), 1) if len(latencies) > 1 else 0,
"p95_ms": round(sorted(latencies)[int(len(latencies) * 0.95)], 1),
"min_ms": round(min(latencies), 1),
"max_ms": round(max(latencies), 1),
}
else:
result["latency"] = {}
# Emotion distribution
dist = {label: 0 for label in PROJECT_LABELS}
for pred in predictions:
dist[pred["label"]] += 1
result["emotion_distribution"] = dist
result["predictions"] = predictions
# 7. Cleanup
logger.info("Cleaning up model...")
del model
gc.collect()
try:
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
except ImportError:
pass
logger.info("Done: %s β€” mean latency %.1fms, peak RAM %.1fMB",
variant_name,
result.get("latency", {}).get("mean_ms", 0),
result.get("peak_ram_mb", 0))
return result
# ──────────────────────────────────────────────
# Output Formatting
# ──────────────────────────────────────────────
def fmt_table(headers: list[str], rows: list[list[str]], col_widths: list[int] | None = None) -> str:
"""Simple table formatter."""
if not col_widths:
col_widths = []
for i, h in enumerate(headers):
max_w = len(h)
for row in rows:
if i < len(row):
max_w = max(max_w, len(str(row[i])))
col_widths.append(max_w + 2)
def fmt_row(cells):
return "β”‚ " + " β”‚ ".join(str(c).ljust(w) for c, w in zip(cells, col_widths)) + " β”‚"
separator = "β”œβ”€" + "─┼─".join("─" * w for w in col_widths) + "──"
top = "β”Œβ”€" + "─┬─".join("─" * w for w in col_widths) + "─┐"
bottom = "└─" + "─┴─".join("─" * w for w in col_widths) + "β”€β”˜"
lines = [top, fmt_row(headers), separator]
for row in rows:
lines.append(fmt_row(row))
lines.append(bottom)
return "\n".join(lines)
def format_results(all_results: dict[str, dict], segments: list[dict]) -> str:
"""Format benchmark results into readable console output."""
output_parts = []
# ── Performance Comparison ──
output_parts.append("\n=== Performance Comparison ===")
headers = ["Variant", "Load (s)", "RAM (MB)", "Latency mean (ms)", "Latency p95 (ms)", "Errors"]
rows = []
for name, res in all_results.items():
if "error" in res:
rows.append([name, "FAIL", "-", "-", "-", res["error"][:40]])
continue
lat = res.get("latency", {})
mean_str = f"{lat.get('mean_ms', 0):.1f} Β± {lat.get('std_ms', 0):.1f}"
rows.append([
name,
f"{res.get('load_time_sec', 0):.1f}",
f"{res.get('peak_ram_mb', 0):.0f}",
mean_str,
f"{lat.get('p95_ms', 0):.1f}",
str(len(res.get("errors", []))),
])
output_parts.append(fmt_table(headers, rows))
# ── Knockout Check ──
output_parts.append("\n=== Knockout Check ===")
for name, res in all_results.items():
if "error" in res:
output_parts.append(f" {name}: ❌ LOAD FAILED")
continue
lat_mean = res.get("latency", {}).get("mean_ms", 999)
ram = res.get("peak_ram_mb", 999)
lat_ok = "βœ…" if lat_mean <= 500 else "❌"
ram_ok = "βœ…" if ram <= 2048 else "❌"
output_parts.append(f" {name}: Latency {lat_ok} ({lat_mean:.0f}ms ≀ 500ms) RAM {ram_ok} ({ram:.0f}MB ≀ 2048MB)")
# ── Emotion Distribution ──
output_parts.append("\n=== Emotion Distribution (across all segments) ===")
headers = ["Variant"] + PROJECT_LABELS
rows = []
for name, res in all_results.items():
if "error" in res:
continue
dist = res.get("emotion_distribution", {})
rows.append([name] + [str(dist.get(l, 0)) for l in PROJECT_LABELS])
output_parts.append(fmt_table(headers, rows))
# ── Korean Sanity Check ──
output_parts.append("\n=== Korean Sanity Check ===")
sanity_segs = select_sanity_segments(segments)
for seg in sanity_segs:
text_preview = seg["text"][:50] + "..." if len(seg["text"]) > 50 else seg["text"]
output_parts.append(f'\n {seg["seg_name"]} ({seg["duration_sec"]:.1f}s): "{text_preview}"')
for name, res in all_results.items():
if "error" in res:
output_parts.append(f" {name}: FAILED")
continue
# Find matching prediction
preds = res.get("predictions", [])
match = next((p for p in preds if p["seg_name"] == seg["seg_name"]), None)
if match:
output_parts.append(f" {name:12s}: {match['label']:10s} ({match['confidence']:.2f})")
else:
output_parts.append(f" {name:12s}: no prediction")
# ── Variant Agreement ──
output_parts.append("\n=== Variant Agreement ===")
valid_results = {k: v for k, v in all_results.items() if "error" not in v}
if len(valid_results) >= 2:
variant_names = list(valid_results.keys())
# Build prediction maps: seg_name -> label
pred_maps = {}
for name, res in valid_results.items():
pred_maps[name] = {p["seg_name"]: p["label"] for p in res.get("predictions", [])}
# All-agree count
all_seg_names = set()
for pm in pred_maps.values():
all_seg_names.update(pm.keys())
agree_count = 0
total_count = 0
for seg_name in all_seg_names:
labels = [pm.get(seg_name) for pm in pred_maps.values() if seg_name in pm]
if len(labels) == len(valid_results):
total_count += 1
if len(set(labels)) == 1:
agree_count += 1
output_parts.append(f" All {len(valid_results)} variants agree: {agree_count}/{total_count} ({agree_count/max(total_count,1)*100:.0f}%)")
# Pairwise agreement
for i in range(len(variant_names)):
for j in range(i + 1, len(variant_names)):
a, b = variant_names[i], variant_names[j]
common = set(pred_maps[a].keys()) & set(pred_maps[b].keys())
pair_agree = sum(1 for s in common if pred_maps[a][s] == pred_maps[b][s])
pct = pair_agree / max(len(common), 1) * 100
output_parts.append(f" {a} vs {b}: {pair_agree}/{len(common)} ({pct:.0f}%)")
return "\n".join(output_parts)
# ──────────────────────────────────────────────
# Main
# ──────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser(description="emotion2vec variant benchmark")
parser.add_argument(
"--variants", nargs="*", default=list(VARIANT_CONFIGS.keys()),
choices=list(VARIANT_CONFIGS.keys()),
help="Which variants to benchmark (default: all)",
)
parser.add_argument("--segments-dir", default="data/segments", help="Segments directory")
parser.add_argument("--output-json", default="data/benchmark_results.json", help="Output JSON path")
parser.add_argument("--device", default="cpu", choices=["cpu", "cuda"], help="Compute device")
parser.add_argument("--warmup", type=int, default=3, help="Warmup iterations")
args = parser.parse_args()
# Check dependency
try:
import funasr # noqa: F401
except ImportError:
logger.error("funasr not installed. Run: pip install funasr onnxruntime")
sys.exit(1)
# Discover segments
segments = discover_segments(args.segments_dir)
logger.info("Total segments: %d", len(segments))
# Run benchmarks
all_results = {}
for variant_name in args.variants:
model_id = VARIANT_CONFIGS[variant_name]
result = benchmark_variant(
variant_name, model_id, segments,
device=args.device, warmup=args.warmup,
)
all_results[variant_name] = result
# Format and print results
report = format_results(all_results, segments)
print(report)
# Save JSON
output_path = Path(args.output_json)
output_path.parent.mkdir(parents=True, exist_ok=True)
# Get system info
import platform
try:
import torch
torch_version = torch.__version__
cuda_available = torch.cuda.is_available()
except ImportError:
torch_version = "not installed"
cuda_available = False
output_data = {
"metadata": {
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"),
"device": args.device,
"total_segments": len(segments),
"python_version": platform.python_version(),
"torch_version": torch_version,
"cuda_available": cuda_available,
"cpu": platform.processor() or "unknown",
"ram_total_gb": round(psutil.virtual_memory().total / (1024**3), 1),
},
"results": all_results,
}
with open(output_path, "w", encoding="utf-8") as f:
json.dump(output_data, f, indent=2, ensure_ascii=False, default=str)
logger.info("Results saved to %s", output_path)
print(f"\nFull results saved to {output_path}")
if __name__ == "__main__":
main()