"""Task registry and deterministic scenario fixtures for the 911 dispatch environment.""" from __future__ import annotations import random from typing import Any, Literal from pydantic import BaseModel, Field from src.models import ( IncidentSeverity, IncidentStatus, IncidentType, UnitState, UnitStatus, UnitType, ) class TaskInfo(BaseModel): """Task information metadata.""" model_config = {"extra": "forbid"} task_id: str = Field(..., min_length=1) name: str = Field(..., min_length=1) description: str = Field(..., min_length=1) difficulty: Literal["easy", "medium", "hard"] initial_state_fn: str = Field(..., min_length=1) class TaskRegistry: """Registry for managing available tasks.""" REGISTRY: dict[str, TaskInfo] = {} @classmethod def register(cls, task: TaskInfo) -> None: cls.REGISTRY[task.task_id] = task @classmethod def get(cls, task_id: str) -> TaskInfo: if task_id not in cls.REGISTRY: raise KeyError(f"Task '{task_id}' not found in registry") return cls.REGISTRY[task_id] @classmethod def list_tasks(cls) -> list[TaskInfo]: return list(cls.REGISTRY.values()) TaskRegistry.register( TaskInfo( task_id="single_incident", name="Single Incident Response", description="1 incident, right unit, fast dispatch", difficulty="easy", initial_state_fn="build_single_incident_fixture", ) ) TaskRegistry.register( TaskInfo( task_id="multi_incident", name="Simultaneous Multi-Incident", description="3 concurrent incidents requiring triage", difficulty="medium", initial_state_fn="build_multi_incident_fixture", ) ) TaskRegistry.register( TaskInfo( task_id="mass_casualty", name="Mass Casualty Event", description="wave-based incidents with resource conflict", difficulty="hard", initial_state_fn="build_mass_casualty_fixture", ) ) TaskRegistry.register( TaskInfo( task_id="shift_surge", name="Shift Surge", description="units go out of service + steady incident stream", difficulty="hard", initial_state_fn="build_shift_surge_fixture", ) ) class DispatchScenarioFactory: """Factory for creating deterministic dispatch scenario fixtures. Returns `(state_dict, meta_dict)`. """ @staticmethod def _seeded_random(seed: int) -> random.Random: return random.Random(seed) @classmethod def build(cls, task_id: str, seed: int) -> tuple[dict[str, Any], dict[str, Any]]: task = TaskRegistry.get(task_id) fn_name = task.initial_state_fn if fn_name == "build_single_incident_fixture": return cls.build_single_incident_fixture(seed) if fn_name == "build_multi_incident_fixture": return cls.build_multi_incident_fixture(seed) if fn_name == "build_mass_casualty_fixture": return cls.build_mass_casualty_fixture(seed) if fn_name == "build_shift_surge_fixture": return cls.build_shift_surge_fixture(seed) raise ValueError(f"Unknown initial_state_fn: {fn_name}") @classmethod def _base_units_city_small(cls) -> dict[str, UnitState]: return { "MED-1": UnitState( unit_id="MED-1", unit_type=UnitType.MEDIC, status=UnitStatus.AVAILABLE, location_x=10.0, location_y=10.0, assigned_incident_id=None, eta_seconds=0.0, crew_count=2, ), "ENG-1": UnitState( unit_id="ENG-1", unit_type=UnitType.ENGINE, status=UnitStatus.AVAILABLE, location_x=20.0, location_y=20.0, assigned_incident_id=None, eta_seconds=0.0, crew_count=4, ), "PAT-1": UnitState( unit_id="PAT-1", unit_type=UnitType.PATROL, status=UnitStatus.AVAILABLE, location_x=30.0, location_y=30.0, assigned_incident_id=None, eta_seconds=0.0, crew_count=2, ), } @classmethod def build_single_incident_fixture(cls, seed: int) -> tuple[dict[str, Any], dict[str, Any]]: rng = cls._seeded_random(seed) units = cls._base_units_city_small() incidents = { "INC-001": { "incident_id": "INC-001", "incident_type": IncidentType.CARDIAC_ARREST, "severity": IncidentSeverity.PRIORITY_1, "location_x": 12.0 + rng.uniform(-1.0, 1.0), "location_y": 12.0 + rng.uniform(-1.0, 1.0), "reported_at_step": 0, "units_assigned": [], "status": IncidentStatus.PENDING, "survival_clock": 600.0, } } state_dict = { "units": {k: v.model_dump() for k, v in units.items()}, "incidents": incidents, "episode_id": f"single-{seed}", "step_count": 0, "task_id": "single_incident", "city_time": 0.0, "metadata": {}, } meta = { "max_steps": 20, "waves": [], "districts": ["downtown", "northside", "eastport", "suburbs", "industrial"], "grid_size": [100, 100], } return state_dict, meta @classmethod def build_multi_incident_fixture(cls, seed: int) -> tuple[dict[str, Any], dict[str, Any]]: rng = cls._seeded_random(seed) units: dict[str, UnitState] = { **cls._base_units_city_small(), "MED-2": UnitState( unit_id="MED-2", unit_type=UnitType.MEDIC, status=UnitStatus.AVAILABLE, location_x=70.0, location_y=30.0, assigned_incident_id=None, eta_seconds=0.0, crew_count=2, ), "LAD-1": UnitState( unit_id="LAD-1", unit_type=UnitType.LADDER, status=UnitStatus.AVAILABLE, location_x=10.0, location_y=20.0, assigned_incident_id=None, eta_seconds=0.0, crew_count=5, ), "ENG-2": UnitState( unit_id="ENG-2", unit_type=UnitType.ENGINE, status=UnitStatus.AVAILABLE, location_x=50.0, location_y=50.0, assigned_incident_id=None, eta_seconds=0.0, crew_count=4, ), } incidents = { "INC-001": { "incident_id": "INC-001", "incident_type": IncidentType.STRUCTURE_FIRE, "severity": IncidentSeverity.PRIORITY_2, "location_x": 20.0 + rng.uniform(-2.0, 2.0), "location_y": 80.0 + rng.uniform(-2.0, 2.0), "reported_at_step": 0, "units_assigned": [], "status": IncidentStatus.PENDING, "survival_clock": 1200.0, }, "INC-002": { "incident_id": "INC-002", "incident_type": IncidentType.CARDIAC_ARREST, "severity": IncidentSeverity.PRIORITY_1, "location_x": 14.0 + rng.uniform(-2.0, 2.0), "location_y": 22.0 + rng.uniform(-2.0, 2.0), "reported_at_step": 0, "units_assigned": [], "status": IncidentStatus.PENDING, "survival_clock": 600.0, }, "INC-003": { "incident_id": "INC-003", "incident_type": IncidentType.SHOOTING, "severity": IncidentSeverity.PRIORITY_1, "location_x": 75.0 + rng.uniform(-2.0, 2.0), "location_y": 15.0 + rng.uniform(-2.0, 2.0), "reported_at_step": 0, "units_assigned": [], "status": IncidentStatus.PENDING, "survival_clock": 600.0, }, } state_dict = { "units": {k: v.model_dump() for k, v in units.items()}, "incidents": incidents, "episode_id": f"multi-{seed}", "step_count": 0, "task_id": "multi_incident", "city_time": 0.0, "metadata": {}, } meta = { "max_steps": 40, "waves": [], "districts": ["downtown", "northside", "eastport", "suburbs", "industrial"], "grid_size": [100, 100], } return state_dict, meta @classmethod def build_mass_casualty_fixture(cls, seed: int) -> tuple[dict[str, Any], dict[str, Any]]: rng = cls._seeded_random(seed) units: dict[str, UnitState] = { "ENG-1": UnitState(unit_id="ENG-1", unit_type=UnitType.ENGINE, status=UnitStatus.AVAILABLE, location_x=10.0, location_y=20.0, assigned_incident_id=None, eta_seconds=0.0, crew_count=4), "ENG-2": UnitState(unit_id="ENG-2", unit_type=UnitType.ENGINE, status=UnitStatus.AVAILABLE, location_x=50.0, location_y=50.0, assigned_incident_id=None, eta_seconds=0.0, crew_count=4), "MED-1": UnitState(unit_id="MED-1", unit_type=UnitType.MEDIC, status=UnitStatus.AVAILABLE, location_x=15.0, location_y=25.0, assigned_incident_id=None, eta_seconds=0.0, crew_count=2), "MED-2": UnitState(unit_id="MED-2", unit_type=UnitType.MEDIC, status=UnitStatus.AVAILABLE, location_x=70.0, location_y=30.0, assigned_incident_id=None, eta_seconds=0.0, crew_count=2), "LAD-1": UnitState(unit_id="LAD-1", unit_type=UnitType.LADDER, status=UnitStatus.AVAILABLE, location_x=10.0, location_y=20.0, assigned_incident_id=None, eta_seconds=0.0, crew_count=5), "PAT-1": UnitState(unit_id="PAT-1", unit_type=UnitType.PATROL, status=UnitStatus.AVAILABLE, location_x=30.0, location_y=60.0, assigned_incident_id=None, eta_seconds=0.0, crew_count=2), "PAT-2": UnitState(unit_id="PAT-2", unit_type=UnitType.PATROL, status=UnitStatus.AVAILABLE, location_x=80.0, location_y=10.0, assigned_incident_id=None, eta_seconds=0.0, crew_count=2), } incidents = { "INC-001": { "incident_id": "INC-001", "incident_type": IncidentType.BUILDING_COLLAPSE, "severity": IncidentSeverity.PRIORITY_1, "location_x": 45.0 + rng.uniform(-3.0, 3.0), "location_y": 45.0 + rng.uniform(-3.0, 3.0), "reported_at_step": 0, "units_assigned": [], "status": IncidentStatus.PENDING, "survival_clock": 480.0, } } waves = [ { "at_step": 5, "incidents": [ { "incident_id": "INC-002", "incident_type": IncidentType.STRUCTURE_FIRE, "severity": IncidentSeverity.PRIORITY_2, "location_x": 10.0 + rng.uniform(-3.0, 3.0), "location_y": 90.0 + rng.uniform(-3.0, 3.0), "reported_at_step": 5, "units_assigned": [], "status": IncidentStatus.PENDING, "survival_clock": 900.0, } ], }, { "at_step": 12, "incidents": [ { "incident_id": "INC-003", "incident_type": IncidentType.CARDIAC_ARREST, "severity": IncidentSeverity.PRIORITY_1, "location_x": 85.0 + rng.uniform(-3.0, 3.0), "location_y": 15.0 + rng.uniform(-3.0, 3.0), "reported_at_step": 12, "units_assigned": [], "status": IncidentStatus.PENDING, "survival_clock": 420.0, }, { "incident_id": "INC-004", "incident_type": IncidentType.CARDIAC_ARREST, "severity": IncidentSeverity.PRIORITY_1, "location_x": 15.0 + rng.uniform(-3.0, 3.0), "location_y": 10.0 + rng.uniform(-3.0, 3.0), "reported_at_step": 12, "units_assigned": [], "status": IncidentStatus.PENDING, "survival_clock": 420.0, }, ], }, ] state_dict = { "units": {k: v.model_dump() for k, v in units.items()}, "incidents": incidents, "episode_id": f"mass-{seed}", "step_count": 0, "task_id": "mass_casualty", "city_time": 0.0, "metadata": {}, } meta = { "max_steps": 60, "waves": waves, "mutual_aid_eta_penalty": 120.0, "districts": ["downtown", "northside", "eastport", "suburbs", "industrial"], "grid_size": [100, 100], } return state_dict, meta @classmethod def build_shift_surge_fixture(cls, seed: int) -> tuple[dict[str, Any], dict[str, Any]]: rng = cls._seeded_random(seed) units: dict[str, UnitState] = { "ENG-1": UnitState(unit_id="ENG-1", unit_type=UnitType.ENGINE, status=UnitStatus.AVAILABLE, location_x=10.0, location_y=20.0, assigned_incident_id=None, eta_seconds=0.0, crew_count=4), "MED-1": UnitState(unit_id="MED-1", unit_type=UnitType.MEDIC, status=UnitStatus.AVAILABLE, location_x=15.0, location_y=25.0, assigned_incident_id=None, eta_seconds=0.0, crew_count=2), "PAT-1": UnitState(unit_id="PAT-1", unit_type=UnitType.PATROL, status=UnitStatus.AVAILABLE, location_x=30.0, location_y=60.0, assigned_incident_id=None, eta_seconds=0.0, crew_count=2), "PAT-2": UnitState(unit_id="PAT-2", unit_type=UnitType.PATROL, status=UnitStatus.AVAILABLE, location_x=80.0, location_y=10.0, assigned_incident_id=None, eta_seconds=0.0, crew_count=2), "ENG-2": UnitState(unit_id="ENG-2", unit_type=UnitType.ENGINE, status=UnitStatus.AVAILABLE, location_x=50.0, location_y=50.0, assigned_incident_id=None, eta_seconds=0.0, crew_count=4), } waves: list[dict[str, Any]] = [] next_id = 0 for t in range(0, 56, 8): next_id += 1 waves.append( { "at_step": t, "incidents": [ { "incident_id": f"INC-{next_id:03d}", "incident_type": rng.choice(list(IncidentType)), "severity": rng.choice(list(IncidentSeverity)), "location_x": rng.uniform(0.0, 100.0), "location_y": rng.uniform(0.0, 100.0), "reported_at_step": t, "units_assigned": [], "status": IncidentStatus.PENDING, "survival_clock": 720.0, } ], } ) unit_status_changes = [ {"at_step": 1, "unit_id": "PAT-2", "status": UnitStatus.OUT_OF_SERVICE}, {"at_step": 3, "unit_id": "ENG-2", "status": UnitStatus.OUT_OF_SERVICE}, {"at_step": 5, "unit_id": "PAT-1", "status": UnitStatus.OUT_OF_SERVICE}, ] state_dict = { "units": {k: v.model_dump() for k, v in units.items()}, "incidents": {}, "episode_id": f"surge-{seed}", "step_count": 0, "task_id": "shift_surge", "city_time": 0.0, "metadata": {}, } meta = { "max_steps": 60, "waves": waves, "unit_status_changes": unit_status_changes, "districts": ["downtown", "northside", "eastport", "suburbs", "industrial"], "grid_size": [100, 100], } return state_dict, meta