Spaces:
Sleeping
Sleeping
| """ | |
| Deterministic terminal reward computation for the MedChain Env environment. | |
| Two reward streams exist: | |
| - Per-step shaping rewards (in medchain_env_environment.py) | |
| - Terminal score on the final end_shift() call — this module | |
| All formulas are deterministic — no LLM judge. | |
| """ | |
| from __future__ import annotations | |
| from typing import TYPE_CHECKING, Dict, List, Set | |
| if TYPE_CHECKING: | |
| from .simulation import SimState | |
| from .tasks import TaskConfig | |
| def compute_reward(state: "SimState", task_config: "TaskConfig") -> float: | |
| """Dispatch to task-specific terminal scorer.""" | |
| if task_config.name == "orientation_ward": | |
| return compute_reward_task0(state, task_config) | |
| elif task_config.name == "single_ward_stable": | |
| return compute_reward_task1(state, task_config) | |
| elif task_config.name == "multi_ward_seasonal": | |
| return compute_reward_task2(state, task_config) | |
| elif task_config.name == "hospital_network_crisis": | |
| return compute_reward_task3(state, task_config) | |
| return 0.0 | |
| def compute_reward_task0(state: "SimState", task_config: "TaskConfig") -> float: | |
| """ | |
| Intro task: score = 0.70 × service_level + 0.30 × ordered_at_least_once | |
| Rewards reading the situation and placing at least one replenishment order. | |
| """ | |
| if not state.daily_demand: | |
| return 0.0 | |
| total_demand = sum(state.daily_demand) | |
| total_fulfilled = sum(state.daily_fulfilled) | |
| service_level = total_fulfilled / max(total_demand, 1) | |
| ordered = 1.0 if state.pipeline_orders or state.total_spend > 0 else 0.0 | |
| return min(1.0, 0.70 * service_level + 0.30 * ordered) | |
| def compute_reward_task1(state: "SimState", task_config: "TaskConfig") -> float: | |
| """ | |
| score = 0.50 × avg_service_level + 0.50 × cost_efficiency_vs_benchmark | |
| """ | |
| if not state.daily_demand: | |
| return 0.0 | |
| total_demand = sum(state.daily_demand) | |
| total_fulfilled = sum(state.daily_fulfilled) | |
| service_level = total_fulfilled / max(total_demand, 1) | |
| avg_unit_cost = ( | |
| sum(p.unit_cost * p.base_demand for p in task_config.products) | |
| / max(sum(p.base_demand for p in task_config.products), 1) | |
| ) | |
| benchmark_spend = total_fulfilled * avg_unit_cost * 1.15 | |
| actual_spend = state.total_spend | |
| if actual_spend <= 0: | |
| cost_efficiency = 0.0 | |
| else: | |
| cost_efficiency = min(1.0, benchmark_spend / actual_spend) | |
| return 0.50 * service_level + 0.50 * cost_efficiency | |
| def compute_reward_task2(state: "SimState", task_config: "TaskConfig") -> float: | |
| """ | |
| score = 0.40 × avg_service_level | |
| + 0.35 × cost_efficiency | |
| + 0.15 × capacity_score | |
| + 0.10 × transfer_efficiency | |
| """ | |
| if not state.daily_demand: | |
| return 0.0 | |
| total_demand = sum(state.daily_demand) | |
| total_fulfilled = sum(state.daily_fulfilled) | |
| service_level = total_fulfilled / max(total_demand, 1) | |
| avg_unit_cost = ( | |
| sum(p.unit_cost * p.base_demand for p in task_config.products) | |
| / max(sum(p.base_demand for p in task_config.products), 1) | |
| ) | |
| benchmark_spend = total_fulfilled * avg_unit_cost * 1.2 | |
| cost_efficiency = min(1.0, benchmark_spend / max(state.total_spend, 0.01)) | |
| total_days = len(state.daily_demand) | |
| capacity_score = max(0.0, 1.0 - state.capacity_violation_days / max(total_days, 1)) | |
| avg_transfers_per_day = state.transfer_count / max(total_days, 1) | |
| transfer_efficiency = max(0.0, 1.0 - max(0.0, avg_transfers_per_day - 10) / 10.0) | |
| return ( | |
| 0.40 * service_level | |
| + 0.35 * cost_efficiency | |
| + 0.15 * capacity_score | |
| + 0.10 * transfer_efficiency | |
| ) | |
| def compute_reward_task3(state: "SimState", task_config: "TaskConfig") -> float: | |
| """ | |
| score = 0.35 × avg_service_level | |
| + 0.25 × cost_efficiency | |
| + 0.20 × (1 - critical_stockout_rate) | |
| + 0.15 × (1 - waste_fraction) | |
| + 0.05 × crisis_response_score | |
| - justification_penalty (capped at 0.15) | |
| """ | |
| if not state.daily_demand: | |
| return 0.0 | |
| total_demand = sum(state.daily_demand) | |
| total_fulfilled = sum(state.daily_fulfilled) | |
| service_level = total_fulfilled / max(total_demand, 1) | |
| avg_unit_cost = ( | |
| sum(p.unit_cost * p.base_demand for p in task_config.products) | |
| / max(sum(p.base_demand for p in task_config.products), 1) | |
| ) | |
| benchmark_spend = total_fulfilled * avg_unit_cost * 1.2 | |
| cost_efficiency = min(1.0, benchmark_spend / max(state.total_spend, 0.01)) | |
| total_crit_dem = sum(state.daily_critical_demand) | |
| total_crit_ful = sum(state.daily_critical_fulfilled) | |
| critical_service = total_crit_ful / max(total_crit_dem, 1) | |
| critical_stockout_rate = 1.0 - critical_service | |
| waste_fraction = min(1.0, state.total_wasted_value / max(state.total_spend, 0.01)) | |
| crisis_response_score = _compute_crisis_response_score(state, task_config) | |
| incoherent_count = sum(1 for r in state.justification_log if not r.is_coherent) | |
| justification_penalty = min(0.15, incoherent_count * 0.05) | |
| score = ( | |
| 0.35 * service_level | |
| + 0.25 * cost_efficiency | |
| + 0.20 * (1.0 - critical_stockout_rate) | |
| + 0.15 * (1.0 - waste_fraction) | |
| + 0.05 * crisis_response_score | |
| - justification_penalty | |
| ) | |
| return max(0.0, min(1.0, score)) | |
| def _compute_crisis_response_score( | |
| state: "SimState", | |
| task_config: "TaskConfig", | |
| ) -> float: | |
| """ | |
| Measures crisis response for MCI and recall events. | |
| Returns 0.0 to 1.0. | |
| """ | |
| score = 0.0 | |
| max_score = 0.0 | |
| mci_event = next((e for e in task_config.events if e.event_id == "mci_activation"), None) | |
| if mci_event: | |
| max_score += 0.6 | |
| total_crit_dem = sum(state.daily_critical_demand) | |
| total_crit_ful = sum(state.daily_critical_fulfilled) | |
| mci_service = total_crit_ful / max(total_crit_dem, 1) | |
| score += 0.6 * mci_service | |
| recall_event = next((e for e in task_config.events if e.event_id == "iv_saline_recall"), None) | |
| if recall_event: | |
| max_score += 0.4 | |
| if state.recall_handled_by_day is not None: | |
| days_delayed = state.recall_handled_by_day - recall_event.trigger_day | |
| if days_delayed <= 0: | |
| score += 0.4 | |
| elif days_delayed <= 2: | |
| score += 0.2 | |
| if max_score == 0: | |
| return 1.0 | |
| return score / max_score | |
| def grade_justification(reason: str, active_event_types: Set[str]) -> bool: | |
| """ | |
| Deterministic keyword-based justification grading. | |
| Returns True if coherent (no penalty), False if incoherent. | |
| """ | |
| CRISIS_KEYWORDS: Dict[str, List[str]] = { | |
| "mci": ["mci", "mass casualty", "trauma", "incident", "accident", | |
| "emergency", "casualties", "blood", "critical patients"], | |
| "supplier_disruption": ["disruption", "delay", "lead time", "supplier", | |
| "shortage", "force majeure", "extended"], | |
| "product_recall": ["recall", "quarantine", "contamination", "lot", | |
| "health authority", "batch", "defective", "compromised"], | |
| "budget_tighten": ["budget", "fiscal", "quarter", "constraint", | |
| "ceiling", "limit", "finance"], | |
| "cold_chain_breach": ["cold chain", "temperature", "breach", | |
| "refriger", "spoilage", "compromised"], | |
| "demand_surge": ["demand", "surge", "increased", "elevated", | |
| "high usage", "outbreak", "flu", "influenza"], | |
| } | |
| GENERIC_KEYWORDS = [ | |
| "urgent", "critical", "shortage", "low stock", | |
| "stockout", "emergency", "insufficient", | |
| ] | |
| reason_lower = reason.lower() | |
| if not active_event_types: | |
| return any(kw in reason_lower for kw in GENERIC_KEYWORDS) | |
| for event_type in active_event_types: | |
| keywords = CRISIS_KEYWORDS.get(event_type, []) | |
| if any(kw in reason_lower for kw in keywords): | |
| return True | |
| return any(kw in reason_lower for kw in GENERIC_KEYWORDS) | |