dispatch_arena_v0 / models.py
Freakdivi's picture
Upload folder using huggingface_hub
c71bf62 verified
"""Public models for the Dispatch Arena environment."""
from __future__ import annotations
from enum import Enum
from typing import Any, Dict, List, Mapping, Optional
from pydantic import BaseModel, ConfigDict, Field, model_validator
class DispatchArenaModel(BaseModel):
"""Base model with JSON-friendly helpers used by server and client."""
model_config = ConfigDict(use_enum_values=False)
def to_dict(self) -> Dict[str, Any]:
return self.model_dump(mode="json", exclude_none=True)
class Mode(str, Enum):
MINI = "mini"
NORMAL = "normal"
class CourierStatus(str, Enum):
IDLE = "idle"
TO_PICKUP = "to_pickup"
WAITING_PICKUP = "waiting_pickup"
TO_DROPOFF = "to_dropoff"
REPOSITIONING = "repositioning"
class OrderStatus(str, Enum):
QUEUED = "queued"
READY = "ready"
PICKED = "picked"
DELIVERED = "delivered"
EXPIRED = "expired"
class VerifierVerdict(str, Enum):
IN_PROGRESS = "in_progress"
DELIVERED_SUCCESSFULLY = "delivered_successfully"
TIMEOUT_FAILURE = "timeout_failure"
PARTIAL_SUCCESS = "partial_success"
class MiniActionType(str, Enum):
WAIT = "wait"
GO_PICKUP = "go_pickup"
GO_DROPOFF = "go_dropoff"
PICKUP = "pickup"
DROPOFF = "dropoff"
class NormalActionType(str, Enum):
ASSIGN = "assign"
REPOSITION = "reposition"
HOLD = "hold"
PRIORITIZE = "prioritize"
class Action(DispatchArenaModel):
"""OpenEnv-facing action payload."""
action_type: str
courier_id: Optional[str] = None
order_id: Optional[str] = None
node_id: Optional[str] = None
@property
def name(self) -> str:
return self.action_type
@classmethod
def from_dict(cls, data: Mapping[str, Any]) -> "Action":
return cls.model_validate(dict(data))
class Node(DispatchArenaModel):
id: str
kind: str
label: str
class Courier(DispatchArenaModel):
id: str
node_id: str
status: CourierStatus = CourierStatus.IDLE
eta_remaining: int = 0
assigned_order_id: Optional[str] = None
load: Optional[str] = None
target_node_id: Optional[str] = None
@model_validator(mode="after")
def _validate_eta(self) -> "Courier":
if self.eta_remaining < 0:
raise ValueError("eta_remaining must be >= 0")
return self
class Order(DispatchArenaModel):
id: str
kind: str = "food"
pickup_node_id: str
dropoff_node_id: str
created_tick: int = 0
arrival_tick: int = 0
prep_remaining: Optional[int] = None
deadline_tick: int = 20
status: OrderStatus = OrderStatus.QUEUED
assigned_courier_id: Optional[str] = None
ready_now: Optional[bool] = None
delivered_tick: Optional[int] = None
class Config(DispatchArenaModel):
mode: Mode = Mode.MINI
max_ticks: int = 12
visible_prep: bool = False
num_couriers: int = 1
num_orders: int = 1
scenario_bucket: str = "easy"
progress_shaping: bool = True
rolling_arrivals: bool = False
traffic_noise: float = 0.0
@model_validator(mode="after")
def _validate_config(self) -> "Config":
if self.max_ticks <= 0:
raise ValueError("max_ticks must be > 0")
if self.traffic_noise < 0.0 or self.traffic_noise > 2.0:
raise ValueError("traffic_noise must be in [0.0, 2.0]")
if self.mode == Mode.MINI:
self.num_couriers = 1
self.num_orders = 1
if self.mode == Mode.NORMAL:
self.num_couriers = min(max(self.num_couriers, 2), 5)
self.num_orders = min(max(self.num_orders, 3), 10)
return self
class RewardBreakdown(DispatchArenaModel):
"""Machine-readable reward decomposition for a transition."""
step_cost: float = 0.0
progress_reward: float = 0.0
invalid_penalty: float = 0.0
success_reward: float = 0.0
timeout_penalty: float = 0.0
on_time_bonus: float = 0.0
late_penalty: float = 0.0
idle_penalty: float = 0.0
route_churn_penalty: float = 0.0
fairness_penalty: float = 0.0
total_reward: float = 0.0
class State(DispatchArenaModel):
"""Sanitized public environment state.
Hidden `prep_remaining` values are excluded unless visible prep mode is
explicitly enabled by the scenario config.
"""
episode_id: Optional[str] = None
tick: int = 0
max_ticks: int = 12
seed: Optional[int] = None
mode: Mode = Mode.MINI
nodes: List[Node] = Field(default_factory=list)
travel_time_matrix: Dict[str, Dict[str, int]] = Field(default_factory=dict)
couriers: List[Courier] = Field(default_factory=list)
orders: List[Order] = Field(default_factory=list)
reward_breakdown: RewardBreakdown = Field(default_factory=RewardBreakdown)
done: bool = False
truncated: bool = False
verifier_status: VerifierVerdict = VerifierVerdict.IN_PROGRESS
last_action: Optional[Action] = None
event_log: List[str] = Field(default_factory=list)
invalid_actions: int = 0
total_reward: float = 0.0
backlog: int = 0
sla_pressure: float = 0.0
@model_validator(mode="after")
def _validate_ticks(self) -> "State":
if self.tick < 0:
raise ValueError("tick must be >= 0")
if self.max_ticks <= 0:
raise ValueError("max_ticks must be > 0")
return self
class Observation(DispatchArenaModel):
"""Public observation returned by reset and step."""
state: State = Field(default_factory=State)
reward: float = 0.0
done: bool = False
truncated: bool = False
verifier_status: VerifierVerdict = VerifierVerdict.IN_PROGRESS
reward_breakdown: RewardBreakdown = Field(default_factory=RewardBreakdown)
legal_actions: List[str] = Field(default_factory=list)
action_mask: List[int] = Field(default_factory=list)
summary_text: str = "Awaiting dispatch."
info: Dict[str, Any] = Field(default_factory=dict)
@classmethod
def from_dict(cls, data: Mapping[str, Any]) -> "Observation":
return cls.model_validate(dict(data))
class EpisodeSummary(DispatchArenaModel):
episode_id: Optional[str] = None
seed: Optional[int] = None
mode: Mode = Mode.MINI
max_ticks: int = 12
ticks_taken: int = 0
invalid_actions: int = 0
total_reward: float = 0.0
final_verdict: VerifierVerdict = VerifierVerdict.IN_PROGRESS
action_trace: List[Action] = Field(default_factory=list)
delivered_orders: int = 0
expired_orders: int = 0
DispatchArenaAction = Action
DispatchArenaObservation = Observation
DispatchArenaState = State