dispatch_arena_v0 / tests /test_environment.py
Freakdivi's picture
Upload folder using huggingface_hub
c71bf62 verified
import random
import unittest
from dispatch_arena.models import Config, Mode, OrderStatus, VerifierVerdict
from dispatch_arena.server.env import DispatchArenaEnvironment
def heuristic_action(obs):
state = obs.state
if state.mode == Mode.NORMAL:
courier = next((c for c in state.couriers if c.status == "idle" and c.load is None), None)
order = next(
(
o
for o in state.orders
if o.status in {"queued", "ready"} and o.assigned_courier_id is None
),
None,
)
if courier and order:
return {"action_type": "assign", "courier_id": courier.id, "order_id": order.id}
return {"action_type": "hold"}
for action in ["pickup", "dropoff", "go_pickup", "go_dropoff", "wait"]:
if action in obs.legal_actions:
return action
return obs.legal_actions[0]
def play_episode(env: DispatchArenaEnvironment, seed: int = 1):
obs = env.reset(seed=seed)
while not obs.done:
obs = env.step(heuristic_action(obs))
return obs
class DispatchArenaEnvironmentTests(unittest.TestCase):
def test_mini_golden_successful_trajectory(self):
env = DispatchArenaEnvironment(Config(mode=Mode.MINI, max_ticks=12))
obs = play_episode(env, seed=1)
self.assertTrue(obs.done)
self.assertFalse(obs.truncated)
self.assertEqual(obs.verifier_status, VerifierVerdict.DELIVERED_SUCCESSFULLY)
self.assertEqual(obs.state.orders[0].status, OrderStatus.DELIVERED)
self.assertGreater(env.get_episode_summary()["total_reward"], 0.0)
def test_mini_reward_components_are_returned(self):
env = DispatchArenaEnvironment(Config(mode=Mode.MINI, max_ticks=12))
obs = env.reset(seed=1)
self.assertIn("go_pickup", obs.legal_actions)
obs = env.step("go_pickup")
self.assertEqual(obs.reward_breakdown.step_cost, -0.1)
self.assertGreater(obs.reward_breakdown.progress_reward, 0.0)
self.assertEqual(obs.reward, obs.reward_breakdown.total_reward)
def test_invalid_action_penalty_without_mutation(self):
env = DispatchArenaEnvironment(Config(mode=Mode.MINI, max_ticks=12))
env.reset(seed=1)
before = env.state
obs = env.step("dropoff")
after = env.state
self.assertTrue(obs.info["invalid_action"])
self.assertLess(obs.reward_breakdown.invalid_penalty, 0.0)
self.assertEqual(after.couriers[0].node_id, before.couriers[0].node_id)
self.assertEqual(after.orders[0].status, before.orders[0].status)
self.assertEqual(after.tick, before.tick + 1)
def test_timeout_trajectory_is_negative(self):
env = DispatchArenaEnvironment(Config(mode=Mode.MINI, max_ticks=1))
env.reset(seed=1)
obs = env.step("wait")
self.assertTrue(obs.done)
self.assertTrue(obs.truncated)
self.assertEqual(obs.verifier_status, VerifierVerdict.TIMEOUT_FAILURE)
self.assertLess(obs.reward_breakdown.timeout_penalty, 0.0)
def test_hidden_prep_remaining_never_appears_publicly(self):
env = DispatchArenaEnvironment(Config(mode=Mode.MINI, max_ticks=12, visible_prep=False))
obs = env.reset(seed=3)
env.step("wait")
public_blob = " ".join([str(obs.to_dict()), str(env.state.to_dict()), obs.summary_text, str(env.get_episode_summary())])
self.assertNotIn("prep_remaining", public_blob)
def test_visible_mode_can_expose_ready_now_and_prep(self):
env = DispatchArenaEnvironment(Config(mode=Mode.MINI, max_ticks=12, visible_prep=True))
obs = env.reset(seed=3)
data = obs.to_dict()
self.assertIn("prep_remaining", str(data))
self.assertIn("ready_now", str(data))
def test_action_mask_matches_legal_action_list(self):
env = DispatchArenaEnvironment(Config(mode=Mode.MINI, max_ticks=12))
obs = env.reset(seed=4)
for action, mask_value in zip(["wait", "go_pickup", "go_dropoff", "pickup", "dropoff"], obs.action_mask):
self.assertEqual(mask_value, 1 if action in obs.legal_actions else 0)
def test_seeded_reset_is_reproducible(self):
config = Config(mode=Mode.MINI, max_ticks=12)
env1 = DispatchArenaEnvironment(config)
env2 = DispatchArenaEnvironment(config)
obs1 = env1.reset(seed=42)
obs2 = env2.reset(seed=42)
self.assertEqual(obs1.to_dict(), obs2.to_dict())
actions = ["go_pickup", "wait", "pickup"]
trace1 = [env1.step(action).to_dict() for action in actions]
trace2 = [env2.step(action).to_dict() for action in actions]
self.assertEqual(trace1, trace2)
def test_random_rollout_never_breaks_invariants(self):
env = DispatchArenaEnvironment(Config(mode=Mode.MINI, max_ticks=12))
obs = env.reset(seed=9)
rng = random.Random(9)
while not obs.done:
action = rng.choice(["wait", "go_pickup", "go_dropoff", "pickup", "dropoff"])
obs = env.step(action)
delivered = obs.state.orders[0].status == OrderStatus.DELIVERED
self.assertFalse(delivered and obs.state.couriers[0].load)
self.assertLessEqual(obs.state.tick, obs.state.max_ticks)
def test_normal_heuristic_rollout_delivers_some_orders(self):
config = Config(mode=Mode.NORMAL, max_ticks=18, num_couriers=3, num_orders=5)
env = DispatchArenaEnvironment(config)
obs = play_episode(env, seed=5)
self.assertTrue(obs.done)
self.assertGreaterEqual(env.get_episode_summary()["delivered_orders"], 1)
for order in obs.state.orders:
self.assertFalse(order.status == OrderStatus.DELIVERED and order.assigned_courier_id is None)
def test_normal_invalid_duplicate_assignment_is_penalized(self):
config = Config(mode=Mode.NORMAL, max_ticks=18, num_couriers=2, num_orders=3)
env = DispatchArenaEnvironment(config)
env.reset(seed=5)
env.step({"action_type": "assign", "courier_id": "courier_0", "order_id": "order_0"})
obs = env.step({"action_type": "assign", "courier_id": "courier_1", "order_id": "order_0"})
self.assertTrue(obs.info["invalid_action"])
self.assertLess(obs.reward_breakdown.invalid_penalty, 0.0)
def test_rolling_arrivals_appear_over_time(self):
config = Config(
mode=Mode.NORMAL,
max_ticks=20,
num_couriers=3,
num_orders=5,
scenario_bucket="easy",
rolling_arrivals=True,
)
env = DispatchArenaEnvironment(config)
obs = env.reset(seed=11)
initial_visible = len(obs.state.orders)
# With rolling enabled and num_orders=5, fewer than all orders should
# be visible at t=0 — the rest are scheduled to arrive later.
self.assertLess(initial_visible, 5)
# Run a heuristic loop; by the end, every order should have arrived
# (i.e., num_orders worth of orders should appear in state.orders).
while not obs.done:
obs = env.step(heuristic_action(obs))
self.assertEqual(len(obs.state.orders), 5)
def test_pending_arrivals_never_leak_in_public_state(self):
config = Config(
mode=Mode.NORMAL,
max_ticks=20,
num_couriers=3,
num_orders=6,
rolling_arrivals=True,
)
env = DispatchArenaEnvironment(config)
obs = env.reset(seed=13)
# The env has pending arrivals scheduled for future ticks. The visible
# state should expose strictly fewer orders than the scenario total.
self.assertGreater(len(env._pending_arrivals), 0)
visible_ids = {order.id for order in obs.state.orders}
pending_ids = {order.id for order in env._pending_arrivals}
self.assertEqual(visible_ids & pending_ids, set())
# And no pending order id should appear in any public serialization.
public_blob = str(obs.to_dict()) + str(env.state.to_dict()) + obs.summary_text
for pending_id in pending_ids:
self.assertNotIn(pending_id, public_blob)
def test_traffic_noise_is_deterministic_and_extends_eta(self):
config_no_traffic = Config(
mode=Mode.NORMAL,
max_ticks=20,
num_couriers=2,
num_orders=3,
traffic_noise=0.0,
)
config_with_traffic = Config(
mode=Mode.NORMAL,
max_ticks=20,
num_couriers=2,
num_orders=3,
traffic_noise=1.0,
)
env_a = DispatchArenaEnvironment(config_with_traffic)
env_b = DispatchArenaEnvironment(config_with_traffic)
env_a.reset(seed=21)
env_b.reset(seed=21)
# Same seed → identical traffic multipliers.
self.assertEqual(env_a._traffic_multipliers, env_b._traffic_multipliers)
# Multipliers must be >= 1.0 (uniform 1.0 .. 1.0+noise).
self.assertTrue(all(m >= 1.0 for m in env_a._traffic_multipliers.values()))
env_clean = DispatchArenaEnvironment(config_no_traffic)
env_clean.reset(seed=21)
self.assertEqual(env_clean._traffic_multipliers, {})
def test_traffic_multipliers_never_appear_in_observation(self):
config = Config(
mode=Mode.NORMAL,
max_ticks=20,
num_couriers=2,
num_orders=3,
traffic_noise=0.8,
)
env = DispatchArenaEnvironment(config)
obs = env.reset(seed=33)
public_blob = str(obs.to_dict()) + str(env.state.to_dict())
self.assertNotIn("traffic_multiplier", public_blob)
self.assertNotIn("_traffic", public_blob)
if __name__ == "__main__":
unittest.main()