ashiqabdulkhader's picture
Upload folder using huggingface_hub
508f4b1 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""Event-driven simulator for the Emergency Response Allocation environment."""
from __future__ import annotations
import heapq
import math
import random
from uuid import uuid4
try:
from ..models import (
ACTION_DIM,
HOLD_ACTION_OFFSET,
MAX_OBSERVABLE_INCIDENTS,
NUM_AMBULANCES,
AmbulanceSnapshot,
EmergencyResponseAllocationObservation,
EmergencyResponseAllocationState,
ERASInfo,
IncidentSnapshot,
)
from .config import GridPoint, SimulationConfig
from .entities import (
Ambulance,
AmbulanceStatus,
Event,
EventType,
EVENT_PRIORITIES,
Incident,
IncidentStatus,
SeverityLevel,
ZoneType,
)
except ImportError:
from models import (
ACTION_DIM,
HOLD_ACTION_OFFSET,
MAX_OBSERVABLE_INCIDENTS,
NUM_AMBULANCES,
AmbulanceSnapshot,
EmergencyResponseAllocationObservation,
EmergencyResponseAllocationState,
ERASInfo,
IncidentSnapshot,
)
from server.config import GridPoint, SimulationConfig
from server.entities import (
Ambulance,
AmbulanceStatus,
Event,
EventType,
EVENT_PRIORITIES,
Incident,
IncidentStatus,
SeverityLevel,
ZoneType,
)
class ERASSimulator:
"""Core simulator that powers the OpenEnv-facing environment wrapper."""
def __init__(
self, config: SimulationConfig | None = None, auto_reset: bool = True
) -> None:
self.config = config or SimulationConfig()
self._rng = random.Random()
self._event_sequence = 0
self._all_dispatchable_locations = self._build_dispatchable_locations()
self._next_incident_time: float | None = None
self._reset_internal_state()
if auto_reset:
self.reset()
def reset(
self, seed: int | None = None, episode_id: str | None = None
) -> EmergencyResponseAllocationObservation:
if seed is not None:
self._rng = random.Random(seed)
self._reset_internal_state()
self.episode_id = episode_id or str(uuid4())
self.last_event_type = "reset"
self._initialize_ambulances()
self._schedule_next_incident(self.current_time)
self._advance_until_actionable_event()
return self._build_observation(reward=None)
def step(self, action_index: int) -> EmergencyResponseAllocationObservation:
if self.done:
return self._build_observation(reward=0.0)
mask = self.build_action_mask()
if action_index < 0 or action_index >= ACTION_DIM:
raise ValueError(f"Action index out of range: {action_index}")
if not mask[action_index]:
raise ValueError(f"Invalid action for current state: {action_index}")
self.step_count += 1
reward = 0.0
if action_index < HOLD_ACTION_OFFSET:
ambulance_id, incident_slot = divmod(
action_index, self.config.max_observable_incidents
)
incident = self.get_observable_incidents()[incident_slot]
reward += self._dispatch_incident(ambulance_id, incident.incident_id)
self.last_event_type = "assignment"
if self._has_actionable_assignment():
return self._build_observation(reward=reward)
else:
self.last_event_type = "hold"
reward += self._idle_penalty()
reward += self._advance_until_actionable_event()
return self._build_observation(reward=reward)
@property
def state(self) -> EmergencyResponseAllocationState:
info = self._build_info()
return EmergencyResponseAllocationState(
episode_id=self.episode_id,
step_count=self.step_count,
current_sim_time=self.current_time,
time_of_day=self.time_of_day,
ambulances=self._build_ambulance_snapshots(),
incidents=self._build_incident_snapshots(),
valid_action_mask=self.build_action_mask(),
observation_vector=self.build_observation_vector(),
info=info,
event_queue_size=len(self._event_queue),
pending_incident_count=len(self.get_pending_incidents()),
inflight_incident_count=self._count_inflight_incidents(),
episode_done=self.done,
last_event_type=self.last_event_type,
)
@property
def time_of_day(self) -> float:
return (self.current_time / 60.0) % 24.0
def get_observable_incidents(self) -> list[Incident]:
pending = self.get_pending_incidents()
pending.sort(
key=lambda incident: (
-self.config.severity_weights[incident.severity.value],
-(self.current_time - incident.reported_at),
incident.incident_id,
)
)
return pending[: self.config.max_observable_incidents]
def get_pending_incidents(self) -> list[Incident]:
return [
incident
for incident in self._incidents.values()
if incident.status == IncidentStatus.PENDING
]
def get_free_ambulances(self) -> list[Ambulance]:
return [
ambulance
for ambulance in self._ambulances
if ambulance.status == AmbulanceStatus.FREE
]
def build_action_mask(self) -> list[bool]:
mask = [False] * ACTION_DIM
visible_incidents = self.get_observable_incidents()
if not visible_incidents:
return mask
for ambulance in self._ambulances:
if ambulance.status != AmbulanceStatus.FREE:
continue
for slot_index, _incident in enumerate(visible_incidents):
action_index = (
ambulance.ambulance_id * self.config.max_observable_incidents
+ slot_index
)
mask[action_index] = True
mask[HOLD_ACTION_OFFSET + ambulance.ambulance_id] = True
return mask
def build_observation_vector(self) -> list[float]:
vector: list[float] = []
visible_incidents = self.get_observable_incidents()
for ambulance in self._ambulances:
vector.extend(
[
float(ambulance.location[0]),
float(ambulance.location[1]),
1.0 if ambulance.status == AmbulanceStatus.FREE else 0.0,
max(0.0, ambulance.eta_free_at - self.current_time),
]
)
for slot_index in range(self.config.max_observable_incidents):
if slot_index < len(visible_incidents):
incident = visible_incidents[slot_index]
vector.extend(
[
float(incident.location[0]),
float(incident.location[1]),
float(self._severity_code(incident.severity.value)),
max(0.0, self.current_time - incident.reported_at),
]
)
else:
vector.extend([0.0, 0.0, 0.0, 0.0])
travel_times = self._build_travel_time_matrix(visible_incidents)
for row in travel_times:
vector.extend(row)
vector.append(self.time_of_day)
return vector
def encode_assign_action(self, ambulance_id: int, incident_slot: int) -> int:
return ambulance_id * self.config.max_observable_incidents + incident_slot
def encode_hold_action(self, ambulance_id: int) -> int:
return HOLD_ACTION_OFFSET + ambulance_id
def _reset_internal_state(self) -> None:
self.current_time = 0.0
self.done = False
self.step_count = 0
self.episode_id = str(uuid4())
self.last_event_type = "reset"
self._event_queue: list[Event] = []
self._ambulances: list[Ambulance] = []
self._incidents: dict[int, Incident] = {}
self._incident_counter = 0
self._response_times: list[float] = []
self._severity_weighted_response_score = 0.0
self._incidents_served = 0
self._missed_critical = 0
self._next_incident_time = None
def _initialize_ambulances(self) -> None:
self._ambulances = []
for ambulance_id, depot_id in enumerate(self.config.ambulance_depot_ids):
home_location = self.config.depot_locations[depot_id]
self._ambulances.append(
Ambulance(
ambulance_id=ambulance_id,
depot_id=depot_id,
home_location=home_location,
location=home_location,
status=AmbulanceStatus.FREE,
eta_free_at=0.0,
)
)
def _build_dispatchable_locations(self) -> list[GridPoint]:
locations: list[GridPoint] = []
for x in range(self.config.grid_size):
for y in range(self.config.grid_size):
zone = self._zone_type((x, y))
if zone not in (ZoneType.HOSPITAL, ZoneType.DEPOT):
locations.append((x, y))
return locations
def _zone_type(self, location: GridPoint) -> ZoneType:
if location in self.config.depot_locations:
return ZoneType.DEPOT
if location in self.config.hospital_locations:
return ZoneType.HOSPITAL
min_x, min_y, max_x, max_y = self.config.commercial_bounds
x, y = location
if min_x <= x <= max_x and min_y <= y <= max_y:
return ZoneType.COMMERCIAL
return ZoneType.RESIDENTIAL
def _build_ambulance_snapshots(self) -> list[AmbulanceSnapshot]:
snapshots: list[AmbulanceSnapshot] = []
for ambulance in self._ambulances:
snapshots.append(
AmbulanceSnapshot(
ambulance_id=ambulance.ambulance_id,
x=ambulance.location[0],
y=ambulance.location[1],
status=ambulance.status.value,
eta_free=max(0.0, ambulance.eta_free_at - self.current_time),
depot_id=ambulance.depot_id,
)
)
return snapshots
def _build_incident_snapshots(self) -> list[IncidentSnapshot]:
snapshots: list[IncidentSnapshot] = []
visible_incidents = self.get_observable_incidents()
for slot_index in range(self.config.max_observable_incidents):
if slot_index < len(visible_incidents):
incident = visible_incidents[slot_index]
snapshots.append(
IncidentSnapshot(
incident_id=incident.incident_id,
x=incident.location[0],
y=incident.location[1],
severity=incident.severity.value,
severity_code=self._severity_code(incident.severity.value),
time_since_reported=max(
0.0, self.current_time - incident.reported_at
),
)
)
else:
snapshots.append(IncidentSnapshot())
return snapshots
def _build_travel_time_matrix(
self, visible_incidents: list[Incident] | None = None
) -> list[list[float]]:
incidents = visible_incidents if visible_incidents is not None else self.get_observable_incidents()
matrix: list[list[float]] = []
for ambulance in self._ambulances:
row = [0.0] * self.config.max_observable_incidents
for slot_index, incident in enumerate(incidents):
if ambulance.status == AmbulanceStatus.FREE:
row[slot_index] = self._travel_time(
ambulance.location, incident.location, self.current_time
)
matrix.append(row)
return matrix
def _build_info(self) -> ERASInfo:
avg_response_time = (
sum(self._response_times) / len(self._response_times)
if self._response_times
else 0.0
)
p95_response_time = self._percentile(self._response_times, 95.0)
incidents_total = len(self._incidents)
coverage_rate = (
self._incidents_served / incidents_total if incidents_total else 0.0
)
elapsed = max(self.current_time, 1e-6)
utilization = [
min(1.0, self._current_busy_time(ambulance) / elapsed)
for ambulance in self._ambulances
]
return ERASInfo(
avg_response_time=avg_response_time,
incidents_served=self._incidents_served,
incidents_total=incidents_total,
missed_critical=self._missed_critical,
ambulance_utilization=utilization,
current_sim_time=self.current_time,
p95_response_time=p95_response_time,
severity_weighted_response_score=self._severity_weighted_response_score,
coverage_rate=coverage_rate,
)
def _build_observation(
self, reward: float | None
) -> EmergencyResponseAllocationObservation:
visible_incidents = self.get_observable_incidents()
visible_ids = [-1] * self.config.max_observable_incidents
for slot_index, incident in enumerate(visible_incidents):
visible_ids[slot_index] = incident.incident_id
return EmergencyResponseAllocationObservation(
done=self.done,
reward=reward,
observation_vector=self.build_observation_vector(),
valid_action_mask=self.build_action_mask(),
ambulances=self._build_ambulance_snapshots(),
incidents=self._build_incident_snapshots(),
travel_times=self._build_travel_time_matrix(visible_incidents),
visible_incident_ids=visible_ids,
time_of_day=self.time_of_day,
event_type=self.last_event_type,
info=self._build_info(),
)
def _push_event(self, scheduled_time: float, event_type: EventType, **payload: object) -> None:
if scheduled_time > self.config.episode_duration_minutes:
return
event = Event(
scheduled_time=scheduled_time,
priority=EVENT_PRIORITIES[event_type],
sequence=self._event_sequence,
event_type=event_type,
payload=dict(payload),
)
self._event_sequence += 1
heapq.heappush(self._event_queue, event)
def _schedule_next_incident(self, from_time: float) -> None:
next_time = self._sample_next_incident_time(from_time)
self._next_incident_time = next_time
if next_time is not None:
self._push_event(next_time, EventType.INCIDENT_ARRIVAL)
def _sample_next_incident_time(self, from_time: float) -> float | None:
time_cursor = from_time
boundaries = self.config.arrival_phase_boundaries_minutes
rates = self.config.arrival_rates_per_hour
while time_cursor < self.config.episode_duration_minutes:
phase_index = self._phase_index(time_cursor)
phase_end = boundaries[phase_index + 1]
rate_per_hour = rates[phase_index]
if rate_per_hour <= 0.0:
time_cursor = phase_end
continue
delta_minutes = self._rng.expovariate(rate_per_hour) * 60.0
candidate = time_cursor + delta_minutes
if candidate < phase_end:
return candidate
time_cursor = phase_end
return None
def _phase_index(self, sim_minutes: float) -> int:
boundaries = self.config.arrival_phase_boundaries_minutes
for index in range(len(boundaries) - 1):
if boundaries[index] <= sim_minutes < boundaries[index + 1]:
return index
return len(boundaries) - 2
def _sample_incident_location(self) -> GridPoint:
weights: list[float] = []
for location in self._all_dispatchable_locations:
zone = self._zone_type(location)
weight = 1.0
if zone == ZoneType.COMMERCIAL and self._is_peak_demand_period():
weight = 3.0
elif zone == ZoneType.RESIDENTIAL and self._is_night_demand_period():
weight = 2.0
weights.append(weight)
return self._rng.choices(self._all_dispatchable_locations, weights=weights, k=1)[0]
def _sample_severity(self) -> SeverityLevel:
threshold = self._rng.random()
cumulative = 0.0
for severity_name, probability in self.config.severity_probabilities:
cumulative += probability
if threshold <= cumulative:
return SeverityLevel(severity_name)
return SeverityLevel.LOW
def _sample_service_time(self, severity: SeverityLevel) -> float:
low, high = self.config.scene_time_ranges_minutes[severity.value]
return self._rng.uniform(low, high)
def _travel_time(
self, origin: GridPoint, destination: GridPoint, sim_time_minutes: float
) -> float:
distance = math.dist(origin, destination)
return distance * self._traffic_multiplier(sim_time_minutes)
def _traffic_multiplier(self, sim_time_minutes: float) -> float:
hour = (sim_time_minutes / 60.0) % 24.0
for start, end in self.config.traffic_peak_windows_hours:
if start <= hour < end:
return self.config.traffic_peak_multiplier
for start, end in self.config.traffic_night_windows_hours:
if start <= hour < end:
return self.config.traffic_night_multiplier
return self.config.traffic_offpeak_multiplier
def _is_peak_demand_period(self) -> bool:
hour = self.time_of_day
return (6.0 <= hour < 10.0) or (16.0 <= hour < 20.0)
def _is_night_demand_period(self) -> bool:
hour = self.time_of_day
return hour < 6.0 or hour >= 20.0
def _dispatch_incident(self, ambulance_id: int, incident_id: int) -> float:
ambulance = self._ambulances[ambulance_id]
incident = self._incidents[incident_id]
if ambulance.status != AmbulanceStatus.FREE:
raise ValueError(f"Ambulance {ambulance_id} is not free")
if incident.status != IncidentStatus.PENDING:
raise ValueError(f"Incident {incident_id} is not pending")
travel_to_scene = self._travel_time(
ambulance.location, incident.location, self.current_time
)
hospital = self._nearest_hospital(incident.location)
travel_to_hospital = self._travel_time(
incident.location, hospital, self.current_time + travel_to_scene
)
travel_to_depot = self._travel_time(
hospital,
ambulance.home_location,
self.current_time + travel_to_scene + incident.service_time + travel_to_hospital,
)
ambulance.status = AmbulanceStatus.BUSY
ambulance.assigned_incident_id = incident_id
ambulance.busy_start = self.current_time
ambulance.eta_free_at = (
self.current_time
+ travel_to_scene
+ incident.service_time
+ travel_to_hospital
+ travel_to_depot
)
incident.status = IncidentStatus.DISPATCHED
incident.assigned_ambulance_id = ambulance_id
incident.dispatch_at = self.current_time
incident.estimated_travel_time = travel_to_scene
incident.hospital_location = hospital
self._push_event(
self.current_time + travel_to_scene,
EventType.ARRIVE_SCENE,
ambulance_id=ambulance_id,
incident_id=incident_id,
)
severity_weight = self.config.severity_weights[incident.severity.value]
reward_cfg = self.config.reward
return -reward_cfg.response_time_penalty_weight * severity_weight * travel_to_scene
def _idle_penalty(self) -> float:
pending_incidents = self.get_pending_incidents()
if not pending_incidents:
return 0.0
idle_free_ambulances = len(self.get_free_ambulances())
return -self.config.reward.idle_penalty_weight * idle_free_ambulances
def _advance_until_actionable_event(self) -> float:
reward = 0.0
while not self.done:
if self._should_end_naturally():
self.done = True
self.last_event_type = "natural_end"
break
if not self._event_queue:
self.done = True
self.last_event_type = "queue_exhausted"
break
next_event_time = self._event_queue[0].scheduled_time
if next_event_time >= self.config.episode_duration_minutes:
self.current_time = self.config.episode_duration_minutes
self.done = True
self.last_event_type = "hard_cutoff"
break
self.current_time = next_event_time
while (
self._event_queue
and abs(self._event_queue[0].scheduled_time - next_event_time) < 1e-9
):
event = heapq.heappop(self._event_queue)
reward += self._process_event(event)
if self._catastrophic_failure():
reward += self.config.reward.catastrophic_failure_penalty
self.done = True
self.last_event_type = "catastrophic_failure"
break
if self._has_actionable_assignment():
self.last_event_type = "dispatch_decision"
break
return reward
def _process_event(self, event: Event) -> float:
if event.event_type == EventType.INCIDENT_ARRIVAL:
self._next_incident_time = None
self._create_incident()
self._schedule_next_incident(self.current_time)
return 0.0
if event.event_type == EventType.ARRIVE_SCENE:
incident_id = int(event.payload["incident_id"])
ambulance_id = int(event.payload["ambulance_id"])
incident = self._incidents[incident_id]
ambulance = self._ambulances[ambulance_id]
incident.status = IncidentStatus.ON_SCENE
incident.arrival_at_scene = self.current_time
ambulance.location = incident.location
response_time = self.current_time - incident.reported_at
self._response_times.append(response_time)
self._severity_weighted_response_score += (
self.config.severity_weights[incident.severity.value] * response_time
)
self._incidents_served += 1
if (
incident.severity == SeverityLevel.CRITICAL
and response_time > self.config.critical_response_threshold_minutes
):
self._missed_critical += 1
self._push_event(
self.current_time + incident.service_time,
EventType.SERVICE_COMPLETE,
incident_id=incident_id,
ambulance_id=ambulance_id,
)
return self._arrival_reward(incident, response_time)
if event.event_type == EventType.SERVICE_COMPLETE:
incident_id = int(event.payload["incident_id"])
ambulance_id = int(event.payload["ambulance_id"])
incident = self._incidents[incident_id]
ambulance = self._ambulances[ambulance_id]
incident.status = IncidentStatus.TRANSPORTING
hospital_location = incident.hospital_location or self._nearest_hospital(
incident.location
)
incident.hospital_location = hospital_location
travel_to_hospital = self._travel_time(
incident.location, hospital_location, self.current_time
)
ambulance.location = incident.location
self._push_event(
self.current_time + travel_to_hospital,
EventType.ARRIVE_HOSPITAL,
incident_id=incident_id,
ambulance_id=ambulance_id,
)
return 0.0
if event.event_type == EventType.ARRIVE_HOSPITAL:
incident_id = int(event.payload["incident_id"])
ambulance_id = int(event.payload["ambulance_id"])
incident = self._incidents[incident_id]
ambulance = self._ambulances[ambulance_id]
hospital_location = incident.hospital_location or self._nearest_hospital(
incident.location
)
ambulance.location = hospital_location
incident.status = IncidentStatus.RESOLVED
incident.resolved_at = self.current_time
travel_to_depot = self._travel_time(
hospital_location, ambulance.home_location, self.current_time
)
self._push_event(
self.current_time + travel_to_depot,
EventType.AMBULANCE_FREE,
ambulance_id=ambulance_id,
)
return 0.0
if event.event_type == EventType.AMBULANCE_FREE:
ambulance_id = int(event.payload["ambulance_id"])
ambulance = self._ambulances[ambulance_id]
ambulance.location = ambulance.home_location
ambulance.status = AmbulanceStatus.FREE
ambulance.assigned_incident_id = None
ambulance.completed_jobs += 1
if ambulance.busy_start is not None:
ambulance.busy_time += self.current_time - ambulance.busy_start
ambulance.busy_start = None
ambulance.eta_free_at = self.current_time
return 0.0
return 0.0
def _create_incident(self) -> None:
severity = self._sample_severity()
incident = Incident(
incident_id=self._incident_counter,
location=self._sample_incident_location(),
severity=severity,
reported_at=self.current_time,
service_time=self._sample_service_time(severity),
)
self._incidents[incident.incident_id] = incident
self._incident_counter += 1
def _nearest_hospital(self, location: GridPoint) -> GridPoint:
return min(
self.config.hospital_locations,
key=lambda hospital: math.dist(location, hospital),
)
def _arrival_reward(self, incident: Incident, response_time: float) -> float:
severity_weight = self.config.severity_weights[incident.severity.value]
reward_cfg = self.config.reward
safe_response_time = max(response_time, 1.0)
return (
reward_cfg.successful_rescue_weight * severity_weight * (1.0 / safe_response_time)
- reward_cfg.waiting_time_penalty_weight * severity_weight * response_time
)
def _has_actionable_assignment(self) -> bool:
return bool(self.get_pending_incidents()) and bool(self.get_free_ambulances())
def _should_end_naturally(self) -> bool:
if not self._incidents:
return False
if (
self.current_time
< self.config.episode_duration_minutes - self.config.quiet_window_minutes
):
return False
if self.get_pending_incidents():
return False
if self._count_inflight_incidents() > 0:
return False
if self._next_incident_time is None:
return True
return (self._next_incident_time - self.current_time) > self.config.quiet_window_minutes
def _count_inflight_incidents(self) -> int:
return sum(
1
for incident in self._incidents.values()
if incident.status
in (
IncidentStatus.DISPATCHED,
IncidentStatus.ON_SCENE,
IncidentStatus.TRANSPORTING,
)
)
def _catastrophic_failure(self) -> bool:
overdue_critical = 0
for incident in self._incidents.values():
if incident.severity != SeverityLevel.CRITICAL:
continue
if incident.status in (IncidentStatus.ON_SCENE, IncidentStatus.TRANSPORTING, IncidentStatus.RESOLVED):
continue
if (self.current_time - incident.reported_at) > self.config.critical_response_threshold_minutes:
overdue_critical += 1
return (
overdue_critical > self.config.catastrophic_overdue_critical_limit
)
def _current_busy_time(self, ambulance: Ambulance) -> float:
busy_time = ambulance.busy_time
if ambulance.status == AmbulanceStatus.BUSY and ambulance.busy_start is not None:
busy_time += self.current_time - ambulance.busy_start
return busy_time
def _severity_code(self, severity_name: str) -> int:
return {"low": 1, "moderate": 2, "critical": 3}.get(severity_name, 0)
def _percentile(self, values: list[float], percentile: float) -> float:
if not values:
return 0.0
ordered = sorted(values)
if len(ordered) == 1:
return ordered[0]
index = (percentile / 100.0) * (len(ordered) - 1)
lower = math.floor(index)
upper = math.ceil(index)
if lower == upper:
return ordered[lower]
blend = index - lower
return ordered[lower] * (1.0 - blend) + ordered[upper] * blend