"""DispatchPulse OpenEnv environment. Inherits from ``openenv.core.env_server.interfaces.Environment`` and implements the standard ``reset() / step() / state`` Gym-style API. The wire types ``DispatchPulseAction`` and ``DispatchPulseObservation`` are defined in ``models.py`` and inherit from the OpenEnv ``Action`` / ``Observation`` base classes. This is a thin wrapper around the in-process ``DispatchSimulation`` engine. """ from __future__ import annotations import os import sys from typing import Any, Optional from uuid import uuid4 # Make project root importable when running as ``server.app:app`` from /app/env _PKG_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) if _PKG_ROOT not in sys.path: sys.path.insert(0, _PKG_ROOT) from openenv.core.env_server.interfaces import Environment from grader import grade_simulation from models import DispatchPulseAction, DispatchPulseObservation, DispatchPulseState from scenario_loader import VALID_TASKS, load_scenario from simulation import DispatchSimulation from text_view import render_dispatch_center # Re-export the task registry and grader symbols at module level so static # validators that scan server/environment.py for tasks can find them here # (same pattern as the SQL Repair passing submission where both TASKS and # grade_submission are accessible from server/environment.py). from task_definitions import ( # noqa: F401,E402 TASKS, TaskDefinition, grade_submission, get_task, list_tasks, ) DEFAULT_TASK = "easy" DEFAULT_SEED = 42 class DispatchPulseEnvironment( Environment[DispatchPulseAction, DispatchPulseObservation, DispatchPulseState] ): """Emergency-dispatch OpenEnv environment. Each call to ``reset()`` starts a fresh episode for the chosen task. Calls to ``step(action)`` advance the simulation by one decision turn (which usually equals 1 minute of simulation time). Tasks: ``easy``, ``medium``, ``hard``. """ SUPPORTS_CONCURRENT_SESSIONS: bool = True def __init__(self) -> None: super().__init__() self.sim: Optional[DispatchSimulation] = None self.task_name: str = DEFAULT_TASK self.seed: int = DEFAULT_SEED self._episode_id: str = str(uuid4()) self._step_count: int = 0 self._cumulative_step_reward: float = 0.0 self._last_step_reward: float = 0.0 # Bootstrap so single-shot HTTP /step still works without an explicit reset self._bootstrap() def _bootstrap(self) -> None: try: scenario = load_scenario(DEFAULT_TASK) self.sim = DispatchSimulation(scenario, seed=DEFAULT_SEED) self.task_name = DEFAULT_TASK self.seed = DEFAULT_SEED self._cumulative_step_reward = 0.0 self._last_step_reward = 0.0 self._step_count = 0 except Exception as exc: # pragma: no cover print(f"[DispatchPulseEnvironment] bootstrap failed: {exc}", file=sys.stderr, flush=True) self.sim = None # ------------------------------------------------------------------ # Environment API # ------------------------------------------------------------------ def reset( self, seed: Optional[int] = None, episode_id: Optional[str] = None, task_name: Optional[str] = None, **kwargs: Any, ) -> DispatchPulseObservation: chosen_task = (task_name or DEFAULT_TASK).strip().lower() if chosen_task not in VALID_TASKS: chosen_task = DEFAULT_TASK chosen_seed = int(seed) if seed is not None else DEFAULT_SEED scenario = load_scenario(chosen_task) self.sim = DispatchSimulation(scenario, seed=chosen_seed) self.task_name = chosen_task self.seed = chosen_seed self._episode_id = episode_id or str(uuid4()) self._step_count = 0 self._cumulative_step_reward = 0.0 self._last_step_reward = 0.0 return self._build_observation(info_message="ready", error=None) def step( self, action: DispatchPulseAction, timeout_s: Optional[float] = None, **kwargs: Any, ) -> DispatchPulseObservation: if self.sim is None: self._bootstrap() if self.sim is None: return self._build_observation(error="environment not initialised") if self.sim.episode_done: return self._build_observation(error="episode already done") self._step_count += 1 action_type = (action.action_type or "").strip().lower() text_action = (action.text or "").strip() # Allow text-only actions: parse the text into structured fields if not action_type and text_action: parsed = _parse_text_action(text_action) if parsed is not None: action_type, fields = parsed for key, value in fields.items(): if getattr(action, key, None) in (None, ""): setattr(action, key, value) step_reward = 0.0 info_message: Optional[str] = None error: Optional[str] = None try: if action_type == "dispatch": if not action.call_id or not action.unit_id: error = "dispatch requires call_id and unit_id" else: step_reward, info_message = self.sim.dispatch( call_id=action.call_id, unit_id=action.unit_id, hospital_id=action.hospital_id, ) self.sim.advance_time(1) elif action_type == "classify": if not action.call_id or action.severity is None: error = "classify requires call_id and severity (1-5)" else: step_reward, info_message = self.sim.classify( call_id=action.call_id, severity=int(action.severity) ) self.sim.advance_time(1) elif action_type == "callback": if not action.call_id: error = "callback requires call_id" else: step_reward, info_message = self.sim.callback( call_id=action.call_id, question=action.message or "" ) self.sim.advance_time(1) elif action_type == "wait": minutes = int(action.minutes or 1) minutes = max(1, min(minutes, self.sim.config.max_wait_step_minutes)) pending_before = len(self.sim.get_pending_calls()) self.sim.advance_time(minutes) step_reward = -0.005 * minutes * pending_before info_message = f"waited {minutes} minute(s)" elif action_type == "view": step_reward = 0.0 info_message = "view (no time cost)" else: step_reward = -0.05 error = f"unknown action_type: {action_type!r}" except Exception as exc: # pragma: no cover - defensive error = f"{type(exc).__name__}: {exc}" step_reward = -0.05 self._cumulative_step_reward += step_reward self._last_step_reward = step_reward return self._build_observation(info_message=info_message, error=error) @property def state(self) -> DispatchPulseState: if self.sim is None: return DispatchPulseState( episode_id=self._episode_id, step_count=self._step_count, task_name=self.task_name, ) return DispatchPulseState( episode_id=self._episode_id, step_count=self._step_count, current_time=self.sim.current_time, episode_done=self.sim.episode_done, total_calls=self.sim.total_calls(), calls_dispatched=len(self.sim.dispatches), calls_completed=len(self.sim.completed_calls), calls_timed_out=len(self.sim.timed_out_calls), calls_pending=len(self.sim.get_pending_calls()), units_available=len(self.sim.get_available_units()), running_reward=self._cumulative_step_reward, task_name=self.task_name, ) # ------------------------------------------------------------------ # Helpers # ------------------------------------------------------------------ def _build_observation( self, info_message: Optional[str] = None, error: Optional[str] = None, ) -> DispatchPulseObservation: if self.sim is None: return DispatchPulseObservation( done=True, reward=0.0, text="ERROR: environment not initialised. Call reset first.", last_action_error="not_initialised", ) text = render_dispatch_center(self.sim, self.task_name) done = bool(self.sim.episode_done) if done: final = grade_simulation(self.sim) reward_value: float = float(final.total) metadata = { "final_reward": final.model_dump(), "task": self.task_name, "cumulative_step_reward": float(self._cumulative_step_reward), } else: # Report the per-step delta, not the running cumulative. The # cumulative is still available via state() and metadata, but the # observation's reward field matches the standard Gym/OpenEnv # semantics of "reward for this step only". reward_value = float(self._last_step_reward) metadata = { "task": self.task_name, "cumulative_step_reward": float(self._cumulative_step_reward), } if info_message: metadata["info"] = info_message if error: metadata["error"] = error return DispatchPulseObservation( done=done, reward=reward_value, text=text, current_time=self.sim.current_time, time_limit=self.sim.config.time_limit_minutes, calls_pending=len(self.sim.get_pending_calls()), units_available=len(self.sim.get_available_units()), calls_completed=len(self.sim.completed_calls), calls_timed_out=len(self.sim.timed_out_calls), total_calls=self.sim.total_calls(), last_action_error=error, info_message=info_message, metadata=metadata, ) def _parse_text_action(text: str): """Parse a text action like ``dispatch CALL-001 ALS-1 H1`` into fields. Returns ``(action_type, kwargs_dict)`` or None on parse failure. """ parts = text.strip().split(maxsplit=4) if not parts: return None head = parts[0].lower() if head == "dispatch" and len(parts) >= 3: out = {"call_id": parts[1], "unit_id": parts[2]} if len(parts) >= 4 and parts[3]: out["hospital_id"] = parts[3] return "dispatch", out if head == "classify" and len(parts) >= 3: try: sev = int(parts[2]) except ValueError: return None return "classify", {"call_id": parts[1], "severity": sev} if head == "callback" and len(parts) >= 2: return "callback", { "call_id": parts[1], "message": " ".join(parts[2:]) if len(parts) > 2 else "", } if head == "wait": try: mins = int(parts[1]) if len(parts) > 1 else 1 except ValueError: mins = 1 return "wait", {"minutes": mins} if head in ("view", "view_dispatch_center"): return "view", {} return None