dispatchpulse / task_definitions.py
Arun-Sanjay's picture
Trim task_definitions.py to match minimal passing pattern
82713c7
"""Task registry for DispatchPulse.
This module is the canonical source of truth for the three graded tasks that
DispatchPulse ships. Each task is declared as a frozen ``TaskDefinition``
dataclass and registered in the module-level ``TASKS`` dict. This mirrors the
pattern used by other passing Meta PyTorch OpenEnv Hackathon submissions
(e.g. Calendar Scheduling).
Every task in ``TASKS`` is graded by the module-level ``grade_submission``
function below, which returns a deterministic score in [0.0, 1.0].
There are exactly three tasks: ``easy``, ``medium``, ``hard``.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Dict, List, Literal, Optional, Tuple
from grader import grade_simulation
from reward import calculate_episode_reward
from scenario_loader import load_scenario
from simulation import DispatchSimulation
# ---------------------------------------------------------------------------
# Task dataclasses
# ---------------------------------------------------------------------------
@dataclass(frozen=True)
class TaskDefinition:
"""A single graded task.
Attributes:
task_id: Stable identifier used by the server, the grader, and the
inference script. Matches the filename in ``tasks/``.
name: Human-readable name for the task.
difficulty: One of ``easy``, ``medium``, ``hard``.
description: Multi-sentence description explaining what the agent has
to do and what makes the task hard.
max_steps: Upper bound on the number of agent actions per episode
(matches the scenario's ``time_limit_minutes``).
time_limit_minutes: Wall-clock time limit for the simulated episode.
num_calls: Total number of emergency calls scheduled for the episode.
num_units: Number of emergency units available to dispatch.
num_hospitals: Number of hospitals on the map.
caller_inaccuracy: Fraction of callers who misreport the emergency
type or severity (0.0 = always accurate, 1.0 = always wrong).
"""
task_id: str
name: str
difficulty: Literal["easy", "medium", "hard"]
description: str
max_steps: int
time_limit_minutes: int
num_calls: int
num_units: int
num_hospitals: int
caller_inaccuracy: float
# ---------------------------------------------------------------------------
# Task registry β€” populated at import time by introspecting the YAML files.
# ---------------------------------------------------------------------------
def _build_task(task_id: str, name: str, difficulty: str, description: str) -> TaskDefinition:
"""Build a TaskDefinition by loading the YAML scenario for task_id."""
scenario = load_scenario(task_id)
world_cfg = scenario.get("world_config", {}) or {}
return TaskDefinition(
task_id=task_id,
name=name,
difficulty=difficulty, # type: ignore[arg-type]
description=description.strip(),
max_steps=int(world_cfg.get("time_limit_minutes", 30)),
time_limit_minutes=int(world_cfg.get("time_limit_minutes", 30)),
num_calls=len(scenario.get("calls", [])),
num_units=len(scenario.get("units", [])),
num_hospitals=len(scenario.get("hospitals", [])),
caller_inaccuracy=float(scenario.get("caller_inaccuracy", 0.0)),
)
TASKS: Dict[str, TaskDefinition] = {
"easy": _build_task(
task_id="easy",
name="Routine Urban Shift",
difficulty="easy",
description=(
"Five emergency calls arrive over 30 minutes. The dispatcher "
"has four units (ALS ambulance, BLS ambulance, fire engine, "
"police) and one well-equipped hospital. Callers report their "
"emergency accurately. Optimal play β€” dispatching the right "
"unit type to the right call in the right order β€” scores 0.85 "
"or higher. A silent 'do nothing' agent scores 0."
),
),
"medium": _build_task(
task_id="medium",
name="Urban Mass Casualty",
difficulty="medium",
description=(
"Fifteen emergency calls over 45 minutes including a mass "
"casualty bus accident at minute 12 that spawns multiple "
"severity-1 trauma calls simultaneously. The dispatcher has "
"six units and two hospitals. 20% of callers misreport the "
"emergency type due to panic. The core challenge: ALS "
"conservation β€” if you spend your only ALS ambulance on a "
"minor injury, the cardiac arrest arriving 4 minutes later "
"has no good unit to send."
),
),
"hard": _build_task(
task_id="hard",
name="Earthquake Response",
difficulty="hard",
description=(
"An earthquake triggers 30 emergency calls over 60 minutes. "
"The dispatcher has eight units and three hospitals β€” but one "
"hospital is on diversion and another is near bed capacity. "
"35% of callers misreport due to panic. Hospital-routing "
"decisions meaningfully affect outcome: cardiac patients "
"routed to a hospital without a cardiac unit survive less "
"often. This is the full difficulty tier β€” even a good agent "
"will score in the 0.40-0.55 range because the scenario is "
"deliberately resource-scarce."
),
),
}
# ---------------------------------------------------------------------------
# Public API β€” the symbols the validator looks for
# ---------------------------------------------------------------------------
def list_tasks() -> List[TaskDefinition]:
"""Return all registered tasks as a list.
The validator calls this (or inspects the ``TASKS`` dict directly) to
count how many graded tasks the environment ships with. We return them
in difficulty order: easy, medium, hard.
"""
return [TASKS["easy"], TASKS["medium"], TASKS["hard"]]
def get_task(task_id: str) -> TaskDefinition:
"""Look up a single task by id. Raises KeyError if unknown."""
if task_id not in TASKS:
raise KeyError(
f"unknown task_id '{task_id}'. Known tasks: {', '.join(TASKS.keys())}"
)
return TASKS[task_id]
def grade_submission(
task_id: str,
actions: Optional[List[Dict]] = None,
seed: int = 42,
) -> Tuple[float, Dict]:
"""Grade a submission for a task.
Two modes:
1. **Silent run** β€” when ``actions`` is None, runs the task to time
limit with no agent decisions. All calls time out. Used as a
sanity check that the grader and task both load correctly. Returns
score 0.0.
2. **Replay mode** β€” when ``actions`` is a list of action dicts like
``[{"action_type": "dispatch", "call_id": "CALL-001", "unit_id": "ALS-1"}, ...]``,
the grader replays them through a fresh simulation seeded with
``seed`` and returns the final score.
Args:
task_id: One of ``easy``, ``medium``, ``hard``.
actions: Optional list of action dicts to replay.
seed: Random seed for the simulation (default 42 for reproducibility).
Returns:
A tuple ``(score, details_dict)`` where ``score`` is a float in
[0.0, 1.0] and ``details_dict`` has the full reward breakdown plus
call counts.
"""
if task_id not in TASKS:
raise KeyError(
f"unknown task_id '{task_id}'. Known tasks: {', '.join(TASKS.keys())}"
)
scenario = load_scenario(task_id)
sim = DispatchSimulation(scenario, seed=seed)
if actions:
_replay_actions(sim, actions)
# Always fast-forward to episode end so the reward is final.
while not sim.episode_done:
sim.advance_time(sim.config.time_limit_minutes)
reward = calculate_episode_reward(
sim.completed_calls,
sim.timed_out_calls,
sim.total_calls(),
sim.dispatches,
)
details = {
"task_id": task_id,
"score": reward.total,
"passed": reward.total >= 0.20,
"survival_score": reward.survival_score,
"efficiency_score": reward.efficiency_score,
"triage_accuracy": reward.triage_accuracy,
"penalty": reward.penalty,
"details": reward.details,
"completed_calls": len(sim.completed_calls),
"timed_out_calls": len(sim.timed_out_calls),
"total_calls": sim.total_calls(),
}
return reward.total, details
def _replay_actions(sim: DispatchSimulation, actions: List[Dict]) -> None:
"""Replay a scripted action list through a fresh simulation."""
max_steps = 500
for idx, act in enumerate(actions):
if idx >= max_steps or sim.episode_done:
break
atype = (act.get("action_type") or "").strip().lower()
if atype == "dispatch":
sim.dispatch(
call_id=str(act.get("call_id", "")),
unit_id=str(act.get("unit_id", "")),
hospital_id=act.get("hospital_id"),
)
sim.advance_time(1)
elif atype == "classify":
try:
sev = int(act.get("severity", 3))
except (TypeError, ValueError):
sev = 3
sim.classify(str(act.get("call_id", "")), sev)
sim.advance_time(1)
elif atype == "callback":
sim.callback(
str(act.get("call_id", "")),
str(act.get("message", act.get("question", ""))),
)
sim.advance_time(1)
elif atype == "wait":
try:
mins = int(act.get("minutes", 1))
except (TypeError, ValueError):
mins = 1
sim.advance_time(max(1, min(mins, sim.config.max_wait_step_minutes)))
elif atype == "view":
continue
else:
sim.advance_time(1)