""" Deterministic terminal reward computation for the MedChain Env environment. Two reward streams exist: - Per-step shaping rewards (in medchain_env_environment.py) - Terminal score on the final end_shift() call — this module All formulas are deterministic — no LLM judge. """ from __future__ import annotations from typing import TYPE_CHECKING, Dict, List, Set if TYPE_CHECKING: from .simulation import SimState from .tasks import TaskConfig def compute_reward(state: "SimState", task_config: "TaskConfig") -> float: """Dispatch to task-specific terminal scorer.""" if task_config.name == "orientation_ward": return compute_reward_task0(state, task_config) elif task_config.name == "single_ward_stable": return compute_reward_task1(state, task_config) elif task_config.name == "multi_ward_seasonal": return compute_reward_task2(state, task_config) elif task_config.name == "hospital_network_crisis": return compute_reward_task3(state, task_config) return 0.0 def compute_reward_task0(state: "SimState", task_config: "TaskConfig") -> float: """ Intro task: score = 0.70 × service_level + 0.30 × ordered_at_least_once Rewards reading the situation and placing at least one replenishment order. """ if not state.daily_demand: return 0.0 total_demand = sum(state.daily_demand) total_fulfilled = sum(state.daily_fulfilled) service_level = total_fulfilled / max(total_demand, 1) ordered = 1.0 if state.pipeline_orders or state.total_spend > 0 else 0.0 return min(1.0, 0.70 * service_level + 0.30 * ordered) def compute_reward_task1(state: "SimState", task_config: "TaskConfig") -> float: """ score = 0.50 × avg_service_level + 0.50 × cost_efficiency_vs_benchmark """ if not state.daily_demand: return 0.0 total_demand = sum(state.daily_demand) total_fulfilled = sum(state.daily_fulfilled) service_level = total_fulfilled / max(total_demand, 1) avg_unit_cost = ( sum(p.unit_cost * p.base_demand for p in task_config.products) / max(sum(p.base_demand for p in task_config.products), 1) ) benchmark_spend = total_fulfilled * avg_unit_cost * 1.15 actual_spend = state.total_spend if actual_spend <= 0: cost_efficiency = 0.0 else: cost_efficiency = min(1.0, benchmark_spend / actual_spend) return 0.50 * service_level + 0.50 * cost_efficiency def compute_reward_task2(state: "SimState", task_config: "TaskConfig") -> float: """ score = 0.40 × avg_service_level + 0.35 × cost_efficiency + 0.15 × capacity_score + 0.10 × transfer_efficiency """ if not state.daily_demand: return 0.0 total_demand = sum(state.daily_demand) total_fulfilled = sum(state.daily_fulfilled) service_level = total_fulfilled / max(total_demand, 1) avg_unit_cost = ( sum(p.unit_cost * p.base_demand for p in task_config.products) / max(sum(p.base_demand for p in task_config.products), 1) ) benchmark_spend = total_fulfilled * avg_unit_cost * 1.2 cost_efficiency = min(1.0, benchmark_spend / max(state.total_spend, 0.01)) total_days = len(state.daily_demand) capacity_score = max(0.0, 1.0 - state.capacity_violation_days / max(total_days, 1)) avg_transfers_per_day = state.transfer_count / max(total_days, 1) transfer_efficiency = max(0.0, 1.0 - max(0.0, avg_transfers_per_day - 10) / 10.0) return ( 0.40 * service_level + 0.35 * cost_efficiency + 0.15 * capacity_score + 0.10 * transfer_efficiency ) def compute_reward_task3(state: "SimState", task_config: "TaskConfig") -> float: """ score = 0.35 × avg_service_level + 0.25 × cost_efficiency + 0.20 × (1 - critical_stockout_rate) + 0.15 × (1 - waste_fraction) + 0.05 × crisis_response_score - justification_penalty (capped at 0.15) """ if not state.daily_demand: return 0.0 total_demand = sum(state.daily_demand) total_fulfilled = sum(state.daily_fulfilled) service_level = total_fulfilled / max(total_demand, 1) avg_unit_cost = ( sum(p.unit_cost * p.base_demand for p in task_config.products) / max(sum(p.base_demand for p in task_config.products), 1) ) benchmark_spend = total_fulfilled * avg_unit_cost * 1.2 cost_efficiency = min(1.0, benchmark_spend / max(state.total_spend, 0.01)) total_crit_dem = sum(state.daily_critical_demand) total_crit_ful = sum(state.daily_critical_fulfilled) critical_service = total_crit_ful / max(total_crit_dem, 1) critical_stockout_rate = 1.0 - critical_service waste_fraction = min(1.0, state.total_wasted_value / max(state.total_spend, 0.01)) crisis_response_score = _compute_crisis_response_score(state, task_config) incoherent_count = sum(1 for r in state.justification_log if not r.is_coherent) justification_penalty = min(0.15, incoherent_count * 0.05) score = ( 0.35 * service_level + 0.25 * cost_efficiency + 0.20 * (1.0 - critical_stockout_rate) + 0.15 * (1.0 - waste_fraction) + 0.05 * crisis_response_score - justification_penalty ) return max(0.0, min(1.0, score)) def _compute_crisis_response_score( state: "SimState", task_config: "TaskConfig", ) -> float: """ Measures crisis response for MCI and recall events. Returns 0.0 to 1.0. """ score = 0.0 max_score = 0.0 mci_event = next((e for e in task_config.events if e.event_id == "mci_activation"), None) if mci_event: max_score += 0.6 total_crit_dem = sum(state.daily_critical_demand) total_crit_ful = sum(state.daily_critical_fulfilled) mci_service = total_crit_ful / max(total_crit_dem, 1) score += 0.6 * mci_service recall_event = next((e for e in task_config.events if e.event_id == "iv_saline_recall"), None) if recall_event: max_score += 0.4 if state.recall_handled_by_day is not None: days_delayed = state.recall_handled_by_day - recall_event.trigger_day if days_delayed <= 0: score += 0.4 elif days_delayed <= 2: score += 0.2 if max_score == 0: return 1.0 return score / max_score def grade_justification(reason: str, active_event_types: Set[str]) -> bool: """ Deterministic keyword-based justification grading. Returns True if coherent (no penalty), False if incoherent. """ CRISIS_KEYWORDS: Dict[str, List[str]] = { "mci": ["mci", "mass casualty", "trauma", "incident", "accident", "emergency", "casualties", "blood", "critical patients"], "supplier_disruption": ["disruption", "delay", "lead time", "supplier", "shortage", "force majeure", "extended"], "product_recall": ["recall", "quarantine", "contamination", "lot", "health authority", "batch", "defective", "compromised"], "budget_tighten": ["budget", "fiscal", "quarter", "constraint", "ceiling", "limit", "finance"], "cold_chain_breach": ["cold chain", "temperature", "breach", "refriger", "spoilage", "compromised"], "demand_surge": ["demand", "surge", "increased", "elevated", "high usage", "outbreak", "flu", "influenza"], } GENERIC_KEYWORDS = [ "urgent", "critical", "shortage", "low stock", "stockout", "emergency", "insufficient", ] reason_lower = reason.lower() if not active_event_types: return any(kw in reason_lower for kw in GENERIC_KEYWORDS) for event_type in active_event_types: keywords = CRISIS_KEYWORDS.get(event_type, []) if any(kw in reason_lower for kw in keywords): return True return any(kw in reason_lower for kw in GENERIC_KEYWORDS)