| |
| """20Hours Korean demo set — E2E server test. |
| |
| Uploads 7 curated Korean demo WAVs to the deployed HF Spaces server, |
| runs the full pipeline (Stage 1 → 2 → 3), and compares results |
| against the intended demo emotion labels (data/20hours_test/ground_truth.json). |
| |
| Usage: |
| python scripts/test_20hours_e2e_server.py |
| python scripts/test_20hours_e2e_server.py --server http://localhost:8000 |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import sys |
| import time |
| from pathlib import Path |
|
|
| import requests |
|
|
| PROJECT_ROOT = Path(__file__).resolve().parent.parent |
| TEST_DIR = PROJECT_ROOT / "data" / "20hours_test" |
| GT_PATH = TEST_DIR / "ground_truth.json" |
|
|
| DEFAULT_SERVER = "https://bbbakery-ustwo-api.hf.space" |
| POLL_INTERVAL = 5 |
| MAX_WAIT = 300 |
|
|
|
|
| def health_check(base: str) -> bool: |
| try: |
| r = requests.get(f"{base}/api/health", timeout=10) |
| data = r.json() |
| if data.get("status") == "ok": |
| print(f" Server OK ({data.get('timestamp', '?')})") |
| return True |
| except Exception as e: |
| print(f" Health check failed: {e}") |
| return False |
|
|
|
|
| def upload(base: str, wav_path: Path) -> str | None: |
| with open(wav_path, "rb") as f: |
| r = requests.post( |
| f"{base}/api/upload", |
| files={"file": (wav_path.name, f, "audio/wav")}, |
| timeout=60, |
| ) |
| if r.status_code != 200: |
| print(f" Upload failed ({r.status_code}): {r.text[:200]}") |
| return None |
| return r.json().get("call_id") |
|
|
|
|
| def analyze_and_poll(base: str, call_id: str) -> dict | None: |
| r = requests.post(f"{base}/api/analyze", params={"call_id": call_id}, timeout=30) |
| if r.status_code not in (200, 202): |
| print(f" Analyze start failed ({r.status_code}): {r.text[:200]}") |
| return None |
|
|
| data = r.json() |
| if data.get("status") == "done": |
| return data.get("result") |
|
|
| elapsed = 0 |
| while elapsed < MAX_WAIT: |
| time.sleep(POLL_INTERVAL) |
| elapsed += POLL_INTERVAL |
| r = requests.get(f"{base}/api/analyze/{call_id}/status", timeout=15) |
| data = r.json() |
| status = data.get("status") |
| if status == "done": |
| return data.get("result") |
| if status == "error": |
| print(f" Pipeline error: {data.get('error', '?')}") |
| return None |
| mins, secs = divmod(elapsed, 60) |
| print(f" {status}... ({int(mins)}m{int(secs)}s)", end="\r") |
|
|
| print(f" Timeout after {MAX_WAIT}s") |
| return None |
|
|
|
|
| def extract_emotions(result: dict) -> dict: |
| info: dict = {} |
| reactions = result.get("character_reactions", []) |
| for i, rx in enumerate(reactions): |
| info[f"speaker_{i}_state"] = rx.get("solo_state", "?") |
| garden = result.get("garden_update", {}) |
| info["garden_mood"] = garden.get("mood", "?") |
| info["garden_delta"] = garden.get("growth_delta", 0) |
| recap = result.get("recap_card", {}) or {} |
| info["recap_headline"] = recap.get("headline") or recap.get("title", "?") |
|
|
| stage2 = result.get("stage2_output", {}) |
| if stage2: |
| for spk, summary in stage2.get("speaker_summaries", {}).items(): |
| info[f"{spk}_dominant"] = summary.get("dominant_emotion", "?") |
| info[f"{spk}_distribution"] = summary.get("emotion_distribution", {}) |
|
|
| |
| emotions = result.get("emotions") or stage2.get("emotions", []) |
| segs_by_lang: dict[str, int] = {} |
| for e in emotions: |
| lang = e.get("language") or "?" |
| segs_by_lang[lang] = segs_by_lang.get(lang, 0) + 1 |
| info["segments"] = len(emotions) |
| info["segments_by_lang"] = segs_by_lang |
| return info |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--server", default=DEFAULT_SERVER) |
| args = parser.parse_args() |
| base = args.server.rstrip("/") |
|
|
| print("=" * 70) |
| print(" 20Hours Korean Demo — E2E Server Test") |
| print(f" Server: {base}") |
| print("=" * 70) |
|
|
| print("\n[1] Health check") |
| if not health_check(base): |
| sys.exit(1) |
|
|
| print("\n[2] Loading intended emotion labels") |
| gt = json.loads(GT_PATH.read_text()) |
| print(f" {len(gt)} demo clips loaded") |
|
|
| print("\n[3] Running E2E tests\n") |
| results = {} |
| hit = 0 |
| total = 0 |
|
|
| for tag in sorted(gt.keys()): |
| wav_path = TEST_DIR / f"{tag}.wav" |
| if not wav_path.exists(): |
| print(f" {tag}: WAV not found, skipping") |
| continue |
| gt_entry = gt[tag] |
| print(f" {tag} — {gt_entry['description'][:55]}") |
| print(f" Intended: {gt_entry['primary_emotion']} | Duration: {gt_entry['duration_sec']}s | Utts: {gt_entry['total_utterances']}") |
|
|
| call_id = upload(base, wav_path) |
| if not call_id: |
| results[tag] = {"status": "upload_failed"} |
| continue |
| print(f" Upload OK → {call_id}") |
|
|
| print(f" Analyzing...", end="") |
| start_time = time.time() |
| result = analyze_and_poll(base, call_id) |
| elapsed = time.time() - start_time |
|
|
| if not result: |
| results[tag] = {"status": "analyze_failed", "call_id": call_id} |
| print() |
| continue |
|
|
| print(f"\r Done in {elapsed:.1f}s ") |
|
|
| emotions = extract_emotions(result) |
| total += 1 |
| intended = gt_entry["primary_emotion"] |
| speaker_states = {k: v for k, v in emotions.items() if k.endswith("_state")} |
| if intended in speaker_states.values(): |
| hit += 1 |
| match = "HIT" |
| else: |
| match = "miss" |
|
|
| results[tag] = { |
| "status": "pass", |
| "call_id": call_id, |
| "elapsed_sec": round(elapsed, 1), |
| "intended_emotion": intended, |
| "match": match, |
| "emotions": emotions, |
| "full_result": result, |
| } |
|
|
| for k, v in emotions.items(): |
| if not k.endswith("_distribution"): |
| print(f" {k}: {v}") |
| else: |
| dist = ", ".join(f"{kk}:{vv:.2f}" for kk, vv in sorted(v.items(), key=lambda x: -x[1])[:3]) |
| print(f" {k}: {dist}") |
| print(f" → {match}") |
| print() |
|
|
| out_path = TEST_DIR / "e2e_results.json" |
| save = {} |
| for tag, r in results.items(): |
| save[tag] = {k: v for k, v in r.items() if k != "full_result"} |
| out_path.write_text(json.dumps(save, indent=2, ensure_ascii=False)) |
|
|
| print("=" * 70) |
| print(" SUMMARY") |
| print("=" * 70) |
| print(f"\n {'Tag':<24} {'Intended':<12} {'Match':<6} {'Time':>6}") |
| print(" " + "-" * 55) |
| for tag in sorted(results.keys()): |
| r = results[tag] |
| if r.get("status") != "pass": |
| continue |
| intended = r["intended_emotion"] |
| m = r["match"] |
| t = f"{r['elapsed_sec']:.0f}s" |
| print(f" {tag:<24} {intended:<12} {m:<6} {t:>6}") |
| print(f"\n Intended-emotion match: {hit}/{total}") |
| print(f" Results saved: {out_path}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|