File size: 11,874 Bytes
64d56f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
431e294
82713c7
 
 
431e294
 
 
 
 
 
 
 
64d56f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82713c7
64d56f9
 
 
 
 
 
 
 
 
 
82713c7
64d56f9
82713c7
 
64d56f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82713c7
64d56f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82713c7
64d56f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82713c7
64d56f9
 
82713c7
 
 
 
 
 
 
 
 
64d56f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
"""DispatchPulse OpenEnv environment.

Inherits from ``openenv.core.env_server.interfaces.Environment`` and implements
the standard ``reset() / step() / state`` Gym-style API. The wire types
``DispatchPulseAction`` and ``DispatchPulseObservation`` are defined in
``models.py`` and inherit from the OpenEnv ``Action`` / ``Observation`` base
classes.

This is a thin wrapper around the in-process ``DispatchSimulation`` engine.
"""

from __future__ import annotations

import os
import sys
from typing import Any, Optional
from uuid import uuid4

# Make project root importable when running as ``server.app:app`` from /app/env
_PKG_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if _PKG_ROOT not in sys.path:
    sys.path.insert(0, _PKG_ROOT)

from openenv.core.env_server.interfaces import Environment

from grader import grade_simulation
from models import DispatchPulseAction, DispatchPulseObservation, DispatchPulseState
from scenario_loader import VALID_TASKS, load_scenario
from simulation import DispatchSimulation
from text_view import render_dispatch_center

# Re-export the task registry and grader symbols at module level so static
# validators that scan server/environment.py for tasks can find them here
# (same pattern as the SQL Repair passing submission where both TASKS and
# grade_submission are accessible from server/environment.py).
from task_definitions import (  # noqa: F401,E402
    TASKS,
    TaskDefinition,
    grade_submission,
    get_task,
    list_tasks,
)

DEFAULT_TASK = "easy"
DEFAULT_SEED = 42


class DispatchPulseEnvironment(
    Environment[DispatchPulseAction, DispatchPulseObservation, DispatchPulseState]
):
    """Emergency-dispatch OpenEnv environment.

    Each call to ``reset()`` starts a fresh episode for the chosen task.
    Calls to ``step(action)`` advance the simulation by one decision turn
    (which usually equals 1 minute of simulation time).

    Tasks: ``easy``, ``medium``, ``hard``.
    """

    SUPPORTS_CONCURRENT_SESSIONS: bool = True

    def __init__(self) -> None:
        super().__init__()
        self.sim: Optional[DispatchSimulation] = None
        self.task_name: str = DEFAULT_TASK
        self.seed: int = DEFAULT_SEED
        self._episode_id: str = str(uuid4())
        self._step_count: int = 0
        self._cumulative_step_reward: float = 0.0
        self._last_step_reward: float = 0.0
        # Bootstrap so single-shot HTTP /step still works without an explicit reset
        self._bootstrap()

    def _bootstrap(self) -> None:
        try:
            scenario = load_scenario(DEFAULT_TASK)
            self.sim = DispatchSimulation(scenario, seed=DEFAULT_SEED)
            self.task_name = DEFAULT_TASK
            self.seed = DEFAULT_SEED
            self._cumulative_step_reward = 0.0
            self._last_step_reward = 0.0
            self._step_count = 0
        except Exception as exc:  # pragma: no cover
            print(f"[DispatchPulseEnvironment] bootstrap failed: {exc}", file=sys.stderr, flush=True)
            self.sim = None

    # ------------------------------------------------------------------
    # Environment API
    # ------------------------------------------------------------------

    def reset(
        self,
        seed: Optional[int] = None,
        episode_id: Optional[str] = None,
        task_name: Optional[str] = None,
        **kwargs: Any,
    ) -> DispatchPulseObservation:
        chosen_task = (task_name or DEFAULT_TASK).strip().lower()
        if chosen_task not in VALID_TASKS:
            chosen_task = DEFAULT_TASK
        chosen_seed = int(seed) if seed is not None else DEFAULT_SEED

        scenario = load_scenario(chosen_task)
        self.sim = DispatchSimulation(scenario, seed=chosen_seed)
        self.task_name = chosen_task
        self.seed = chosen_seed
        self._episode_id = episode_id or str(uuid4())
        self._step_count = 0
        self._cumulative_step_reward = 0.0
        self._last_step_reward = 0.0
        return self._build_observation(info_message="ready", error=None)

    def step(
        self,
        action: DispatchPulseAction,
        timeout_s: Optional[float] = None,
        **kwargs: Any,
    ) -> DispatchPulseObservation:
        if self.sim is None:
            self._bootstrap()
        if self.sim is None:
            return self._build_observation(error="environment not initialised")

        if self.sim.episode_done:
            return self._build_observation(error="episode already done")

        self._step_count += 1
        action_type = (action.action_type or "").strip().lower()
        text_action = (action.text or "").strip()

        # Allow text-only actions: parse the text into structured fields
        if not action_type and text_action:
            parsed = _parse_text_action(text_action)
            if parsed is not None:
                action_type, fields = parsed
                for key, value in fields.items():
                    if getattr(action, key, None) in (None, ""):
                        setattr(action, key, value)

        step_reward = 0.0
        info_message: Optional[str] = None
        error: Optional[str] = None

        try:
            if action_type == "dispatch":
                if not action.call_id or not action.unit_id:
                    error = "dispatch requires call_id and unit_id"
                else:
                    step_reward, info_message = self.sim.dispatch(
                        call_id=action.call_id,
                        unit_id=action.unit_id,
                        hospital_id=action.hospital_id,
                    )
                    self.sim.advance_time(1)
            elif action_type == "classify":
                if not action.call_id or action.severity is None:
                    error = "classify requires call_id and severity (1-5)"
                else:
                    step_reward, info_message = self.sim.classify(
                        call_id=action.call_id, severity=int(action.severity)
                    )
                    self.sim.advance_time(1)
            elif action_type == "callback":
                if not action.call_id:
                    error = "callback requires call_id"
                else:
                    step_reward, info_message = self.sim.callback(
                        call_id=action.call_id, question=action.message or ""
                    )
                    self.sim.advance_time(1)
            elif action_type == "wait":
                minutes = int(action.minutes or 1)
                minutes = max(1, min(minutes, self.sim.config.max_wait_step_minutes))
                pending_before = len(self.sim.get_pending_calls())
                self.sim.advance_time(minutes)
                step_reward = -0.005 * minutes * pending_before
                info_message = f"waited {minutes} minute(s)"
            elif action_type == "view":
                step_reward = 0.0
                info_message = "view (no time cost)"
            else:
                step_reward = -0.05
                error = f"unknown action_type: {action_type!r}"
        except Exception as exc:  # pragma: no cover - defensive
            error = f"{type(exc).__name__}: {exc}"
            step_reward = -0.05

        self._cumulative_step_reward += step_reward
        self._last_step_reward = step_reward
        return self._build_observation(info_message=info_message, error=error)

    @property
    def state(self) -> DispatchPulseState:
        if self.sim is None:
            return DispatchPulseState(
                episode_id=self._episode_id,
                step_count=self._step_count,
                task_name=self.task_name,
            )
        return DispatchPulseState(
            episode_id=self._episode_id,
            step_count=self._step_count,
            current_time=self.sim.current_time,
            episode_done=self.sim.episode_done,
            total_calls=self.sim.total_calls(),
            calls_dispatched=len(self.sim.dispatches),
            calls_completed=len(self.sim.completed_calls),
            calls_timed_out=len(self.sim.timed_out_calls),
            calls_pending=len(self.sim.get_pending_calls()),
            units_available=len(self.sim.get_available_units()),
            running_reward=self._cumulative_step_reward,
            task_name=self.task_name,
        )

    # ------------------------------------------------------------------
    # Helpers
    # ------------------------------------------------------------------

    def _build_observation(
        self,
        info_message: Optional[str] = None,
        error: Optional[str] = None,
    ) -> DispatchPulseObservation:
        if self.sim is None:
            return DispatchPulseObservation(
                done=True,
                reward=0.0,
                text="ERROR: environment not initialised. Call reset first.",
                last_action_error="not_initialised",
            )

        text = render_dispatch_center(self.sim, self.task_name)
        done = bool(self.sim.episode_done)
        if done:
            final = grade_simulation(self.sim)
            reward_value: float = float(final.total)
            metadata = {
                "final_reward": final.model_dump(),
                "task": self.task_name,
                "cumulative_step_reward": float(self._cumulative_step_reward),
            }
        else:
            # Report the per-step delta, not the running cumulative. The
            # cumulative is still available via state() and metadata, but the
            # observation's reward field matches the standard Gym/OpenEnv
            # semantics of "reward for this step only".
            reward_value = float(self._last_step_reward)
            metadata = {
                "task": self.task_name,
                "cumulative_step_reward": float(self._cumulative_step_reward),
            }

        if info_message:
            metadata["info"] = info_message
        if error:
            metadata["error"] = error

        return DispatchPulseObservation(
            done=done,
            reward=reward_value,
            text=text,
            current_time=self.sim.current_time,
            time_limit=self.sim.config.time_limit_minutes,
            calls_pending=len(self.sim.get_pending_calls()),
            units_available=len(self.sim.get_available_units()),
            calls_completed=len(self.sim.completed_calls),
            calls_timed_out=len(self.sim.timed_out_calls),
            total_calls=self.sim.total_calls(),
            last_action_error=error,
            info_message=info_message,
            metadata=metadata,
        )


def _parse_text_action(text: str):
    """Parse a text action like ``dispatch CALL-001 ALS-1 H1`` into fields.

    Returns ``(action_type, kwargs_dict)`` or None on parse failure.
    """
    parts = text.strip().split(maxsplit=4)
    if not parts:
        return None
    head = parts[0].lower()
    if head == "dispatch" and len(parts) >= 3:
        out = {"call_id": parts[1], "unit_id": parts[2]}
        if len(parts) >= 4 and parts[3]:
            out["hospital_id"] = parts[3]
        return "dispatch", out
    if head == "classify" and len(parts) >= 3:
        try:
            sev = int(parts[2])
        except ValueError:
            return None
        return "classify", {"call_id": parts[1], "severity": sev}
    if head == "callback" and len(parts) >= 2:
        return "callback", {
            "call_id": parts[1],
            "message": " ".join(parts[2:]) if len(parts) > 2 else "",
        }
    if head == "wait":
        try:
            mins = int(parts[1]) if len(parts) > 1 else 1
        except ValueError:
            mins = 1
        return "wait", {"minutes": mins}
    if head in ("view", "view_dispatch_center"):
        return "view", {}
    return None