| |
| |
| |
| |
| |
|
|
| """Baseline dispatch heuristics and evaluation helpers for ERAS.""" |
|
|
| from __future__ import annotations |
|
|
| import random |
| from dataclasses import dataclass |
|
|
| try: |
| from ..models import HOLD_ACTION_OFFSET |
| from .config import SimulationConfig |
| from .entities import Incident |
| from .simulator import ERASSimulator |
| except ImportError: |
| from models import HOLD_ACTION_OFFSET |
| from server.config import SimulationConfig |
| from server.entities import Incident |
| from server.simulator import ERASSimulator |
|
|
|
|
| @dataclass |
| class EvaluationSummary: |
| avg_response_time: float = 0.0 |
| p95_response_time: float = 0.0 |
| severity_weighted_response_score: float = 0.0 |
| coverage_rate: float = 0.0 |
| utilization: float = 0.0 |
| missed_critical_cases: float = 0.0 |
|
|
|
|
| def random_policy(simulator: ERASSimulator, rng: random.Random | None = None) -> int: |
| chooser = rng or random.Random() |
| valid_actions = [ |
| action_index |
| for action_index, is_valid in enumerate(simulator.build_action_mask()) |
| if is_valid |
| ] |
| if not valid_actions: |
| raise RuntimeError("No valid actions available for random policy") |
| return chooser.choice(valid_actions) |
|
|
|
|
| def nearest_ambulance_policy(simulator: ERASSimulator) -> int: |
| visible_incidents = simulator.get_observable_incidents() |
| free_ambulances = simulator.get_free_ambulances() |
| if not visible_incidents or not free_ambulances: |
| raise RuntimeError("No dispatch action available for nearest heuristic") |
|
|
| best_choice: tuple[float, int, int, int] | None = None |
| for slot_index, incident in enumerate(visible_incidents): |
| for ambulance in free_ambulances: |
| travel_time = simulator._travel_time( |
| ambulance.location, incident.location, simulator.current_time |
| ) |
| severity_rank = -simulator._severity_code(incident.severity.value) |
| candidate = ( |
| travel_time, |
| severity_rank, |
| slot_index, |
| ambulance.ambulance_id, |
| ) |
| if best_choice is None or candidate < best_choice: |
| best_choice = candidate |
|
|
| assert best_choice is not None |
| _travel_time, _severity_rank, incident_slot, ambulance_id = best_choice |
| return simulator.encode_assign_action(ambulance_id, incident_slot) |
|
|
|
|
| def severity_first_policy(simulator: ERASSimulator) -> int: |
| visible_incidents = simulator.get_observable_incidents() |
| free_ambulances = simulator.get_free_ambulances() |
| if not visible_incidents or not free_ambulances: |
| raise RuntimeError("No dispatch action available for severity-first heuristic") |
|
|
| incident = visible_incidents[0] |
| incident_slot = 0 |
| nearest_ambulance = min( |
| free_ambulances, |
| key=lambda ambulance: simulator._travel_time( |
| ambulance.location, incident.location, simulator.current_time |
| ), |
| ) |
| return simulator.encode_assign_action(nearest_ambulance.ambulance_id, incident_slot) |
|
|
|
|
| def evaluate_policy( |
| policy_name: str, |
| policy_fn, |
| seeds: list[int], |
| config: SimulationConfig | None = None, |
| ) -> EvaluationSummary: |
| summaries: list[EvaluationSummary] = [] |
| randomizer = random.Random(0) |
|
|
| for seed in seeds: |
| simulator = ERASSimulator(config=config) |
| simulator.reset(seed=seed) |
|
|
| while not simulator.done: |
| if policy_name == "random": |
| action_index = policy_fn(simulator, randomizer) |
| else: |
| action_index = policy_fn(simulator) |
| simulator.step(action_index) |
|
|
| info = simulator.state.info |
| utilization = ( |
| sum(info.ambulance_utilization) / len(info.ambulance_utilization) |
| if info.ambulance_utilization |
| else 0.0 |
| ) |
| summaries.append( |
| EvaluationSummary( |
| avg_response_time=info.avg_response_time, |
| p95_response_time=info.p95_response_time, |
| severity_weighted_response_score=info.severity_weighted_response_score, |
| coverage_rate=info.coverage_rate, |
| utilization=utilization, |
| missed_critical_cases=float(info.missed_critical), |
| ) |
| ) |
|
|
| return _mean_summary(summaries) |
|
|
|
|
| def evaluate_baselines( |
| num_episodes: int = 100, config: SimulationConfig | None = None |
| ) -> dict[str, EvaluationSummary]: |
| seeds = list(range(num_episodes)) |
| return { |
| "random": evaluate_policy("random", random_policy, seeds, config=config), |
| "nearest": evaluate_policy( |
| "nearest", nearest_ambulance_policy, seeds, config=config |
| ), |
| "severity_first": evaluate_policy( |
| "severity_first", severity_first_policy, seeds, config=config |
| ), |
| } |
|
|
|
|
| def format_comparison_table( |
| baseline_results: dict[str, EvaluationSummary], |
| rl_agent: EvaluationSummary | None = None, |
| ) -> str: |
| rl_summary = rl_agent or EvaluationSummary() |
| rows = [ |
| ( |
| "Avg Response Time", |
| baseline_results["random"].avg_response_time, |
| baseline_results["nearest"].avg_response_time, |
| baseline_results["severity_first"].avg_response_time, |
| rl_summary.avg_response_time, |
| ), |
| ( |
| "95th Percentile", |
| baseline_results["random"].p95_response_time, |
| baseline_results["nearest"].p95_response_time, |
| baseline_results["severity_first"].p95_response_time, |
| rl_summary.p95_response_time, |
| ), |
| ( |
| "Severity-Weighted Score", |
| baseline_results["random"].severity_weighted_response_score, |
| baseline_results["nearest"].severity_weighted_response_score, |
| baseline_results["severity_first"].severity_weighted_response_score, |
| rl_summary.severity_weighted_response_score, |
| ), |
| ( |
| "Coverage Rate", |
| baseline_results["random"].coverage_rate, |
| baseline_results["nearest"].coverage_rate, |
| baseline_results["severity_first"].coverage_rate, |
| rl_summary.coverage_rate, |
| ), |
| ( |
| "Utilization", |
| baseline_results["random"].utilization, |
| baseline_results["nearest"].utilization, |
| baseline_results["severity_first"].utilization, |
| rl_summary.utilization, |
| ), |
| ( |
| "Missed Critical Cases", |
| baseline_results["random"].missed_critical_cases, |
| baseline_results["nearest"].missed_critical_cases, |
| baseline_results["severity_first"].missed_critical_cases, |
| rl_summary.missed_critical_cases, |
| ), |
| ] |
|
|
| lines = [ |
| "| Metric | Random | Nearest | Severity-First | RL Agent |", |
| "|---|---:|---:|---:|---:|", |
| ] |
| for metric, random_value, nearest_value, severity_value, rl_value in rows: |
| lines.append( |
| f"| {metric} | {random_value:.3f} | {nearest_value:.3f} | " |
| f"{severity_value:.3f} | {rl_value:.3f} |" |
| ) |
| return "\n".join(lines) |
|
|
|
|
| def _mean_summary(summaries: list[EvaluationSummary]) -> EvaluationSummary: |
| if not summaries: |
| return EvaluationSummary() |
| count = float(len(summaries)) |
| return EvaluationSummary( |
| avg_response_time=sum(item.avg_response_time for item in summaries) / count, |
| p95_response_time=sum(item.p95_response_time for item in summaries) / count, |
| severity_weighted_response_score=( |
| sum(item.severity_weighted_response_score for item in summaries) / count |
| ), |
| coverage_rate=sum(item.coverage_rate for item in summaries) / count, |
| utilization=sum(item.utilization for item in summaries) / count, |
| missed_critical_cases=( |
| sum(item.missed_critical_cases for item in summaries) / count |
| ), |
| ) |
|
|