File size: 6,946 Bytes
404c45f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2e410e4
 
 
 
 
 
 
404c45f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2e410e4
 
 
404c45f
 
 
 
 
 
2e410e4
404c45f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
"""
OpenEnv-native environment adapter for Clinical Trial Triage.

This module exposes a Meta OpenEnv Environment implementation while reusing
existing domain logic from server.environment. It enables native OpenEnv
HTTP/WebSocket operation via openenv.core.env_server.create_fastapi_app.
"""
from __future__ import annotations

import random
from typing import Any, Dict, Optional

from pydantic import Field, model_validator

from openenv.core.env_server import (
    Action,
    Environment,
    Observation,
    State,
)
from openenv.core.env_server.types import EnvironmentMetadata

from models import (
    AdverseEventTriageAction,
    ProtocolDeviationAction,
    SafetyNarrativeAction,
    TaskID,
    TriageAction,
)
from server.environment import ClinicalTrialEnvironment


_SCORE_EPS = 1e-3


def _clamp_open_score(value: float) -> float:
    return max(_SCORE_EPS, min(1.0 - _SCORE_EPS, float(value)))


class OpenEnvTriageAction(Action):
    """OpenEnv action wrapper for the clinical triage tasks."""

    task_id: TaskID = Field(..., description="Task to execute action against")
    ae_triage: Optional[AdverseEventTriageAction] = None
    deviation_audit: Optional[ProtocolDeviationAction] = None
    safety_narrative: Optional[SafetyNarrativeAction] = None

    @model_validator(mode="after")
    def validate_task_payload(self) -> "OpenEnvTriageAction":
        has_ae = self.ae_triage is not None
        has_dev = self.deviation_audit is not None
        has_nr = self.safety_narrative is not None
        if sum([has_ae, has_dev, has_nr]) != 1:
            raise ValueError(
                "Exactly one payload must be provided: ae_triage, deviation_audit, or safety_narrative"
            )

        if self.task_id == TaskID.ADVERSE_EVENT_TRIAGE and not has_ae:
            raise ValueError("task_id=adverse_event_triage requires ae_triage payload")
        if self.task_id == TaskID.PROTOCOL_DEVIATION_AUDIT and not has_dev:
            raise ValueError("task_id=protocol_deviation_audit requires deviation_audit payload")
        if self.task_id == TaskID.SAFETY_NARRATIVE_GENERATION and not has_nr:
            raise ValueError("task_id=safety_narrative_generation requires safety_narrative payload")
        return self


class OpenEnvTriageObservation(Observation):
    """OpenEnv observation wrapper with full domain payload."""

    task_id: TaskID
    payload: Dict[str, Any] = Field(default_factory=dict)
    message: str = ""


class OpenEnvTriageState(State):
    """OpenEnv state object for environment introspection."""

    task_id: Optional[TaskID] = None
    max_steps: int = 0
    done: bool = False
    cumulative_reward: float = 0.0
    current_case_id: Optional[str] = None


class ClinicalTrialOpenEnv(
    Environment[OpenEnvTriageAction, OpenEnvTriageObservation, OpenEnvTriageState]
):
    """
    Native OpenEnv environment implementation for clinical trial triage.

    Supports:
    - task-specific episodes via reset(task_id=...)
    - mixed-task curriculum via reset(task_id="mixed")
    - complete reward + done semantics in OpenEnv Observation
    """

    SUPPORTS_CONCURRENT_SESSIONS = False

    def __init__(self) -> None:
        super().__init__()
        self._core = ClinicalTrialEnvironment()
        self._task_rng = random.Random(42)
        self._last_task_id: Optional[TaskID] = None

    def _to_openenv_observation(
        self,
        payload: Dict[str, Any],
        task_id: TaskID,
        reward: Optional[float],
        done: bool,
        message: str,
        reward_detail: Optional[Dict[str, Any]] = None,
    ) -> OpenEnvTriageObservation:
        metadata = {}
        if reward_detail is not None:
            metadata["reward_detail"] = reward_detail

        return OpenEnvTriageObservation(
            task_id=task_id,
            payload=payload,
            message=message,
            reward=reward,
            done=done,
            metadata=metadata,
        )

    def reset(
        self,
        seed: Optional[int] = None,
        episode_id: Optional[str] = None,
        **kwargs: Any,
    ) -> OpenEnvTriageObservation:
        if seed is not None:
            self._task_rng.seed(seed)

        requested_task = kwargs.get("task_id", TaskID.ADVERSE_EVENT_TRIAGE)
        if requested_task == "mixed":
            chosen_task = self._task_rng.choice(
                [
                    TaskID.ADVERSE_EVENT_TRIAGE,
                    TaskID.PROTOCOL_DEVIATION_AUDIT,
                    TaskID.SAFETY_NARRATIVE_GENERATION,
                ]
            )
        else:
            chosen_task = TaskID(requested_task)

        self._last_task_id = chosen_task
        obs = self._core.reset(task_id=chosen_task)
        payload = obs.model_dump()
        return self._to_openenv_observation(
            payload=payload,
            task_id=chosen_task,
            reward=None,
            done=False,
            message=obs.message,
        )

    def step(
        self,
        action: OpenEnvTriageAction,
        timeout_s: Optional[float] = None,
        **kwargs: Any,
    ) -> OpenEnvTriageObservation:
        triage_action = TriageAction(
            task_id=action.task_id,
            ae_triage=action.ae_triage,
            deviation_audit=action.deviation_audit,
            safety_narrative=action.safety_narrative,
        )

        step_result = self._core.step(triage_action)
        obs = step_result.observation
        payload = obs.model_dump()
        return self._to_openenv_observation(
            payload=payload,
            task_id=TaskID(obs.task_id),
            reward=step_result.reward,
            done=step_result.done,
            message=obs.message,
            reward_detail=step_result.reward_detail.model_dump(),
        )

    @property
    def state(self) -> OpenEnvTriageState:
        state = self._core.state()
        normalized_cumulative = _clamp_open_score(
            state.cumulative_reward / state.step_count if state.step_count > 0 else _SCORE_EPS
        )
        return OpenEnvTriageState(
            episode_id=state.episode_id,
            step_count=state.step_count,
            task_id=TaskID(state.task_id),
            max_steps=state.max_steps,
            done=state.done,
            cumulative_reward=normalized_cumulative,
            current_case_id=state.current_case_id,
        )

    def get_metadata(self) -> EnvironmentMetadata:
        return EnvironmentMetadata(
            name="Clinical Trial Triage OpenEnv",
            description=(
                "Production-grade multi-task environment for adverse event triage, "
                "protocol deviation auditing, and safety narrative generation."
            ),
            version="2.0.0",
            author="OpenEnv Hackathon Submission",
            documentation_url="/docs",
        )

    def close(self) -> None:
        # Core env has no external handles to close today.
        return