meta-hack / client.py
vvinayakkk's picture
Sync full clinical-trial-triage project into Space
404c45f
"""
Clinical Trial Triage β€” EnvClient
===================================
OpenEnv-compatible client for interacting with the environment server.
Supports both async (recommended) and sync usage patterns.
Usage (async):
from client import ClinicalTrialEnv, TriageAction
async with ClinicalTrialEnv(base_url="http://localhost:8000") as env:
obs = await env.reset(task_id="adverse_event_triage")
result = await env.step(action)
Usage (sync):
with ClinicalTrialEnv(base_url="http://localhost:8000").sync() as env:
obs = env.reset(task_id="adverse_event_triage")
result = env.step(action)
Usage (HF Spaces):
async with ClinicalTrialEnv(base_url="https://YOUR_HF_SPACE.hf.space") as env:
obs = await env.reset()
"""
from __future__ import annotations
import asyncio
from typing import Any, Dict, Optional
import httpx
from models import (
StepResult,
TaskID,
TriageAction,
TriageObservation,
TriageState,
)
class ClinicalTrialEnv:
"""
Async environment client implementing OpenEnv EnvClient pattern.
Communicates with the FastAPI server over HTTP.
"""
def __init__(
self,
base_url: str = "http://localhost:8000",
timeout: float = 30.0,
session_id: str = "default",
) -> None:
self.base_url = base_url.rstrip("/")
self.timeout = timeout
self.session_id = session_id.strip() or "default"
self._client: Optional[httpx.AsyncClient] = None
async def __aenter__(self) -> "ClinicalTrialEnv":
self._client = httpx.AsyncClient(
base_url=self.base_url,
timeout=self.timeout,
headers={"X-Session-ID": self.session_id},
)
return self
async def __aexit__(self, *args) -> None:
if self._client:
await self._client.aclose()
self._client = None
def _ensure_client(self) -> httpx.AsyncClient:
if self._client is None:
raise RuntimeError(
"Client not initialized. Use 'async with ClinicalTrialEnv(...) as env:'"
)
return self._client
async def reset(
self, task_id: str = TaskID.ADVERSE_EVENT_TRIAGE
) -> TriageObservation:
"""Initialize a new episode. Returns initial observation."""
client = self._ensure_client()
response = await client.post("/reset", json={"task_id": task_id})
response.raise_for_status()
data = response.json()
return TriageObservation(**data["observation"])
async def step(self, action: TriageAction) -> StepResult:
"""Execute one action. Returns StepResult(observation, reward, done, info)."""
client = self._ensure_client()
response = await client.post("/step", json=action.model_dump())
response.raise_for_status()
return StepResult(**response.json())
async def state(self) -> TriageState:
"""Return current episode state metadata."""
client = self._ensure_client()
response = await client.get("/state")
response.raise_for_status()
return TriageState(**response.json())
async def tasks(self) -> Dict[str, Any]:
"""Return available tasks and action schemas."""
client = self._ensure_client()
response = await client.get("/tasks")
response.raise_for_status()
return response.json()
async def grader(self) -> Dict[str, Any]:
"""Return grader scores for the last completed episode."""
client = self._ensure_client()
response = await client.get("/grader")
response.raise_for_status()
return response.json()
def sync(self) -> "SyncClinicalTrialEnv":
"""Return a synchronous wrapper around this client."""
return SyncClinicalTrialEnv(self.base_url, self.timeout, self.session_id)
class SyncClinicalTrialEnv:
"""Synchronous wrapper for environments that don't use async."""
def __init__(self, base_url: str, timeout: float = 30.0, session_id: str = "default") -> None:
self.base_url = base_url.rstrip("/")
self.timeout = timeout
self.session_id = session_id.strip() or "default"
self._client: Optional[httpx.Client] = None
def __enter__(self) -> "SyncClinicalTrialEnv":
self._client = httpx.Client(
base_url=self.base_url,
timeout=self.timeout,
headers={"X-Session-ID": self.session_id},
)
return self
def __exit__(self, *args) -> None:
if self._client:
self._client.close()
self._client = None
def reset(self, task_id: str = TaskID.ADVERSE_EVENT_TRIAGE) -> TriageObservation:
assert self._client
r = self._client.post("/reset", json={"task_id": task_id})
r.raise_for_status()
return TriageObservation(**r.json()["observation"])
def step(self, action: TriageAction) -> StepResult:
assert self._client
r = self._client.post("/step", json=action.model_dump())
r.raise_for_status()
return StepResult(**r.json())
def state(self) -> TriageState:
assert self._client
r = self._client.get("/state")
r.raise_for_status()
return TriageState(**r.json())
# ─────────────────────────────────────────
# QUICK DEMO
# ─────────────────────────────────────────
async def _demo():
"""Quick async demo of the environment client."""
from models import AdverseEventTriageAction, TriageAction
print("Clinical Trial Triage β€” Client Demo")
print("=" * 50)
async with ClinicalTrialEnv(base_url="http://localhost:8000") as env:
# Get available tasks
tasks = await env.tasks()
print(f"Available tasks: {[t['id'] for t in tasks['tasks']]}\n")
# Reset AE triage episode
obs = await env.reset(task_id="adverse_event_triage")
print(f"Task: {obs.task_id}")
if obs.ae_observation:
print(f"Case: {obs.ae_observation.case_id}")
print(f"AE: {obs.ae_observation.ae_description}")
# Send a test action
action = TriageAction(
task_id="adverse_event_triage",
ae_triage=AdverseEventTriageAction(
severity_classification="life_threatening",
reporting_timeline="7-day",
meddra_soc="Cardiac disorders",
meddra_preferred_term="Myocardial infarction",
is_serious=True,
rationale="STEMI with troponin elevation β€” life-threatening SAE requiring 7-day reporting.",
),
)
result = await env.step(action)
print(f"\nReward: {result.reward:.4f}")
print(f"Done: {result.done}")
print(f"Reward detail: {result.reward_detail.model_dump()}")
if __name__ == "__main__":
asyncio.run(_demo())