mini-rl-env / grid_env /tasks.py
sohambose98's picture
updated the tests and graders
eaa79f0
"""
Task definitions for the warehouse fulfillment environment.
"""
from __future__ import annotations
from typing import Dict, List
from .models import BinState, OrderLine, TaskDefinition
GRID_SIZE = (7, 7)
def _base_bins() -> List[BinState]:
return [
BinState(bin_id="A1", position=(2, 1), sku="thermometer", quantity=2),
BinState(bin_id="A2", position=(2, 3), sku="cough_syrup", quantity=2),
BinState(bin_id="B1", position=(4, 1), sku="bandage_kit", quantity=2),
BinState(bin_id="B2", position=(4, 3), sku="pain_relief", quantity=2),
BinState(bin_id="C1", position=(2, 5), sku="gloves", quantity=2),
]
TASKS: Dict[str, TaskDefinition] = {
"easy_single_pick": TaskDefinition(
task_id="easy_single_pick",
difficulty="easy",
title="Single urgent order",
description=(
"Fulfill a same-hour pharmacy order by retrieving one thermometer "
"from the correct bin and packing it at the station."
),
max_steps=40,
battery_capacity=36,
low_battery_threshold=8,
agent_start=(1, 1),
agent_heading="E",
dock_position=(1, 1),
pack_station_position=(5, 5),
charger_position=(1, 5),
bins=_base_bins(),
order=[OrderLine(sku="thermometer", quantity=1)],
required_scans=["A1"],
rubric_criteria=[
{
"name": "completion",
"description": "Order fully packed.",
"check": "param_at_least:state.completion_ratio=1.0",
},
{
"name": "pick_item",
"description": "Picked at least one item.",
"check": "tool_used:pick_item",
},
{
"name": "pack_item",
"description": "Packed the order at the station.",
"check": "tool_used:pack_item",
},
{
"name": "no_wrong_picks",
"description": "No incorrect picks.",
"check": "param_at_most:state.wrong_picks=0",
},
{
"name": "no_invalid_actions",
"description": "No invalid actions.",
"check": "param_at_most:state.invalid_actions=0",
},
],
),
"medium_multi_item": TaskDefinition(
task_id="medium_multi_item",
difficulty="medium",
title="Two-line prescription basket",
description=(
"Fulfill a two-line order by scanning the relevant bins before each "
"pick, then pack one cough syrup and one pain relief unit."
),
max_steps=60,
battery_capacity=34,
low_battery_threshold=8,
agent_start=(1, 1),
agent_heading="E",
dock_position=(1, 1),
pack_station_position=(5, 5),
charger_position=(1, 5),
bins=_base_bins(),
order=[
OrderLine(sku="cough_syrup", quantity=1),
OrderLine(sku="pain_relief", quantity=1),
],
required_scans=["A2", "B2"],
rubric_criteria=[
{
"name": "completion",
"description": "All items packed.",
"check": "param_at_least:state.completion_ratio=1.0",
},
{
"name": "scans",
"description": "Scanned required bins.",
"check": "param_at_least:state.correct_scans=2",
},
{
"name": "pack_item",
"description": "Packed items at the station.",
"check": "tool_used:pack_item",
},
{
"name": "no_wrong_picks",
"description": "No incorrect picks.",
"check": "param_at_most:state.wrong_picks=0",
},
{
"name": "no_wrong_scans",
"description": "No incorrect scans.",
"check": "param_at_most:state.wrong_scans=0",
},
{
"name": "few_invalid_actions",
"description": "At most one invalid action.",
"check": "param_at_most:state.invalid_actions=1",
},
],
),
"hard_restock_priority": TaskDefinition(
task_id="hard_restock_priority",
difficulty="hard",
title="Priority basket with battery management",
description=(
"Fulfill a three-line urgent order while managing battery reserve. "
"The agent must scan each target bin, recharge when needed, and "
"pack a thermometer, gloves, and bandage kit."
),
max_steps=85,
battery_capacity=24,
low_battery_threshold=6,
agent_start=(1, 1),
agent_heading="E",
dock_position=(1, 1),
pack_station_position=(5, 5),
charger_position=(1, 5),
bins=_base_bins(),
order=[
OrderLine(sku="thermometer", quantity=1),
OrderLine(sku="gloves", quantity=1),
OrderLine(sku="bandage_kit", quantity=1),
],
required_scans=["A1", "C1", "B1"],
rubric_criteria=[
{
"name": "completion",
"description": "All items packed.",
"check": "param_at_least:state.completion_ratio=1.0",
},
{
"name": "scans",
"description": "Scanned all required bins.",
"check": "param_at_least:state.correct_scans=3",
},
{
"name": "recharge",
"description": "Recharged at least once.",
"check": "tool_used:recharge",
},
{
"name": "no_battery_depletion",
"description": "Avoided battery depletion.",
"check": "param_at_most:state.battery_depletion_events=0",
},
{
"name": "no_wrong_picks",
"description": "No incorrect picks.",
"check": "param_at_most:state.wrong_picks=0",
},
{
"name": "few_invalid_actions",
"description": "At most one invalid action.",
"check": "param_at_most:state.invalid_actions=1",
},
],
),
# -----------------------------------------------------------------------
# Task 4: obstacle_course (medium) — obstacles block direct paths
# -----------------------------------------------------------------------
"obstacle_course": TaskDefinition(
task_id="obstacle_course",
difficulty="medium",
title="Obstacle-filled aisle navigation",
description=(
"Fulfill a two-item order in a warehouse cluttered with fallen crates. "
"Navigate around obstacles to reach bins, scan them, pick one thermometer "
"and one bandage kit, then pack both at the station."
),
max_steps=70,
battery_capacity=40,
low_battery_threshold=10,
agent_start=(0, 0),
agent_heading="E",
dock_position=(0, 0),
pack_station_position=(6, 6),
charger_position=(0, 6),
bins=[
BinState(bin_id="A1", position=(2, 1), sku="thermometer", quantity=2),
BinState(bin_id="B1", position=(4, 1), sku="bandage_kit", quantity=2),
BinState(bin_id="A2", position=(2, 3), sku="cough_syrup", quantity=2),
BinState(bin_id="B2", position=(4, 3), sku="pain_relief", quantity=2),
BinState(bin_id="C1", position=(2, 5), sku="gloves", quantity=2),
],
order=[
OrderLine(sku="thermometer", quantity=1),
OrderLine(sku="bandage_kit", quantity=1),
],
required_scans=["A1", "B1"],
obstacles=[(1, 2), (2, 2), (3, 2), (3, 4), (4, 4), (5, 4)],
rubric_criteria=[
{
"name": "completion",
"description": "All items packed.",
"check": "param_at_least:state.completion_ratio=1.0",
},
{
"name": "scans",
"description": "Scanned both required bins.",
"check": "param_at_least:state.correct_scans=2",
},
{
"name": "pack_item",
"description": "Packed items at the station.",
"check": "tool_used:pack_item",
},
{
"name": "no_obstacle_collisions",
"description": "Avoided all obstacle collisions.",
"check": "param_at_most:state.obstacle_collisions=0",
},
{
"name": "no_wrong_picks",
"description": "No incorrect picks.",
"check": "param_at_most:state.wrong_picks=0",
},
{
"name": "few_invalid_actions",
"description": "At most two invalid actions.",
"check": "param_at_most:state.invalid_actions=2",
},
],
),
# -----------------------------------------------------------------------
# Task 5: heavy_lifting (hard) — items have weight, limited carry capacity
# -----------------------------------------------------------------------
"heavy_lifting": TaskDefinition(
task_id="heavy_lifting",
difficulty="hard",
title="Heavy-item logistics with weight limits",
description=(
"Fulfill a three-item order where items vary in weight (1-4 units). "
"The agent has a carry capacity of 3 and must choose pickup order wisely. "
"Heavier items drain more battery while moving. Scan each bin, pick items "
"that fit within your carry limit, and pack at the station. "
"The heavy pain_relief (weight 4) cannot be carried — skip it!"
),
max_steps=90,
battery_capacity=32,
low_battery_threshold=8,
agent_start=(1, 1),
agent_heading="E",
dock_position=(1, 1),
pack_station_position=(5, 5),
charger_position=(1, 5),
bins=[
BinState(bin_id="A1", position=(2, 1), sku="thermometer", quantity=3, weight=1),
BinState(bin_id="A2", position=(2, 3), sku="cough_syrup", quantity=2, weight=2),
BinState(bin_id="B1", position=(4, 1), sku="bandage_kit", quantity=2, weight=3),
BinState(bin_id="B2", position=(4, 3), sku="pain_relief", quantity=2, weight=4),
BinState(bin_id="C1", position=(2, 5), sku="gloves", quantity=3, weight=1),
],
order=[
OrderLine(sku="thermometer", quantity=1),
OrderLine(sku="cough_syrup", quantity=1),
OrderLine(sku="bandage_kit", quantity=1),
],
required_scans=["A1", "A2", "B1"],
carry_capacity=3,
rubric_criteria=[
{
"name": "completion",
"description": "All items packed.",
"check": "param_at_least:state.completion_ratio=1.0",
},
{
"name": "scans",
"description": "Scanned all three required bins.",
"check": "param_at_least:state.correct_scans=3",
},
{
"name": "recharge",
"description": "Recharged at least once.",
"check": "tool_used:recharge",
},
{
"name": "no_overweight",
"description": "Never tried to pick an overweight item.",
"check": "param_at_most:state.overweight_attempts=0",
},
{
"name": "no_battery_depletion",
"description": "Avoided battery depletion.",
"check": "param_at_most:state.battery_depletion_events=0",
},
{
"name": "no_wrong_picks",
"description": "No incorrect picks.",
"check": "param_at_most:state.wrong_picks=0",
},
{
"name": "few_invalid_actions",
"description": "At most two invalid actions.",
"check": "param_at_most:state.invalid_actions=2",
},
],
),
# -----------------------------------------------------------------------
# Task 6: stamina_run (hard) — stamina drains on movement
# -----------------------------------------------------------------------
"stamina_run": TaskDefinition(
task_id="stamina_run",
difficulty="hard",
title="Endurance run with stamina management",
description=(
"Fulfill a two-item order while managing stamina. Every move drains stamina; "
"when stamina hits zero, movement costs double battery. Rest at the rest area "
"to restore stamina. Pick one cough syrup and one gloves unit, scan bins, and "
"pack at the station without running out of energy."
),
max_steps=80,
battery_capacity=36,
low_battery_threshold=8,
agent_start=(0, 0),
agent_heading="E",
dock_position=(0, 0),
pack_station_position=(6, 6),
charger_position=(0, 6),
rest_position=(3, 3),
stamina_capacity=12,
stamina_move_cost=1,
bins=[
BinState(bin_id="A1", position=(2, 1), sku="thermometer", quantity=2),
BinState(bin_id="A2", position=(2, 3), sku="cough_syrup", quantity=2),
BinState(bin_id="B1", position=(4, 1), sku="bandage_kit", quantity=2),
BinState(bin_id="B2", position=(4, 3), sku="pain_relief", quantity=2),
BinState(bin_id="C1", position=(2, 5), sku="gloves", quantity=2),
],
order=[
OrderLine(sku="cough_syrup", quantity=1),
OrderLine(sku="gloves", quantity=1),
],
required_scans=["A2", "C1"],
rubric_criteria=[
{
"name": "completion",
"description": "All items packed.",
"check": "param_at_least:state.completion_ratio=1.0",
},
{
"name": "scans",
"description": "Scanned both required bins.",
"check": "param_at_least:state.correct_scans=2",
},
{
"name": "rest_used",
"description": "Used the rest area at least once.",
"check": "tool_used:rest",
},
{
"name": "no_stamina_depletion",
"description": "Avoided complete stamina depletion.",
"check": "param_at_most:state.stamina_depletion_events=0",
},
{
"name": "no_battery_depletion",
"description": "Avoided battery depletion.",
"check": "param_at_most:state.battery_depletion_events=0",
},
{
"name": "no_wrong_picks",
"description": "No incorrect picks.",
"check": "param_at_most:state.wrong_picks=0",
},
],
),
# -----------------------------------------------------------------------
# Task 7: budget_run (expert) — money rewards and profit target
# -----------------------------------------------------------------------
"budget_run": TaskDefinition(
task_id="budget_run",
difficulty="expert",
title="Profitable fulfillment under budget pressure",
description=(
"Fulfill orders for profit. Each item has a dollar value earned when correctly "
"packed. Wrong packs lose half the item value. You must reach a profit target "
"of $15.00 while completing the order. Pick high-value items efficiently: "
"2 thermometers ($5 each) and 1 bandage kit ($8). Budget-aware decisions matter."
),
max_steps=70,
battery_capacity=30,
low_battery_threshold=6,
agent_start=(1, 1),
agent_heading="E",
dock_position=(1, 1),
pack_station_position=(5, 5),
charger_position=(1, 5),
bins=[
BinState(bin_id="A1", position=(2, 1), sku="thermometer", quantity=3, value=5.0),
BinState(bin_id="A2", position=(2, 3), sku="cough_syrup", quantity=2, value=3.0),
BinState(bin_id="B1", position=(4, 1), sku="bandage_kit", quantity=2, value=8.0),
BinState(bin_id="B2", position=(4, 3), sku="pain_relief", quantity=2, value=4.0),
BinState(bin_id="C1", position=(2, 5), sku="gloves", quantity=2, value=2.0),
],
order=[
OrderLine(sku="thermometer", quantity=2),
OrderLine(sku="bandage_kit", quantity=1),
],
required_scans=["A1", "B1"],
profit_target=15.0,
rubric_criteria=[
{
"name": "completion",
"description": "All items packed.",
"check": "param_at_least:state.completion_ratio=1.0",
},
{
"name": "profit_target",
"description": "Reached the profit target of $15.",
"check": "param_at_least:state.money_earned=15.0",
},
{
"name": "scans",
"description": "Scanned required bins.",
"check": "param_at_least:state.correct_scans=2",
},
{
"name": "recharge",
"description": "Recharged at least once.",
"check": "tool_used:recharge",
},
{
"name": "no_money_lost",
"description": "No money lost from wrong packs.",
"check": "param_at_most:state.money_lost=0.0",
},
{
"name": "no_wrong_picks",
"description": "No incorrect picks.",
"check": "param_at_most:state.wrong_picks=0",
},
{
"name": "few_invalid_actions",
"description": "At most one invalid action.",
"check": "param_at_most:state.invalid_actions=1",
},
],
),
# -----------------------------------------------------------------------
# Task 8: gauntlet (expert) — all mechanics combined
# -----------------------------------------------------------------------
"gauntlet": TaskDefinition(
task_id="gauntlet",
difficulty="expert",
title="The gauntlet: obstacles, weight, stamina, and profit",
description=(
"The ultimate warehouse challenge. Navigate a cluttered warehouse with obstacles, "
"manage item weights (carry capacity 3), conserve stamina (rest when needed), "
"earn money for packed items, and hit a $20 profit target. Fulfill a four-item "
"order: 1 thermometer ($5, wt 1), 1 cough syrup ($6, wt 2), 1 bandage kit ($8, wt 3), "
"and 1 gloves ($4, wt 1). Recharge battery, rest for stamina, avoid obstacles, "
"and finish profitable."
),
max_steps=120,
battery_capacity=28,
low_battery_threshold=7,
agent_start=(0, 0),
agent_heading="S",
dock_position=(0, 0),
pack_station_position=(6, 6),
charger_position=(6, 0),
rest_position=(0, 6),
stamina_capacity=10,
stamina_move_cost=1,
carry_capacity=3,
profit_target=20.0,
obstacles=[(1, 1), (3, 1), (5, 3), (3, 3), (1, 5), (5, 5)],
bins=[
BinState(bin_id="A1", position=(2, 0), sku="thermometer", quantity=3, weight=1, value=5.0),
BinState(bin_id="A2", position=(2, 4), sku="cough_syrup", quantity=2, weight=2, value=6.0),
BinState(bin_id="B1", position=(4, 0), sku="bandage_kit", quantity=2, weight=3, value=8.0),
BinState(bin_id="B2", position=(4, 4), sku="pain_relief", quantity=2, weight=4, value=4.0),
BinState(bin_id="C1", position=(4, 2), sku="gloves", quantity=3, weight=1, value=4.0),
],
order=[
OrderLine(sku="thermometer", quantity=1),
OrderLine(sku="cough_syrup", quantity=1),
OrderLine(sku="bandage_kit", quantity=1),
OrderLine(sku="gloves", quantity=1),
],
required_scans=["A1", "A2", "B1", "C1"],
rubric_criteria=[
{
"name": "completion",
"description": "All four items packed.",
"check": "param_at_least:state.completion_ratio=1.0",
},
{
"name": "profit_target",
"description": "Reached $20 profit target.",
"check": "param_at_least:state.money_earned=20.0",
},
{
"name": "scans",
"description": "Scanned all four required bins.",
"check": "param_at_least:state.correct_scans=4",
},
{
"name": "recharge",
"description": "Recharged at least once.",
"check": "tool_used:recharge",
},
{
"name": "rest_used",
"description": "Used the rest area at least once.",
"check": "tool_used:rest",
},
{
"name": "no_obstacle_collisions",
"description": "Avoided all obstacle collisions.",
"check": "param_at_most:state.obstacle_collisions=0",
},
{
"name": "no_overweight",
"description": "Never tried to pick an overweight item.",
"check": "param_at_most:state.overweight_attempts=0",
},
{
"name": "no_battery_depletion",
"description": "Avoided battery depletion.",
"check": "param_at_most:state.battery_depletion_events=0",
},
{
"name": "no_stamina_depletion",
"description": "Avoided stamina depletion.",
"check": "param_at_most:state.stamina_depletion_events=0",
},
{
"name": "no_wrong_picks",
"description": "No incorrect picks.",
"check": "param_at_most:state.wrong_picks=0",
},
{
"name": "few_invalid_actions",
"description": "At most two invalid actions.",
"check": "param_at_most:state.invalid_actions=2",
},
],
),
}
def get_task(task_id: str) -> TaskDefinition:
if task_id not in TASKS:
raise KeyError(f"Unknown task_id: {task_id}")
return TASKS[task_id]