| |
| |
| |
| |
| |
|
|
| """Event-driven simulator for the Emergency Response Allocation environment.""" |
|
|
| from __future__ import annotations |
|
|
| import heapq |
| import math |
| import random |
| from uuid import uuid4 |
|
|
| try: |
| from ..models import ( |
| ACTION_DIM, |
| HOLD_ACTION_OFFSET, |
| MAX_OBSERVABLE_INCIDENTS, |
| NUM_AMBULANCES, |
| AmbulanceSnapshot, |
| EmergencyResponseAllocationObservation, |
| EmergencyResponseAllocationState, |
| ERASInfo, |
| IncidentSnapshot, |
| ) |
| from .config import GridPoint, SimulationConfig |
| from .entities import ( |
| Ambulance, |
| AmbulanceStatus, |
| Event, |
| EventType, |
| EVENT_PRIORITIES, |
| Incident, |
| IncidentStatus, |
| SeverityLevel, |
| ZoneType, |
| ) |
| except ImportError: |
| from models import ( |
| ACTION_DIM, |
| HOLD_ACTION_OFFSET, |
| MAX_OBSERVABLE_INCIDENTS, |
| NUM_AMBULANCES, |
| AmbulanceSnapshot, |
| EmergencyResponseAllocationObservation, |
| EmergencyResponseAllocationState, |
| ERASInfo, |
| IncidentSnapshot, |
| ) |
| from server.config import GridPoint, SimulationConfig |
| from server.entities import ( |
| Ambulance, |
| AmbulanceStatus, |
| Event, |
| EventType, |
| EVENT_PRIORITIES, |
| Incident, |
| IncidentStatus, |
| SeverityLevel, |
| ZoneType, |
| ) |
|
|
|
|
| class ERASSimulator: |
| """Core simulator that powers the OpenEnv-facing environment wrapper.""" |
|
|
| def __init__( |
| self, config: SimulationConfig | None = None, auto_reset: bool = True |
| ) -> None: |
| self.config = config or SimulationConfig() |
| self._rng = random.Random() |
| self._event_sequence = 0 |
| self._all_dispatchable_locations = self._build_dispatchable_locations() |
| self._next_incident_time: float | None = None |
| self._reset_internal_state() |
| if auto_reset: |
| self.reset() |
|
|
| def reset( |
| self, seed: int | None = None, episode_id: str | None = None |
| ) -> EmergencyResponseAllocationObservation: |
| if seed is not None: |
| self._rng = random.Random(seed) |
| self._reset_internal_state() |
| self.episode_id = episode_id or str(uuid4()) |
| self.last_event_type = "reset" |
| self._initialize_ambulances() |
| self._schedule_next_incident(self.current_time) |
| self._advance_until_actionable_event() |
| return self._build_observation(reward=None) |
|
|
| def step(self, action_index: int) -> EmergencyResponseAllocationObservation: |
| if self.done: |
| return self._build_observation(reward=0.0) |
|
|
| mask = self.build_action_mask() |
| if action_index < 0 or action_index >= ACTION_DIM: |
| raise ValueError(f"Action index out of range: {action_index}") |
| if not mask[action_index]: |
| raise ValueError(f"Invalid action for current state: {action_index}") |
|
|
| self.step_count += 1 |
| reward = 0.0 |
|
|
| if action_index < HOLD_ACTION_OFFSET: |
| ambulance_id, incident_slot = divmod( |
| action_index, self.config.max_observable_incidents |
| ) |
| incident = self.get_observable_incidents()[incident_slot] |
| reward += self._dispatch_incident(ambulance_id, incident.incident_id) |
| self.last_event_type = "assignment" |
| if self._has_actionable_assignment(): |
| return self._build_observation(reward=reward) |
| else: |
| self.last_event_type = "hold" |
| reward += self._idle_penalty() |
|
|
| reward += self._advance_until_actionable_event() |
| return self._build_observation(reward=reward) |
|
|
| @property |
| def state(self) -> EmergencyResponseAllocationState: |
| info = self._build_info() |
| return EmergencyResponseAllocationState( |
| episode_id=self.episode_id, |
| step_count=self.step_count, |
| current_sim_time=self.current_time, |
| time_of_day=self.time_of_day, |
| ambulances=self._build_ambulance_snapshots(), |
| incidents=self._build_incident_snapshots(), |
| valid_action_mask=self.build_action_mask(), |
| observation_vector=self.build_observation_vector(), |
| info=info, |
| event_queue_size=len(self._event_queue), |
| pending_incident_count=len(self.get_pending_incidents()), |
| inflight_incident_count=self._count_inflight_incidents(), |
| episode_done=self.done, |
| last_event_type=self.last_event_type, |
| ) |
|
|
| @property |
| def time_of_day(self) -> float: |
| return (self.current_time / 60.0) % 24.0 |
|
|
| def get_observable_incidents(self) -> list[Incident]: |
| pending = self.get_pending_incidents() |
| pending.sort( |
| key=lambda incident: ( |
| -self.config.severity_weights[incident.severity.value], |
| -(self.current_time - incident.reported_at), |
| incident.incident_id, |
| ) |
| ) |
| return pending[: self.config.max_observable_incidents] |
|
|
| def get_pending_incidents(self) -> list[Incident]: |
| return [ |
| incident |
| for incident in self._incidents.values() |
| if incident.status == IncidentStatus.PENDING |
| ] |
|
|
| def get_free_ambulances(self) -> list[Ambulance]: |
| return [ |
| ambulance |
| for ambulance in self._ambulances |
| if ambulance.status == AmbulanceStatus.FREE |
| ] |
|
|
| def build_action_mask(self) -> list[bool]: |
| mask = [False] * ACTION_DIM |
| visible_incidents = self.get_observable_incidents() |
| if not visible_incidents: |
| return mask |
|
|
| for ambulance in self._ambulances: |
| if ambulance.status != AmbulanceStatus.FREE: |
| continue |
| for slot_index, _incident in enumerate(visible_incidents): |
| action_index = ( |
| ambulance.ambulance_id * self.config.max_observable_incidents |
| + slot_index |
| ) |
| mask[action_index] = True |
| mask[HOLD_ACTION_OFFSET + ambulance.ambulance_id] = True |
| return mask |
|
|
| def build_observation_vector(self) -> list[float]: |
| vector: list[float] = [] |
| visible_incidents = self.get_observable_incidents() |
|
|
| for ambulance in self._ambulances: |
| vector.extend( |
| [ |
| float(ambulance.location[0]), |
| float(ambulance.location[1]), |
| 1.0 if ambulance.status == AmbulanceStatus.FREE else 0.0, |
| max(0.0, ambulance.eta_free_at - self.current_time), |
| ] |
| ) |
|
|
| for slot_index in range(self.config.max_observable_incidents): |
| if slot_index < len(visible_incidents): |
| incident = visible_incidents[slot_index] |
| vector.extend( |
| [ |
| float(incident.location[0]), |
| float(incident.location[1]), |
| float(self._severity_code(incident.severity.value)), |
| max(0.0, self.current_time - incident.reported_at), |
| ] |
| ) |
| else: |
| vector.extend([0.0, 0.0, 0.0, 0.0]) |
|
|
| travel_times = self._build_travel_time_matrix(visible_incidents) |
| for row in travel_times: |
| vector.extend(row) |
|
|
| vector.append(self.time_of_day) |
| return vector |
|
|
| def encode_assign_action(self, ambulance_id: int, incident_slot: int) -> int: |
| return ambulance_id * self.config.max_observable_incidents + incident_slot |
|
|
| def encode_hold_action(self, ambulance_id: int) -> int: |
| return HOLD_ACTION_OFFSET + ambulance_id |
|
|
| def _reset_internal_state(self) -> None: |
| self.current_time = 0.0 |
| self.done = False |
| self.step_count = 0 |
| self.episode_id = str(uuid4()) |
| self.last_event_type = "reset" |
| self._event_queue: list[Event] = [] |
| self._ambulances: list[Ambulance] = [] |
| self._incidents: dict[int, Incident] = {} |
| self._incident_counter = 0 |
| self._response_times: list[float] = [] |
| self._severity_weighted_response_score = 0.0 |
| self._incidents_served = 0 |
| self._missed_critical = 0 |
| self._next_incident_time = None |
|
|
| def _initialize_ambulances(self) -> None: |
| self._ambulances = [] |
| for ambulance_id, depot_id in enumerate(self.config.ambulance_depot_ids): |
| home_location = self.config.depot_locations[depot_id] |
| self._ambulances.append( |
| Ambulance( |
| ambulance_id=ambulance_id, |
| depot_id=depot_id, |
| home_location=home_location, |
| location=home_location, |
| status=AmbulanceStatus.FREE, |
| eta_free_at=0.0, |
| ) |
| ) |
|
|
| def _build_dispatchable_locations(self) -> list[GridPoint]: |
| locations: list[GridPoint] = [] |
| for x in range(self.config.grid_size): |
| for y in range(self.config.grid_size): |
| zone = self._zone_type((x, y)) |
| if zone not in (ZoneType.HOSPITAL, ZoneType.DEPOT): |
| locations.append((x, y)) |
| return locations |
|
|
| def _zone_type(self, location: GridPoint) -> ZoneType: |
| if location in self.config.depot_locations: |
| return ZoneType.DEPOT |
| if location in self.config.hospital_locations: |
| return ZoneType.HOSPITAL |
|
|
| min_x, min_y, max_x, max_y = self.config.commercial_bounds |
| x, y = location |
| if min_x <= x <= max_x and min_y <= y <= max_y: |
| return ZoneType.COMMERCIAL |
| return ZoneType.RESIDENTIAL |
|
|
| def _build_ambulance_snapshots(self) -> list[AmbulanceSnapshot]: |
| snapshots: list[AmbulanceSnapshot] = [] |
| for ambulance in self._ambulances: |
| snapshots.append( |
| AmbulanceSnapshot( |
| ambulance_id=ambulance.ambulance_id, |
| x=ambulance.location[0], |
| y=ambulance.location[1], |
| status=ambulance.status.value, |
| eta_free=max(0.0, ambulance.eta_free_at - self.current_time), |
| depot_id=ambulance.depot_id, |
| ) |
| ) |
| return snapshots |
|
|
| def _build_incident_snapshots(self) -> list[IncidentSnapshot]: |
| snapshots: list[IncidentSnapshot] = [] |
| visible_incidents = self.get_observable_incidents() |
| for slot_index in range(self.config.max_observable_incidents): |
| if slot_index < len(visible_incidents): |
| incident = visible_incidents[slot_index] |
| snapshots.append( |
| IncidentSnapshot( |
| incident_id=incident.incident_id, |
| x=incident.location[0], |
| y=incident.location[1], |
| severity=incident.severity.value, |
| severity_code=self._severity_code(incident.severity.value), |
| time_since_reported=max( |
| 0.0, self.current_time - incident.reported_at |
| ), |
| ) |
| ) |
| else: |
| snapshots.append(IncidentSnapshot()) |
| return snapshots |
|
|
| def _build_travel_time_matrix( |
| self, visible_incidents: list[Incident] | None = None |
| ) -> list[list[float]]: |
| incidents = visible_incidents if visible_incidents is not None else self.get_observable_incidents() |
| matrix: list[list[float]] = [] |
| for ambulance in self._ambulances: |
| row = [0.0] * self.config.max_observable_incidents |
| for slot_index, incident in enumerate(incidents): |
| if ambulance.status == AmbulanceStatus.FREE: |
| row[slot_index] = self._travel_time( |
| ambulance.location, incident.location, self.current_time |
| ) |
| matrix.append(row) |
| return matrix |
|
|
| def _build_info(self) -> ERASInfo: |
| avg_response_time = ( |
| sum(self._response_times) / len(self._response_times) |
| if self._response_times |
| else 0.0 |
| ) |
| p95_response_time = self._percentile(self._response_times, 95.0) |
| incidents_total = len(self._incidents) |
| coverage_rate = ( |
| self._incidents_served / incidents_total if incidents_total else 0.0 |
| ) |
| elapsed = max(self.current_time, 1e-6) |
| utilization = [ |
| min(1.0, self._current_busy_time(ambulance) / elapsed) |
| for ambulance in self._ambulances |
| ] |
| return ERASInfo( |
| avg_response_time=avg_response_time, |
| incidents_served=self._incidents_served, |
| incidents_total=incidents_total, |
| missed_critical=self._missed_critical, |
| ambulance_utilization=utilization, |
| current_sim_time=self.current_time, |
| p95_response_time=p95_response_time, |
| severity_weighted_response_score=self._severity_weighted_response_score, |
| coverage_rate=coverage_rate, |
| ) |
|
|
| def _build_observation( |
| self, reward: float | None |
| ) -> EmergencyResponseAllocationObservation: |
| visible_incidents = self.get_observable_incidents() |
| visible_ids = [-1] * self.config.max_observable_incidents |
| for slot_index, incident in enumerate(visible_incidents): |
| visible_ids[slot_index] = incident.incident_id |
|
|
| return EmergencyResponseAllocationObservation( |
| done=self.done, |
| reward=reward, |
| observation_vector=self.build_observation_vector(), |
| valid_action_mask=self.build_action_mask(), |
| ambulances=self._build_ambulance_snapshots(), |
| incidents=self._build_incident_snapshots(), |
| travel_times=self._build_travel_time_matrix(visible_incidents), |
| visible_incident_ids=visible_ids, |
| time_of_day=self.time_of_day, |
| event_type=self.last_event_type, |
| info=self._build_info(), |
| ) |
|
|
| def _push_event(self, scheduled_time: float, event_type: EventType, **payload: object) -> None: |
| if scheduled_time > self.config.episode_duration_minutes: |
| return |
| event = Event( |
| scheduled_time=scheduled_time, |
| priority=EVENT_PRIORITIES[event_type], |
| sequence=self._event_sequence, |
| event_type=event_type, |
| payload=dict(payload), |
| ) |
| self._event_sequence += 1 |
| heapq.heappush(self._event_queue, event) |
|
|
| def _schedule_next_incident(self, from_time: float) -> None: |
| next_time = self._sample_next_incident_time(from_time) |
| self._next_incident_time = next_time |
| if next_time is not None: |
| self._push_event(next_time, EventType.INCIDENT_ARRIVAL) |
|
|
| def _sample_next_incident_time(self, from_time: float) -> float | None: |
| time_cursor = from_time |
| boundaries = self.config.arrival_phase_boundaries_minutes |
| rates = self.config.arrival_rates_per_hour |
|
|
| while time_cursor < self.config.episode_duration_minutes: |
| phase_index = self._phase_index(time_cursor) |
| phase_end = boundaries[phase_index + 1] |
| rate_per_hour = rates[phase_index] |
| if rate_per_hour <= 0.0: |
| time_cursor = phase_end |
| continue |
|
|
| delta_minutes = self._rng.expovariate(rate_per_hour) * 60.0 |
| candidate = time_cursor + delta_minutes |
| if candidate < phase_end: |
| return candidate |
| time_cursor = phase_end |
| return None |
|
|
| def _phase_index(self, sim_minutes: float) -> int: |
| boundaries = self.config.arrival_phase_boundaries_minutes |
| for index in range(len(boundaries) - 1): |
| if boundaries[index] <= sim_minutes < boundaries[index + 1]: |
| return index |
| return len(boundaries) - 2 |
|
|
| def _sample_incident_location(self) -> GridPoint: |
| weights: list[float] = [] |
| for location in self._all_dispatchable_locations: |
| zone = self._zone_type(location) |
| weight = 1.0 |
| if zone == ZoneType.COMMERCIAL and self._is_peak_demand_period(): |
| weight = 3.0 |
| elif zone == ZoneType.RESIDENTIAL and self._is_night_demand_period(): |
| weight = 2.0 |
| weights.append(weight) |
| return self._rng.choices(self._all_dispatchable_locations, weights=weights, k=1)[0] |
|
|
| def _sample_severity(self) -> SeverityLevel: |
| threshold = self._rng.random() |
| cumulative = 0.0 |
| for severity_name, probability in self.config.severity_probabilities: |
| cumulative += probability |
| if threshold <= cumulative: |
| return SeverityLevel(severity_name) |
| return SeverityLevel.LOW |
|
|
| def _sample_service_time(self, severity: SeverityLevel) -> float: |
| low, high = self.config.scene_time_ranges_minutes[severity.value] |
| return self._rng.uniform(low, high) |
|
|
| def _travel_time( |
| self, origin: GridPoint, destination: GridPoint, sim_time_minutes: float |
| ) -> float: |
| distance = math.dist(origin, destination) |
| return distance * self._traffic_multiplier(sim_time_minutes) |
|
|
| def _traffic_multiplier(self, sim_time_minutes: float) -> float: |
| hour = (sim_time_minutes / 60.0) % 24.0 |
| for start, end in self.config.traffic_peak_windows_hours: |
| if start <= hour < end: |
| return self.config.traffic_peak_multiplier |
| for start, end in self.config.traffic_night_windows_hours: |
| if start <= hour < end: |
| return self.config.traffic_night_multiplier |
| return self.config.traffic_offpeak_multiplier |
|
|
| def _is_peak_demand_period(self) -> bool: |
| hour = self.time_of_day |
| return (6.0 <= hour < 10.0) or (16.0 <= hour < 20.0) |
|
|
| def _is_night_demand_period(self) -> bool: |
| hour = self.time_of_day |
| return hour < 6.0 or hour >= 20.0 |
|
|
| def _dispatch_incident(self, ambulance_id: int, incident_id: int) -> float: |
| ambulance = self._ambulances[ambulance_id] |
| incident = self._incidents[incident_id] |
| if ambulance.status != AmbulanceStatus.FREE: |
| raise ValueError(f"Ambulance {ambulance_id} is not free") |
| if incident.status != IncidentStatus.PENDING: |
| raise ValueError(f"Incident {incident_id} is not pending") |
|
|
| travel_to_scene = self._travel_time( |
| ambulance.location, incident.location, self.current_time |
| ) |
| hospital = self._nearest_hospital(incident.location) |
| travel_to_hospital = self._travel_time( |
| incident.location, hospital, self.current_time + travel_to_scene |
| ) |
| travel_to_depot = self._travel_time( |
| hospital, |
| ambulance.home_location, |
| self.current_time + travel_to_scene + incident.service_time + travel_to_hospital, |
| ) |
|
|
| ambulance.status = AmbulanceStatus.BUSY |
| ambulance.assigned_incident_id = incident_id |
| ambulance.busy_start = self.current_time |
| ambulance.eta_free_at = ( |
| self.current_time |
| + travel_to_scene |
| + incident.service_time |
| + travel_to_hospital |
| + travel_to_depot |
| ) |
|
|
| incident.status = IncidentStatus.DISPATCHED |
| incident.assigned_ambulance_id = ambulance_id |
| incident.dispatch_at = self.current_time |
| incident.estimated_travel_time = travel_to_scene |
| incident.hospital_location = hospital |
|
|
| self._push_event( |
| self.current_time + travel_to_scene, |
| EventType.ARRIVE_SCENE, |
| ambulance_id=ambulance_id, |
| incident_id=incident_id, |
| ) |
|
|
| severity_weight = self.config.severity_weights[incident.severity.value] |
| reward_cfg = self.config.reward |
| return -reward_cfg.response_time_penalty_weight * severity_weight * travel_to_scene |
|
|
| def _idle_penalty(self) -> float: |
| pending_incidents = self.get_pending_incidents() |
| if not pending_incidents: |
| return 0.0 |
| idle_free_ambulances = len(self.get_free_ambulances()) |
| return -self.config.reward.idle_penalty_weight * idle_free_ambulances |
|
|
| def _advance_until_actionable_event(self) -> float: |
| reward = 0.0 |
|
|
| while not self.done: |
| if self._should_end_naturally(): |
| self.done = True |
| self.last_event_type = "natural_end" |
| break |
|
|
| if not self._event_queue: |
| self.done = True |
| self.last_event_type = "queue_exhausted" |
| break |
|
|
| next_event_time = self._event_queue[0].scheduled_time |
| if next_event_time >= self.config.episode_duration_minutes: |
| self.current_time = self.config.episode_duration_minutes |
| self.done = True |
| self.last_event_type = "hard_cutoff" |
| break |
|
|
| self.current_time = next_event_time |
| while ( |
| self._event_queue |
| and abs(self._event_queue[0].scheduled_time - next_event_time) < 1e-9 |
| ): |
| event = heapq.heappop(self._event_queue) |
| reward += self._process_event(event) |
|
|
| if self._catastrophic_failure(): |
| reward += self.config.reward.catastrophic_failure_penalty |
| self.done = True |
| self.last_event_type = "catastrophic_failure" |
| break |
|
|
| if self._has_actionable_assignment(): |
| self.last_event_type = "dispatch_decision" |
| break |
|
|
| return reward |
|
|
| def _process_event(self, event: Event) -> float: |
| if event.event_type == EventType.INCIDENT_ARRIVAL: |
| self._next_incident_time = None |
| self._create_incident() |
| self._schedule_next_incident(self.current_time) |
| return 0.0 |
|
|
| if event.event_type == EventType.ARRIVE_SCENE: |
| incident_id = int(event.payload["incident_id"]) |
| ambulance_id = int(event.payload["ambulance_id"]) |
| incident = self._incidents[incident_id] |
| ambulance = self._ambulances[ambulance_id] |
| incident.status = IncidentStatus.ON_SCENE |
| incident.arrival_at_scene = self.current_time |
| ambulance.location = incident.location |
| response_time = self.current_time - incident.reported_at |
| self._response_times.append(response_time) |
| self._severity_weighted_response_score += ( |
| self.config.severity_weights[incident.severity.value] * response_time |
| ) |
| self._incidents_served += 1 |
| if ( |
| incident.severity == SeverityLevel.CRITICAL |
| and response_time > self.config.critical_response_threshold_minutes |
| ): |
| self._missed_critical += 1 |
|
|
| self._push_event( |
| self.current_time + incident.service_time, |
| EventType.SERVICE_COMPLETE, |
| incident_id=incident_id, |
| ambulance_id=ambulance_id, |
| ) |
| return self._arrival_reward(incident, response_time) |
|
|
| if event.event_type == EventType.SERVICE_COMPLETE: |
| incident_id = int(event.payload["incident_id"]) |
| ambulance_id = int(event.payload["ambulance_id"]) |
| incident = self._incidents[incident_id] |
| ambulance = self._ambulances[ambulance_id] |
| incident.status = IncidentStatus.TRANSPORTING |
| hospital_location = incident.hospital_location or self._nearest_hospital( |
| incident.location |
| ) |
| incident.hospital_location = hospital_location |
| travel_to_hospital = self._travel_time( |
| incident.location, hospital_location, self.current_time |
| ) |
| ambulance.location = incident.location |
| self._push_event( |
| self.current_time + travel_to_hospital, |
| EventType.ARRIVE_HOSPITAL, |
| incident_id=incident_id, |
| ambulance_id=ambulance_id, |
| ) |
| return 0.0 |
|
|
| if event.event_type == EventType.ARRIVE_HOSPITAL: |
| incident_id = int(event.payload["incident_id"]) |
| ambulance_id = int(event.payload["ambulance_id"]) |
| incident = self._incidents[incident_id] |
| ambulance = self._ambulances[ambulance_id] |
| hospital_location = incident.hospital_location or self._nearest_hospital( |
| incident.location |
| ) |
| ambulance.location = hospital_location |
| incident.status = IncidentStatus.RESOLVED |
| incident.resolved_at = self.current_time |
| travel_to_depot = self._travel_time( |
| hospital_location, ambulance.home_location, self.current_time |
| ) |
| self._push_event( |
| self.current_time + travel_to_depot, |
| EventType.AMBULANCE_FREE, |
| ambulance_id=ambulance_id, |
| ) |
| return 0.0 |
|
|
| if event.event_type == EventType.AMBULANCE_FREE: |
| ambulance_id = int(event.payload["ambulance_id"]) |
| ambulance = self._ambulances[ambulance_id] |
| ambulance.location = ambulance.home_location |
| ambulance.status = AmbulanceStatus.FREE |
| ambulance.assigned_incident_id = None |
| ambulance.completed_jobs += 1 |
| if ambulance.busy_start is not None: |
| ambulance.busy_time += self.current_time - ambulance.busy_start |
| ambulance.busy_start = None |
| ambulance.eta_free_at = self.current_time |
| return 0.0 |
|
|
| return 0.0 |
|
|
| def _create_incident(self) -> None: |
| severity = self._sample_severity() |
| incident = Incident( |
| incident_id=self._incident_counter, |
| location=self._sample_incident_location(), |
| severity=severity, |
| reported_at=self.current_time, |
| service_time=self._sample_service_time(severity), |
| ) |
| self._incidents[incident.incident_id] = incident |
| self._incident_counter += 1 |
|
|
| def _nearest_hospital(self, location: GridPoint) -> GridPoint: |
| return min( |
| self.config.hospital_locations, |
| key=lambda hospital: math.dist(location, hospital), |
| ) |
|
|
| def _arrival_reward(self, incident: Incident, response_time: float) -> float: |
| severity_weight = self.config.severity_weights[incident.severity.value] |
| reward_cfg = self.config.reward |
| safe_response_time = max(response_time, 1.0) |
| return ( |
| reward_cfg.successful_rescue_weight * severity_weight * (1.0 / safe_response_time) |
| - reward_cfg.waiting_time_penalty_weight * severity_weight * response_time |
| ) |
|
|
| def _has_actionable_assignment(self) -> bool: |
| return bool(self.get_pending_incidents()) and bool(self.get_free_ambulances()) |
|
|
| def _should_end_naturally(self) -> bool: |
| if not self._incidents: |
| return False |
| if ( |
| self.current_time |
| < self.config.episode_duration_minutes - self.config.quiet_window_minutes |
| ): |
| return False |
| if self.get_pending_incidents(): |
| return False |
| if self._count_inflight_incidents() > 0: |
| return False |
| if self._next_incident_time is None: |
| return True |
| return (self._next_incident_time - self.current_time) > self.config.quiet_window_minutes |
|
|
| def _count_inflight_incidents(self) -> int: |
| return sum( |
| 1 |
| for incident in self._incidents.values() |
| if incident.status |
| in ( |
| IncidentStatus.DISPATCHED, |
| IncidentStatus.ON_SCENE, |
| IncidentStatus.TRANSPORTING, |
| ) |
| ) |
|
|
| def _catastrophic_failure(self) -> bool: |
| overdue_critical = 0 |
| for incident in self._incidents.values(): |
| if incident.severity != SeverityLevel.CRITICAL: |
| continue |
| if incident.status in (IncidentStatus.ON_SCENE, IncidentStatus.TRANSPORTING, IncidentStatus.RESOLVED): |
| continue |
| if (self.current_time - incident.reported_at) > self.config.critical_response_threshold_minutes: |
| overdue_critical += 1 |
| return ( |
| overdue_critical > self.config.catastrophic_overdue_critical_limit |
| ) |
|
|
| def _current_busy_time(self, ambulance: Ambulance) -> float: |
| busy_time = ambulance.busy_time |
| if ambulance.status == AmbulanceStatus.BUSY and ambulance.busy_start is not None: |
| busy_time += self.current_time - ambulance.busy_start |
| return busy_time |
|
|
| def _severity_code(self, severity_name: str) -> int: |
| return {"low": 1, "moderate": 2, "critical": 3}.get(severity_name, 0) |
|
|
| def _percentile(self, values: list[float], percentile: float) -> float: |
| if not values: |
| return 0.0 |
| ordered = sorted(values) |
| if len(ordered) == 1: |
| return ordered[0] |
| index = (percentile / 100.0) * (len(ordered) - 1) |
| lower = math.floor(index) |
| upper = math.ceil(index) |
| if lower == upper: |
| return ordered[lower] |
| blend = index - lower |
| return ordered[lower] * (1.0 - blend) + ordered[upper] * blend |
|
|