File size: 3,463 Bytes
ea847ad
eaa79f0
ea847ad
 
 
 
 
 
 
 
 
eaa79f0
 
 
 
 
 
 
 
 
 
ea847ad
 
eaa79f0
 
ea847ad
 
 
 
 
 
 
 
 
 
eaa79f0
ea847ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eaa79f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea847ad
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""
Tests for task definitions: presence of all tasks, structural validity,
and grader callability.
"""

import pytest
from grid_env.tasks import TASKS, get_task, GRID_SIZE
from grid_env.graders import grade_episode
from grid_env.env import WarehouseFulfillmentEnv


EXPECTED_TASK_IDS = {
    "easy_single_pick",
    "medium_multi_item",
    "hard_restock_priority",
    "obstacle_course",
    "heavy_lifting",
    "stamina_run",
    "budget_run",
    "gauntlet",
}


def test_expected_task_count():
    assert len(TASKS) == len(EXPECTED_TASK_IDS)


def test_all_expected_task_ids_present():
    assert set(TASKS.keys()) == EXPECTED_TASK_IDS


@pytest.mark.parametrize("task_id", list(EXPECTED_TASK_IDS))
def test_task_has_required_fields(task_id):
    task = get_task(task_id)
    assert task.task_id == task_id
    assert task.difficulty in {"easy", "medium", "hard", "expert"}
    assert task.max_steps > 0
    assert task.battery_capacity > 0
    assert len(task.bins) > 0
    assert len(task.order) > 0


@pytest.mark.parametrize("task_id", list(EXPECTED_TASK_IDS))
def test_task_required_scans_non_empty(task_id):
    task = get_task(task_id)
    assert len(task.required_scans) > 0, f"{task_id} has no required_scans"


@pytest.mark.parametrize("task_id", list(EXPECTED_TASK_IDS))
def test_task_order_skus_exist_in_bins(task_id):
    """Every SKU in the order exists in at least one bin."""
    task = get_task(task_id)
    bin_skus = {b.sku for b in task.bins}
    for line in task.order:
        assert line.sku in bin_skus, f"{line.sku} ordered but not stocked in {task_id}"


@pytest.mark.parametrize("task_id", list(EXPECTED_TASK_IDS))
def test_required_scans_are_valid_bin_ids(task_id):
    """required_scans reference bin IDs that actually exist."""
    task = get_task(task_id)
    bin_ids = {b.bin_id for b in task.bins}
    for scan_id in task.required_scans:
        assert scan_id in bin_ids, f"required scan {scan_id} not a valid bin in {task_id}"


@pytest.mark.parametrize("task_id", list(EXPECTED_TASK_IDS))
def test_grader_callable_returns_float_in_range(task_id):
    """Run a short episode and verify the grader returns [0,1]."""
    env = WarehouseFulfillmentEnv(task_id=task_id, seed=7)
    env.reset()
    done = False
    while not done:
        _, _, done, _ = env.step("wait")
    state = env.state()
    score = grade_episode(state)
    assert isinstance(score, float)
    assert 0.0 <= score <= 1.0


@pytest.mark.parametrize("task_id", list(EXPECTED_TASK_IDS))
def test_obstacles_do_not_overlap_bins_or_stations(task_id):
    """Obstacles must not overlap with bins, stations, or agent start."""
    task = get_task(task_id)
    obstacle_set = set(task.obstacles)
    bin_positions = {tuple(b.position) for b in task.bins}
    fixed_positions = {
        tuple(task.pack_station_position),
        tuple(task.charger_position),
        tuple(task.dock_position),
        tuple(task.agent_start),
    }
    if task.rest_position:
        fixed_positions.add(tuple(task.rest_position))
    assert obstacle_set.isdisjoint(bin_positions), f"Obstacle overlaps bin in {task_id}"
    assert obstacle_set.isdisjoint(fixed_positions), f"Obstacle overlaps station in {task_id}"


def test_get_task_raises_on_unknown_id():
    with pytest.raises(KeyError, match="Unknown task_id"):
        get_task("does_not_exist")


def test_grid_size_is_positive_tuple():
    assert len(GRID_SIZE) == 2
    assert GRID_SIZE[0] > 0 and GRID_SIZE[1] > 0