clinical-trial-env / server /environment.py
Geetesh
fix: move server files into server/ subfolder
045a5fe
"""
ClinicalTrialEnv — OpenEnv-compliant environment for Clinical Trial Protocol Review.
Implements:
reset() → StepResult (initial observation, no reward)
step() → StepResult (graded observation, reward, done)
state() → dict of current internal state
Episode structure:
- Each episode corresponds to one task.
- An agent may take up to max_steps actions.
- The final step triggers the grader and returns the terminal reward.
- Intermediate steps return partial rewards based on incremental improvement.
- The episode ends when done=True (max steps reached or agent signals completion).
"""
from __future__ import annotations
import copy
import time
from typing import Any, Dict, Optional
from models import ClinicalTrialAction, ClinicalTrialObservation, StepResult
from tasks import TASKS
class ClinicalTrialEnv:
"""
OpenEnv environment for clinical trial protocol review.
"""
def __init__(self, task_name: str = "eligibility_screening") -> None:
if task_name not in TASKS:
raise ValueError(
f"Unknown task '{task_name}'. Available: {list(TASKS.keys())}"
)
self.task_name = task_name
self._task = TASKS[task_name]
self._step = 0
self._done = False
self._all_findings: list = []
self._all_rationale: str = ""
self._last_score: float = 0.0
self._last_feedback: str = ""
self._history: list = []
self._start_time: float = time.time()
# ------------------------------------------------------------------
# OpenEnv required interface
# ------------------------------------------------------------------
def reset(self) -> StepResult:
"""Reset the environment and return the initial observation."""
self._step = 0
self._done = False
self._all_findings = []
self._all_rationale = ""
self._last_score = 0.0
self._last_feedback = ""
self._history = []
self._start_time = time.time()
obs = self._make_observation(feedback="", partial_score=0.0)
return StepResult(observation=obs, reward=0.0, done=False, info={})
def step(self, action: ClinicalTrialAction) -> StepResult:
"""
Process one agent action and return the result.
Reward shaping:
- Intermediate steps: reward = improvement in partial score since last step
- Final step (max_steps or agent signals done): full grader score
- Penalty for empty/trivial actions: -0.05
"""
if self._done:
raise RuntimeError("Episode is done. Call reset() to start a new episode.")
self._step += 1
# Accumulate findings and rationale across steps
self._all_findings.extend(action.findings)
if action.rationale:
self._all_rationale += " " + action.rationale
# Run grader to get current score
grader = self._task["grader"]
new_score, feedback = grader(self._all_findings, self._all_rationale)
# Reward = incremental improvement in score
reward = new_score - self._last_score
# Penalty for trivial/empty action
if not action.findings and not action.rationale.strip():
reward -= 0.05
# Determine if episode ends
max_steps = self._task["max_steps"]
done = self._step >= max_steps
self._last_score = new_score
self._last_feedback = feedback
self._history.append({
"step": self._step,
"n_findings": len(action.findings),
"score_after": new_score,
"reward": reward,
})
if done:
self._done = True
obs = self._make_observation(feedback=feedback, partial_score=new_score)
info = {
"score": new_score,
"step": self._step,
"max_steps": max_steps,
"grader_feedback": feedback,
}
return StepResult(observation=obs, reward=reward, done=done, info=info)
def state(self) -> Dict[str, Any]:
"""Return current internal state (for debugging/inspection)."""
return {
"task_name": self.task_name,
"step": self._step,
"done": self._done,
"current_score": self._last_score,
"n_findings_accumulated": len(self._all_findings),
"history": copy.deepcopy(self._history),
"elapsed_seconds": time.time() - self._start_time,
}
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _make_observation(self, feedback: str, partial_score: float) -> ClinicalTrialObservation:
task = self._task
return ClinicalTrialObservation(
task_name=self.task_name,
protocol_summary=task["protocol_summary"],
patient_records=task["patient_records"],
adverse_events=task["adverse_events"],
protocol_text=task["protocol_text"],
step=self._step,
feedback=feedback,
partial_score=partial_score,
)
def close(self) -> None:
"""Clean up resources (no-op for this environment)."""
pass