Spaces:
Sleeping
Sleeping
| """Dispatch state machine for the 911 supervisor environment.""" | |
| from __future__ import annotations | |
| import math | |
| import random | |
| from src.city_schema import CitySchema | |
| from src.models import ( | |
| Action, | |
| DispatchAction, | |
| IncidentSeverity, | |
| IncidentState, | |
| IncidentStatus, | |
| IncidentType, | |
| Observation, | |
| State, | |
| UnitState, | |
| UnitStatus, | |
| ) | |
| from src.protocol import DispatchProtocolValidator | |
| from src.rewards import RewardCalculator | |
| from src.tasks.registry import DispatchScenarioFactory | |
| DEFAULT_DT_S = 30.0 | |
| MAX_STEPS = 200 | |
| def _severity_deadline_seconds(severity: IncidentSeverity) -> float: | |
| if severity == IncidentSeverity.PRIORITY_1: | |
| return 600.0 | |
| if severity == IncidentSeverity.PRIORITY_2: | |
| return 1200.0 | |
| return 1800.0 | |
| def _resolve_timer_seconds(severity: IncidentSeverity) -> float: | |
| if severity == IncidentSeverity.PRIORITY_1: | |
| return 300.0 | |
| if severity == IncidentSeverity.PRIORITY_2: | |
| return 600.0 | |
| return 900.0 | |
| def _distance(x1: float, y1: float, x2: float, y2: float) -> float: | |
| return math.hypot(x2 - x1, y2 - y1) | |
| class DispatchStateMachine: | |
| """Deterministic dispatch state machine. | |
| Supports dispatch operations and advances incidents through: | |
| PENDING → RESPONDING → ON_SCENE → RESOLVED. | |
| """ | |
| def __init__(self, schema: CitySchema, seed: int | None = None) -> None: | |
| self._schema = schema | |
| self._seed = seed | |
| self._rng = random.Random(seed) | |
| self._validator = DispatchProtocolValidator() | |
| self._rewards = RewardCalculator() | |
| self._incident_counter = 0 | |
| def reset(self, task_id: str, episode_id: str) -> State: | |
| self._rng = random.Random(self._seed) | |
| self._incident_counter = 0 | |
| seed = self._seed if self._seed is not None else 42 | |
| state_dict, meta = DispatchScenarioFactory.build(task_id=task_id, seed=seed) | |
| state_dict["episode_id"] = episode_id | |
| state = State.model_validate(state_dict) | |
| # Enrich metadata with schema-derived info for rewards and validation. | |
| schema_dump = self._schema.model_dump() | |
| state.metadata.setdefault("seed", seed) | |
| state.metadata.setdefault("schema", self._schema.city_name) | |
| state.metadata.setdefault("districts", meta.get("districts", schema_dump.get("districts", []))) | |
| state.metadata.setdefault("grid_size", meta.get("grid_size", schema_dump.get("grid_size", []))) | |
| state.metadata.setdefault("unit_speeds", schema_dump.get("unit_speeds", {})) | |
| # Convert unit type values to plain strings for consistent lookup | |
| raw_required = schema_dump.get("default_required_units", {}) | |
| converted_required: dict[str, list[str]] = {} | |
| for inc_type, unit_types in raw_required.items(): | |
| inc_key = getattr(inc_type, "value", None) or str(inc_type) | |
| converted_required[str(inc_key)] = [ | |
| str(getattr(u, "value", None) or str(u)) for u in list(unit_types) | |
| ] | |
| state.metadata.setdefault("default_required_units", converted_required) | |
| state.metadata["max_steps"] = int(meta.get("max_steps", MAX_STEPS)) | |
| state.metadata["waves"] = list(meta.get("waves", [])) | |
| state.metadata["unit_status_changes"] = list(meta.get("unit_status_changes", [])) | |
| if "mutual_aid_eta_penalty" in meta: | |
| state.metadata["mutual_aid_eta_penalty"] = float(meta["mutual_aid_eta_penalty"]) | |
| state.metadata.setdefault("resolved_incidents", []) | |
| state.metadata.setdefault("failed_incidents", []) | |
| state.metadata.setdefault("p1_seen", []) | |
| # Apply any wave configured for step 0 at reset. | |
| for wave in list(state.metadata.get("waves", [])): | |
| if int(wave.get("at_step", -1)) != 0: | |
| continue | |
| for inc in wave.get("incidents", []): | |
| incident_obj = IncidentState.model_validate(inc) | |
| state.incidents[incident_obj.incident_id] = incident_obj | |
| # Initialize P1 tracking. | |
| for inc in state.incidents.values(): | |
| if inc.severity == IncidentSeverity.PRIORITY_1 and inc.incident_id not in state.metadata["p1_seen"]: | |
| state.metadata["p1_seen"].append(inc.incident_id) | |
| return state | |
| def get_legal_actions(self, state: State) -> list[Action]: | |
| actions: list[Action] = [] | |
| active_incidents = [ | |
| i for i in state.incidents.values() if i.status not in {IncidentStatus.RESOLVED} | |
| ] | |
| if not active_incidents: | |
| return actions | |
| # Keep ordering stable and DISPATCH-first for callers that take legal[0]. | |
| active_incidents_sorted = sorted(active_incidents, key=lambda i: i.incident_id) | |
| units_sorted = sorted(state.units.values(), key=lambda u: u.unit_id) | |
| # Pick a deterministic "reference" unit for actions that don't semantically need one | |
| # (UPGRADE/DOWNGRADE require unit_id in the Action contract). | |
| ref_unit_id = units_sorted[0].unit_id if units_sorted else "" | |
| # DISPATCH actions (primary control surface) | |
| for unit in units_sorted: | |
| if unit.status != UnitStatus.AVAILABLE: | |
| continue | |
| for incident in active_incidents_sorted: | |
| actions.append( | |
| Action( | |
| action_type=DispatchAction.DISPATCH, | |
| unit_id=unit.unit_id, | |
| incident_id=incident.incident_id, | |
| ) | |
| ) | |
| # STAGE actions (pre-position without committing as assigned) | |
| for unit in units_sorted: | |
| if unit.status != UnitStatus.AVAILABLE: | |
| continue | |
| for incident in active_incidents_sorted: | |
| if incident.status != IncidentStatus.PENDING: | |
| continue | |
| actions.append( | |
| Action( | |
| action_type=DispatchAction.STAGE, | |
| unit_id=unit.unit_id, | |
| incident_id=incident.incident_id, | |
| ) | |
| ) | |
| # CANCEL actions (release currently assigned units) | |
| for unit in units_sorted: | |
| if unit.assigned_incident_id is None: | |
| continue | |
| actions.append( | |
| Action( | |
| action_type=DispatchAction.CANCEL, | |
| unit_id=unit.unit_id, | |
| incident_id=unit.assigned_incident_id, | |
| ) | |
| ) | |
| # REASSIGN actions (redirect already-assigned units to a different active incident) | |
| for unit in units_sorted: | |
| if unit.assigned_incident_id is None: | |
| continue | |
| if unit.status not in {UnitStatus.DISPATCHED, UnitStatus.ON_SCENE, UnitStatus.TRANSPORTING}: | |
| continue | |
| for incident in active_incidents_sorted: | |
| if incident.incident_id == unit.assigned_incident_id: | |
| continue | |
| actions.append( | |
| Action( | |
| action_type=DispatchAction.REASSIGN, | |
| unit_id=unit.unit_id, | |
| incident_id=incident.incident_id, | |
| ) | |
| ) | |
| # MUTUAL_AID actions (only for unit types with no local availability) | |
| # Use any existing unit as the "type selector". | |
| available_types = {u.unit_type for u in units_sorted if u.status == UnitStatus.AVAILABLE} | |
| type_to_template_unit: dict[object, str] = {} | |
| for unit in units_sorted: | |
| type_to_template_unit.setdefault(unit.unit_type, unit.unit_id) | |
| for unit_type, template_unit_id in sorted(type_to_template_unit.items(), key=lambda kv: str(kv[0])): | |
| if unit_type in available_types: | |
| continue | |
| for incident in active_incidents_sorted: | |
| actions.append( | |
| Action( | |
| action_type=DispatchAction.MUTUAL_AID, | |
| unit_id=template_unit_id, | |
| incident_id=incident.incident_id, | |
| ) | |
| ) | |
| # UPGRADE / DOWNGRADE actions (severity adjustments) | |
| if ref_unit_id: | |
| for incident in active_incidents_sorted: | |
| if incident.status == IncidentStatus.RESOLVED: | |
| continue | |
| # These candidates are filtered by protocol validation at step-time, | |
| # but we only generate the obviously-relevant ones. | |
| if incident.severity == IncidentSeverity.PRIORITY_1: | |
| actions.append( | |
| Action( | |
| action_type=DispatchAction.DOWNGRADE, | |
| unit_id=ref_unit_id, | |
| incident_id=incident.incident_id, | |
| priority_override=IncidentSeverity.PRIORITY_2, | |
| ) | |
| ) | |
| actions.append( | |
| Action( | |
| action_type=DispatchAction.DOWNGRADE, | |
| unit_id=ref_unit_id, | |
| incident_id=incident.incident_id, | |
| priority_override=IncidentSeverity.PRIORITY_3, | |
| ) | |
| ) | |
| elif incident.severity == IncidentSeverity.PRIORITY_2: | |
| actions.append( | |
| Action( | |
| action_type=DispatchAction.UPGRADE, | |
| unit_id=ref_unit_id, | |
| incident_id=incident.incident_id, | |
| priority_override=IncidentSeverity.PRIORITY_1, | |
| ) | |
| ) | |
| actions.append( | |
| Action( | |
| action_type=DispatchAction.DOWNGRADE, | |
| unit_id=ref_unit_id, | |
| incident_id=incident.incident_id, | |
| priority_override=IncidentSeverity.PRIORITY_3, | |
| ) | |
| ) | |
| else: | |
| actions.append( | |
| Action( | |
| action_type=DispatchAction.UPGRADE, | |
| unit_id=ref_unit_id, | |
| incident_id=incident.incident_id, | |
| priority_override=IncidentSeverity.PRIORITY_2, | |
| ) | |
| ) | |
| actions.append( | |
| Action( | |
| action_type=DispatchAction.UPGRADE, | |
| unit_id=ref_unit_id, | |
| incident_id=incident.incident_id, | |
| priority_override=IncidentSeverity.PRIORITY_1, | |
| ) | |
| ) | |
| # Filter out any actions that violate the protocol validator. | |
| legal: list[Action] = [] | |
| for a in actions: | |
| result = self._validator.validate(self._schema, state, a) | |
| if result.ok: | |
| legal.append(a) | |
| return legal | |
| def step(self, state: State, action: Action) -> tuple[State, Observation]: | |
| validation = self._validator.validate(self._schema, state, action) | |
| if not validation.ok: | |
| state = self._tick(state) | |
| breakdown = { | |
| "response_time": 0.0, | |
| "triage": 0.0, | |
| "survival": 0.0, | |
| "coverage": 0.0, | |
| "protocol": 0.0, | |
| } | |
| return ( | |
| state, | |
| Observation( | |
| result="invalid action", | |
| score=0.0, | |
| protocol_ok=False, | |
| issues=validation.issues, | |
| reward_breakdown=breakdown, | |
| ), | |
| ) | |
| if action.action_type == DispatchAction.DISPATCH: | |
| self._apply_dispatch(state, action) | |
| elif action.action_type == DispatchAction.CANCEL: | |
| self._apply_cancel(state, action) | |
| elif action.action_type == DispatchAction.REASSIGN: | |
| self._apply_reassign(state, action) | |
| elif action.action_type == DispatchAction.STAGE: | |
| self._apply_stage(state, action) | |
| elif action.action_type == DispatchAction.MUTUAL_AID: | |
| self._apply_mutual_aid(state, action) | |
| elif action.action_type in {DispatchAction.UPGRADE, DispatchAction.DOWNGRADE}: | |
| self._apply_severity_change(state, action) | |
| state = self._tick(state) | |
| obs = Observation( | |
| result="ok", | |
| score=0.0, | |
| protocol_ok=True, | |
| issues=validation.issues, | |
| ) | |
| signal, total = self._rewards.compute_reward(state, action, obs) | |
| obs = obs.model_copy(update={"score": total, "reward_breakdown": signal.model_dump()}) | |
| return (state, obs) | |
| def is_terminal(self, state: State) -> bool: | |
| max_steps = int(state.metadata.get("max_steps", MAX_STEPS)) | |
| if state.step_count >= max_steps: | |
| return True | |
| if any(i.status == IncidentStatus.ESCALATED for i in state.incidents.values()): | |
| return True | |
| if state.incidents and all( | |
| i.status == IncidentStatus.RESOLVED for i in state.incidents.values() | |
| ): | |
| return True | |
| return False | |
| def _create_incident(self, state: State) -> IncidentState: | |
| self._incident_counter += 1 | |
| incident_id = f"INC-{self._incident_counter:04d}" | |
| incident_type = self._rng.choice(list(IncidentType)) | |
| severity = self._rng.choice(list(IncidentSeverity)) | |
| width, height = self._schema.grid_size | |
| location_x = float(self._rng.uniform(0.0, float(width))) | |
| location_y = float(self._rng.uniform(0.0, float(height))) | |
| return IncidentState( | |
| incident_id=incident_id, | |
| incident_type=incident_type, | |
| severity=severity, | |
| location_x=location_x, | |
| location_y=location_y, | |
| reported_at_step=state.step_count, | |
| units_assigned=[], | |
| status=IncidentStatus.PENDING, | |
| survival_clock=_severity_deadline_seconds(severity), | |
| ) | |
| def _apply_dispatch(self, state: State, action: Action) -> None: | |
| unit = state.units[action.unit_id] | |
| incident = state.incidents[action.incident_id] | |
| speed = float(self._schema.unit_speeds.get(unit.unit_type, 1.0)) | |
| # Use Manhattan distance to match move_unit_toward physics | |
| dx = abs(unit.location_x - incident.location_x) | |
| dy = abs(unit.location_y - incident.location_y) | |
| manhattan_dist = dx + dy | |
| eta = manhattan_dist / max(speed, 1e-6) | |
| unit.status = UnitStatus.DISPATCHED | |
| unit.assigned_incident_id = incident.incident_id | |
| unit.eta_seconds = max(0.0, float(eta)) | |
| if unit.unit_id not in incident.units_assigned: | |
| incident.units_assigned.append(unit.unit_id) | |
| if incident.status == IncidentStatus.PENDING: | |
| incident.status = IncidentStatus.RESPONDING | |
| def _apply_cancel(self, state: State, action: Action) -> None: | |
| unit = state.units[action.unit_id] | |
| incident = state.incidents[action.incident_id] | |
| unit.status = UnitStatus.AVAILABLE | |
| unit.assigned_incident_id = None | |
| unit.eta_seconds = 0.0 | |
| if unit.unit_id in incident.units_assigned: | |
| incident.units_assigned.remove(unit.unit_id) | |
| if not incident.units_assigned and incident.status in { | |
| IncidentStatus.RESPONDING, | |
| IncidentStatus.ON_SCENE, | |
| }: | |
| incident.status = IncidentStatus.PENDING | |
| incident.survival_clock = _severity_deadline_seconds(incident.severity) | |
| def _apply_reassign(self, state: State, action: Action) -> None: | |
| unit = state.units[action.unit_id] | |
| new_incident = state.incidents[action.incident_id] | |
| old_incident_id = unit.assigned_incident_id | |
| old_incident = state.incidents.get(old_incident_id) if old_incident_id else None | |
| # Remove from the old incident, if present. | |
| if old_incident is not None and unit.unit_id in old_incident.units_assigned: | |
| old_incident.units_assigned.remove(unit.unit_id) | |
| if not old_incident.units_assigned and old_incident.status in { | |
| IncidentStatus.RESPONDING, | |
| IncidentStatus.ON_SCENE, | |
| }: | |
| old_incident.status = IncidentStatus.PENDING | |
| old_incident.survival_clock = _severity_deadline_seconds(old_incident.severity) | |
| # Assign to the new incident like a dispatch. | |
| unit.status = UnitStatus.DISPATCHED | |
| unit.assigned_incident_id = new_incident.incident_id | |
| speed = float(self._schema.unit_speeds.get(unit.unit_type, 1.0)) | |
| dx = abs(unit.location_x - new_incident.location_x) | |
| dy = abs(unit.location_y - new_incident.location_y) | |
| manhattan_dist = dx + dy | |
| eta = manhattan_dist / max(speed, 1e-6) | |
| unit.eta_seconds = max(0.0, float(eta)) | |
| if unit.unit_id not in new_incident.units_assigned: | |
| new_incident.units_assigned.append(unit.unit_id) | |
| if new_incident.status == IncidentStatus.PENDING: | |
| new_incident.status = IncidentStatus.RESPONDING | |
| def _apply_stage(self, state: State, action: Action) -> None: | |
| """Pre-position a unit towards an incident without counting as 'assigned'.""" | |
| unit = state.units[action.unit_id] | |
| incident = state.incidents[action.incident_id] | |
| speed = float(self._schema.unit_speeds.get(unit.unit_type, 1.0)) | |
| dx = abs(unit.location_x - incident.location_x) | |
| dy = abs(unit.location_y - incident.location_y) | |
| manhattan_dist = dx + dy | |
| eta = manhattan_dist / max(speed, 1e-6) | |
| unit.status = UnitStatus.DISPATCHED | |
| unit.assigned_incident_id = incident.incident_id | |
| unit.eta_seconds = max(0.0, float(eta)) | |
| def _apply_mutual_aid(self, state: State, action: Action) -> None: | |
| """Request an external unit of the given type and dispatch it.""" | |
| template = state.units[action.unit_id] | |
| incident = state.incidents[action.incident_id] | |
| counter = int(state.metadata.get("mutual_aid_counter", 0)) + 1 | |
| state.metadata["mutual_aid_counter"] = counter | |
| prefix = template.unit_type.value[:3] | |
| new_unit_id = f"MA-{prefix}-{counter}" | |
| new_unit_id = new_unit_id[:20] | |
| speed = float(self._schema.unit_speeds.get(template.unit_type, 1.0)) | |
| dx = abs(template.location_x - incident.location_x) | |
| dy = abs(template.location_y - incident.location_y) | |
| manhattan_dist = dx + dy | |
| base_eta = manhattan_dist / max(speed, 1e-6) | |
| penalty = float(state.metadata.get("mutual_aid_eta_penalty", 120.0)) | |
| unit = UnitState( | |
| unit_id=new_unit_id, | |
| unit_type=template.unit_type, | |
| status=UnitStatus.DISPATCHED, | |
| location_x=float(template.location_x), | |
| location_y=float(template.location_y), | |
| assigned_incident_id=incident.incident_id, | |
| eta_seconds=max(0.0, float(base_eta + penalty)), | |
| crew_count=int(template.crew_count), | |
| ) | |
| state.units[unit.unit_id] = unit | |
| if unit.unit_id not in incident.units_assigned: | |
| incident.units_assigned.append(unit.unit_id) | |
| if incident.status == IncidentStatus.PENDING: | |
| incident.status = IncidentStatus.RESPONDING | |
| def _apply_severity_change(self, state: State, action: Action) -> None: | |
| if action.priority_override is None: | |
| return | |
| incident = state.incidents[action.incident_id] | |
| incident.severity = action.priority_override | |
| # Update clocks based on current incident phase. | |
| if incident.status in {IncidentStatus.PENDING, IncidentStatus.RESPONDING}: | |
| incident.survival_clock = _severity_deadline_seconds(incident.severity) | |
| elif incident.status == IncidentStatus.ON_SCENE: | |
| incident.survival_clock = _resolve_timer_seconds(incident.severity) | |
| def _tick(self, state: State) -> State: | |
| state.step_count += 1 | |
| state.city_time += DEFAULT_DT_S | |
| # Apply any scheduled unit status changes. | |
| for change in list(state.metadata.get("unit_status_changes", [])): | |
| if int(change.get("at_step", -1)) != state.step_count: | |
| continue | |
| unit_id = str(change.get("unit_id", "")) | |
| if unit_id in state.units: | |
| new_status = UnitStatus(change.get("status")) | |
| unit = state.units[unit_id] | |
| unit.status = new_status | |
| if new_status in {UnitStatus.OUT_OF_SERVICE, UnitStatus.AVAILABLE}: | |
| unit.assigned_incident_id = None | |
| unit.eta_seconds = 0.0 | |
| # Spawn incident waves. | |
| for wave in list(state.metadata.get("waves", [])): | |
| if int(wave.get("at_step", -1)) != state.step_count: | |
| continue | |
| for inc in wave.get("incidents", []): | |
| incident_obj = IncidentState.model_validate(inc) | |
| if incident_obj.incident_id not in state.incidents: | |
| state.incidents[incident_obj.incident_id] = incident_obj | |
| if ( | |
| incident_obj.severity == IncidentSeverity.PRIORITY_1 | |
| and incident_obj.incident_id not in state.metadata.get("p1_seen", []) | |
| ): | |
| state.metadata.setdefault("p1_seen", []).append(incident_obj.incident_id) | |
| for unit in state.units.values(): | |
| if unit.status == UnitStatus.DISPATCHED: | |
| unit.eta_seconds = max(0.0, unit.eta_seconds - DEFAULT_DT_S) | |
| if unit.eta_seconds <= 0.0 and unit.assigned_incident_id is not None: | |
| unit.status = UnitStatus.ON_SCENE | |
| for incident in state.incidents.values(): | |
| if incident.status in {IncidentStatus.PENDING, IncidentStatus.RESPONDING}: | |
| incident.survival_clock = max(0.0, incident.survival_clock - DEFAULT_DT_S) | |
| if incident.survival_clock <= 0.0: | |
| incident.status = IncidentStatus.ESCALATED | |
| failed = state.metadata.setdefault("failed_incidents", []) | |
| if incident.incident_id not in failed: | |
| failed.append(incident.incident_id) | |
| continue | |
| if incident.status == IncidentStatus.RESPONDING: | |
| if any( | |
| state.units[uid].status == UnitStatus.ON_SCENE | |
| for uid in incident.units_assigned | |
| if uid in state.units | |
| ): | |
| incident.status = IncidentStatus.ON_SCENE | |
| incident.survival_clock = _resolve_timer_seconds(incident.severity) | |
| if incident.status == IncidentStatus.ON_SCENE: | |
| incident.survival_clock = max(0.0, incident.survival_clock - DEFAULT_DT_S) | |
| if incident.survival_clock <= 0.0: | |
| incident.status = IncidentStatus.RESOLVED | |
| resolved = state.metadata.setdefault("resolved_incidents", []) | |
| if incident.incident_id not in resolved: | |
| resolved.append(incident.incident_id) | |
| return state | |