"""Task registry for DispatchPulse. This module is the canonical source of truth for the three graded tasks that DispatchPulse ships. Each task is declared as a frozen ``TaskDefinition`` dataclass and registered in the module-level ``TASKS`` dict. This mirrors the pattern used by other passing Meta PyTorch OpenEnv Hackathon submissions (e.g. Calendar Scheduling). Every task in ``TASKS`` is graded by the module-level ``grade_submission`` function below, which returns a deterministic score in [0.0, 1.0]. There are exactly three tasks: ``easy``, ``medium``, ``hard``. """ from __future__ import annotations from dataclasses import dataclass from typing import Dict, List, Literal, Optional, Tuple from grader import grade_simulation from reward import calculate_episode_reward from scenario_loader import load_scenario from simulation import DispatchSimulation # --------------------------------------------------------------------------- # Task dataclasses # --------------------------------------------------------------------------- @dataclass(frozen=True) class TaskDefinition: """A single graded task. Attributes: task_id: Stable identifier used by the server, the grader, and the inference script. Matches the filename in ``tasks/``. name: Human-readable name for the task. difficulty: One of ``easy``, ``medium``, ``hard``. description: Multi-sentence description explaining what the agent has to do and what makes the task hard. max_steps: Upper bound on the number of agent actions per episode (matches the scenario's ``time_limit_minutes``). time_limit_minutes: Wall-clock time limit for the simulated episode. num_calls: Total number of emergency calls scheduled for the episode. num_units: Number of emergency units available to dispatch. num_hospitals: Number of hospitals on the map. caller_inaccuracy: Fraction of callers who misreport the emergency type or severity (0.0 = always accurate, 1.0 = always wrong). """ task_id: str name: str difficulty: Literal["easy", "medium", "hard"] description: str max_steps: int time_limit_minutes: int num_calls: int num_units: int num_hospitals: int caller_inaccuracy: float # --------------------------------------------------------------------------- # Task registry — populated at import time by introspecting the YAML files. # --------------------------------------------------------------------------- def _build_task(task_id: str, name: str, difficulty: str, description: str) -> TaskDefinition: """Build a TaskDefinition by loading the YAML scenario for task_id.""" scenario = load_scenario(task_id) world_cfg = scenario.get("world_config", {}) or {} return TaskDefinition( task_id=task_id, name=name, difficulty=difficulty, # type: ignore[arg-type] description=description.strip(), max_steps=int(world_cfg.get("time_limit_minutes", 30)), time_limit_minutes=int(world_cfg.get("time_limit_minutes", 30)), num_calls=len(scenario.get("calls", [])), num_units=len(scenario.get("units", [])), num_hospitals=len(scenario.get("hospitals", [])), caller_inaccuracy=float(scenario.get("caller_inaccuracy", 0.0)), ) TASKS: Dict[str, TaskDefinition] = { "easy": _build_task( task_id="easy", name="Routine Urban Shift", difficulty="easy", description=( "Five emergency calls arrive over 30 minutes. The dispatcher " "has four units (ALS ambulance, BLS ambulance, fire engine, " "police) and one well-equipped hospital. Callers report their " "emergency accurately. Optimal play — dispatching the right " "unit type to the right call in the right order — scores 0.85 " "or higher. A silent 'do nothing' agent scores 0." ), ), "medium": _build_task( task_id="medium", name="Urban Mass Casualty", difficulty="medium", description=( "Fifteen emergency calls over 45 minutes including a mass " "casualty bus accident at minute 12 that spawns multiple " "severity-1 trauma calls simultaneously. The dispatcher has " "six units and two hospitals. 20% of callers misreport the " "emergency type due to panic. The core challenge: ALS " "conservation — if you spend your only ALS ambulance on a " "minor injury, the cardiac arrest arriving 4 minutes later " "has no good unit to send." ), ), "hard": _build_task( task_id="hard", name="Earthquake Response", difficulty="hard", description=( "An earthquake triggers 30 emergency calls over 60 minutes. " "The dispatcher has eight units and three hospitals — but one " "hospital is on diversion and another is near bed capacity. " "35% of callers misreport due to panic. Hospital-routing " "decisions meaningfully affect outcome: cardiac patients " "routed to a hospital without a cardiac unit survive less " "often. This is the full difficulty tier — even a good agent " "will score in the 0.40-0.55 range because the scenario is " "deliberately resource-scarce." ), ), } # --------------------------------------------------------------------------- # Public API — the symbols the validator looks for # --------------------------------------------------------------------------- def list_tasks() -> List[TaskDefinition]: """Return all registered tasks as a list. The validator calls this (or inspects the ``TASKS`` dict directly) to count how many graded tasks the environment ships with. We return them in difficulty order: easy, medium, hard. """ return [TASKS["easy"], TASKS["medium"], TASKS["hard"]] def get_task(task_id: str) -> TaskDefinition: """Look up a single task by id. Raises KeyError if unknown.""" if task_id not in TASKS: raise KeyError( f"unknown task_id '{task_id}'. Known tasks: {', '.join(TASKS.keys())}" ) return TASKS[task_id] def grade_submission( task_id: str, actions: Optional[List[Dict]] = None, seed: int = 42, ) -> Tuple[float, Dict]: """Grade a submission for a task. Two modes: 1. **Silent run** — when ``actions`` is None, runs the task to time limit with no agent decisions. All calls time out. Used as a sanity check that the grader and task both load correctly. Returns score 0.0. 2. **Replay mode** — when ``actions`` is a list of action dicts like ``[{"action_type": "dispatch", "call_id": "CALL-001", "unit_id": "ALS-1"}, ...]``, the grader replays them through a fresh simulation seeded with ``seed`` and returns the final score. Args: task_id: One of ``easy``, ``medium``, ``hard``. actions: Optional list of action dicts to replay. seed: Random seed for the simulation (default 42 for reproducibility). Returns: A tuple ``(score, details_dict)`` where ``score`` is a float in [0.0, 1.0] and ``details_dict`` has the full reward breakdown plus call counts. """ if task_id not in TASKS: raise KeyError( f"unknown task_id '{task_id}'. Known tasks: {', '.join(TASKS.keys())}" ) scenario = load_scenario(task_id) sim = DispatchSimulation(scenario, seed=seed) if actions: _replay_actions(sim, actions) # Always fast-forward to episode end so the reward is final. while not sim.episode_done: sim.advance_time(sim.config.time_limit_minutes) reward = calculate_episode_reward( sim.completed_calls, sim.timed_out_calls, sim.total_calls(), sim.dispatches, ) details = { "task_id": task_id, "score": reward.total, "passed": reward.total >= 0.20, "survival_score": reward.survival_score, "efficiency_score": reward.efficiency_score, "triage_accuracy": reward.triage_accuracy, "penalty": reward.penalty, "details": reward.details, "completed_calls": len(sim.completed_calls), "timed_out_calls": len(sim.timed_out_calls), "total_calls": sim.total_calls(), } return reward.total, details def _replay_actions(sim: DispatchSimulation, actions: List[Dict]) -> None: """Replay a scripted action list through a fresh simulation.""" max_steps = 500 for idx, act in enumerate(actions): if idx >= max_steps or sim.episode_done: break atype = (act.get("action_type") or "").strip().lower() if atype == "dispatch": sim.dispatch( call_id=str(act.get("call_id", "")), unit_id=str(act.get("unit_id", "")), hospital_id=act.get("hospital_id"), ) sim.advance_time(1) elif atype == "classify": try: sev = int(act.get("severity", 3)) except (TypeError, ValueError): sev = 3 sim.classify(str(act.get("call_id", "")), sev) sim.advance_time(1) elif atype == "callback": sim.callback( str(act.get("call_id", "")), str(act.get("message", act.get("question", ""))), ) sim.advance_time(1) elif atype == "wait": try: mins = int(act.get("minutes", 1)) except (TypeError, ValueError): mins = 1 sim.advance_time(max(1, min(mins, sim.config.max_wait_step_minutes))) elif atype == "view": continue else: sim.advance_time(1)