Spaces:
Sleeping
Sleeping
| """ | |
| Clinical Trial Triage β EnvClient | |
| =================================== | |
| OpenEnv-compatible client for interacting with the environment server. | |
| Supports both async (recommended) and sync usage patterns. | |
| Usage (async): | |
| from client import ClinicalTrialEnv, TriageAction | |
| async with ClinicalTrialEnv(base_url="http://localhost:8000") as env: | |
| obs = await env.reset(task_id="adverse_event_triage") | |
| result = await env.step(action) | |
| Usage (sync): | |
| with ClinicalTrialEnv(base_url="http://localhost:8000").sync() as env: | |
| obs = env.reset(task_id="adverse_event_triage") | |
| result = env.step(action) | |
| Usage (HF Spaces): | |
| async with ClinicalTrialEnv(base_url="https://YOUR_HF_SPACE.hf.space") as env: | |
| obs = await env.reset() | |
| """ | |
| from __future__ import annotations | |
| import asyncio | |
| from typing import Any, Dict, Optional | |
| import httpx | |
| from models import ( | |
| StepResult, | |
| TaskID, | |
| TriageAction, | |
| TriageObservation, | |
| TriageState, | |
| ) | |
| class ClinicalTrialEnv: | |
| """ | |
| Async environment client implementing OpenEnv EnvClient pattern. | |
| Communicates with the FastAPI server over HTTP. | |
| """ | |
| def __init__( | |
| self, | |
| base_url: str = "http://localhost:8000", | |
| timeout: float = 30.0, | |
| session_id: str = "default", | |
| ) -> None: | |
| self.base_url = base_url.rstrip("/") | |
| self.timeout = timeout | |
| self.session_id = session_id.strip() or "default" | |
| self._client: Optional[httpx.AsyncClient] = None | |
| async def __aenter__(self) -> "ClinicalTrialEnv": | |
| self._client = httpx.AsyncClient( | |
| base_url=self.base_url, | |
| timeout=self.timeout, | |
| headers={"X-Session-ID": self.session_id}, | |
| ) | |
| return self | |
| async def __aexit__(self, *args) -> None: | |
| if self._client: | |
| await self._client.aclose() | |
| self._client = None | |
| def _ensure_client(self) -> httpx.AsyncClient: | |
| if self._client is None: | |
| raise RuntimeError( | |
| "Client not initialized. Use 'async with ClinicalTrialEnv(...) as env:'" | |
| ) | |
| return self._client | |
| async def reset( | |
| self, task_id: str = TaskID.ADVERSE_EVENT_TRIAGE | |
| ) -> TriageObservation: | |
| """Initialize a new episode. Returns initial observation.""" | |
| client = self._ensure_client() | |
| response = await client.post("/reset", json={"task_id": task_id}) | |
| response.raise_for_status() | |
| data = response.json() | |
| return TriageObservation(**data["observation"]) | |
| async def step(self, action: TriageAction) -> StepResult: | |
| """Execute one action. Returns StepResult(observation, reward, done, info).""" | |
| client = self._ensure_client() | |
| response = await client.post("/step", json=action.model_dump()) | |
| response.raise_for_status() | |
| return StepResult(**response.json()) | |
| async def state(self) -> TriageState: | |
| """Return current episode state metadata.""" | |
| client = self._ensure_client() | |
| response = await client.get("/state") | |
| response.raise_for_status() | |
| return TriageState(**response.json()) | |
| async def tasks(self) -> Dict[str, Any]: | |
| """Return available tasks and action schemas.""" | |
| client = self._ensure_client() | |
| response = await client.get("/tasks") | |
| response.raise_for_status() | |
| return response.json() | |
| async def grader(self) -> Dict[str, Any]: | |
| """Return grader scores for the last completed episode.""" | |
| client = self._ensure_client() | |
| response = await client.get("/grader") | |
| response.raise_for_status() | |
| return response.json() | |
| def sync(self) -> "SyncClinicalTrialEnv": | |
| """Return a synchronous wrapper around this client.""" | |
| return SyncClinicalTrialEnv(self.base_url, self.timeout, self.session_id) | |
| class SyncClinicalTrialEnv: | |
| """Synchronous wrapper for environments that don't use async.""" | |
| def __init__(self, base_url: str, timeout: float = 30.0, session_id: str = "default") -> None: | |
| self.base_url = base_url.rstrip("/") | |
| self.timeout = timeout | |
| self.session_id = session_id.strip() or "default" | |
| self._client: Optional[httpx.Client] = None | |
| def __enter__(self) -> "SyncClinicalTrialEnv": | |
| self._client = httpx.Client( | |
| base_url=self.base_url, | |
| timeout=self.timeout, | |
| headers={"X-Session-ID": self.session_id}, | |
| ) | |
| return self | |
| def __exit__(self, *args) -> None: | |
| if self._client: | |
| self._client.close() | |
| self._client = None | |
| def reset(self, task_id: str = TaskID.ADVERSE_EVENT_TRIAGE) -> TriageObservation: | |
| assert self._client | |
| r = self._client.post("/reset", json={"task_id": task_id}) | |
| r.raise_for_status() | |
| return TriageObservation(**r.json()["observation"]) | |
| def step(self, action: TriageAction) -> StepResult: | |
| assert self._client | |
| r = self._client.post("/step", json=action.model_dump()) | |
| r.raise_for_status() | |
| return StepResult(**r.json()) | |
| def state(self) -> TriageState: | |
| assert self._client | |
| r = self._client.get("/state") | |
| r.raise_for_status() | |
| return TriageState(**r.json()) | |
| # βββββββββββββββββββββββββββββββββββββββββ | |
| # QUICK DEMO | |
| # βββββββββββββββββββββββββββββββββββββββββ | |
| async def _demo(): | |
| """Quick async demo of the environment client.""" | |
| from models import AdverseEventTriageAction, TriageAction | |
| print("Clinical Trial Triage β Client Demo") | |
| print("=" * 50) | |
| async with ClinicalTrialEnv(base_url="http://localhost:8000") as env: | |
| # Get available tasks | |
| tasks = await env.tasks() | |
| print(f"Available tasks: {[t['id'] for t in tasks['tasks']]}\n") | |
| # Reset AE triage episode | |
| obs = await env.reset(task_id="adverse_event_triage") | |
| print(f"Task: {obs.task_id}") | |
| if obs.ae_observation: | |
| print(f"Case: {obs.ae_observation.case_id}") | |
| print(f"AE: {obs.ae_observation.ae_description}") | |
| # Send a test action | |
| action = TriageAction( | |
| task_id="adverse_event_triage", | |
| ae_triage=AdverseEventTriageAction( | |
| severity_classification="life_threatening", | |
| reporting_timeline="7-day", | |
| meddra_soc="Cardiac disorders", | |
| meddra_preferred_term="Myocardial infarction", | |
| is_serious=True, | |
| rationale="STEMI with troponin elevation β life-threatening SAE requiring 7-day reporting.", | |
| ), | |
| ) | |
| result = await env.step(action) | |
| print(f"\nReward: {result.reward:.4f}") | |
| print(f"Done: {result.done}") | |
| print(f"Reward detail: {result.reward_detail.model_dump()}") | |
| if __name__ == "__main__": | |
| asyncio.run(_demo()) | |