File size: 6,586 Bytes
c71bf62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
"""Public models for the Dispatch Arena environment."""

from __future__ import annotations

from enum import Enum
from typing import Any, Dict, List, Mapping, Optional

from pydantic import BaseModel, ConfigDict, Field, model_validator


class DispatchArenaModel(BaseModel):
    """Base model with JSON-friendly helpers used by server and client."""

    model_config = ConfigDict(use_enum_values=False)

    def to_dict(self) -> Dict[str, Any]:
        return self.model_dump(mode="json", exclude_none=True)


class Mode(str, Enum):
    MINI = "mini"
    NORMAL = "normal"


class CourierStatus(str, Enum):
    IDLE = "idle"
    TO_PICKUP = "to_pickup"
    WAITING_PICKUP = "waiting_pickup"
    TO_DROPOFF = "to_dropoff"
    REPOSITIONING = "repositioning"


class OrderStatus(str, Enum):
    QUEUED = "queued"
    READY = "ready"
    PICKED = "picked"
    DELIVERED = "delivered"
    EXPIRED = "expired"


class VerifierVerdict(str, Enum):
    IN_PROGRESS = "in_progress"
    DELIVERED_SUCCESSFULLY = "delivered_successfully"
    TIMEOUT_FAILURE = "timeout_failure"
    PARTIAL_SUCCESS = "partial_success"


class MiniActionType(str, Enum):
    WAIT = "wait"
    GO_PICKUP = "go_pickup"
    GO_DROPOFF = "go_dropoff"
    PICKUP = "pickup"
    DROPOFF = "dropoff"


class NormalActionType(str, Enum):
    ASSIGN = "assign"
    REPOSITION = "reposition"
    HOLD = "hold"
    PRIORITIZE = "prioritize"


class Action(DispatchArenaModel):
    """OpenEnv-facing action payload."""

    action_type: str
    courier_id: Optional[str] = None
    order_id: Optional[str] = None
    node_id: Optional[str] = None

    @property
    def name(self) -> str:
        return self.action_type

    @classmethod
    def from_dict(cls, data: Mapping[str, Any]) -> "Action":
        return cls.model_validate(dict(data))


class Node(DispatchArenaModel):
    id: str
    kind: str
    label: str


class Courier(DispatchArenaModel):
    id: str
    node_id: str
    status: CourierStatus = CourierStatus.IDLE
    eta_remaining: int = 0
    assigned_order_id: Optional[str] = None
    load: Optional[str] = None
    target_node_id: Optional[str] = None

    @model_validator(mode="after")
    def _validate_eta(self) -> "Courier":
        if self.eta_remaining < 0:
            raise ValueError("eta_remaining must be >= 0")
        return self


class Order(DispatchArenaModel):
    id: str
    kind: str = "food"
    pickup_node_id: str
    dropoff_node_id: str
    created_tick: int = 0
    arrival_tick: int = 0
    prep_remaining: Optional[int] = None
    deadline_tick: int = 20
    status: OrderStatus = OrderStatus.QUEUED
    assigned_courier_id: Optional[str] = None
    ready_now: Optional[bool] = None
    delivered_tick: Optional[int] = None


class Config(DispatchArenaModel):
    mode: Mode = Mode.MINI
    max_ticks: int = 12
    visible_prep: bool = False
    num_couriers: int = 1
    num_orders: int = 1
    scenario_bucket: str = "easy"
    progress_shaping: bool = True
    rolling_arrivals: bool = False
    traffic_noise: float = 0.0

    @model_validator(mode="after")
    def _validate_config(self) -> "Config":
        if self.max_ticks <= 0:
            raise ValueError("max_ticks must be > 0")
        if self.traffic_noise < 0.0 or self.traffic_noise > 2.0:
            raise ValueError("traffic_noise must be in [0.0, 2.0]")
        if self.mode == Mode.MINI:
            self.num_couriers = 1
            self.num_orders = 1
        if self.mode == Mode.NORMAL:
            self.num_couriers = min(max(self.num_couriers, 2), 5)
            self.num_orders = min(max(self.num_orders, 3), 10)
        return self


class RewardBreakdown(DispatchArenaModel):
    """Machine-readable reward decomposition for a transition."""

    step_cost: float = 0.0
    progress_reward: float = 0.0
    invalid_penalty: float = 0.0
    success_reward: float = 0.0
    timeout_penalty: float = 0.0
    on_time_bonus: float = 0.0
    late_penalty: float = 0.0
    idle_penalty: float = 0.0
    route_churn_penalty: float = 0.0
    fairness_penalty: float = 0.0
    total_reward: float = 0.0


class State(DispatchArenaModel):
    """Sanitized public environment state.

    Hidden `prep_remaining` values are excluded unless visible prep mode is
    explicitly enabled by the scenario config.
    """

    episode_id: Optional[str] = None
    tick: int = 0
    max_ticks: int = 12
    seed: Optional[int] = None
    mode: Mode = Mode.MINI
    nodes: List[Node] = Field(default_factory=list)
    travel_time_matrix: Dict[str, Dict[str, int]] = Field(default_factory=dict)
    couriers: List[Courier] = Field(default_factory=list)
    orders: List[Order] = Field(default_factory=list)
    reward_breakdown: RewardBreakdown = Field(default_factory=RewardBreakdown)
    done: bool = False
    truncated: bool = False
    verifier_status: VerifierVerdict = VerifierVerdict.IN_PROGRESS
    last_action: Optional[Action] = None
    event_log: List[str] = Field(default_factory=list)
    invalid_actions: int = 0
    total_reward: float = 0.0
    backlog: int = 0
    sla_pressure: float = 0.0

    @model_validator(mode="after")
    def _validate_ticks(self) -> "State":
        if self.tick < 0:
            raise ValueError("tick must be >= 0")
        if self.max_ticks <= 0:
            raise ValueError("max_ticks must be > 0")
        return self


class Observation(DispatchArenaModel):
    """Public observation returned by reset and step."""

    state: State = Field(default_factory=State)
    reward: float = 0.0
    done: bool = False
    truncated: bool = False
    verifier_status: VerifierVerdict = VerifierVerdict.IN_PROGRESS
    reward_breakdown: RewardBreakdown = Field(default_factory=RewardBreakdown)
    legal_actions: List[str] = Field(default_factory=list)
    action_mask: List[int] = Field(default_factory=list)
    summary_text: str = "Awaiting dispatch."
    info: Dict[str, Any] = Field(default_factory=dict)

    @classmethod
    def from_dict(cls, data: Mapping[str, Any]) -> "Observation":
        return cls.model_validate(dict(data))


class EpisodeSummary(DispatchArenaModel):
    episode_id: Optional[str] = None
    seed: Optional[int] = None
    mode: Mode = Mode.MINI
    max_ticks: int = 12
    ticks_taken: int = 0
    invalid_actions: int = 0
    total_reward: float = 0.0
    final_verdict: VerifierVerdict = VerifierVerdict.IN_PROGRESS
    action_trace: List[Action] = Field(default_factory=list)
    delivered_orders: int = 0
    expired_orders: int = 0


DispatchArenaAction = Action
DispatchArenaObservation = Observation
DispatchArenaState = State