Spaces:
Sleeping
Sleeping
| """ | |
| ClinicalTrialEnv — OpenEnv-compliant environment for Clinical Trial Protocol Review. | |
| Implements: | |
| reset() → StepResult (initial observation, no reward) | |
| step() → StepResult (graded observation, reward, done) | |
| state() → dict of current internal state | |
| Episode structure: | |
| - Each episode corresponds to one task. | |
| - An agent may take up to max_steps actions. | |
| - The final step triggers the grader and returns the terminal reward. | |
| - Intermediate steps return partial rewards based on incremental improvement. | |
| - The episode ends when done=True (max steps reached or agent signals completion). | |
| """ | |
| from __future__ import annotations | |
| import copy | |
| import time | |
| from typing import Any, Dict, Optional | |
| from models import ClinicalTrialAction, ClinicalTrialObservation, StepResult | |
| from tasks import TASKS | |
| class ClinicalTrialEnv: | |
| """ | |
| OpenEnv environment for clinical trial protocol review. | |
| """ | |
| def __init__(self, task_name: str = "eligibility_screening") -> None: | |
| if task_name not in TASKS: | |
| raise ValueError( | |
| f"Unknown task '{task_name}'. Available: {list(TASKS.keys())}" | |
| ) | |
| self.task_name = task_name | |
| self._task = TASKS[task_name] | |
| self._step = 0 | |
| self._done = False | |
| self._all_findings: list = [] | |
| self._all_rationale: str = "" | |
| self._last_score: float = 0.0 | |
| self._last_feedback: str = "" | |
| self._history: list = [] | |
| self._start_time: float = time.time() | |
| # ------------------------------------------------------------------ | |
| # OpenEnv required interface | |
| # ------------------------------------------------------------------ | |
| def reset(self) -> StepResult: | |
| """Reset the environment and return the initial observation.""" | |
| self._step = 0 | |
| self._done = False | |
| self._all_findings = [] | |
| self._all_rationale = "" | |
| self._last_score = 0.0 | |
| self._last_feedback = "" | |
| self._history = [] | |
| self._start_time = time.time() | |
| obs = self._make_observation(feedback="", partial_score=0.0) | |
| return StepResult(observation=obs, reward=0.0, done=False, info={}) | |
| def step(self, action: ClinicalTrialAction) -> StepResult: | |
| """ | |
| Process one agent action and return the result. | |
| Reward shaping: | |
| - Intermediate steps: reward = improvement in partial score since last step | |
| - Final step (max_steps or agent signals done): full grader score | |
| - Penalty for empty/trivial actions: -0.05 | |
| """ | |
| if self._done: | |
| raise RuntimeError("Episode is done. Call reset() to start a new episode.") | |
| self._step += 1 | |
| # Accumulate findings and rationale across steps | |
| self._all_findings.extend(action.findings) | |
| if action.rationale: | |
| self._all_rationale += " " + action.rationale | |
| # Run grader to get current score | |
| grader = self._task["grader"] | |
| new_score, feedback = grader(self._all_findings, self._all_rationale) | |
| # Reward = incremental improvement in score | |
| reward = new_score - self._last_score | |
| # Penalty for trivial/empty action | |
| if not action.findings and not action.rationale.strip(): | |
| reward -= 0.05 | |
| # Determine if episode ends | |
| max_steps = self._task["max_steps"] | |
| done = self._step >= max_steps | |
| self._last_score = new_score | |
| self._last_feedback = feedback | |
| self._history.append({ | |
| "step": self._step, | |
| "n_findings": len(action.findings), | |
| "score_after": new_score, | |
| "reward": reward, | |
| }) | |
| if done: | |
| self._done = True | |
| obs = self._make_observation(feedback=feedback, partial_score=new_score) | |
| info = { | |
| "score": new_score, | |
| "step": self._step, | |
| "max_steps": max_steps, | |
| "grader_feedback": feedback, | |
| } | |
| return StepResult(observation=obs, reward=reward, done=done, info=info) | |
| def state(self) -> Dict[str, Any]: | |
| """Return current internal state (for debugging/inspection).""" | |
| return { | |
| "task_name": self.task_name, | |
| "step": self._step, | |
| "done": self._done, | |
| "current_score": self._last_score, | |
| "n_findings_accumulated": len(self._all_findings), | |
| "history": copy.deepcopy(self._history), | |
| "elapsed_seconds": time.time() - self._start_time, | |
| } | |
| # ------------------------------------------------------------------ | |
| # Internal helpers | |
| # ------------------------------------------------------------------ | |
| def _make_observation(self, feedback: str, partial_score: float) -> ClinicalTrialObservation: | |
| task = self._task | |
| return ClinicalTrialObservation( | |
| task_name=self.task_name, | |
| protocol_summary=task["protocol_summary"], | |
| patient_records=task["patient_records"], | |
| adverse_events=task["adverse_events"], | |
| protocol_text=task["protocol_text"], | |
| step=self._step, | |
| feedback=feedback, | |
| partial_score=partial_score, | |
| ) | |
| def close(self) -> None: | |
| """Clean up resources (no-op for this environment).""" | |
| pass | |