ustwo-api / scripts /test_meld_e2e_server.py
asdfasdfqrqwer's picture
Deploy from GitHub 2026-04-23T03:56:31Z
c857b85
Raw
History Blame Contribute Delete
8.12 kB
#!/usr/bin/env python3
"""MELD English test sets β€” E2E server test.
Uploads 8 MELD test WAVs to the deployed HF Spaces server,
runs the full pipeline (Stage 1β†’2β†’3), and compares results
against ground truth emotion labels.
Usage:
python scripts/test_meld_e2e_server.py
python scripts/test_meld_e2e_server.py --server http://localhost:8000
"""
from __future__ import annotations
import argparse
import json
import sys
import time
from pathlib import Path
import requests
PROJECT_ROOT = Path(__file__).resolve().parent.parent
MELD_DIR = PROJECT_ROOT / "data" / "meld_test"
GT_PATH = MELD_DIR / "ground_truth.json"
DEFAULT_SERVER = "https://bbbakery-ustwo-api.hf.space"
POLL_INTERVAL = 5 # seconds
MAX_WAIT = 300 # 5 minutes per file
def health_check(base: str) -> bool:
try:
r = requests.get(f"{base}/api/health", timeout=10)
data = r.json()
if data.get("status") == "ok":
print(f" βœ… Server OK ({data.get('timestamp', '?')})")
return True
except Exception as e:
print(f" ❌ Health check failed: {e}")
return False
def upload(base: str, wav_path: Path) -> str | None:
"""Upload WAV and return call_id."""
with open(wav_path, "rb") as f:
r = requests.post(
f"{base}/api/upload",
files={"file": (wav_path.name, f, "audio/wav")},
timeout=60,
)
if r.status_code != 200:
print(f" ❌ Upload failed ({r.status_code}): {r.text[:200]}")
return None
data = r.json()
return data.get("call_id")
def analyze_and_poll(base: str, call_id: str) -> dict | None:
"""Start analysis and poll until done."""
# Start
r = requests.post(f"{base}/api/analyze", params={"call_id": call_id}, timeout=30)
if r.status_code not in (200, 202):
print(f" ❌ Analyze start failed ({r.status_code}): {r.text[:200]}")
return None
data = r.json()
if data.get("status") == "done":
return data.get("result")
# Poll
elapsed = 0
while elapsed < MAX_WAIT:
time.sleep(POLL_INTERVAL)
elapsed += POLL_INTERVAL
r = requests.get(f"{base}/api/analyze/{call_id}/status", timeout=15)
data = r.json()
status = data.get("status")
if status == "done":
return data.get("result")
elif status == "error":
print(f" ❌ Pipeline error: {data.get('error', '?')}")
return None
mins, secs = divmod(elapsed, 60)
print(f" ⏳ {status}... ({int(mins)}m{int(secs)}s)", end="\r")
print(f" ❌ Timeout after {MAX_WAIT}s")
return None
def extract_emotions(result: dict) -> dict:
"""Extract emotion info from Stage 3 result."""
info = {}
# Character reactions β†’ emotions
reactions = result.get("character_reactions", [])
for i, rx in enumerate(reactions):
speaker = rx.get("speaker_id", f"speaker_{i}")
info[f"speaker_{i}_state"] = rx.get("solo_state", "?")
# Garden update
garden = result.get("garden_update", {})
info["garden_mood"] = garden.get("mood", "?")
info["garden_delta"] = garden.get("growth_delta", 0)
# Recap
recap = result.get("recap_card", {})
info["recap_headline"] = recap.get("headline", "?")
# Stage 2 emotions (if exposed in result)
stage2 = result.get("stage2_output", {})
if stage2:
for spk, summary in stage2.get("speaker_summaries", {}).items():
info[f"{spk}_dominant"] = summary.get("dominant_emotion", "?")
info[f"{spk}_distribution"] = summary.get("emotion_distribution", {})
return info
def main():
parser = argparse.ArgumentParser(description="MELD E2E server test")
parser.add_argument("--server", default=DEFAULT_SERVER, help="Server base URL")
args = parser.parse_args()
base = args.server.rstrip("/")
print("=" * 70)
print(" MELD E2E Server Test")
print(f" Server: {base}")
print("=" * 70)
# Health check
print("\n[1] Health check")
if not health_check(base):
sys.exit(1)
# Load ground truth
print("\n[2] Loading ground truth")
with open(GT_PATH) as f:
gt = json.load(f)
print(f" {len(gt)} test sets loaded")
# Process each test set
print("\n[3] Running E2E tests\n")
results = {}
pass_count = 0
fail_count = 0
for tag in sorted(gt.keys()):
wav_path = MELD_DIR / f"{tag}.wav"
if not wav_path.exists():
print(f" ⚠️ {tag}: WAV not found, skipping")
continue
gt_entry = gt[tag]
print(f" πŸ“¦ {tag} β€” {gt_entry['description']}")
print(f" Primary: {gt_entry['primary_emotion']} | Duration: {gt_entry['duration_sec']}s | Utts: {gt_entry['total_utterances']}")
# Upload
call_id = upload(base, wav_path)
if not call_id:
fail_count += 1
results[tag] = {"status": "upload_failed"}
continue
print(f" Upload OK β†’ {call_id}")
# Analyze + poll
print(f" Analyzing...", end="")
start_time = time.time()
result = analyze_and_poll(base, call_id)
elapsed = time.time() - start_time
if not result:
fail_count += 1
results[tag] = {"status": "analyze_failed", "call_id": call_id}
print()
continue
print(f"\r βœ… Done in {elapsed:.1f}s")
# Extract emotions
emotions = extract_emotions(result)
# Check pipeline completeness
has_reactions = len(result.get("character_reactions", [])) > 0
has_garden = "garden_update" in result
has_recap = "recap_card" in result
status = "pass" if (has_reactions and has_garden and has_recap) else "partial"
if status == "pass":
pass_count += 1
else:
fail_count += 1
results[tag] = {
"status": status,
"call_id": call_id,
"elapsed_sec": round(elapsed, 1),
"has_reactions": has_reactions,
"has_garden": has_garden,
"has_recap": has_recap,
"emotions": emotions,
"full_result": result,
"ground_truth": {
"primary_emotion": gt_entry["primary_emotion"],
"emotion_distribution": gt_entry["emotion_distribution"],
},
}
# Print details
print(f" Reactions: {'βœ…' if has_reactions else '❌'} | Garden: {'βœ…' if has_garden else '❌'} | Recap: {'βœ…' if has_recap else '❌'}")
for k, v in emotions.items():
if not k.startswith("full_"):
print(f" {k}: {v}")
print()
# Save results
out_path = MELD_DIR / "e2e_results.json"
with open(out_path, "w", encoding="utf-8") as f:
# Don't save full_result to keep file manageable
save_results = {}
for tag, r in results.items():
save_copy = {k: v for k, v in r.items() if k != "full_result"}
save_results[tag] = save_copy
json.dump(save_results, f, indent=2, ensure_ascii=False)
# Summary
print("=" * 70)
print(" SUMMARY")
print("=" * 70)
print(f"\n {'Tag':<25} {'Status':<10} {'Time':>6} {'Reactions':>10} {'Garden':>8} {'Recap':>7}")
print(" " + "-" * 70)
for tag in sorted(results.keys()):
r = results[tag]
status_icon = "βœ…" if r["status"] == "pass" else "❌"
elapsed = f"{r.get('elapsed_sec', 0):.1f}s" if "elapsed_sec" in r else "β€”"
react = "βœ…" if r.get("has_reactions") else "❌"
garden = "βœ…" if r.get("has_garden") else "❌"
recap = "βœ…" if r.get("has_recap") else "❌"
print(f" {tag:<25} {status_icon:<10} {elapsed:>6} {react:>10} {garden:>8} {recap:>7}")
print(f"\n Total: {pass_count} pass / {fail_count} fail / {len(results)} total")
print(f" Results saved: {out_path}")
print("=" * 70)
return fail_count == 0
if __name__ == "__main__":
success = main()
sys.exit(0 if success else 1)