File size: 8,615 Bytes
53d9f07 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 | """
SepsisPilot β Pre-Submission Validation Script
Run this before submitting to verify OpenEnv spec compliance.
Usage: python validate.py [--url http://localhost:7860]
"""
from __future__ import annotations
import argparse
import sys
import requests
PASS = "\033[92m[PASS]\033[0m"
FAIL = "\033[91m[FAIL]\033[0m"
WARN = "\033[93m[WARN]\033[0m"
INFO = "\033[94m[INFO]\033[0m"
errors = 0
def check(label: str, condition: bool, msg: str = ""):
global errors
if condition:
print(f" {PASS} {label}")
else:
print(f" {FAIL} {label} {msg}")
errors += 1
def section(title: str):
print(f"\n{'β'*50}\n {title}\n{'β'*50}")
def validate(base_url: str):
global errors
print(f"\n㪠SepsisPilot OpenEnv Validation\n Target: {base_url}\n")
# ββ 1. Health βββββββββββββββββββββββββββββββ
section("1. Health Check")
try:
r = requests.get(f"{base_url}/health", timeout=10)
check("GET /health returns 200", r.status_code == 200)
data = r.json()
check("Response contains 'status'", "status" in data)
check("Status is 'ok'", data.get("status") == "ok")
except Exception as e:
check("Server reachable", False, str(e))
print("\n [ABORT] Server not reachable. Start the server first.\n")
sys.exit(1)
# ββ 2. Tasks ββββββββββββββββββββββββββββββββ
section("2. Task Listing")
try:
r = requests.get(f"{base_url}/tasks", timeout=10)
check("GET /tasks returns 200", r.status_code == 200)
tasks = r.json()
check("Returns a list", isinstance(tasks, list))
check("At least 3 tasks", len(tasks) >= 3)
task_names = [t["name"] for t in tasks]
check("mild_sepsis present", "mild_sepsis" in task_names)
check("septic_shock present", "septic_shock" in task_names)
check("severe_mods present", "severe_mods" in task_names)
for t in tasks:
check(f" Task '{t['name']}' has difficulty", "difficulty" in t)
check(f" Task '{t['name']}' has description", "description" in t)
check(f" Task '{t['name']}' has max_steps", "max_steps" in t)
except Exception as e:
check("Tasks endpoint works", False, str(e))
# ββ 3. Episode β mild_sepsis βββββββββββββ
section("3. Episode Flow β mild_sepsis (Easy)")
_validate_episode(base_url, "mild_sepsis", max_steps=24)
# ββ 4. Episode β septic_shock ββββββββββββ
section("4. Episode Flow β septic_shock (Medium)")
_validate_episode(base_url, "septic_shock", max_steps=48)
# ββ 5. Episode β severe_mods βββββββββββββ
section("5. Episode Flow β severe_mods (Hard)")
_validate_episode(base_url, "severe_mods", max_steps=72)
# ββ 6. Grader variance ββββββββββββββββββ
section("6. Grader Score Variance (anti-trivial check)")
scores = []
actions_list = [
[5, 5, 5, 1, 1, 1], # broad + low vaso (good)
[0, 0, 0, 0, 0, 0], # no treatment (bad)
[4, 4, 4, 4, 4, 4], # high vaso only (wrong)
]
for i, actions in enumerate(actions_list):
try:
r = requests.post(f"{base_url}/reset", json={"task": "mild_sepsis", "seed": 42}, timeout=10)
for a in actions:
r = requests.post(f"{base_url}/step", json={"action": a}, timeout=10)
if r.json().get("done"):
break
# Force episode end
while not r.json().get("done"):
r = requests.post(f"{base_url}/step", json={"action": 0}, timeout=10)
grade = requests.get(f"{base_url}/grade", timeout=10).json()
scores.append(grade["score"])
except Exception as e:
scores.append(None)
print(f" {WARN} Strategy {i} failed: {e}")
valid_scores = [s for s in scores if s is not None]
check("Grader returns different scores for different strategies",
len(set(round(s, 2) for s in valid_scores)) > 1,
f"(scores: {[round(s,4) for s in valid_scores]})")
check("All scores in [0.0, 1.0]",
all(0.0 <= s <= 1.0 for s in valid_scores))
# ββ 7. Reproducibility ββββββββββββββββββ
section("7. Reproducibility (same seed = same result)")
try:
scores_run1, scores_run2 = [], []
for run_scores in (scores_run1, scores_run2):
requests.post(f"{base_url}/reset", json={"task": "mild_sepsis", "seed": 99}, timeout=10)
for _ in range(5):
r = requests.post(f"{base_url}/step", json={"action": 5}, timeout=10)
run_scores.append(round(r.json()["reward"], 4))
if r.json()["done"]:
break
check("Reward sequences are identical across runs", scores_run1 == scores_run2,
f"\n run1={scores_run1}\n run2={scores_run2}")
except Exception as e:
check("Reproducibility check", False, str(e))
# ββ 8. Error handling βββββββββββββββββββ
section("8. Error Handling")
try:
r = requests.post(f"{base_url}/step", json={"action": 99}, timeout=10)
check("Invalid action returns 4xx", r.status_code in (400, 422))
except Exception as e:
check("Invalid action error handling", False, str(e))
# ββ Summary βββββββββββββββββββββββββββββ
print(f"\n{'β'*50}")
if errors == 0:
print(f" β
All checks passed. Ready for submission!")
else:
print(f" β {errors} check(s) failed. Fix before submitting.")
print(f"{'β'*50}\n")
sys.exit(0 if errors == 0 else 1)
def _validate_episode(base_url: str, task: str, max_steps: int):
"""Run a short episode and verify all OpenEnv contracts."""
try:
# Reset
r = requests.post(f"{base_url}/reset", json={"task": task, "seed": 42}, timeout=10)
check(f"POST /reset 200", r.status_code == 200)
state = r.json()
check("Reset returns vitals", "vitals" in state)
check("Reset returns step=0", state.get("step") == 0)
check("Reset returns done=False", state.get("done") == False)
check("Reset returns alive=True", state.get("alive") == True)
# State endpoint
r = requests.get(f"{base_url}/state", timeout=10)
check("GET /state 200", r.status_code == 200)
# Step
r = requests.post(f"{base_url}/step", json={"action": 5}, timeout=10)
check("POST /step 200", r.status_code == 200)
result = r.json()
check("Step returns state", "state" in result)
check("Step returns reward (float)", isinstance(result.get("reward"), (int, float)))
check("Step returns done (bool)", isinstance(result.get("done"), bool))
check("Step returns info (dict)", isinstance(result.get("info"), dict))
check("Step increments step counter", result["state"]["step"] == 1)
# Reward range check
reward = result["reward"]
check("Reward is finite and in expected range",
-15.0 <= reward <= 10.0, f"(got {reward})")
# Run until done (fast β use fixed action)
done = result["done"]
for _ in range(max_steps):
if done:
break
r = requests.post(f"{base_url}/step", json={"action": 5}, timeout=10)
done = r.json()["done"]
# Grade
r = requests.get(f"{base_url}/grade", timeout=10)
check("GET /grade 200 after episode", r.status_code == 200)
grade = r.json()
check("Grade has score in [0,1]",
isinstance(grade.get("score"), (int, float)) and 0.0 <= grade["score"] <= 1.0,
f"(got {grade.get('score')})")
check("Grade has reason string", isinstance(grade.get("reason"), str))
check("Grade has metrics dict", isinstance(grade.get("metrics"), dict))
check("Grade has passed bool", isinstance(grade.get("passed"), bool))
except Exception as e:
check(f"Episode for {task} completed without error", False, str(e))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--url", default="http://localhost:7860")
args = parser.parse_args()
validate(args.url)
|