| |
| """MELD English test sets β E2E server test. |
| |
| Uploads 8 MELD test WAVs to the deployed HF Spaces server, |
| runs the full pipeline (Stage 1β2β3), and compares results |
| against ground truth emotion labels. |
| |
| Usage: |
| python scripts/test_meld_e2e_server.py |
| python scripts/test_meld_e2e_server.py --server http://localhost:8000 |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import sys |
| import time |
| from pathlib import Path |
|
|
| import requests |
|
|
| PROJECT_ROOT = Path(__file__).resolve().parent.parent |
| MELD_DIR = PROJECT_ROOT / "data" / "meld_test" |
| GT_PATH = MELD_DIR / "ground_truth.json" |
|
|
| DEFAULT_SERVER = "https://bbbakery-ustwo-api.hf.space" |
| POLL_INTERVAL = 5 |
| MAX_WAIT = 300 |
|
|
|
|
| def health_check(base: str) -> bool: |
| try: |
| r = requests.get(f"{base}/api/health", timeout=10) |
| data = r.json() |
| if data.get("status") == "ok": |
| print(f" β
Server OK ({data.get('timestamp', '?')})") |
| return True |
| except Exception as e: |
| print(f" β Health check failed: {e}") |
| return False |
|
|
|
|
| def upload(base: str, wav_path: Path) -> str | None: |
| """Upload WAV and return call_id.""" |
| with open(wav_path, "rb") as f: |
| r = requests.post( |
| f"{base}/api/upload", |
| files={"file": (wav_path.name, f, "audio/wav")}, |
| timeout=60, |
| ) |
| if r.status_code != 200: |
| print(f" β Upload failed ({r.status_code}): {r.text[:200]}") |
| return None |
| data = r.json() |
| return data.get("call_id") |
|
|
|
|
| def analyze_and_poll(base: str, call_id: str) -> dict | None: |
| """Start analysis and poll until done.""" |
| |
| r = requests.post(f"{base}/api/analyze", params={"call_id": call_id}, timeout=30) |
| if r.status_code not in (200, 202): |
| print(f" β Analyze start failed ({r.status_code}): {r.text[:200]}") |
| return None |
|
|
| data = r.json() |
| if data.get("status") == "done": |
| return data.get("result") |
|
|
| |
| elapsed = 0 |
| while elapsed < MAX_WAIT: |
| time.sleep(POLL_INTERVAL) |
| elapsed += POLL_INTERVAL |
|
|
| r = requests.get(f"{base}/api/analyze/{call_id}/status", timeout=15) |
| data = r.json() |
| status = data.get("status") |
|
|
| if status == "done": |
| return data.get("result") |
| elif status == "error": |
| print(f" β Pipeline error: {data.get('error', '?')}") |
| return None |
|
|
| mins, secs = divmod(elapsed, 60) |
| print(f" β³ {status}... ({int(mins)}m{int(secs)}s)", end="\r") |
|
|
| print(f" β Timeout after {MAX_WAIT}s") |
| return None |
|
|
|
|
| def extract_emotions(result: dict) -> dict: |
| """Extract emotion info from Stage 3 result.""" |
| info = {} |
|
|
| |
| reactions = result.get("character_reactions", []) |
| for i, rx in enumerate(reactions): |
| speaker = rx.get("speaker_id", f"speaker_{i}") |
| info[f"speaker_{i}_state"] = rx.get("solo_state", "?") |
|
|
| |
| garden = result.get("garden_update", {}) |
| info["garden_mood"] = garden.get("mood", "?") |
| info["garden_delta"] = garden.get("growth_delta", 0) |
|
|
| |
| recap = result.get("recap_card", {}) |
| info["recap_headline"] = recap.get("headline", "?") |
|
|
| |
| stage2 = result.get("stage2_output", {}) |
| if stage2: |
| for spk, summary in stage2.get("speaker_summaries", {}).items(): |
| info[f"{spk}_dominant"] = summary.get("dominant_emotion", "?") |
| info[f"{spk}_distribution"] = summary.get("emotion_distribution", {}) |
|
|
| return info |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="MELD E2E server test") |
| parser.add_argument("--server", default=DEFAULT_SERVER, help="Server base URL") |
| args = parser.parse_args() |
| base = args.server.rstrip("/") |
|
|
| print("=" * 70) |
| print(" MELD E2E Server Test") |
| print(f" Server: {base}") |
| print("=" * 70) |
|
|
| |
| print("\n[1] Health check") |
| if not health_check(base): |
| sys.exit(1) |
|
|
| |
| print("\n[2] Loading ground truth") |
| with open(GT_PATH) as f: |
| gt = json.load(f) |
| print(f" {len(gt)} test sets loaded") |
|
|
| |
| print("\n[3] Running E2E tests\n") |
| results = {} |
| pass_count = 0 |
| fail_count = 0 |
|
|
| for tag in sorted(gt.keys()): |
| wav_path = MELD_DIR / f"{tag}.wav" |
| if not wav_path.exists(): |
| print(f" β οΈ {tag}: WAV not found, skipping") |
| continue |
|
|
| gt_entry = gt[tag] |
| print(f" π¦ {tag} β {gt_entry['description']}") |
| print(f" Primary: {gt_entry['primary_emotion']} | Duration: {gt_entry['duration_sec']}s | Utts: {gt_entry['total_utterances']}") |
|
|
| |
| call_id = upload(base, wav_path) |
| if not call_id: |
| fail_count += 1 |
| results[tag] = {"status": "upload_failed"} |
| continue |
| print(f" Upload OK β {call_id}") |
|
|
| |
| print(f" Analyzing...", end="") |
| start_time = time.time() |
| result = analyze_and_poll(base, call_id) |
| elapsed = time.time() - start_time |
|
|
| if not result: |
| fail_count += 1 |
| results[tag] = {"status": "analyze_failed", "call_id": call_id} |
| print() |
| continue |
|
|
| print(f"\r β
Done in {elapsed:.1f}s") |
|
|
| |
| emotions = extract_emotions(result) |
|
|
| |
| has_reactions = len(result.get("character_reactions", [])) > 0 |
| has_garden = "garden_update" in result |
| has_recap = "recap_card" in result |
|
|
| status = "pass" if (has_reactions and has_garden and has_recap) else "partial" |
| if status == "pass": |
| pass_count += 1 |
| else: |
| fail_count += 1 |
|
|
| results[tag] = { |
| "status": status, |
| "call_id": call_id, |
| "elapsed_sec": round(elapsed, 1), |
| "has_reactions": has_reactions, |
| "has_garden": has_garden, |
| "has_recap": has_recap, |
| "emotions": emotions, |
| "full_result": result, |
| "ground_truth": { |
| "primary_emotion": gt_entry["primary_emotion"], |
| "emotion_distribution": gt_entry["emotion_distribution"], |
| }, |
| } |
|
|
| |
| print(f" Reactions: {'β
' if has_reactions else 'β'} | Garden: {'β
' if has_garden else 'β'} | Recap: {'β
' if has_recap else 'β'}") |
| for k, v in emotions.items(): |
| if not k.startswith("full_"): |
| print(f" {k}: {v}") |
| print() |
|
|
| |
| out_path = MELD_DIR / "e2e_results.json" |
| with open(out_path, "w", encoding="utf-8") as f: |
| |
| save_results = {} |
| for tag, r in results.items(): |
| save_copy = {k: v for k, v in r.items() if k != "full_result"} |
| save_results[tag] = save_copy |
| json.dump(save_results, f, indent=2, ensure_ascii=False) |
|
|
| |
| print("=" * 70) |
| print(" SUMMARY") |
| print("=" * 70) |
| print(f"\n {'Tag':<25} {'Status':<10} {'Time':>6} {'Reactions':>10} {'Garden':>8} {'Recap':>7}") |
| print(" " + "-" * 70) |
| for tag in sorted(results.keys()): |
| r = results[tag] |
| status_icon = "β
" if r["status"] == "pass" else "β" |
| elapsed = f"{r.get('elapsed_sec', 0):.1f}s" if "elapsed_sec" in r else "β" |
| react = "β
" if r.get("has_reactions") else "β" |
| garden = "β
" if r.get("has_garden") else "β" |
| recap = "β
" if r.get("has_recap") else "β" |
| print(f" {tag:<25} {status_icon:<10} {elapsed:>6} {react:>10} {garden:>8} {recap:>7}") |
|
|
| print(f"\n Total: {pass_count} pass / {fail_count} fail / {len(results)} total") |
| print(f" Results saved: {out_path}") |
| print("=" * 70) |
|
|
| return fail_count == 0 |
|
|
|
|
| if __name__ == "__main__": |
| success = main() |
| sys.exit(0 if success else 1) |
|
|