Spaces:
Sleeping
Sleeping
| """ | |
| Pre-submission validator for Clinical Trial Triage OpenEnv. | |
| Checks: | |
| 1. Core endpoints respond and return expected shapes. | |
| 2. /tasks returns >= 3 tasks. | |
| 3. Each task can be completed and /grader returns score in [0.0, 1.0]. | |
| 4. Root inference script runs without errors and produces outputs/baseline_results.json. | |
| Usage: | |
| python scripts/validate_submission.py | |
| Notes: | |
| - Requires the API server to be running (default: http://localhost:8000). | |
| - Uses deterministic heuristic actions for endpoint and grader checks. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import os | |
| import subprocess | |
| import sys | |
| from pathlib import Path | |
| from typing import Any, Dict | |
| import httpx | |
| # Ensure project root import resolution | |
| ROOT = Path(__file__).resolve().parent.parent | |
| if str(ROOT) not in sys.path: | |
| sys.path.insert(0, str(ROOT)) | |
| from models import TaskID | |
| from scripts.heuristic_baseline import ( | |
| _heuristic_ae_triage, | |
| _heuristic_deviation_audit, | |
| _heuristic_narrative, | |
| ) | |
| from tasks.case_bank import AE_CASES, DEVIATION_CASES, NARRATIVE_CASES | |
| BASE_URL = os.environ.get("VALIDATOR_BASE_URL", "http://localhost:8000").rstrip("/") | |
| OUTPUT_FILE = ROOT / "outputs" / "baseline_results.json" | |
| INFERENCE_TIMEOUT_SECONDS = 20 * 60 | |
| def _assert(condition: bool, message: str) -> None: | |
| if not condition: | |
| raise AssertionError(message) | |
| def _post_json(client: httpx.Client, path: str, payload: Dict[str, Any]) -> Dict[str, Any]: | |
| response = client.post(path, json=payload) | |
| _assert(response.status_code == 200, f"{path} returned {response.status_code}: {response.text}") | |
| return response.json() | |
| def _run_episode(client: httpx.Client, task_id: str) -> float: | |
| reset_data = _post_json(client, "/reset", {"task_id": task_id}) | |
| _assert("observation" in reset_data, f"/reset missing observation for task {task_id}") | |
| if task_id == TaskID.ADVERSE_EVENT_TRIAGE: | |
| for case in AE_CASES: | |
| step_payload = _heuristic_ae_triage(case).model_dump() | |
| step_response = _post_json(client, "/step", step_payload) | |
| if step_response.get("done"): | |
| break | |
| elif task_id == TaskID.PROTOCOL_DEVIATION_AUDIT: | |
| for case in DEVIATION_CASES: | |
| step_payload = _heuristic_deviation_audit(case).model_dump() | |
| step_response = _post_json(client, "/step", step_payload) | |
| if step_response.get("done"): | |
| break | |
| elif task_id == TaskID.SAFETY_NARRATIVE_GENERATION: | |
| for case in NARRATIVE_CASES: | |
| step_payload = _heuristic_narrative(case).model_dump() | |
| step_response = _post_json(client, "/step", step_payload) | |
| if step_response.get("done"): | |
| break | |
| else: | |
| raise AssertionError(f"Unknown task_id: {task_id}") | |
| grader_response = client.get("/grader") | |
| _assert(grader_response.status_code == 200, f"/grader failed for task {task_id}: {grader_response.text}") | |
| grader_data = grader_response.json() | |
| score = grader_data.get("normalized_score") | |
| _assert(isinstance(score, (int, float)), f"normalized_score missing for task {task_id}") | |
| _assert(0.0 <= float(score) <= 1.0, f"normalized_score out of range for task {task_id}: {score}") | |
| return float(score) | |
| def _check_openenv_endpoints(client: httpx.Client) -> None: | |
| metadata = client.get("/openenv/metadata") | |
| _assert(metadata.status_code == 200, f"/openenv/metadata returned {metadata.status_code}") | |
| schema = client.get("/openenv/schema") | |
| _assert(schema.status_code == 200, f"/openenv/schema returned {schema.status_code}") | |
| reset = client.post("/openenv/reset", json={"task_id": TaskID.ADVERSE_EVENT_TRIAGE}) | |
| _assert(reset.status_code == 200, f"/openenv/reset returned {reset.status_code}: {reset.text}") | |
| reset_payload = reset.json() | |
| _assert("observation" in reset_payload, "/openenv/reset missing observation") | |
| step = client.post( | |
| "/openenv/step", | |
| json={ | |
| "action": { | |
| "task_id": TaskID.ADVERSE_EVENT_TRIAGE, | |
| "ae_triage": { | |
| "severity_classification": "severe", | |
| "reporting_timeline": "15-day", | |
| "meddra_soc": "Cardiac disorders", | |
| "meddra_preferred_term": "Myocardial infarction", | |
| "is_serious": True, | |
| "rationale": "validator openenv smoke action", | |
| }, | |
| } | |
| }, | |
| ) | |
| _assert(step.status_code == 200, f"/openenv/step returned {step.status_code}: {step.text}") | |
| state = client.get("/openenv/state") | |
| _assert(state.status_code == 200, f"/openenv/state returned {state.status_code}: {state.text}") | |
| health = client.get("/openenv/health") | |
| _assert(health.status_code == 200, f"/openenv/health returned {health.status_code}") | |
| def _run_baseline_script() -> Dict[str, Any]: | |
| cmd = [sys.executable, str(ROOT / "inference.py")] | |
| try: | |
| process = subprocess.run( | |
| cmd, | |
| cwd=str(ROOT), | |
| capture_output=True, | |
| text=True, | |
| timeout=INFERENCE_TIMEOUT_SECONDS, | |
| ) | |
| except subprocess.TimeoutExpired as exc: | |
| raise AssertionError( | |
| f"inference.py exceeded runtime budget ({INFERENCE_TIMEOUT_SECONDS}s). " | |
| "Submission requires completion under 20 minutes." | |
| ) from exc | |
| _assert(process.returncode == 0, f"inference.py failed:\n{process.stderr}\n{process.stdout}") | |
| _assert(OUTPUT_FILE.exists(), f"Missing baseline output file: {OUTPUT_FILE}") | |
| with open(OUTPUT_FILE, "r", encoding="utf-8") as file: | |
| data = json.load(file) | |
| tasks = data.get("tasks", {}) | |
| _assert(len(tasks) >= 3, "Baseline output does not contain all 3 tasks") | |
| _assert("mean_score" in data, "Baseline output missing mean_score") | |
| _assert("overall_mean_reward" in data, "Baseline output missing overall_mean_reward") | |
| return data | |
| def main() -> None: | |
| print("Running pre-submission validator") | |
| print(f"Base URL: {BASE_URL}") | |
| with httpx.Client(base_url=BASE_URL, timeout=60.0) as client: | |
| root = client.get("/") | |
| _assert(root.status_code == 200, f"/ returned {root.status_code}") | |
| health = client.get("/health") | |
| _assert(health.status_code == 200, f"/health returned {health.status_code}") | |
| tasks = client.get("/tasks") | |
| _assert(tasks.status_code == 200, f"/tasks returned {tasks.status_code}") | |
| tasks_data = tasks.json() | |
| task_list = tasks_data.get("tasks", []) | |
| _assert(len(task_list) >= 3, f"Expected >=3 tasks, found {len(task_list)}") | |
| _check_openenv_endpoints(client) | |
| scores: Dict[str, float] = {} | |
| for task in [ | |
| TaskID.ADVERSE_EVENT_TRIAGE, | |
| TaskID.PROTOCOL_DEVIATION_AUDIT, | |
| TaskID.SAFETY_NARRATIVE_GENERATION, | |
| ]: | |
| scores[task] = _run_episode(client, task) | |
| baseline_data = _run_baseline_script() | |
| print("All checks passed") | |
| print("Episode grader scores:") | |
| for task_id, score in scores.items(): | |
| print(f" - {task_id}: {score:.4f}") | |
| print(f"Baseline overall mean: {baseline_data.get('overall_mean_reward')}") | |
| if __name__ == "__main__": | |
| main() |