""" Clinical Trial Triage — Environment (Server-Side) =================================================== Implements the OpenEnv Environment base with reset(), step(), state(). Full episode management, reward shaping, and multi-task support. """ from __future__ import annotations import json import threading import uuid from datetime import datetime, timezone from typing import Any, Dict, Optional from models import ( AdverseEventObservation, AdverseEventTriageAction, ProtocolDeviationAction, ProtocolDeviationObservation, SafetyNarrativeAction, SafetyNarrativeObservation, StepResult, TaskID, TriageAction, TriageObservation, TriageReward, TriageState, ) from tasks.case_bank import AE_CASES, DEVIATION_CASES, NARRATIVE_CASES # noqa: E402 from tasks.graders import ( # noqa: E402 grade_ae_triage, grade_protocol_deviation, grade_safety_narrative, ) # Task → max steps configuration TASK_MAX_STEPS: Dict[str, int] = { TaskID.ADVERSE_EVENT_TRIAGE: 3, # 3 AE cases per episode TaskID.PROTOCOL_DEVIATION_AUDIT: 3, # 3 site audits per episode TaskID.SAFETY_NARRATIVE_GENERATION: 1, # 1 complex narrative per episode } TASK_CASES: Dict[str, list] = { TaskID.ADVERSE_EVENT_TRIAGE: AE_CASES, TaskID.PROTOCOL_DEVIATION_AUDIT: DEVIATION_CASES, TaskID.SAFETY_NARRATIVE_GENERATION: NARRATIVE_CASES, } _SCORE_EPS = 1e-3 _sessions: Dict[str, "ClinicalTrialEnvironment"] = {} _sessions_lock = threading.Lock() def get_or_create_session(session_id: str = "default") -> "ClinicalTrialEnvironment": with _sessions_lock: if session_id not in _sessions: _sessions[session_id] = ClinicalTrialEnvironment() return _sessions[session_id] def clear_session(session_id: str = "default") -> None: with _sessions_lock: _sessions.pop(session_id, None) class ClinicalTrialEnvironment: """ Main environment class implementing OpenEnv-compatible APIs. Episode lifecycle: reset(task_id) → initial observation step(action) → observation, reward, done, info state() → episode metadata """ def __init__(self) -> None: self._state: Optional[TriageState] = None self._case_index: int = 0 self._current_task: Optional[TaskID] = None self._cumulative_reward: float = 0.0 self._actions_log: list = [] self._last_action_signature: Optional[str] = None # ───────────────────────────────────── # PUBLIC OPENENV API # ───────────────────────────────────── def reset(self, task_id: str = TaskID.ADVERSE_EVENT_TRIAGE) -> TriageObservation: """Initialize a new episode for the given task.""" task = TaskID(task_id) self._current_task = task self._case_index = 0 self._cumulative_reward = 0.0 self._actions_log = [] self._last_action_signature = None self._state = TriageState( episode_id=str(uuid.uuid4()), task_id=task, step_count=0, max_steps=TASK_MAX_STEPS[task], done=False, cumulative_reward=0.0, actions_taken=[], current_case_id=self._get_current_case_id(), started_at=datetime.now(timezone.utc).isoformat(), ) return self._build_observation() def step(self, action: TriageAction) -> StepResult: """Execute one action and return (observation, reward, done, info).""" if self._state is None or self._state.done: raise RuntimeError("Call reset() before step(), or episode is already done.") if TaskID(action.task_id) != self._current_task: raise ValueError( f"Action task_id '{action.task_id}' does not match " f"current episode task '{self._current_task}'." ) current_observation = self._build_observation() # Grade this step reward_detail = self._grade(action) step_reward = reward_detail.total # Reward shaping: small partial reward for being in the right direction # This gives signal across the trajectory, not just at episode end shaped_reward = self._shape_reward(step_reward, reward_detail, action, current_observation) # Update state self._cumulative_reward += shaped_reward self._case_index += 1 self._state.step_count += 1 self._state.cumulative_reward = self._cumulative_reward self._state.actions_taken.append( {"step": self._state.step_count, "action": action.model_dump(), "reward": shaped_reward} ) done = self._state.step_count >= self._state.max_steps self._state.done = done if done: self._state.completed_at = datetime.now(timezone.utc).isoformat() self._state.current_case_id = None else: self._state.current_case_id = self._get_current_case_id() obs = self._build_observation() return StepResult( observation=obs, reward=shaped_reward, reward_detail=reward_detail, done=done, info={ "episode_id": self._state.episode_id, "step": self._state.step_count, "cumulative_reward": self._cumulative_reward, "done": done, }, ) def state(self) -> TriageState: """Return current episode state metadata.""" if self._state is None: raise RuntimeError("No active episode. Call reset() first.") return self._state # ───────────────────────────────────── # INTERNAL HELPERS # ───────────────────────────────────── def _get_current_case_id(self) -> Optional[str]: cases = TASK_CASES.get(self._current_task, []) if self._case_index < len(cases): case = cases[self._case_index] return case.get("case_id") or case.get("site_id") return None def _build_observation(self) -> TriageObservation: """Build the typed observation for the current state.""" if self._state is None: raise RuntimeError("No active state.") cases = TASK_CASES[self._current_task] step = self._state.step_count max_steps = self._state.max_steps if self._state.done or self._case_index >= len(cases): return TriageObservation( task_id=self._current_task, message=f"Episode complete. Cumulative reward: {self._cumulative_reward:.4f}", ) case = cases[self._case_index] if self._current_task == TaskID.ADVERSE_EVENT_TRIAGE: obs = AdverseEventObservation( case_id=case["case_id"], narrative=case["narrative"], patient_age=case["patient_age"], patient_sex=case["patient_sex"], study_drug=case["study_drug"], dose_mg=case["dose_mg"], days_on_drug=case["days_on_drug"], relevant_medical_history=case["relevant_medical_history"], concomitant_medications=case["concomitant_medications"], lab_values=case["lab_values"], ae_onset_date=case["ae_onset_date"], ae_description=case["ae_description"], outcome=case["outcome"], step_count=step, max_steps=max_steps, ) return TriageObservation( task_id=self._current_task, ae_observation=obs, message=f"Step {step + 1}/{max_steps}: Classify the adverse event.", ) elif self._current_task == TaskID.PROTOCOL_DEVIATION_AUDIT: obs = ProtocolDeviationObservation( site_id=case["site_id"], site_name=case["site_name"], visit_type=case["visit_type"], findings=case["findings"], prior_deviations=case["prior_deviations"], active_subjects=case["active_subjects"], study_phase=case["study_phase"], last_monitoring_visit=case["last_monitoring_visit"], step_count=step, max_steps=max_steps, ) return TriageObservation( task_id=self._current_task, deviation_observation=obs, message=f"Step {step + 1}/{max_steps}: Audit the site findings.", ) elif self._current_task == TaskID.SAFETY_NARRATIVE_GENERATION: obs = SafetyNarrativeObservation( case_id=case["case_id"], patient_demographics=case["patient_demographics"], study_drug=case["study_drug"], suspect_drugs=case["suspect_drugs"], concomitant_medications=case["concomitant_medications"], adverse_event=case["adverse_event"], lab_values_timeline=case["lab_values_timeline"], medical_history=case["medical_history"], action_taken=case["action_taken"], outcome_at_last_followup=case["outcome_at_last_followup"], reference_documents=case["reference_documents"], step_count=step, max_steps=max_steps, ) return TriageObservation( task_id=self._current_task, narrative_observation=obs, message=f"Step {step + 1}/{max_steps}: Write the ICSR safety narrative.", ) return TriageObservation(task_id=self._current_task, message="Unknown task state.") def _grade(self, action: TriageAction) -> TriageReward: """Route grading to the correct task grader.""" cases = TASK_CASES[self._current_task] if self._case_index >= len(cases): return TriageReward(total=_SCORE_EPS) case = cases[self._case_index] if self._current_task == TaskID.ADVERSE_EVENT_TRIAGE: if action.ae_triage is None: return TriageReward( total=_SCORE_EPS, penalty_applied=True, penalty_reason="No ae_triage action provided for adverse_event_triage task." ) return grade_ae_triage(action.ae_triage, case) elif self._current_task == TaskID.PROTOCOL_DEVIATION_AUDIT: if action.deviation_audit is None: return TriageReward( total=_SCORE_EPS, penalty_applied=True, penalty_reason="No deviation_audit action provided for protocol_deviation_audit task." ) return grade_protocol_deviation(action.deviation_audit, case) elif self._current_task == TaskID.SAFETY_NARRATIVE_GENERATION: if action.safety_narrative is None: return TriageReward( total=_SCORE_EPS, penalty_applied=True, penalty_reason="No safety_narrative action provided for safety_narrative_generation task." ) return grade_safety_narrative(action.safety_narrative, case) return TriageReward(total=_SCORE_EPS) def _shape_reward( self, raw_reward: float, detail: TriageReward, action: TriageAction, current_observation: TriageObservation, ) -> float: """ Apply reward shaping to ensure dense signals across the trajectory. - Small bonus for partial progress (>0.3 total) - Deduction for penalty-flagged actions - No bonus for trivially wrong answers """ shaped = raw_reward # Partial progress signal if 0.3 <= raw_reward < 0.6: shaped += 0.02 # small encouragement signal # Penalty deduction if detail.penalty_applied: shaped = max(0.0, shaped - 0.05) # Anti-gaming penalty for obviously inflated severity when narrative implies mild signal. if ( action.task_id == TaskID.ADVERSE_EVENT_TRIAGE and action.ae_triage is not None and current_observation.ae_observation is not None ): ae_action = action.ae_triage narrative = current_observation.ae_observation.narrative.lower() narrative_implies_mild = any(word in narrative for word in ["mild", "minor", "slight", "minimal"]) narrative_has_critical = any( word in narrative for word in ["life-threatening", "critical", "icu", "intubat", "cardiac arrest"] ) severity_value = getattr(ae_action.severity_classification, "value", ae_action.severity_classification) agent_says_life_threatening = str(severity_value) == "life_threatening" if narrative_implies_mild and not narrative_has_critical and agent_says_life_threatening: shaped -= 0.15 # Anti-loop penalty for repeating identical consecutive actions. action_signature = json.dumps(action.model_dump(mode="json"), sort_keys=True) if self._last_action_signature == action_signature: shaped -= 0.05 self._last_action_signature = action_signature return round(min(1.0 - _SCORE_EPS, max(_SCORE_EPS, shaped)), 6)