scaenv / tests /test_tasks.py
noanya's picture
test(01-01): add failing tests for Grid and TaskRegistry
9643285
"""Tests for warehouse_env/tasks.py — TaskConfig and TASK_REGISTRY."""
import pytest
from warehouse_env.tasks import TaskConfig, TASK_REGISTRY
class TestTaskRegistry:
def test_has_three_tasks(self):
assert len(TASK_REGISTRY) == 3
def test_has_solo_delivery(self):
assert "solo_delivery" in TASK_REGISTRY
assert isinstance(TASK_REGISTRY["solo_delivery"], TaskConfig)
def test_has_coordinated_delivery(self):
assert "coordinated_delivery" in TASK_REGISTRY
assert isinstance(TASK_REGISTRY["coordinated_delivery"], TaskConfig)
def test_has_crisis_management(self):
assert "crisis_management" in TASK_REGISTRY
assert isinstance(TASK_REGISTRY["crisis_management"], TaskConfig)
class TestSoloDelivery:
def setup_method(self):
self.task = TASK_REGISTRY["solo_delivery"]
def test_grid_size(self):
assert self.task.grid_rows == 10
assert self.task.grid_cols == 10
def test_num_robots(self):
assert self.task.num_robots == 1
def test_initial_orders_count(self):
assert len(self.task.initial_orders) == 5
def test_max_steps(self):
assert self.task.max_steps == 100
def test_order_dict_structure(self):
for order in self.task.initial_orders:
assert "order_id" in order
assert "shelf_pos" in order
assert "packing_pos" in order
def test_task_id(self):
assert self.task.task_id == "solo_delivery"
class TestCoordinatedDelivery:
def setup_method(self):
self.task = TASK_REGISTRY["coordinated_delivery"]
def test_grid_size(self):
assert self.task.grid_rows == 12
assert self.task.grid_cols == 12
def test_num_robots(self):
assert self.task.num_robots == 3
def test_initial_orders_count(self):
assert len(self.task.initial_orders) == 10
def test_max_steps(self):
assert self.task.max_steps == 150
def test_has_disruption_events(self):
assert len(self.task.disruption_events) >= 1
event = self.task.disruption_events[0]
assert "step" in event
assert "type" in event
assert "params" in event
def test_order_dict_structure(self):
for order in self.task.initial_orders:
assert "order_id" in order
assert "shelf_pos" in order
assert "packing_pos" in order
class TestCrisisManagement:
def setup_method(self):
self.task = TASK_REGISTRY["crisis_management"]
def test_grid_size(self):
assert self.task.grid_rows == 15
assert self.task.grid_cols == 15
def test_num_robots(self):
assert self.task.num_robots == 5
def test_initial_orders_count(self):
assert len(self.task.initial_orders) == 20
def test_max_steps(self):
assert self.task.max_steps == 200
def test_has_disruption_events(self):
assert len(self.task.disruption_events) >= 1
def test_order_dict_structure(self):
for order in self.task.initial_orders:
assert "order_id" in order
assert "shelf_pos" in order
assert "packing_pos" in order
class TestTaskConfigDataclass:
def test_is_dataclass(self):
import dataclasses
assert dataclasses.is_dataclass(TaskConfig)
def test_has_required_fields(self):
task = TASK_REGISTRY["solo_delivery"]
assert hasattr(task, "task_id")
assert hasattr(task, "name")
assert hasattr(task, "description")
assert hasattr(task, "grid_rows")
assert hasattr(task, "grid_cols")
assert hasattr(task, "num_robots")
assert hasattr(task, "shelf_positions")
assert hasattr(task, "packing_positions")
assert hasattr(task, "robot_start_positions")
assert hasattr(task, "initial_orders")
assert hasattr(task, "max_steps")
assert hasattr(task, "disruption_events")
assert hasattr(task, "time_bonus_window")