""" Tests for task definitions: presence of all tasks, structural validity, and grader callability. """ import pytest from grid_env.tasks import TASKS, get_task, GRID_SIZE from grid_env.graders import grade_episode from grid_env.env import WarehouseFulfillmentEnv EXPECTED_TASK_IDS = { "easy_single_pick", "medium_multi_item", "hard_restock_priority", "obstacle_course", "heavy_lifting", "stamina_run", "budget_run", "gauntlet", } def test_expected_task_count(): assert len(TASKS) == len(EXPECTED_TASK_IDS) def test_all_expected_task_ids_present(): assert set(TASKS.keys()) == EXPECTED_TASK_IDS @pytest.mark.parametrize("task_id", list(EXPECTED_TASK_IDS)) def test_task_has_required_fields(task_id): task = get_task(task_id) assert task.task_id == task_id assert task.difficulty in {"easy", "medium", "hard", "expert"} assert task.max_steps > 0 assert task.battery_capacity > 0 assert len(task.bins) > 0 assert len(task.order) > 0 @pytest.mark.parametrize("task_id", list(EXPECTED_TASK_IDS)) def test_task_required_scans_non_empty(task_id): task = get_task(task_id) assert len(task.required_scans) > 0, f"{task_id} has no required_scans" @pytest.mark.parametrize("task_id", list(EXPECTED_TASK_IDS)) def test_task_order_skus_exist_in_bins(task_id): """Every SKU in the order exists in at least one bin.""" task = get_task(task_id) bin_skus = {b.sku for b in task.bins} for line in task.order: assert line.sku in bin_skus, f"{line.sku} ordered but not stocked in {task_id}" @pytest.mark.parametrize("task_id", list(EXPECTED_TASK_IDS)) def test_required_scans_are_valid_bin_ids(task_id): """required_scans reference bin IDs that actually exist.""" task = get_task(task_id) bin_ids = {b.bin_id for b in task.bins} for scan_id in task.required_scans: assert scan_id in bin_ids, f"required scan {scan_id} not a valid bin in {task_id}" @pytest.mark.parametrize("task_id", list(EXPECTED_TASK_IDS)) def test_grader_callable_returns_float_in_range(task_id): """Run a short episode and verify the grader returns [0,1].""" env = WarehouseFulfillmentEnv(task_id=task_id, seed=7) env.reset() done = False while not done: _, _, done, _ = env.step("wait") state = env.state() score = grade_episode(state) assert isinstance(score, float) assert 0.0 <= score <= 1.0 @pytest.mark.parametrize("task_id", list(EXPECTED_TASK_IDS)) def test_obstacles_do_not_overlap_bins_or_stations(task_id): """Obstacles must not overlap with bins, stations, or agent start.""" task = get_task(task_id) obstacle_set = set(task.obstacles) bin_positions = {tuple(b.position) for b in task.bins} fixed_positions = { tuple(task.pack_station_position), tuple(task.charger_position), tuple(task.dock_position), tuple(task.agent_start), } if task.rest_position: fixed_positions.add(tuple(task.rest_position)) assert obstacle_set.isdisjoint(bin_positions), f"Obstacle overlaps bin in {task_id}" assert obstacle_set.isdisjoint(fixed_positions), f"Obstacle overlaps station in {task_id}" def test_get_task_raises_on_unknown_id(): with pytest.raises(KeyError, match="Unknown task_id"): get_task("does_not_exist") def test_grid_size_is_positive_tuple(): assert len(GRID_SIZE) == 2 assert GRID_SIZE[0] > 0 and GRID_SIZE[1] > 0