File size: 5,831 Bytes
57d9c21 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 | """Tests for warehouse_env/models.py — TDD RED phase."""
import pytest
from warehouse_env.models import (
RobotAction,
WarehouseAction,
RobotState,
OrderState,
WarehouseObservation,
WarehouseState,
WarehouseReward,
)
class TestRobotAction:
def test_valid_action(self):
a = RobotAction(robot_id=0, action_type="move_up")
assert a.robot_id == 0
assert a.action_type == "move_up"
def test_invalid_action_type_still_validates(self):
a = RobotAction(robot_id=0, action_type="INVALID")
assert a.action_type == "INVALID"
def test_json_roundtrip(self):
a = RobotAction(robot_id=1, action_type="pick")
dumped = a.model_dump()
restored = RobotAction.model_validate(dumped)
assert restored.robot_id == 1
assert restored.action_type == "pick"
def test_model_dump_json(self):
a = RobotAction(robot_id=0, action_type="wait")
json_str = a.model_dump_json()
assert "robot_id" in json_str
class TestWarehouseAction:
def test_empty_robots_default(self):
a = WarehouseAction()
assert a.robots == []
def test_with_robot(self):
ra = RobotAction(robot_id=0, action_type="wait")
a = WarehouseAction(robots=[ra])
assert len(a.robots) == 1
assert a.robots[0].robot_id == 0
def test_roundtrip(self):
ra = RobotAction(robot_id=0, action_type="move_down")
a = WarehouseAction(robots=[ra])
dumped = a.model_dump()
restored = WarehouseAction.model_validate(dumped)
assert restored.robots[0].action_type == "move_down"
def test_inherits_action(self):
from openenv.core.env_server.types import Action
a = WarehouseAction()
assert isinstance(a, Action)
def test_model_dump_json(self):
a = WarehouseAction(robots=[RobotAction(robot_id=0, action_type="wait")])
json_str = a.model_dump_json()
assert "robots" in json_str
class TestRobotState:
def test_basic_instantiation(self):
r = RobotState(id=0, row=2, col=3, carrying_item=False, assigned_order_id=None, is_active=True)
assert r.id == 0
assert r.row == 2
assert r.col == 3
assert r.carrying_item is False
assert r.assigned_order_id is None
assert r.is_active is True
def test_to_dict(self):
r = RobotState(id=0, row=2, col=3, carrying_item=False)
d = r.to_dict()
assert isinstance(d, dict)
assert d["id"] == 0
def test_model_dump_json(self):
r = RobotState(id=0, row=0, col=0, carrying_item=False)
json_str = r.model_dump_json()
assert "carrying_item" in json_str
class TestOrderState:
def test_basic_instantiation(self):
o = OrderState(
order_id="o1",
shelf_pos=(1, 1),
packing_pos=(7, 3),
status="pending",
created_at_step=0,
)
assert o.order_id == "o1"
assert o.shelf_pos == (1, 1)
assert o.packing_pos == (7, 3)
assert o.status == "pending"
assert o.created_at_step == 0
def test_defaults(self):
o = OrderState(order_id="o2", shelf_pos=(1, 1), packing_pos=(8, 2))
assert o.status == "pending"
assert o.created_at_step == 0
assert o.assigned_robot_id is None
def test_to_dict(self):
o = OrderState(order_id="o1", shelf_pos=(1, 1), packing_pos=(7, 3))
d = o.to_dict()
assert isinstance(d, dict)
assert d["order_id"] == "o1"
def test_model_dump_json(self):
o = OrderState(order_id="o1", shelf_pos=(1, 1), packing_pos=(7, 3))
json_str = o.model_dump_json()
assert "order_id" in json_str
class TestWarehouseObservation:
def test_instantiates_with_defaults(self):
obs = WarehouseObservation()
assert obs.grid == []
assert obs.robots == []
assert obs.order_queue == []
assert obs.step_count == 0
assert obs.max_steps == 50
assert obs.task_id == "solo_delivery"
assert obs.description == ""
def test_inherits_observation(self):
from openenv.core.env_server.types import Observation
obs = WarehouseObservation()
assert isinstance(obs, Observation)
def test_inherited_done_reward(self):
obs = WarehouseObservation()
assert obs.done is False
assert obs.reward is None
def test_model_dump_json(self):
obs = WarehouseObservation()
json_str = obs.model_dump_json()
assert "grid" in json_str
class TestWarehouseState:
def test_instantiates_with_defaults(self):
s = WarehouseState()
assert s.task_id == ""
assert s.grid == []
assert s.robots == []
assert s.orders == []
assert s.done is False
def test_inherits_state(self):
from openenv.core.env_server.types import State
s = WarehouseState()
assert isinstance(s, State)
def test_inherited_fields(self):
s = WarehouseState()
assert s.episode_id is None
assert s.step_count == 0
def test_model_dump_json(self):
s = WarehouseState()
json_str = s.model_dump_json()
assert "task_id" in json_str
class TestWarehouseReward:
def test_basic_instantiation(self):
r = WarehouseReward(value=10.0, breakdown={"delivery": 10.0})
assert r.value == 10.0
assert r.breakdown["delivery"] == 10.0
def test_empty_breakdown_default(self):
r = WarehouseReward(value=0.0)
assert r.breakdown == {}
def test_model_dump_json(self):
r = WarehouseReward(value=5.0, breakdown={"fast_bonus": 5.0})
json_str = r.model_dump_json()
assert "value" in json_str
|