meta-hack / scripts /validate_submission.py
vvinayakkk's picture
Sync full clinical-trial-triage project into Space
404c45f
"""
Pre-submission validator for Clinical Trial Triage OpenEnv.
Checks:
1. Core endpoints respond and return expected shapes.
2. /tasks returns >= 3 tasks.
3. Each task can be completed and /grader returns score in [0.0, 1.0].
4. Root inference script runs without errors and produces outputs/baseline_results.json.
Usage:
python scripts/validate_submission.py
Notes:
- Requires the API server to be running (default: http://localhost:8000).
- Uses deterministic heuristic actions for endpoint and grader checks.
"""
from __future__ import annotations
import json
import os
import subprocess
import sys
from pathlib import Path
from typing import Any, Dict
import httpx
# Ensure project root import resolution
ROOT = Path(__file__).resolve().parent.parent
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from models import TaskID
from scripts.heuristic_baseline import (
_heuristic_ae_triage,
_heuristic_deviation_audit,
_heuristic_narrative,
)
from tasks.case_bank import AE_CASES, DEVIATION_CASES, NARRATIVE_CASES
BASE_URL = os.environ.get("VALIDATOR_BASE_URL", "http://localhost:8000").rstrip("/")
OUTPUT_FILE = ROOT / "outputs" / "baseline_results.json"
INFERENCE_TIMEOUT_SECONDS = 20 * 60
def _assert(condition: bool, message: str) -> None:
if not condition:
raise AssertionError(message)
def _post_json(client: httpx.Client, path: str, payload: Dict[str, Any]) -> Dict[str, Any]:
response = client.post(path, json=payload)
_assert(response.status_code == 200, f"{path} returned {response.status_code}: {response.text}")
return response.json()
def _run_episode(client: httpx.Client, task_id: str) -> float:
reset_data = _post_json(client, "/reset", {"task_id": task_id})
_assert("observation" in reset_data, f"/reset missing observation for task {task_id}")
if task_id == TaskID.ADVERSE_EVENT_TRIAGE:
for case in AE_CASES:
step_payload = _heuristic_ae_triage(case).model_dump()
step_response = _post_json(client, "/step", step_payload)
if step_response.get("done"):
break
elif task_id == TaskID.PROTOCOL_DEVIATION_AUDIT:
for case in DEVIATION_CASES:
step_payload = _heuristic_deviation_audit(case).model_dump()
step_response = _post_json(client, "/step", step_payload)
if step_response.get("done"):
break
elif task_id == TaskID.SAFETY_NARRATIVE_GENERATION:
for case in NARRATIVE_CASES:
step_payload = _heuristic_narrative(case).model_dump()
step_response = _post_json(client, "/step", step_payload)
if step_response.get("done"):
break
else:
raise AssertionError(f"Unknown task_id: {task_id}")
grader_response = client.get("/grader")
_assert(grader_response.status_code == 200, f"/grader failed for task {task_id}: {grader_response.text}")
grader_data = grader_response.json()
score = grader_data.get("normalized_score")
_assert(isinstance(score, (int, float)), f"normalized_score missing for task {task_id}")
_assert(0.0 <= float(score) <= 1.0, f"normalized_score out of range for task {task_id}: {score}")
return float(score)
def _check_openenv_endpoints(client: httpx.Client) -> None:
metadata = client.get("/openenv/metadata")
_assert(metadata.status_code == 200, f"/openenv/metadata returned {metadata.status_code}")
schema = client.get("/openenv/schema")
_assert(schema.status_code == 200, f"/openenv/schema returned {schema.status_code}")
reset = client.post("/openenv/reset", json={"task_id": TaskID.ADVERSE_EVENT_TRIAGE})
_assert(reset.status_code == 200, f"/openenv/reset returned {reset.status_code}: {reset.text}")
reset_payload = reset.json()
_assert("observation" in reset_payload, "/openenv/reset missing observation")
step = client.post(
"/openenv/step",
json={
"action": {
"task_id": TaskID.ADVERSE_EVENT_TRIAGE,
"ae_triage": {
"severity_classification": "severe",
"reporting_timeline": "15-day",
"meddra_soc": "Cardiac disorders",
"meddra_preferred_term": "Myocardial infarction",
"is_serious": True,
"rationale": "validator openenv smoke action",
},
}
},
)
_assert(step.status_code == 200, f"/openenv/step returned {step.status_code}: {step.text}")
state = client.get("/openenv/state")
_assert(state.status_code == 200, f"/openenv/state returned {state.status_code}: {state.text}")
health = client.get("/openenv/health")
_assert(health.status_code == 200, f"/openenv/health returned {health.status_code}")
def _run_baseline_script() -> Dict[str, Any]:
cmd = [sys.executable, str(ROOT / "inference.py")]
try:
process = subprocess.run(
cmd,
cwd=str(ROOT),
capture_output=True,
text=True,
timeout=INFERENCE_TIMEOUT_SECONDS,
)
except subprocess.TimeoutExpired as exc:
raise AssertionError(
f"inference.py exceeded runtime budget ({INFERENCE_TIMEOUT_SECONDS}s). "
"Submission requires completion under 20 minutes."
) from exc
_assert(process.returncode == 0, f"inference.py failed:\n{process.stderr}\n{process.stdout}")
_assert(OUTPUT_FILE.exists(), f"Missing baseline output file: {OUTPUT_FILE}")
with open(OUTPUT_FILE, "r", encoding="utf-8") as file:
data = json.load(file)
tasks = data.get("tasks", {})
_assert(len(tasks) >= 3, "Baseline output does not contain all 3 tasks")
_assert("mean_score" in data, "Baseline output missing mean_score")
_assert("overall_mean_reward" in data, "Baseline output missing overall_mean_reward")
return data
def main() -> None:
print("Running pre-submission validator")
print(f"Base URL: {BASE_URL}")
with httpx.Client(base_url=BASE_URL, timeout=60.0) as client:
root = client.get("/")
_assert(root.status_code == 200, f"/ returned {root.status_code}")
health = client.get("/health")
_assert(health.status_code == 200, f"/health returned {health.status_code}")
tasks = client.get("/tasks")
_assert(tasks.status_code == 200, f"/tasks returned {tasks.status_code}")
tasks_data = tasks.json()
task_list = tasks_data.get("tasks", [])
_assert(len(task_list) >= 3, f"Expected >=3 tasks, found {len(task_list)}")
_check_openenv_endpoints(client)
scores: Dict[str, float] = {}
for task in [
TaskID.ADVERSE_EVENT_TRIAGE,
TaskID.PROTOCOL_DEVIATION_AUDIT,
TaskID.SAFETY_NARRATIVE_GENERATION,
]:
scores[task] = _run_episode(client, task)
baseline_data = _run_baseline_script()
print("All checks passed")
print("Episode grader scores:")
for task_id, score in scores.items():
print(f" - {task_id}: {score:.4f}")
print(f"Baseline overall mean: {baseline_data.get('overall_mean_reward')}")
if __name__ == "__main__":
main()