Spaces:
Sleeping
Sleeping
File size: 7,247 Bytes
404c45f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 | """
Pre-submission validator for Clinical Trial Triage OpenEnv.
Checks:
1. Core endpoints respond and return expected shapes.
2. /tasks returns >= 3 tasks.
3. Each task can be completed and /grader returns score in [0.0, 1.0].
4. Root inference script runs without errors and produces outputs/baseline_results.json.
Usage:
python scripts/validate_submission.py
Notes:
- Requires the API server to be running (default: http://localhost:8000).
- Uses deterministic heuristic actions for endpoint and grader checks.
"""
from __future__ import annotations
import json
import os
import subprocess
import sys
from pathlib import Path
from typing import Any, Dict
import httpx
# Ensure project root import resolution
ROOT = Path(__file__).resolve().parent.parent
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from models import TaskID
from scripts.heuristic_baseline import (
_heuristic_ae_triage,
_heuristic_deviation_audit,
_heuristic_narrative,
)
from tasks.case_bank import AE_CASES, DEVIATION_CASES, NARRATIVE_CASES
BASE_URL = os.environ.get("VALIDATOR_BASE_URL", "http://localhost:8000").rstrip("/")
OUTPUT_FILE = ROOT / "outputs" / "baseline_results.json"
INFERENCE_TIMEOUT_SECONDS = 20 * 60
def _assert(condition: bool, message: str) -> None:
if not condition:
raise AssertionError(message)
def _post_json(client: httpx.Client, path: str, payload: Dict[str, Any]) -> Dict[str, Any]:
response = client.post(path, json=payload)
_assert(response.status_code == 200, f"{path} returned {response.status_code}: {response.text}")
return response.json()
def _run_episode(client: httpx.Client, task_id: str) -> float:
reset_data = _post_json(client, "/reset", {"task_id": task_id})
_assert("observation" in reset_data, f"/reset missing observation for task {task_id}")
if task_id == TaskID.ADVERSE_EVENT_TRIAGE:
for case in AE_CASES:
step_payload = _heuristic_ae_triage(case).model_dump()
step_response = _post_json(client, "/step", step_payload)
if step_response.get("done"):
break
elif task_id == TaskID.PROTOCOL_DEVIATION_AUDIT:
for case in DEVIATION_CASES:
step_payload = _heuristic_deviation_audit(case).model_dump()
step_response = _post_json(client, "/step", step_payload)
if step_response.get("done"):
break
elif task_id == TaskID.SAFETY_NARRATIVE_GENERATION:
for case in NARRATIVE_CASES:
step_payload = _heuristic_narrative(case).model_dump()
step_response = _post_json(client, "/step", step_payload)
if step_response.get("done"):
break
else:
raise AssertionError(f"Unknown task_id: {task_id}")
grader_response = client.get("/grader")
_assert(grader_response.status_code == 200, f"/grader failed for task {task_id}: {grader_response.text}")
grader_data = grader_response.json()
score = grader_data.get("normalized_score")
_assert(isinstance(score, (int, float)), f"normalized_score missing for task {task_id}")
_assert(0.0 <= float(score) <= 1.0, f"normalized_score out of range for task {task_id}: {score}")
return float(score)
def _check_openenv_endpoints(client: httpx.Client) -> None:
metadata = client.get("/openenv/metadata")
_assert(metadata.status_code == 200, f"/openenv/metadata returned {metadata.status_code}")
schema = client.get("/openenv/schema")
_assert(schema.status_code == 200, f"/openenv/schema returned {schema.status_code}")
reset = client.post("/openenv/reset", json={"task_id": TaskID.ADVERSE_EVENT_TRIAGE})
_assert(reset.status_code == 200, f"/openenv/reset returned {reset.status_code}: {reset.text}")
reset_payload = reset.json()
_assert("observation" in reset_payload, "/openenv/reset missing observation")
step = client.post(
"/openenv/step",
json={
"action": {
"task_id": TaskID.ADVERSE_EVENT_TRIAGE,
"ae_triage": {
"severity_classification": "severe",
"reporting_timeline": "15-day",
"meddra_soc": "Cardiac disorders",
"meddra_preferred_term": "Myocardial infarction",
"is_serious": True,
"rationale": "validator openenv smoke action",
},
}
},
)
_assert(step.status_code == 200, f"/openenv/step returned {step.status_code}: {step.text}")
state = client.get("/openenv/state")
_assert(state.status_code == 200, f"/openenv/state returned {state.status_code}: {state.text}")
health = client.get("/openenv/health")
_assert(health.status_code == 200, f"/openenv/health returned {health.status_code}")
def _run_baseline_script() -> Dict[str, Any]:
cmd = [sys.executable, str(ROOT / "inference.py")]
try:
process = subprocess.run(
cmd,
cwd=str(ROOT),
capture_output=True,
text=True,
timeout=INFERENCE_TIMEOUT_SECONDS,
)
except subprocess.TimeoutExpired as exc:
raise AssertionError(
f"inference.py exceeded runtime budget ({INFERENCE_TIMEOUT_SECONDS}s). "
"Submission requires completion under 20 minutes."
) from exc
_assert(process.returncode == 0, f"inference.py failed:\n{process.stderr}\n{process.stdout}")
_assert(OUTPUT_FILE.exists(), f"Missing baseline output file: {OUTPUT_FILE}")
with open(OUTPUT_FILE, "r", encoding="utf-8") as file:
data = json.load(file)
tasks = data.get("tasks", {})
_assert(len(tasks) >= 3, "Baseline output does not contain all 3 tasks")
_assert("mean_score" in data, "Baseline output missing mean_score")
_assert("overall_mean_reward" in data, "Baseline output missing overall_mean_reward")
return data
def main() -> None:
print("Running pre-submission validator")
print(f"Base URL: {BASE_URL}")
with httpx.Client(base_url=BASE_URL, timeout=60.0) as client:
root = client.get("/")
_assert(root.status_code == 200, f"/ returned {root.status_code}")
health = client.get("/health")
_assert(health.status_code == 200, f"/health returned {health.status_code}")
tasks = client.get("/tasks")
_assert(tasks.status_code == 200, f"/tasks returned {tasks.status_code}")
tasks_data = tasks.json()
task_list = tasks_data.get("tasks", [])
_assert(len(task_list) >= 3, f"Expected >=3 tasks, found {len(task_list)}")
_check_openenv_endpoints(client)
scores: Dict[str, float] = {}
for task in [
TaskID.ADVERSE_EVENT_TRIAGE,
TaskID.PROTOCOL_DEVIATION_AUDIT,
TaskID.SAFETY_NARRATIVE_GENERATION,
]:
scores[task] = _run_episode(client, task)
baseline_data = _run_baseline_script()
print("All checks passed")
print("Episode grader scores:")
for task_id, score in scores.items():
print(f" - {task_id}: {score:.4f}")
print(f"Baseline overall mean: {baseline_data.get('overall_mean_reward')}")
if __name__ == "__main__":
main() |