nik-55's picture
Upload folder using huggingface_hub
4afc4db verified
"""
Deterministic terminal reward computation for the MedChain Env environment.
Two reward streams exist:
- Per-step shaping rewards (in medchain_env_environment.py)
- Terminal score on the final end_shift() call — this module
All formulas are deterministic — no LLM judge.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Dict, List, Set
if TYPE_CHECKING:
from .simulation import SimState
from .tasks import TaskConfig
def compute_reward(state: "SimState", task_config: "TaskConfig") -> float:
"""Dispatch to task-specific terminal scorer."""
if task_config.name == "orientation_ward":
return compute_reward_task0(state, task_config)
elif task_config.name == "single_ward_stable":
return compute_reward_task1(state, task_config)
elif task_config.name == "multi_ward_seasonal":
return compute_reward_task2(state, task_config)
elif task_config.name == "hospital_network_crisis":
return compute_reward_task3(state, task_config)
return 0.0
def compute_reward_task0(state: "SimState", task_config: "TaskConfig") -> float:
"""
Intro task: score = 0.70 × service_level + 0.30 × ordered_at_least_once
Rewards reading the situation and placing at least one replenishment order.
"""
if not state.daily_demand:
return 0.0
total_demand = sum(state.daily_demand)
total_fulfilled = sum(state.daily_fulfilled)
service_level = total_fulfilled / max(total_demand, 1)
ordered = 1.0 if state.pipeline_orders or state.total_spend > 0 else 0.0
return min(1.0, 0.70 * service_level + 0.30 * ordered)
def compute_reward_task1(state: "SimState", task_config: "TaskConfig") -> float:
"""
score = 0.50 × avg_service_level + 0.50 × cost_efficiency_vs_benchmark
"""
if not state.daily_demand:
return 0.0
total_demand = sum(state.daily_demand)
total_fulfilled = sum(state.daily_fulfilled)
service_level = total_fulfilled / max(total_demand, 1)
avg_unit_cost = (
sum(p.unit_cost * p.base_demand for p in task_config.products)
/ max(sum(p.base_demand for p in task_config.products), 1)
)
benchmark_spend = total_fulfilled * avg_unit_cost * 1.15
actual_spend = state.total_spend
if actual_spend <= 0:
cost_efficiency = 0.0
else:
cost_efficiency = min(1.0, benchmark_spend / actual_spend)
return 0.50 * service_level + 0.50 * cost_efficiency
def compute_reward_task2(state: "SimState", task_config: "TaskConfig") -> float:
"""
score = 0.40 × avg_service_level
+ 0.35 × cost_efficiency
+ 0.15 × capacity_score
+ 0.10 × transfer_efficiency
"""
if not state.daily_demand:
return 0.0
total_demand = sum(state.daily_demand)
total_fulfilled = sum(state.daily_fulfilled)
service_level = total_fulfilled / max(total_demand, 1)
avg_unit_cost = (
sum(p.unit_cost * p.base_demand for p in task_config.products)
/ max(sum(p.base_demand for p in task_config.products), 1)
)
benchmark_spend = total_fulfilled * avg_unit_cost * 1.2
cost_efficiency = min(1.0, benchmark_spend / max(state.total_spend, 0.01))
total_days = len(state.daily_demand)
capacity_score = max(0.0, 1.0 - state.capacity_violation_days / max(total_days, 1))
avg_transfers_per_day = state.transfer_count / max(total_days, 1)
transfer_efficiency = max(0.0, 1.0 - max(0.0, avg_transfers_per_day - 10) / 10.0)
return (
0.40 * service_level
+ 0.35 * cost_efficiency
+ 0.15 * capacity_score
+ 0.10 * transfer_efficiency
)
def compute_reward_task3(state: "SimState", task_config: "TaskConfig") -> float:
"""
score = 0.35 × avg_service_level
+ 0.25 × cost_efficiency
+ 0.20 × (1 - critical_stockout_rate)
+ 0.15 × (1 - waste_fraction)
+ 0.05 × crisis_response_score
- justification_penalty (capped at 0.15)
"""
if not state.daily_demand:
return 0.0
total_demand = sum(state.daily_demand)
total_fulfilled = sum(state.daily_fulfilled)
service_level = total_fulfilled / max(total_demand, 1)
avg_unit_cost = (
sum(p.unit_cost * p.base_demand for p in task_config.products)
/ max(sum(p.base_demand for p in task_config.products), 1)
)
benchmark_spend = total_fulfilled * avg_unit_cost * 1.2
cost_efficiency = min(1.0, benchmark_spend / max(state.total_spend, 0.01))
total_crit_dem = sum(state.daily_critical_demand)
total_crit_ful = sum(state.daily_critical_fulfilled)
critical_service = total_crit_ful / max(total_crit_dem, 1)
critical_stockout_rate = 1.0 - critical_service
waste_fraction = min(1.0, state.total_wasted_value / max(state.total_spend, 0.01))
crisis_response_score = _compute_crisis_response_score(state, task_config)
incoherent_count = sum(1 for r in state.justification_log if not r.is_coherent)
justification_penalty = min(0.15, incoherent_count * 0.05)
score = (
0.35 * service_level
+ 0.25 * cost_efficiency
+ 0.20 * (1.0 - critical_stockout_rate)
+ 0.15 * (1.0 - waste_fraction)
+ 0.05 * crisis_response_score
- justification_penalty
)
return max(0.0, min(1.0, score))
def _compute_crisis_response_score(
state: "SimState",
task_config: "TaskConfig",
) -> float:
"""
Measures crisis response for MCI and recall events.
Returns 0.0 to 1.0.
"""
score = 0.0
max_score = 0.0
mci_event = next((e for e in task_config.events if e.event_id == "mci_activation"), None)
if mci_event:
max_score += 0.6
total_crit_dem = sum(state.daily_critical_demand)
total_crit_ful = sum(state.daily_critical_fulfilled)
mci_service = total_crit_ful / max(total_crit_dem, 1)
score += 0.6 * mci_service
recall_event = next((e for e in task_config.events if e.event_id == "iv_saline_recall"), None)
if recall_event:
max_score += 0.4
if state.recall_handled_by_day is not None:
days_delayed = state.recall_handled_by_day - recall_event.trigger_day
if days_delayed <= 0:
score += 0.4
elif days_delayed <= 2:
score += 0.2
if max_score == 0:
return 1.0
return score / max_score
def grade_justification(reason: str, active_event_types: Set[str]) -> bool:
"""
Deterministic keyword-based justification grading.
Returns True if coherent (no penalty), False if incoherent.
"""
CRISIS_KEYWORDS: Dict[str, List[str]] = {
"mci": ["mci", "mass casualty", "trauma", "incident", "accident",
"emergency", "casualties", "blood", "critical patients"],
"supplier_disruption": ["disruption", "delay", "lead time", "supplier",
"shortage", "force majeure", "extended"],
"product_recall": ["recall", "quarantine", "contamination", "lot",
"health authority", "batch", "defective", "compromised"],
"budget_tighten": ["budget", "fiscal", "quarter", "constraint",
"ceiling", "limit", "finance"],
"cold_chain_breach": ["cold chain", "temperature", "breach",
"refriger", "spoilage", "compromised"],
"demand_surge": ["demand", "surge", "increased", "elevated",
"high usage", "outbreak", "flu", "influenza"],
}
GENERIC_KEYWORDS = [
"urgent", "critical", "shortage", "low stock",
"stockout", "emergency", "insufficient",
]
reason_lower = reason.lower()
if not active_event_types:
return any(kw in reason_lower for kw in GENERIC_KEYWORDS)
for event_type in active_event_types:
keywords = CRISIS_KEYWORDS.get(event_type, [])
if any(kw in reason_lower for kw in keywords):
return True
return any(kw in reason_lower for kw in GENERIC_KEYWORDS)