Spaces:
Sleeping
Sleeping
| """ | |
| OpenEnv-native environment adapter for Clinical Trial Triage. | |
| This module exposes a Meta OpenEnv Environment implementation while reusing | |
| existing domain logic from server.environment. It enables native OpenEnv | |
| HTTP/WebSocket operation via openenv.core.env_server.create_fastapi_app. | |
| """ | |
| from __future__ import annotations | |
| import random | |
| from typing import Any, Dict, Optional | |
| from pydantic import Field, model_validator | |
| from openenv.core.env_server import ( | |
| Action, | |
| Environment, | |
| Observation, | |
| State, | |
| ) | |
| from openenv.core.env_server.types import EnvironmentMetadata | |
| from models import ( | |
| AdverseEventTriageAction, | |
| ProtocolDeviationAction, | |
| SafetyNarrativeAction, | |
| TaskID, | |
| TriageAction, | |
| ) | |
| from server.environment import ClinicalTrialEnvironment | |
| _SCORE_EPS = 1e-3 | |
| def _clamp_open_score(value: float) -> float: | |
| return max(_SCORE_EPS, min(1.0 - _SCORE_EPS, float(value))) | |
| class OpenEnvTriageAction(Action): | |
| """OpenEnv action wrapper for the clinical triage tasks.""" | |
| task_id: TaskID = Field(..., description="Task to execute action against") | |
| ae_triage: Optional[AdverseEventTriageAction] = None | |
| deviation_audit: Optional[ProtocolDeviationAction] = None | |
| safety_narrative: Optional[SafetyNarrativeAction] = None | |
| def validate_task_payload(self) -> "OpenEnvTriageAction": | |
| has_ae = self.ae_triage is not None | |
| has_dev = self.deviation_audit is not None | |
| has_nr = self.safety_narrative is not None | |
| if sum([has_ae, has_dev, has_nr]) != 1: | |
| raise ValueError( | |
| "Exactly one payload must be provided: ae_triage, deviation_audit, or safety_narrative" | |
| ) | |
| if self.task_id == TaskID.ADVERSE_EVENT_TRIAGE and not has_ae: | |
| raise ValueError("task_id=adverse_event_triage requires ae_triage payload") | |
| if self.task_id == TaskID.PROTOCOL_DEVIATION_AUDIT and not has_dev: | |
| raise ValueError("task_id=protocol_deviation_audit requires deviation_audit payload") | |
| if self.task_id == TaskID.SAFETY_NARRATIVE_GENERATION and not has_nr: | |
| raise ValueError("task_id=safety_narrative_generation requires safety_narrative payload") | |
| return self | |
| class OpenEnvTriageObservation(Observation): | |
| """OpenEnv observation wrapper with full domain payload.""" | |
| task_id: TaskID | |
| payload: Dict[str, Any] = Field(default_factory=dict) | |
| message: str = "" | |
| class OpenEnvTriageState(State): | |
| """OpenEnv state object for environment introspection.""" | |
| task_id: Optional[TaskID] = None | |
| max_steps: int = 0 | |
| done: bool = False | |
| cumulative_reward: float = 0.0 | |
| current_case_id: Optional[str] = None | |
| class ClinicalTrialOpenEnv( | |
| Environment[OpenEnvTriageAction, OpenEnvTriageObservation, OpenEnvTriageState] | |
| ): | |
| """ | |
| Native OpenEnv environment implementation for clinical trial triage. | |
| Supports: | |
| - task-specific episodes via reset(task_id=...) | |
| - mixed-task curriculum via reset(task_id="mixed") | |
| - complete reward + done semantics in OpenEnv Observation | |
| """ | |
| SUPPORTS_CONCURRENT_SESSIONS = False | |
| def __init__(self) -> None: | |
| super().__init__() | |
| self._core = ClinicalTrialEnvironment() | |
| self._task_rng = random.Random(42) | |
| self._last_task_id: Optional[TaskID] = None | |
| def _to_openenv_observation( | |
| self, | |
| payload: Dict[str, Any], | |
| task_id: TaskID, | |
| reward: Optional[float], | |
| done: bool, | |
| message: str, | |
| reward_detail: Optional[Dict[str, Any]] = None, | |
| ) -> OpenEnvTriageObservation: | |
| metadata = {} | |
| if reward_detail is not None: | |
| metadata["reward_detail"] = reward_detail | |
| return OpenEnvTriageObservation( | |
| task_id=task_id, | |
| payload=payload, | |
| message=message, | |
| reward=reward, | |
| done=done, | |
| metadata=metadata, | |
| ) | |
| def reset( | |
| self, | |
| seed: Optional[int] = None, | |
| episode_id: Optional[str] = None, | |
| **kwargs: Any, | |
| ) -> OpenEnvTriageObservation: | |
| if seed is not None: | |
| self._task_rng.seed(seed) | |
| requested_task = kwargs.get("task_id", TaskID.ADVERSE_EVENT_TRIAGE) | |
| if requested_task == "mixed": | |
| chosen_task = self._task_rng.choice( | |
| [ | |
| TaskID.ADVERSE_EVENT_TRIAGE, | |
| TaskID.PROTOCOL_DEVIATION_AUDIT, | |
| TaskID.SAFETY_NARRATIVE_GENERATION, | |
| ] | |
| ) | |
| else: | |
| chosen_task = TaskID(requested_task) | |
| self._last_task_id = chosen_task | |
| obs = self._core.reset(task_id=chosen_task) | |
| payload = obs.model_dump() | |
| return self._to_openenv_observation( | |
| payload=payload, | |
| task_id=chosen_task, | |
| reward=None, | |
| done=False, | |
| message=obs.message, | |
| ) | |
| def step( | |
| self, | |
| action: OpenEnvTriageAction, | |
| timeout_s: Optional[float] = None, | |
| **kwargs: Any, | |
| ) -> OpenEnvTriageObservation: | |
| triage_action = TriageAction( | |
| task_id=action.task_id, | |
| ae_triage=action.ae_triage, | |
| deviation_audit=action.deviation_audit, | |
| safety_narrative=action.safety_narrative, | |
| ) | |
| step_result = self._core.step(triage_action) | |
| obs = step_result.observation | |
| payload = obs.model_dump() | |
| return self._to_openenv_observation( | |
| payload=payload, | |
| task_id=TaskID(obs.task_id), | |
| reward=step_result.reward, | |
| done=step_result.done, | |
| message=obs.message, | |
| reward_detail=step_result.reward_detail.model_dump(), | |
| ) | |
| def state(self) -> OpenEnvTriageState: | |
| state = self._core.state() | |
| normalized_cumulative = _clamp_open_score( | |
| state.cumulative_reward / state.step_count if state.step_count > 0 else _SCORE_EPS | |
| ) | |
| return OpenEnvTriageState( | |
| episode_id=state.episode_id, | |
| step_count=state.step_count, | |
| task_id=TaskID(state.task_id), | |
| max_steps=state.max_steps, | |
| done=state.done, | |
| cumulative_reward=normalized_cumulative, | |
| current_case_id=state.current_case_id, | |
| ) | |
| def get_metadata(self) -> EnvironmentMetadata: | |
| return EnvironmentMetadata( | |
| name="Clinical Trial Triage OpenEnv", | |
| description=( | |
| "Production-grade multi-task environment for adverse event triage, " | |
| "protocol deviation auditing, and safety narrative generation." | |
| ), | |
| version="2.0.0", | |
| author="OpenEnv Hackathon Submission", | |
| documentation_url="/docs", | |
| ) | |
| def close(self) -> None: | |
| # Core env has no external handles to close today. | |
| return | |