openenv-warehouse / models.py
flamingo44333's picture
Upload models.py
f35152b verified
"""
env_core/models.py
Typed Pydantic models for the OpenEnv Warehouse environment.
"""
from __future__ import annotations
from enum import Enum
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
# ─────────────────────────── Domain enums ────────────────────────────
class AlertSeverity(str, Enum):
CRITICAL = "CRITICAL" # stockout imminent (< safety stock)
WARNING = "WARNING" # below reorder point but not critical
OK = "OK" # healthy stock level
class ActionType(str, Enum):
TRIAGE_ALERT = "triage_alert" # classify a pending alert
PLACE_ORDER = "place_order" # order N units of SKU from supplier
RESOLVE_CONFLICT = "resolve_conflict" # pick correct count when records conflict
SKIP = "skip" # explicitly do nothing this step
class OrderStatus(str, Enum):
PENDING = "pending"
CONFIRMED = "confirmed"
DELIVERED = "delivered"
CANCELLED = "cancelled"
# ─────────────────────────── Sub-models ──────────────────────────────
class SKU(BaseModel):
sku_id: str
name: str
current_stock: int
safety_stock: int # below this β†’ CRITICAL
reorder_point: int # below this β†’ WARNING
max_capacity: int
unit_cost: float # $ per unit
days_until_expiry: Optional[int] = None # None = non-perishable
lead_time_days: int = 2 # supplier lead time
class Supplier(BaseModel):
supplier_id: str
name: str
skus_offered: List[str] # SKU IDs this supplier carries
price_multiplier: float = 1.0 # relative to unit_cost
min_order_qty: int = 1
max_order_qty: int = 500
reliability_score: float = 1.0 # 0–1, affects on-time delivery
class Alert(BaseModel):
alert_id: str
sku_id: str
message: str
raw_stock: int # stock level that triggered the alert
expected_severity: Optional[AlertSeverity] = None # set by grader, hidden from agent
triaged: bool = False
agent_classification: Optional[AlertSeverity] = None
class OrderRecord(BaseModel):
order_id: str
sku_id: str
supplier_id: str
quantity: int
total_cost: float
day_placed: int
expected_delivery_day: int
status: OrderStatus = OrderStatus.PENDING
class StockConflict(BaseModel):
conflict_id: str
sku_id: str
system_count: int # what the WMS says
physical_count: int # what a manual audit found
correct_count: Optional[int] = None # ground truth, hidden from agent
# ─────────────────────────── Observation ─────────────────────────────
class Observation(BaseModel):
"""Full observation returned to the agent each step."""
day: int # current simulation day (0-indexed)
budget_remaining: float # $ left this week
skus: List[SKU]
pending_alerts: List[Alert]
suppliers: List[Supplier]
open_orders: List[OrderRecord]
stock_conflicts: List[StockConflict]
recent_actions: List[str] = Field(default_factory=list) # human-readable log
task_hint: str = "" # natural-language description of what to do
# ─────────────────────────── Action ──────────────────────────────────
class Action(BaseModel):
"""Structured action the agent sends to step()."""
action_type: ActionType
# for TRIAGE_ALERT
alert_id: Optional[str] = None
severity_classification: Optional[AlertSeverity] = None
# for PLACE_ORDER
sku_id: Optional[str] = None
supplier_id: Optional[str] = None
quantity: Optional[int] = None
# for RESOLVE_CONFLICT
conflict_id: Optional[str] = None
chosen_count: Optional[int] = None
# free-form rationale (logged but not scored)
rationale: Optional[str] = None
# ─────────────────────────── Reward ──────────────────────────────────
class Reward(BaseModel):
"""Shaped step reward with breakdown for transparency."""
total: float = Field(..., ge=-1.0, le=1.0)
correctness: float = 0.0 # classification / conflict accuracy
efficiency: float = 0.0 # cost minimization bonus
coverage: float = 0.0 # critical alerts handled
penalty: float = 0.0 # stockouts, over-budget, ignored criticals
explanation: str = ""