dispatchpulse / server /environment.py
Arun-Sanjay's picture
Trim task_definitions.py to match minimal passing pattern
82713c7
"""DispatchPulse OpenEnv environment.
Inherits from ``openenv.core.env_server.interfaces.Environment`` and implements
the standard ``reset() / step() / state`` Gym-style API. The wire types
``DispatchPulseAction`` and ``DispatchPulseObservation`` are defined in
``models.py`` and inherit from the OpenEnv ``Action`` / ``Observation`` base
classes.
This is a thin wrapper around the in-process ``DispatchSimulation`` engine.
"""
from __future__ import annotations
import os
import sys
from typing import Any, Optional
from uuid import uuid4
# Make project root importable when running as ``server.app:app`` from /app/env
_PKG_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if _PKG_ROOT not in sys.path:
sys.path.insert(0, _PKG_ROOT)
from openenv.core.env_server.interfaces import Environment
from grader import grade_simulation
from models import DispatchPulseAction, DispatchPulseObservation, DispatchPulseState
from scenario_loader import VALID_TASKS, load_scenario
from simulation import DispatchSimulation
from text_view import render_dispatch_center
# Re-export the task registry and grader symbols at module level so static
# validators that scan server/environment.py for tasks can find them here
# (same pattern as the SQL Repair passing submission where both TASKS and
# grade_submission are accessible from server/environment.py).
from task_definitions import ( # noqa: F401,E402
TASKS,
TaskDefinition,
grade_submission,
get_task,
list_tasks,
)
DEFAULT_TASK = "easy"
DEFAULT_SEED = 42
class DispatchPulseEnvironment(
Environment[DispatchPulseAction, DispatchPulseObservation, DispatchPulseState]
):
"""Emergency-dispatch OpenEnv environment.
Each call to ``reset()`` starts a fresh episode for the chosen task.
Calls to ``step(action)`` advance the simulation by one decision turn
(which usually equals 1 minute of simulation time).
Tasks: ``easy``, ``medium``, ``hard``.
"""
SUPPORTS_CONCURRENT_SESSIONS: bool = True
def __init__(self) -> None:
super().__init__()
self.sim: Optional[DispatchSimulation] = None
self.task_name: str = DEFAULT_TASK
self.seed: int = DEFAULT_SEED
self._episode_id: str = str(uuid4())
self._step_count: int = 0
self._cumulative_step_reward: float = 0.0
self._last_step_reward: float = 0.0
# Bootstrap so single-shot HTTP /step still works without an explicit reset
self._bootstrap()
def _bootstrap(self) -> None:
try:
scenario = load_scenario(DEFAULT_TASK)
self.sim = DispatchSimulation(scenario, seed=DEFAULT_SEED)
self.task_name = DEFAULT_TASK
self.seed = DEFAULT_SEED
self._cumulative_step_reward = 0.0
self._last_step_reward = 0.0
self._step_count = 0
except Exception as exc: # pragma: no cover
print(f"[DispatchPulseEnvironment] bootstrap failed: {exc}", file=sys.stderr, flush=True)
self.sim = None
# ------------------------------------------------------------------
# Environment API
# ------------------------------------------------------------------
def reset(
self,
seed: Optional[int] = None,
episode_id: Optional[str] = None,
task_name: Optional[str] = None,
**kwargs: Any,
) -> DispatchPulseObservation:
chosen_task = (task_name or DEFAULT_TASK).strip().lower()
if chosen_task not in VALID_TASKS:
chosen_task = DEFAULT_TASK
chosen_seed = int(seed) if seed is not None else DEFAULT_SEED
scenario = load_scenario(chosen_task)
self.sim = DispatchSimulation(scenario, seed=chosen_seed)
self.task_name = chosen_task
self.seed = chosen_seed
self._episode_id = episode_id or str(uuid4())
self._step_count = 0
self._cumulative_step_reward = 0.0
self._last_step_reward = 0.0
return self._build_observation(info_message="ready", error=None)
def step(
self,
action: DispatchPulseAction,
timeout_s: Optional[float] = None,
**kwargs: Any,
) -> DispatchPulseObservation:
if self.sim is None:
self._bootstrap()
if self.sim is None:
return self._build_observation(error="environment not initialised")
if self.sim.episode_done:
return self._build_observation(error="episode already done")
self._step_count += 1
action_type = (action.action_type or "").strip().lower()
text_action = (action.text or "").strip()
# Allow text-only actions: parse the text into structured fields
if not action_type and text_action:
parsed = _parse_text_action(text_action)
if parsed is not None:
action_type, fields = parsed
for key, value in fields.items():
if getattr(action, key, None) in (None, ""):
setattr(action, key, value)
step_reward = 0.0
info_message: Optional[str] = None
error: Optional[str] = None
try:
if action_type == "dispatch":
if not action.call_id or not action.unit_id:
error = "dispatch requires call_id and unit_id"
else:
step_reward, info_message = self.sim.dispatch(
call_id=action.call_id,
unit_id=action.unit_id,
hospital_id=action.hospital_id,
)
self.sim.advance_time(1)
elif action_type == "classify":
if not action.call_id or action.severity is None:
error = "classify requires call_id and severity (1-5)"
else:
step_reward, info_message = self.sim.classify(
call_id=action.call_id, severity=int(action.severity)
)
self.sim.advance_time(1)
elif action_type == "callback":
if not action.call_id:
error = "callback requires call_id"
else:
step_reward, info_message = self.sim.callback(
call_id=action.call_id, question=action.message or ""
)
self.sim.advance_time(1)
elif action_type == "wait":
minutes = int(action.minutes or 1)
minutes = max(1, min(minutes, self.sim.config.max_wait_step_minutes))
pending_before = len(self.sim.get_pending_calls())
self.sim.advance_time(minutes)
step_reward = -0.005 * minutes * pending_before
info_message = f"waited {minutes} minute(s)"
elif action_type == "view":
step_reward = 0.0
info_message = "view (no time cost)"
else:
step_reward = -0.05
error = f"unknown action_type: {action_type!r}"
except Exception as exc: # pragma: no cover - defensive
error = f"{type(exc).__name__}: {exc}"
step_reward = -0.05
self._cumulative_step_reward += step_reward
self._last_step_reward = step_reward
return self._build_observation(info_message=info_message, error=error)
@property
def state(self) -> DispatchPulseState:
if self.sim is None:
return DispatchPulseState(
episode_id=self._episode_id,
step_count=self._step_count,
task_name=self.task_name,
)
return DispatchPulseState(
episode_id=self._episode_id,
step_count=self._step_count,
current_time=self.sim.current_time,
episode_done=self.sim.episode_done,
total_calls=self.sim.total_calls(),
calls_dispatched=len(self.sim.dispatches),
calls_completed=len(self.sim.completed_calls),
calls_timed_out=len(self.sim.timed_out_calls),
calls_pending=len(self.sim.get_pending_calls()),
units_available=len(self.sim.get_available_units()),
running_reward=self._cumulative_step_reward,
task_name=self.task_name,
)
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
def _build_observation(
self,
info_message: Optional[str] = None,
error: Optional[str] = None,
) -> DispatchPulseObservation:
if self.sim is None:
return DispatchPulseObservation(
done=True,
reward=0.0,
text="ERROR: environment not initialised. Call reset first.",
last_action_error="not_initialised",
)
text = render_dispatch_center(self.sim, self.task_name)
done = bool(self.sim.episode_done)
if done:
final = grade_simulation(self.sim)
reward_value: float = float(final.total)
metadata = {
"final_reward": final.model_dump(),
"task": self.task_name,
"cumulative_step_reward": float(self._cumulative_step_reward),
}
else:
# Report the per-step delta, not the running cumulative. The
# cumulative is still available via state() and metadata, but the
# observation's reward field matches the standard Gym/OpenEnv
# semantics of "reward for this step only".
reward_value = float(self._last_step_reward)
metadata = {
"task": self.task_name,
"cumulative_step_reward": float(self._cumulative_step_reward),
}
if info_message:
metadata["info"] = info_message
if error:
metadata["error"] = error
return DispatchPulseObservation(
done=done,
reward=reward_value,
text=text,
current_time=self.sim.current_time,
time_limit=self.sim.config.time_limit_minutes,
calls_pending=len(self.sim.get_pending_calls()),
units_available=len(self.sim.get_available_units()),
calls_completed=len(self.sim.completed_calls),
calls_timed_out=len(self.sim.timed_out_calls),
total_calls=self.sim.total_calls(),
last_action_error=error,
info_message=info_message,
metadata=metadata,
)
def _parse_text_action(text: str):
"""Parse a text action like ``dispatch CALL-001 ALS-1 H1`` into fields.
Returns ``(action_type, kwargs_dict)`` or None on parse failure.
"""
parts = text.strip().split(maxsplit=4)
if not parts:
return None
head = parts[0].lower()
if head == "dispatch" and len(parts) >= 3:
out = {"call_id": parts[1], "unit_id": parts[2]}
if len(parts) >= 4 and parts[3]:
out["hospital_id"] = parts[3]
return "dispatch", out
if head == "classify" and len(parts) >= 3:
try:
sev = int(parts[2])
except ValueError:
return None
return "classify", {"call_id": parts[1], "severity": sev}
if head == "callback" and len(parts) >= 2:
return "callback", {
"call_id": parts[1],
"message": " ".join(parts[2:]) if len(parts) > 2 else "",
}
if head == "wait":
try:
mins = int(parts[1]) if len(parts) > 1 else 1
except ValueError:
mins = 1
return "wait", {"minutes": mins}
if head in ("view", "view_dispatch_center"):
return "view", {}
return None