nik-55's picture
Upload folder using huggingface_hub
4afc4db verified
"""
Core simulation engine for the MedChain Env environment.
MedchainSimulation manages the full episode lifecycle:
- Inventory tracked as FEFO lots per (location, product)
- Purchase order pipeline with stochastic lead times
- Event-driven inbox messages (crises, recalls, demand surges)
- Daily demand generation and fulfillment
"""
from __future__ import annotations
import uuid
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Set, Tuple
import numpy as np
from .tasks import SimEvent, TaskConfig
# ─── Simulation Dataclasses ───────────────────────────────────────────────────
@dataclass
class Lot:
lot_id: str
qty: int
expiry_day: Optional[int] # None = non-perishable. Expired when current_day >= expiry_day.
cost_per_unit: float
@dataclass
class PurchaseOrder:
po_id: str
supplier_id: str
product_id: str
destination_id: str
quantity: int
priority: str # "standard" or "expedited"
day_submitted: int
eta_day: int
unit_cost: float
total_cost: float
status: str # "pending_justification", "in_transit", "delivered"
lot_id: str
@dataclass
class PendingBudgetOverride:
ticket_id: str
po: PurchaseOrder
@dataclass
class InboxMessage:
msg_id: str
priority: str
timestamp_str: str # "Day {n} {HH:MM}"
sender: str
subject: str
body: str
read: bool
flagged: bool
event_id: str
@dataclass
class JustificationRecord:
ticket_id: str
po_id: str
reason: str
is_coherent: bool
@dataclass
class SimState:
# Episode meta
task: str
episode_id: str
seed: int
rng: np.random.Generator
# Time
day: int
max_days: int
# Action budget
actions_remaining: int
actions_per_shift: int
# Budget
budget_used: float
budget_limit: float
# Inventory: (location_id, product_id) -> List[Lot] (FEFO-sorted)
inventory: Dict[Tuple[str, str], List[Lot]]
# Orders
pipeline_orders: List[PurchaseOrder]
po_counter: int
# Inbox
inbox: List[InboxMessage]
msg_counter: int
# Budget override tickets
pending_overrides: Dict[str, PendingBudgetOverride]
# Quarantine
quarantined_lots: Set[str]
# Demand / fulfillment tracking (one value per completed day)
daily_demand: List[float]
daily_fulfilled: List[float]
daily_critical_demand: List[float]
daily_critical_fulfilled: List[float]
# Per-(location, product) daily tracking (for demand_history queries)
daily_product_demand: Dict[Tuple[str, str], List[int]]
daily_product_fulfilled: Dict[Tuple[str, str], List[int]]
# Spend tracking
total_spend: float
total_wasted_value: float
# Transfer tracking (task 2)
transfer_count: int
transfer_cost_paid: float
# Capacity violations (task 2)
capacity_violation_days: int
# Active event effects: event_id -> last_day_active (inclusive)
active_events: Dict[str, int]
# Per-shift shaping reward helpers
info_rewards_given_this_shift: Set[str]
daily_stockout_count: int
daily_expired_lots: int
# Task 3 crisis tracking
justification_log: List[JustificationRecord]
mci_preemptive_order: bool
recall_handled_by_day: Optional[int]
# ─── MedchainSimulation ───────────────────────────────────────────────────────
class MedchainSimulation:
"""
Core simulation engine. Called by MedchainEnvironment's MCP tools.
All public tool methods return a string (displayed to agent as ERP output).
end_shift_tool() also stores _last_reward and _done for the environment.
"""
def __init__(self, task_config: TaskConfig):
self._task = task_config
self._state: Optional[SimState] = None
self._last_reward: float = 0.0
self._done: bool = False
# ── Called by environment.reset() ──────────────────────────────────────
def reset(self, seed: int, episode_id: str) -> str:
"""Initialize a new episode. Returns dashboard text."""
self._done = False
self._last_reward = 0.0
rng = np.random.default_rng(seed)
self._state = SimState(
task=self._task.name,
episode_id=episode_id,
seed=seed,
rng=rng,
day=1,
max_days=self._task.max_days,
actions_remaining=self._task.actions_per_shift,
actions_per_shift=self._task.actions_per_shift,
budget_used=0.0,
budget_limit=self._task.budget_limit,
inventory={},
pipeline_orders=[],
po_counter=1,
inbox=[],
msg_counter=1,
pending_overrides={},
quarantined_lots=set(),
daily_demand=[],
daily_fulfilled=[],
daily_critical_demand=[],
daily_critical_fulfilled=[],
daily_product_demand={},
daily_product_fulfilled={},
total_spend=0.0,
total_wasted_value=0.0,
transfer_count=0,
transfer_cost_paid=0.0,
capacity_violation_days=0,
active_events={},
info_rewards_given_this_shift=set(),
daily_stockout_count=0,
daily_expired_lots=0,
justification_log=[],
mci_preemptive_order=False,
recall_handled_by_day=None,
)
self._initialize_inventory()
self._inject_day1_inbox()
from .erp_formatter import format_dashboard
return format_dashboard(self._state, self._task)
def _initialize_inventory(self):
"""Seed initial inventory: initial_stock_days Γ— base_demand per location/product."""
state = self._state
for product in self._task.products:
for loc_id in product.locations:
key = (loc_id, product.product_id)
qty = int(product.base_demand * self._task.initial_stock_days)
expiry_day = (
1 + int(product.shelf_life_days * 0.7)
if product.shelf_life_days is not None
else None
)
lot = Lot(
lot_id=f"INIT-{product.product_id}-{loc_id}",
qty=qty,
expiry_day=expiry_day,
cost_per_unit=product.unit_cost,
)
state.inventory[key] = [lot]
def _inject_day1_inbox(self):
"""Add Day 1 inbox messages (welcome + any Day 1 events)."""
state = self._state
welcome = InboxMessage(
msg_id=f"MSG-{state.msg_counter:04d}",
priority="LOW",
timestamp_str="Day 1 08:00",
sender="System",
subject="Shift Handover Notes",
body=(
f"Welcome to the {self._task.name} scenario.\n"
f"You are managing medical supplies for {self._task.max_days} days.\n"
f"Action budget: {self._task.actions_per_shift} actions per shift.\n"
f"Budget ceiling: ${self._task.budget_limit:,.0f} outstanding orders.\n\n"
"Use read_inbox to check messages, query_erp to check stock,\n"
"submit_po to order supplies, and end_shift to advance the day."
),
read=False,
flagged=False,
event_id="system_welcome",
)
state.inbox.append(welcome)
state.msg_counter += 1
self._inject_events_for_day(1)
# ── Action Budget Helper ────────────────────────────────────────────────
def _check_action_budget(self, tool_name: str) -> Optional[str]:
"""Returns error string if budget exhausted, None if OK. Does NOT decrement."""
if tool_name == "end_shift":
return None
if self._state is None:
return "ERROR: Environment not initialized. Call reset() first."
if self._state.actions_remaining <= 0:
return (
"ERROR: Action budget exhausted for this shift.\n"
f"Actions used: {self._state.actions_per_shift}/{self._state.actions_per_shift}\n"
"Call end_shift() to advance to the next day and restore your action budget."
)
return None
# ── MCP Tool Implementations ────────────────────────────────────────────
def read_inbox(self, filter: str = "unread") -> str:
err = self._check_action_budget("read_inbox")
if err:
return err
self._state.actions_remaining -= 1
messages = list(self._state.inbox)
if filter == "unread":
messages = [m for m in messages if not m.read]
elif filter == "flagged":
messages = [m for m in messages if m.flagged]
# "all" β†’ use full inbox
for m in messages:
m.read = True
if not messages:
return f"INBOX EMPTY\nFilter: {filter} | No messages matching filter."
lines = []
for m in messages:
read_status = "READ" if m.read else "UNREAD"
lines.append(
f"\n[MSG {m.msg_id} | {read_status} | PRIORITY: {m.priority} | {m.timestamp_str}]"
)
lines.append(f"FROM: {m.sender}")
lines.append(f"SUBJ: {m.subject}")
lines.append("")
lines.append(m.body)
lines.append("")
return "\n".join(lines)
def query_erp(self, table: str, location: str = "all", sku: str = "all") -> str:
err = self._check_action_budget("query_erp")
if err:
return err
self._state.actions_remaining -= 1
valid_tables = ["inventory", "expiry", "pipeline_orders", "demand_history"]
if table not in valid_tables:
return f"ERROR: Unknown table '{table}'. Valid tables: {valid_tables}"
from .erp_formatter import (
format_demand_history,
format_expiry_table,
format_inventory_table,
format_pipeline_table,
)
if table == "inventory":
return format_inventory_table(self._state, self._task, location, sku)
elif table == "expiry":
return format_expiry_table(self._state, self._task, location, sku)
elif table == "pipeline_orders":
return format_pipeline_table(self._state, location, sku)
elif table == "demand_history":
return format_demand_history(self._state, self._task, location, sku)
return "ERROR: Unexpected table."
def query_supplier(self, supplier_id: str) -> str:
err = self._check_action_budget("query_supplier")
if err:
return err
self._state.actions_remaining -= 1
supplier = next((s for s in self._task.suppliers if s.supplier_id == supplier_id), None)
if not supplier:
available = [s.supplier_id for s in self._task.suppliers]
return f"ERROR: Supplier '{supplier_id}' not found. Available: {available}"
effective_lead_time = supplier.base_lead_time
disruption_note = "No disruptions reported."
for event_id, last_day in self._state.active_events.items():
event = next((e for e in self._task.events if e.event_id == event_id), None)
if (
event
and event.event_type == "supplier_disruption"
and event.params.get("supplier_id") == supplier_id
):
effective_lead_time = event.params["new_lead_time"]
disruption_note = (
f"ACTIVE DISRUPTION: Lead time extended to {effective_lead_time} days. "
f"Reason: {event.params['reason']}"
)
from .erp_formatter import format_supplier_info
return format_supplier_info(supplier, effective_lead_time, disruption_note)
def query_forecast(self, product_id: str, location_id: str, horizon_days: int = 7) -> str:
err = self._check_action_budget("query_forecast")
if err:
return err
self._state.actions_remaining -= 1
horizon_days = max(1, min(21, horizon_days))
product = next((p for p in self._task.products if p.product_id == product_id), None)
if not product:
return f"ERROR: Product '{product_id}' not found."
if location_id not in product.locations and location_id != "all":
return f"ERROR: Product '{product_id}' is not stocked at '{location_id}'."
from .erp_formatter import format_forecast
return format_forecast(self._state, self._task, product, location_id, horizon_days)
def submit_po(
self,
supplier_id: str,
product_id: str,
destination_id: str,
quantity: int,
priority: str = "standard",
) -> str:
err = self._check_action_budget("submit_po")
if err:
return err
if priority not in ("standard", "expedited"):
return "ERROR: priority must be 'standard' or 'expedited'."
if quantity <= 0:
return "ERROR: quantity must be positive."
supplier = next((s for s in self._task.suppliers if s.supplier_id == supplier_id), None)
if not supplier:
return f"ERROR: Supplier '{supplier_id}' not found."
if product_id not in supplier.products:
return f"ERROR: Supplier '{supplier_id}' does not supply '{product_id}'."
valid_locs = [l.location_id for l in self._task.locations]
if destination_id not in valid_locs:
return f"ERROR: Destination '{destination_id}' not found. Valid: {valid_locs}"
product = next((p for p in self._task.products if p.product_id == product_id), None)
expedited_multiplier = 1.5 if priority == "expedited" else 1.0
unit_cost = product.unit_cost * supplier.cost_multiplier * expedited_multiplier
total_cost = unit_cost * quantity
if self._state.budget_used + total_cost > self._state.budget_limit:
overage = (self._state.budget_used + total_cost) - self._state.budget_limit
return (
f"ERROR: BUDGET_EXCEEDED\n"
f"Order cost: ${total_cost:,.2f} | "
f"Current outstanding: ${self._state.budget_used:,.2f} | "
f"Limit: ${self._state.budget_limit:,.2f}\n"
f"Overage: ${overage:,.2f}\n"
f"Reduce order quantity or wait for existing orders to be delivered."
)
# Effective lead time (check active disruptions)
lead_time = supplier.base_lead_time
for event_id, last_day in self._state.active_events.items():
event = next((e for e in self._task.events if e.event_id == event_id), None)
if (
event
and event.event_type == "supplier_disruption"
and event.params.get("supplier_id") == supplier_id
):
lead_time = event.params["new_lead_time"]
if priority == "expedited":
lead_time = max(1, lead_time - 2)
# Stochastic jitter for task 3
if supplier.lead_time_std > 0:
jitter = int(round(self._state.rng.normal(0, supplier.lead_time_std)))
lead_time = max(1, lead_time + jitter)
eta_day = self._state.day + lead_time
po_id = f"POD-{self._state.po_counter:04d}"
lot_id = f"LOT-{po_id}"
self._state.po_counter += 1
# Expedited: requires justification
if priority == "expedited":
ticket_id = f"BOT-{self._state.po_counter:04d}"
self._state.po_counter += 1
po = PurchaseOrder(
po_id=po_id, supplier_id=supplier_id, product_id=product_id,
destination_id=destination_id, quantity=quantity, priority=priority,
day_submitted=self._state.day, eta_day=eta_day, unit_cost=unit_cost,
total_cost=total_cost, status="pending_justification", lot_id=lot_id,
)
self._state.pending_overrides[ticket_id] = PendingBudgetOverride(
ticket_id=ticket_id, po=po
)
self._state.actions_remaining -= 1
return (
f"ERROR: BUDGET_OVERRIDE_REQUIRED\n"
f"Order {po_id} ({priority}, ${total_cost:,.2f} incl. 50% expedite premium) "
f"requires justification.\n"
f"Ticket ID: {ticket_id}\n"
f"Use file_justification(ticket_id=\"{ticket_id}\", reason=\"...\") to proceed.\n"
f"Justification will be audited by Finance. False justifications are flagged."
)
# Standard order: submit immediately
self._state.actions_remaining -= 1
po = PurchaseOrder(
po_id=po_id, supplier_id=supplier_id, product_id=product_id,
destination_id=destination_id, quantity=quantity, priority=priority,
day_submitted=self._state.day, eta_day=eta_day, unit_cost=unit_cost,
total_cost=total_cost, status="in_transit", lot_id=lot_id,
)
self._state.pipeline_orders.append(po)
self._state.budget_used += total_cost
return (
f"OK β€” PO {po_id} submitted.\n"
f"Product: {product_id} Γ— {quantity} units\n"
f"Supplier: {supplier_id} | Priority: {priority}\n"
f"Destination: {destination_id} | ETA: Day {eta_day}\n"
f"Cost: ${total_cost:,.2f} | "
f"Budget remaining: ${self._state.budget_limit - self._state.budget_used:,.2f}"
)
def transfer(
self,
from_location_id: str,
to_location_id: str,
product_id: str,
quantity: int,
) -> str:
err = self._check_action_budget("transfer")
if err:
return err
self._state.actions_remaining -= 1
if quantity <= 0:
return "ERROR: quantity must be positive."
valid_locs = {l.location_id for l in self._task.locations}
if from_location_id not in valid_locs:
return f"ERROR: Location '{from_location_id}' not found."
if to_location_id not in valid_locs:
return f"ERROR: Location '{to_location_id}' not found."
key_from = (from_location_id, product_id)
lots = sorted(
[
l for l in self._state.inventory.get(key_from, [])
if l.lot_id not in self._state.quarantined_lots
],
key=lambda l: (l.expiry_day is None, l.expiry_day or 0),
)
available = sum(l.qty for l in lots)
if available < quantity:
return (
f"ERROR: Insufficient stock at {from_location_id}. "
f"Available: {available} units of {product_id}."
)
# Check destination capacity (task 2)
dest_loc = next(
(l for l in self._task.locations if l.location_id == to_location_id), None
)
if dest_loc and dest_loc.capacity is not None:
current_at_dest = sum(
sum(lot.qty for lot in lots2)
for (loc, pid), lots2 in self._state.inventory.items()
if loc == to_location_id
)
if current_at_dest + quantity > dest_loc.capacity:
return (
f"ERROR: CAPACITY_EXCEEDED β€” {to_location_id} capacity {dest_loc.capacity}. "
f"Current: {current_at_dest}, Transfer: {quantity}."
)
# FEFO transfer
remaining = quantity
key_to = (to_location_id, product_id)
if key_to not in self._state.inventory:
self._state.inventory[key_to] = []
for lot in lots:
if remaining <= 0:
break
take = min(remaining, lot.qty)
lot.qty -= take
remaining -= take
self._state.inventory[key_to].append(
Lot(
lot_id=f"XFR-{lot.lot_id}",
qty=take,
expiry_day=lot.expiry_day,
cost_per_unit=lot.cost_per_unit,
)
)
self._state.inventory[key_from] = [
l for l in self._state.inventory[key_from] if l.qty > 0
]
TRANSFER_FEE = 0.5
fee = quantity * TRANSFER_FEE
self._state.transfer_count += 1
self._state.transfer_cost_paid += fee
return (
f"OK β€” Transfer complete.\n"
f"{quantity} units of {product_id}: {from_location_id} β†’ {to_location_id}\n"
f"Transfer fee: ${fee:.2f}"
)
def quarantine_lot(self, location_id: str, sku: str, lot_id: str) -> str:
err = self._check_action_budget("quarantine_lot")
if err:
return err
self._state.actions_remaining -= 1
valid_locs = {l.location_id for l in self._task.locations}
if location_id not in valid_locs:
return f"ERROR: Location '{location_id}' not found."
key = (location_id, sku)
lots = self._state.inventory.get(key, [])
if lot_id == "all":
target_lots = [l for l in lots]
else:
target_lots = [l for l in lots if l.lot_id == lot_id]
if not target_lots:
target_lots = [l for l in lots if lot_id in l.lot_id]
if not target_lots:
available_lots = [l.lot_id for l in lots]
return (
f"ERROR: Lot '{lot_id}' not found at {location_id} for SKU {sku}. "
f"Available lots: {available_lots}"
)
quarantined_qty = 0
disposal_ids = []
for lot in target_lots:
if lot.lot_id not in self._state.quarantined_lots:
self._state.quarantined_lots.add(lot.lot_id)
quarantined_qty += lot.qty
disposal_ids.append(lot.lot_id)
# Track recall completion for task 3
if sku == "IV-SAL-500" and "RECALL-LOT" in lot_id:
self._check_recall_completion()
disposal_ticket = f"DIS-{self._state.po_counter:04d}"
self._state.po_counter += 1
return (
f"OK β€” Quarantine complete.\n"
f"SKU: {sku} | Location: {location_id}\n"
f"Lots quarantined: {disposal_ids}\n"
f"Units quarantined: {quarantined_qty}\n"
f"Disposal ticket: {disposal_ticket} created."
)
def file_justification(self, ticket_id: str, reason: str) -> str:
err = self._check_action_budget("file_justification")
if err:
return err
self._state.actions_remaining -= 1
if ticket_id not in self._state.pending_overrides:
return (
f"ERROR: Ticket '{ticket_id}' not found or already processed.\n"
f"Active tickets: {list(self._state.pending_overrides.keys())}"
)
override = self._state.pending_overrides.pop(ticket_id)
po = override.po
active_event_types: Set[str] = set()
for event_id in self._state.active_events:
event = next((e for e in self._task.events if e.event_id == event_id), None)
if event:
active_event_types.add(event.event_type)
from .grader import grade_justification
is_coherent = grade_justification(reason, active_event_types)
record = JustificationRecord(
ticket_id=ticket_id, po_id=po.po_id, reason=reason, is_coherent=is_coherent
)
self._state.justification_log.append(record)
po.status = "in_transit"
self._state.pipeline_orders.append(po)
self._state.budget_used += po.total_cost
audit_note = ""
if not is_coherent:
audit_note = (
"\n[AUDIT FLAG] Justification does not reference active crisis conditions. "
"Flagged for Finance review. Penalty applied."
)
return (
f"OK β€” Justification {'accepted' if is_coherent else 'FLAGGED'}. "
f"PO {po.po_id} submitted.\n"
f"Product: {po.product_id} Γ— {po.quantity} units | Destination: {po.destination_id}\n"
f"ETA: Day {po.eta_day} | Cost: ${po.total_cost:,.2f}"
f"{audit_note}"
)
def end_shift_tool(self) -> str:
"""Advance simulation by one day. Stores _last_reward and _done."""
state = self._state
if state is None:
return "ERROR: Environment not initialized."
day = state.day
report_lines = [f"╔═══ END OF SHIFT β€” Day {day} {'═' * 40}β•—"]
# ── Step 1: Deliver arriving orders ──────────────────────────────
delivered = []
for po in list(state.pipeline_orders):
if po.eta_day <= day:
product = next(
(p for p in self._task.products if p.product_id == po.product_id), None
)
key = (po.destination_id, po.product_id)
if key not in state.inventory:
state.inventory[key] = []
expiry_day = (day + product.shelf_life_days) if product.shelf_life_days else None
lot = Lot(
lot_id=po.lot_id, qty=po.quantity,
expiry_day=expiry_day, cost_per_unit=po.unit_cost
)
state.inventory[key].append(lot)
state.budget_used -= po.total_cost
state.total_spend += po.total_cost
po.status = "delivered"
delivered.append(po)
state.pipeline_orders = [po for po in state.pipeline_orders if po.status != "delivered"]
if delivered:
report_lines.append(f" DELIVERIES: {len(delivered)} order(s) received.")
# ── Step 2: Expire old lots ───────────────────────────────────────
total_expired_units = 0
total_expired_value = 0.0
for key in list(state.inventory.keys()):
fresh, expired = [], []
for lot in state.inventory[key]:
if lot.expiry_day is not None and lot.expiry_day <= day:
expired.append(lot)
else:
fresh.append(lot)
if expired:
for lot in expired:
total_expired_units += lot.qty
total_expired_value += lot.qty * lot.cost_per_unit
state.total_wasted_value += lot.qty * lot.cost_per_unit
state.daily_expired_lots += len(expired)
state.inventory[key] = fresh
if total_expired_units > 0:
report_lines.append(
f" EXPIRED: {total_expired_units} units (${total_expired_value:,.2f} written off)"
)
# ── Step 3: Generate and fulfill demand ───────────────────────────
day_demand = 0.0
day_fulfilled = 0.0
day_critical_demand = 0.0
day_critical_fulfilled = 0.0
for product in self._task.products:
for loc_id in product.locations:
demand = self._generate_demand(product, loc_id, day)
fulfilled = self._fefo_fulfill(product.product_id, loc_id, demand, day)
day_demand += demand
day_fulfilled += fulfilled
if product.criticality == "CRITICAL":
day_critical_demand += demand
day_critical_fulfilled += fulfilled
# Per-product daily tracking
key = (loc_id, product.product_id)
if key not in state.daily_product_demand:
state.daily_product_demand[key] = []
state.daily_product_fulfilled[key] = []
state.daily_product_demand[key].append(demand)
state.daily_product_fulfilled[key].append(fulfilled)
state.daily_demand.append(day_demand)
state.daily_fulfilled.append(day_fulfilled)
state.daily_critical_demand.append(day_critical_demand)
state.daily_critical_fulfilled.append(day_critical_fulfilled)
day_svc = day_fulfilled / max(day_demand, 1)
report_lines.append(
f" DEMAND: {int(day_demand)} units | FULFILLED: {int(day_fulfilled)} ({100 * day_svc:.1f}%)"
)
# ── Step 4: Check capacity violations (task 2) ────────────────────
if any(l.capacity is not None for l in self._task.locations):
for location in self._task.locations:
if location.capacity is None:
continue
current = sum(
sum(lot.qty for lot in lots)
for (lid, pid), lots in state.inventory.items()
if lid == location.location_id
)
if current > location.capacity:
state.capacity_violation_days += 1
# ── Step 5: Inject recall lot for task 3 (Day 2, silent) ─────────
if self._task.name == "hospital_network_crisis" and day == 2:
self._inject_recall_lot()
# ── Step 6: Advance day, reset budget, inject next-day events ────
state.day += 1
state.actions_remaining = state.actions_per_shift
self._update_active_events(state.day)
self._inject_events_for_day(state.day)
# ── Step 7: Daily shaping reward ──────────────────────────────────
shaping = 0.0
day_service = day_fulfilled / max(day_demand, 1)
shaping += 0.10 * day_service
total_units = sum(
lot.qty
for lots in state.inventory.values()
for lot in lots
if lot.lot_id not in state.quarantined_lots
)
shaping -= 0.00005 * total_units
shaping -= min(0.30, state.daily_expired_lots * 0.10)
shaping -= min(0.50, state.daily_stockout_count * 0.20)
state.info_rewards_given_this_shift = set()
state.daily_stockout_count = 0
state.daily_expired_lots = 0
# ── Step 8: Compute terminal score & check done ───────────────────
from .grader import compute_reward
final_score = compute_reward(state, self._task)
done = state.day > state.max_days
if done:
report_lines.append(
f"╠═══ EPISODE COMPLETE β€” Final Score: {final_score:.3f} {'═' * 30}β•£"
)
total_d = sum(state.daily_demand)
total_f = sum(state.daily_fulfilled)
report_lines.append(
f" Service Level: {total_f / max(total_d, 1) * 100:.1f}%"
)
report_lines.append(f" Total Spend: ${state.total_spend:,.2f}")
report_lines.append(f" Waste Value: ${state.total_wasted_value:,.2f}")
report_lines.append(f"β•š{'═' * 68}╝")
self._done = True
self._last_reward = final_score
return "\n".join(report_lines)
self._done = False
self._last_reward = shaping
report_lines.append(
f"β•šβ•β•β• Day {day} committed. Day {state.day} begins. {'═' * 38}╝"
)
report_lines.append("")
from .erp_formatter import format_dashboard
report_lines.append(format_dashboard(state, self._task))
return "\n".join(report_lines)
# ── Private Helpers ────────────────────────────────────────────────────
def _generate_demand(self, product, location_id: str, day: int) -> int:
import math as _math
state = self._state
base = product.base_demand
if product.seasonal_amplitude > 0 and product.seasonal_period > 0:
seasonal = product.seasonal_amplitude * _math.sin(
2 * _math.pi * day / product.seasonal_period + product.seasonal_phase
)
base *= (1 + seasonal)
for event_id, last_day in state.active_events.items():
event = next((e for e in self._task.events if e.event_id == event_id), None)
if event is None:
continue
if event.event_type == "mci":
if (
product.criticality in ("CRITICAL", "HIGH")
and location_id in event.params.get("locations", [])
):
base *= event.params.get("demand_multiplier", 3.0)
elif event.event_type == "demand_surge":
if product.product_id in event.params.get("products", []):
base *= event.params.get("multiplier", 1.4)
noise = state.rng.normal(0, product.demand_std)
return max(0, int(round(base + noise)))
def _fefo_fulfill(
self, product_id: str, location_id: str, demand: int, day: int
) -> int:
state = self._state
key = (location_id, product_id)
lots = state.inventory.get(key, [])
lots_sorted = sorted(
[l for l in lots if l.lot_id not in state.quarantined_lots and l.qty > 0],
key=lambda l: (l.expiry_day is None, l.expiry_day or 0),
)
fulfilled = 0
for lot in lots_sorted:
if fulfilled >= demand:
break
take = min(demand - fulfilled, lot.qty)
lot.qty -= take
fulfilled += take
state.inventory[key] = [l for l in lots if l.qty > 0]
if fulfilled < demand:
state.daily_stockout_count += 1
return fulfilled
def _update_active_events(self, day: int):
state = self._state
state.active_events = {
eid: last_day
for eid, last_day in state.active_events.items()
if last_day >= day
}
for event in self._task.events:
if event.trigger_day == day and event.duration_days > 0:
state.active_events[event.event_id] = day + event.duration_days - 1
def _inject_events_for_day(self, day: int):
state = self._state
for event in self._task.events:
if event.trigger_day == day:
msg = InboxMessage(
msg_id=f"MSG-{state.msg_counter:04d}",
priority=event.message.priority,
timestamp_str=f"Day {day} 06:00",
sender=event.message.sender,
subject=event.message.subject,
body=event.message.body,
read=False,
flagged=(event.message.priority == "CRITICAL"),
event_id=event.event_id,
)
state.inbox.append(msg)
state.msg_counter += 1
if event.event_type == "cold_chain_breach":
self._apply_cold_chain_breach(event)
if event.event_type == "budget_tighten":
state.budget_limit = event.params["new_budget_limit"]
if event.warning_message and event.trigger_day - 1 == day:
msg = InboxMessage(
msg_id=f"MSG-{state.msg_counter:04d}",
priority=event.warning_message.priority,
timestamp_str=f"Day {day} 18:00",
sender=event.warning_message.sender,
subject=event.warning_message.subject,
body=event.warning_message.body,
read=False,
flagged=False,
event_id=f"{event.event_id}_warning",
)
state.inbox.append(msg)
state.msg_counter += 1
def _apply_cold_chain_breach(self, event: SimEvent):
state = self._state
loc = event.params["location_id"]
prod = event.params["product_id"]
key = (loc, prod)
for lot in state.inventory.get(key, []):
state.quarantined_lots.add(lot.lot_id)
def _inject_recall_lot(self):
state = self._state
recall_lot_id = "RECALL-LOT-IV2026-9821"
for event in self._task.events:
if event.event_id == "iv_saline_recall":
qty = event.params["qty_per_location"]
product = next(
(p for p in self._task.products if p.product_id == "IV-SAL-500"), None
)
if product is None:
break
for loc_id in event.params["locations_with_lot"]:
key = (loc_id, "IV-SAL-500")
if key not in state.inventory:
state.inventory[key] = []
lot = Lot(
lot_id=recall_lot_id,
qty=qty,
expiry_day=None,
cost_per_unit=product.unit_cost,
)
state.inventory[key].append(lot)
break
def _check_recall_completion(self):
state = self._state
recall_lot_id = "RECALL-LOT-IV2026-9821"
if recall_lot_id not in state.quarantined_lots:
return
if state.recall_handled_by_day is None:
state.recall_handled_by_day = state.day
# ── Accessors used by MedchainEnvironment ──────────────────────────────
def get_last_reward(self) -> float:
return self._last_reward
def is_done(self) -> bool:
return self._done