| """Tests for WarehouseEnv — reset, step, state property, episode logic. |
| |
| TDD: tests written before implementation (env.py does not exist yet). |
| """ |
| import pytest |
| from warehouse_env.models import ( |
| WarehouseAction, |
| RobotAction, |
| WarehouseObservation, |
| WarehouseState, |
| ) |
|
|
|
|
| |
| |
| |
|
|
| @pytest.fixture |
| def env(): |
| from warehouse_env.env import WarehouseEnv |
| return WarehouseEnv() |
|
|
|
|
| |
| |
| |
|
|
| class TestReset: |
| def test_reset_returns_warehouse_observation(self, env): |
| obs = env.reset() |
| assert isinstance(obs, WarehouseObservation) |
|
|
| def test_reset_default_task_is_solo_delivery(self, env): |
| obs = env.reset() |
| assert obs.task_id == "solo_delivery" |
|
|
| def test_reset_solo_delivery_one_robot(self, env): |
| obs = env.reset(task_id="solo_delivery") |
| assert len(obs.robots) == 1 |
|
|
| def test_reset_coordinated_delivery_three_robots(self, env): |
| obs = env.reset(task_id="coordinated_delivery") |
| assert len(obs.robots) == 3 |
|
|
| def test_reset_crisis_management_five_robots(self, env): |
| obs = env.reset(task_id="crisis_management") |
| assert len(obs.robots) == 5 |
|
|
| def test_reset_unknown_task_raises_value_error(self, env): |
| with pytest.raises(ValueError, match="unknown_task"): |
| env.reset(task_id="unknown_task") |
|
|
| def test_reset_step_count_is_zero(self, env): |
| obs = env.reset() |
| assert obs.step_count == 0 |
|
|
| def test_reset_done_is_false(self, env): |
| obs = env.reset() |
| assert obs.done is False |
|
|
| def test_reset_grid_is_2d_list(self, env): |
| obs = env.reset() |
| assert isinstance(obs.grid, list) |
| assert all(isinstance(row, list) for row in obs.grid) |
|
|
| def test_reset_solo_grid_dimensions_10x10(self, env): |
| obs = env.reset(task_id="solo_delivery") |
| assert len(obs.grid) == 10 |
| assert all(len(row) == 10 for row in obs.grid) |
|
|
| def test_reset_coordinated_grid_dimensions_12x12(self, env): |
| obs = env.reset(task_id="coordinated_delivery") |
| assert len(obs.grid) == 12 |
| assert all(len(row) == 12 for row in obs.grid) |
|
|
| def test_reset_solo_order_queue_length_five(self, env): |
| obs = env.reset(task_id="solo_delivery") |
| assert len(obs.order_queue) == 5 |
|
|
| def test_reset_description_is_nonempty_string(self, env): |
| obs = env.reset() |
| assert isinstance(obs.description, str) |
| assert len(obs.description) > 0 |
|
|
| def test_reset_description_contains_robot_info(self, env): |
| obs = env.reset() |
| |
| desc = obs.description.lower() |
| assert "robot" in desc or "r0" in desc.lower() |
|
|
|
|
| |
| |
| |
|
|
| class TestStep: |
| def test_step_returns_warehouse_observation(self, env): |
| env.reset() |
| obs = env.step(WarehouseAction(robots=[])) |
| assert isinstance(obs, WarehouseObservation) |
|
|
| def test_step_empty_action_increments_step_count(self, env): |
| env.reset() |
| obs = env.step(WarehouseAction(robots=[])) |
| assert obs.step_count == 1 |
|
|
| def test_step_before_reset_raises_runtime_error(self, env): |
| with pytest.raises(RuntimeError): |
| env.step(WarehouseAction(robots=[])) |
|
|
| def test_step_invalid_action_type_treated_as_wait(self, env): |
| env.reset(task_id="solo_delivery") |
| action = WarehouseAction(robots=[RobotAction(robot_id=0, action_type="JUMP")]) |
| |
| obs_before = env.reset(task_id="solo_delivery") |
| robot_row_before = obs_before.robots[0].row |
| robot_col_before = obs_before.robots[0].col |
| obs = env.step(WarehouseAction(robots=[RobotAction(robot_id=0, action_type="JUMP")])) |
| assert obs.robots[0].row == robot_row_before |
| assert obs.robots[0].col == robot_col_before |
|
|
| def test_step_move_up_decrements_row(self, env): |
| """Robot at (5,5) with move_up ends at (4,5) if passable.""" |
| obs = env.reset(task_id="solo_delivery") |
| |
| assert obs.robots[0].row == 5 |
| assert obs.robots[0].col == 5 |
| obs2 = env.step(WarehouseAction(robots=[RobotAction(robot_id=0, action_type="move_up")])) |
| assert obs2.robots[0].row == 4 |
| assert obs2.robots[0].col == 5 |
|
|
| def test_step_robot_cannot_move_into_shelf(self, env): |
| """Shelf cells are not passable — robot stays in place.""" |
| |
| obs = env.reset(task_id="solo_delivery") |
| |
| |
| |
| grid = obs.grid |
| assert grid[1][1] == "S", f"Expected shelf at (1,1) but got {grid[1][1]}" |
|
|
| def test_step_robot_cannot_move_outside_bounds(self, env): |
| """Robot at (0,5) with move_up stays at (0,5) — out of bounds.""" |
| obs = env.reset(task_id="solo_delivery") |
| |
| for _ in range(5): |
| env.step(WarehouseAction(robots=[RobotAction(robot_id=0, action_type="move_up")])) |
| |
| obs2 = env.step(WarehouseAction(robots=[RobotAction(robot_id=0, action_type="move_up")])) |
| assert obs2.robots[0].row >= 0 |
|
|
| def test_step_done_when_max_steps_reached(self, env): |
| """Episode ends when step_count >= max_steps.""" |
| obs = env.reset(task_id="solo_delivery") |
| max_steps = obs.max_steps |
| for _ in range(max_steps): |
| obs = env.step(WarehouseAction(robots=[])) |
| assert obs.done is True |
|
|
| def test_step_pick_drop_full_cycle(self, env): |
| """Integration: pick item from shelf, carry to packing station, drop, reward > 0.""" |
| from warehouse_env.env import WarehouseEnv |
| e = WarehouseEnv() |
| obs = e.reset(task_id="solo_delivery") |
| |
| |
| for _ in range(3): |
| obs = e.step(WarehouseAction(robots=[RobotAction(robot_id=0, action_type="move_up")])) |
| for _ in range(2): |
| obs = e.step(WarehouseAction(robots=[RobotAction(robot_id=0, action_type="move_left")])) |
| |
| assert obs.robots[0].row == 2 |
| assert obs.robots[0].col == 3 |
| |
| obs = e.step(WarehouseAction(robots=[RobotAction(robot_id=0, action_type="pick")])) |
| assert obs.robots[0].carrying_item is True |
|
|
| |
| for _ in range(5): |
| obs = e.step(WarehouseAction(robots=[RobotAction(robot_id=0, action_type="move_down")])) |
| |
| obs = e.step(WarehouseAction(robots=[RobotAction(robot_id=0, action_type="move_left")])) |
| |
| |
| obs = e.step(WarehouseAction(robots=[RobotAction(robot_id=0, action_type="drop")])) |
| |
| assert obs.reward is not None and obs.reward > 0 |
|
|
| def test_step_pick_sets_carrying_item(self, env): |
| """Pick action adjacent to shelf with pending order sets carrying_item=True.""" |
| obs = env.reset(task_id="solo_delivery") |
| |
| |
| for _ in range(3): |
| env.step(WarehouseAction(robots=[RobotAction(robot_id=0, action_type="move_up")])) |
| for _ in range(2): |
| env.step(WarehouseAction(robots=[RobotAction(robot_id=0, action_type="move_left")])) |
| obs = env.step(WarehouseAction(robots=[RobotAction(robot_id=0, action_type="pick")])) |
| assert obs.robots[0].carrying_item is True |
|
|
| def test_step_drop_at_packing_station_delivers_order(self, env): |
| """Drop when carrying item adjacent to correct packing station delivers order, reward > 0.""" |
| obs = env.reset(task_id="solo_delivery") |
| |
| |
| |
| |
| for _ in range(3): |
| env.step(WarehouseAction(robots=[RobotAction(robot_id=0, action_type="move_up")])) |
| |
| for _ in range(2): |
| env.step(WarehouseAction(robots=[RobotAction(robot_id=0, action_type="move_left")])) |
| |
| obs = env.step(WarehouseAction(robots=[RobotAction(robot_id=0, action_type="pick")])) |
| assert obs.robots[0].carrying_item is True, "Robot should be carrying after pick" |
|
|
| |
| for _ in range(5): |
| obs = env.step(WarehouseAction(robots=[RobotAction(robot_id=0, action_type="move_down")])) |
| |
| |
| obs = env.step(WarehouseAction(robots=[RobotAction(robot_id=0, action_type="move_left")])) |
| |
| assert obs.robots[0].row == 7 |
| assert obs.robots[0].col == 2 |
| |
| obs = env.step(WarehouseAction(robots=[RobotAction(robot_id=0, action_type="drop")])) |
| |
| assert obs.robots[0].carrying_item is False |
| |
| assert obs.reward is not None and obs.reward > 0 |
|
|
| def test_step_done_when_all_orders_delivered(self, env): |
| """obs.done becomes True when all orders are delivered.""" |
| |
| |
| obs = env.reset(task_id="solo_delivery") |
| |
| assert obs.done is False |
|
|
| def test_step_two_robots_same_target_both_stay(self): |
| """Two robots targeting the same cell both stay in place (collision).""" |
| from warehouse_env.env import WarehouseEnv |
| e = WarehouseEnv() |
| obs = e.reset(task_id="coordinated_delivery") |
| |
| r0_start = (obs.robots[0].row, obs.robots[0].col) |
| r1_start = (obs.robots[1].row, obs.robots[1].col) |
| |
| |
| |
| |
| obs = e.step(WarehouseAction(robots=[ |
| RobotAction(robot_id=0, action_type="move_right"), |
| RobotAction(robot_id=1, action_type="move_left"), |
| ])) |
| obs = e.step(WarehouseAction(robots=[ |
| RobotAction(robot_id=0, action_type="move_right"), |
| RobotAction(robot_id=1, action_type="move_left"), |
| ])) |
| |
| |
| |
| r0_pos = (obs.robots[0].row, obs.robots[0].col) |
| r1_pos = (obs.robots[1].row, obs.robots[1].col) |
| assert r0_pos != r1_pos, "Robots should not occupy the same cell" |
|
|
| def test_step_two_robots_swap_cells_both_stay(self): |
| """Two robots swapping cells both stay in place.""" |
| from warehouse_env.env import WarehouseEnv |
| e = WarehouseEnv() |
| obs = e.reset(task_id="coordinated_delivery") |
| |
| |
| obs = e.step(WarehouseAction(robots=[ |
| RobotAction(robot_id=0, action_type="move_right"), |
| ])) |
| r0_pos = (obs.robots[0].row, obs.robots[0].col) |
| r1_pos = (obs.robots[1].row, obs.robots[1].col) |
| |
| for _ in range(2): |
| obs = e.step(WarehouseAction(robots=[ |
| RobotAction(robot_id=0, action_type="move_right"), |
| ])) |
| |
| obs_before = obs |
| r0_before = (obs_before.robots[0].row, obs_before.robots[0].col) |
| r1_before = (obs_before.robots[1].row, obs_before.robots[1].col) |
| |
| if abs(r0_before[1] - r1_before[1]) == 1: |
| obs = e.step(WarehouseAction(robots=[ |
| RobotAction(robot_id=0, action_type="move_right"), |
| RobotAction(robot_id=1, action_type="move_left"), |
| ])) |
| r0_after = (obs.robots[0].row, obs.robots[0].col) |
| r1_after = (obs.robots[1].row, obs.robots[1].col) |
| |
| assert r0_after == r0_before |
| assert r1_after == r1_before |
|
|
|
|
| |
| |
| |
|
|
| class TestStateProperty: |
| def test_state_returns_warehouse_state(self, env): |
| env.reset() |
| s = env.state |
| assert isinstance(s, WarehouseState) |
|
|
| def test_state_is_not_dict(self, env): |
| env.reset() |
| s = env.state |
| assert not isinstance(s, dict) |
|
|
| def test_state_not_none(self, env): |
| env.reset() |
| s = env.state |
| assert s is not None |
|
|
| def test_state_before_reset_returns_warehouse_state(self, env): |
| s = env.state |
| assert isinstance(s, WarehouseState) |
| assert s.done is False |
|
|
| def test_state_after_reset_task_id_correct(self, env): |
| env.reset(task_id="solo_delivery") |
| s = env.state |
| assert s.task_id == "solo_delivery" |
|
|
| def test_state_after_reset_step_count_zero(self, env): |
| env.reset() |
| s = env.state |
| assert s.step_count == 0 |
|
|
| def test_state_after_step_step_count_incremented(self, env): |
| env.reset() |
| env.step(WarehouseAction(robots=[])) |
| s = env.state |
| assert s.step_count == 1 |
|
|
| def test_state_is_property_not_method(self, env): |
| """state should be accessed without parentheses.""" |
| env.reset() |
| |
| s = env.state |
| assert isinstance(s, WarehouseState) |
| |
| assert not callable(s) |
|
|