Spaces:
Sleeping
Sleeping
File size: 8,113 Bytes
4afc4db | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 | """
Deterministic terminal reward computation for the MedChain Env environment.
Two reward streams exist:
- Per-step shaping rewards (in medchain_env_environment.py)
- Terminal score on the final end_shift() call — this module
All formulas are deterministic — no LLM judge.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Dict, List, Set
if TYPE_CHECKING:
from .simulation import SimState
from .tasks import TaskConfig
def compute_reward(state: "SimState", task_config: "TaskConfig") -> float:
"""Dispatch to task-specific terminal scorer."""
if task_config.name == "orientation_ward":
return compute_reward_task0(state, task_config)
elif task_config.name == "single_ward_stable":
return compute_reward_task1(state, task_config)
elif task_config.name == "multi_ward_seasonal":
return compute_reward_task2(state, task_config)
elif task_config.name == "hospital_network_crisis":
return compute_reward_task3(state, task_config)
return 0.0
def compute_reward_task0(state: "SimState", task_config: "TaskConfig") -> float:
"""
Intro task: score = 0.70 × service_level + 0.30 × ordered_at_least_once
Rewards reading the situation and placing at least one replenishment order.
"""
if not state.daily_demand:
return 0.0
total_demand = sum(state.daily_demand)
total_fulfilled = sum(state.daily_fulfilled)
service_level = total_fulfilled / max(total_demand, 1)
ordered = 1.0 if state.pipeline_orders or state.total_spend > 0 else 0.0
return min(1.0, 0.70 * service_level + 0.30 * ordered)
def compute_reward_task1(state: "SimState", task_config: "TaskConfig") -> float:
"""
score = 0.50 × avg_service_level + 0.50 × cost_efficiency_vs_benchmark
"""
if not state.daily_demand:
return 0.0
total_demand = sum(state.daily_demand)
total_fulfilled = sum(state.daily_fulfilled)
service_level = total_fulfilled / max(total_demand, 1)
avg_unit_cost = (
sum(p.unit_cost * p.base_demand for p in task_config.products)
/ max(sum(p.base_demand for p in task_config.products), 1)
)
benchmark_spend = total_fulfilled * avg_unit_cost * 1.15
actual_spend = state.total_spend
if actual_spend <= 0:
cost_efficiency = 0.0
else:
cost_efficiency = min(1.0, benchmark_spend / actual_spend)
return 0.50 * service_level + 0.50 * cost_efficiency
def compute_reward_task2(state: "SimState", task_config: "TaskConfig") -> float:
"""
score = 0.40 × avg_service_level
+ 0.35 × cost_efficiency
+ 0.15 × capacity_score
+ 0.10 × transfer_efficiency
"""
if not state.daily_demand:
return 0.0
total_demand = sum(state.daily_demand)
total_fulfilled = sum(state.daily_fulfilled)
service_level = total_fulfilled / max(total_demand, 1)
avg_unit_cost = (
sum(p.unit_cost * p.base_demand for p in task_config.products)
/ max(sum(p.base_demand for p in task_config.products), 1)
)
benchmark_spend = total_fulfilled * avg_unit_cost * 1.2
cost_efficiency = min(1.0, benchmark_spend / max(state.total_spend, 0.01))
total_days = len(state.daily_demand)
capacity_score = max(0.0, 1.0 - state.capacity_violation_days / max(total_days, 1))
avg_transfers_per_day = state.transfer_count / max(total_days, 1)
transfer_efficiency = max(0.0, 1.0 - max(0.0, avg_transfers_per_day - 10) / 10.0)
return (
0.40 * service_level
+ 0.35 * cost_efficiency
+ 0.15 * capacity_score
+ 0.10 * transfer_efficiency
)
def compute_reward_task3(state: "SimState", task_config: "TaskConfig") -> float:
"""
score = 0.35 × avg_service_level
+ 0.25 × cost_efficiency
+ 0.20 × (1 - critical_stockout_rate)
+ 0.15 × (1 - waste_fraction)
+ 0.05 × crisis_response_score
- justification_penalty (capped at 0.15)
"""
if not state.daily_demand:
return 0.0
total_demand = sum(state.daily_demand)
total_fulfilled = sum(state.daily_fulfilled)
service_level = total_fulfilled / max(total_demand, 1)
avg_unit_cost = (
sum(p.unit_cost * p.base_demand for p in task_config.products)
/ max(sum(p.base_demand for p in task_config.products), 1)
)
benchmark_spend = total_fulfilled * avg_unit_cost * 1.2
cost_efficiency = min(1.0, benchmark_spend / max(state.total_spend, 0.01))
total_crit_dem = sum(state.daily_critical_demand)
total_crit_ful = sum(state.daily_critical_fulfilled)
critical_service = total_crit_ful / max(total_crit_dem, 1)
critical_stockout_rate = 1.0 - critical_service
waste_fraction = min(1.0, state.total_wasted_value / max(state.total_spend, 0.01))
crisis_response_score = _compute_crisis_response_score(state, task_config)
incoherent_count = sum(1 for r in state.justification_log if not r.is_coherent)
justification_penalty = min(0.15, incoherent_count * 0.05)
score = (
0.35 * service_level
+ 0.25 * cost_efficiency
+ 0.20 * (1.0 - critical_stockout_rate)
+ 0.15 * (1.0 - waste_fraction)
+ 0.05 * crisis_response_score
- justification_penalty
)
return max(0.0, min(1.0, score))
def _compute_crisis_response_score(
state: "SimState",
task_config: "TaskConfig",
) -> float:
"""
Measures crisis response for MCI and recall events.
Returns 0.0 to 1.0.
"""
score = 0.0
max_score = 0.0
mci_event = next((e for e in task_config.events if e.event_id == "mci_activation"), None)
if mci_event:
max_score += 0.6
total_crit_dem = sum(state.daily_critical_demand)
total_crit_ful = sum(state.daily_critical_fulfilled)
mci_service = total_crit_ful / max(total_crit_dem, 1)
score += 0.6 * mci_service
recall_event = next((e for e in task_config.events if e.event_id == "iv_saline_recall"), None)
if recall_event:
max_score += 0.4
if state.recall_handled_by_day is not None:
days_delayed = state.recall_handled_by_day - recall_event.trigger_day
if days_delayed <= 0:
score += 0.4
elif days_delayed <= 2:
score += 0.2
if max_score == 0:
return 1.0
return score / max_score
def grade_justification(reason: str, active_event_types: Set[str]) -> bool:
"""
Deterministic keyword-based justification grading.
Returns True if coherent (no penalty), False if incoherent.
"""
CRISIS_KEYWORDS: Dict[str, List[str]] = {
"mci": ["mci", "mass casualty", "trauma", "incident", "accident",
"emergency", "casualties", "blood", "critical patients"],
"supplier_disruption": ["disruption", "delay", "lead time", "supplier",
"shortage", "force majeure", "extended"],
"product_recall": ["recall", "quarantine", "contamination", "lot",
"health authority", "batch", "defective", "compromised"],
"budget_tighten": ["budget", "fiscal", "quarter", "constraint",
"ceiling", "limit", "finance"],
"cold_chain_breach": ["cold chain", "temperature", "breach",
"refriger", "spoilage", "compromised"],
"demand_surge": ["demand", "surge", "increased", "elevated",
"high usage", "outbreak", "flu", "influenza"],
}
GENERIC_KEYWORDS = [
"urgent", "critical", "shortage", "low stock",
"stockout", "emergency", "insufficient",
]
reason_lower = reason.lower()
if not active_event_types:
return any(kw in reason_lower for kw in GENERIC_KEYWORDS)
for event_type in active_event_types:
keywords = CRISIS_KEYWORDS.get(event_type, [])
if any(kw in reason_lower for kw in keywords):
return True
return any(kw in reason_lower for kw in GENERIC_KEYWORDS)
|