"""Authoritative Dispatch Arena simulator.""" from __future__ import annotations import math import random from dataclasses import dataclass, field from typing import Any, Dict, List, Mapping, Optional, Tuple from dispatch_arena.models import ( Action, Config, Courier, CourierStatus, EpisodeSummary, MiniActionType, Mode, NormalActionType, Observation, Order, OrderStatus, State, VerifierVerdict, ) from dispatch_arena.server.rewards import RewardModel from dispatch_arena.server.scenarios import generate_scenario from dispatch_arena.server.serializers import idle_courier_count, make_observation, public_state DEFAULT_MAX_TICKS = 12 MINI_ACTION_ORDER = [action.value for action in MiniActionType] NORMAL_ACTION_ORDER = [action.value for action in NormalActionType] @dataclass class DispatchArenaEnvironment: """Native OpenEnv-style dispatch simulation for mini and normal modes.""" config: Config = field(default_factory=Config) _rng: random.Random = field(default_factory=random.Random) _state: Optional[State] = None _reward_model: RewardModel = field(default_factory=RewardModel) _action_trace: List[Action] = field(default_factory=list) # Hidden simulator-only state. Never serialized into Observation/State. _pending_arrivals: List[Order] = field(default_factory=list) _traffic_multipliers: Dict[Tuple[str, str], float] = field(default_factory=dict) def reset( self, seed: Optional[int] = None, episode_id: Optional[str] = None, config: Optional[Config | Mapping[str, Any]] = None, ) -> Observation: if config is not None: self.config = config if isinstance(config, Config) else Config.model_validate(dict(config)) self._rng.seed(seed) scenario = generate_scenario(self.config, seed) # Partition orders: anything arriving at t=0 is visible immediately; # everything else is held in the env's hidden pending list. initial_orders = [o for o in scenario.orders if o.arrival_tick == 0] pending = [o for o in scenario.orders if o.arrival_tick > 0] pending.sort(key=lambda o: o.arrival_tick) self._pending_arrivals = pending self._traffic_multipliers = dict(scenario.traffic_multipliers) self._state = State( episode_id=episode_id, tick=0, max_ticks=self.config.max_ticks, seed=seed, mode=self.config.mode, nodes=scenario.nodes, travel_time_matrix=scenario.travel_time_matrix, couriers=scenario.couriers, orders=initial_orders, ) self._action_trace = [] self._refresh_derived() return self._observation(info={"reset": True}) def step(self, action: Action | str | Mapping[str, Any]) -> Observation: state = self._require_state() if state.done: raise RuntimeError("Episode already finished. Call reset() before stepping again.") parsed_action = self._coerce_action(action) reward = self._reward_model.base() info: Dict[str, Any] = {"invalid_action": False, "invalid_reason": None, "events": []} legal_actions = self.legal_actions() state.tick += 1 state.last_action = parsed_action self._action_trace.append(parsed_action) if self.config.mode == Mode.MINI: valid = parsed_action.action_type in legal_actions if valid: self._progress_prep() self._release_arrivals(info) self._apply_mini_action(parsed_action, reward, info) else: self._mark_invalid(parsed_action, reward, info) else: valid = self._is_valid_normal_action(parsed_action) if valid: self._progress_prep() self._release_arrivals(info) self._apply_normal_action(parsed_action, reward, info) self._advance_normal_couriers(reward, info) self._expire_orders(info) else: self._mark_invalid(parsed_action, reward, info) self._reward_model.idle(reward, idle_courier_count(state)) self._reward_model.fairness(reward, self._delivery_imbalance()) if not state.done and state.tick >= state.max_ticks: state.done = True state.truncated = True self._reward_model.timeout(reward) else: self._reward_model.finalize(reward) state.reward_breakdown = reward state.total_reward += reward.total_reward state.event_log.extend(info["events"]) self._refresh_derived() return self._observation(info=info) @property def state(self) -> State: return public_state(self._require_state(), self.config) def legal_actions(self) -> List[str]: state = self._require_state() if state.done: return [] if state.mode == Mode.NORMAL: actions = [NormalActionType.HOLD.value] if any(c.status == CourierStatus.IDLE and c.load is None for c in state.couriers) and any( o.status in {OrderStatus.QUEUED, OrderStatus.READY} and o.assigned_courier_id is None for o in state.orders ): actions.append(NormalActionType.ASSIGN.value) if any(c.status == CourierStatus.IDLE and c.load is None for c in state.couriers): actions.append(NormalActionType.REPOSITION.value) if any(o.status in {OrderStatus.QUEUED, OrderStatus.READY} for o in state.orders): actions.append(NormalActionType.PRIORITIZE.value) return actions courier = state.couriers[0] order = state.orders[0] actions = [] if courier.load is None and order.status in {OrderStatus.QUEUED, OrderStatus.READY} and courier.node_id != order.pickup_node_id: actions.append(MiniActionType.GO_PICKUP.value) if courier.node_id == order.pickup_node_id and courier.load is None and order.status == OrderStatus.READY: actions.append(MiniActionType.PICKUP.value) if courier.load == order.id and courier.node_id != order.dropoff_node_id: actions.append(MiniActionType.GO_DROPOFF.value) if courier.load == order.id and courier.node_id == order.dropoff_node_id: actions.append(MiniActionType.DROPOFF.value) actions.append(MiniActionType.WAIT.value) return actions def action_mask(self) -> List[int]: legal = set(self.legal_actions()) order = NORMAL_ACTION_ORDER if self.config.mode == Mode.NORMAL else MINI_ACTION_ORDER return [1 if action in legal else 0 for action in order] def get_episode_summary(self) -> Dict[str, Any]: state = self._require_state() summary = EpisodeSummary( episode_id=state.episode_id, seed=state.seed, mode=state.mode, max_ticks=state.max_ticks, ticks_taken=state.tick, invalid_actions=state.invalid_actions, total_reward=state.total_reward, final_verdict=state.verifier_status, action_trace=self._action_trace, delivered_orders=sum(1 for order in state.orders if order.status == OrderStatus.DELIVERED), expired_orders=sum(1 for order in state.orders if order.status == OrderStatus.EXPIRED), ) return summary.to_dict() def _observation(self, info: Optional[dict] = None) -> Observation: return make_observation(self._require_state(), self.config, self.legal_actions(), self.action_mask(), info=info) def _require_state(self) -> State: if self._state is None: raise RuntimeError("Environment not initialized. Call reset() first.") return self._state def _coerce_action(self, action: Action | str | Mapping[str, Any]) -> Action: if isinstance(action, Action): return action if isinstance(action, str): return Action(action_type=action) if isinstance(action, Mapping): return Action.model_validate(dict(action)) raise TypeError("action must be Action, string, or mapping") def _mark_invalid(self, action: Action, reward, info: Dict[str, Any]) -> None: state = self._require_state() state.invalid_actions += 1 info["invalid_action"] = True info["invalid_reason"] = f"{action.action_type} is not legal from the current state" info["events"].append(info["invalid_reason"]) self._reward_model.invalid(reward) def _apply_mini_action(self, action: Action, reward, info: Dict[str, Any]) -> None: state = self._require_state() courier = state.couriers[0] order = state.orders[0] action_type = action.action_type if action_type == MiniActionType.WAIT.value: info["events"].append("courier waited") elif action_type == MiniActionType.GO_PICKUP.value: courier.node_id = order.pickup_node_id courier.status = CourierStatus.IDLE info["events"].append(f"{courier.id} moved to pickup") elif action_type == MiniActionType.PICKUP.value: order.status = OrderStatus.PICKED order.assigned_courier_id = courier.id courier.load = order.id courier.assigned_order_id = order.id courier.status = CourierStatus.IDLE info["events"].append(f"{courier.id} picked {order.id}") elif action_type == MiniActionType.GO_DROPOFF.value: courier.node_id = order.dropoff_node_id courier.status = CourierStatus.IDLE info["events"].append(f"{courier.id} moved to dropoff") elif action_type == MiniActionType.DROPOFF.value: order.status = OrderStatus.DELIVERED order.delivered_tick = state.tick courier.load = None courier.assigned_order_id = None state.done = True info["events"].append(f"{order.id} delivered") self._reward_model.mini_progress(reward, action_type) def _apply_normal_action(self, action: Action, reward, info: Dict[str, Any]) -> None: state = self._require_state() action_type = action.action_type if action_type == NormalActionType.HOLD.value: info["events"].append(f"{action.courier_id or 'dispatcher'} held") elif action_type == NormalActionType.PRIORITIZE.value: info["events"].append(f"{action.order_id or 'backlog'} prioritized") elif action_type == NormalActionType.REPOSITION.value: courier = self._courier(action.courier_id) if courier.node_id != action.node_id: courier.status = CourierStatus.REPOSITIONING courier.target_node_id = action.node_id courier.eta_remaining = self._travel_time(courier.node_id, action.node_id) info["events"].append(f"{courier.id} repositioning to {action.node_id}") elif action_type == NormalActionType.ASSIGN.value: courier = self._courier(action.courier_id) order = self._order(action.order_id) if courier.assigned_order_id and courier.assigned_order_id != order.id: self._reward_model.churn(reward) courier.assigned_order_id = order.id courier.status = CourierStatus.TO_PICKUP courier.target_node_id = order.pickup_node_id courier.eta_remaining = self._travel_time(courier.node_id, order.pickup_node_id) order.assigned_courier_id = courier.id info["events"].append(f"{courier.id} assigned {order.id}") self._reward_model.normal_action_progress(reward, action_type) def _is_valid_normal_action(self, action: Action) -> bool: state = self._require_state() action_type = action.action_type if action_type == NormalActionType.HOLD.value: return action.courier_id is None or any(c.id == action.courier_id for c in state.couriers) if action_type == NormalActionType.PRIORITIZE.value: return action.order_id is None or any(o.id == action.order_id and o.status in {OrderStatus.QUEUED, OrderStatus.READY} for o in state.orders) if action_type == NormalActionType.REPOSITION.value: if not action.courier_id or not action.node_id: return False courier = self._maybe_courier(action.courier_id) return courier is not None and courier.status == CourierStatus.IDLE and courier.load is None and action.node_id in self._node_ids() if action_type == NormalActionType.ASSIGN.value: if not action.courier_id or not action.order_id: return False courier = self._maybe_courier(action.courier_id) order = self._maybe_order(action.order_id) return ( courier is not None and order is not None and courier.status == CourierStatus.IDLE and courier.load is None and order.status in {OrderStatus.QUEUED, OrderStatus.READY} and order.assigned_courier_id is None ) return False def _advance_normal_couriers(self, reward, info: Dict[str, Any]) -> None: state = self._require_state() for courier in state.couriers: if courier.eta_remaining > 0: courier.eta_remaining -= 1 if courier.eta_remaining > 0: continue if courier.target_node_id: courier.node_id = courier.target_node_id if courier.status == CourierStatus.REPOSITIONING and courier.eta_remaining == 0: courier.status = CourierStatus.IDLE courier.target_node_id = None info["events"].append(f"{courier.id} finished reposition") elif courier.status == CourierStatus.TO_PICKUP and courier.eta_remaining == 0: order = self._order(courier.assigned_order_id) if order.status == OrderStatus.READY: self._auto_pickup(courier, order, reward, info) else: courier.status = CourierStatus.WAITING_PICKUP info["events"].append(f"{courier.id} waiting for {order.id}") elif courier.status == CourierStatus.WAITING_PICKUP: order = self._order(courier.assigned_order_id) if order.status == OrderStatus.READY: self._auto_pickup(courier, order, reward, info) elif courier.status == CourierStatus.TO_DROPOFF and courier.eta_remaining == 0: order = self._order(courier.load) order.status = OrderStatus.DELIVERED order.delivered_tick = state.tick on_time = state.tick <= order.deadline_tick courier.load = None courier.assigned_order_id = None courier.target_node_id = None courier.status = CourierStatus.IDLE self._reward_model.delivered(reward, on_time=on_time) info["events"].append(f"{order.id} delivered by {courier.id}") def _auto_pickup(self, courier: Courier, order: Order, reward, info: Dict[str, Any]) -> None: order.status = OrderStatus.PICKED courier.load = order.id courier.status = CourierStatus.TO_DROPOFF courier.target_node_id = order.dropoff_node_id courier.eta_remaining = self._travel_time(courier.node_id, order.dropoff_node_id) reward.progress_reward += self._reward_model.config.pickup_progress_bonus info["events"].append(f"{courier.id} picked {order.id}") def _progress_prep(self) -> None: state = self._require_state() for order in state.orders: if order.status not in {OrderStatus.QUEUED, OrderStatus.READY}: continue if order.prep_remaining is None: continue if order.prep_remaining > 0: order.prep_remaining -= 1 if order.prep_remaining == 0: order.status = OrderStatus.READY def _release_arrivals(self, info: Dict[str, Any]) -> None: if not self._pending_arrivals: return state = self._require_state() while self._pending_arrivals and self._pending_arrivals[0].arrival_tick <= state.tick: new_order = self._pending_arrivals.pop(0) new_order.created_tick = state.tick if new_order.prep_remaining == 0: new_order.status = OrderStatus.READY state.orders.append(new_order) info["events"].append(f"{new_order.id} arrived") def _expire_orders(self, info: Dict[str, Any]) -> None: state = self._require_state() for order in state.orders: if order.status in {OrderStatus.QUEUED, OrderStatus.READY} and state.tick > order.deadline_tick: order.status = OrderStatus.EXPIRED info["events"].append(f"{order.id} expired") def _refresh_derived(self) -> None: state = self._require_state() delivered = sum(1 for order in state.orders if order.status == OrderStatus.DELIVERED) active = [order for order in state.orders if order.status in {OrderStatus.QUEUED, OrderStatus.READY, OrderStatus.PICKED}] # Backlog must include orders not yet visible (rolling arrivals) — else # the SLA pressure metric and 'done' check ignore future work. state.backlog = len(active) + len(self._pending_arrivals) state.sla_pressure = 0.0 if not active else sum(1 for order in active if order.deadline_tick - state.tick <= 3) / len(active) all_visible_resolved = delivered == len(state.orders) no_more_pending = not self._pending_arrivals if all_visible_resolved and no_more_pending: state.done = True state.verifier_status = VerifierVerdict.DELIVERED_SUCCESSFULLY elif state.truncated: state.verifier_status = VerifierVerdict.PARTIAL_SUCCESS if delivered > 0 and state.mode == Mode.NORMAL else VerifierVerdict.TIMEOUT_FAILURE else: state.verifier_status = VerifierVerdict.IN_PROGRESS def _travel_time(self, src: str, dst: Optional[str]) -> int: if dst is None: return 0 base = self._require_state().travel_time_matrix.get(src, {}).get(dst, 1) multiplier = self._traffic_multipliers.get((src, dst), 1.0) return max(1, math.ceil(base * multiplier)) def _delivery_imbalance(self) -> int: state = self._require_state() delivered_by: Dict[str, int] = {courier.id: 0 for courier in state.couriers} for order in state.orders: if order.status == OrderStatus.DELIVERED and order.assigned_courier_id: delivered_by[order.assigned_courier_id] = delivered_by.get(order.assigned_courier_id, 0) + 1 return max(delivered_by.values(), default=0) - min(delivered_by.values(), default=0) def _courier(self, courier_id: Optional[str]) -> Courier: courier = self._maybe_courier(courier_id) if courier is None: raise ValueError(f"Unknown courier_id: {courier_id}") return courier def _maybe_courier(self, courier_id: Optional[str]) -> Optional[Courier]: return next((courier for courier in self._require_state().couriers if courier.id == courier_id), None) def _order(self, order_id: Optional[str]) -> Order: order = self._maybe_order(order_id) if order is None: raise ValueError(f"Unknown order_id: {order_id}") return order def _maybe_order(self, order_id: Optional[str]) -> Optional[Order]: return next((order for order in self._require_state().orders if order.id == order_id), None) def _node_ids(self) -> set[str]: return {node.id for node in self._require_state().nodes} Environment = DispatchArenaEnvironment