ustwo-api / scripts /test_20hours_e2e_server.py
asdfasdfqrqwer's picture
Deploy from GitHub 2026-04-23T03:56:31Z
c857b85
Raw
History Blame Contribute Delete
7.07 kB
#!/usr/bin/env python3
"""20Hours Korean demo set — E2E server test.
Uploads 7 curated Korean demo WAVs to the deployed HF Spaces server,
runs the full pipeline (Stage 1 → 2 → 3), and compares results
against the intended demo emotion labels (data/20hours_test/ground_truth.json).
Usage:
python scripts/test_20hours_e2e_server.py
python scripts/test_20hours_e2e_server.py --server http://localhost:8000
"""
from __future__ import annotations
import argparse
import json
import sys
import time
from pathlib import Path
import requests
PROJECT_ROOT = Path(__file__).resolve().parent.parent
TEST_DIR = PROJECT_ROOT / "data" / "20hours_test"
GT_PATH = TEST_DIR / "ground_truth.json"
DEFAULT_SERVER = "https://bbbakery-ustwo-api.hf.space"
POLL_INTERVAL = 5
MAX_WAIT = 300
def health_check(base: str) -> bool:
try:
r = requests.get(f"{base}/api/health", timeout=10)
data = r.json()
if data.get("status") == "ok":
print(f" Server OK ({data.get('timestamp', '?')})")
return True
except Exception as e:
print(f" Health check failed: {e}")
return False
def upload(base: str, wav_path: Path) -> str | None:
with open(wav_path, "rb") as f:
r = requests.post(
f"{base}/api/upload",
files={"file": (wav_path.name, f, "audio/wav")},
timeout=60,
)
if r.status_code != 200:
print(f" Upload failed ({r.status_code}): {r.text[:200]}")
return None
return r.json().get("call_id")
def analyze_and_poll(base: str, call_id: str) -> dict | None:
r = requests.post(f"{base}/api/analyze", params={"call_id": call_id}, timeout=30)
if r.status_code not in (200, 202):
print(f" Analyze start failed ({r.status_code}): {r.text[:200]}")
return None
data = r.json()
if data.get("status") == "done":
return data.get("result")
elapsed = 0
while elapsed < MAX_WAIT:
time.sleep(POLL_INTERVAL)
elapsed += POLL_INTERVAL
r = requests.get(f"{base}/api/analyze/{call_id}/status", timeout=15)
data = r.json()
status = data.get("status")
if status == "done":
return data.get("result")
if status == "error":
print(f" Pipeline error: {data.get('error', '?')}")
return None
mins, secs = divmod(elapsed, 60)
print(f" {status}... ({int(mins)}m{int(secs)}s)", end="\r")
print(f" Timeout after {MAX_WAIT}s")
return None
def extract_emotions(result: dict) -> dict:
info: dict = {}
reactions = result.get("character_reactions", [])
for i, rx in enumerate(reactions):
info[f"speaker_{i}_state"] = rx.get("solo_state", "?")
garden = result.get("garden_update", {})
info["garden_mood"] = garden.get("mood", "?")
info["garden_delta"] = garden.get("growth_delta", 0)
recap = result.get("recap_card", {}) or {}
info["recap_headline"] = recap.get("headline") or recap.get("title", "?")
stage2 = result.get("stage2_output", {})
if stage2:
for spk, summary in stage2.get("speaker_summaries", {}).items():
info[f"{spk}_dominant"] = summary.get("dominant_emotion", "?")
info[f"{spk}_distribution"] = summary.get("emotion_distribution", {})
# Segment-level language breakdown
emotions = result.get("emotions") or stage2.get("emotions", [])
segs_by_lang: dict[str, int] = {}
for e in emotions:
lang = e.get("language") or "?"
segs_by_lang[lang] = segs_by_lang.get(lang, 0) + 1
info["segments"] = len(emotions)
info["segments_by_lang"] = segs_by_lang
return info
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--server", default=DEFAULT_SERVER)
args = parser.parse_args()
base = args.server.rstrip("/")
print("=" * 70)
print(" 20Hours Korean Demo — E2E Server Test")
print(f" Server: {base}")
print("=" * 70)
print("\n[1] Health check")
if not health_check(base):
sys.exit(1)
print("\n[2] Loading intended emotion labels")
gt = json.loads(GT_PATH.read_text())
print(f" {len(gt)} demo clips loaded")
print("\n[3] Running E2E tests\n")
results = {}
hit = 0
total = 0
for tag in sorted(gt.keys()):
wav_path = TEST_DIR / f"{tag}.wav"
if not wav_path.exists():
print(f" {tag}: WAV not found, skipping")
continue
gt_entry = gt[tag]
print(f" {tag}{gt_entry['description'][:55]}")
print(f" Intended: {gt_entry['primary_emotion']} | Duration: {gt_entry['duration_sec']}s | Utts: {gt_entry['total_utterances']}")
call_id = upload(base, wav_path)
if not call_id:
results[tag] = {"status": "upload_failed"}
continue
print(f" Upload OK → {call_id}")
print(f" Analyzing...", end="")
start_time = time.time()
result = analyze_and_poll(base, call_id)
elapsed = time.time() - start_time
if not result:
results[tag] = {"status": "analyze_failed", "call_id": call_id}
print()
continue
print(f"\r Done in {elapsed:.1f}s ")
emotions = extract_emotions(result)
total += 1
intended = gt_entry["primary_emotion"]
speaker_states = {k: v for k, v in emotions.items() if k.endswith("_state")}
if intended in speaker_states.values():
hit += 1
match = "HIT"
else:
match = "miss"
results[tag] = {
"status": "pass",
"call_id": call_id,
"elapsed_sec": round(elapsed, 1),
"intended_emotion": intended,
"match": match,
"emotions": emotions,
"full_result": result,
}
for k, v in emotions.items():
if not k.endswith("_distribution"):
print(f" {k}: {v}")
else:
dist = ", ".join(f"{kk}:{vv:.2f}" for kk, vv in sorted(v.items(), key=lambda x: -x[1])[:3])
print(f" {k}: {dist}")
print(f" → {match}")
print()
out_path = TEST_DIR / "e2e_results.json"
save = {}
for tag, r in results.items():
save[tag] = {k: v for k, v in r.items() if k != "full_result"}
out_path.write_text(json.dumps(save, indent=2, ensure_ascii=False))
print("=" * 70)
print(" SUMMARY")
print("=" * 70)
print(f"\n {'Tag':<24} {'Intended':<12} {'Match':<6} {'Time':>6}")
print(" " + "-" * 55)
for tag in sorted(results.keys()):
r = results[tag]
if r.get("status") != "pass":
continue
intended = r["intended_emotion"]
m = r["match"]
t = f"{r['elapsed_sec']:.0f}s"
print(f" {tag:<24} {intended:<12} {m:<6} {t:>6}")
print(f"\n Intended-emotion match: {hit}/{total}")
print(f" Results saved: {out_path}")
if __name__ == "__main__":
main()