Spaces:
Sleeping
Sleeping
| """DispatchPulse OpenEnv environment. | |
| Inherits from ``openenv.core.env_server.interfaces.Environment`` and implements | |
| the standard ``reset() / step() / state`` Gym-style API. The wire types | |
| ``DispatchPulseAction`` and ``DispatchPulseObservation`` are defined in | |
| ``models.py`` and inherit from the OpenEnv ``Action`` / ``Observation`` base | |
| classes. | |
| This is a thin wrapper around the in-process ``DispatchSimulation`` engine. | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import sys | |
| from typing import Any, Optional | |
| from uuid import uuid4 | |
| # Make project root importable when running as ``server.app:app`` from /app/env | |
| _PKG_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
| if _PKG_ROOT not in sys.path: | |
| sys.path.insert(0, _PKG_ROOT) | |
| from openenv.core.env_server.interfaces import Environment | |
| from grader import grade_simulation | |
| from models import DispatchPulseAction, DispatchPulseObservation, DispatchPulseState | |
| from scenario_loader import VALID_TASKS, load_scenario | |
| from simulation import DispatchSimulation | |
| from text_view import render_dispatch_center | |
| # Re-export the task registry and grader symbols at module level so static | |
| # validators that scan server/environment.py for tasks can find them here | |
| # (same pattern as the SQL Repair passing submission where both TASKS and | |
| # grade_submission are accessible from server/environment.py). | |
| from task_definitions import ( # noqa: F401,E402 | |
| TASKS, | |
| TaskDefinition, | |
| grade_submission, | |
| get_task, | |
| list_tasks, | |
| ) | |
| DEFAULT_TASK = "easy" | |
| DEFAULT_SEED = 42 | |
| class DispatchPulseEnvironment( | |
| Environment[DispatchPulseAction, DispatchPulseObservation, DispatchPulseState] | |
| ): | |
| """Emergency-dispatch OpenEnv environment. | |
| Each call to ``reset()`` starts a fresh episode for the chosen task. | |
| Calls to ``step(action)`` advance the simulation by one decision turn | |
| (which usually equals 1 minute of simulation time). | |
| Tasks: ``easy``, ``medium``, ``hard``. | |
| """ | |
| SUPPORTS_CONCURRENT_SESSIONS: bool = True | |
| def __init__(self) -> None: | |
| super().__init__() | |
| self.sim: Optional[DispatchSimulation] = None | |
| self.task_name: str = DEFAULT_TASK | |
| self.seed: int = DEFAULT_SEED | |
| self._episode_id: str = str(uuid4()) | |
| self._step_count: int = 0 | |
| self._cumulative_step_reward: float = 0.0 | |
| self._last_step_reward: float = 0.0 | |
| # Bootstrap so single-shot HTTP /step still works without an explicit reset | |
| self._bootstrap() | |
| def _bootstrap(self) -> None: | |
| try: | |
| scenario = load_scenario(DEFAULT_TASK) | |
| self.sim = DispatchSimulation(scenario, seed=DEFAULT_SEED) | |
| self.task_name = DEFAULT_TASK | |
| self.seed = DEFAULT_SEED | |
| self._cumulative_step_reward = 0.0 | |
| self._last_step_reward = 0.0 | |
| self._step_count = 0 | |
| except Exception as exc: # pragma: no cover | |
| print(f"[DispatchPulseEnvironment] bootstrap failed: {exc}", file=sys.stderr, flush=True) | |
| self.sim = None | |
| # ------------------------------------------------------------------ | |
| # Environment API | |
| # ------------------------------------------------------------------ | |
| def reset( | |
| self, | |
| seed: Optional[int] = None, | |
| episode_id: Optional[str] = None, | |
| task_name: Optional[str] = None, | |
| **kwargs: Any, | |
| ) -> DispatchPulseObservation: | |
| chosen_task = (task_name or DEFAULT_TASK).strip().lower() | |
| if chosen_task not in VALID_TASKS: | |
| chosen_task = DEFAULT_TASK | |
| chosen_seed = int(seed) if seed is not None else DEFAULT_SEED | |
| scenario = load_scenario(chosen_task) | |
| self.sim = DispatchSimulation(scenario, seed=chosen_seed) | |
| self.task_name = chosen_task | |
| self.seed = chosen_seed | |
| self._episode_id = episode_id or str(uuid4()) | |
| self._step_count = 0 | |
| self._cumulative_step_reward = 0.0 | |
| self._last_step_reward = 0.0 | |
| return self._build_observation(info_message="ready", error=None) | |
| def step( | |
| self, | |
| action: DispatchPulseAction, | |
| timeout_s: Optional[float] = None, | |
| **kwargs: Any, | |
| ) -> DispatchPulseObservation: | |
| if self.sim is None: | |
| self._bootstrap() | |
| if self.sim is None: | |
| return self._build_observation(error="environment not initialised") | |
| if self.sim.episode_done: | |
| return self._build_observation(error="episode already done") | |
| self._step_count += 1 | |
| action_type = (action.action_type or "").strip().lower() | |
| text_action = (action.text or "").strip() | |
| # Allow text-only actions: parse the text into structured fields | |
| if not action_type and text_action: | |
| parsed = _parse_text_action(text_action) | |
| if parsed is not None: | |
| action_type, fields = parsed | |
| for key, value in fields.items(): | |
| if getattr(action, key, None) in (None, ""): | |
| setattr(action, key, value) | |
| step_reward = 0.0 | |
| info_message: Optional[str] = None | |
| error: Optional[str] = None | |
| try: | |
| if action_type == "dispatch": | |
| if not action.call_id or not action.unit_id: | |
| error = "dispatch requires call_id and unit_id" | |
| else: | |
| step_reward, info_message = self.sim.dispatch( | |
| call_id=action.call_id, | |
| unit_id=action.unit_id, | |
| hospital_id=action.hospital_id, | |
| ) | |
| self.sim.advance_time(1) | |
| elif action_type == "classify": | |
| if not action.call_id or action.severity is None: | |
| error = "classify requires call_id and severity (1-5)" | |
| else: | |
| step_reward, info_message = self.sim.classify( | |
| call_id=action.call_id, severity=int(action.severity) | |
| ) | |
| self.sim.advance_time(1) | |
| elif action_type == "callback": | |
| if not action.call_id: | |
| error = "callback requires call_id" | |
| else: | |
| step_reward, info_message = self.sim.callback( | |
| call_id=action.call_id, question=action.message or "" | |
| ) | |
| self.sim.advance_time(1) | |
| elif action_type == "wait": | |
| minutes = int(action.minutes or 1) | |
| minutes = max(1, min(minutes, self.sim.config.max_wait_step_minutes)) | |
| pending_before = len(self.sim.get_pending_calls()) | |
| self.sim.advance_time(minutes) | |
| step_reward = -0.005 * minutes * pending_before | |
| info_message = f"waited {minutes} minute(s)" | |
| elif action_type == "view": | |
| step_reward = 0.0 | |
| info_message = "view (no time cost)" | |
| else: | |
| step_reward = -0.05 | |
| error = f"unknown action_type: {action_type!r}" | |
| except Exception as exc: # pragma: no cover - defensive | |
| error = f"{type(exc).__name__}: {exc}" | |
| step_reward = -0.05 | |
| self._cumulative_step_reward += step_reward | |
| self._last_step_reward = step_reward | |
| return self._build_observation(info_message=info_message, error=error) | |
| def state(self) -> DispatchPulseState: | |
| if self.sim is None: | |
| return DispatchPulseState( | |
| episode_id=self._episode_id, | |
| step_count=self._step_count, | |
| task_name=self.task_name, | |
| ) | |
| return DispatchPulseState( | |
| episode_id=self._episode_id, | |
| step_count=self._step_count, | |
| current_time=self.sim.current_time, | |
| episode_done=self.sim.episode_done, | |
| total_calls=self.sim.total_calls(), | |
| calls_dispatched=len(self.sim.dispatches), | |
| calls_completed=len(self.sim.completed_calls), | |
| calls_timed_out=len(self.sim.timed_out_calls), | |
| calls_pending=len(self.sim.get_pending_calls()), | |
| units_available=len(self.sim.get_available_units()), | |
| running_reward=self._cumulative_step_reward, | |
| task_name=self.task_name, | |
| ) | |
| # ------------------------------------------------------------------ | |
| # Helpers | |
| # ------------------------------------------------------------------ | |
| def _build_observation( | |
| self, | |
| info_message: Optional[str] = None, | |
| error: Optional[str] = None, | |
| ) -> DispatchPulseObservation: | |
| if self.sim is None: | |
| return DispatchPulseObservation( | |
| done=True, | |
| reward=0.0, | |
| text="ERROR: environment not initialised. Call reset first.", | |
| last_action_error="not_initialised", | |
| ) | |
| text = render_dispatch_center(self.sim, self.task_name) | |
| done = bool(self.sim.episode_done) | |
| if done: | |
| final = grade_simulation(self.sim) | |
| reward_value: float = float(final.total) | |
| metadata = { | |
| "final_reward": final.model_dump(), | |
| "task": self.task_name, | |
| "cumulative_step_reward": float(self._cumulative_step_reward), | |
| } | |
| else: | |
| # Report the per-step delta, not the running cumulative. The | |
| # cumulative is still available via state() and metadata, but the | |
| # observation's reward field matches the standard Gym/OpenEnv | |
| # semantics of "reward for this step only". | |
| reward_value = float(self._last_step_reward) | |
| metadata = { | |
| "task": self.task_name, | |
| "cumulative_step_reward": float(self._cumulative_step_reward), | |
| } | |
| if info_message: | |
| metadata["info"] = info_message | |
| if error: | |
| metadata["error"] = error | |
| return DispatchPulseObservation( | |
| done=done, | |
| reward=reward_value, | |
| text=text, | |
| current_time=self.sim.current_time, | |
| time_limit=self.sim.config.time_limit_minutes, | |
| calls_pending=len(self.sim.get_pending_calls()), | |
| units_available=len(self.sim.get_available_units()), | |
| calls_completed=len(self.sim.completed_calls), | |
| calls_timed_out=len(self.sim.timed_out_calls), | |
| total_calls=self.sim.total_calls(), | |
| last_action_error=error, | |
| info_message=info_message, | |
| metadata=metadata, | |
| ) | |
| def _parse_text_action(text: str): | |
| """Parse a text action like ``dispatch CALL-001 ALS-1 H1`` into fields. | |
| Returns ``(action_type, kwargs_dict)`` or None on parse failure. | |
| """ | |
| parts = text.strip().split(maxsplit=4) | |
| if not parts: | |
| return None | |
| head = parts[0].lower() | |
| if head == "dispatch" and len(parts) >= 3: | |
| out = {"call_id": parts[1], "unit_id": parts[2]} | |
| if len(parts) >= 4 and parts[3]: | |
| out["hospital_id"] = parts[3] | |
| return "dispatch", out | |
| if head == "classify" and len(parts) >= 3: | |
| try: | |
| sev = int(parts[2]) | |
| except ValueError: | |
| return None | |
| return "classify", {"call_id": parts[1], "severity": sev} | |
| if head == "callback" and len(parts) >= 2: | |
| return "callback", { | |
| "call_id": parts[1], | |
| "message": " ".join(parts[2:]) if len(parts) > 2 else "", | |
| } | |
| if head == "wait": | |
| try: | |
| mins = int(parts[1]) if len(parts) > 1 else 1 | |
| except ValueError: | |
| mins = 1 | |
| return "wait", {"minutes": mins} | |
| if head in ("view", "view_dispatch_center"): | |
| return "view", {} | |
| return None | |