Spaces:
Sleeping
Sleeping
| """ | |
| SepsisPilot β Pre-Submission Validation Script | |
| Run this before submitting to verify OpenEnv spec compliance. | |
| Usage: python validate.py [--url http://localhost:7860] | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import sys | |
| import requests | |
| PASS = "\033[92m[PASS]\033[0m" | |
| FAIL = "\033[91m[FAIL]\033[0m" | |
| WARN = "\033[93m[WARN]\033[0m" | |
| INFO = "\033[94m[INFO]\033[0m" | |
| errors = 0 | |
| def check(label: str, condition: bool, msg: str = ""): | |
| global errors | |
| if condition: | |
| print(f" {PASS} {label}") | |
| else: | |
| print(f" {FAIL} {label} {msg}") | |
| errors += 1 | |
| def section(title: str): | |
| print(f"\n{'β'*50}\n {title}\n{'β'*50}") | |
| def validate(base_url: str): | |
| global errors | |
| print(f"\n㪠SepsisPilot OpenEnv Validation\n Target: {base_url}\n") | |
| # ββ 1. Health βββββββββββββββββββββββββββββββ | |
| section("1. Health Check") | |
| try: | |
| r = requests.get(f"{base_url}/health", timeout=10) | |
| check("GET /health returns 200", r.status_code == 200) | |
| data = r.json() | |
| check("Response contains 'status'", "status" in data) | |
| check("Status is 'ok'", data.get("status") == "ok") | |
| except Exception as e: | |
| check("Server reachable", False, str(e)) | |
| print("\n [ABORT] Server not reachable. Start the server first.\n") | |
| sys.exit(1) | |
| # ββ 2. Tasks ββββββββββββββββββββββββββββββββ | |
| section("2. Task Listing") | |
| try: | |
| r = requests.get(f"{base_url}/tasks", timeout=10) | |
| check("GET /tasks returns 200", r.status_code == 200) | |
| tasks = r.json() | |
| check("Returns a list", isinstance(tasks, list)) | |
| check("At least 3 tasks", len(tasks) >= 3) | |
| task_names = [t["name"] for t in tasks] | |
| check("mild_sepsis present", "mild_sepsis" in task_names) | |
| check("septic_shock present", "septic_shock" in task_names) | |
| check("severe_mods present", "severe_mods" in task_names) | |
| for t in tasks: | |
| check(f" Task '{t['name']}' has difficulty", "difficulty" in t) | |
| check(f" Task '{t['name']}' has description", "description" in t) | |
| check(f" Task '{t['name']}' has max_steps", "max_steps" in t) | |
| except Exception as e: | |
| check("Tasks endpoint works", False, str(e)) | |
| # ββ 3. Episode β mild_sepsis βββββββββββββ | |
| section("3. Episode Flow β mild_sepsis (Easy)") | |
| _validate_episode(base_url, "mild_sepsis", max_steps=24) | |
| # ββ 4. Episode β septic_shock ββββββββββββ | |
| section("4. Episode Flow β septic_shock (Medium)") | |
| _validate_episode(base_url, "septic_shock", max_steps=48) | |
| # ββ 5. Episode β severe_mods βββββββββββββ | |
| section("5. Episode Flow β severe_mods (Hard)") | |
| _validate_episode(base_url, "severe_mods", max_steps=72) | |
| # ββ 6. Grader variance ββββββββββββββββββ | |
| section("6. Grader Score Variance (anti-trivial check)") | |
| scores = [] | |
| actions_list = [ | |
| [5, 5, 5, 1, 1, 1], # broad + low vaso (good) | |
| [0, 0, 0, 0, 0, 0], # no treatment (bad) | |
| [4, 4, 4, 4, 4, 4], # high vaso only (wrong) | |
| ] | |
| for i, actions in enumerate(actions_list): | |
| try: | |
| r = requests.post(f"{base_url}/reset", json={"task": "mild_sepsis", "seed": 42}, timeout=10) | |
| for a in actions: | |
| r = requests.post(f"{base_url}/step", json={"action": a}, timeout=10) | |
| if r.json().get("done"): | |
| break | |
| # Force episode end | |
| while not r.json().get("done"): | |
| r = requests.post(f"{base_url}/step", json={"action": 0}, timeout=10) | |
| grade = requests.get(f"{base_url}/grade", timeout=10).json() | |
| scores.append(grade["score"]) | |
| except Exception as e: | |
| scores.append(None) | |
| print(f" {WARN} Strategy {i} failed: {e}") | |
| valid_scores = [s for s in scores if s is not None] | |
| check("Grader returns different scores for different strategies", | |
| len(set(round(s, 2) for s in valid_scores)) > 1, | |
| f"(scores: {[round(s,4) for s in valid_scores]})") | |
| check("All scores in [0.0, 1.0]", | |
| all(0.0 <= s <= 1.0 for s in valid_scores)) | |
| # ββ 7. Reproducibility ββββββββββββββββββ | |
| section("7. Reproducibility (same seed = same result)") | |
| try: | |
| scores_run1, scores_run2 = [], [] | |
| for run_scores in (scores_run1, scores_run2): | |
| requests.post(f"{base_url}/reset", json={"task": "mild_sepsis", "seed": 99}, timeout=10) | |
| for _ in range(5): | |
| r = requests.post(f"{base_url}/step", json={"action": 5}, timeout=10) | |
| run_scores.append(round(r.json()["reward"], 4)) | |
| if r.json()["done"]: | |
| break | |
| check("Reward sequences are identical across runs", scores_run1 == scores_run2, | |
| f"\n run1={scores_run1}\n run2={scores_run2}") | |
| except Exception as e: | |
| check("Reproducibility check", False, str(e)) | |
| # ββ 8. Error handling βββββββββββββββββββ | |
| section("8. Error Handling") | |
| try: | |
| r = requests.post(f"{base_url}/step", json={"action": 99}, timeout=10) | |
| check("Invalid action returns 4xx", r.status_code in (400, 422)) | |
| except Exception as e: | |
| check("Invalid action error handling", False, str(e)) | |
| # ββ Summary βββββββββββββββββββββββββββββ | |
| print(f"\n{'β'*50}") | |
| if errors == 0: | |
| print(f" β All checks passed. Ready for submission!") | |
| else: | |
| print(f" β {errors} check(s) failed. Fix before submitting.") | |
| print(f"{'β'*50}\n") | |
| sys.exit(0 if errors == 0 else 1) | |
| def _validate_episode(base_url: str, task: str, max_steps: int): | |
| """Run a short episode and verify all OpenEnv contracts.""" | |
| try: | |
| # Reset | |
| r = requests.post(f"{base_url}/reset", json={"task": task, "seed": 42}, timeout=10) | |
| check(f"POST /reset 200", r.status_code == 200) | |
| state = r.json() | |
| check("Reset returns vitals", "vitals" in state) | |
| check("Reset returns step=0", state.get("step") == 0) | |
| check("Reset returns done=False", state.get("done") == False) | |
| check("Reset returns alive=True", state.get("alive") == True) | |
| # State endpoint | |
| r = requests.get(f"{base_url}/state", timeout=10) | |
| check("GET /state 200", r.status_code == 200) | |
| # Step | |
| r = requests.post(f"{base_url}/step", json={"action": 5}, timeout=10) | |
| check("POST /step 200", r.status_code == 200) | |
| result = r.json() | |
| check("Step returns state", "state" in result) | |
| check("Step returns reward (float)", isinstance(result.get("reward"), (int, float))) | |
| check("Step returns done (bool)", isinstance(result.get("done"), bool)) | |
| check("Step returns info (dict)", isinstance(result.get("info"), dict)) | |
| check("Step increments step counter", result["state"]["step"] == 1) | |
| # Reward range check | |
| reward = result["reward"] | |
| check("Reward is finite and in expected range", | |
| -15.0 <= reward <= 10.0, f"(got {reward})") | |
| # Run until done (fast β use fixed action) | |
| done = result["done"] | |
| for _ in range(max_steps): | |
| if done: | |
| break | |
| r = requests.post(f"{base_url}/step", json={"action": 5}, timeout=10) | |
| done = r.json()["done"] | |
| # Grade | |
| r = requests.get(f"{base_url}/grade", timeout=10) | |
| check("GET /grade 200 after episode", r.status_code == 200) | |
| grade = r.json() | |
| check("Grade has score in [0,1]", | |
| isinstance(grade.get("score"), (int, float)) and 0.0 <= grade["score"] <= 1.0, | |
| f"(got {grade.get('score')})") | |
| check("Grade has reason string", isinstance(grade.get("reason"), str)) | |
| check("Grade has metrics dict", isinstance(grade.get("metrics"), dict)) | |
| check("Grade has passed bool", isinstance(grade.get("passed"), bool)) | |
| except Exception as e: | |
| check(f"Episode for {task} completed without error", False, str(e)) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--url", default="http://localhost:7860") | |
| args = parser.parse_args() | |
| validate(args.url) | |