mini-rl-env / tests /test_tasks.py
sohambose98's picture
updated the tests and graders
eaa79f0
"""
Tests for task definitions: presence of all tasks, structural validity,
and grader callability.
"""
import pytest
from grid_env.tasks import TASKS, get_task, GRID_SIZE
from grid_env.graders import grade_episode
from grid_env.env import WarehouseFulfillmentEnv
EXPECTED_TASK_IDS = {
"easy_single_pick",
"medium_multi_item",
"hard_restock_priority",
"obstacle_course",
"heavy_lifting",
"stamina_run",
"budget_run",
"gauntlet",
}
def test_expected_task_count():
assert len(TASKS) == len(EXPECTED_TASK_IDS)
def test_all_expected_task_ids_present():
assert set(TASKS.keys()) == EXPECTED_TASK_IDS
@pytest.mark.parametrize("task_id", list(EXPECTED_TASK_IDS))
def test_task_has_required_fields(task_id):
task = get_task(task_id)
assert task.task_id == task_id
assert task.difficulty in {"easy", "medium", "hard", "expert"}
assert task.max_steps > 0
assert task.battery_capacity > 0
assert len(task.bins) > 0
assert len(task.order) > 0
@pytest.mark.parametrize("task_id", list(EXPECTED_TASK_IDS))
def test_task_required_scans_non_empty(task_id):
task = get_task(task_id)
assert len(task.required_scans) > 0, f"{task_id} has no required_scans"
@pytest.mark.parametrize("task_id", list(EXPECTED_TASK_IDS))
def test_task_order_skus_exist_in_bins(task_id):
"""Every SKU in the order exists in at least one bin."""
task = get_task(task_id)
bin_skus = {b.sku for b in task.bins}
for line in task.order:
assert line.sku in bin_skus, f"{line.sku} ordered but not stocked in {task_id}"
@pytest.mark.parametrize("task_id", list(EXPECTED_TASK_IDS))
def test_required_scans_are_valid_bin_ids(task_id):
"""required_scans reference bin IDs that actually exist."""
task = get_task(task_id)
bin_ids = {b.bin_id for b in task.bins}
for scan_id in task.required_scans:
assert scan_id in bin_ids, f"required scan {scan_id} not a valid bin in {task_id}"
@pytest.mark.parametrize("task_id", list(EXPECTED_TASK_IDS))
def test_grader_callable_returns_float_in_range(task_id):
"""Run a short episode and verify the grader returns [0,1]."""
env = WarehouseFulfillmentEnv(task_id=task_id, seed=7)
env.reset()
done = False
while not done:
_, _, done, _ = env.step("wait")
state = env.state()
score = grade_episode(state)
assert isinstance(score, float)
assert 0.0 <= score <= 1.0
@pytest.mark.parametrize("task_id", list(EXPECTED_TASK_IDS))
def test_obstacles_do_not_overlap_bins_or_stations(task_id):
"""Obstacles must not overlap with bins, stations, or agent start."""
task = get_task(task_id)
obstacle_set = set(task.obstacles)
bin_positions = {tuple(b.position) for b in task.bins}
fixed_positions = {
tuple(task.pack_station_position),
tuple(task.charger_position),
tuple(task.dock_position),
tuple(task.agent_start),
}
if task.rest_position:
fixed_positions.add(tuple(task.rest_position))
assert obstacle_set.isdisjoint(bin_positions), f"Obstacle overlaps bin in {task_id}"
assert obstacle_set.isdisjoint(fixed_positions), f"Obstacle overlaps station in {task_id}"
def test_get_task_raises_on_unknown_id():
with pytest.raises(KeyError, match="Unknown task_id"):
get_task("does_not_exist")
def test_grid_size_is_positive_tuple():
assert len(GRID_SIZE) == 2
assert GRID_SIZE[0] > 0 and GRID_SIZE[1] > 0