ashiqabdulkhader's picture
Upload folder using huggingface_hub
508f4b1 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""Baseline dispatch heuristics and evaluation helpers for ERAS."""
from __future__ import annotations
import random
from dataclasses import dataclass
try:
from ..models import HOLD_ACTION_OFFSET
from .config import SimulationConfig
from .entities import Incident
from .simulator import ERASSimulator
except ImportError:
from models import HOLD_ACTION_OFFSET
from server.config import SimulationConfig
from server.entities import Incident
from server.simulator import ERASSimulator
@dataclass
class EvaluationSummary:
avg_response_time: float = 0.0
p95_response_time: float = 0.0
severity_weighted_response_score: float = 0.0
coverage_rate: float = 0.0
utilization: float = 0.0
missed_critical_cases: float = 0.0
def random_policy(simulator: ERASSimulator, rng: random.Random | None = None) -> int:
chooser = rng or random.Random()
valid_actions = [
action_index
for action_index, is_valid in enumerate(simulator.build_action_mask())
if is_valid
]
if not valid_actions:
raise RuntimeError("No valid actions available for random policy")
return chooser.choice(valid_actions)
def nearest_ambulance_policy(simulator: ERASSimulator) -> int:
visible_incidents = simulator.get_observable_incidents()
free_ambulances = simulator.get_free_ambulances()
if not visible_incidents or not free_ambulances:
raise RuntimeError("No dispatch action available for nearest heuristic")
best_choice: tuple[float, int, int, int] | None = None
for slot_index, incident in enumerate(visible_incidents):
for ambulance in free_ambulances:
travel_time = simulator._travel_time( # noqa: SLF001
ambulance.location, incident.location, simulator.current_time
)
severity_rank = -simulator._severity_code(incident.severity.value) # noqa: SLF001
candidate = (
travel_time,
severity_rank,
slot_index,
ambulance.ambulance_id,
)
if best_choice is None or candidate < best_choice:
best_choice = candidate
assert best_choice is not None
_travel_time, _severity_rank, incident_slot, ambulance_id = best_choice
return simulator.encode_assign_action(ambulance_id, incident_slot)
def severity_first_policy(simulator: ERASSimulator) -> int:
visible_incidents = simulator.get_observable_incidents()
free_ambulances = simulator.get_free_ambulances()
if not visible_incidents or not free_ambulances:
raise RuntimeError("No dispatch action available for severity-first heuristic")
incident = visible_incidents[0]
incident_slot = 0
nearest_ambulance = min(
free_ambulances,
key=lambda ambulance: simulator._travel_time( # noqa: SLF001
ambulance.location, incident.location, simulator.current_time
),
)
return simulator.encode_assign_action(nearest_ambulance.ambulance_id, incident_slot)
def evaluate_policy(
policy_name: str,
policy_fn,
seeds: list[int],
config: SimulationConfig | None = None,
) -> EvaluationSummary:
summaries: list[EvaluationSummary] = []
randomizer = random.Random(0)
for seed in seeds:
simulator = ERASSimulator(config=config)
simulator.reset(seed=seed)
while not simulator.done:
if policy_name == "random":
action_index = policy_fn(simulator, randomizer)
else:
action_index = policy_fn(simulator)
simulator.step(action_index)
info = simulator.state.info
utilization = (
sum(info.ambulance_utilization) / len(info.ambulance_utilization)
if info.ambulance_utilization
else 0.0
)
summaries.append(
EvaluationSummary(
avg_response_time=info.avg_response_time,
p95_response_time=info.p95_response_time,
severity_weighted_response_score=info.severity_weighted_response_score,
coverage_rate=info.coverage_rate,
utilization=utilization,
missed_critical_cases=float(info.missed_critical),
)
)
return _mean_summary(summaries)
def evaluate_baselines(
num_episodes: int = 100, config: SimulationConfig | None = None
) -> dict[str, EvaluationSummary]:
seeds = list(range(num_episodes))
return {
"random": evaluate_policy("random", random_policy, seeds, config=config),
"nearest": evaluate_policy(
"nearest", nearest_ambulance_policy, seeds, config=config
),
"severity_first": evaluate_policy(
"severity_first", severity_first_policy, seeds, config=config
),
}
def format_comparison_table(
baseline_results: dict[str, EvaluationSummary],
rl_agent: EvaluationSummary | None = None,
) -> str:
rl_summary = rl_agent or EvaluationSummary()
rows = [
(
"Avg Response Time",
baseline_results["random"].avg_response_time,
baseline_results["nearest"].avg_response_time,
baseline_results["severity_first"].avg_response_time,
rl_summary.avg_response_time,
),
(
"95th Percentile",
baseline_results["random"].p95_response_time,
baseline_results["nearest"].p95_response_time,
baseline_results["severity_first"].p95_response_time,
rl_summary.p95_response_time,
),
(
"Severity-Weighted Score",
baseline_results["random"].severity_weighted_response_score,
baseline_results["nearest"].severity_weighted_response_score,
baseline_results["severity_first"].severity_weighted_response_score,
rl_summary.severity_weighted_response_score,
),
(
"Coverage Rate",
baseline_results["random"].coverage_rate,
baseline_results["nearest"].coverage_rate,
baseline_results["severity_first"].coverage_rate,
rl_summary.coverage_rate,
),
(
"Utilization",
baseline_results["random"].utilization,
baseline_results["nearest"].utilization,
baseline_results["severity_first"].utilization,
rl_summary.utilization,
),
(
"Missed Critical Cases",
baseline_results["random"].missed_critical_cases,
baseline_results["nearest"].missed_critical_cases,
baseline_results["severity_first"].missed_critical_cases,
rl_summary.missed_critical_cases,
),
]
lines = [
"| Metric | Random | Nearest | Severity-First | RL Agent |",
"|---|---:|---:|---:|---:|",
]
for metric, random_value, nearest_value, severity_value, rl_value in rows:
lines.append(
f"| {metric} | {random_value:.3f} | {nearest_value:.3f} | "
f"{severity_value:.3f} | {rl_value:.3f} |"
)
return "\n".join(lines)
def _mean_summary(summaries: list[EvaluationSummary]) -> EvaluationSummary:
if not summaries:
return EvaluationSummary()
count = float(len(summaries))
return EvaluationSummary(
avg_response_time=sum(item.avg_response_time for item in summaries) / count,
p95_response_time=sum(item.p95_response_time for item in summaries) / count,
severity_weighted_response_score=(
sum(item.severity_weighted_response_score for item in summaries) / count
),
coverage_rate=sum(item.coverage_rate for item in summaries) / count,
utilization=sum(item.utilization for item in summaries) / count,
missed_critical_cases=(
sum(item.missed_critical_cases for item in summaries) / count
),
)