File size: 8,615 Bytes
53d9f07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
"""
SepsisPilot β€” Pre-Submission Validation Script
Run this before submitting to verify OpenEnv spec compliance.

Usage: python validate.py [--url http://localhost:7860]
"""

from __future__ import annotations
import argparse
import sys
import requests

PASS = "\033[92m[PASS]\033[0m"
FAIL = "\033[91m[FAIL]\033[0m"
WARN = "\033[93m[WARN]\033[0m"
INFO = "\033[94m[INFO]\033[0m"

errors = 0


def check(label: str, condition: bool, msg: str = ""):
    global errors
    if condition:
        print(f"  {PASS} {label}")
    else:
        print(f"  {FAIL} {label} {msg}")
        errors += 1


def section(title: str):
    print(f"\n{'─'*50}\n  {title}\n{'─'*50}")


def validate(base_url: str):
    global errors
    print(f"\nπŸ”¬ SepsisPilot OpenEnv Validation\n   Target: {base_url}\n")

    # ── 1. Health ───────────────────────────────
    section("1. Health Check")
    try:
        r = requests.get(f"{base_url}/health", timeout=10)
        check("GET /health returns 200", r.status_code == 200)
        data = r.json()
        check("Response contains 'status'", "status" in data)
        check("Status is 'ok'", data.get("status") == "ok")
    except Exception as e:
        check("Server reachable", False, str(e))
        print("\n  [ABORT] Server not reachable. Start the server first.\n")
        sys.exit(1)

    # ── 2. Tasks ────────────────────────────────
    section("2. Task Listing")
    try:
        r = requests.get(f"{base_url}/tasks", timeout=10)
        check("GET /tasks returns 200", r.status_code == 200)
        tasks = r.json()
        check("Returns a list", isinstance(tasks, list))
        check("At least 3 tasks", len(tasks) >= 3)
        task_names = [t["name"] for t in tasks]
        check("mild_sepsis present", "mild_sepsis" in task_names)
        check("septic_shock present", "septic_shock" in task_names)
        check("severe_mods present", "severe_mods" in task_names)
        for t in tasks:
            check(f"  Task '{t['name']}' has difficulty", "difficulty" in t)
            check(f"  Task '{t['name']}' has description", "description" in t)
            check(f"  Task '{t['name']}' has max_steps", "max_steps" in t)
    except Exception as e:
        check("Tasks endpoint works", False, str(e))

    # ── 3. Episode β€” mild_sepsis ─────────────
    section("3. Episode Flow β€” mild_sepsis (Easy)")
    _validate_episode(base_url, "mild_sepsis", max_steps=24)

    # ── 4. Episode β€” septic_shock ────────────
    section("4. Episode Flow β€” septic_shock (Medium)")
    _validate_episode(base_url, "septic_shock", max_steps=48)

    # ── 5. Episode β€” severe_mods ─────────────
    section("5. Episode Flow β€” severe_mods (Hard)")
    _validate_episode(base_url, "severe_mods", max_steps=72)

    # ── 6. Grader variance ──────────────────
    section("6. Grader Score Variance (anti-trivial check)")
    scores = []
    actions_list = [
        [5, 5, 5, 1, 1, 1],   # broad + low vaso (good)
        [0, 0, 0, 0, 0, 0],   # no treatment (bad)
        [4, 4, 4, 4, 4, 4],   # high vaso only (wrong)
    ]
    for i, actions in enumerate(actions_list):
        try:
            r = requests.post(f"{base_url}/reset", json={"task": "mild_sepsis", "seed": 42}, timeout=10)
            for a in actions:
                r = requests.post(f"{base_url}/step", json={"action": a}, timeout=10)
                if r.json().get("done"):
                    break
            # Force episode end
            while not r.json().get("done"):
                r = requests.post(f"{base_url}/step", json={"action": 0}, timeout=10)
            grade = requests.get(f"{base_url}/grade", timeout=10).json()
            scores.append(grade["score"])
        except Exception as e:
            scores.append(None)
            print(f"  {WARN} Strategy {i} failed: {e}")

    valid_scores = [s for s in scores if s is not None]
    check("Grader returns different scores for different strategies", 
          len(set(round(s, 2) for s in valid_scores)) > 1,
          f"(scores: {[round(s,4) for s in valid_scores]})")
    check("All scores in [0.0, 1.0]",
          all(0.0 <= s <= 1.0 for s in valid_scores))

    # ── 7. Reproducibility ──────────────────
    section("7. Reproducibility (same seed = same result)")
    try:
        scores_run1, scores_run2 = [], []
        for run_scores in (scores_run1, scores_run2):
            requests.post(f"{base_url}/reset", json={"task": "mild_sepsis", "seed": 99}, timeout=10)
            for _ in range(5):
                r = requests.post(f"{base_url}/step", json={"action": 5}, timeout=10)
                run_scores.append(round(r.json()["reward"], 4))
                if r.json()["done"]:
                    break
        check("Reward sequences are identical across runs", scores_run1 == scores_run2,
              f"\n    run1={scores_run1}\n    run2={scores_run2}")
    except Exception as e:
        check("Reproducibility check", False, str(e))

    # ── 8. Error handling ───────────────────
    section("8. Error Handling")
    try:
        r = requests.post(f"{base_url}/step", json={"action": 99}, timeout=10)
        check("Invalid action returns 4xx", r.status_code in (400, 422))
    except Exception as e:
        check("Invalid action error handling", False, str(e))

    # ── Summary ─────────────────────────────
    print(f"\n{'═'*50}")
    if errors == 0:
        print(f"  βœ…  All checks passed. Ready for submission!")
    else:
        print(f"  ❌  {errors} check(s) failed. Fix before submitting.")
    print(f"{'═'*50}\n")
    sys.exit(0 if errors == 0 else 1)


def _validate_episode(base_url: str, task: str, max_steps: int):
    """Run a short episode and verify all OpenEnv contracts."""
    try:
        # Reset
        r = requests.post(f"{base_url}/reset", json={"task": task, "seed": 42}, timeout=10)
        check(f"POST /reset 200", r.status_code == 200)
        state = r.json()
        check("Reset returns vitals", "vitals" in state)
        check("Reset returns step=0", state.get("step") == 0)
        check("Reset returns done=False", state.get("done") == False)
        check("Reset returns alive=True", state.get("alive") == True)

        # State endpoint
        r = requests.get(f"{base_url}/state", timeout=10)
        check("GET /state 200", r.status_code == 200)

        # Step
        r = requests.post(f"{base_url}/step", json={"action": 5}, timeout=10)
        check("POST /step 200", r.status_code == 200)
        result = r.json()
        check("Step returns state", "state" in result)
        check("Step returns reward (float)", isinstance(result.get("reward"), (int, float)))
        check("Step returns done (bool)", isinstance(result.get("done"), bool))
        check("Step returns info (dict)", isinstance(result.get("info"), dict))
        check("Step increments step counter", result["state"]["step"] == 1)

        # Reward range check
        reward = result["reward"]
        check("Reward is finite and in expected range",
              -15.0 <= reward <= 10.0, f"(got {reward})")

        # Run until done (fast β€” use fixed action)
        done = result["done"]
        for _ in range(max_steps):
            if done:
                break
            r = requests.post(f"{base_url}/step", json={"action": 5}, timeout=10)
            done = r.json()["done"]

        # Grade
        r = requests.get(f"{base_url}/grade", timeout=10)
        check("GET /grade 200 after episode", r.status_code == 200)
        grade = r.json()
        check("Grade has score in [0,1]",
              isinstance(grade.get("score"), (int, float)) and 0.0 <= grade["score"] <= 1.0,
              f"(got {grade.get('score')})")
        check("Grade has reason string", isinstance(grade.get("reason"), str))
        check("Grade has metrics dict", isinstance(grade.get("metrics"), dict))
        check("Grade has passed bool", isinstance(grade.get("passed"), bool))

    except Exception as e:
        check(f"Episode for {task} completed without error", False, str(e))


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--url", default="http://localhost:7860")
    args = parser.parse_args()
    validate(args.url)