Spaces:
Sleeping
Sleeping
| """ | |
| Tests for task definitions: presence of all tasks, structural validity, | |
| and grader callability. | |
| """ | |
| import pytest | |
| from grid_env.tasks import TASKS, get_task, GRID_SIZE | |
| from grid_env.graders import grade_episode | |
| from grid_env.env import WarehouseFulfillmentEnv | |
| EXPECTED_TASK_IDS = { | |
| "easy_single_pick", | |
| "medium_multi_item", | |
| "hard_restock_priority", | |
| "obstacle_course", | |
| "heavy_lifting", | |
| "stamina_run", | |
| "budget_run", | |
| "gauntlet", | |
| } | |
| def test_expected_task_count(): | |
| assert len(TASKS) == len(EXPECTED_TASK_IDS) | |
| def test_all_expected_task_ids_present(): | |
| assert set(TASKS.keys()) == EXPECTED_TASK_IDS | |
| def test_task_has_required_fields(task_id): | |
| task = get_task(task_id) | |
| assert task.task_id == task_id | |
| assert task.difficulty in {"easy", "medium", "hard", "expert"} | |
| assert task.max_steps > 0 | |
| assert task.battery_capacity > 0 | |
| assert len(task.bins) > 0 | |
| assert len(task.order) > 0 | |
| def test_task_required_scans_non_empty(task_id): | |
| task = get_task(task_id) | |
| assert len(task.required_scans) > 0, f"{task_id} has no required_scans" | |
| def test_task_order_skus_exist_in_bins(task_id): | |
| """Every SKU in the order exists in at least one bin.""" | |
| task = get_task(task_id) | |
| bin_skus = {b.sku for b in task.bins} | |
| for line in task.order: | |
| assert line.sku in bin_skus, f"{line.sku} ordered but not stocked in {task_id}" | |
| def test_required_scans_are_valid_bin_ids(task_id): | |
| """required_scans reference bin IDs that actually exist.""" | |
| task = get_task(task_id) | |
| bin_ids = {b.bin_id for b in task.bins} | |
| for scan_id in task.required_scans: | |
| assert scan_id in bin_ids, f"required scan {scan_id} not a valid bin in {task_id}" | |
| def test_grader_callable_returns_float_in_range(task_id): | |
| """Run a short episode and verify the grader returns [0,1].""" | |
| env = WarehouseFulfillmentEnv(task_id=task_id, seed=7) | |
| env.reset() | |
| done = False | |
| while not done: | |
| _, _, done, _ = env.step("wait") | |
| state = env.state() | |
| score = grade_episode(state) | |
| assert isinstance(score, float) | |
| assert 0.0 <= score <= 1.0 | |
| def test_obstacles_do_not_overlap_bins_or_stations(task_id): | |
| """Obstacles must not overlap with bins, stations, or agent start.""" | |
| task = get_task(task_id) | |
| obstacle_set = set(task.obstacles) | |
| bin_positions = {tuple(b.position) for b in task.bins} | |
| fixed_positions = { | |
| tuple(task.pack_station_position), | |
| tuple(task.charger_position), | |
| tuple(task.dock_position), | |
| tuple(task.agent_start), | |
| } | |
| if task.rest_position: | |
| fixed_positions.add(tuple(task.rest_position)) | |
| assert obstacle_set.isdisjoint(bin_positions), f"Obstacle overlaps bin in {task_id}" | |
| assert obstacle_set.isdisjoint(fixed_positions), f"Obstacle overlaps station in {task_id}" | |
| def test_get_task_raises_on_unknown_id(): | |
| with pytest.raises(KeyError, match="Unknown task_id"): | |
| get_task("does_not_exist") | |
| def test_grid_size_is_positive_tuple(): | |
| assert len(GRID_SIZE) == 2 | |
| assert GRID_SIZE[0] > 0 and GRID_SIZE[1] > 0 | |