Spaces:
Sleeping
Sleeping
| """Task registry for DispatchPulse. | |
| This module is the canonical source of truth for the three graded tasks that | |
| DispatchPulse ships. Each task is declared as a frozen ``TaskDefinition`` | |
| dataclass and registered in the module-level ``TASKS`` dict. This mirrors the | |
| pattern used by other passing Meta PyTorch OpenEnv Hackathon submissions | |
| (e.g. Calendar Scheduling). | |
| Every task in ``TASKS`` is graded by the module-level ``grade_submission`` | |
| function below, which returns a deterministic score in [0.0, 1.0]. | |
| There are exactly three tasks: ``easy``, ``medium``, ``hard``. | |
| """ | |
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from typing import Dict, List, Literal, Optional, Tuple | |
| from grader import grade_simulation | |
| from reward import calculate_episode_reward | |
| from scenario_loader import load_scenario | |
| from simulation import DispatchSimulation | |
| # --------------------------------------------------------------------------- | |
| # Task dataclasses | |
| # --------------------------------------------------------------------------- | |
| class TaskDefinition: | |
| """A single graded task. | |
| Attributes: | |
| task_id: Stable identifier used by the server, the grader, and the | |
| inference script. Matches the filename in ``tasks/``. | |
| name: Human-readable name for the task. | |
| difficulty: One of ``easy``, ``medium``, ``hard``. | |
| description: Multi-sentence description explaining what the agent has | |
| to do and what makes the task hard. | |
| max_steps: Upper bound on the number of agent actions per episode | |
| (matches the scenario's ``time_limit_minutes``). | |
| time_limit_minutes: Wall-clock time limit for the simulated episode. | |
| num_calls: Total number of emergency calls scheduled for the episode. | |
| num_units: Number of emergency units available to dispatch. | |
| num_hospitals: Number of hospitals on the map. | |
| caller_inaccuracy: Fraction of callers who misreport the emergency | |
| type or severity (0.0 = always accurate, 1.0 = always wrong). | |
| """ | |
| task_id: str | |
| name: str | |
| difficulty: Literal["easy", "medium", "hard"] | |
| description: str | |
| max_steps: int | |
| time_limit_minutes: int | |
| num_calls: int | |
| num_units: int | |
| num_hospitals: int | |
| caller_inaccuracy: float | |
| # --------------------------------------------------------------------------- | |
| # Task registry β populated at import time by introspecting the YAML files. | |
| # --------------------------------------------------------------------------- | |
| def _build_task(task_id: str, name: str, difficulty: str, description: str) -> TaskDefinition: | |
| """Build a TaskDefinition by loading the YAML scenario for task_id.""" | |
| scenario = load_scenario(task_id) | |
| world_cfg = scenario.get("world_config", {}) or {} | |
| return TaskDefinition( | |
| task_id=task_id, | |
| name=name, | |
| difficulty=difficulty, # type: ignore[arg-type] | |
| description=description.strip(), | |
| max_steps=int(world_cfg.get("time_limit_minutes", 30)), | |
| time_limit_minutes=int(world_cfg.get("time_limit_minutes", 30)), | |
| num_calls=len(scenario.get("calls", [])), | |
| num_units=len(scenario.get("units", [])), | |
| num_hospitals=len(scenario.get("hospitals", [])), | |
| caller_inaccuracy=float(scenario.get("caller_inaccuracy", 0.0)), | |
| ) | |
| TASKS: Dict[str, TaskDefinition] = { | |
| "easy": _build_task( | |
| task_id="easy", | |
| name="Routine Urban Shift", | |
| difficulty="easy", | |
| description=( | |
| "Five emergency calls arrive over 30 minutes. The dispatcher " | |
| "has four units (ALS ambulance, BLS ambulance, fire engine, " | |
| "police) and one well-equipped hospital. Callers report their " | |
| "emergency accurately. Optimal play β dispatching the right " | |
| "unit type to the right call in the right order β scores 0.85 " | |
| "or higher. A silent 'do nothing' agent scores 0." | |
| ), | |
| ), | |
| "medium": _build_task( | |
| task_id="medium", | |
| name="Urban Mass Casualty", | |
| difficulty="medium", | |
| description=( | |
| "Fifteen emergency calls over 45 minutes including a mass " | |
| "casualty bus accident at minute 12 that spawns multiple " | |
| "severity-1 trauma calls simultaneously. The dispatcher has " | |
| "six units and two hospitals. 20% of callers misreport the " | |
| "emergency type due to panic. The core challenge: ALS " | |
| "conservation β if you spend your only ALS ambulance on a " | |
| "minor injury, the cardiac arrest arriving 4 minutes later " | |
| "has no good unit to send." | |
| ), | |
| ), | |
| "hard": _build_task( | |
| task_id="hard", | |
| name="Earthquake Response", | |
| difficulty="hard", | |
| description=( | |
| "An earthquake triggers 30 emergency calls over 60 minutes. " | |
| "The dispatcher has eight units and three hospitals β but one " | |
| "hospital is on diversion and another is near bed capacity. " | |
| "35% of callers misreport due to panic. Hospital-routing " | |
| "decisions meaningfully affect outcome: cardiac patients " | |
| "routed to a hospital without a cardiac unit survive less " | |
| "often. This is the full difficulty tier β even a good agent " | |
| "will score in the 0.40-0.55 range because the scenario is " | |
| "deliberately resource-scarce." | |
| ), | |
| ), | |
| } | |
| # --------------------------------------------------------------------------- | |
| # Public API β the symbols the validator looks for | |
| # --------------------------------------------------------------------------- | |
| def list_tasks() -> List[TaskDefinition]: | |
| """Return all registered tasks as a list. | |
| The validator calls this (or inspects the ``TASKS`` dict directly) to | |
| count how many graded tasks the environment ships with. We return them | |
| in difficulty order: easy, medium, hard. | |
| """ | |
| return [TASKS["easy"], TASKS["medium"], TASKS["hard"]] | |
| def get_task(task_id: str) -> TaskDefinition: | |
| """Look up a single task by id. Raises KeyError if unknown.""" | |
| if task_id not in TASKS: | |
| raise KeyError( | |
| f"unknown task_id '{task_id}'. Known tasks: {', '.join(TASKS.keys())}" | |
| ) | |
| return TASKS[task_id] | |
| def grade_submission( | |
| task_id: str, | |
| actions: Optional[List[Dict]] = None, | |
| seed: int = 42, | |
| ) -> Tuple[float, Dict]: | |
| """Grade a submission for a task. | |
| Two modes: | |
| 1. **Silent run** β when ``actions`` is None, runs the task to time | |
| limit with no agent decisions. All calls time out. Used as a | |
| sanity check that the grader and task both load correctly. Returns | |
| score 0.0. | |
| 2. **Replay mode** β when ``actions`` is a list of action dicts like | |
| ``[{"action_type": "dispatch", "call_id": "CALL-001", "unit_id": "ALS-1"}, ...]``, | |
| the grader replays them through a fresh simulation seeded with | |
| ``seed`` and returns the final score. | |
| Args: | |
| task_id: One of ``easy``, ``medium``, ``hard``. | |
| actions: Optional list of action dicts to replay. | |
| seed: Random seed for the simulation (default 42 for reproducibility). | |
| Returns: | |
| A tuple ``(score, details_dict)`` where ``score`` is a float in | |
| [0.0, 1.0] and ``details_dict`` has the full reward breakdown plus | |
| call counts. | |
| """ | |
| if task_id not in TASKS: | |
| raise KeyError( | |
| f"unknown task_id '{task_id}'. Known tasks: {', '.join(TASKS.keys())}" | |
| ) | |
| scenario = load_scenario(task_id) | |
| sim = DispatchSimulation(scenario, seed=seed) | |
| if actions: | |
| _replay_actions(sim, actions) | |
| # Always fast-forward to episode end so the reward is final. | |
| while not sim.episode_done: | |
| sim.advance_time(sim.config.time_limit_minutes) | |
| reward = calculate_episode_reward( | |
| sim.completed_calls, | |
| sim.timed_out_calls, | |
| sim.total_calls(), | |
| sim.dispatches, | |
| ) | |
| details = { | |
| "task_id": task_id, | |
| "score": reward.total, | |
| "passed": reward.total >= 0.20, | |
| "survival_score": reward.survival_score, | |
| "efficiency_score": reward.efficiency_score, | |
| "triage_accuracy": reward.triage_accuracy, | |
| "penalty": reward.penalty, | |
| "details": reward.details, | |
| "completed_calls": len(sim.completed_calls), | |
| "timed_out_calls": len(sim.timed_out_calls), | |
| "total_calls": sim.total_calls(), | |
| } | |
| return reward.total, details | |
| def _replay_actions(sim: DispatchSimulation, actions: List[Dict]) -> None: | |
| """Replay a scripted action list through a fresh simulation.""" | |
| max_steps = 500 | |
| for idx, act in enumerate(actions): | |
| if idx >= max_steps or sim.episode_done: | |
| break | |
| atype = (act.get("action_type") or "").strip().lower() | |
| if atype == "dispatch": | |
| sim.dispatch( | |
| call_id=str(act.get("call_id", "")), | |
| unit_id=str(act.get("unit_id", "")), | |
| hospital_id=act.get("hospital_id"), | |
| ) | |
| sim.advance_time(1) | |
| elif atype == "classify": | |
| try: | |
| sev = int(act.get("severity", 3)) | |
| except (TypeError, ValueError): | |
| sev = 3 | |
| sim.classify(str(act.get("call_id", "")), sev) | |
| sim.advance_time(1) | |
| elif atype == "callback": | |
| sim.callback( | |
| str(act.get("call_id", "")), | |
| str(act.get("message", act.get("question", ""))), | |
| ) | |
| sim.advance_time(1) | |
| elif atype == "wait": | |
| try: | |
| mins = int(act.get("minutes", 1)) | |
| except (TypeError, ValueError): | |
| mins = 1 | |
| sim.advance_time(max(1, min(mins, sim.config.max_wait_step_minutes))) | |
| elif atype == "view": | |
| continue | |
| else: | |
| sim.advance_time(1) | |