Arun-Sanjay commited on
Commit
64d56f9
·
1 Parent(s): 1091ce2

Replace MCPEnvironment with canonical Environment base class

Browse files

- New server/environment.py inherits from openenv Environment base class
with the standard reset()/step()/state API
- server/app.py now uses create_fastapi_app(env, ActionCls, ObservationCls)
- Drops the MCP-tool wrapper in favor of typed Action/Observation
- Removes server/dispatchpulse_environment.py (replaced by environment.py)

server/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
  """DispatchPulse environment server components."""
2
 
3
- from .dispatchpulse_environment import DispatchPulseEnvironment
4
 
5
  __all__ = ["DispatchPulseEnvironment"]
 
1
  """DispatchPulse environment server components."""
2
 
3
+ from .environment import DispatchPulseEnvironment
4
 
5
  __all__ = ["DispatchPulseEnvironment"]
server/app.py CHANGED
@@ -1,31 +1,42 @@
1
- """FastAPI application for DispatchPulse, served via openenv-core's create_app."""
 
 
 
 
 
2
 
3
  from __future__ import annotations
4
 
5
  # Support both in-repo and standalone imports.
6
  try:
7
- from openenv.core.env_server.http_server import create_app
8
- from openenv.core.env_server.mcp_types import CallToolAction, CallToolObservation
9
-
10
- from .dispatchpulse_environment import DispatchPulseEnvironment
11
- except ImportError:
12
- from openenv.core.env_server.http_server import create_app
13
- from openenv.core.env_server.mcp_types import CallToolAction, CallToolObservation
14
- from server.dispatchpulse_environment import DispatchPulseEnvironment
15
-
16
- # create_app expects the environment class (not instance) so each WebSocket
17
- # session gets its own environment object — this enables concurrent grading
18
- # without cross-session state leakage.
19
- app = create_app(
 
 
 
 
 
 
20
  DispatchPulseEnvironment,
21
- CallToolAction,
22
- CallToolObservation,
23
- env_name="dispatchpulse",
24
  )
25
 
26
 
27
  def main() -> None:
28
- """Entry point for ``uv run --project . server`` or direct execution."""
29
  import uvicorn
30
 
31
  uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
+ """FastAPI application for DispatchPulse.
2
+
3
+ Uses ``create_fastapi_app(env_factory, ActionCls, ObservationCls)`` from
4
+ openenv-core's HTTP server, which exposes the standard ``/reset``, ``/step``,
5
+ ``/state``, ``/health``, ``/metadata``, ``/schema``, and ``/ws`` routes.
6
+ """
7
 
8
  from __future__ import annotations
9
 
10
  # Support both in-repo and standalone imports.
11
  try:
12
+ from openenv.core.env_server.http_server import create_fastapi_app
13
+
14
+ from .environment import DispatchPulseEnvironment
15
+ except ImportError: # pragma: no cover
16
+ from openenv.core.env_server.http_server import create_fastapi_app
17
+ from server.environment import DispatchPulseEnvironment
18
+
19
+ # Import the typed Action / Observation classes from the project root models.py
20
+ import os
21
+ import sys
22
+
23
+ _PKG_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
24
+ if _PKG_ROOT not in sys.path:
25
+ sys.path.insert(0, _PKG_ROOT)
26
+
27
+ from models import DispatchPulseAction, DispatchPulseObservation # noqa: E402
28
+
29
+ # Pass the class (factory) so each session gets its own env instance.
30
+ app = create_fastapi_app(
31
  DispatchPulseEnvironment,
32
+ DispatchPulseAction,
33
+ DispatchPulseObservation,
34
+ max_concurrent_envs=8,
35
  )
36
 
37
 
38
  def main() -> None:
39
+ """Entry point for ``uv run server`` or direct execution."""
40
  import uvicorn
41
 
42
  uvicorn.run(app, host="0.0.0.0", port=8000)
server/dispatchpulse_environment.py DELETED
@@ -1,358 +0,0 @@
1
- """DispatchPulse MCP environment.
2
-
3
- Inherits from openenv MCPEnvironment and exposes the dispatcher interface
4
- as MCP tools (FastMCP). Each tool advances the simulation by 1 minute,
5
- except `view_dispatch_center` (free inspection) and `wait` (custom n minutes).
6
- """
7
-
8
- from __future__ import annotations
9
-
10
- import os
11
- import sys
12
- from typing import Any, Optional
13
- from uuid import uuid4
14
-
15
- # Make package modules importable when running as `server.app:app`.
16
- _PKG_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
17
- if _PKG_ROOT not in sys.path:
18
- sys.path.insert(0, _PKG_ROOT)
19
-
20
- # Import the OpenEnv base classes.
21
- try:
22
- from openenv.core.env_server.mcp_environment import MCPEnvironment
23
- from openenv.core.env_server.types import Action, Observation, State
24
- except ImportError as e: # pragma: no cover - we still want the file to import for tests
25
- MCPEnvironment = object # type: ignore
26
- Action = object # type: ignore
27
- Observation = dict # type: ignore
28
- State = dict # type: ignore
29
- _OPENENV_IMPORT_ERROR = e
30
- else:
31
- _OPENENV_IMPORT_ERROR = None
32
-
33
- try:
34
- from fastmcp import FastMCP
35
- except ImportError as e: # pragma: no cover
36
- FastMCP = None # type: ignore
37
- _FASTMCP_IMPORT_ERROR = e
38
- else:
39
- _FASTMCP_IMPORT_ERROR = None
40
-
41
- from grader import grade_simulation
42
- from scenario_loader import load_scenario, list_tasks
43
- from simulation import DispatchSimulation
44
- from text_view import render_dispatch_center
45
-
46
-
47
- DEFAULT_TASK = "easy"
48
- DEFAULT_SEED = 42
49
-
50
-
51
- class DispatchPulseEnvironment(MCPEnvironment): # type: ignore[misc]
52
- """Emergency-dispatch OpenEnv environment exposed as MCP tools."""
53
-
54
- def __init__(self) -> None:
55
- if _OPENENV_IMPORT_ERROR is not None:
56
- raise RuntimeError(
57
- "openenv-core is required to run the DispatchPulse server. "
58
- f"Original import error: {_OPENENV_IMPORT_ERROR}"
59
- )
60
- if _FASTMCP_IMPORT_ERROR is not None:
61
- raise RuntimeError(
62
- "fastmcp is required to run the DispatchPulse server. "
63
- f"Original import error: {_FASTMCP_IMPORT_ERROR}"
64
- )
65
-
66
- # Internal mutable state set by reset()
67
- self.sim: Optional[DispatchSimulation] = None
68
- self.task_name: str = DEFAULT_TASK
69
- self.seed: int = DEFAULT_SEED
70
- self.cumulative_step_reward: float = 0.0
71
- self.episode_count: int = 0
72
-
73
- mcp = FastMCP("dispatchpulse")
74
-
75
- # Capture self for tool closures
76
- env = self
77
-
78
- @mcp.tool
79
- def view_dispatch_center() -> str:
80
- """Return the current dispatch center view as text.
81
-
82
- This is a FREE inspection action — it does NOT advance the
83
- simulation clock. Use it whenever you need to re-check pending
84
- calls, available units, or hospital status before deciding what
85
- to do next.
86
-
87
- Returns:
88
- A formatted text snapshot of the dispatch center.
89
- """
90
- if env.sim is None:
91
- return "ERROR: environment not initialised. Call reset first."
92
- return render_dispatch_center(env.sim, env.task_name)
93
-
94
- @mcp.tool
95
- def dispatch(call_id: str, unit_id: str, hospital_id: str = "") -> str:
96
- """Dispatch an emergency unit to a pending call.
97
-
98
- This advances the simulation clock by 1 minute (the dispatcher's
99
- decision time).
100
-
101
- Args:
102
- call_id: ID of the call to send a unit to (e.g. "CALL-007").
103
- unit_id: ID of the unit to dispatch (e.g. "ALS-1").
104
- hospital_id: Optional destination hospital (e.g. "H1"). Leave
105
- empty to defer the hospital choice. Choosing the hospital
106
- that has the right specialty (cardiac/stroke/trauma) for
107
- the call meaningfully improves patient outcome.
108
-
109
- Returns:
110
- Confirmation message followed by the new dispatch center view.
111
- """
112
- if env.sim is None:
113
- return "ERROR: environment not initialised. Call reset first."
114
- if env.sim.episode_done:
115
- return "ERROR: episode is already complete. Call reset to start a new one."
116
- chosen_hospital = hospital_id.strip() or None
117
- step_reward, msg = env.sim.dispatch(call_id, unit_id, chosen_hospital)
118
- env.cumulative_step_reward += step_reward
119
- env.sim.advance_time(1)
120
- return msg + "\n\n" + render_dispatch_center(env.sim, env.task_name)
121
-
122
- @mcp.tool
123
- def classify(call_id: str, severity: int) -> str:
124
- """Reclassify the severity of a pending call.
125
-
126
- Use this when you suspect the caller's reported severity is wrong
127
- (for example, after gathering more details). Severity is on a
128
- 1-5 scale where 1 is life-threatening and 5 is a false alarm.
129
- Advances the simulation clock by 1 minute.
130
-
131
- Args:
132
- call_id: ID of the call to reclassify.
133
- severity: New severity level (1=critical, 2=urgent, 3=moderate,
134
- 4=low, 5=false alarm).
135
-
136
- Returns:
137
- Confirmation message followed by the new dispatch center view.
138
- """
139
- if env.sim is None:
140
- return "ERROR: environment not initialised. Call reset first."
141
- if env.sim.episode_done:
142
- return "ERROR: episode is already complete."
143
- step_reward, msg = env.sim.classify(call_id, severity)
144
- env.cumulative_step_reward += step_reward
145
- env.sim.advance_time(1)
146
- return msg + "\n\n" + render_dispatch_center(env.sim, env.task_name)
147
-
148
- @mcp.tool
149
- def callback(call_id: str, question: str) -> str:
150
- """Phone the caller back to clarify their emergency.
151
-
152
- Useful when the caller's description is ambiguous and you want
153
- ground-truth on the emergency type before committing your most
154
- valuable units. There's a 70% chance the caller will clarify;
155
- otherwise they'll be too distressed. Advances the clock by 1
156
- minute (you spent that minute on the phone).
157
-
158
- Args:
159
- call_id: ID of the call to phone back.
160
- question: The clarifying question to ask (free text).
161
-
162
- Returns:
163
- The caller's response followed by the dispatch center view.
164
- """
165
- if env.sim is None:
166
- return "ERROR: environment not initialised. Call reset first."
167
- if env.sim.episode_done:
168
- return "ERROR: episode is already complete."
169
- step_reward, msg = env.sim.callback(call_id, question)
170
- env.cumulative_step_reward += step_reward
171
- env.sim.advance_time(1)
172
- return msg + "\n\n" + render_dispatch_center(env.sim, env.task_name)
173
-
174
- @mcp.tool
175
- def wait(minutes: int = 1) -> str:
176
- """Skip ahead in the simulation by the given number of minutes.
177
-
178
- Use this when there are no decisions to make right now (e.g. all
179
- units are en route and you're waiting for one to free up). The
180
- cap is 5 minutes per call to keep the agent in the loop on
181
- incoming calls. Calling wait while critical calls are unhandled
182
- costs you score.
183
-
184
- Args:
185
- minutes: Number of simulation minutes to skip (1-5).
186
-
187
- Returns:
188
- The new dispatch center view after time has advanced.
189
- """
190
- if env.sim is None:
191
- return "ERROR: environment not initialised. Call reset first."
192
- if env.sim.episode_done:
193
- return "ERROR: episode is already complete."
194
- n = max(1, min(int(minutes), env.sim.config.max_wait_step_minutes))
195
- pending_before = len(env.sim.get_pending_calls())
196
- env.sim.advance_time(n)
197
- # Slight per-minute penalty for waiting while pending calls exist
198
- env.cumulative_step_reward -= 0.005 * n * pending_before
199
- return f"Advanced {n} minute(s).\n\n" + render_dispatch_center(
200
- env.sim, env.task_name
201
- )
202
-
203
- # Register MCP server with the base class
204
- super().__init__(mcp)
205
- self._state = State(episode_id=str(uuid4()), step_count=0)
206
-
207
- # Auto-bootstrap with the default task so single-shot HTTP /step calls
208
- # (which create a fresh env per request) start in a usable state.
209
- # WebSocket / MCP sessions can still call reset() explicitly with a
210
- # different task_name to override.
211
- self._auto_reset()
212
-
213
- def _auto_reset(self) -> None:
214
- try:
215
- scenario = load_scenario(DEFAULT_TASK)
216
- self.sim = DispatchSimulation(scenario, seed=DEFAULT_SEED)
217
- self.task_name = DEFAULT_TASK
218
- self.seed = DEFAULT_SEED
219
- self.cumulative_step_reward = 0.0
220
- except Exception: # pragma: no cover - never crash __init__
221
- self.sim = None
222
-
223
- # ------------------------------------------------------------------
224
- # OpenEnv lifecycle methods
225
- # ------------------------------------------------------------------
226
-
227
- def reset(
228
- self,
229
- seed: Optional[int] = None,
230
- episode_id: Optional[str] = None,
231
- task_name: Optional[str] = None,
232
- **kwargs: Any,
233
- ) -> Observation:
234
- """Reset the environment to the start of a fresh episode.
235
-
236
- Args:
237
- seed: random seed for reproducibility (default 42).
238
- episode_id: Optional caller-supplied episode ID.
239
- task_name: One of {"easy", "medium", "hard"} (default "easy").
240
-
241
- Returns:
242
- Observation with the initial dispatch center view in metadata["text"].
243
- """
244
- chosen_task = (task_name or DEFAULT_TASK).strip().lower()
245
- if chosen_task not in list_tasks():
246
- chosen_task = DEFAULT_TASK
247
- chosen_seed = int(seed) if seed is not None else DEFAULT_SEED
248
-
249
- scenario = load_scenario(chosen_task)
250
- self.sim = DispatchSimulation(scenario, seed=chosen_seed)
251
- self.task_name = chosen_task
252
- self.seed = chosen_seed
253
- self.cumulative_step_reward = 0.0
254
- self.episode_count += 1
255
-
256
- self._state = State(
257
- episode_id=episode_id or str(uuid4()),
258
- step_count=0,
259
- )
260
-
261
- return Observation(
262
- done=False,
263
- reward=0.0,
264
- metadata={
265
- "status": "ready",
266
- "task": chosen_task,
267
- "seed": chosen_seed,
268
- "tasks_available": list_tasks(),
269
- "text": render_dispatch_center(self.sim, self.task_name),
270
- "tools": [
271
- "view_dispatch_center",
272
- "dispatch",
273
- "classify",
274
- "callback",
275
- "wait",
276
- ],
277
- },
278
- )
279
-
280
- def _step_impl(
281
- self,
282
- action: Action,
283
- timeout_s: Optional[float] = None,
284
- **kwargs: Any,
285
- ) -> Observation:
286
- return Observation(
287
- done=False,
288
- reward=0.0,
289
- metadata={
290
- "error": (
291
- f"Unknown action type {type(action).__name__}. "
292
- "Use ListToolsAction or CallToolAction (MCP)."
293
- )
294
- },
295
- )
296
-
297
- def step(
298
- self,
299
- action: Action,
300
- timeout_s: Optional[float] = None,
301
- **kwargs: Any,
302
- ) -> Observation:
303
- """Execute one step. Tools advance the sim; we then enrich the obs."""
304
- self._state.step_count += 1
305
- obs = super().step(action, timeout_s=timeout_s, **kwargs)
306
- return self._enrich_observation(obs)
307
-
308
- async def step_async(
309
- self,
310
- action: Action,
311
- timeout_s: Optional[float] = None,
312
- **kwargs: Any,
313
- ) -> Observation:
314
- self._state.step_count += 1
315
- obs = await super().step_async(action, timeout_s=timeout_s, **kwargs)
316
- return self._enrich_observation(obs)
317
-
318
- @property
319
- def state(self) -> State:
320
- return self._state
321
-
322
- # ------------------------------------------------------------------
323
- # Helpers
324
- # ------------------------------------------------------------------
325
-
326
- def _enrich_observation(self, obs: Observation) -> Observation:
327
- """Inject reward / done / sim metadata onto a base Observation."""
328
- if self.sim is None:
329
- return obs
330
-
331
- obs.done = bool(self.sim.episode_done)
332
- if self.sim.episode_done:
333
- final = grade_simulation(self.sim)
334
- obs.reward = float(final.total)
335
- md = obs.metadata or {}
336
- md.update(
337
- {
338
- "final_reward": final.model_dump(),
339
- "task": self.task_name,
340
- "completed_calls": len(self.sim.completed_calls),
341
- "timed_out_calls": len(self.sim.timed_out_calls),
342
- "total_calls": self.sim.total_calls(),
343
- }
344
- )
345
- obs.metadata = md
346
- else:
347
- obs.reward = float(self.cumulative_step_reward)
348
- md = obs.metadata or {}
349
- md.update(
350
- {
351
- "current_time": self.sim.current_time,
352
- "calls_pending": len(self.sim.get_pending_calls()),
353
- "units_available": len(self.sim.get_available_units()),
354
- "task": self.task_name,
355
- }
356
- )
357
- obs.metadata = md
358
- return obs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
server/environment.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """DispatchPulse OpenEnv environment.
2
+
3
+ Inherits from ``openenv.core.env_server.interfaces.Environment`` and implements
4
+ the standard ``reset() / step() / state`` Gym-style API. The wire types
5
+ ``DispatchPulseAction`` and ``DispatchPulseObservation`` are defined in
6
+ ``models.py`` and inherit from the OpenEnv ``Action`` / ``Observation`` base
7
+ classes.
8
+
9
+ This is a thin wrapper around the in-process ``DispatchSimulation`` engine.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import os
15
+ import sys
16
+ from typing import Any, Optional
17
+ from uuid import uuid4
18
+
19
+ # Make project root importable when running as ``server.app:app`` from /app/env
20
+ _PKG_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
21
+ if _PKG_ROOT not in sys.path:
22
+ sys.path.insert(0, _PKG_ROOT)
23
+
24
+ from openenv.core.env_server.interfaces import Environment
25
+
26
+ from grader import grade_simulation
27
+ from models import DispatchPulseAction, DispatchPulseObservation, DispatchPulseState
28
+ from scenario_loader import VALID_TASKS, load_scenario
29
+ from simulation import DispatchSimulation
30
+ from text_view import render_dispatch_center
31
+
32
+ DEFAULT_TASK = "easy"
33
+ DEFAULT_SEED = 42
34
+
35
+
36
+ class DispatchPulseEnvironment(
37
+ Environment[DispatchPulseAction, DispatchPulseObservation, DispatchPulseState]
38
+ ):
39
+ """Emergency-dispatch OpenEnv environment.
40
+
41
+ Each call to ``reset()`` starts a fresh episode for the chosen task.
42
+ Calls to ``step(action)`` advance the simulation by one decision turn
43
+ (which usually equals 1 minute of simulation time).
44
+
45
+ Tasks: ``easy``, ``medium``, ``hard``.
46
+ """
47
+
48
+ SUPPORTS_CONCURRENT_SESSIONS: bool = True
49
+
50
+ def __init__(self) -> None:
51
+ super().__init__()
52
+ self.sim: Optional[DispatchSimulation] = None
53
+ self.task_name: str = DEFAULT_TASK
54
+ self.seed: int = DEFAULT_SEED
55
+ self._episode_id: str = str(uuid4())
56
+ self._step_count: int = 0
57
+ self._cumulative_step_reward: float = 0.0
58
+ # Bootstrap so single-shot HTTP /step still works without an explicit reset
59
+ self._bootstrap()
60
+
61
+ def _bootstrap(self) -> None:
62
+ try:
63
+ scenario = load_scenario(DEFAULT_TASK)
64
+ self.sim = DispatchSimulation(scenario, seed=DEFAULT_SEED)
65
+ self.task_name = DEFAULT_TASK
66
+ self.seed = DEFAULT_SEED
67
+ self._cumulative_step_reward = 0.0
68
+ self._step_count = 0
69
+ except Exception: # pragma: no cover
70
+ self.sim = None
71
+
72
+ # ------------------------------------------------------------------
73
+ # Environment API
74
+ # ------------------------------------------------------------------
75
+
76
+ def reset(
77
+ self,
78
+ seed: Optional[int] = None,
79
+ episode_id: Optional[str] = None,
80
+ task_name: Optional[str] = None,
81
+ **kwargs: Any,
82
+ ) -> DispatchPulseObservation:
83
+ chosen_task = (task_name or DEFAULT_TASK).strip().lower()
84
+ if chosen_task not in VALID_TASKS:
85
+ chosen_task = DEFAULT_TASK
86
+ chosen_seed = int(seed) if seed is not None else DEFAULT_SEED
87
+
88
+ scenario = load_scenario(chosen_task)
89
+ self.sim = DispatchSimulation(scenario, seed=chosen_seed)
90
+ self.task_name = chosen_task
91
+ self.seed = chosen_seed
92
+ self._episode_id = episode_id or str(uuid4())
93
+ self._step_count = 0
94
+ self._cumulative_step_reward = 0.0
95
+ return self._build_observation(info_message="ready", error=None)
96
+
97
+ def step(
98
+ self,
99
+ action: DispatchPulseAction,
100
+ timeout_s: Optional[float] = None,
101
+ **kwargs: Any,
102
+ ) -> DispatchPulseObservation:
103
+ if self.sim is None:
104
+ self._bootstrap()
105
+ if self.sim is None:
106
+ return self._build_observation(error="environment not initialised")
107
+
108
+ if self.sim.episode_done:
109
+ return self._build_observation(error="episode already done")
110
+
111
+ self._step_count += 1
112
+ action_type = (action.action_type or "").strip().lower()
113
+ text_action = (action.text or "").strip()
114
+
115
+ # Allow text-only actions: parse the text into structured fields
116
+ if not action_type and text_action:
117
+ parsed = _parse_text_action(text_action)
118
+ if parsed is not None:
119
+ action_type, fields = parsed
120
+ for key, value in fields.items():
121
+ if getattr(action, key, None) in (None, ""):
122
+ setattr(action, key, value)
123
+
124
+ step_reward = 0.0
125
+ info_message: Optional[str] = None
126
+ error: Optional[str] = None
127
+
128
+ try:
129
+ if action_type == "dispatch":
130
+ if not action.call_id or not action.unit_id:
131
+ error = "dispatch requires call_id and unit_id"
132
+ else:
133
+ step_reward, info_message = self.sim.dispatch(
134
+ call_id=action.call_id,
135
+ unit_id=action.unit_id,
136
+ hospital_id=action.hospital_id,
137
+ )
138
+ self.sim.advance_time(1)
139
+ elif action_type == "classify":
140
+ if not action.call_id or action.severity is None:
141
+ error = "classify requires call_id and severity (1-5)"
142
+ else:
143
+ step_reward, info_message = self.sim.classify(
144
+ call_id=action.call_id, severity=int(action.severity)
145
+ )
146
+ self.sim.advance_time(1)
147
+ elif action_type == "callback":
148
+ if not action.call_id:
149
+ error = "callback requires call_id"
150
+ else:
151
+ step_reward, info_message = self.sim.callback(
152
+ call_id=action.call_id, question=action.message or ""
153
+ )
154
+ self.sim.advance_time(1)
155
+ elif action_type == "wait":
156
+ minutes = int(action.minutes or 1)
157
+ minutes = max(1, min(minutes, self.sim.config.max_wait_step_minutes))
158
+ pending_before = len(self.sim.get_pending_calls())
159
+ self.sim.advance_time(minutes)
160
+ step_reward = -0.005 * minutes * pending_before
161
+ info_message = f"waited {minutes} minute(s)"
162
+ elif action_type == "view":
163
+ step_reward = 0.0
164
+ info_message = "view (no time cost)"
165
+ else:
166
+ step_reward = -0.05
167
+ error = f"unknown action_type: {action_type!r}"
168
+ except Exception as exc: # pragma: no cover - defensive
169
+ error = f"{type(exc).__name__}: {exc}"
170
+ step_reward = -0.05
171
+
172
+ self._cumulative_step_reward += step_reward
173
+ return self._build_observation(info_message=info_message, error=error)
174
+
175
+ @property
176
+ def state(self) -> DispatchPulseState:
177
+ if self.sim is None:
178
+ return DispatchPulseState(
179
+ episode_id=self._episode_id,
180
+ step_count=self._step_count,
181
+ task_name=self.task_name,
182
+ )
183
+ return DispatchPulseState(
184
+ episode_id=self._episode_id,
185
+ step_count=self._step_count,
186
+ current_time=self.sim.current_time,
187
+ episode_done=self.sim.episode_done,
188
+ total_calls=self.sim.total_calls(),
189
+ calls_dispatched=len(self.sim.dispatches),
190
+ calls_completed=len(self.sim.completed_calls),
191
+ calls_timed_out=len(self.sim.timed_out_calls),
192
+ calls_pending=len(self.sim.get_pending_calls()),
193
+ units_available=len(self.sim.get_available_units()),
194
+ running_reward=self._cumulative_step_reward,
195
+ task_name=self.task_name,
196
+ )
197
+
198
+ # ------------------------------------------------------------------
199
+ # Helpers
200
+ # ------------------------------------------------------------------
201
+
202
+ def _build_observation(
203
+ self,
204
+ info_message: Optional[str] = None,
205
+ error: Optional[str] = None,
206
+ ) -> DispatchPulseObservation:
207
+ if self.sim is None:
208
+ return DispatchPulseObservation(
209
+ done=True,
210
+ reward=0.0,
211
+ text="ERROR: environment not initialised. Call reset first.",
212
+ last_action_error="not_initialised",
213
+ )
214
+
215
+ text = render_dispatch_center(self.sim, self.task_name)
216
+ done = bool(self.sim.episode_done)
217
+ if done:
218
+ final = grade_simulation(self.sim)
219
+ reward_value: float = float(final.total)
220
+ metadata = {
221
+ "final_reward": final.model_dump(),
222
+ "task": self.task_name,
223
+ }
224
+ else:
225
+ reward_value = float(self._cumulative_step_reward)
226
+ metadata = {"task": self.task_name}
227
+
228
+ if info_message:
229
+ metadata["info"] = info_message
230
+ if error:
231
+ metadata["error"] = error
232
+
233
+ return DispatchPulseObservation(
234
+ done=done,
235
+ reward=reward_value,
236
+ text=text,
237
+ current_time=self.sim.current_time,
238
+ time_limit=self.sim.config.time_limit_minutes,
239
+ calls_pending=len(self.sim.get_pending_calls()),
240
+ units_available=len(self.sim.get_available_units()),
241
+ calls_completed=len(self.sim.completed_calls),
242
+ calls_timed_out=len(self.sim.timed_out_calls),
243
+ total_calls=self.sim.total_calls(),
244
+ last_action_error=error,
245
+ info_message=info_message,
246
+ metadata=metadata,
247
+ )
248
+
249
+
250
+ def _parse_text_action(text: str):
251
+ """Parse a text action like ``dispatch CALL-001 ALS-1 H1`` into fields.
252
+
253
+ Returns ``(action_type, kwargs_dict)`` or None on parse failure.
254
+ """
255
+ parts = text.strip().split(maxsplit=4)
256
+ if not parts:
257
+ return None
258
+ head = parts[0].lower()
259
+ if head == "dispatch" and len(parts) >= 3:
260
+ out = {"call_id": parts[1], "unit_id": parts[2]}
261
+ if len(parts) >= 4 and parts[3]:
262
+ out["hospital_id"] = parts[3]
263
+ return "dispatch", out
264
+ if head == "classify" and len(parts) >= 3:
265
+ try:
266
+ sev = int(parts[2])
267
+ except ValueError:
268
+ return None
269
+ return "classify", {"call_id": parts[1], "severity": sev}
270
+ if head == "callback" and len(parts) >= 2:
271
+ return "callback", {
272
+ "call_id": parts[1],
273
+ "message": " ".join(parts[2:]) if len(parts) > 2 else "",
274
+ }
275
+ if head == "wait":
276
+ try:
277
+ mins = int(parts[1]) if len(parts) > 1 else 1
278
+ except ValueError:
279
+ mins = 1
280
+ return "wait", {"minutes": mins}
281
+ if head in ("view", "view_dispatch_center"):
282
+ return "view", {}
283
+ return None