Spaces:
Sleeping
Sleeping
Commit ·
eaa79f0
1
Parent(s): 3ac18bb
updated the tests and graders
Browse files- grid_env/Server/__pycache__/__init__.cpython-313.pyc +0 -0
- grid_env/Server/__pycache__/app.cpython-313.pyc +0 -0
- grid_env/Server/__pycache__/warehouse_env.cpython-313.pyc +0 -0
- grid_env/__init__.py +16 -1
- grid_env/__pycache__/__init__.cpython-313.pyc +0 -0
- grid_env/__pycache__/baseline.cpython-313.pyc +0 -0
- grid_env/__pycache__/client.cpython-313.pyc +0 -0
- grid_env/__pycache__/env.cpython-313.pyc +0 -0
- grid_env/__pycache__/graders.cpython-313.pyc +0 -0
- grid_env/__pycache__/models.cpython-313.pyc +0 -0
- grid_env/__pycache__/tasks.cpython-313.pyc +0 -0
- grid_env/baseline.py +9 -1
- grid_env/env.py +109 -6
- grid_env/graders.py +42 -7
- grid_env/models.py +25 -1
- grid_env/openv.yaml +15 -0
- grid_env/tasks.py +387 -0
- grid_env/tools.py +1 -0
- openenv.yaml +15 -0
- tests/conftest.py +35 -0
- tests/test_baseline_stub.py +11 -1
- tests/test_env_smoke.py +14 -4
- tests/test_graders.py +145 -9
- tests/test_server.py +12 -2
- tests/test_tasks.py +32 -5
grid_env/Server/__pycache__/__init__.cpython-313.pyc
CHANGED
|
Binary files a/grid_env/Server/__pycache__/__init__.cpython-313.pyc and b/grid_env/Server/__pycache__/__init__.cpython-313.pyc differ
|
|
|
grid_env/Server/__pycache__/app.cpython-313.pyc
CHANGED
|
Binary files a/grid_env/Server/__pycache__/app.cpython-313.pyc and b/grid_env/Server/__pycache__/app.cpython-313.pyc differ
|
|
|
grid_env/Server/__pycache__/warehouse_env.cpython-313.pyc
CHANGED
|
Binary files a/grid_env/Server/__pycache__/warehouse_env.cpython-313.pyc and b/grid_env/Server/__pycache__/warehouse_env.cpython-313.pyc differ
|
|
|
grid_env/__init__.py
CHANGED
|
@@ -1,6 +1,16 @@
|
|
| 1 |
from .client import WarehouseEnvClient
|
| 2 |
from .env import WarehouseFulfillmentEnv, available_tasks
|
| 3 |
-
from .graders import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
from .models import (
|
| 5 |
BaselineCommand,
|
| 6 |
BinState,
|
|
@@ -30,8 +40,13 @@ __all__ = [
|
|
| 30 |
"WarehouseReward",
|
| 31 |
"WarehouseState",
|
| 32 |
"available_tasks",
|
|
|
|
| 33 |
"grade_easy",
|
| 34 |
"grade_episode",
|
|
|
|
| 35 |
"grade_hard",
|
|
|
|
| 36 |
"grade_medium",
|
|
|
|
|
|
|
| 37 |
]
|
|
|
|
| 1 |
from .client import WarehouseEnvClient
|
| 2 |
from .env import WarehouseFulfillmentEnv, available_tasks
|
| 3 |
+
from .graders import (
|
| 4 |
+
grade_budget_run,
|
| 5 |
+
grade_easy,
|
| 6 |
+
grade_episode,
|
| 7 |
+
grade_gauntlet,
|
| 8 |
+
grade_hard,
|
| 9 |
+
grade_heavy_lifting,
|
| 10 |
+
grade_medium,
|
| 11 |
+
grade_obstacle_course,
|
| 12 |
+
grade_stamina_run,
|
| 13 |
+
)
|
| 14 |
from .models import (
|
| 15 |
BaselineCommand,
|
| 16 |
BinState,
|
|
|
|
| 40 |
"WarehouseReward",
|
| 41 |
"WarehouseState",
|
| 42 |
"available_tasks",
|
| 43 |
+
"grade_budget_run",
|
| 44 |
"grade_easy",
|
| 45 |
"grade_episode",
|
| 46 |
+
"grade_gauntlet",
|
| 47 |
"grade_hard",
|
| 48 |
+
"grade_heavy_lifting",
|
| 49 |
"grade_medium",
|
| 50 |
+
"grade_obstacle_course",
|
| 51 |
+
"grade_stamina_run",
|
| 52 |
]
|
grid_env/__pycache__/__init__.cpython-313.pyc
CHANGED
|
Binary files a/grid_env/__pycache__/__init__.cpython-313.pyc and b/grid_env/__pycache__/__init__.cpython-313.pyc differ
|
|
|
grid_env/__pycache__/baseline.cpython-313.pyc
CHANGED
|
Binary files a/grid_env/__pycache__/baseline.cpython-313.pyc and b/grid_env/__pycache__/baseline.cpython-313.pyc differ
|
|
|
grid_env/__pycache__/client.cpython-313.pyc
CHANGED
|
Binary files a/grid_env/__pycache__/client.cpython-313.pyc and b/grid_env/__pycache__/client.cpython-313.pyc differ
|
|
|
grid_env/__pycache__/env.cpython-313.pyc
CHANGED
|
Binary files a/grid_env/__pycache__/env.cpython-313.pyc and b/grid_env/__pycache__/env.cpython-313.pyc differ
|
|
|
grid_env/__pycache__/graders.cpython-313.pyc
CHANGED
|
Binary files a/grid_env/__pycache__/graders.cpython-313.pyc and b/grid_env/__pycache__/graders.cpython-313.pyc differ
|
|
|
grid_env/__pycache__/models.cpython-313.pyc
CHANGED
|
Binary files a/grid_env/__pycache__/models.cpython-313.pyc and b/grid_env/__pycache__/models.cpython-313.pyc differ
|
|
|
grid_env/__pycache__/tasks.cpython-313.pyc
CHANGED
|
Binary files a/grid_env/__pycache__/tasks.cpython-313.pyc and b/grid_env/__pycache__/tasks.cpython-313.pyc differ
|
|
|
grid_env/baseline.py
CHANGED
|
@@ -21,7 +21,7 @@ except ImportError:
|
|
| 21 |
|
| 22 |
SYSTEM_PROMPT = """You control a warehouse fulfillment robot.
|
| 23 |
Return exactly one JSON object with:
|
| 24 |
-
- command: one of turn_left, turn_right, move_forward, scan_bin, pick_item, pack_item, recharge, wait
|
| 25 |
- rationale: a short sentence
|
| 26 |
|
| 27 |
Objective:
|
|
@@ -29,6 +29,14 @@ Objective:
|
|
| 29 |
- Use scans before picks when the task requires verified bins.
|
| 30 |
- Recharge before battery depletion if needed.
|
| 31 |
- Avoid invalid actions and unnecessary wandering.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
"""
|
| 33 |
|
| 34 |
|
|
|
|
| 21 |
|
| 22 |
SYSTEM_PROMPT = """You control a warehouse fulfillment robot.
|
| 23 |
Return exactly one JSON object with:
|
| 24 |
+
- command: one of turn_left, turn_right, move_forward, scan_bin, pick_item, pack_item, recharge, rest, wait
|
| 25 |
- rationale: a short sentence
|
| 26 |
|
| 27 |
Objective:
|
|
|
|
| 29 |
- Use scans before picks when the task requires verified bins.
|
| 30 |
- Recharge before battery depletion if needed.
|
| 31 |
- Avoid invalid actions and unnecessary wandering.
|
| 32 |
+
|
| 33 |
+
Advanced mechanics (active on harder tasks):
|
| 34 |
+
- Obstacles: some cells are impassable. If front_cell says "obstacle", turn to find another route.
|
| 35 |
+
- Item weight: items have weight. If an item exceeds your carry capacity, you cannot pick it.
|
| 36 |
+
Heavier items drain more battery while moving.
|
| 37 |
+
- Stamina: movement costs stamina. When stamina hits 0, movement costs double battery.
|
| 38 |
+
Use the "rest" action at the rest area to restore stamina.
|
| 39 |
+
- Money: packing correct items earns money; wrong packs lose money. Hit the profit target if set.
|
| 40 |
"""
|
| 41 |
|
| 42 |
|
grid_env/env.py
CHANGED
|
@@ -40,6 +40,7 @@ class WarehouseFulfillmentEnv:
|
|
| 40 |
"pick_item",
|
| 41 |
"pack_item",
|
| 42 |
"recharge",
|
|
|
|
| 43 |
"wait",
|
| 44 |
]
|
| 45 |
|
|
@@ -100,6 +101,8 @@ class WarehouseFulfillmentEnv:
|
|
| 100 |
reward_value, narrative = self._pack_item(reward_value)
|
| 101 |
elif command == "recharge":
|
| 102 |
reward_value, narrative = self._recharge(reward_value)
|
|
|
|
|
|
|
| 103 |
elif command == "wait":
|
| 104 |
reward_value -= 0.01
|
| 105 |
self.metrics.invalid_actions += 1
|
|
@@ -110,8 +113,8 @@ class WarehouseFulfillmentEnv:
|
|
| 110 |
narrative = f"Unknown action: {command}."
|
| 111 |
|
| 112 |
self.action_history.append(command)
|
| 113 |
-
self.done = self.
|
| 114 |
-
self.success = self.
|
| 115 |
if self.success:
|
| 116 |
reward_value += 0.50
|
| 117 |
narrative = "Order fully packed and ready for dispatch."
|
|
@@ -143,11 +146,17 @@ class WarehouseFulfillmentEnv:
|
|
| 143 |
agent_position=self.agent_position,
|
| 144 |
heading=self.heading,
|
| 145 |
carrying=self.carrying,
|
|
|
|
| 146 |
battery_level=self.battery_level,
|
| 147 |
battery_capacity=self.task.battery_capacity,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
dock_position=self.task.dock_position,
|
| 149 |
pack_station_position=self.task.pack_station_position,
|
| 150 |
charger_position=self.task.charger_position,
|
|
|
|
| 151 |
bins=[self._clone_bin(bin_state) for bin_state in self.bins],
|
| 152 |
order=[self._clone_order_line(line) for line in self.order],
|
| 153 |
packed_order=[self._clone_order_line(line) for line in self.packed_order],
|
|
@@ -166,6 +175,9 @@ class WarehouseFulfillmentEnv:
|
|
| 166 |
self.heading = self.task.agent_heading
|
| 167 |
self.battery_level = self.task.battery_capacity
|
| 168 |
self.carrying: Optional[str] = None
|
|
|
|
|
|
|
|
|
|
| 169 |
self.step_count = 0
|
| 170 |
self.done = False
|
| 171 |
self.success = False
|
|
@@ -206,6 +218,8 @@ class WarehouseFulfillmentEnv:
|
|
| 206 |
front = self._front_position()
|
| 207 |
if not self._in_bounds(front):
|
| 208 |
return "wall"
|
|
|
|
|
|
|
| 209 |
front_bin = self._front_bin()
|
| 210 |
if front_bin:
|
| 211 |
return f"bin {front_bin.bin_id} ({front_bin.sku})"
|
|
@@ -215,18 +229,35 @@ class WarehouseFulfillmentEnv:
|
|
| 215 |
return "charger"
|
| 216 |
if front == self.task.dock_position:
|
| 217 |
return "dock"
|
|
|
|
|
|
|
| 218 |
return "aisle"
|
| 219 |
|
| 220 |
def _move_forward(self, reward: float) -> Tuple[float, str]:
|
| 221 |
next_pos = self._front_position()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
if not self._in_bounds(next_pos) or self._occupied(next_pos):
|
| 223 |
self.metrics.invalid_actions += 1
|
| 224 |
self._consume_battery(1)
|
| 225 |
return reward - 0.08, "Forward move blocked by warehouse infrastructure."
|
| 226 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
self.agent_position = next_pos
|
| 228 |
self.metrics.distance_travelled += 1
|
| 229 |
-
self._consume_battery(
|
|
|
|
| 230 |
return reward, f"Moved to aisle cell {self.agent_position}."
|
| 231 |
|
| 232 |
def _scan_bin(self, reward: float) -> Tuple[float, str]:
|
|
@@ -260,13 +291,22 @@ class WarehouseFulfillmentEnv:
|
|
| 260 |
self.metrics.invalid_actions += 1
|
| 261 |
self._consume_battery(1)
|
| 262 |
return reward - 0.10, f"Bin {bin_state.bin_id} is empty."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
|
| 264 |
self._consume_battery(1)
|
| 265 |
bin_state.quantity -= 1
|
| 266 |
self.carrying = bin_state.sku
|
|
|
|
| 267 |
if self._remaining_quantity(bin_state.sku) > 0:
|
| 268 |
self.metrics.correct_picks += 1
|
| 269 |
-
return reward + 0.20, f"Picked {bin_state.sku} from {bin_state.bin_id}."
|
| 270 |
|
| 271 |
self.metrics.wrong_picks += 1
|
| 272 |
return reward - 0.18, f"Picked {bin_state.sku}, which is not needed now."
|
|
@@ -285,18 +325,28 @@ class WarehouseFulfillmentEnv:
|
|
| 285 |
remaining = self._remaining_quantity(self.carrying)
|
| 286 |
if remaining <= 0:
|
| 287 |
item = self.carrying
|
|
|
|
| 288 |
self.carrying = None
|
|
|
|
| 289 |
self.metrics.wrong_picks += 1
|
|
|
|
|
|
|
|
|
|
| 290 |
return reward - 0.15, f"Packed extra unit of {item}; order did not require it."
|
| 291 |
|
|
|
|
| 292 |
for packed_line in self.packed_order:
|
| 293 |
if packed_line.sku == self.carrying:
|
| 294 |
packed_line.quantity += 1
|
| 295 |
break
|
| 296 |
item = self.carrying
|
| 297 |
self.carrying = None
|
|
|
|
| 298 |
self.metrics.correct_packs += 1
|
| 299 |
-
|
|
|
|
|
|
|
|
|
|
| 300 |
|
| 301 |
def _recharge(self, reward: float) -> Tuple[float, str]:
|
| 302 |
if self._front_position() != self.task.charger_position:
|
|
@@ -312,6 +362,24 @@ class WarehouseFulfillmentEnv:
|
|
| 312 |
self.metrics.recharges += 1
|
| 313 |
return reward + benefit, "Battery restored to full capacity."
|
| 314 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 315 |
def _build_observation(self, narrative: str) -> WarehouseObservation:
|
| 316 |
nearby_bins = []
|
| 317 |
for bin_state in self.bins:
|
|
@@ -334,7 +402,10 @@ class WarehouseFulfillmentEnv:
|
|
| 334 |
heading=self.heading,
|
| 335 |
front_cell=self._front_cell_label(),
|
| 336 |
carrying=self.carrying,
|
|
|
|
| 337 |
battery_level=self.battery_level,
|
|
|
|
|
|
|
| 338 |
visible_bins=nearby_bins,
|
| 339 |
pending_order=pending,
|
| 340 |
packed_order=packed,
|
|
@@ -363,11 +434,43 @@ class WarehouseFulfillmentEnv:
|
|
| 363 |
if previous > 0 and self.battery_level == 0:
|
| 364 |
self.metrics.battery_depletion_events += 1
|
| 365 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 366 |
def _in_bounds(self, position: Tuple[int, int]) -> bool:
|
| 367 |
return 0 <= position[0] < self.grid_size[0] and 0 <= position[1] < self.grid_size[1]
|
| 368 |
|
| 369 |
def _occupied(self, position: Tuple[int, int]) -> bool:
|
| 370 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 371 |
return True
|
| 372 |
return any(bin_state.position == position for bin_state in self.bins)
|
| 373 |
|
|
|
|
| 40 |
"pick_item",
|
| 41 |
"pack_item",
|
| 42 |
"recharge",
|
| 43 |
+
"rest",
|
| 44 |
"wait",
|
| 45 |
]
|
| 46 |
|
|
|
|
| 101 |
reward_value, narrative = self._pack_item(reward_value)
|
| 102 |
elif command == "recharge":
|
| 103 |
reward_value, narrative = self._recharge(reward_value)
|
| 104 |
+
elif command == "rest":
|
| 105 |
+
reward_value, narrative = self._rest(reward_value)
|
| 106 |
elif command == "wait":
|
| 107 |
reward_value -= 0.01
|
| 108 |
self.metrics.invalid_actions += 1
|
|
|
|
| 113 |
narrative = f"Unknown action: {command}."
|
| 114 |
|
| 115 |
self.action_history.append(command)
|
| 116 |
+
self.done = self._is_episode_complete() or self.step_count >= self.task.max_steps
|
| 117 |
+
self.success = self._is_episode_complete()
|
| 118 |
if self.success:
|
| 119 |
reward_value += 0.50
|
| 120 |
narrative = "Order fully packed and ready for dispatch."
|
|
|
|
| 146 |
agent_position=self.agent_position,
|
| 147 |
heading=self.heading,
|
| 148 |
carrying=self.carrying,
|
| 149 |
+
carrying_weight=self.carrying_weight,
|
| 150 |
battery_level=self.battery_level,
|
| 151 |
battery_capacity=self.task.battery_capacity,
|
| 152 |
+
stamina_level=self.stamina_level,
|
| 153 |
+
stamina_capacity=self.task.stamina_capacity,
|
| 154 |
+
money=round(self.money, 2),
|
| 155 |
+
profit_target=self.task.profit_target,
|
| 156 |
dock_position=self.task.dock_position,
|
| 157 |
pack_station_position=self.task.pack_station_position,
|
| 158 |
charger_position=self.task.charger_position,
|
| 159 |
+
obstacles=list(self.task.obstacles),
|
| 160 |
bins=[self._clone_bin(bin_state) for bin_state in self.bins],
|
| 161 |
order=[self._clone_order_line(line) for line in self.order],
|
| 162 |
packed_order=[self._clone_order_line(line) for line in self.packed_order],
|
|
|
|
| 175 |
self.heading = self.task.agent_heading
|
| 176 |
self.battery_level = self.task.battery_capacity
|
| 177 |
self.carrying: Optional[str] = None
|
| 178 |
+
self.carrying_weight: int = 0
|
| 179 |
+
self.stamina_level: int = self.task.stamina_capacity
|
| 180 |
+
self.money: float = 0.0
|
| 181 |
self.step_count = 0
|
| 182 |
self.done = False
|
| 183 |
self.success = False
|
|
|
|
| 218 |
front = self._front_position()
|
| 219 |
if not self._in_bounds(front):
|
| 220 |
return "wall"
|
| 221 |
+
if self._is_obstacle(front):
|
| 222 |
+
return "obstacle"
|
| 223 |
front_bin = self._front_bin()
|
| 224 |
if front_bin:
|
| 225 |
return f"bin {front_bin.bin_id} ({front_bin.sku})"
|
|
|
|
| 229 |
return "charger"
|
| 230 |
if front == self.task.dock_position:
|
| 231 |
return "dock"
|
| 232 |
+
if self.task.rest_position and front == self.task.rest_position:
|
| 233 |
+
return "rest area"
|
| 234 |
return "aisle"
|
| 235 |
|
| 236 |
def _move_forward(self, reward: float) -> Tuple[float, str]:
|
| 237 |
next_pos = self._front_position()
|
| 238 |
+
|
| 239 |
+
if self._is_obstacle(next_pos):
|
| 240 |
+
self.metrics.obstacle_collisions += 1
|
| 241 |
+
self.metrics.invalid_actions += 1
|
| 242 |
+
self._consume_battery(1)
|
| 243 |
+
return reward - 0.12, "Blocked by an obstacle! Find another route."
|
| 244 |
+
|
| 245 |
if not self._in_bounds(next_pos) or self._occupied(next_pos):
|
| 246 |
self.metrics.invalid_actions += 1
|
| 247 |
self._consume_battery(1)
|
| 248 |
return reward - 0.08, "Forward move blocked by warehouse infrastructure."
|
| 249 |
|
| 250 |
+
battery_cost = 2
|
| 251 |
+
weight_penalty = self.carrying_weight if self.carrying_weight > 1 else 0
|
| 252 |
+
battery_cost += weight_penalty
|
| 253 |
+
|
| 254 |
+
if self._has_stamina() and self.stamina_level <= 0:
|
| 255 |
+
battery_cost *= 2
|
| 256 |
+
|
| 257 |
self.agent_position = next_pos
|
| 258 |
self.metrics.distance_travelled += 1
|
| 259 |
+
self._consume_battery(battery_cost)
|
| 260 |
+
self._consume_stamina(self.task.stamina_move_cost)
|
| 261 |
return reward, f"Moved to aisle cell {self.agent_position}."
|
| 262 |
|
| 263 |
def _scan_bin(self, reward: float) -> Tuple[float, str]:
|
|
|
|
| 291 |
self.metrics.invalid_actions += 1
|
| 292 |
self._consume_battery(1)
|
| 293 |
return reward - 0.10, f"Bin {bin_state.bin_id} is empty."
|
| 294 |
+
if bin_state.weight > self.task.carry_capacity:
|
| 295 |
+
self.metrics.overweight_attempts += 1
|
| 296 |
+
self.metrics.invalid_actions += 1
|
| 297 |
+
self._consume_battery(1)
|
| 298 |
+
return reward - 0.12, (
|
| 299 |
+
f"Item {bin_state.sku} weighs {bin_state.weight} but carry capacity "
|
| 300 |
+
f"is {self.task.carry_capacity}. Too heavy!"
|
| 301 |
+
)
|
| 302 |
|
| 303 |
self._consume_battery(1)
|
| 304 |
bin_state.quantity -= 1
|
| 305 |
self.carrying = bin_state.sku
|
| 306 |
+
self.carrying_weight = bin_state.weight
|
| 307 |
if self._remaining_quantity(bin_state.sku) > 0:
|
| 308 |
self.metrics.correct_picks += 1
|
| 309 |
+
return reward + 0.20, f"Picked {bin_state.sku} (weight {bin_state.weight}) from {bin_state.bin_id}."
|
| 310 |
|
| 311 |
self.metrics.wrong_picks += 1
|
| 312 |
return reward - 0.18, f"Picked {bin_state.sku}, which is not needed now."
|
|
|
|
| 325 |
remaining = self._remaining_quantity(self.carrying)
|
| 326 |
if remaining <= 0:
|
| 327 |
item = self.carrying
|
| 328 |
+
item_value = self._item_value(item)
|
| 329 |
self.carrying = None
|
| 330 |
+
self.carrying_weight = 0
|
| 331 |
self.metrics.wrong_picks += 1
|
| 332 |
+
if item_value > 0:
|
| 333 |
+
self.money -= item_value * 0.5
|
| 334 |
+
self.metrics.money_lost += item_value * 0.5
|
| 335 |
return reward - 0.15, f"Packed extra unit of {item}; order did not require it."
|
| 336 |
|
| 337 |
+
item_value = self._item_value(self.carrying)
|
| 338 |
for packed_line in self.packed_order:
|
| 339 |
if packed_line.sku == self.carrying:
|
| 340 |
packed_line.quantity += 1
|
| 341 |
break
|
| 342 |
item = self.carrying
|
| 343 |
self.carrying = None
|
| 344 |
+
self.carrying_weight = 0
|
| 345 |
self.metrics.correct_packs += 1
|
| 346 |
+
if item_value > 0:
|
| 347 |
+
self.money += item_value
|
| 348 |
+
self.metrics.money_earned += item_value
|
| 349 |
+
return reward + 0.35, f"Packed {item} at the station. (+${item_value:.2f})"
|
| 350 |
|
| 351 |
def _recharge(self, reward: float) -> Tuple[float, str]:
|
| 352 |
if self._front_position() != self.task.charger_position:
|
|
|
|
| 362 |
self.metrics.recharges += 1
|
| 363 |
return reward + benefit, "Battery restored to full capacity."
|
| 364 |
|
| 365 |
+
def _rest(self, reward: float) -> Tuple[float, str]:
|
| 366 |
+
if not self._has_stamina():
|
| 367 |
+
self.metrics.invalid_actions += 1
|
| 368 |
+
return reward - 0.03, "This task has no stamina mechanic."
|
| 369 |
+
|
| 370 |
+
if self.task.rest_position and self._front_position() != self.task.rest_position:
|
| 371 |
+
self.metrics.invalid_actions += 1
|
| 372 |
+
self._consume_battery(1)
|
| 373 |
+
return reward - 0.08, "Rest action requires facing the rest area."
|
| 374 |
+
|
| 375 |
+
if self.stamina_level >= self.task.stamina_capacity:
|
| 376 |
+
return reward - 0.03, "Stamina already full."
|
| 377 |
+
|
| 378 |
+
benefit = 0.06 if self.stamina_level <= self.task.stamina_capacity // 4 else -0.02
|
| 379 |
+
self.stamina_level = self.task.stamina_capacity
|
| 380 |
+
self.metrics.rest_events += 1
|
| 381 |
+
return reward + benefit, "Stamina restored to full capacity."
|
| 382 |
+
|
| 383 |
def _build_observation(self, narrative: str) -> WarehouseObservation:
|
| 384 |
nearby_bins = []
|
| 385 |
for bin_state in self.bins:
|
|
|
|
| 402 |
heading=self.heading,
|
| 403 |
front_cell=self._front_cell_label(),
|
| 404 |
carrying=self.carrying,
|
| 405 |
+
carrying_weight=self.carrying_weight,
|
| 406 |
battery_level=self.battery_level,
|
| 407 |
+
stamina_level=self.stamina_level,
|
| 408 |
+
money=round(self.money, 2),
|
| 409 |
visible_bins=nearby_bins,
|
| 410 |
pending_order=pending,
|
| 411 |
packed_order=packed,
|
|
|
|
| 434 |
if previous > 0 and self.battery_level == 0:
|
| 435 |
self.metrics.battery_depletion_events += 1
|
| 436 |
|
| 437 |
+
def _consume_stamina(self, amount: int) -> None:
|
| 438 |
+
if not self._has_stamina():
|
| 439 |
+
return
|
| 440 |
+
previous = self.stamina_level
|
| 441 |
+
self.stamina_level = max(0, self.stamina_level - amount)
|
| 442 |
+
if previous > 0 and self.stamina_level == 0:
|
| 443 |
+
self.metrics.stamina_depletion_events += 1
|
| 444 |
+
|
| 445 |
+
def _has_stamina(self) -> bool:
|
| 446 |
+
return self.task.stamina_capacity > 0
|
| 447 |
+
|
| 448 |
+
def _is_obstacle(self, position: Tuple[int, int]) -> bool:
|
| 449 |
+
return tuple(position) in {tuple(o) for o in self.task.obstacles}
|
| 450 |
+
|
| 451 |
+
def _item_value(self, sku: str) -> float:
|
| 452 |
+
for bin_state in self.task.bins:
|
| 453 |
+
if bin_state.sku == sku:
|
| 454 |
+
return bin_state.value
|
| 455 |
+
return 0.0
|
| 456 |
+
|
| 457 |
+
def _is_episode_complete(self) -> bool:
|
| 458 |
+
if not self._all_order_lines_complete():
|
| 459 |
+
return False
|
| 460 |
+
if self.task.profit_target > 0 and self.money < self.task.profit_target:
|
| 461 |
+
return False
|
| 462 |
+
return True
|
| 463 |
+
|
| 464 |
def _in_bounds(self, position: Tuple[int, int]) -> bool:
|
| 465 |
return 0 <= position[0] < self.grid_size[0] and 0 <= position[1] < self.grid_size[1]
|
| 466 |
|
| 467 |
def _occupied(self, position: Tuple[int, int]) -> bool:
|
| 468 |
+
fixed = {self.task.pack_station_position, self.task.charger_position, self.task.dock_position}
|
| 469 |
+
if self.task.rest_position:
|
| 470 |
+
fixed.add(self.task.rest_position)
|
| 471 |
+
if position in fixed:
|
| 472 |
+
return True
|
| 473 |
+
if self._is_obstacle(position):
|
| 474 |
return True
|
| 475 |
return any(bin_state.position == position for bin_state in self.bins)
|
| 476 |
|
grid_env/graders.py
CHANGED
|
@@ -224,6 +224,12 @@ def _build_action_log(state: WarehouseState) -> List[Dict[str, Any]]:
|
|
| 224 |
"recharges": metrics.recharges,
|
| 225 |
"battery_depletion_events": metrics.battery_depletion_events,
|
| 226 |
"distance_travelled": metrics.distance_travelled,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
},
|
| 228 |
"result": "",
|
| 229 |
}
|
|
@@ -255,11 +261,40 @@ def grade_hard(state: WarehouseState) -> float:
|
|
| 255 |
return _grade_task("hard_restock_priority", state)
|
| 256 |
|
| 257 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
def grade_episode(state: WarehouseState) -> float:
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
if state.task_id == "hard_restock_priority":
|
| 264 |
-
return grade_hard(state)
|
| 265 |
-
raise KeyError(f"No grader for task_id: {state.task_id}")
|
|
|
|
| 224 |
"recharges": metrics.recharges,
|
| 225 |
"battery_depletion_events": metrics.battery_depletion_events,
|
| 226 |
"distance_travelled": metrics.distance_travelled,
|
| 227 |
+
"stamina_depletion_events": metrics.stamina_depletion_events,
|
| 228 |
+
"rest_events": metrics.rest_events,
|
| 229 |
+
"obstacle_collisions": metrics.obstacle_collisions,
|
| 230 |
+
"money_earned": metrics.money_earned,
|
| 231 |
+
"money_lost": metrics.money_lost,
|
| 232 |
+
"overweight_attempts": metrics.overweight_attempts,
|
| 233 |
},
|
| 234 |
"result": "",
|
| 235 |
}
|
|
|
|
| 261 |
return _grade_task("hard_restock_priority", state)
|
| 262 |
|
| 263 |
|
| 264 |
+
def grade_obstacle_course(state: WarehouseState) -> float:
|
| 265 |
+
return _grade_task("obstacle_course", state)
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def grade_heavy_lifting(state: WarehouseState) -> float:
|
| 269 |
+
return _grade_task("heavy_lifting", state)
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def grade_stamina_run(state: WarehouseState) -> float:
|
| 273 |
+
return _grade_task("stamina_run", state)
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def grade_budget_run(state: WarehouseState) -> float:
|
| 277 |
+
return _grade_task("budget_run", state)
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def grade_gauntlet(state: WarehouseState) -> float:
|
| 281 |
+
return _grade_task("gauntlet", state)
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
_GRADER_DISPATCH = {
|
| 285 |
+
"easy_single_pick": grade_easy,
|
| 286 |
+
"medium_multi_item": grade_medium,
|
| 287 |
+
"hard_restock_priority": grade_hard,
|
| 288 |
+
"obstacle_course": grade_obstacle_course,
|
| 289 |
+
"heavy_lifting": grade_heavy_lifting,
|
| 290 |
+
"stamina_run": grade_stamina_run,
|
| 291 |
+
"budget_run": grade_budget_run,
|
| 292 |
+
"gauntlet": grade_gauntlet,
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
|
| 296 |
def grade_episode(state: WarehouseState) -> float:
|
| 297 |
+
grader = _GRADER_DISPATCH.get(state.task_id)
|
| 298 |
+
if grader is None:
|
| 299 |
+
raise KeyError(f"No grader for task_id: {state.task_id}")
|
| 300 |
+
return grader(state)
|
|
|
|
|
|
|
|
|
grid_env/models.py
CHANGED
|
@@ -62,6 +62,7 @@ Command = Literal[
|
|
| 62 |
"pick_item",
|
| 63 |
"pack_item",
|
| 64 |
"recharge",
|
|
|
|
| 65 |
"wait",
|
| 66 |
]
|
| 67 |
|
|
@@ -84,11 +85,13 @@ class BinState(OpenEnvModel):
|
|
| 84 |
position: Position
|
| 85 |
sku: str
|
| 86 |
quantity: int
|
|
|
|
|
|
|
| 87 |
|
| 88 |
|
| 89 |
class TaskDefinition(OpenEnvModel):
|
| 90 |
task_id: str
|
| 91 |
-
difficulty: Literal["easy", "medium", "hard"]
|
| 92 |
title: str
|
| 93 |
description: str
|
| 94 |
max_steps: int
|
|
@@ -103,6 +106,12 @@ class TaskDefinition(OpenEnvModel):
|
|
| 103 |
order: List[OrderLine]
|
| 104 |
required_scans: List[str] = Field(default_factory=list)
|
| 105 |
rubric_criteria: List[Dict[str, str]] = Field(default_factory=list)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
|
| 107 |
|
| 108 |
class PendingOrderLine(OpenEnvModel):
|
|
@@ -123,7 +132,10 @@ class WarehouseObservation(Observation, OpenEnvModel):
|
|
| 123 |
heading: Heading
|
| 124 |
front_cell: str
|
| 125 |
carrying: Optional[str]
|
|
|
|
| 126 |
battery_level: int
|
|
|
|
|
|
|
| 127 |
visible_bins: List[str]
|
| 128 |
pending_order: List[PendingOrderLine]
|
| 129 |
packed_order: List[PackedOrderLine]
|
|
@@ -146,6 +158,12 @@ class WarehouseMetrics(OpenEnvModel):
|
|
| 146 |
recharges: int = 0
|
| 147 |
battery_depletion_events: int = 0
|
| 148 |
distance_travelled: int = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
|
| 150 |
|
| 151 |
class WarehouseState(State, OpenEnvModel):
|
|
@@ -160,11 +178,17 @@ class WarehouseState(State, OpenEnvModel):
|
|
| 160 |
agent_position: Position
|
| 161 |
heading: Heading
|
| 162 |
carrying: Optional[str]
|
|
|
|
| 163 |
battery_level: int
|
| 164 |
battery_capacity: int
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
dock_position: Position
|
| 166 |
pack_station_position: Position
|
| 167 |
charger_position: Position
|
|
|
|
| 168 |
bins: List[BinState]
|
| 169 |
order: List[OrderLine]
|
| 170 |
packed_order: List[OrderLine]
|
|
|
|
| 62 |
"pick_item",
|
| 63 |
"pack_item",
|
| 64 |
"recharge",
|
| 65 |
+
"rest",
|
| 66 |
"wait",
|
| 67 |
]
|
| 68 |
|
|
|
|
| 85 |
position: Position
|
| 86 |
sku: str
|
| 87 |
quantity: int
|
| 88 |
+
weight: int = 1
|
| 89 |
+
value: float = 0.0
|
| 90 |
|
| 91 |
|
| 92 |
class TaskDefinition(OpenEnvModel):
|
| 93 |
task_id: str
|
| 94 |
+
difficulty: Literal["easy", "medium", "hard", "expert"]
|
| 95 |
title: str
|
| 96 |
description: str
|
| 97 |
max_steps: int
|
|
|
|
| 106 |
order: List[OrderLine]
|
| 107 |
required_scans: List[str] = Field(default_factory=list)
|
| 108 |
rubric_criteria: List[Dict[str, str]] = Field(default_factory=list)
|
| 109 |
+
obstacles: List[Position] = Field(default_factory=list)
|
| 110 |
+
carry_capacity: int = 99
|
| 111 |
+
stamina_capacity: int = 0
|
| 112 |
+
stamina_move_cost: int = 1
|
| 113 |
+
rest_position: Optional[Position] = None
|
| 114 |
+
profit_target: float = 0.0
|
| 115 |
|
| 116 |
|
| 117 |
class PendingOrderLine(OpenEnvModel):
|
|
|
|
| 132 |
heading: Heading
|
| 133 |
front_cell: str
|
| 134 |
carrying: Optional[str]
|
| 135 |
+
carrying_weight: int = 0
|
| 136 |
battery_level: int
|
| 137 |
+
stamina_level: int = 0
|
| 138 |
+
money: float = 0.0
|
| 139 |
visible_bins: List[str]
|
| 140 |
pending_order: List[PendingOrderLine]
|
| 141 |
packed_order: List[PackedOrderLine]
|
|
|
|
| 158 |
recharges: int = 0
|
| 159 |
battery_depletion_events: int = 0
|
| 160 |
distance_travelled: int = 0
|
| 161 |
+
stamina_depletion_events: int = 0
|
| 162 |
+
rest_events: int = 0
|
| 163 |
+
obstacle_collisions: int = 0
|
| 164 |
+
money_earned: float = 0.0
|
| 165 |
+
money_lost: float = 0.0
|
| 166 |
+
overweight_attempts: int = 0
|
| 167 |
|
| 168 |
|
| 169 |
class WarehouseState(State, OpenEnvModel):
|
|
|
|
| 178 |
agent_position: Position
|
| 179 |
heading: Heading
|
| 180 |
carrying: Optional[str]
|
| 181 |
+
carrying_weight: int = 0
|
| 182 |
battery_level: int
|
| 183 |
battery_capacity: int
|
| 184 |
+
stamina_level: int = 0
|
| 185 |
+
stamina_capacity: int = 0
|
| 186 |
+
money: float = 0.0
|
| 187 |
+
profit_target: float = 0.0
|
| 188 |
dock_position: Position
|
| 189 |
pack_station_position: Position
|
| 190 |
charger_position: Position
|
| 191 |
+
obstacles: List[Position] = Field(default_factory=list)
|
| 192 |
bins: List[BinState]
|
| 193 |
order: List[OrderLine]
|
| 194 |
packed_order: List[OrderLine]
|
grid_env/openv.yaml
CHANGED
|
@@ -28,6 +28,21 @@ tasks:
|
|
| 28 |
- id: hard_restock_priority
|
| 29 |
difficulty: hard
|
| 30 |
grader: grid_env.graders:grade_hard
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
baseline:
|
| 32 |
runner: grid_env.baseline:run_baseline
|
| 33 |
seed: 7
|
|
|
|
| 28 |
- id: hard_restock_priority
|
| 29 |
difficulty: hard
|
| 30 |
grader: grid_env.graders:grade_hard
|
| 31 |
+
- id: obstacle_course
|
| 32 |
+
difficulty: medium
|
| 33 |
+
grader: grid_env.graders:grade_obstacle_course
|
| 34 |
+
- id: heavy_lifting
|
| 35 |
+
difficulty: hard
|
| 36 |
+
grader: grid_env.graders:grade_heavy_lifting
|
| 37 |
+
- id: stamina_run
|
| 38 |
+
difficulty: hard
|
| 39 |
+
grader: grid_env.graders:grade_stamina_run
|
| 40 |
+
- id: budget_run
|
| 41 |
+
difficulty: expert
|
| 42 |
+
grader: grid_env.graders:grade_budget_run
|
| 43 |
+
- id: gauntlet
|
| 44 |
+
difficulty: expert
|
| 45 |
+
grader: grid_env.graders:grade_gauntlet
|
| 46 |
baseline:
|
| 47 |
runner: grid_env.baseline:run_baseline
|
| 48 |
seed: 7
|
grid_env/tasks.py
CHANGED
|
@@ -182,6 +182,393 @@ TASKS: Dict[str, TaskDefinition] = {
|
|
| 182 |
},
|
| 183 |
],
|
| 184 |
),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
}
|
| 186 |
|
| 187 |
|
|
|
|
| 182 |
},
|
| 183 |
],
|
| 184 |
),
|
| 185 |
+
|
| 186 |
+
# -----------------------------------------------------------------------
|
| 187 |
+
# Task 4: obstacle_course (medium) — obstacles block direct paths
|
| 188 |
+
# -----------------------------------------------------------------------
|
| 189 |
+
"obstacle_course": TaskDefinition(
|
| 190 |
+
task_id="obstacle_course",
|
| 191 |
+
difficulty="medium",
|
| 192 |
+
title="Obstacle-filled aisle navigation",
|
| 193 |
+
description=(
|
| 194 |
+
"Fulfill a two-item order in a warehouse cluttered with fallen crates. "
|
| 195 |
+
"Navigate around obstacles to reach bins, scan them, pick one thermometer "
|
| 196 |
+
"and one bandage kit, then pack both at the station."
|
| 197 |
+
),
|
| 198 |
+
max_steps=70,
|
| 199 |
+
battery_capacity=40,
|
| 200 |
+
low_battery_threshold=10,
|
| 201 |
+
agent_start=(0, 0),
|
| 202 |
+
agent_heading="E",
|
| 203 |
+
dock_position=(0, 0),
|
| 204 |
+
pack_station_position=(6, 6),
|
| 205 |
+
charger_position=(0, 6),
|
| 206 |
+
bins=[
|
| 207 |
+
BinState(bin_id="A1", position=(2, 1), sku="thermometer", quantity=2),
|
| 208 |
+
BinState(bin_id="B1", position=(4, 1), sku="bandage_kit", quantity=2),
|
| 209 |
+
BinState(bin_id="A2", position=(2, 3), sku="cough_syrup", quantity=2),
|
| 210 |
+
BinState(bin_id="B2", position=(4, 3), sku="pain_relief", quantity=2),
|
| 211 |
+
BinState(bin_id="C1", position=(2, 5), sku="gloves", quantity=2),
|
| 212 |
+
],
|
| 213 |
+
order=[
|
| 214 |
+
OrderLine(sku="thermometer", quantity=1),
|
| 215 |
+
OrderLine(sku="bandage_kit", quantity=1),
|
| 216 |
+
],
|
| 217 |
+
required_scans=["A1", "B1"],
|
| 218 |
+
obstacles=[(1, 2), (2, 2), (3, 2), (3, 4), (4, 4), (5, 4)],
|
| 219 |
+
rubric_criteria=[
|
| 220 |
+
{
|
| 221 |
+
"name": "completion",
|
| 222 |
+
"description": "All items packed.",
|
| 223 |
+
"check": "param_at_least:state.completion_ratio=1.0",
|
| 224 |
+
},
|
| 225 |
+
{
|
| 226 |
+
"name": "scans",
|
| 227 |
+
"description": "Scanned both required bins.",
|
| 228 |
+
"check": "param_at_least:state.correct_scans=2",
|
| 229 |
+
},
|
| 230 |
+
{
|
| 231 |
+
"name": "pack_item",
|
| 232 |
+
"description": "Packed items at the station.",
|
| 233 |
+
"check": "tool_used:pack_item",
|
| 234 |
+
},
|
| 235 |
+
{
|
| 236 |
+
"name": "no_obstacle_collisions",
|
| 237 |
+
"description": "Avoided all obstacle collisions.",
|
| 238 |
+
"check": "param_at_most:state.obstacle_collisions=0",
|
| 239 |
+
},
|
| 240 |
+
{
|
| 241 |
+
"name": "no_wrong_picks",
|
| 242 |
+
"description": "No incorrect picks.",
|
| 243 |
+
"check": "param_at_most:state.wrong_picks=0",
|
| 244 |
+
},
|
| 245 |
+
{
|
| 246 |
+
"name": "few_invalid_actions",
|
| 247 |
+
"description": "At most two invalid actions.",
|
| 248 |
+
"check": "param_at_most:state.invalid_actions=2",
|
| 249 |
+
},
|
| 250 |
+
],
|
| 251 |
+
),
|
| 252 |
+
|
| 253 |
+
# -----------------------------------------------------------------------
|
| 254 |
+
# Task 5: heavy_lifting (hard) — items have weight, limited carry capacity
|
| 255 |
+
# -----------------------------------------------------------------------
|
| 256 |
+
"heavy_lifting": TaskDefinition(
|
| 257 |
+
task_id="heavy_lifting",
|
| 258 |
+
difficulty="hard",
|
| 259 |
+
title="Heavy-item logistics with weight limits",
|
| 260 |
+
description=(
|
| 261 |
+
"Fulfill a three-item order where items vary in weight (1-4 units). "
|
| 262 |
+
"The agent has a carry capacity of 3 and must choose pickup order wisely. "
|
| 263 |
+
"Heavier items drain more battery while moving. Scan each bin, pick items "
|
| 264 |
+
"that fit within your carry limit, and pack at the station. "
|
| 265 |
+
"The heavy pain_relief (weight 4) cannot be carried — skip it!"
|
| 266 |
+
),
|
| 267 |
+
max_steps=90,
|
| 268 |
+
battery_capacity=32,
|
| 269 |
+
low_battery_threshold=8,
|
| 270 |
+
agent_start=(1, 1),
|
| 271 |
+
agent_heading="E",
|
| 272 |
+
dock_position=(1, 1),
|
| 273 |
+
pack_station_position=(5, 5),
|
| 274 |
+
charger_position=(1, 5),
|
| 275 |
+
bins=[
|
| 276 |
+
BinState(bin_id="A1", position=(2, 1), sku="thermometer", quantity=3, weight=1),
|
| 277 |
+
BinState(bin_id="A2", position=(2, 3), sku="cough_syrup", quantity=2, weight=2),
|
| 278 |
+
BinState(bin_id="B1", position=(4, 1), sku="bandage_kit", quantity=2, weight=3),
|
| 279 |
+
BinState(bin_id="B2", position=(4, 3), sku="pain_relief", quantity=2, weight=4),
|
| 280 |
+
BinState(bin_id="C1", position=(2, 5), sku="gloves", quantity=3, weight=1),
|
| 281 |
+
],
|
| 282 |
+
order=[
|
| 283 |
+
OrderLine(sku="thermometer", quantity=1),
|
| 284 |
+
OrderLine(sku="cough_syrup", quantity=1),
|
| 285 |
+
OrderLine(sku="bandage_kit", quantity=1),
|
| 286 |
+
],
|
| 287 |
+
required_scans=["A1", "A2", "B1"],
|
| 288 |
+
carry_capacity=3,
|
| 289 |
+
rubric_criteria=[
|
| 290 |
+
{
|
| 291 |
+
"name": "completion",
|
| 292 |
+
"description": "All items packed.",
|
| 293 |
+
"check": "param_at_least:state.completion_ratio=1.0",
|
| 294 |
+
},
|
| 295 |
+
{
|
| 296 |
+
"name": "scans",
|
| 297 |
+
"description": "Scanned all three required bins.",
|
| 298 |
+
"check": "param_at_least:state.correct_scans=3",
|
| 299 |
+
},
|
| 300 |
+
{
|
| 301 |
+
"name": "recharge",
|
| 302 |
+
"description": "Recharged at least once.",
|
| 303 |
+
"check": "tool_used:recharge",
|
| 304 |
+
},
|
| 305 |
+
{
|
| 306 |
+
"name": "no_overweight",
|
| 307 |
+
"description": "Never tried to pick an overweight item.",
|
| 308 |
+
"check": "param_at_most:state.overweight_attempts=0",
|
| 309 |
+
},
|
| 310 |
+
{
|
| 311 |
+
"name": "no_battery_depletion",
|
| 312 |
+
"description": "Avoided battery depletion.",
|
| 313 |
+
"check": "param_at_most:state.battery_depletion_events=0",
|
| 314 |
+
},
|
| 315 |
+
{
|
| 316 |
+
"name": "no_wrong_picks",
|
| 317 |
+
"description": "No incorrect picks.",
|
| 318 |
+
"check": "param_at_most:state.wrong_picks=0",
|
| 319 |
+
},
|
| 320 |
+
{
|
| 321 |
+
"name": "few_invalid_actions",
|
| 322 |
+
"description": "At most two invalid actions.",
|
| 323 |
+
"check": "param_at_most:state.invalid_actions=2",
|
| 324 |
+
},
|
| 325 |
+
],
|
| 326 |
+
),
|
| 327 |
+
|
| 328 |
+
# -----------------------------------------------------------------------
|
| 329 |
+
# Task 6: stamina_run (hard) — stamina drains on movement
|
| 330 |
+
# -----------------------------------------------------------------------
|
| 331 |
+
"stamina_run": TaskDefinition(
|
| 332 |
+
task_id="stamina_run",
|
| 333 |
+
difficulty="hard",
|
| 334 |
+
title="Endurance run with stamina management",
|
| 335 |
+
description=(
|
| 336 |
+
"Fulfill a two-item order while managing stamina. Every move drains stamina; "
|
| 337 |
+
"when stamina hits zero, movement costs double battery. Rest at the rest area "
|
| 338 |
+
"to restore stamina. Pick one cough syrup and one gloves unit, scan bins, and "
|
| 339 |
+
"pack at the station without running out of energy."
|
| 340 |
+
),
|
| 341 |
+
max_steps=80,
|
| 342 |
+
battery_capacity=36,
|
| 343 |
+
low_battery_threshold=8,
|
| 344 |
+
agent_start=(0, 0),
|
| 345 |
+
agent_heading="E",
|
| 346 |
+
dock_position=(0, 0),
|
| 347 |
+
pack_station_position=(6, 6),
|
| 348 |
+
charger_position=(0, 6),
|
| 349 |
+
rest_position=(3, 3),
|
| 350 |
+
stamina_capacity=12,
|
| 351 |
+
stamina_move_cost=1,
|
| 352 |
+
bins=[
|
| 353 |
+
BinState(bin_id="A1", position=(2, 1), sku="thermometer", quantity=2),
|
| 354 |
+
BinState(bin_id="A2", position=(2, 3), sku="cough_syrup", quantity=2),
|
| 355 |
+
BinState(bin_id="B1", position=(4, 1), sku="bandage_kit", quantity=2),
|
| 356 |
+
BinState(bin_id="B2", position=(4, 3), sku="pain_relief", quantity=2),
|
| 357 |
+
BinState(bin_id="C1", position=(2, 5), sku="gloves", quantity=2),
|
| 358 |
+
],
|
| 359 |
+
order=[
|
| 360 |
+
OrderLine(sku="cough_syrup", quantity=1),
|
| 361 |
+
OrderLine(sku="gloves", quantity=1),
|
| 362 |
+
],
|
| 363 |
+
required_scans=["A2", "C1"],
|
| 364 |
+
rubric_criteria=[
|
| 365 |
+
{
|
| 366 |
+
"name": "completion",
|
| 367 |
+
"description": "All items packed.",
|
| 368 |
+
"check": "param_at_least:state.completion_ratio=1.0",
|
| 369 |
+
},
|
| 370 |
+
{
|
| 371 |
+
"name": "scans",
|
| 372 |
+
"description": "Scanned both required bins.",
|
| 373 |
+
"check": "param_at_least:state.correct_scans=2",
|
| 374 |
+
},
|
| 375 |
+
{
|
| 376 |
+
"name": "rest_used",
|
| 377 |
+
"description": "Used the rest area at least once.",
|
| 378 |
+
"check": "tool_used:rest",
|
| 379 |
+
},
|
| 380 |
+
{
|
| 381 |
+
"name": "no_stamina_depletion",
|
| 382 |
+
"description": "Avoided complete stamina depletion.",
|
| 383 |
+
"check": "param_at_most:state.stamina_depletion_events=0",
|
| 384 |
+
},
|
| 385 |
+
{
|
| 386 |
+
"name": "no_battery_depletion",
|
| 387 |
+
"description": "Avoided battery depletion.",
|
| 388 |
+
"check": "param_at_most:state.battery_depletion_events=0",
|
| 389 |
+
},
|
| 390 |
+
{
|
| 391 |
+
"name": "no_wrong_picks",
|
| 392 |
+
"description": "No incorrect picks.",
|
| 393 |
+
"check": "param_at_most:state.wrong_picks=0",
|
| 394 |
+
},
|
| 395 |
+
],
|
| 396 |
+
),
|
| 397 |
+
|
| 398 |
+
# -----------------------------------------------------------------------
|
| 399 |
+
# Task 7: budget_run (expert) — money rewards and profit target
|
| 400 |
+
# -----------------------------------------------------------------------
|
| 401 |
+
"budget_run": TaskDefinition(
|
| 402 |
+
task_id="budget_run",
|
| 403 |
+
difficulty="expert",
|
| 404 |
+
title="Profitable fulfillment under budget pressure",
|
| 405 |
+
description=(
|
| 406 |
+
"Fulfill orders for profit. Each item has a dollar value earned when correctly "
|
| 407 |
+
"packed. Wrong packs lose half the item value. You must reach a profit target "
|
| 408 |
+
"of $15.00 while completing the order. Pick high-value items efficiently: "
|
| 409 |
+
"2 thermometers ($5 each) and 1 bandage kit ($8). Budget-aware decisions matter."
|
| 410 |
+
),
|
| 411 |
+
max_steps=70,
|
| 412 |
+
battery_capacity=30,
|
| 413 |
+
low_battery_threshold=6,
|
| 414 |
+
agent_start=(1, 1),
|
| 415 |
+
agent_heading="E",
|
| 416 |
+
dock_position=(1, 1),
|
| 417 |
+
pack_station_position=(5, 5),
|
| 418 |
+
charger_position=(1, 5),
|
| 419 |
+
bins=[
|
| 420 |
+
BinState(bin_id="A1", position=(2, 1), sku="thermometer", quantity=3, value=5.0),
|
| 421 |
+
BinState(bin_id="A2", position=(2, 3), sku="cough_syrup", quantity=2, value=3.0),
|
| 422 |
+
BinState(bin_id="B1", position=(4, 1), sku="bandage_kit", quantity=2, value=8.0),
|
| 423 |
+
BinState(bin_id="B2", position=(4, 3), sku="pain_relief", quantity=2, value=4.0),
|
| 424 |
+
BinState(bin_id="C1", position=(2, 5), sku="gloves", quantity=2, value=2.0),
|
| 425 |
+
],
|
| 426 |
+
order=[
|
| 427 |
+
OrderLine(sku="thermometer", quantity=2),
|
| 428 |
+
OrderLine(sku="bandage_kit", quantity=1),
|
| 429 |
+
],
|
| 430 |
+
required_scans=["A1", "B1"],
|
| 431 |
+
profit_target=15.0,
|
| 432 |
+
rubric_criteria=[
|
| 433 |
+
{
|
| 434 |
+
"name": "completion",
|
| 435 |
+
"description": "All items packed.",
|
| 436 |
+
"check": "param_at_least:state.completion_ratio=1.0",
|
| 437 |
+
},
|
| 438 |
+
{
|
| 439 |
+
"name": "profit_target",
|
| 440 |
+
"description": "Reached the profit target of $15.",
|
| 441 |
+
"check": "param_at_least:state.money_earned=15.0",
|
| 442 |
+
},
|
| 443 |
+
{
|
| 444 |
+
"name": "scans",
|
| 445 |
+
"description": "Scanned required bins.",
|
| 446 |
+
"check": "param_at_least:state.correct_scans=2",
|
| 447 |
+
},
|
| 448 |
+
{
|
| 449 |
+
"name": "recharge",
|
| 450 |
+
"description": "Recharged at least once.",
|
| 451 |
+
"check": "tool_used:recharge",
|
| 452 |
+
},
|
| 453 |
+
{
|
| 454 |
+
"name": "no_money_lost",
|
| 455 |
+
"description": "No money lost from wrong packs.",
|
| 456 |
+
"check": "param_at_most:state.money_lost=0.0",
|
| 457 |
+
},
|
| 458 |
+
{
|
| 459 |
+
"name": "no_wrong_picks",
|
| 460 |
+
"description": "No incorrect picks.",
|
| 461 |
+
"check": "param_at_most:state.wrong_picks=0",
|
| 462 |
+
},
|
| 463 |
+
{
|
| 464 |
+
"name": "few_invalid_actions",
|
| 465 |
+
"description": "At most one invalid action.",
|
| 466 |
+
"check": "param_at_most:state.invalid_actions=1",
|
| 467 |
+
},
|
| 468 |
+
],
|
| 469 |
+
),
|
| 470 |
+
|
| 471 |
+
# -----------------------------------------------------------------------
|
| 472 |
+
# Task 8: gauntlet (expert) — all mechanics combined
|
| 473 |
+
# -----------------------------------------------------------------------
|
| 474 |
+
"gauntlet": TaskDefinition(
|
| 475 |
+
task_id="gauntlet",
|
| 476 |
+
difficulty="expert",
|
| 477 |
+
title="The gauntlet: obstacles, weight, stamina, and profit",
|
| 478 |
+
description=(
|
| 479 |
+
"The ultimate warehouse challenge. Navigate a cluttered warehouse with obstacles, "
|
| 480 |
+
"manage item weights (carry capacity 3), conserve stamina (rest when needed), "
|
| 481 |
+
"earn money for packed items, and hit a $20 profit target. Fulfill a four-item "
|
| 482 |
+
"order: 1 thermometer ($5, wt 1), 1 cough syrup ($6, wt 2), 1 bandage kit ($8, wt 3), "
|
| 483 |
+
"and 1 gloves ($4, wt 1). Recharge battery, rest for stamina, avoid obstacles, "
|
| 484 |
+
"and finish profitable."
|
| 485 |
+
),
|
| 486 |
+
max_steps=120,
|
| 487 |
+
battery_capacity=28,
|
| 488 |
+
low_battery_threshold=7,
|
| 489 |
+
agent_start=(0, 0),
|
| 490 |
+
agent_heading="S",
|
| 491 |
+
dock_position=(0, 0),
|
| 492 |
+
pack_station_position=(6, 6),
|
| 493 |
+
charger_position=(6, 0),
|
| 494 |
+
rest_position=(0, 6),
|
| 495 |
+
stamina_capacity=10,
|
| 496 |
+
stamina_move_cost=1,
|
| 497 |
+
carry_capacity=3,
|
| 498 |
+
profit_target=20.0,
|
| 499 |
+
obstacles=[(1, 1), (3, 1), (5, 3), (3, 3), (1, 5), (5, 5)],
|
| 500 |
+
bins=[
|
| 501 |
+
BinState(bin_id="A1", position=(2, 0), sku="thermometer", quantity=3, weight=1, value=5.0),
|
| 502 |
+
BinState(bin_id="A2", position=(2, 4), sku="cough_syrup", quantity=2, weight=2, value=6.0),
|
| 503 |
+
BinState(bin_id="B1", position=(4, 0), sku="bandage_kit", quantity=2, weight=3, value=8.0),
|
| 504 |
+
BinState(bin_id="B2", position=(4, 4), sku="pain_relief", quantity=2, weight=4, value=4.0),
|
| 505 |
+
BinState(bin_id="C1", position=(4, 2), sku="gloves", quantity=3, weight=1, value=4.0),
|
| 506 |
+
],
|
| 507 |
+
order=[
|
| 508 |
+
OrderLine(sku="thermometer", quantity=1),
|
| 509 |
+
OrderLine(sku="cough_syrup", quantity=1),
|
| 510 |
+
OrderLine(sku="bandage_kit", quantity=1),
|
| 511 |
+
OrderLine(sku="gloves", quantity=1),
|
| 512 |
+
],
|
| 513 |
+
required_scans=["A1", "A2", "B1", "C1"],
|
| 514 |
+
rubric_criteria=[
|
| 515 |
+
{
|
| 516 |
+
"name": "completion",
|
| 517 |
+
"description": "All four items packed.",
|
| 518 |
+
"check": "param_at_least:state.completion_ratio=1.0",
|
| 519 |
+
},
|
| 520 |
+
{
|
| 521 |
+
"name": "profit_target",
|
| 522 |
+
"description": "Reached $20 profit target.",
|
| 523 |
+
"check": "param_at_least:state.money_earned=20.0",
|
| 524 |
+
},
|
| 525 |
+
{
|
| 526 |
+
"name": "scans",
|
| 527 |
+
"description": "Scanned all four required bins.",
|
| 528 |
+
"check": "param_at_least:state.correct_scans=4",
|
| 529 |
+
},
|
| 530 |
+
{
|
| 531 |
+
"name": "recharge",
|
| 532 |
+
"description": "Recharged at least once.",
|
| 533 |
+
"check": "tool_used:recharge",
|
| 534 |
+
},
|
| 535 |
+
{
|
| 536 |
+
"name": "rest_used",
|
| 537 |
+
"description": "Used the rest area at least once.",
|
| 538 |
+
"check": "tool_used:rest",
|
| 539 |
+
},
|
| 540 |
+
{
|
| 541 |
+
"name": "no_obstacle_collisions",
|
| 542 |
+
"description": "Avoided all obstacle collisions.",
|
| 543 |
+
"check": "param_at_most:state.obstacle_collisions=0",
|
| 544 |
+
},
|
| 545 |
+
{
|
| 546 |
+
"name": "no_overweight",
|
| 547 |
+
"description": "Never tried to pick an overweight item.",
|
| 548 |
+
"check": "param_at_most:state.overweight_attempts=0",
|
| 549 |
+
},
|
| 550 |
+
{
|
| 551 |
+
"name": "no_battery_depletion",
|
| 552 |
+
"description": "Avoided battery depletion.",
|
| 553 |
+
"check": "param_at_most:state.battery_depletion_events=0",
|
| 554 |
+
},
|
| 555 |
+
{
|
| 556 |
+
"name": "no_stamina_depletion",
|
| 557 |
+
"description": "Avoided stamina depletion.",
|
| 558 |
+
"check": "param_at_most:state.stamina_depletion_events=0",
|
| 559 |
+
},
|
| 560 |
+
{
|
| 561 |
+
"name": "no_wrong_picks",
|
| 562 |
+
"description": "No incorrect picks.",
|
| 563 |
+
"check": "param_at_most:state.wrong_picks=0",
|
| 564 |
+
},
|
| 565 |
+
{
|
| 566 |
+
"name": "few_invalid_actions",
|
| 567 |
+
"description": "At most two invalid actions.",
|
| 568 |
+
"check": "param_at_most:state.invalid_actions=2",
|
| 569 |
+
},
|
| 570 |
+
],
|
| 571 |
+
),
|
| 572 |
}
|
| 573 |
|
| 574 |
|
grid_env/tools.py
CHANGED
|
@@ -18,6 +18,7 @@ _TOOL_DESCRIPTIONS: Dict[str, str] = {
|
|
| 18 |
"pick_item": "Pick an item from the bin in front.",
|
| 19 |
"pack_item": "Pack the carried item at the packing station.",
|
| 20 |
"recharge": "Recharge the battery at the charging dock.",
|
|
|
|
| 21 |
"wait": "Stay in place and consume time.",
|
| 22 |
}
|
| 23 |
|
|
|
|
| 18 |
"pick_item": "Pick an item from the bin in front.",
|
| 19 |
"pack_item": "Pack the carried item at the packing station.",
|
| 20 |
"recharge": "Recharge the battery at the charging dock.",
|
| 21 |
+
"rest": "Rest at the rest area to restore stamina.",
|
| 22 |
"wait": "Stay in place and consume time.",
|
| 23 |
}
|
| 24 |
|
openenv.yaml
CHANGED
|
@@ -28,6 +28,21 @@ tasks:
|
|
| 28 |
- id: hard_restock_priority
|
| 29 |
difficulty: hard
|
| 30 |
grader: grid_env.graders:grade_hard
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
baseline:
|
| 32 |
runner: grid_env.baseline:run_baseline
|
| 33 |
seed: 7
|
|
|
|
| 28 |
- id: hard_restock_priority
|
| 29 |
difficulty: hard
|
| 30 |
grader: grid_env.graders:grade_hard
|
| 31 |
+
- id: obstacle_course
|
| 32 |
+
difficulty: medium
|
| 33 |
+
grader: grid_env.graders:grade_obstacle_course
|
| 34 |
+
- id: heavy_lifting
|
| 35 |
+
difficulty: hard
|
| 36 |
+
grader: grid_env.graders:grade_heavy_lifting
|
| 37 |
+
- id: stamina_run
|
| 38 |
+
difficulty: hard
|
| 39 |
+
grader: grid_env.graders:grade_stamina_run
|
| 40 |
+
- id: budget_run
|
| 41 |
+
difficulty: expert
|
| 42 |
+
grader: grid_env.graders:grade_budget_run
|
| 43 |
+
- id: gauntlet
|
| 44 |
+
difficulty: expert
|
| 45 |
+
grader: grid_env.graders:grade_gauntlet
|
| 46 |
baseline:
|
| 47 |
runner: grid_env.baseline:run_baseline
|
| 48 |
seed: 7
|
tests/conftest.py
CHANGED
|
@@ -25,3 +25,38 @@ def env_hard():
|
|
| 25 |
env = WarehouseFulfillmentEnv(task_id="hard_restock_priority", seed=7)
|
| 26 |
env.reset()
|
| 27 |
return env
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
env = WarehouseFulfillmentEnv(task_id="hard_restock_priority", seed=7)
|
| 26 |
env.reset()
|
| 27 |
return env
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@pytest.fixture()
|
| 31 |
+
def env_obstacle_course():
|
| 32 |
+
env = WarehouseFulfillmentEnv(task_id="obstacle_course", seed=7)
|
| 33 |
+
env.reset()
|
| 34 |
+
return env
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@pytest.fixture()
|
| 38 |
+
def env_heavy_lifting():
|
| 39 |
+
env = WarehouseFulfillmentEnv(task_id="heavy_lifting", seed=7)
|
| 40 |
+
env.reset()
|
| 41 |
+
return env
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@pytest.fixture()
|
| 45 |
+
def env_stamina_run():
|
| 46 |
+
env = WarehouseFulfillmentEnv(task_id="stamina_run", seed=7)
|
| 47 |
+
env.reset()
|
| 48 |
+
return env
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@pytest.fixture()
|
| 52 |
+
def env_budget_run():
|
| 53 |
+
env = WarehouseFulfillmentEnv(task_id="budget_run", seed=7)
|
| 54 |
+
env.reset()
|
| 55 |
+
return env
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
@pytest.fixture()
|
| 59 |
+
def env_gauntlet():
|
| 60 |
+
env = WarehouseFulfillmentEnv(task_id="gauntlet", seed=7)
|
| 61 |
+
env.reset()
|
| 62 |
+
return env
|
tests/test_baseline_stub.py
CHANGED
|
@@ -14,7 +14,16 @@ from grid_env.graders import grade_episode
|
|
| 14 |
from grid_env.models import BaselineCommand
|
| 15 |
|
| 16 |
|
| 17 |
-
TASK_IDS = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
# Cycle of deterministic actions that exercise most code paths without getting stuck.
|
| 20 |
_ACTION_CYCLE = [
|
|
@@ -27,6 +36,7 @@ _ACTION_CYCLE = [
|
|
| 27 |
"turn_right",
|
| 28 |
"move_forward",
|
| 29 |
"pack_item",
|
|
|
|
| 30 |
"wait",
|
| 31 |
]
|
| 32 |
|
|
|
|
| 14 |
from grid_env.models import BaselineCommand
|
| 15 |
|
| 16 |
|
| 17 |
+
TASK_IDS = [
|
| 18 |
+
"easy_single_pick",
|
| 19 |
+
"medium_multi_item",
|
| 20 |
+
"hard_restock_priority",
|
| 21 |
+
"obstacle_course",
|
| 22 |
+
"heavy_lifting",
|
| 23 |
+
"stamina_run",
|
| 24 |
+
"budget_run",
|
| 25 |
+
"gauntlet",
|
| 26 |
+
]
|
| 27 |
|
| 28 |
# Cycle of deterministic actions that exercise most code paths without getting stuck.
|
| 29 |
_ACTION_CYCLE = [
|
|
|
|
| 36 |
"turn_right",
|
| 37 |
"move_forward",
|
| 38 |
"pack_item",
|
| 39 |
+
"rest",
|
| 40 |
"wait",
|
| 41 |
]
|
| 42 |
|
tests/test_env_smoke.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
"""
|
| 2 |
Smoke tests: environment instantiation, reset, step, and episode termination
|
| 3 |
-
for all
|
| 4 |
"""
|
| 5 |
|
| 6 |
import pytest
|
|
@@ -9,7 +9,16 @@ from grid_env.graders import grade_episode
|
|
| 9 |
from grid_env.models import WarehouseObservation, WarehouseReward
|
| 10 |
|
| 11 |
|
| 12 |
-
TASK_IDS = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
ALL_ACTIONS = [
|
| 14 |
"turn_left",
|
| 15 |
"turn_right",
|
|
@@ -18,6 +27,7 @@ ALL_ACTIONS = [
|
|
| 18 |
"pick_item",
|
| 19 |
"pack_item",
|
| 20 |
"recharge",
|
|
|
|
| 21 |
"wait",
|
| 22 |
]
|
| 23 |
|
|
@@ -104,8 +114,8 @@ def test_step_after_done_is_safe():
|
|
| 104 |
assert done
|
| 105 |
|
| 106 |
|
| 107 |
-
def
|
| 108 |
-
"""available_tasks() returns
|
| 109 |
tasks = available_tasks()
|
| 110 |
ids = {t["task_id"] for t in tasks}
|
| 111 |
assert ids == set(TASK_IDS)
|
|
|
|
| 1 |
"""
|
| 2 |
Smoke tests: environment instantiation, reset, step, and episode termination
|
| 3 |
+
for all task IDs.
|
| 4 |
"""
|
| 5 |
|
| 6 |
import pytest
|
|
|
|
| 9 |
from grid_env.models import WarehouseObservation, WarehouseReward
|
| 10 |
|
| 11 |
|
| 12 |
+
TASK_IDS = [
|
| 13 |
+
"easy_single_pick",
|
| 14 |
+
"medium_multi_item",
|
| 15 |
+
"hard_restock_priority",
|
| 16 |
+
"obstacle_course",
|
| 17 |
+
"heavy_lifting",
|
| 18 |
+
"stamina_run",
|
| 19 |
+
"budget_run",
|
| 20 |
+
"gauntlet",
|
| 21 |
+
]
|
| 22 |
ALL_ACTIONS = [
|
| 23 |
"turn_left",
|
| 24 |
"turn_right",
|
|
|
|
| 27 |
"pick_item",
|
| 28 |
"pack_item",
|
| 29 |
"recharge",
|
| 30 |
+
"rest",
|
| 31 |
"wait",
|
| 32 |
]
|
| 33 |
|
|
|
|
| 114 |
assert done
|
| 115 |
|
| 116 |
|
| 117 |
+
def test_available_tasks_returns_all():
|
| 118 |
+
"""available_tasks() returns all expected task IDs."""
|
| 119 |
tasks = available_tasks()
|
| 120 |
ids = {t["task_id"] for t in tasks}
|
| 121 |
assert ids == set(TASK_IDS)
|
tests/test_graders.py
CHANGED
|
@@ -4,7 +4,17 @@ Unit tests for rubric-based graders.
|
|
| 4 |
|
| 5 |
import pytest
|
| 6 |
|
| 7 |
-
from grid_env.graders import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
from grid_env.models import BinState, OrderLine, WarehouseMetrics, WarehouseState
|
| 9 |
|
| 10 |
|
|
@@ -21,6 +31,12 @@ def _make_state(
|
|
| 21 |
invalid_actions: int = 0,
|
| 22 |
recharges: int = 0,
|
| 23 |
battery_depletion_events: int = 0,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
action_history: list[str] | None = None,
|
| 25 |
) -> WarehouseState:
|
| 26 |
metrics = WarehouseMetrics(
|
|
@@ -32,6 +48,12 @@ def _make_state(
|
|
| 32 |
invalid_actions=invalid_actions,
|
| 33 |
recharges=recharges,
|
| 34 |
battery_depletion_events=battery_depletion_events,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
)
|
| 36 |
return WarehouseState(
|
| 37 |
episode_id="test-ep",
|
|
@@ -138,14 +160,128 @@ def test_grade_hard_partial_rubric_scores_lower():
|
|
| 138 |
assert 0.0 < grade_hard(state) < 1.0
|
| 139 |
|
| 140 |
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
def test_grade_episode_dispatches_correctly(task_id, grader):
|
| 150 |
state = _make_state(task_id, completion_ratio=0.5, action_history=["pick_item"])
|
| 151 |
assert grade_episode(state) == grader(state)
|
|
|
|
| 4 |
|
| 5 |
import pytest
|
| 6 |
|
| 7 |
+
from grid_env.graders import (
|
| 8 |
+
grade_budget_run,
|
| 9 |
+
grade_easy,
|
| 10 |
+
grade_episode,
|
| 11 |
+
grade_gauntlet,
|
| 12 |
+
grade_hard,
|
| 13 |
+
grade_heavy_lifting,
|
| 14 |
+
grade_medium,
|
| 15 |
+
grade_obstacle_course,
|
| 16 |
+
grade_stamina_run,
|
| 17 |
+
)
|
| 18 |
from grid_env.models import BinState, OrderLine, WarehouseMetrics, WarehouseState
|
| 19 |
|
| 20 |
|
|
|
|
| 31 |
invalid_actions: int = 0,
|
| 32 |
recharges: int = 0,
|
| 33 |
battery_depletion_events: int = 0,
|
| 34 |
+
stamina_depletion_events: int = 0,
|
| 35 |
+
rest_events: int = 0,
|
| 36 |
+
obstacle_collisions: int = 0,
|
| 37 |
+
money_earned: float = 0.0,
|
| 38 |
+
money_lost: float = 0.0,
|
| 39 |
+
overweight_attempts: int = 0,
|
| 40 |
action_history: list[str] | None = None,
|
| 41 |
) -> WarehouseState:
|
| 42 |
metrics = WarehouseMetrics(
|
|
|
|
| 48 |
invalid_actions=invalid_actions,
|
| 49 |
recharges=recharges,
|
| 50 |
battery_depletion_events=battery_depletion_events,
|
| 51 |
+
stamina_depletion_events=stamina_depletion_events,
|
| 52 |
+
rest_events=rest_events,
|
| 53 |
+
obstacle_collisions=obstacle_collisions,
|
| 54 |
+
money_earned=money_earned,
|
| 55 |
+
money_lost=money_lost,
|
| 56 |
+
overweight_attempts=overweight_attempts,
|
| 57 |
)
|
| 58 |
return WarehouseState(
|
| 59 |
episode_id="test-ep",
|
|
|
|
| 160 |
assert 0.0 < grade_hard(state) < 1.0
|
| 161 |
|
| 162 |
|
| 163 |
+
def test_grade_obstacle_course_full_rubric_passes():
|
| 164 |
+
state = _make_state(
|
| 165 |
+
"obstacle_course",
|
| 166 |
+
completion_ratio=1.0,
|
| 167 |
+
correct_scans=2,
|
| 168 |
+
wrong_picks=0,
|
| 169 |
+
invalid_actions=2,
|
| 170 |
+
obstacle_collisions=0,
|
| 171 |
+
action_history=["scan_bin", "pick_item", "pack_item"],
|
| 172 |
+
)
|
| 173 |
+
assert grade_obstacle_course(state) == pytest.approx(1.0)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def test_grade_obstacle_course_collision_lowers_score():
|
| 177 |
+
state = _make_state(
|
| 178 |
+
"obstacle_course",
|
| 179 |
+
completion_ratio=1.0,
|
| 180 |
+
correct_scans=2,
|
| 181 |
+
wrong_picks=0,
|
| 182 |
+
invalid_actions=2,
|
| 183 |
+
obstacle_collisions=3,
|
| 184 |
+
action_history=["scan_bin", "pick_item", "pack_item"],
|
| 185 |
+
)
|
| 186 |
+
assert 0.0 < grade_obstacle_course(state) < 1.0
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def test_grade_heavy_lifting_full_rubric_passes():
|
| 190 |
+
state = _make_state(
|
| 191 |
+
"heavy_lifting",
|
| 192 |
+
completion_ratio=1.0,
|
| 193 |
+
correct_scans=3,
|
| 194 |
+
wrong_picks=0,
|
| 195 |
+
invalid_actions=2,
|
| 196 |
+
recharges=1,
|
| 197 |
+
battery_depletion_events=0,
|
| 198 |
+
overweight_attempts=0,
|
| 199 |
+
action_history=["scan_bin", "pick_item", "pack_item", "recharge"],
|
| 200 |
+
)
|
| 201 |
+
assert grade_heavy_lifting(state) == pytest.approx(1.0)
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def test_grade_stamina_run_full_rubric_passes():
|
| 205 |
+
state = _make_state(
|
| 206 |
+
"stamina_run",
|
| 207 |
+
completion_ratio=1.0,
|
| 208 |
+
correct_scans=2,
|
| 209 |
+
wrong_picks=0,
|
| 210 |
+
stamina_depletion_events=0,
|
| 211 |
+
battery_depletion_events=0,
|
| 212 |
+
rest_events=1,
|
| 213 |
+
action_history=["scan_bin", "pick_item", "pack_item", "rest"],
|
| 214 |
+
)
|
| 215 |
+
assert grade_stamina_run(state) == pytest.approx(1.0)
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def test_grade_budget_run_full_rubric_passes():
|
| 219 |
+
state = _make_state(
|
| 220 |
+
"budget_run",
|
| 221 |
+
completion_ratio=1.0,
|
| 222 |
+
correct_scans=2,
|
| 223 |
+
wrong_picks=0,
|
| 224 |
+
invalid_actions=1,
|
| 225 |
+
recharges=1,
|
| 226 |
+
money_earned=18.0,
|
| 227 |
+
money_lost=0.0,
|
| 228 |
+
action_history=["scan_bin", "pick_item", "pack_item", "recharge"],
|
| 229 |
+
)
|
| 230 |
+
assert grade_budget_run(state) == pytest.approx(1.0)
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def test_grade_gauntlet_full_rubric_passes():
|
| 234 |
+
state = _make_state(
|
| 235 |
+
"gauntlet",
|
| 236 |
+
completion_ratio=1.0,
|
| 237 |
+
correct_scans=4,
|
| 238 |
+
wrong_picks=0,
|
| 239 |
+
invalid_actions=2,
|
| 240 |
+
recharges=1,
|
| 241 |
+
battery_depletion_events=0,
|
| 242 |
+
stamina_depletion_events=0,
|
| 243 |
+
rest_events=1,
|
| 244 |
+
obstacle_collisions=0,
|
| 245 |
+
overweight_attempts=0,
|
| 246 |
+
money_earned=23.0,
|
| 247 |
+
money_lost=0.0,
|
| 248 |
+
action_history=["scan_bin", "pick_item", "pack_item", "recharge", "rest"],
|
| 249 |
+
)
|
| 250 |
+
assert grade_gauntlet(state) == pytest.approx(1.0)
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def test_grade_gauntlet_partial_scores_lower():
|
| 254 |
+
state = _make_state(
|
| 255 |
+
"gauntlet",
|
| 256 |
+
completion_ratio=1.0,
|
| 257 |
+
correct_scans=2,
|
| 258 |
+
wrong_picks=0,
|
| 259 |
+
invalid_actions=5,
|
| 260 |
+
recharges=0,
|
| 261 |
+
battery_depletion_events=1,
|
| 262 |
+
stamina_depletion_events=1,
|
| 263 |
+
obstacle_collisions=2,
|
| 264 |
+
overweight_attempts=1,
|
| 265 |
+
money_earned=10.0,
|
| 266 |
+
money_lost=5.0,
|
| 267 |
+
action_history=["scan_bin", "pick_item", "pack_item"],
|
| 268 |
+
)
|
| 269 |
+
assert 0.0 < grade_gauntlet(state) < 1.0
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
ALL_GRADERS = [
|
| 273 |
+
("easy_single_pick", grade_easy),
|
| 274 |
+
("medium_multi_item", grade_medium),
|
| 275 |
+
("hard_restock_priority", grade_hard),
|
| 276 |
+
("obstacle_course", grade_obstacle_course),
|
| 277 |
+
("heavy_lifting", grade_heavy_lifting),
|
| 278 |
+
("stamina_run", grade_stamina_run),
|
| 279 |
+
("budget_run", grade_budget_run),
|
| 280 |
+
("gauntlet", grade_gauntlet),
|
| 281 |
+
]
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
@pytest.mark.parametrize("task_id,grader", ALL_GRADERS)
|
| 285 |
def test_grade_episode_dispatches_correctly(task_id, grader):
|
| 286 |
state = _make_state(task_id, completion_ratio=0.5, action_history=["pick_item"])
|
| 287 |
assert grade_episode(state) == grader(state)
|
tests/test_server.py
CHANGED
|
@@ -18,7 +18,16 @@ from fastapi.testclient import TestClient
|
|
| 18 |
from grid_env.Server.app import app
|
| 19 |
|
| 20 |
|
| 21 |
-
TASK_IDS = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
ALL_ACTIONS = [
|
| 23 |
"turn_left",
|
| 24 |
"turn_right",
|
|
@@ -27,6 +36,7 @@ ALL_ACTIONS = [
|
|
| 27 |
"pick_item",
|
| 28 |
"pack_item",
|
| 29 |
"recharge",
|
|
|
|
| 30 |
"wait",
|
| 31 |
]
|
| 32 |
|
|
@@ -78,7 +88,7 @@ def test_tasks_has_tasks_key(client):
|
|
| 78 |
assert isinstance(body["tasks"], list)
|
| 79 |
|
| 80 |
|
| 81 |
-
def
|
| 82 |
body = client.get("/tasks").json()
|
| 83 |
ids = {t["task_id"] for t in body["tasks"]}
|
| 84 |
assert ids == set(TASK_IDS)
|
|
|
|
| 18 |
from grid_env.Server.app import app
|
| 19 |
|
| 20 |
|
| 21 |
+
TASK_IDS = [
|
| 22 |
+
"easy_single_pick",
|
| 23 |
+
"medium_multi_item",
|
| 24 |
+
"hard_restock_priority",
|
| 25 |
+
"obstacle_course",
|
| 26 |
+
"heavy_lifting",
|
| 27 |
+
"stamina_run",
|
| 28 |
+
"budget_run",
|
| 29 |
+
"gauntlet",
|
| 30 |
+
]
|
| 31 |
ALL_ACTIONS = [
|
| 32 |
"turn_left",
|
| 33 |
"turn_right",
|
|
|
|
| 36 |
"pick_item",
|
| 37 |
"pack_item",
|
| 38 |
"recharge",
|
| 39 |
+
"rest",
|
| 40 |
"wait",
|
| 41 |
]
|
| 42 |
|
|
|
|
| 88 |
assert isinstance(body["tasks"], list)
|
| 89 |
|
| 90 |
|
| 91 |
+
def test_tasks_returns_all(client):
|
| 92 |
body = client.get("/tasks").json()
|
| 93 |
ids = {t["task_id"] for t in body["tasks"]}
|
| 94 |
assert ids == set(TASK_IDS)
|
tests/test_tasks.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
"""
|
| 2 |
-
Tests for task definitions: presence of all
|
| 3 |
and grader callability.
|
| 4 |
"""
|
| 5 |
|
|
@@ -9,11 +9,20 @@ from grid_env.graders import grade_episode
|
|
| 9 |
from grid_env.env import WarehouseFulfillmentEnv
|
| 10 |
|
| 11 |
|
| 12 |
-
EXPECTED_TASK_IDS = {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
|
| 15 |
-
def
|
| 16 |
-
assert len(TASKS) ==
|
| 17 |
|
| 18 |
|
| 19 |
def test_all_expected_task_ids_present():
|
|
@@ -24,7 +33,7 @@ def test_all_expected_task_ids_present():
|
|
| 24 |
def test_task_has_required_fields(task_id):
|
| 25 |
task = get_task(task_id)
|
| 26 |
assert task.task_id == task_id
|
| 27 |
-
assert task.difficulty in {"easy", "medium", "hard"}
|
| 28 |
assert task.max_steps > 0
|
| 29 |
assert task.battery_capacity > 0
|
| 30 |
assert len(task.bins) > 0
|
|
@@ -69,6 +78,24 @@ def test_grader_callable_returns_float_in_range(task_id):
|
|
| 69 |
assert 0.0 <= score <= 1.0
|
| 70 |
|
| 71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
def test_get_task_raises_on_unknown_id():
|
| 73 |
with pytest.raises(KeyError, match="Unknown task_id"):
|
| 74 |
get_task("does_not_exist")
|
|
|
|
| 1 |
"""
|
| 2 |
+
Tests for task definitions: presence of all tasks, structural validity,
|
| 3 |
and grader callability.
|
| 4 |
"""
|
| 5 |
|
|
|
|
| 9 |
from grid_env.env import WarehouseFulfillmentEnv
|
| 10 |
|
| 11 |
|
| 12 |
+
EXPECTED_TASK_IDS = {
|
| 13 |
+
"easy_single_pick",
|
| 14 |
+
"medium_multi_item",
|
| 15 |
+
"hard_restock_priority",
|
| 16 |
+
"obstacle_course",
|
| 17 |
+
"heavy_lifting",
|
| 18 |
+
"stamina_run",
|
| 19 |
+
"budget_run",
|
| 20 |
+
"gauntlet",
|
| 21 |
+
}
|
| 22 |
|
| 23 |
|
| 24 |
+
def test_expected_task_count():
|
| 25 |
+
assert len(TASKS) == len(EXPECTED_TASK_IDS)
|
| 26 |
|
| 27 |
|
| 28 |
def test_all_expected_task_ids_present():
|
|
|
|
| 33 |
def test_task_has_required_fields(task_id):
|
| 34 |
task = get_task(task_id)
|
| 35 |
assert task.task_id == task_id
|
| 36 |
+
assert task.difficulty in {"easy", "medium", "hard", "expert"}
|
| 37 |
assert task.max_steps > 0
|
| 38 |
assert task.battery_capacity > 0
|
| 39 |
assert len(task.bins) > 0
|
|
|
|
| 78 |
assert 0.0 <= score <= 1.0
|
| 79 |
|
| 80 |
|
| 81 |
+
@pytest.mark.parametrize("task_id", list(EXPECTED_TASK_IDS))
|
| 82 |
+
def test_obstacles_do_not_overlap_bins_or_stations(task_id):
|
| 83 |
+
"""Obstacles must not overlap with bins, stations, or agent start."""
|
| 84 |
+
task = get_task(task_id)
|
| 85 |
+
obstacle_set = set(task.obstacles)
|
| 86 |
+
bin_positions = {tuple(b.position) for b in task.bins}
|
| 87 |
+
fixed_positions = {
|
| 88 |
+
tuple(task.pack_station_position),
|
| 89 |
+
tuple(task.charger_position),
|
| 90 |
+
tuple(task.dock_position),
|
| 91 |
+
tuple(task.agent_start),
|
| 92 |
+
}
|
| 93 |
+
if task.rest_position:
|
| 94 |
+
fixed_positions.add(tuple(task.rest_position))
|
| 95 |
+
assert obstacle_set.isdisjoint(bin_positions), f"Obstacle overlaps bin in {task_id}"
|
| 96 |
+
assert obstacle_set.isdisjoint(fixed_positions), f"Obstacle overlaps station in {task_id}"
|
| 97 |
+
|
| 98 |
+
|
| 99 |
def test_get_task_raises_on_unknown_id():
|
| 100 |
with pytest.raises(KeyError, match="Unknown task_id"):
|
| 101 |
get_task("does_not_exist")
|