Spaces:
Sleeping
Sleeping
| """ | |
| Core simulation engine for the MedChain Env environment. | |
| MedchainSimulation manages the full episode lifecycle: | |
| - Inventory tracked as FEFO lots per (location, product) | |
| - Purchase order pipeline with stochastic lead times | |
| - Event-driven inbox messages (crises, recalls, demand surges) | |
| - Daily demand generation and fulfillment | |
| """ | |
| from __future__ import annotations | |
| import uuid | |
| from dataclasses import dataclass, field | |
| from typing import Dict, List, Optional, Set, Tuple | |
| import numpy as np | |
| from .tasks import SimEvent, TaskConfig | |
| # βββ Simulation Dataclasses βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class Lot: | |
| lot_id: str | |
| qty: int | |
| expiry_day: Optional[int] # None = non-perishable. Expired when current_day >= expiry_day. | |
| cost_per_unit: float | |
| class PurchaseOrder: | |
| po_id: str | |
| supplier_id: str | |
| product_id: str | |
| destination_id: str | |
| quantity: int | |
| priority: str # "standard" or "expedited" | |
| day_submitted: int | |
| eta_day: int | |
| unit_cost: float | |
| total_cost: float | |
| status: str # "pending_justification", "in_transit", "delivered" | |
| lot_id: str | |
| class PendingBudgetOverride: | |
| ticket_id: str | |
| po: PurchaseOrder | |
| class InboxMessage: | |
| msg_id: str | |
| priority: str | |
| timestamp_str: str # "Day {n} {HH:MM}" | |
| sender: str | |
| subject: str | |
| body: str | |
| read: bool | |
| flagged: bool | |
| event_id: str | |
| class JustificationRecord: | |
| ticket_id: str | |
| po_id: str | |
| reason: str | |
| is_coherent: bool | |
| class SimState: | |
| # Episode meta | |
| task: str | |
| episode_id: str | |
| seed: int | |
| rng: np.random.Generator | |
| # Time | |
| day: int | |
| max_days: int | |
| # Action budget | |
| actions_remaining: int | |
| actions_per_shift: int | |
| # Budget | |
| budget_used: float | |
| budget_limit: float | |
| # Inventory: (location_id, product_id) -> List[Lot] (FEFO-sorted) | |
| inventory: Dict[Tuple[str, str], List[Lot]] | |
| # Orders | |
| pipeline_orders: List[PurchaseOrder] | |
| po_counter: int | |
| # Inbox | |
| inbox: List[InboxMessage] | |
| msg_counter: int | |
| # Budget override tickets | |
| pending_overrides: Dict[str, PendingBudgetOverride] | |
| # Quarantine | |
| quarantined_lots: Set[str] | |
| # Demand / fulfillment tracking (one value per completed day) | |
| daily_demand: List[float] | |
| daily_fulfilled: List[float] | |
| daily_critical_demand: List[float] | |
| daily_critical_fulfilled: List[float] | |
| # Per-(location, product) daily tracking (for demand_history queries) | |
| daily_product_demand: Dict[Tuple[str, str], List[int]] | |
| daily_product_fulfilled: Dict[Tuple[str, str], List[int]] | |
| # Spend tracking | |
| total_spend: float | |
| total_wasted_value: float | |
| # Transfer tracking (task 2) | |
| transfer_count: int | |
| transfer_cost_paid: float | |
| # Capacity violations (task 2) | |
| capacity_violation_days: int | |
| # Active event effects: event_id -> last_day_active (inclusive) | |
| active_events: Dict[str, int] | |
| # Per-shift shaping reward helpers | |
| info_rewards_given_this_shift: Set[str] | |
| daily_stockout_count: int | |
| daily_expired_lots: int | |
| # Task 3 crisis tracking | |
| justification_log: List[JustificationRecord] | |
| mci_preemptive_order: bool | |
| recall_handled_by_day: Optional[int] | |
| # βββ MedchainSimulation βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class MedchainSimulation: | |
| """ | |
| Core simulation engine. Called by MedchainEnvironment's MCP tools. | |
| All public tool methods return a string (displayed to agent as ERP output). | |
| end_shift_tool() also stores _last_reward and _done for the environment. | |
| """ | |
| def __init__(self, task_config: TaskConfig): | |
| self._task = task_config | |
| self._state: Optional[SimState] = None | |
| self._last_reward: float = 0.0 | |
| self._done: bool = False | |
| # ββ Called by environment.reset() ββββββββββββββββββββββββββββββββββββββ | |
| def reset(self, seed: int, episode_id: str) -> str: | |
| """Initialize a new episode. Returns dashboard text.""" | |
| self._done = False | |
| self._last_reward = 0.0 | |
| rng = np.random.default_rng(seed) | |
| self._state = SimState( | |
| task=self._task.name, | |
| episode_id=episode_id, | |
| seed=seed, | |
| rng=rng, | |
| day=1, | |
| max_days=self._task.max_days, | |
| actions_remaining=self._task.actions_per_shift, | |
| actions_per_shift=self._task.actions_per_shift, | |
| budget_used=0.0, | |
| budget_limit=self._task.budget_limit, | |
| inventory={}, | |
| pipeline_orders=[], | |
| po_counter=1, | |
| inbox=[], | |
| msg_counter=1, | |
| pending_overrides={}, | |
| quarantined_lots=set(), | |
| daily_demand=[], | |
| daily_fulfilled=[], | |
| daily_critical_demand=[], | |
| daily_critical_fulfilled=[], | |
| daily_product_demand={}, | |
| daily_product_fulfilled={}, | |
| total_spend=0.0, | |
| total_wasted_value=0.0, | |
| transfer_count=0, | |
| transfer_cost_paid=0.0, | |
| capacity_violation_days=0, | |
| active_events={}, | |
| info_rewards_given_this_shift=set(), | |
| daily_stockout_count=0, | |
| daily_expired_lots=0, | |
| justification_log=[], | |
| mci_preemptive_order=False, | |
| recall_handled_by_day=None, | |
| ) | |
| self._initialize_inventory() | |
| self._inject_day1_inbox() | |
| from .erp_formatter import format_dashboard | |
| return format_dashboard(self._state, self._task) | |
| def _initialize_inventory(self): | |
| """Seed initial inventory: initial_stock_days Γ base_demand per location/product.""" | |
| state = self._state | |
| for product in self._task.products: | |
| for loc_id in product.locations: | |
| key = (loc_id, product.product_id) | |
| qty = int(product.base_demand * self._task.initial_stock_days) | |
| expiry_day = ( | |
| 1 + int(product.shelf_life_days * 0.7) | |
| if product.shelf_life_days is not None | |
| else None | |
| ) | |
| lot = Lot( | |
| lot_id=f"INIT-{product.product_id}-{loc_id}", | |
| qty=qty, | |
| expiry_day=expiry_day, | |
| cost_per_unit=product.unit_cost, | |
| ) | |
| state.inventory[key] = [lot] | |
| def _inject_day1_inbox(self): | |
| """Add Day 1 inbox messages (welcome + any Day 1 events).""" | |
| state = self._state | |
| welcome = InboxMessage( | |
| msg_id=f"MSG-{state.msg_counter:04d}", | |
| priority="LOW", | |
| timestamp_str="Day 1 08:00", | |
| sender="System", | |
| subject="Shift Handover Notes", | |
| body=( | |
| f"Welcome to the {self._task.name} scenario.\n" | |
| f"You are managing medical supplies for {self._task.max_days} days.\n" | |
| f"Action budget: {self._task.actions_per_shift} actions per shift.\n" | |
| f"Budget ceiling: ${self._task.budget_limit:,.0f} outstanding orders.\n\n" | |
| "Use read_inbox to check messages, query_erp to check stock,\n" | |
| "submit_po to order supplies, and end_shift to advance the day." | |
| ), | |
| read=False, | |
| flagged=False, | |
| event_id="system_welcome", | |
| ) | |
| state.inbox.append(welcome) | |
| state.msg_counter += 1 | |
| self._inject_events_for_day(1) | |
| # ββ Action Budget Helper ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _check_action_budget(self, tool_name: str) -> Optional[str]: | |
| """Returns error string if budget exhausted, None if OK. Does NOT decrement.""" | |
| if tool_name == "end_shift": | |
| return None | |
| if self._state is None: | |
| return "ERROR: Environment not initialized. Call reset() first." | |
| if self._state.actions_remaining <= 0: | |
| return ( | |
| "ERROR: Action budget exhausted for this shift.\n" | |
| f"Actions used: {self._state.actions_per_shift}/{self._state.actions_per_shift}\n" | |
| "Call end_shift() to advance to the next day and restore your action budget." | |
| ) | |
| return None | |
| # ββ MCP Tool Implementations ββββββββββββββββββββββββββββββββββββββββββββ | |
| def read_inbox(self, filter: str = "unread") -> str: | |
| err = self._check_action_budget("read_inbox") | |
| if err: | |
| return err | |
| self._state.actions_remaining -= 1 | |
| messages = list(self._state.inbox) | |
| if filter == "unread": | |
| messages = [m for m in messages if not m.read] | |
| elif filter == "flagged": | |
| messages = [m for m in messages if m.flagged] | |
| # "all" β use full inbox | |
| for m in messages: | |
| m.read = True | |
| if not messages: | |
| return f"INBOX EMPTY\nFilter: {filter} | No messages matching filter." | |
| lines = [] | |
| for m in messages: | |
| read_status = "READ" if m.read else "UNREAD" | |
| lines.append( | |
| f"\n[MSG {m.msg_id} | {read_status} | PRIORITY: {m.priority} | {m.timestamp_str}]" | |
| ) | |
| lines.append(f"FROM: {m.sender}") | |
| lines.append(f"SUBJ: {m.subject}") | |
| lines.append("") | |
| lines.append(m.body) | |
| lines.append("") | |
| return "\n".join(lines) | |
| def query_erp(self, table: str, location: str = "all", sku: str = "all") -> str: | |
| err = self._check_action_budget("query_erp") | |
| if err: | |
| return err | |
| self._state.actions_remaining -= 1 | |
| valid_tables = ["inventory", "expiry", "pipeline_orders", "demand_history"] | |
| if table not in valid_tables: | |
| return f"ERROR: Unknown table '{table}'. Valid tables: {valid_tables}" | |
| from .erp_formatter import ( | |
| format_demand_history, | |
| format_expiry_table, | |
| format_inventory_table, | |
| format_pipeline_table, | |
| ) | |
| if table == "inventory": | |
| return format_inventory_table(self._state, self._task, location, sku) | |
| elif table == "expiry": | |
| return format_expiry_table(self._state, self._task, location, sku) | |
| elif table == "pipeline_orders": | |
| return format_pipeline_table(self._state, location, sku) | |
| elif table == "demand_history": | |
| return format_demand_history(self._state, self._task, location, sku) | |
| return "ERROR: Unexpected table." | |
| def query_supplier(self, supplier_id: str) -> str: | |
| err = self._check_action_budget("query_supplier") | |
| if err: | |
| return err | |
| self._state.actions_remaining -= 1 | |
| supplier = next((s for s in self._task.suppliers if s.supplier_id == supplier_id), None) | |
| if not supplier: | |
| available = [s.supplier_id for s in self._task.suppliers] | |
| return f"ERROR: Supplier '{supplier_id}' not found. Available: {available}" | |
| effective_lead_time = supplier.base_lead_time | |
| disruption_note = "No disruptions reported." | |
| for event_id, last_day in self._state.active_events.items(): | |
| event = next((e for e in self._task.events if e.event_id == event_id), None) | |
| if ( | |
| event | |
| and event.event_type == "supplier_disruption" | |
| and event.params.get("supplier_id") == supplier_id | |
| ): | |
| effective_lead_time = event.params["new_lead_time"] | |
| disruption_note = ( | |
| f"ACTIVE DISRUPTION: Lead time extended to {effective_lead_time} days. " | |
| f"Reason: {event.params['reason']}" | |
| ) | |
| from .erp_formatter import format_supplier_info | |
| return format_supplier_info(supplier, effective_lead_time, disruption_note) | |
| def query_forecast(self, product_id: str, location_id: str, horizon_days: int = 7) -> str: | |
| err = self._check_action_budget("query_forecast") | |
| if err: | |
| return err | |
| self._state.actions_remaining -= 1 | |
| horizon_days = max(1, min(21, horizon_days)) | |
| product = next((p for p in self._task.products if p.product_id == product_id), None) | |
| if not product: | |
| return f"ERROR: Product '{product_id}' not found." | |
| if location_id not in product.locations and location_id != "all": | |
| return f"ERROR: Product '{product_id}' is not stocked at '{location_id}'." | |
| from .erp_formatter import format_forecast | |
| return format_forecast(self._state, self._task, product, location_id, horizon_days) | |
| def submit_po( | |
| self, | |
| supplier_id: str, | |
| product_id: str, | |
| destination_id: str, | |
| quantity: int, | |
| priority: str = "standard", | |
| ) -> str: | |
| err = self._check_action_budget("submit_po") | |
| if err: | |
| return err | |
| if priority not in ("standard", "expedited"): | |
| return "ERROR: priority must be 'standard' or 'expedited'." | |
| if quantity <= 0: | |
| return "ERROR: quantity must be positive." | |
| supplier = next((s for s in self._task.suppliers if s.supplier_id == supplier_id), None) | |
| if not supplier: | |
| return f"ERROR: Supplier '{supplier_id}' not found." | |
| if product_id not in supplier.products: | |
| return f"ERROR: Supplier '{supplier_id}' does not supply '{product_id}'." | |
| valid_locs = [l.location_id for l in self._task.locations] | |
| if destination_id not in valid_locs: | |
| return f"ERROR: Destination '{destination_id}' not found. Valid: {valid_locs}" | |
| product = next((p for p in self._task.products if p.product_id == product_id), None) | |
| expedited_multiplier = 1.5 if priority == "expedited" else 1.0 | |
| unit_cost = product.unit_cost * supplier.cost_multiplier * expedited_multiplier | |
| total_cost = unit_cost * quantity | |
| if self._state.budget_used + total_cost > self._state.budget_limit: | |
| overage = (self._state.budget_used + total_cost) - self._state.budget_limit | |
| return ( | |
| f"ERROR: BUDGET_EXCEEDED\n" | |
| f"Order cost: ${total_cost:,.2f} | " | |
| f"Current outstanding: ${self._state.budget_used:,.2f} | " | |
| f"Limit: ${self._state.budget_limit:,.2f}\n" | |
| f"Overage: ${overage:,.2f}\n" | |
| f"Reduce order quantity or wait for existing orders to be delivered." | |
| ) | |
| # Effective lead time (check active disruptions) | |
| lead_time = supplier.base_lead_time | |
| for event_id, last_day in self._state.active_events.items(): | |
| event = next((e for e in self._task.events if e.event_id == event_id), None) | |
| if ( | |
| event | |
| and event.event_type == "supplier_disruption" | |
| and event.params.get("supplier_id") == supplier_id | |
| ): | |
| lead_time = event.params["new_lead_time"] | |
| if priority == "expedited": | |
| lead_time = max(1, lead_time - 2) | |
| # Stochastic jitter for task 3 | |
| if supplier.lead_time_std > 0: | |
| jitter = int(round(self._state.rng.normal(0, supplier.lead_time_std))) | |
| lead_time = max(1, lead_time + jitter) | |
| eta_day = self._state.day + lead_time | |
| po_id = f"POD-{self._state.po_counter:04d}" | |
| lot_id = f"LOT-{po_id}" | |
| self._state.po_counter += 1 | |
| # Expedited: requires justification | |
| if priority == "expedited": | |
| ticket_id = f"BOT-{self._state.po_counter:04d}" | |
| self._state.po_counter += 1 | |
| po = PurchaseOrder( | |
| po_id=po_id, supplier_id=supplier_id, product_id=product_id, | |
| destination_id=destination_id, quantity=quantity, priority=priority, | |
| day_submitted=self._state.day, eta_day=eta_day, unit_cost=unit_cost, | |
| total_cost=total_cost, status="pending_justification", lot_id=lot_id, | |
| ) | |
| self._state.pending_overrides[ticket_id] = PendingBudgetOverride( | |
| ticket_id=ticket_id, po=po | |
| ) | |
| self._state.actions_remaining -= 1 | |
| return ( | |
| f"ERROR: BUDGET_OVERRIDE_REQUIRED\n" | |
| f"Order {po_id} ({priority}, ${total_cost:,.2f} incl. 50% expedite premium) " | |
| f"requires justification.\n" | |
| f"Ticket ID: {ticket_id}\n" | |
| f"Use file_justification(ticket_id=\"{ticket_id}\", reason=\"...\") to proceed.\n" | |
| f"Justification will be audited by Finance. False justifications are flagged." | |
| ) | |
| # Standard order: submit immediately | |
| self._state.actions_remaining -= 1 | |
| po = PurchaseOrder( | |
| po_id=po_id, supplier_id=supplier_id, product_id=product_id, | |
| destination_id=destination_id, quantity=quantity, priority=priority, | |
| day_submitted=self._state.day, eta_day=eta_day, unit_cost=unit_cost, | |
| total_cost=total_cost, status="in_transit", lot_id=lot_id, | |
| ) | |
| self._state.pipeline_orders.append(po) | |
| self._state.budget_used += total_cost | |
| return ( | |
| f"OK β PO {po_id} submitted.\n" | |
| f"Product: {product_id} Γ {quantity} units\n" | |
| f"Supplier: {supplier_id} | Priority: {priority}\n" | |
| f"Destination: {destination_id} | ETA: Day {eta_day}\n" | |
| f"Cost: ${total_cost:,.2f} | " | |
| f"Budget remaining: ${self._state.budget_limit - self._state.budget_used:,.2f}" | |
| ) | |
| def transfer( | |
| self, | |
| from_location_id: str, | |
| to_location_id: str, | |
| product_id: str, | |
| quantity: int, | |
| ) -> str: | |
| err = self._check_action_budget("transfer") | |
| if err: | |
| return err | |
| self._state.actions_remaining -= 1 | |
| if quantity <= 0: | |
| return "ERROR: quantity must be positive." | |
| valid_locs = {l.location_id for l in self._task.locations} | |
| if from_location_id not in valid_locs: | |
| return f"ERROR: Location '{from_location_id}' not found." | |
| if to_location_id not in valid_locs: | |
| return f"ERROR: Location '{to_location_id}' not found." | |
| key_from = (from_location_id, product_id) | |
| lots = sorted( | |
| [ | |
| l for l in self._state.inventory.get(key_from, []) | |
| if l.lot_id not in self._state.quarantined_lots | |
| ], | |
| key=lambda l: (l.expiry_day is None, l.expiry_day or 0), | |
| ) | |
| available = sum(l.qty for l in lots) | |
| if available < quantity: | |
| return ( | |
| f"ERROR: Insufficient stock at {from_location_id}. " | |
| f"Available: {available} units of {product_id}." | |
| ) | |
| # Check destination capacity (task 2) | |
| dest_loc = next( | |
| (l for l in self._task.locations if l.location_id == to_location_id), None | |
| ) | |
| if dest_loc and dest_loc.capacity is not None: | |
| current_at_dest = sum( | |
| sum(lot.qty for lot in lots2) | |
| for (loc, pid), lots2 in self._state.inventory.items() | |
| if loc == to_location_id | |
| ) | |
| if current_at_dest + quantity > dest_loc.capacity: | |
| return ( | |
| f"ERROR: CAPACITY_EXCEEDED β {to_location_id} capacity {dest_loc.capacity}. " | |
| f"Current: {current_at_dest}, Transfer: {quantity}." | |
| ) | |
| # FEFO transfer | |
| remaining = quantity | |
| key_to = (to_location_id, product_id) | |
| if key_to not in self._state.inventory: | |
| self._state.inventory[key_to] = [] | |
| for lot in lots: | |
| if remaining <= 0: | |
| break | |
| take = min(remaining, lot.qty) | |
| lot.qty -= take | |
| remaining -= take | |
| self._state.inventory[key_to].append( | |
| Lot( | |
| lot_id=f"XFR-{lot.lot_id}", | |
| qty=take, | |
| expiry_day=lot.expiry_day, | |
| cost_per_unit=lot.cost_per_unit, | |
| ) | |
| ) | |
| self._state.inventory[key_from] = [ | |
| l for l in self._state.inventory[key_from] if l.qty > 0 | |
| ] | |
| TRANSFER_FEE = 0.5 | |
| fee = quantity * TRANSFER_FEE | |
| self._state.transfer_count += 1 | |
| self._state.transfer_cost_paid += fee | |
| return ( | |
| f"OK β Transfer complete.\n" | |
| f"{quantity} units of {product_id}: {from_location_id} β {to_location_id}\n" | |
| f"Transfer fee: ${fee:.2f}" | |
| ) | |
| def quarantine_lot(self, location_id: str, sku: str, lot_id: str) -> str: | |
| err = self._check_action_budget("quarantine_lot") | |
| if err: | |
| return err | |
| self._state.actions_remaining -= 1 | |
| valid_locs = {l.location_id for l in self._task.locations} | |
| if location_id not in valid_locs: | |
| return f"ERROR: Location '{location_id}' not found." | |
| key = (location_id, sku) | |
| lots = self._state.inventory.get(key, []) | |
| if lot_id == "all": | |
| target_lots = [l for l in lots] | |
| else: | |
| target_lots = [l for l in lots if l.lot_id == lot_id] | |
| if not target_lots: | |
| target_lots = [l for l in lots if lot_id in l.lot_id] | |
| if not target_lots: | |
| available_lots = [l.lot_id for l in lots] | |
| return ( | |
| f"ERROR: Lot '{lot_id}' not found at {location_id} for SKU {sku}. " | |
| f"Available lots: {available_lots}" | |
| ) | |
| quarantined_qty = 0 | |
| disposal_ids = [] | |
| for lot in target_lots: | |
| if lot.lot_id not in self._state.quarantined_lots: | |
| self._state.quarantined_lots.add(lot.lot_id) | |
| quarantined_qty += lot.qty | |
| disposal_ids.append(lot.lot_id) | |
| # Track recall completion for task 3 | |
| if sku == "IV-SAL-500" and "RECALL-LOT" in lot_id: | |
| self._check_recall_completion() | |
| disposal_ticket = f"DIS-{self._state.po_counter:04d}" | |
| self._state.po_counter += 1 | |
| return ( | |
| f"OK β Quarantine complete.\n" | |
| f"SKU: {sku} | Location: {location_id}\n" | |
| f"Lots quarantined: {disposal_ids}\n" | |
| f"Units quarantined: {quarantined_qty}\n" | |
| f"Disposal ticket: {disposal_ticket} created." | |
| ) | |
| def file_justification(self, ticket_id: str, reason: str) -> str: | |
| err = self._check_action_budget("file_justification") | |
| if err: | |
| return err | |
| self._state.actions_remaining -= 1 | |
| if ticket_id not in self._state.pending_overrides: | |
| return ( | |
| f"ERROR: Ticket '{ticket_id}' not found or already processed.\n" | |
| f"Active tickets: {list(self._state.pending_overrides.keys())}" | |
| ) | |
| override = self._state.pending_overrides.pop(ticket_id) | |
| po = override.po | |
| active_event_types: Set[str] = set() | |
| for event_id in self._state.active_events: | |
| event = next((e for e in self._task.events if e.event_id == event_id), None) | |
| if event: | |
| active_event_types.add(event.event_type) | |
| from .grader import grade_justification | |
| is_coherent = grade_justification(reason, active_event_types) | |
| record = JustificationRecord( | |
| ticket_id=ticket_id, po_id=po.po_id, reason=reason, is_coherent=is_coherent | |
| ) | |
| self._state.justification_log.append(record) | |
| po.status = "in_transit" | |
| self._state.pipeline_orders.append(po) | |
| self._state.budget_used += po.total_cost | |
| audit_note = "" | |
| if not is_coherent: | |
| audit_note = ( | |
| "\n[AUDIT FLAG] Justification does not reference active crisis conditions. " | |
| "Flagged for Finance review. Penalty applied." | |
| ) | |
| return ( | |
| f"OK β Justification {'accepted' if is_coherent else 'FLAGGED'}. " | |
| f"PO {po.po_id} submitted.\n" | |
| f"Product: {po.product_id} Γ {po.quantity} units | Destination: {po.destination_id}\n" | |
| f"ETA: Day {po.eta_day} | Cost: ${po.total_cost:,.2f}" | |
| f"{audit_note}" | |
| ) | |
| def end_shift_tool(self) -> str: | |
| """Advance simulation by one day. Stores _last_reward and _done.""" | |
| state = self._state | |
| if state is None: | |
| return "ERROR: Environment not initialized." | |
| day = state.day | |
| report_lines = [f"ββββ END OF SHIFT β Day {day} {'β' * 40}β"] | |
| # ββ Step 1: Deliver arriving orders ββββββββββββββββββββββββββββββ | |
| delivered = [] | |
| for po in list(state.pipeline_orders): | |
| if po.eta_day <= day: | |
| product = next( | |
| (p for p in self._task.products if p.product_id == po.product_id), None | |
| ) | |
| key = (po.destination_id, po.product_id) | |
| if key not in state.inventory: | |
| state.inventory[key] = [] | |
| expiry_day = (day + product.shelf_life_days) if product.shelf_life_days else None | |
| lot = Lot( | |
| lot_id=po.lot_id, qty=po.quantity, | |
| expiry_day=expiry_day, cost_per_unit=po.unit_cost | |
| ) | |
| state.inventory[key].append(lot) | |
| state.budget_used -= po.total_cost | |
| state.total_spend += po.total_cost | |
| po.status = "delivered" | |
| delivered.append(po) | |
| state.pipeline_orders = [po for po in state.pipeline_orders if po.status != "delivered"] | |
| if delivered: | |
| report_lines.append(f" DELIVERIES: {len(delivered)} order(s) received.") | |
| # ββ Step 2: Expire old lots βββββββββββββββββββββββββββββββββββββββ | |
| total_expired_units = 0 | |
| total_expired_value = 0.0 | |
| for key in list(state.inventory.keys()): | |
| fresh, expired = [], [] | |
| for lot in state.inventory[key]: | |
| if lot.expiry_day is not None and lot.expiry_day <= day: | |
| expired.append(lot) | |
| else: | |
| fresh.append(lot) | |
| if expired: | |
| for lot in expired: | |
| total_expired_units += lot.qty | |
| total_expired_value += lot.qty * lot.cost_per_unit | |
| state.total_wasted_value += lot.qty * lot.cost_per_unit | |
| state.daily_expired_lots += len(expired) | |
| state.inventory[key] = fresh | |
| if total_expired_units > 0: | |
| report_lines.append( | |
| f" EXPIRED: {total_expired_units} units (${total_expired_value:,.2f} written off)" | |
| ) | |
| # ββ Step 3: Generate and fulfill demand βββββββββββββββββββββββββββ | |
| day_demand = 0.0 | |
| day_fulfilled = 0.0 | |
| day_critical_demand = 0.0 | |
| day_critical_fulfilled = 0.0 | |
| for product in self._task.products: | |
| for loc_id in product.locations: | |
| demand = self._generate_demand(product, loc_id, day) | |
| fulfilled = self._fefo_fulfill(product.product_id, loc_id, demand, day) | |
| day_demand += demand | |
| day_fulfilled += fulfilled | |
| if product.criticality == "CRITICAL": | |
| day_critical_demand += demand | |
| day_critical_fulfilled += fulfilled | |
| # Per-product daily tracking | |
| key = (loc_id, product.product_id) | |
| if key not in state.daily_product_demand: | |
| state.daily_product_demand[key] = [] | |
| state.daily_product_fulfilled[key] = [] | |
| state.daily_product_demand[key].append(demand) | |
| state.daily_product_fulfilled[key].append(fulfilled) | |
| state.daily_demand.append(day_demand) | |
| state.daily_fulfilled.append(day_fulfilled) | |
| state.daily_critical_demand.append(day_critical_demand) | |
| state.daily_critical_fulfilled.append(day_critical_fulfilled) | |
| day_svc = day_fulfilled / max(day_demand, 1) | |
| report_lines.append( | |
| f" DEMAND: {int(day_demand)} units | FULFILLED: {int(day_fulfilled)} ({100 * day_svc:.1f}%)" | |
| ) | |
| # ββ Step 4: Check capacity violations (task 2) ββββββββββββββββββββ | |
| if any(l.capacity is not None for l in self._task.locations): | |
| for location in self._task.locations: | |
| if location.capacity is None: | |
| continue | |
| current = sum( | |
| sum(lot.qty for lot in lots) | |
| for (lid, pid), lots in state.inventory.items() | |
| if lid == location.location_id | |
| ) | |
| if current > location.capacity: | |
| state.capacity_violation_days += 1 | |
| # ββ Step 5: Inject recall lot for task 3 (Day 2, silent) βββββββββ | |
| if self._task.name == "hospital_network_crisis" and day == 2: | |
| self._inject_recall_lot() | |
| # ββ Step 6: Advance day, reset budget, inject next-day events ββββ | |
| state.day += 1 | |
| state.actions_remaining = state.actions_per_shift | |
| self._update_active_events(state.day) | |
| self._inject_events_for_day(state.day) | |
| # ββ Step 7: Daily shaping reward ββββββββββββββββββββββββββββββββββ | |
| shaping = 0.0 | |
| day_service = day_fulfilled / max(day_demand, 1) | |
| shaping += 0.10 * day_service | |
| total_units = sum( | |
| lot.qty | |
| for lots in state.inventory.values() | |
| for lot in lots | |
| if lot.lot_id not in state.quarantined_lots | |
| ) | |
| shaping -= 0.00005 * total_units | |
| shaping -= min(0.30, state.daily_expired_lots * 0.10) | |
| shaping -= min(0.50, state.daily_stockout_count * 0.20) | |
| state.info_rewards_given_this_shift = set() | |
| state.daily_stockout_count = 0 | |
| state.daily_expired_lots = 0 | |
| # ββ Step 8: Compute terminal score & check done βββββββββββββββββββ | |
| from .grader import compute_reward | |
| final_score = compute_reward(state, self._task) | |
| done = state.day > state.max_days | |
| if done: | |
| report_lines.append( | |
| f"β βββ EPISODE COMPLETE β Final Score: {final_score:.3f} {'β' * 30}β£" | |
| ) | |
| total_d = sum(state.daily_demand) | |
| total_f = sum(state.daily_fulfilled) | |
| report_lines.append( | |
| f" Service Level: {total_f / max(total_d, 1) * 100:.1f}%" | |
| ) | |
| report_lines.append(f" Total Spend: ${state.total_spend:,.2f}") | |
| report_lines.append(f" Waste Value: ${state.total_wasted_value:,.2f}") | |
| report_lines.append(f"β{'β' * 68}β") | |
| self._done = True | |
| self._last_reward = final_score | |
| return "\n".join(report_lines) | |
| self._done = False | |
| self._last_reward = shaping | |
| report_lines.append( | |
| f"ββββ Day {day} committed. Day {state.day} begins. {'β' * 38}β" | |
| ) | |
| report_lines.append("") | |
| from .erp_formatter import format_dashboard | |
| report_lines.append(format_dashboard(state, self._task)) | |
| return "\n".join(report_lines) | |
| # ββ Private Helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _generate_demand(self, product, location_id: str, day: int) -> int: | |
| import math as _math | |
| state = self._state | |
| base = product.base_demand | |
| if product.seasonal_amplitude > 0 and product.seasonal_period > 0: | |
| seasonal = product.seasonal_amplitude * _math.sin( | |
| 2 * _math.pi * day / product.seasonal_period + product.seasonal_phase | |
| ) | |
| base *= (1 + seasonal) | |
| for event_id, last_day in state.active_events.items(): | |
| event = next((e for e in self._task.events if e.event_id == event_id), None) | |
| if event is None: | |
| continue | |
| if event.event_type == "mci": | |
| if ( | |
| product.criticality in ("CRITICAL", "HIGH") | |
| and location_id in event.params.get("locations", []) | |
| ): | |
| base *= event.params.get("demand_multiplier", 3.0) | |
| elif event.event_type == "demand_surge": | |
| if product.product_id in event.params.get("products", []): | |
| base *= event.params.get("multiplier", 1.4) | |
| noise = state.rng.normal(0, product.demand_std) | |
| return max(0, int(round(base + noise))) | |
| def _fefo_fulfill( | |
| self, product_id: str, location_id: str, demand: int, day: int | |
| ) -> int: | |
| state = self._state | |
| key = (location_id, product_id) | |
| lots = state.inventory.get(key, []) | |
| lots_sorted = sorted( | |
| [l for l in lots if l.lot_id not in state.quarantined_lots and l.qty > 0], | |
| key=lambda l: (l.expiry_day is None, l.expiry_day or 0), | |
| ) | |
| fulfilled = 0 | |
| for lot in lots_sorted: | |
| if fulfilled >= demand: | |
| break | |
| take = min(demand - fulfilled, lot.qty) | |
| lot.qty -= take | |
| fulfilled += take | |
| state.inventory[key] = [l for l in lots if l.qty > 0] | |
| if fulfilled < demand: | |
| state.daily_stockout_count += 1 | |
| return fulfilled | |
| def _update_active_events(self, day: int): | |
| state = self._state | |
| state.active_events = { | |
| eid: last_day | |
| for eid, last_day in state.active_events.items() | |
| if last_day >= day | |
| } | |
| for event in self._task.events: | |
| if event.trigger_day == day and event.duration_days > 0: | |
| state.active_events[event.event_id] = day + event.duration_days - 1 | |
| def _inject_events_for_day(self, day: int): | |
| state = self._state | |
| for event in self._task.events: | |
| if event.trigger_day == day: | |
| msg = InboxMessage( | |
| msg_id=f"MSG-{state.msg_counter:04d}", | |
| priority=event.message.priority, | |
| timestamp_str=f"Day {day} 06:00", | |
| sender=event.message.sender, | |
| subject=event.message.subject, | |
| body=event.message.body, | |
| read=False, | |
| flagged=(event.message.priority == "CRITICAL"), | |
| event_id=event.event_id, | |
| ) | |
| state.inbox.append(msg) | |
| state.msg_counter += 1 | |
| if event.event_type == "cold_chain_breach": | |
| self._apply_cold_chain_breach(event) | |
| if event.event_type == "budget_tighten": | |
| state.budget_limit = event.params["new_budget_limit"] | |
| if event.warning_message and event.trigger_day - 1 == day: | |
| msg = InboxMessage( | |
| msg_id=f"MSG-{state.msg_counter:04d}", | |
| priority=event.warning_message.priority, | |
| timestamp_str=f"Day {day} 18:00", | |
| sender=event.warning_message.sender, | |
| subject=event.warning_message.subject, | |
| body=event.warning_message.body, | |
| read=False, | |
| flagged=False, | |
| event_id=f"{event.event_id}_warning", | |
| ) | |
| state.inbox.append(msg) | |
| state.msg_counter += 1 | |
| def _apply_cold_chain_breach(self, event: SimEvent): | |
| state = self._state | |
| loc = event.params["location_id"] | |
| prod = event.params["product_id"] | |
| key = (loc, prod) | |
| for lot in state.inventory.get(key, []): | |
| state.quarantined_lots.add(lot.lot_id) | |
| def _inject_recall_lot(self): | |
| state = self._state | |
| recall_lot_id = "RECALL-LOT-IV2026-9821" | |
| for event in self._task.events: | |
| if event.event_id == "iv_saline_recall": | |
| qty = event.params["qty_per_location"] | |
| product = next( | |
| (p for p in self._task.products if p.product_id == "IV-SAL-500"), None | |
| ) | |
| if product is None: | |
| break | |
| for loc_id in event.params["locations_with_lot"]: | |
| key = (loc_id, "IV-SAL-500") | |
| if key not in state.inventory: | |
| state.inventory[key] = [] | |
| lot = Lot( | |
| lot_id=recall_lot_id, | |
| qty=qty, | |
| expiry_day=None, | |
| cost_per_unit=product.unit_cost, | |
| ) | |
| state.inventory[key].append(lot) | |
| break | |
| def _check_recall_completion(self): | |
| state = self._state | |
| recall_lot_id = "RECALL-LOT-IV2026-9821" | |
| if recall_lot_id not in state.quarantined_lots: | |
| return | |
| if state.recall_handled_by_day is None: | |
| state.recall_handled_by_day = state.day | |
| # ββ Accessors used by MedchainEnvironment ββββββββββββββββββββββββββββββ | |
| def get_last_reward(self) -> float: | |
| return self._last_reward | |
| def is_done(self) -> bool: | |
| return self._done | |