File size: 14,244 Bytes
d5e8433
 
 
 
 
 
 
 
 
 
 
 
 
dee17cd
 
 
 
 
 
 
d5e8433
 
dee17cd
d5e8433
 
 
 
 
 
dee17cd
d5e8433
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dee17cd
 
 
 
 
 
 
 
d5e8433
 
 
 
 
 
 
 
 
dee17cd
d5e8433
 
 
dee17cd
 
 
 
 
 
 
 
d5e8433
 
 
 
 
dee17cd
d5e8433
 
 
 
 
dee17cd
 
 
 
 
 
 
 
d5e8433
 
dee17cd
 
 
 
d5e8433
 
dee17cd
d5e8433
 
dee17cd
d5e8433
 
 
dee17cd
 
 
 
 
 
 
 
 
 
d5e8433
 
 
 
 
dee17cd
 
 
d5e8433
 
 
 
 
 
 
 
 
 
 
dee17cd
 
d5e8433
dee17cd
d5e8433
dee17cd
d5e8433
 
 
dee17cd
 
 
d5e8433
dee17cd
d5e8433
 
 
 
 
 
 
dee17cd
d5e8433
dee17cd
d5e8433
 
 
dee17cd
 
 
d5e8433
dee17cd
d5e8433
 
 
 
dee17cd
 
 
d5e8433
 
dee17cd
 
d5e8433
 
 
dee17cd
 
d5e8433
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dee17cd
 
 
 
 
d5e8433
 
 
 
dee17cd
 
 
d5e8433
 
 
dee17cd
d5e8433
 
 
dee17cd
 
d5e8433
 
 
 
 
 
 
 
 
 
 
 
 
 
dee17cd
 
 
d5e8433
 
 
 
dee17cd
 
 
d5e8433
 
 
 
 
 
dee17cd
 
d5e8433
 
 
dee17cd
 
 
d5e8433
dee17cd
 
 
 
 
 
d5e8433
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dee17cd
 
 
 
 
 
d5e8433
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dee17cd
 
 
d5e8433
 
 
 
 
 
 
 
dee17cd
 
 
 
d5e8433
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
# Copyright (c) Ajay Bandiwaddar β€” OpenEnv Hackathon Round 1
"""
Queue Doctor β€” OpenEnv MCPEnvironment.

A genuine multi-step reinforcement learning environment for hospital
emergency department triage. The agent makes sequential decisions
each step β€” which patient to serve β€” and the environment state changes
meaningfully in response (new arrivals, wait time increases, deterioration).

This is a true Markov Decision Process: the agent's action at step N
changes the state available at step N+1. A better policy produces
measurably better outcomes across all three tasks.

Stochasticity:
    start_task() accepts an optional seed parameter. When provided, small
    random perturbations are applied to patient attributes (severity Β±1
    with 15% probability, arrival step Β±1 with 10% probability). This
    ensures each episode is distinct, prevents solution memorization, and
    produces non-zero score variance across runs (required by Phase 2).

Episode workflow:
    list_tasks()
    β†’ start_task(task_id, seed=<int>)   # seed optional
    β†’ get_queue_state()
    β†’ [serve_patient(patient_id) | wait()] Γ— N steps
    β†’ finalize_episode()
"""

import json
from typing import Optional
from uuid import uuid4

try:
    from openenv.core.env_server.mcp_environment import MCPEnvironment
    from openenv.core.env_server.types import Action, Observation, State
except ImportError:
    from openenv.core.env_server.mcp_environment import MCPEnvironment
    from openenv.core.env_server.types import Action, Observation, State

from fastmcp import FastMCP

from .tasks import TASKS
from .queue_engine import QueueEngine
from .graders import GRADERS


class QueueDoctorEnvironment(MCPEnvironment):
    """
    Queue Doctor β€” Hospital Triage RL Environment.

    Three tasks of increasing difficulty:
        task_1_easy   β€” Static queue, 1 doctor, 10 steps
        task_2_medium β€” Dynamic arrivals, 2 doctors, 20 steps
        task_3_hard   β€” Mass casualty, deterioration, ICU, 3 doctors, 30 steps

    MCP tools:
        list_tasks()                    β†’ task catalogue with metadata
        start_task(task_id, seed)       β†’ init episode (seed optional for stochasticity)
        get_queue_state()               β†’ observe current state (no time advance)
        serve_patient(patient_id)       β†’ treat patient, advance 1 step
        wait()                          β†’ skip step (penalized), advance 1 step
        finalize_episode()              β†’ compute final normalized score
        get_current_state()             β†’ environment-level metadata
    """

    def __init__(self):
        mcp = FastMCP("queue_doctor")

        @mcp.tool
        def list_tasks() -> str:
            """
            List all available triage tasks with metadata.
            Returns task IDs, names, difficulty, resources, and descriptions.
            """
            return json.dumps([
                {
                    "task_id":         tid,
                    "task_name":       t["task_name"],
                    "difficulty":      t["difficulty"],
                    "max_steps":       t["max_steps"],
                    "num_doctors":     t["num_doctors"],
                    "icu_beds":        t.get("icu_beds", 0),
                    "total_patients":  len(t["arrivals"]),
                    "description":     t["description"],
                }
                for tid, t in TASKS.items()
            ], indent=2)

        @mcp.tool
        def start_task(task_id: str, seed: int = None) -> str:
            """
            Initialize a task episode. Must be called before any actions.

            Args:
                task_id: One of 'task_1_easy', 'task_2_medium', 'task_3_hard'
                seed:    Optional integer seed for episode randomization.
                         When provided, small stochastic perturbations are
                         applied to patient attributes (severity Β±1 with 15%
                         probability, arrival step Β±1 with 10% probability).
                         Use different seeds across runs to get score variance.
                         Omit for the deterministic baseline episode.

            Returns task description, rules, initial queue state, and workflow.
            """
            if task_id not in TASKS:
                return json.dumps({
                    "error": f"Unknown task_id '{task_id}'. "
                             f"Valid: {list(TASKS.keys())}"
                })

            self._active_task_id = task_id
            self._engine         = QueueEngine(TASKS[task_id], seed=seed)
            self._state.step_count += 1

            task          = TASKS[task_id]
            initial_state = self._engine.get_state()

            return json.dumps({
                "task_id":         task_id,
                "task_name":       task["task_name"],
                "difficulty":      task["difficulty"],
                "description":     task["description"],
                "max_steps":       task["max_steps"],
                "num_doctors":     task["num_doctors"],
                "icu_beds":        task.get("icu_beds", 0),
                "seed":            seed,
                "initial_queue":   initial_state["queue"],
                "queue_length":    initial_state["queue_length"],
                "triage_advisory": initial_state["triage_advisory"],
                "workflow": (
                    "1. Call get_queue_state() to observe current patients.\n"
                    "2. Call serve_patient(patient_id) to treat a patient "
                    "   β€” this advances time by 1 step.\n"
                    "3. OR call wait() to skip a step "
                    "   (penalized if patients are waiting).\n"
                    "4. Repeat until done=true.\n"
                    "5. Call finalize_episode() to get your final score."
                ),
            }, indent=2)

        @mcp.tool
        def get_queue_state() -> str:
            """
            Observe the current emergency department state. Does NOT advance time.

            Returns:
                - Current step and steps remaining
                - All patients sorted by priority (severity, then wait time)
                - can_serve_now flag per patient (resource availability check)
                - Available doctors and ICU beds
                - Patients served and missed emergencies
                - Cumulative reward
                - Triage advisory (for inspection β€” not used by the inference agent)
                - done flag
            """
            if self._engine is None:
                return json.dumps({
                    "error": "No active task. Call start_task(task_id) first."
                })
            self._state.step_count += 1
            return json.dumps(self._engine.get_state(), indent=2)

        @mcp.tool
        def serve_patient(patient_id: str) -> str:
            """
            Assign a doctor to treat a patient. ADVANCES SIMULATION BY 1 STEP.

            After this action:
            - Patient removed from queue
            - Wait times increase for all remaining patients
            - New patients may arrive (deterministic or seeded schedule)
            - Deteriorating patients' countdowns decrease (Task 3)
            - Step counter increments

            Resource errors (no ICU bed, insufficient doctors) do NOT advance
            time β€” the agent receives an error message and must choose again.

            Args:
                patient_id: Patient ID (e.g. 'P001', 'P007')

            Returns step reward, updated queue state, and events log.
            """
            if self._engine is None:
                return json.dumps({
                    "error": "No active task. Call start_task(task_id) first."
                })
            if self._engine.step >= self._engine.max_steps:
                return json.dumps({
                    "error": "Episode complete. Call finalize_episode().",
                    "done":  True,
                })

            reward, new_state, events = self._engine.serve_patient(patient_id)
            self._cumulative_reward  += reward
            self._state.step_count   += 1

            return json.dumps({
                "action":      f"serve_patient({patient_id})",
                "step_reward": round(reward, 4),
                "events":      events,
                "state":       new_state,
                "done":        new_state["done"],
                "hint": (
                    "Call finalize_episode() to get your final score."
                    if new_state["done"] else
                    "Continue serving patients or call finalize_episode() anytime."
                ),
            }, indent=2)

        @mcp.tool
        def wait() -> str:
            """
            Skip this step without serving any patient. ADVANCES SIMULATION BY 1 STEP.

            Penalties:
              Emergency (severity 1) in queue: -0.30 per patient
              Urgent (severity 2-3) in queue:  -0.10
              Any patient in queue:             -0.05
              Empty queue:                       0.00

            Returns step penalty, updated queue state, and events log.
            """
            if self._engine is None:
                return json.dumps({
                    "error": "No active task. Call start_task(task_id) first."
                })
            if self._engine.step >= self._engine.max_steps:
                return json.dumps({
                    "error": "Episode complete. Call finalize_episode().",
                    "done":  True,
                })

            penalty, new_state, events = self._engine.wait()
            self._cumulative_reward   += penalty
            self._state.step_count    += 1

            return json.dumps({
                "action":      "wait()",
                "step_reward": round(penalty, 4),
                "events":      events,
                "state":       new_state,
                "done":        new_state["done"],
            }, indent=2)

        @mcp.tool
        def finalize_episode() -> str:
            """
            Finalize the current task and compute the final normalized score.

            Applies the principled grader to produce a score in [0, 1].
            Grader weights are derived from published clinical literature β€”
            not tuned to hit target scores.

            Returns final score, component scores, and full episode statistics.
            """
            if self._engine is None:
                return json.dumps({
                    "error": "No active task. Call start_task(task_id) first."
                })

            task_id = self._active_task_id
            task    = TASKS[task_id]
            result  = GRADERS[task["grader"]](self._engine)

            self._finalized_tasks[task_id] = result["score"]
            done       = len(self._finalized_tasks) >= len(TASKS)
            self._done = done
            self._state.step_count += 1

            return json.dumps({
                "task_id":         task_id,
                "task_name":       task["task_name"],
                "difficulty":      task["difficulty"],
                **result,
                "episode_steps":   self._engine.step,
                "patients_served": len(self._engine.served),
                "served_detail":   self._engine.served,
                "tasks_completed": len(self._finalized_tasks),
                "tasks_total":     len(TASKS),
                "all_done":        done,
            }, indent=2)

        @mcp.tool
        def get_current_state() -> str:
            """Get environment-level metadata (episode state, not queue state)."""
            return json.dumps({
                "episode_id":        self._state.episode_id,
                "step_count":        self._state.step_count,
                "active_task":       self._active_task_id,
                "finalized_scores":  self._finalized_tasks,
                "cumulative_reward": round(self._cumulative_reward, 4),
                "done":              self._done,
                "tasks_available":   list(TASKS.keys()),
            }, indent=2)

        super().__init__(mcp)
        self._state                         = State(episode_id=str(uuid4()), step_count=0)
        self._cumulative_reward: float      = 0.0
        self._done: bool                    = False
        self._active_task_id: Optional[str] = None
        self._engine: Optional[QueueEngine] = None
        self._finalized_tasks: dict         = {}

    def reset(self, seed=None, episode_id=None, **kwargs) -> Observation:
        self._state               = State(episode_id=episode_id or str(uuid4()), step_count=0)
        self._cumulative_reward   = 0.0
        self._done                = False
        self._active_task_id      = None
        self._engine              = None
        self._finalized_tasks     = {}
        return Observation(
            done=False,
            reward=0.0,
            metadata={
                "status":  "ready",
                "message": (
                    "Queue Doctor ready. "
                    "Workflow: list_tasks() β†’ start_task(task_id, seed=<int>) β†’ "
                    "get_queue_state() β†’ "
                    "[serve_patient(patient_id) or wait()] Γ— N β†’ "
                    "finalize_episode()"
                ),
                "tasks_available": list(TASKS.keys()),
            },
        )

    def _step_impl(self, action, timeout_s=None, **kwargs) -> Observation:
        return Observation(
            done=False, reward=0.0,
            metadata={
                "error": f"Unknown action: {type(action).__name__}. Use MCP tools."
            },
        )

    def step(self, action, timeout_s=None, **kwargs) -> Observation:
        self._state.step_count += 1
        return super().step(action, timeout_s=timeout_s, **kwargs)

    async def step_async(self, action, timeout_s=None, **kwargs) -> Observation:
        self._state.step_count += 1
        return await super().step_async(action, timeout_s=timeout_s, **kwargs)

    @property
    def state(self) -> State:
        return self._state