""" Core simulation engine for the MedChain Env environment. MedchainSimulation manages the full episode lifecycle: - Inventory tracked as FEFO lots per (location, product) - Purchase order pipeline with stochastic lead times - Event-driven inbox messages (crises, recalls, demand surges) - Daily demand generation and fulfillment """ from __future__ import annotations import uuid from dataclasses import dataclass, field from typing import Dict, List, Optional, Set, Tuple import numpy as np from .tasks import SimEvent, TaskConfig # ─── Simulation Dataclasses ─────────────────────────────────────────────────── @dataclass class Lot: lot_id: str qty: int expiry_day: Optional[int] # None = non-perishable. Expired when current_day >= expiry_day. cost_per_unit: float @dataclass class PurchaseOrder: po_id: str supplier_id: str product_id: str destination_id: str quantity: int priority: str # "standard" or "expedited" day_submitted: int eta_day: int unit_cost: float total_cost: float status: str # "pending_justification", "in_transit", "delivered" lot_id: str @dataclass class PendingBudgetOverride: ticket_id: str po: PurchaseOrder @dataclass class InboxMessage: msg_id: str priority: str timestamp_str: str # "Day {n} {HH:MM}" sender: str subject: str body: str read: bool flagged: bool event_id: str @dataclass class JustificationRecord: ticket_id: str po_id: str reason: str is_coherent: bool @dataclass class SimState: # Episode meta task: str episode_id: str seed: int rng: np.random.Generator # Time day: int max_days: int # Action budget actions_remaining: int actions_per_shift: int # Budget budget_used: float budget_limit: float # Inventory: (location_id, product_id) -> List[Lot] (FEFO-sorted) inventory: Dict[Tuple[str, str], List[Lot]] # Orders pipeline_orders: List[PurchaseOrder] po_counter: int # Inbox inbox: List[InboxMessage] msg_counter: int # Budget override tickets pending_overrides: Dict[str, PendingBudgetOverride] # Quarantine quarantined_lots: Set[str] # Demand / fulfillment tracking (one value per completed day) daily_demand: List[float] daily_fulfilled: List[float] daily_critical_demand: List[float] daily_critical_fulfilled: List[float] # Per-(location, product) daily tracking (for demand_history queries) daily_product_demand: Dict[Tuple[str, str], List[int]] daily_product_fulfilled: Dict[Tuple[str, str], List[int]] # Spend tracking total_spend: float total_wasted_value: float # Transfer tracking (task 2) transfer_count: int transfer_cost_paid: float # Capacity violations (task 2) capacity_violation_days: int # Active event effects: event_id -> last_day_active (inclusive) active_events: Dict[str, int] # Per-shift shaping reward helpers info_rewards_given_this_shift: Set[str] daily_stockout_count: int daily_expired_lots: int # Task 3 crisis tracking justification_log: List[JustificationRecord] mci_preemptive_order: bool recall_handled_by_day: Optional[int] # ─── MedchainSimulation ─────────────────────────────────────────────────────── class MedchainSimulation: """ Core simulation engine. Called by MedchainEnvironment's MCP tools. All public tool methods return a string (displayed to agent as ERP output). end_shift_tool() also stores _last_reward and _done for the environment. """ def __init__(self, task_config: TaskConfig): self._task = task_config self._state: Optional[SimState] = None self._last_reward: float = 0.0 self._done: bool = False # ── Called by environment.reset() ────────────────────────────────────── def reset(self, seed: int, episode_id: str) -> str: """Initialize a new episode. Returns dashboard text.""" self._done = False self._last_reward = 0.0 rng = np.random.default_rng(seed) self._state = SimState( task=self._task.name, episode_id=episode_id, seed=seed, rng=rng, day=1, max_days=self._task.max_days, actions_remaining=self._task.actions_per_shift, actions_per_shift=self._task.actions_per_shift, budget_used=0.0, budget_limit=self._task.budget_limit, inventory={}, pipeline_orders=[], po_counter=1, inbox=[], msg_counter=1, pending_overrides={}, quarantined_lots=set(), daily_demand=[], daily_fulfilled=[], daily_critical_demand=[], daily_critical_fulfilled=[], daily_product_demand={}, daily_product_fulfilled={}, total_spend=0.0, total_wasted_value=0.0, transfer_count=0, transfer_cost_paid=0.0, capacity_violation_days=0, active_events={}, info_rewards_given_this_shift=set(), daily_stockout_count=0, daily_expired_lots=0, justification_log=[], mci_preemptive_order=False, recall_handled_by_day=None, ) self._initialize_inventory() self._inject_day1_inbox() from .erp_formatter import format_dashboard return format_dashboard(self._state, self._task) def _initialize_inventory(self): """Seed initial inventory: initial_stock_days × base_demand per location/product.""" state = self._state for product in self._task.products: for loc_id in product.locations: key = (loc_id, product.product_id) qty = int(product.base_demand * self._task.initial_stock_days) expiry_day = ( 1 + int(product.shelf_life_days * 0.7) if product.shelf_life_days is not None else None ) lot = Lot( lot_id=f"INIT-{product.product_id}-{loc_id}", qty=qty, expiry_day=expiry_day, cost_per_unit=product.unit_cost, ) state.inventory[key] = [lot] def _inject_day1_inbox(self): """Add Day 1 inbox messages (welcome + any Day 1 events).""" state = self._state welcome = InboxMessage( msg_id=f"MSG-{state.msg_counter:04d}", priority="LOW", timestamp_str="Day 1 08:00", sender="System", subject="Shift Handover Notes", body=( f"Welcome to the {self._task.name} scenario.\n" f"You are managing medical supplies for {self._task.max_days} days.\n" f"Action budget: {self._task.actions_per_shift} actions per shift.\n" f"Budget ceiling: ${self._task.budget_limit:,.0f} outstanding orders.\n\n" "Use read_inbox to check messages, query_erp to check stock,\n" "submit_po to order supplies, and end_shift to advance the day." ), read=False, flagged=False, event_id="system_welcome", ) state.inbox.append(welcome) state.msg_counter += 1 self._inject_events_for_day(1) # ── Action Budget Helper ──────────────────────────────────────────────── def _check_action_budget(self, tool_name: str) -> Optional[str]: """Returns error string if budget exhausted, None if OK. Does NOT decrement.""" if tool_name == "end_shift": return None if self._state is None: return "ERROR: Environment not initialized. Call reset() first." if self._state.actions_remaining <= 0: return ( "ERROR: Action budget exhausted for this shift.\n" f"Actions used: {self._state.actions_per_shift}/{self._state.actions_per_shift}\n" "Call end_shift() to advance to the next day and restore your action budget." ) return None # ── MCP Tool Implementations ──────────────────────────────────────────── def read_inbox(self, filter: str = "unread") -> str: err = self._check_action_budget("read_inbox") if err: return err self._state.actions_remaining -= 1 messages = list(self._state.inbox) if filter == "unread": messages = [m for m in messages if not m.read] elif filter == "flagged": messages = [m for m in messages if m.flagged] # "all" → use full inbox for m in messages: m.read = True if not messages: return f"INBOX EMPTY\nFilter: {filter} | No messages matching filter." lines = [] for m in messages: read_status = "READ" if m.read else "UNREAD" lines.append( f"\n[MSG {m.msg_id} | {read_status} | PRIORITY: {m.priority} | {m.timestamp_str}]" ) lines.append(f"FROM: {m.sender}") lines.append(f"SUBJ: {m.subject}") lines.append("") lines.append(m.body) lines.append("") return "\n".join(lines) def query_erp(self, table: str, location: str = "all", sku: str = "all") -> str: err = self._check_action_budget("query_erp") if err: return err self._state.actions_remaining -= 1 valid_tables = ["inventory", "expiry", "pipeline_orders", "demand_history"] if table not in valid_tables: return f"ERROR: Unknown table '{table}'. Valid tables: {valid_tables}" from .erp_formatter import ( format_demand_history, format_expiry_table, format_inventory_table, format_pipeline_table, ) if table == "inventory": return format_inventory_table(self._state, self._task, location, sku) elif table == "expiry": return format_expiry_table(self._state, self._task, location, sku) elif table == "pipeline_orders": return format_pipeline_table(self._state, location, sku) elif table == "demand_history": return format_demand_history(self._state, self._task, location, sku) return "ERROR: Unexpected table." def query_supplier(self, supplier_id: str) -> str: err = self._check_action_budget("query_supplier") if err: return err self._state.actions_remaining -= 1 supplier = next((s for s in self._task.suppliers if s.supplier_id == supplier_id), None) if not supplier: available = [s.supplier_id for s in self._task.suppliers] return f"ERROR: Supplier '{supplier_id}' not found. Available: {available}" effective_lead_time = supplier.base_lead_time disruption_note = "No disruptions reported." for event_id, last_day in self._state.active_events.items(): event = next((e for e in self._task.events if e.event_id == event_id), None) if ( event and event.event_type == "supplier_disruption" and event.params.get("supplier_id") == supplier_id ): effective_lead_time = event.params["new_lead_time"] disruption_note = ( f"ACTIVE DISRUPTION: Lead time extended to {effective_lead_time} days. " f"Reason: {event.params['reason']}" ) from .erp_formatter import format_supplier_info return format_supplier_info(supplier, effective_lead_time, disruption_note) def query_forecast(self, product_id: str, location_id: str, horizon_days: int = 7) -> str: err = self._check_action_budget("query_forecast") if err: return err self._state.actions_remaining -= 1 horizon_days = max(1, min(21, horizon_days)) product = next((p for p in self._task.products if p.product_id == product_id), None) if not product: return f"ERROR: Product '{product_id}' not found." if location_id not in product.locations and location_id != "all": return f"ERROR: Product '{product_id}' is not stocked at '{location_id}'." from .erp_formatter import format_forecast return format_forecast(self._state, self._task, product, location_id, horizon_days) def submit_po( self, supplier_id: str, product_id: str, destination_id: str, quantity: int, priority: str = "standard", ) -> str: err = self._check_action_budget("submit_po") if err: return err if priority not in ("standard", "expedited"): return "ERROR: priority must be 'standard' or 'expedited'." if quantity <= 0: return "ERROR: quantity must be positive." supplier = next((s for s in self._task.suppliers if s.supplier_id == supplier_id), None) if not supplier: return f"ERROR: Supplier '{supplier_id}' not found." if product_id not in supplier.products: return f"ERROR: Supplier '{supplier_id}' does not supply '{product_id}'." valid_locs = [l.location_id for l in self._task.locations] if destination_id not in valid_locs: return f"ERROR: Destination '{destination_id}' not found. Valid: {valid_locs}" product = next((p for p in self._task.products if p.product_id == product_id), None) expedited_multiplier = 1.5 if priority == "expedited" else 1.0 unit_cost = product.unit_cost * supplier.cost_multiplier * expedited_multiplier total_cost = unit_cost * quantity if self._state.budget_used + total_cost > self._state.budget_limit: overage = (self._state.budget_used + total_cost) - self._state.budget_limit return ( f"ERROR: BUDGET_EXCEEDED\n" f"Order cost: ${total_cost:,.2f} | " f"Current outstanding: ${self._state.budget_used:,.2f} | " f"Limit: ${self._state.budget_limit:,.2f}\n" f"Overage: ${overage:,.2f}\n" f"Reduce order quantity or wait for existing orders to be delivered." ) # Effective lead time (check active disruptions) lead_time = supplier.base_lead_time for event_id, last_day in self._state.active_events.items(): event = next((e for e in self._task.events if e.event_id == event_id), None) if ( event and event.event_type == "supplier_disruption" and event.params.get("supplier_id") == supplier_id ): lead_time = event.params["new_lead_time"] if priority == "expedited": lead_time = max(1, lead_time - 2) # Stochastic jitter for task 3 if supplier.lead_time_std > 0: jitter = int(round(self._state.rng.normal(0, supplier.lead_time_std))) lead_time = max(1, lead_time + jitter) eta_day = self._state.day + lead_time po_id = f"POD-{self._state.po_counter:04d}" lot_id = f"LOT-{po_id}" self._state.po_counter += 1 # Expedited: requires justification if priority == "expedited": ticket_id = f"BOT-{self._state.po_counter:04d}" self._state.po_counter += 1 po = PurchaseOrder( po_id=po_id, supplier_id=supplier_id, product_id=product_id, destination_id=destination_id, quantity=quantity, priority=priority, day_submitted=self._state.day, eta_day=eta_day, unit_cost=unit_cost, total_cost=total_cost, status="pending_justification", lot_id=lot_id, ) self._state.pending_overrides[ticket_id] = PendingBudgetOverride( ticket_id=ticket_id, po=po ) self._state.actions_remaining -= 1 return ( f"ERROR: BUDGET_OVERRIDE_REQUIRED\n" f"Order {po_id} ({priority}, ${total_cost:,.2f} incl. 50% expedite premium) " f"requires justification.\n" f"Ticket ID: {ticket_id}\n" f"Use file_justification(ticket_id=\"{ticket_id}\", reason=\"...\") to proceed.\n" f"Justification will be audited by Finance. False justifications are flagged." ) # Standard order: submit immediately self._state.actions_remaining -= 1 po = PurchaseOrder( po_id=po_id, supplier_id=supplier_id, product_id=product_id, destination_id=destination_id, quantity=quantity, priority=priority, day_submitted=self._state.day, eta_day=eta_day, unit_cost=unit_cost, total_cost=total_cost, status="in_transit", lot_id=lot_id, ) self._state.pipeline_orders.append(po) self._state.budget_used += total_cost return ( f"OK — PO {po_id} submitted.\n" f"Product: {product_id} × {quantity} units\n" f"Supplier: {supplier_id} | Priority: {priority}\n" f"Destination: {destination_id} | ETA: Day {eta_day}\n" f"Cost: ${total_cost:,.2f} | " f"Budget remaining: ${self._state.budget_limit - self._state.budget_used:,.2f}" ) def transfer( self, from_location_id: str, to_location_id: str, product_id: str, quantity: int, ) -> str: err = self._check_action_budget("transfer") if err: return err self._state.actions_remaining -= 1 if quantity <= 0: return "ERROR: quantity must be positive." valid_locs = {l.location_id for l in self._task.locations} if from_location_id not in valid_locs: return f"ERROR: Location '{from_location_id}' not found." if to_location_id not in valid_locs: return f"ERROR: Location '{to_location_id}' not found." key_from = (from_location_id, product_id) lots = sorted( [ l for l in self._state.inventory.get(key_from, []) if l.lot_id not in self._state.quarantined_lots ], key=lambda l: (l.expiry_day is None, l.expiry_day or 0), ) available = sum(l.qty for l in lots) if available < quantity: return ( f"ERROR: Insufficient stock at {from_location_id}. " f"Available: {available} units of {product_id}." ) # Check destination capacity (task 2) dest_loc = next( (l for l in self._task.locations if l.location_id == to_location_id), None ) if dest_loc and dest_loc.capacity is not None: current_at_dest = sum( sum(lot.qty for lot in lots2) for (loc, pid), lots2 in self._state.inventory.items() if loc == to_location_id ) if current_at_dest + quantity > dest_loc.capacity: return ( f"ERROR: CAPACITY_EXCEEDED — {to_location_id} capacity {dest_loc.capacity}. " f"Current: {current_at_dest}, Transfer: {quantity}." ) # FEFO transfer remaining = quantity key_to = (to_location_id, product_id) if key_to not in self._state.inventory: self._state.inventory[key_to] = [] for lot in lots: if remaining <= 0: break take = min(remaining, lot.qty) lot.qty -= take remaining -= take self._state.inventory[key_to].append( Lot( lot_id=f"XFR-{lot.lot_id}", qty=take, expiry_day=lot.expiry_day, cost_per_unit=lot.cost_per_unit, ) ) self._state.inventory[key_from] = [ l for l in self._state.inventory[key_from] if l.qty > 0 ] TRANSFER_FEE = 0.5 fee = quantity * TRANSFER_FEE self._state.transfer_count += 1 self._state.transfer_cost_paid += fee return ( f"OK — Transfer complete.\n" f"{quantity} units of {product_id}: {from_location_id} → {to_location_id}\n" f"Transfer fee: ${fee:.2f}" ) def quarantine_lot(self, location_id: str, sku: str, lot_id: str) -> str: err = self._check_action_budget("quarantine_lot") if err: return err self._state.actions_remaining -= 1 valid_locs = {l.location_id for l in self._task.locations} if location_id not in valid_locs: return f"ERROR: Location '{location_id}' not found." key = (location_id, sku) lots = self._state.inventory.get(key, []) if lot_id == "all": target_lots = [l for l in lots] else: target_lots = [l for l in lots if l.lot_id == lot_id] if not target_lots: target_lots = [l for l in lots if lot_id in l.lot_id] if not target_lots: available_lots = [l.lot_id for l in lots] return ( f"ERROR: Lot '{lot_id}' not found at {location_id} for SKU {sku}. " f"Available lots: {available_lots}" ) quarantined_qty = 0 disposal_ids = [] for lot in target_lots: if lot.lot_id not in self._state.quarantined_lots: self._state.quarantined_lots.add(lot.lot_id) quarantined_qty += lot.qty disposal_ids.append(lot.lot_id) # Track recall completion for task 3 if sku == "IV-SAL-500" and "RECALL-LOT" in lot_id: self._check_recall_completion() disposal_ticket = f"DIS-{self._state.po_counter:04d}" self._state.po_counter += 1 return ( f"OK — Quarantine complete.\n" f"SKU: {sku} | Location: {location_id}\n" f"Lots quarantined: {disposal_ids}\n" f"Units quarantined: {quarantined_qty}\n" f"Disposal ticket: {disposal_ticket} created." ) def file_justification(self, ticket_id: str, reason: str) -> str: err = self._check_action_budget("file_justification") if err: return err self._state.actions_remaining -= 1 if ticket_id not in self._state.pending_overrides: return ( f"ERROR: Ticket '{ticket_id}' not found or already processed.\n" f"Active tickets: {list(self._state.pending_overrides.keys())}" ) override = self._state.pending_overrides.pop(ticket_id) po = override.po active_event_types: Set[str] = set() for event_id in self._state.active_events: event = next((e for e in self._task.events if e.event_id == event_id), None) if event: active_event_types.add(event.event_type) from .grader import grade_justification is_coherent = grade_justification(reason, active_event_types) record = JustificationRecord( ticket_id=ticket_id, po_id=po.po_id, reason=reason, is_coherent=is_coherent ) self._state.justification_log.append(record) po.status = "in_transit" self._state.pipeline_orders.append(po) self._state.budget_used += po.total_cost audit_note = "" if not is_coherent: audit_note = ( "\n[AUDIT FLAG] Justification does not reference active crisis conditions. " "Flagged for Finance review. Penalty applied." ) return ( f"OK — Justification {'accepted' if is_coherent else 'FLAGGED'}. " f"PO {po.po_id} submitted.\n" f"Product: {po.product_id} × {po.quantity} units | Destination: {po.destination_id}\n" f"ETA: Day {po.eta_day} | Cost: ${po.total_cost:,.2f}" f"{audit_note}" ) def end_shift_tool(self) -> str: """Advance simulation by one day. Stores _last_reward and _done.""" state = self._state if state is None: return "ERROR: Environment not initialized." day = state.day report_lines = [f"╔═══ END OF SHIFT — Day {day} {'═' * 40}╗"] # ── Step 1: Deliver arriving orders ────────────────────────────── delivered = [] for po in list(state.pipeline_orders): if po.eta_day <= day: product = next( (p for p in self._task.products if p.product_id == po.product_id), None ) key = (po.destination_id, po.product_id) if key not in state.inventory: state.inventory[key] = [] expiry_day = (day + product.shelf_life_days) if product.shelf_life_days else None lot = Lot( lot_id=po.lot_id, qty=po.quantity, expiry_day=expiry_day, cost_per_unit=po.unit_cost ) state.inventory[key].append(lot) state.budget_used -= po.total_cost state.total_spend += po.total_cost po.status = "delivered" delivered.append(po) state.pipeline_orders = [po for po in state.pipeline_orders if po.status != "delivered"] if delivered: report_lines.append(f" DELIVERIES: {len(delivered)} order(s) received.") # ── Step 2: Expire old lots ─────────────────────────────────────── total_expired_units = 0 total_expired_value = 0.0 for key in list(state.inventory.keys()): fresh, expired = [], [] for lot in state.inventory[key]: if lot.expiry_day is not None and lot.expiry_day <= day: expired.append(lot) else: fresh.append(lot) if expired: for lot in expired: total_expired_units += lot.qty total_expired_value += lot.qty * lot.cost_per_unit state.total_wasted_value += lot.qty * lot.cost_per_unit state.daily_expired_lots += len(expired) state.inventory[key] = fresh if total_expired_units > 0: report_lines.append( f" EXPIRED: {total_expired_units} units (${total_expired_value:,.2f} written off)" ) # ── Step 3: Generate and fulfill demand ─────────────────────────── day_demand = 0.0 day_fulfilled = 0.0 day_critical_demand = 0.0 day_critical_fulfilled = 0.0 for product in self._task.products: for loc_id in product.locations: demand = self._generate_demand(product, loc_id, day) fulfilled = self._fefo_fulfill(product.product_id, loc_id, demand, day) day_demand += demand day_fulfilled += fulfilled if product.criticality == "CRITICAL": day_critical_demand += demand day_critical_fulfilled += fulfilled # Per-product daily tracking key = (loc_id, product.product_id) if key not in state.daily_product_demand: state.daily_product_demand[key] = [] state.daily_product_fulfilled[key] = [] state.daily_product_demand[key].append(demand) state.daily_product_fulfilled[key].append(fulfilled) state.daily_demand.append(day_demand) state.daily_fulfilled.append(day_fulfilled) state.daily_critical_demand.append(day_critical_demand) state.daily_critical_fulfilled.append(day_critical_fulfilled) day_svc = day_fulfilled / max(day_demand, 1) report_lines.append( f" DEMAND: {int(day_demand)} units | FULFILLED: {int(day_fulfilled)} ({100 * day_svc:.1f}%)" ) # ── Step 4: Check capacity violations (task 2) ──────────────────── if any(l.capacity is not None for l in self._task.locations): for location in self._task.locations: if location.capacity is None: continue current = sum( sum(lot.qty for lot in lots) for (lid, pid), lots in state.inventory.items() if lid == location.location_id ) if current > location.capacity: state.capacity_violation_days += 1 # ── Step 5: Inject recall lot for task 3 (Day 2, silent) ───────── if self._task.name == "hospital_network_crisis" and day == 2: self._inject_recall_lot() # ── Step 6: Advance day, reset budget, inject next-day events ──── state.day += 1 state.actions_remaining = state.actions_per_shift self._update_active_events(state.day) self._inject_events_for_day(state.day) # ── Step 7: Daily shaping reward ────────────────────────────────── shaping = 0.0 day_service = day_fulfilled / max(day_demand, 1) shaping += 0.10 * day_service total_units = sum( lot.qty for lots in state.inventory.values() for lot in lots if lot.lot_id not in state.quarantined_lots ) shaping -= 0.00005 * total_units shaping -= min(0.30, state.daily_expired_lots * 0.10) shaping -= min(0.50, state.daily_stockout_count * 0.20) state.info_rewards_given_this_shift = set() state.daily_stockout_count = 0 state.daily_expired_lots = 0 # ── Step 8: Compute terminal score & check done ─────────────────── from .grader import compute_reward final_score = compute_reward(state, self._task) done = state.day > state.max_days if done: report_lines.append( f"╠═══ EPISODE COMPLETE — Final Score: {final_score:.3f} {'═' * 30}╣" ) total_d = sum(state.daily_demand) total_f = sum(state.daily_fulfilled) report_lines.append( f" Service Level: {total_f / max(total_d, 1) * 100:.1f}%" ) report_lines.append(f" Total Spend: ${state.total_spend:,.2f}") report_lines.append(f" Waste Value: ${state.total_wasted_value:,.2f}") report_lines.append(f"╚{'═' * 68}╝") self._done = True self._last_reward = final_score return "\n".join(report_lines) self._done = False self._last_reward = shaping report_lines.append( f"╚═══ Day {day} committed. Day {state.day} begins. {'═' * 38}╝" ) report_lines.append("") from .erp_formatter import format_dashboard report_lines.append(format_dashboard(state, self._task)) return "\n".join(report_lines) # ── Private Helpers ──────────────────────────────────────────────────── def _generate_demand(self, product, location_id: str, day: int) -> int: import math as _math state = self._state base = product.base_demand if product.seasonal_amplitude > 0 and product.seasonal_period > 0: seasonal = product.seasonal_amplitude * _math.sin( 2 * _math.pi * day / product.seasonal_period + product.seasonal_phase ) base *= (1 + seasonal) for event_id, last_day in state.active_events.items(): event = next((e for e in self._task.events if e.event_id == event_id), None) if event is None: continue if event.event_type == "mci": if ( product.criticality in ("CRITICAL", "HIGH") and location_id in event.params.get("locations", []) ): base *= event.params.get("demand_multiplier", 3.0) elif event.event_type == "demand_surge": if product.product_id in event.params.get("products", []): base *= event.params.get("multiplier", 1.4) noise = state.rng.normal(0, product.demand_std) return max(0, int(round(base + noise))) def _fefo_fulfill( self, product_id: str, location_id: str, demand: int, day: int ) -> int: state = self._state key = (location_id, product_id) lots = state.inventory.get(key, []) lots_sorted = sorted( [l for l in lots if l.lot_id not in state.quarantined_lots and l.qty > 0], key=lambda l: (l.expiry_day is None, l.expiry_day or 0), ) fulfilled = 0 for lot in lots_sorted: if fulfilled >= demand: break take = min(demand - fulfilled, lot.qty) lot.qty -= take fulfilled += take state.inventory[key] = [l for l in lots if l.qty > 0] if fulfilled < demand: state.daily_stockout_count += 1 return fulfilled def _update_active_events(self, day: int): state = self._state state.active_events = { eid: last_day for eid, last_day in state.active_events.items() if last_day >= day } for event in self._task.events: if event.trigger_day == day and event.duration_days > 0: state.active_events[event.event_id] = day + event.duration_days - 1 def _inject_events_for_day(self, day: int): state = self._state for event in self._task.events: if event.trigger_day == day: msg = InboxMessage( msg_id=f"MSG-{state.msg_counter:04d}", priority=event.message.priority, timestamp_str=f"Day {day} 06:00", sender=event.message.sender, subject=event.message.subject, body=event.message.body, read=False, flagged=(event.message.priority == "CRITICAL"), event_id=event.event_id, ) state.inbox.append(msg) state.msg_counter += 1 if event.event_type == "cold_chain_breach": self._apply_cold_chain_breach(event) if event.event_type == "budget_tighten": state.budget_limit = event.params["new_budget_limit"] if event.warning_message and event.trigger_day - 1 == day: msg = InboxMessage( msg_id=f"MSG-{state.msg_counter:04d}", priority=event.warning_message.priority, timestamp_str=f"Day {day} 18:00", sender=event.warning_message.sender, subject=event.warning_message.subject, body=event.warning_message.body, read=False, flagged=False, event_id=f"{event.event_id}_warning", ) state.inbox.append(msg) state.msg_counter += 1 def _apply_cold_chain_breach(self, event: SimEvent): state = self._state loc = event.params["location_id"] prod = event.params["product_id"] key = (loc, prod) for lot in state.inventory.get(key, []): state.quarantined_lots.add(lot.lot_id) def _inject_recall_lot(self): state = self._state recall_lot_id = "RECALL-LOT-IV2026-9821" for event in self._task.events: if event.event_id == "iv_saline_recall": qty = event.params["qty_per_location"] product = next( (p for p in self._task.products if p.product_id == "IV-SAL-500"), None ) if product is None: break for loc_id in event.params["locations_with_lot"]: key = (loc_id, "IV-SAL-500") if key not in state.inventory: state.inventory[key] = [] lot = Lot( lot_id=recall_lot_id, qty=qty, expiry_day=None, cost_per_unit=product.unit_cost, ) state.inventory[key].append(lot) break def _check_recall_completion(self): state = self._state recall_lot_id = "RECALL-LOT-IV2026-9821" if recall_lot_id not in state.quarantined_lots: return if state.recall_handled_by_day is None: state.recall_handled_by_day = state.day # ── Accessors used by MedchainEnvironment ────────────────────────────── def get_last_reward(self) -> float: return self._last_reward def is_done(self) -> bool: return self._done