sohambose98 commited on
Commit
eaa79f0
·
1 Parent(s): 3ac18bb

updated the tests and graders

Browse files
grid_env/Server/__pycache__/__init__.cpython-313.pyc CHANGED
Binary files a/grid_env/Server/__pycache__/__init__.cpython-313.pyc and b/grid_env/Server/__pycache__/__init__.cpython-313.pyc differ
 
grid_env/Server/__pycache__/app.cpython-313.pyc CHANGED
Binary files a/grid_env/Server/__pycache__/app.cpython-313.pyc and b/grid_env/Server/__pycache__/app.cpython-313.pyc differ
 
grid_env/Server/__pycache__/warehouse_env.cpython-313.pyc CHANGED
Binary files a/grid_env/Server/__pycache__/warehouse_env.cpython-313.pyc and b/grid_env/Server/__pycache__/warehouse_env.cpython-313.pyc differ
 
grid_env/__init__.py CHANGED
@@ -1,6 +1,16 @@
1
  from .client import WarehouseEnvClient
2
  from .env import WarehouseFulfillmentEnv, available_tasks
3
- from .graders import grade_easy, grade_episode, grade_hard, grade_medium
 
 
 
 
 
 
 
 
 
 
4
  from .models import (
5
  BaselineCommand,
6
  BinState,
@@ -30,8 +40,13 @@ __all__ = [
30
  "WarehouseReward",
31
  "WarehouseState",
32
  "available_tasks",
 
33
  "grade_easy",
34
  "grade_episode",
 
35
  "grade_hard",
 
36
  "grade_medium",
 
 
37
  ]
 
1
  from .client import WarehouseEnvClient
2
  from .env import WarehouseFulfillmentEnv, available_tasks
3
+ from .graders import (
4
+ grade_budget_run,
5
+ grade_easy,
6
+ grade_episode,
7
+ grade_gauntlet,
8
+ grade_hard,
9
+ grade_heavy_lifting,
10
+ grade_medium,
11
+ grade_obstacle_course,
12
+ grade_stamina_run,
13
+ )
14
  from .models import (
15
  BaselineCommand,
16
  BinState,
 
40
  "WarehouseReward",
41
  "WarehouseState",
42
  "available_tasks",
43
+ "grade_budget_run",
44
  "grade_easy",
45
  "grade_episode",
46
+ "grade_gauntlet",
47
  "grade_hard",
48
+ "grade_heavy_lifting",
49
  "grade_medium",
50
+ "grade_obstacle_course",
51
+ "grade_stamina_run",
52
  ]
grid_env/__pycache__/__init__.cpython-313.pyc CHANGED
Binary files a/grid_env/__pycache__/__init__.cpython-313.pyc and b/grid_env/__pycache__/__init__.cpython-313.pyc differ
 
grid_env/__pycache__/baseline.cpython-313.pyc CHANGED
Binary files a/grid_env/__pycache__/baseline.cpython-313.pyc and b/grid_env/__pycache__/baseline.cpython-313.pyc differ
 
grid_env/__pycache__/client.cpython-313.pyc CHANGED
Binary files a/grid_env/__pycache__/client.cpython-313.pyc and b/grid_env/__pycache__/client.cpython-313.pyc differ
 
grid_env/__pycache__/env.cpython-313.pyc CHANGED
Binary files a/grid_env/__pycache__/env.cpython-313.pyc and b/grid_env/__pycache__/env.cpython-313.pyc differ
 
grid_env/__pycache__/graders.cpython-313.pyc CHANGED
Binary files a/grid_env/__pycache__/graders.cpython-313.pyc and b/grid_env/__pycache__/graders.cpython-313.pyc differ
 
grid_env/__pycache__/models.cpython-313.pyc CHANGED
Binary files a/grid_env/__pycache__/models.cpython-313.pyc and b/grid_env/__pycache__/models.cpython-313.pyc differ
 
grid_env/__pycache__/tasks.cpython-313.pyc CHANGED
Binary files a/grid_env/__pycache__/tasks.cpython-313.pyc and b/grid_env/__pycache__/tasks.cpython-313.pyc differ
 
grid_env/baseline.py CHANGED
@@ -21,7 +21,7 @@ except ImportError:
21
 
22
  SYSTEM_PROMPT = """You control a warehouse fulfillment robot.
23
  Return exactly one JSON object with:
24
- - command: one of turn_left, turn_right, move_forward, scan_bin, pick_item, pack_item, recharge, wait
25
  - rationale: a short sentence
26
 
27
  Objective:
@@ -29,6 +29,14 @@ Objective:
29
  - Use scans before picks when the task requires verified bins.
30
  - Recharge before battery depletion if needed.
31
  - Avoid invalid actions and unnecessary wandering.
 
 
 
 
 
 
 
 
32
  """
33
 
34
 
 
21
 
22
  SYSTEM_PROMPT = """You control a warehouse fulfillment robot.
23
  Return exactly one JSON object with:
24
+ - command: one of turn_left, turn_right, move_forward, scan_bin, pick_item, pack_item, recharge, rest, wait
25
  - rationale: a short sentence
26
 
27
  Objective:
 
29
  - Use scans before picks when the task requires verified bins.
30
  - Recharge before battery depletion if needed.
31
  - Avoid invalid actions and unnecessary wandering.
32
+
33
+ Advanced mechanics (active on harder tasks):
34
+ - Obstacles: some cells are impassable. If front_cell says "obstacle", turn to find another route.
35
+ - Item weight: items have weight. If an item exceeds your carry capacity, you cannot pick it.
36
+ Heavier items drain more battery while moving.
37
+ - Stamina: movement costs stamina. When stamina hits 0, movement costs double battery.
38
+ Use the "rest" action at the rest area to restore stamina.
39
+ - Money: packing correct items earns money; wrong packs lose money. Hit the profit target if set.
40
  """
41
 
42
 
grid_env/env.py CHANGED
@@ -40,6 +40,7 @@ class WarehouseFulfillmentEnv:
40
  "pick_item",
41
  "pack_item",
42
  "recharge",
 
43
  "wait",
44
  ]
45
 
@@ -100,6 +101,8 @@ class WarehouseFulfillmentEnv:
100
  reward_value, narrative = self._pack_item(reward_value)
101
  elif command == "recharge":
102
  reward_value, narrative = self._recharge(reward_value)
 
 
103
  elif command == "wait":
104
  reward_value -= 0.01
105
  self.metrics.invalid_actions += 1
@@ -110,8 +113,8 @@ class WarehouseFulfillmentEnv:
110
  narrative = f"Unknown action: {command}."
111
 
112
  self.action_history.append(command)
113
- self.done = self._all_order_lines_complete() or self.step_count >= self.task.max_steps
114
- self.success = self._all_order_lines_complete()
115
  if self.success:
116
  reward_value += 0.50
117
  narrative = "Order fully packed and ready for dispatch."
@@ -143,11 +146,17 @@ class WarehouseFulfillmentEnv:
143
  agent_position=self.agent_position,
144
  heading=self.heading,
145
  carrying=self.carrying,
 
146
  battery_level=self.battery_level,
147
  battery_capacity=self.task.battery_capacity,
 
 
 
 
148
  dock_position=self.task.dock_position,
149
  pack_station_position=self.task.pack_station_position,
150
  charger_position=self.task.charger_position,
 
151
  bins=[self._clone_bin(bin_state) for bin_state in self.bins],
152
  order=[self._clone_order_line(line) for line in self.order],
153
  packed_order=[self._clone_order_line(line) for line in self.packed_order],
@@ -166,6 +175,9 @@ class WarehouseFulfillmentEnv:
166
  self.heading = self.task.agent_heading
167
  self.battery_level = self.task.battery_capacity
168
  self.carrying: Optional[str] = None
 
 
 
169
  self.step_count = 0
170
  self.done = False
171
  self.success = False
@@ -206,6 +218,8 @@ class WarehouseFulfillmentEnv:
206
  front = self._front_position()
207
  if not self._in_bounds(front):
208
  return "wall"
 
 
209
  front_bin = self._front_bin()
210
  if front_bin:
211
  return f"bin {front_bin.bin_id} ({front_bin.sku})"
@@ -215,18 +229,35 @@ class WarehouseFulfillmentEnv:
215
  return "charger"
216
  if front == self.task.dock_position:
217
  return "dock"
 
 
218
  return "aisle"
219
 
220
  def _move_forward(self, reward: float) -> Tuple[float, str]:
221
  next_pos = self._front_position()
 
 
 
 
 
 
 
222
  if not self._in_bounds(next_pos) or self._occupied(next_pos):
223
  self.metrics.invalid_actions += 1
224
  self._consume_battery(1)
225
  return reward - 0.08, "Forward move blocked by warehouse infrastructure."
226
 
 
 
 
 
 
 
 
227
  self.agent_position = next_pos
228
  self.metrics.distance_travelled += 1
229
- self._consume_battery(2)
 
230
  return reward, f"Moved to aisle cell {self.agent_position}."
231
 
232
  def _scan_bin(self, reward: float) -> Tuple[float, str]:
@@ -260,13 +291,22 @@ class WarehouseFulfillmentEnv:
260
  self.metrics.invalid_actions += 1
261
  self._consume_battery(1)
262
  return reward - 0.10, f"Bin {bin_state.bin_id} is empty."
 
 
 
 
 
 
 
 
263
 
264
  self._consume_battery(1)
265
  bin_state.quantity -= 1
266
  self.carrying = bin_state.sku
 
267
  if self._remaining_quantity(bin_state.sku) > 0:
268
  self.metrics.correct_picks += 1
269
- return reward + 0.20, f"Picked {bin_state.sku} from {bin_state.bin_id}."
270
 
271
  self.metrics.wrong_picks += 1
272
  return reward - 0.18, f"Picked {bin_state.sku}, which is not needed now."
@@ -285,18 +325,28 @@ class WarehouseFulfillmentEnv:
285
  remaining = self._remaining_quantity(self.carrying)
286
  if remaining <= 0:
287
  item = self.carrying
 
288
  self.carrying = None
 
289
  self.metrics.wrong_picks += 1
 
 
 
290
  return reward - 0.15, f"Packed extra unit of {item}; order did not require it."
291
 
 
292
  for packed_line in self.packed_order:
293
  if packed_line.sku == self.carrying:
294
  packed_line.quantity += 1
295
  break
296
  item = self.carrying
297
  self.carrying = None
 
298
  self.metrics.correct_packs += 1
299
- return reward + 0.35, f"Packed {item} at the station."
 
 
 
300
 
301
  def _recharge(self, reward: float) -> Tuple[float, str]:
302
  if self._front_position() != self.task.charger_position:
@@ -312,6 +362,24 @@ class WarehouseFulfillmentEnv:
312
  self.metrics.recharges += 1
313
  return reward + benefit, "Battery restored to full capacity."
314
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
  def _build_observation(self, narrative: str) -> WarehouseObservation:
316
  nearby_bins = []
317
  for bin_state in self.bins:
@@ -334,7 +402,10 @@ class WarehouseFulfillmentEnv:
334
  heading=self.heading,
335
  front_cell=self._front_cell_label(),
336
  carrying=self.carrying,
 
337
  battery_level=self.battery_level,
 
 
338
  visible_bins=nearby_bins,
339
  pending_order=pending,
340
  packed_order=packed,
@@ -363,11 +434,43 @@ class WarehouseFulfillmentEnv:
363
  if previous > 0 and self.battery_level == 0:
364
  self.metrics.battery_depletion_events += 1
365
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
366
  def _in_bounds(self, position: Tuple[int, int]) -> bool:
367
  return 0 <= position[0] < self.grid_size[0] and 0 <= position[1] < self.grid_size[1]
368
 
369
  def _occupied(self, position: Tuple[int, int]) -> bool:
370
- if position in {self.task.pack_station_position, self.task.charger_position, self.task.dock_position}:
 
 
 
 
 
371
  return True
372
  return any(bin_state.position == position for bin_state in self.bins)
373
 
 
40
  "pick_item",
41
  "pack_item",
42
  "recharge",
43
+ "rest",
44
  "wait",
45
  ]
46
 
 
101
  reward_value, narrative = self._pack_item(reward_value)
102
  elif command == "recharge":
103
  reward_value, narrative = self._recharge(reward_value)
104
+ elif command == "rest":
105
+ reward_value, narrative = self._rest(reward_value)
106
  elif command == "wait":
107
  reward_value -= 0.01
108
  self.metrics.invalid_actions += 1
 
113
  narrative = f"Unknown action: {command}."
114
 
115
  self.action_history.append(command)
116
+ self.done = self._is_episode_complete() or self.step_count >= self.task.max_steps
117
+ self.success = self._is_episode_complete()
118
  if self.success:
119
  reward_value += 0.50
120
  narrative = "Order fully packed and ready for dispatch."
 
146
  agent_position=self.agent_position,
147
  heading=self.heading,
148
  carrying=self.carrying,
149
+ carrying_weight=self.carrying_weight,
150
  battery_level=self.battery_level,
151
  battery_capacity=self.task.battery_capacity,
152
+ stamina_level=self.stamina_level,
153
+ stamina_capacity=self.task.stamina_capacity,
154
+ money=round(self.money, 2),
155
+ profit_target=self.task.profit_target,
156
  dock_position=self.task.dock_position,
157
  pack_station_position=self.task.pack_station_position,
158
  charger_position=self.task.charger_position,
159
+ obstacles=list(self.task.obstacles),
160
  bins=[self._clone_bin(bin_state) for bin_state in self.bins],
161
  order=[self._clone_order_line(line) for line in self.order],
162
  packed_order=[self._clone_order_line(line) for line in self.packed_order],
 
175
  self.heading = self.task.agent_heading
176
  self.battery_level = self.task.battery_capacity
177
  self.carrying: Optional[str] = None
178
+ self.carrying_weight: int = 0
179
+ self.stamina_level: int = self.task.stamina_capacity
180
+ self.money: float = 0.0
181
  self.step_count = 0
182
  self.done = False
183
  self.success = False
 
218
  front = self._front_position()
219
  if not self._in_bounds(front):
220
  return "wall"
221
+ if self._is_obstacle(front):
222
+ return "obstacle"
223
  front_bin = self._front_bin()
224
  if front_bin:
225
  return f"bin {front_bin.bin_id} ({front_bin.sku})"
 
229
  return "charger"
230
  if front == self.task.dock_position:
231
  return "dock"
232
+ if self.task.rest_position and front == self.task.rest_position:
233
+ return "rest area"
234
  return "aisle"
235
 
236
  def _move_forward(self, reward: float) -> Tuple[float, str]:
237
  next_pos = self._front_position()
238
+
239
+ if self._is_obstacle(next_pos):
240
+ self.metrics.obstacle_collisions += 1
241
+ self.metrics.invalid_actions += 1
242
+ self._consume_battery(1)
243
+ return reward - 0.12, "Blocked by an obstacle! Find another route."
244
+
245
  if not self._in_bounds(next_pos) or self._occupied(next_pos):
246
  self.metrics.invalid_actions += 1
247
  self._consume_battery(1)
248
  return reward - 0.08, "Forward move blocked by warehouse infrastructure."
249
 
250
+ battery_cost = 2
251
+ weight_penalty = self.carrying_weight if self.carrying_weight > 1 else 0
252
+ battery_cost += weight_penalty
253
+
254
+ if self._has_stamina() and self.stamina_level <= 0:
255
+ battery_cost *= 2
256
+
257
  self.agent_position = next_pos
258
  self.metrics.distance_travelled += 1
259
+ self._consume_battery(battery_cost)
260
+ self._consume_stamina(self.task.stamina_move_cost)
261
  return reward, f"Moved to aisle cell {self.agent_position}."
262
 
263
  def _scan_bin(self, reward: float) -> Tuple[float, str]:
 
291
  self.metrics.invalid_actions += 1
292
  self._consume_battery(1)
293
  return reward - 0.10, f"Bin {bin_state.bin_id} is empty."
294
+ if bin_state.weight > self.task.carry_capacity:
295
+ self.metrics.overweight_attempts += 1
296
+ self.metrics.invalid_actions += 1
297
+ self._consume_battery(1)
298
+ return reward - 0.12, (
299
+ f"Item {bin_state.sku} weighs {bin_state.weight} but carry capacity "
300
+ f"is {self.task.carry_capacity}. Too heavy!"
301
+ )
302
 
303
  self._consume_battery(1)
304
  bin_state.quantity -= 1
305
  self.carrying = bin_state.sku
306
+ self.carrying_weight = bin_state.weight
307
  if self._remaining_quantity(bin_state.sku) > 0:
308
  self.metrics.correct_picks += 1
309
+ return reward + 0.20, f"Picked {bin_state.sku} (weight {bin_state.weight}) from {bin_state.bin_id}."
310
 
311
  self.metrics.wrong_picks += 1
312
  return reward - 0.18, f"Picked {bin_state.sku}, which is not needed now."
 
325
  remaining = self._remaining_quantity(self.carrying)
326
  if remaining <= 0:
327
  item = self.carrying
328
+ item_value = self._item_value(item)
329
  self.carrying = None
330
+ self.carrying_weight = 0
331
  self.metrics.wrong_picks += 1
332
+ if item_value > 0:
333
+ self.money -= item_value * 0.5
334
+ self.metrics.money_lost += item_value * 0.5
335
  return reward - 0.15, f"Packed extra unit of {item}; order did not require it."
336
 
337
+ item_value = self._item_value(self.carrying)
338
  for packed_line in self.packed_order:
339
  if packed_line.sku == self.carrying:
340
  packed_line.quantity += 1
341
  break
342
  item = self.carrying
343
  self.carrying = None
344
+ self.carrying_weight = 0
345
  self.metrics.correct_packs += 1
346
+ if item_value > 0:
347
+ self.money += item_value
348
+ self.metrics.money_earned += item_value
349
+ return reward + 0.35, f"Packed {item} at the station. (+${item_value:.2f})"
350
 
351
  def _recharge(self, reward: float) -> Tuple[float, str]:
352
  if self._front_position() != self.task.charger_position:
 
362
  self.metrics.recharges += 1
363
  return reward + benefit, "Battery restored to full capacity."
364
 
365
+ def _rest(self, reward: float) -> Tuple[float, str]:
366
+ if not self._has_stamina():
367
+ self.metrics.invalid_actions += 1
368
+ return reward - 0.03, "This task has no stamina mechanic."
369
+
370
+ if self.task.rest_position and self._front_position() != self.task.rest_position:
371
+ self.metrics.invalid_actions += 1
372
+ self._consume_battery(1)
373
+ return reward - 0.08, "Rest action requires facing the rest area."
374
+
375
+ if self.stamina_level >= self.task.stamina_capacity:
376
+ return reward - 0.03, "Stamina already full."
377
+
378
+ benefit = 0.06 if self.stamina_level <= self.task.stamina_capacity // 4 else -0.02
379
+ self.stamina_level = self.task.stamina_capacity
380
+ self.metrics.rest_events += 1
381
+ return reward + benefit, "Stamina restored to full capacity."
382
+
383
  def _build_observation(self, narrative: str) -> WarehouseObservation:
384
  nearby_bins = []
385
  for bin_state in self.bins:
 
402
  heading=self.heading,
403
  front_cell=self._front_cell_label(),
404
  carrying=self.carrying,
405
+ carrying_weight=self.carrying_weight,
406
  battery_level=self.battery_level,
407
+ stamina_level=self.stamina_level,
408
+ money=round(self.money, 2),
409
  visible_bins=nearby_bins,
410
  pending_order=pending,
411
  packed_order=packed,
 
434
  if previous > 0 and self.battery_level == 0:
435
  self.metrics.battery_depletion_events += 1
436
 
437
+ def _consume_stamina(self, amount: int) -> None:
438
+ if not self._has_stamina():
439
+ return
440
+ previous = self.stamina_level
441
+ self.stamina_level = max(0, self.stamina_level - amount)
442
+ if previous > 0 and self.stamina_level == 0:
443
+ self.metrics.stamina_depletion_events += 1
444
+
445
+ def _has_stamina(self) -> bool:
446
+ return self.task.stamina_capacity > 0
447
+
448
+ def _is_obstacle(self, position: Tuple[int, int]) -> bool:
449
+ return tuple(position) in {tuple(o) for o in self.task.obstacles}
450
+
451
+ def _item_value(self, sku: str) -> float:
452
+ for bin_state in self.task.bins:
453
+ if bin_state.sku == sku:
454
+ return bin_state.value
455
+ return 0.0
456
+
457
+ def _is_episode_complete(self) -> bool:
458
+ if not self._all_order_lines_complete():
459
+ return False
460
+ if self.task.profit_target > 0 and self.money < self.task.profit_target:
461
+ return False
462
+ return True
463
+
464
  def _in_bounds(self, position: Tuple[int, int]) -> bool:
465
  return 0 <= position[0] < self.grid_size[0] and 0 <= position[1] < self.grid_size[1]
466
 
467
  def _occupied(self, position: Tuple[int, int]) -> bool:
468
+ fixed = {self.task.pack_station_position, self.task.charger_position, self.task.dock_position}
469
+ if self.task.rest_position:
470
+ fixed.add(self.task.rest_position)
471
+ if position in fixed:
472
+ return True
473
+ if self._is_obstacle(position):
474
  return True
475
  return any(bin_state.position == position for bin_state in self.bins)
476
 
grid_env/graders.py CHANGED
@@ -224,6 +224,12 @@ def _build_action_log(state: WarehouseState) -> List[Dict[str, Any]]:
224
  "recharges": metrics.recharges,
225
  "battery_depletion_events": metrics.battery_depletion_events,
226
  "distance_travelled": metrics.distance_travelled,
 
 
 
 
 
 
227
  },
228
  "result": "",
229
  }
@@ -255,11 +261,40 @@ def grade_hard(state: WarehouseState) -> float:
255
  return _grade_task("hard_restock_priority", state)
256
 
257
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
  def grade_episode(state: WarehouseState) -> float:
259
- if state.task_id == "easy_single_pick":
260
- return grade_easy(state)
261
- if state.task_id == "medium_multi_item":
262
- return grade_medium(state)
263
- if state.task_id == "hard_restock_priority":
264
- return grade_hard(state)
265
- raise KeyError(f"No grader for task_id: {state.task_id}")
 
224
  "recharges": metrics.recharges,
225
  "battery_depletion_events": metrics.battery_depletion_events,
226
  "distance_travelled": metrics.distance_travelled,
227
+ "stamina_depletion_events": metrics.stamina_depletion_events,
228
+ "rest_events": metrics.rest_events,
229
+ "obstacle_collisions": metrics.obstacle_collisions,
230
+ "money_earned": metrics.money_earned,
231
+ "money_lost": metrics.money_lost,
232
+ "overweight_attempts": metrics.overweight_attempts,
233
  },
234
  "result": "",
235
  }
 
261
  return _grade_task("hard_restock_priority", state)
262
 
263
 
264
+ def grade_obstacle_course(state: WarehouseState) -> float:
265
+ return _grade_task("obstacle_course", state)
266
+
267
+
268
+ def grade_heavy_lifting(state: WarehouseState) -> float:
269
+ return _grade_task("heavy_lifting", state)
270
+
271
+
272
+ def grade_stamina_run(state: WarehouseState) -> float:
273
+ return _grade_task("stamina_run", state)
274
+
275
+
276
+ def grade_budget_run(state: WarehouseState) -> float:
277
+ return _grade_task("budget_run", state)
278
+
279
+
280
+ def grade_gauntlet(state: WarehouseState) -> float:
281
+ return _grade_task("gauntlet", state)
282
+
283
+
284
+ _GRADER_DISPATCH = {
285
+ "easy_single_pick": grade_easy,
286
+ "medium_multi_item": grade_medium,
287
+ "hard_restock_priority": grade_hard,
288
+ "obstacle_course": grade_obstacle_course,
289
+ "heavy_lifting": grade_heavy_lifting,
290
+ "stamina_run": grade_stamina_run,
291
+ "budget_run": grade_budget_run,
292
+ "gauntlet": grade_gauntlet,
293
+ }
294
+
295
+
296
  def grade_episode(state: WarehouseState) -> float:
297
+ grader = _GRADER_DISPATCH.get(state.task_id)
298
+ if grader is None:
299
+ raise KeyError(f"No grader for task_id: {state.task_id}")
300
+ return grader(state)
 
 
 
grid_env/models.py CHANGED
@@ -62,6 +62,7 @@ Command = Literal[
62
  "pick_item",
63
  "pack_item",
64
  "recharge",
 
65
  "wait",
66
  ]
67
 
@@ -84,11 +85,13 @@ class BinState(OpenEnvModel):
84
  position: Position
85
  sku: str
86
  quantity: int
 
 
87
 
88
 
89
  class TaskDefinition(OpenEnvModel):
90
  task_id: str
91
- difficulty: Literal["easy", "medium", "hard"]
92
  title: str
93
  description: str
94
  max_steps: int
@@ -103,6 +106,12 @@ class TaskDefinition(OpenEnvModel):
103
  order: List[OrderLine]
104
  required_scans: List[str] = Field(default_factory=list)
105
  rubric_criteria: List[Dict[str, str]] = Field(default_factory=list)
 
 
 
 
 
 
106
 
107
 
108
  class PendingOrderLine(OpenEnvModel):
@@ -123,7 +132,10 @@ class WarehouseObservation(Observation, OpenEnvModel):
123
  heading: Heading
124
  front_cell: str
125
  carrying: Optional[str]
 
126
  battery_level: int
 
 
127
  visible_bins: List[str]
128
  pending_order: List[PendingOrderLine]
129
  packed_order: List[PackedOrderLine]
@@ -146,6 +158,12 @@ class WarehouseMetrics(OpenEnvModel):
146
  recharges: int = 0
147
  battery_depletion_events: int = 0
148
  distance_travelled: int = 0
 
 
 
 
 
 
149
 
150
 
151
  class WarehouseState(State, OpenEnvModel):
@@ -160,11 +178,17 @@ class WarehouseState(State, OpenEnvModel):
160
  agent_position: Position
161
  heading: Heading
162
  carrying: Optional[str]
 
163
  battery_level: int
164
  battery_capacity: int
 
 
 
 
165
  dock_position: Position
166
  pack_station_position: Position
167
  charger_position: Position
 
168
  bins: List[BinState]
169
  order: List[OrderLine]
170
  packed_order: List[OrderLine]
 
62
  "pick_item",
63
  "pack_item",
64
  "recharge",
65
+ "rest",
66
  "wait",
67
  ]
68
 
 
85
  position: Position
86
  sku: str
87
  quantity: int
88
+ weight: int = 1
89
+ value: float = 0.0
90
 
91
 
92
  class TaskDefinition(OpenEnvModel):
93
  task_id: str
94
+ difficulty: Literal["easy", "medium", "hard", "expert"]
95
  title: str
96
  description: str
97
  max_steps: int
 
106
  order: List[OrderLine]
107
  required_scans: List[str] = Field(default_factory=list)
108
  rubric_criteria: List[Dict[str, str]] = Field(default_factory=list)
109
+ obstacles: List[Position] = Field(default_factory=list)
110
+ carry_capacity: int = 99
111
+ stamina_capacity: int = 0
112
+ stamina_move_cost: int = 1
113
+ rest_position: Optional[Position] = None
114
+ profit_target: float = 0.0
115
 
116
 
117
  class PendingOrderLine(OpenEnvModel):
 
132
  heading: Heading
133
  front_cell: str
134
  carrying: Optional[str]
135
+ carrying_weight: int = 0
136
  battery_level: int
137
+ stamina_level: int = 0
138
+ money: float = 0.0
139
  visible_bins: List[str]
140
  pending_order: List[PendingOrderLine]
141
  packed_order: List[PackedOrderLine]
 
158
  recharges: int = 0
159
  battery_depletion_events: int = 0
160
  distance_travelled: int = 0
161
+ stamina_depletion_events: int = 0
162
+ rest_events: int = 0
163
+ obstacle_collisions: int = 0
164
+ money_earned: float = 0.0
165
+ money_lost: float = 0.0
166
+ overweight_attempts: int = 0
167
 
168
 
169
  class WarehouseState(State, OpenEnvModel):
 
178
  agent_position: Position
179
  heading: Heading
180
  carrying: Optional[str]
181
+ carrying_weight: int = 0
182
  battery_level: int
183
  battery_capacity: int
184
+ stamina_level: int = 0
185
+ stamina_capacity: int = 0
186
+ money: float = 0.0
187
+ profit_target: float = 0.0
188
  dock_position: Position
189
  pack_station_position: Position
190
  charger_position: Position
191
+ obstacles: List[Position] = Field(default_factory=list)
192
  bins: List[BinState]
193
  order: List[OrderLine]
194
  packed_order: List[OrderLine]
grid_env/openv.yaml CHANGED
@@ -28,6 +28,21 @@ tasks:
28
  - id: hard_restock_priority
29
  difficulty: hard
30
  grader: grid_env.graders:grade_hard
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  baseline:
32
  runner: grid_env.baseline:run_baseline
33
  seed: 7
 
28
  - id: hard_restock_priority
29
  difficulty: hard
30
  grader: grid_env.graders:grade_hard
31
+ - id: obstacle_course
32
+ difficulty: medium
33
+ grader: grid_env.graders:grade_obstacle_course
34
+ - id: heavy_lifting
35
+ difficulty: hard
36
+ grader: grid_env.graders:grade_heavy_lifting
37
+ - id: stamina_run
38
+ difficulty: hard
39
+ grader: grid_env.graders:grade_stamina_run
40
+ - id: budget_run
41
+ difficulty: expert
42
+ grader: grid_env.graders:grade_budget_run
43
+ - id: gauntlet
44
+ difficulty: expert
45
+ grader: grid_env.graders:grade_gauntlet
46
  baseline:
47
  runner: grid_env.baseline:run_baseline
48
  seed: 7
grid_env/tasks.py CHANGED
@@ -182,6 +182,393 @@ TASKS: Dict[str, TaskDefinition] = {
182
  },
183
  ],
184
  ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  }
186
 
187
 
 
182
  },
183
  ],
184
  ),
185
+
186
+ # -----------------------------------------------------------------------
187
+ # Task 4: obstacle_course (medium) — obstacles block direct paths
188
+ # -----------------------------------------------------------------------
189
+ "obstacle_course": TaskDefinition(
190
+ task_id="obstacle_course",
191
+ difficulty="medium",
192
+ title="Obstacle-filled aisle navigation",
193
+ description=(
194
+ "Fulfill a two-item order in a warehouse cluttered with fallen crates. "
195
+ "Navigate around obstacles to reach bins, scan them, pick one thermometer "
196
+ "and one bandage kit, then pack both at the station."
197
+ ),
198
+ max_steps=70,
199
+ battery_capacity=40,
200
+ low_battery_threshold=10,
201
+ agent_start=(0, 0),
202
+ agent_heading="E",
203
+ dock_position=(0, 0),
204
+ pack_station_position=(6, 6),
205
+ charger_position=(0, 6),
206
+ bins=[
207
+ BinState(bin_id="A1", position=(2, 1), sku="thermometer", quantity=2),
208
+ BinState(bin_id="B1", position=(4, 1), sku="bandage_kit", quantity=2),
209
+ BinState(bin_id="A2", position=(2, 3), sku="cough_syrup", quantity=2),
210
+ BinState(bin_id="B2", position=(4, 3), sku="pain_relief", quantity=2),
211
+ BinState(bin_id="C1", position=(2, 5), sku="gloves", quantity=2),
212
+ ],
213
+ order=[
214
+ OrderLine(sku="thermometer", quantity=1),
215
+ OrderLine(sku="bandage_kit", quantity=1),
216
+ ],
217
+ required_scans=["A1", "B1"],
218
+ obstacles=[(1, 2), (2, 2), (3, 2), (3, 4), (4, 4), (5, 4)],
219
+ rubric_criteria=[
220
+ {
221
+ "name": "completion",
222
+ "description": "All items packed.",
223
+ "check": "param_at_least:state.completion_ratio=1.0",
224
+ },
225
+ {
226
+ "name": "scans",
227
+ "description": "Scanned both required bins.",
228
+ "check": "param_at_least:state.correct_scans=2",
229
+ },
230
+ {
231
+ "name": "pack_item",
232
+ "description": "Packed items at the station.",
233
+ "check": "tool_used:pack_item",
234
+ },
235
+ {
236
+ "name": "no_obstacle_collisions",
237
+ "description": "Avoided all obstacle collisions.",
238
+ "check": "param_at_most:state.obstacle_collisions=0",
239
+ },
240
+ {
241
+ "name": "no_wrong_picks",
242
+ "description": "No incorrect picks.",
243
+ "check": "param_at_most:state.wrong_picks=0",
244
+ },
245
+ {
246
+ "name": "few_invalid_actions",
247
+ "description": "At most two invalid actions.",
248
+ "check": "param_at_most:state.invalid_actions=2",
249
+ },
250
+ ],
251
+ ),
252
+
253
+ # -----------------------------------------------------------------------
254
+ # Task 5: heavy_lifting (hard) — items have weight, limited carry capacity
255
+ # -----------------------------------------------------------------------
256
+ "heavy_lifting": TaskDefinition(
257
+ task_id="heavy_lifting",
258
+ difficulty="hard",
259
+ title="Heavy-item logistics with weight limits",
260
+ description=(
261
+ "Fulfill a three-item order where items vary in weight (1-4 units). "
262
+ "The agent has a carry capacity of 3 and must choose pickup order wisely. "
263
+ "Heavier items drain more battery while moving. Scan each bin, pick items "
264
+ "that fit within your carry limit, and pack at the station. "
265
+ "The heavy pain_relief (weight 4) cannot be carried — skip it!"
266
+ ),
267
+ max_steps=90,
268
+ battery_capacity=32,
269
+ low_battery_threshold=8,
270
+ agent_start=(1, 1),
271
+ agent_heading="E",
272
+ dock_position=(1, 1),
273
+ pack_station_position=(5, 5),
274
+ charger_position=(1, 5),
275
+ bins=[
276
+ BinState(bin_id="A1", position=(2, 1), sku="thermometer", quantity=3, weight=1),
277
+ BinState(bin_id="A2", position=(2, 3), sku="cough_syrup", quantity=2, weight=2),
278
+ BinState(bin_id="B1", position=(4, 1), sku="bandage_kit", quantity=2, weight=3),
279
+ BinState(bin_id="B2", position=(4, 3), sku="pain_relief", quantity=2, weight=4),
280
+ BinState(bin_id="C1", position=(2, 5), sku="gloves", quantity=3, weight=1),
281
+ ],
282
+ order=[
283
+ OrderLine(sku="thermometer", quantity=1),
284
+ OrderLine(sku="cough_syrup", quantity=1),
285
+ OrderLine(sku="bandage_kit", quantity=1),
286
+ ],
287
+ required_scans=["A1", "A2", "B1"],
288
+ carry_capacity=3,
289
+ rubric_criteria=[
290
+ {
291
+ "name": "completion",
292
+ "description": "All items packed.",
293
+ "check": "param_at_least:state.completion_ratio=1.0",
294
+ },
295
+ {
296
+ "name": "scans",
297
+ "description": "Scanned all three required bins.",
298
+ "check": "param_at_least:state.correct_scans=3",
299
+ },
300
+ {
301
+ "name": "recharge",
302
+ "description": "Recharged at least once.",
303
+ "check": "tool_used:recharge",
304
+ },
305
+ {
306
+ "name": "no_overweight",
307
+ "description": "Never tried to pick an overweight item.",
308
+ "check": "param_at_most:state.overweight_attempts=0",
309
+ },
310
+ {
311
+ "name": "no_battery_depletion",
312
+ "description": "Avoided battery depletion.",
313
+ "check": "param_at_most:state.battery_depletion_events=0",
314
+ },
315
+ {
316
+ "name": "no_wrong_picks",
317
+ "description": "No incorrect picks.",
318
+ "check": "param_at_most:state.wrong_picks=0",
319
+ },
320
+ {
321
+ "name": "few_invalid_actions",
322
+ "description": "At most two invalid actions.",
323
+ "check": "param_at_most:state.invalid_actions=2",
324
+ },
325
+ ],
326
+ ),
327
+
328
+ # -----------------------------------------------------------------------
329
+ # Task 6: stamina_run (hard) — stamina drains on movement
330
+ # -----------------------------------------------------------------------
331
+ "stamina_run": TaskDefinition(
332
+ task_id="stamina_run",
333
+ difficulty="hard",
334
+ title="Endurance run with stamina management",
335
+ description=(
336
+ "Fulfill a two-item order while managing stamina. Every move drains stamina; "
337
+ "when stamina hits zero, movement costs double battery. Rest at the rest area "
338
+ "to restore stamina. Pick one cough syrup and one gloves unit, scan bins, and "
339
+ "pack at the station without running out of energy."
340
+ ),
341
+ max_steps=80,
342
+ battery_capacity=36,
343
+ low_battery_threshold=8,
344
+ agent_start=(0, 0),
345
+ agent_heading="E",
346
+ dock_position=(0, 0),
347
+ pack_station_position=(6, 6),
348
+ charger_position=(0, 6),
349
+ rest_position=(3, 3),
350
+ stamina_capacity=12,
351
+ stamina_move_cost=1,
352
+ bins=[
353
+ BinState(bin_id="A1", position=(2, 1), sku="thermometer", quantity=2),
354
+ BinState(bin_id="A2", position=(2, 3), sku="cough_syrup", quantity=2),
355
+ BinState(bin_id="B1", position=(4, 1), sku="bandage_kit", quantity=2),
356
+ BinState(bin_id="B2", position=(4, 3), sku="pain_relief", quantity=2),
357
+ BinState(bin_id="C1", position=(2, 5), sku="gloves", quantity=2),
358
+ ],
359
+ order=[
360
+ OrderLine(sku="cough_syrup", quantity=1),
361
+ OrderLine(sku="gloves", quantity=1),
362
+ ],
363
+ required_scans=["A2", "C1"],
364
+ rubric_criteria=[
365
+ {
366
+ "name": "completion",
367
+ "description": "All items packed.",
368
+ "check": "param_at_least:state.completion_ratio=1.0",
369
+ },
370
+ {
371
+ "name": "scans",
372
+ "description": "Scanned both required bins.",
373
+ "check": "param_at_least:state.correct_scans=2",
374
+ },
375
+ {
376
+ "name": "rest_used",
377
+ "description": "Used the rest area at least once.",
378
+ "check": "tool_used:rest",
379
+ },
380
+ {
381
+ "name": "no_stamina_depletion",
382
+ "description": "Avoided complete stamina depletion.",
383
+ "check": "param_at_most:state.stamina_depletion_events=0",
384
+ },
385
+ {
386
+ "name": "no_battery_depletion",
387
+ "description": "Avoided battery depletion.",
388
+ "check": "param_at_most:state.battery_depletion_events=0",
389
+ },
390
+ {
391
+ "name": "no_wrong_picks",
392
+ "description": "No incorrect picks.",
393
+ "check": "param_at_most:state.wrong_picks=0",
394
+ },
395
+ ],
396
+ ),
397
+
398
+ # -----------------------------------------------------------------------
399
+ # Task 7: budget_run (expert) — money rewards and profit target
400
+ # -----------------------------------------------------------------------
401
+ "budget_run": TaskDefinition(
402
+ task_id="budget_run",
403
+ difficulty="expert",
404
+ title="Profitable fulfillment under budget pressure",
405
+ description=(
406
+ "Fulfill orders for profit. Each item has a dollar value earned when correctly "
407
+ "packed. Wrong packs lose half the item value. You must reach a profit target "
408
+ "of $15.00 while completing the order. Pick high-value items efficiently: "
409
+ "2 thermometers ($5 each) and 1 bandage kit ($8). Budget-aware decisions matter."
410
+ ),
411
+ max_steps=70,
412
+ battery_capacity=30,
413
+ low_battery_threshold=6,
414
+ agent_start=(1, 1),
415
+ agent_heading="E",
416
+ dock_position=(1, 1),
417
+ pack_station_position=(5, 5),
418
+ charger_position=(1, 5),
419
+ bins=[
420
+ BinState(bin_id="A1", position=(2, 1), sku="thermometer", quantity=3, value=5.0),
421
+ BinState(bin_id="A2", position=(2, 3), sku="cough_syrup", quantity=2, value=3.0),
422
+ BinState(bin_id="B1", position=(4, 1), sku="bandage_kit", quantity=2, value=8.0),
423
+ BinState(bin_id="B2", position=(4, 3), sku="pain_relief", quantity=2, value=4.0),
424
+ BinState(bin_id="C1", position=(2, 5), sku="gloves", quantity=2, value=2.0),
425
+ ],
426
+ order=[
427
+ OrderLine(sku="thermometer", quantity=2),
428
+ OrderLine(sku="bandage_kit", quantity=1),
429
+ ],
430
+ required_scans=["A1", "B1"],
431
+ profit_target=15.0,
432
+ rubric_criteria=[
433
+ {
434
+ "name": "completion",
435
+ "description": "All items packed.",
436
+ "check": "param_at_least:state.completion_ratio=1.0",
437
+ },
438
+ {
439
+ "name": "profit_target",
440
+ "description": "Reached the profit target of $15.",
441
+ "check": "param_at_least:state.money_earned=15.0",
442
+ },
443
+ {
444
+ "name": "scans",
445
+ "description": "Scanned required bins.",
446
+ "check": "param_at_least:state.correct_scans=2",
447
+ },
448
+ {
449
+ "name": "recharge",
450
+ "description": "Recharged at least once.",
451
+ "check": "tool_used:recharge",
452
+ },
453
+ {
454
+ "name": "no_money_lost",
455
+ "description": "No money lost from wrong packs.",
456
+ "check": "param_at_most:state.money_lost=0.0",
457
+ },
458
+ {
459
+ "name": "no_wrong_picks",
460
+ "description": "No incorrect picks.",
461
+ "check": "param_at_most:state.wrong_picks=0",
462
+ },
463
+ {
464
+ "name": "few_invalid_actions",
465
+ "description": "At most one invalid action.",
466
+ "check": "param_at_most:state.invalid_actions=1",
467
+ },
468
+ ],
469
+ ),
470
+
471
+ # -----------------------------------------------------------------------
472
+ # Task 8: gauntlet (expert) — all mechanics combined
473
+ # -----------------------------------------------------------------------
474
+ "gauntlet": TaskDefinition(
475
+ task_id="gauntlet",
476
+ difficulty="expert",
477
+ title="The gauntlet: obstacles, weight, stamina, and profit",
478
+ description=(
479
+ "The ultimate warehouse challenge. Navigate a cluttered warehouse with obstacles, "
480
+ "manage item weights (carry capacity 3), conserve stamina (rest when needed), "
481
+ "earn money for packed items, and hit a $20 profit target. Fulfill a four-item "
482
+ "order: 1 thermometer ($5, wt 1), 1 cough syrup ($6, wt 2), 1 bandage kit ($8, wt 3), "
483
+ "and 1 gloves ($4, wt 1). Recharge battery, rest for stamina, avoid obstacles, "
484
+ "and finish profitable."
485
+ ),
486
+ max_steps=120,
487
+ battery_capacity=28,
488
+ low_battery_threshold=7,
489
+ agent_start=(0, 0),
490
+ agent_heading="S",
491
+ dock_position=(0, 0),
492
+ pack_station_position=(6, 6),
493
+ charger_position=(6, 0),
494
+ rest_position=(0, 6),
495
+ stamina_capacity=10,
496
+ stamina_move_cost=1,
497
+ carry_capacity=3,
498
+ profit_target=20.0,
499
+ obstacles=[(1, 1), (3, 1), (5, 3), (3, 3), (1, 5), (5, 5)],
500
+ bins=[
501
+ BinState(bin_id="A1", position=(2, 0), sku="thermometer", quantity=3, weight=1, value=5.0),
502
+ BinState(bin_id="A2", position=(2, 4), sku="cough_syrup", quantity=2, weight=2, value=6.0),
503
+ BinState(bin_id="B1", position=(4, 0), sku="bandage_kit", quantity=2, weight=3, value=8.0),
504
+ BinState(bin_id="B2", position=(4, 4), sku="pain_relief", quantity=2, weight=4, value=4.0),
505
+ BinState(bin_id="C1", position=(4, 2), sku="gloves", quantity=3, weight=1, value=4.0),
506
+ ],
507
+ order=[
508
+ OrderLine(sku="thermometer", quantity=1),
509
+ OrderLine(sku="cough_syrup", quantity=1),
510
+ OrderLine(sku="bandage_kit", quantity=1),
511
+ OrderLine(sku="gloves", quantity=1),
512
+ ],
513
+ required_scans=["A1", "A2", "B1", "C1"],
514
+ rubric_criteria=[
515
+ {
516
+ "name": "completion",
517
+ "description": "All four items packed.",
518
+ "check": "param_at_least:state.completion_ratio=1.0",
519
+ },
520
+ {
521
+ "name": "profit_target",
522
+ "description": "Reached $20 profit target.",
523
+ "check": "param_at_least:state.money_earned=20.0",
524
+ },
525
+ {
526
+ "name": "scans",
527
+ "description": "Scanned all four required bins.",
528
+ "check": "param_at_least:state.correct_scans=4",
529
+ },
530
+ {
531
+ "name": "recharge",
532
+ "description": "Recharged at least once.",
533
+ "check": "tool_used:recharge",
534
+ },
535
+ {
536
+ "name": "rest_used",
537
+ "description": "Used the rest area at least once.",
538
+ "check": "tool_used:rest",
539
+ },
540
+ {
541
+ "name": "no_obstacle_collisions",
542
+ "description": "Avoided all obstacle collisions.",
543
+ "check": "param_at_most:state.obstacle_collisions=0",
544
+ },
545
+ {
546
+ "name": "no_overweight",
547
+ "description": "Never tried to pick an overweight item.",
548
+ "check": "param_at_most:state.overweight_attempts=0",
549
+ },
550
+ {
551
+ "name": "no_battery_depletion",
552
+ "description": "Avoided battery depletion.",
553
+ "check": "param_at_most:state.battery_depletion_events=0",
554
+ },
555
+ {
556
+ "name": "no_stamina_depletion",
557
+ "description": "Avoided stamina depletion.",
558
+ "check": "param_at_most:state.stamina_depletion_events=0",
559
+ },
560
+ {
561
+ "name": "no_wrong_picks",
562
+ "description": "No incorrect picks.",
563
+ "check": "param_at_most:state.wrong_picks=0",
564
+ },
565
+ {
566
+ "name": "few_invalid_actions",
567
+ "description": "At most two invalid actions.",
568
+ "check": "param_at_most:state.invalid_actions=2",
569
+ },
570
+ ],
571
+ ),
572
  }
573
 
574
 
grid_env/tools.py CHANGED
@@ -18,6 +18,7 @@ _TOOL_DESCRIPTIONS: Dict[str, str] = {
18
  "pick_item": "Pick an item from the bin in front.",
19
  "pack_item": "Pack the carried item at the packing station.",
20
  "recharge": "Recharge the battery at the charging dock.",
 
21
  "wait": "Stay in place and consume time.",
22
  }
23
 
 
18
  "pick_item": "Pick an item from the bin in front.",
19
  "pack_item": "Pack the carried item at the packing station.",
20
  "recharge": "Recharge the battery at the charging dock.",
21
+ "rest": "Rest at the rest area to restore stamina.",
22
  "wait": "Stay in place and consume time.",
23
  }
24
 
openenv.yaml CHANGED
@@ -28,6 +28,21 @@ tasks:
28
  - id: hard_restock_priority
29
  difficulty: hard
30
  grader: grid_env.graders:grade_hard
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  baseline:
32
  runner: grid_env.baseline:run_baseline
33
  seed: 7
 
28
  - id: hard_restock_priority
29
  difficulty: hard
30
  grader: grid_env.graders:grade_hard
31
+ - id: obstacle_course
32
+ difficulty: medium
33
+ grader: grid_env.graders:grade_obstacle_course
34
+ - id: heavy_lifting
35
+ difficulty: hard
36
+ grader: grid_env.graders:grade_heavy_lifting
37
+ - id: stamina_run
38
+ difficulty: hard
39
+ grader: grid_env.graders:grade_stamina_run
40
+ - id: budget_run
41
+ difficulty: expert
42
+ grader: grid_env.graders:grade_budget_run
43
+ - id: gauntlet
44
+ difficulty: expert
45
+ grader: grid_env.graders:grade_gauntlet
46
  baseline:
47
  runner: grid_env.baseline:run_baseline
48
  seed: 7
tests/conftest.py CHANGED
@@ -25,3 +25,38 @@ def env_hard():
25
  env = WarehouseFulfillmentEnv(task_id="hard_restock_priority", seed=7)
26
  env.reset()
27
  return env
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  env = WarehouseFulfillmentEnv(task_id="hard_restock_priority", seed=7)
26
  env.reset()
27
  return env
28
+
29
+
30
+ @pytest.fixture()
31
+ def env_obstacle_course():
32
+ env = WarehouseFulfillmentEnv(task_id="obstacle_course", seed=7)
33
+ env.reset()
34
+ return env
35
+
36
+
37
+ @pytest.fixture()
38
+ def env_heavy_lifting():
39
+ env = WarehouseFulfillmentEnv(task_id="heavy_lifting", seed=7)
40
+ env.reset()
41
+ return env
42
+
43
+
44
+ @pytest.fixture()
45
+ def env_stamina_run():
46
+ env = WarehouseFulfillmentEnv(task_id="stamina_run", seed=7)
47
+ env.reset()
48
+ return env
49
+
50
+
51
+ @pytest.fixture()
52
+ def env_budget_run():
53
+ env = WarehouseFulfillmentEnv(task_id="budget_run", seed=7)
54
+ env.reset()
55
+ return env
56
+
57
+
58
+ @pytest.fixture()
59
+ def env_gauntlet():
60
+ env = WarehouseFulfillmentEnv(task_id="gauntlet", seed=7)
61
+ env.reset()
62
+ return env
tests/test_baseline_stub.py CHANGED
@@ -14,7 +14,16 @@ from grid_env.graders import grade_episode
14
  from grid_env.models import BaselineCommand
15
 
16
 
17
- TASK_IDS = ["easy_single_pick", "medium_multi_item", "hard_restock_priority"]
 
 
 
 
 
 
 
 
 
18
 
19
  # Cycle of deterministic actions that exercise most code paths without getting stuck.
20
  _ACTION_CYCLE = [
@@ -27,6 +36,7 @@ _ACTION_CYCLE = [
27
  "turn_right",
28
  "move_forward",
29
  "pack_item",
 
30
  "wait",
31
  ]
32
 
 
14
  from grid_env.models import BaselineCommand
15
 
16
 
17
+ TASK_IDS = [
18
+ "easy_single_pick",
19
+ "medium_multi_item",
20
+ "hard_restock_priority",
21
+ "obstacle_course",
22
+ "heavy_lifting",
23
+ "stamina_run",
24
+ "budget_run",
25
+ "gauntlet",
26
+ ]
27
 
28
  # Cycle of deterministic actions that exercise most code paths without getting stuck.
29
  _ACTION_CYCLE = [
 
36
  "turn_right",
37
  "move_forward",
38
  "pack_item",
39
+ "rest",
40
  "wait",
41
  ]
42
 
tests/test_env_smoke.py CHANGED
@@ -1,6 +1,6 @@
1
  """
2
  Smoke tests: environment instantiation, reset, step, and episode termination
3
- for all three task IDs.
4
  """
5
 
6
  import pytest
@@ -9,7 +9,16 @@ from grid_env.graders import grade_episode
9
  from grid_env.models import WarehouseObservation, WarehouseReward
10
 
11
 
12
- TASK_IDS = ["easy_single_pick", "medium_multi_item", "hard_restock_priority"]
 
 
 
 
 
 
 
 
 
13
  ALL_ACTIONS = [
14
  "turn_left",
15
  "turn_right",
@@ -18,6 +27,7 @@ ALL_ACTIONS = [
18
  "pick_item",
19
  "pack_item",
20
  "recharge",
 
21
  "wait",
22
  ]
23
 
@@ -104,8 +114,8 @@ def test_step_after_done_is_safe():
104
  assert done
105
 
106
 
107
- def test_available_tasks_returns_all_three():
108
- """available_tasks() returns exactly the three expected task IDs."""
109
  tasks = available_tasks()
110
  ids = {t["task_id"] for t in tasks}
111
  assert ids == set(TASK_IDS)
 
1
  """
2
  Smoke tests: environment instantiation, reset, step, and episode termination
3
+ for all task IDs.
4
  """
5
 
6
  import pytest
 
9
  from grid_env.models import WarehouseObservation, WarehouseReward
10
 
11
 
12
+ TASK_IDS = [
13
+ "easy_single_pick",
14
+ "medium_multi_item",
15
+ "hard_restock_priority",
16
+ "obstacle_course",
17
+ "heavy_lifting",
18
+ "stamina_run",
19
+ "budget_run",
20
+ "gauntlet",
21
+ ]
22
  ALL_ACTIONS = [
23
  "turn_left",
24
  "turn_right",
 
27
  "pick_item",
28
  "pack_item",
29
  "recharge",
30
+ "rest",
31
  "wait",
32
  ]
33
 
 
114
  assert done
115
 
116
 
117
+ def test_available_tasks_returns_all():
118
+ """available_tasks() returns all expected task IDs."""
119
  tasks = available_tasks()
120
  ids = {t["task_id"] for t in tasks}
121
  assert ids == set(TASK_IDS)
tests/test_graders.py CHANGED
@@ -4,7 +4,17 @@ Unit tests for rubric-based graders.
4
 
5
  import pytest
6
 
7
- from grid_env.graders import grade_easy, grade_episode, grade_hard, grade_medium
 
 
 
 
 
 
 
 
 
 
8
  from grid_env.models import BinState, OrderLine, WarehouseMetrics, WarehouseState
9
 
10
 
@@ -21,6 +31,12 @@ def _make_state(
21
  invalid_actions: int = 0,
22
  recharges: int = 0,
23
  battery_depletion_events: int = 0,
 
 
 
 
 
 
24
  action_history: list[str] | None = None,
25
  ) -> WarehouseState:
26
  metrics = WarehouseMetrics(
@@ -32,6 +48,12 @@ def _make_state(
32
  invalid_actions=invalid_actions,
33
  recharges=recharges,
34
  battery_depletion_events=battery_depletion_events,
 
 
 
 
 
 
35
  )
36
  return WarehouseState(
37
  episode_id="test-ep",
@@ -138,14 +160,128 @@ def test_grade_hard_partial_rubric_scores_lower():
138
  assert 0.0 < grade_hard(state) < 1.0
139
 
140
 
141
- @pytest.mark.parametrize(
142
- "task_id,grader",
143
- [
144
- ("easy_single_pick", grade_easy),
145
- ("medium_multi_item", grade_medium),
146
- ("hard_restock_priority", grade_hard),
147
- ],
148
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  def test_grade_episode_dispatches_correctly(task_id, grader):
150
  state = _make_state(task_id, completion_ratio=0.5, action_history=["pick_item"])
151
  assert grade_episode(state) == grader(state)
 
4
 
5
  import pytest
6
 
7
+ from grid_env.graders import (
8
+ grade_budget_run,
9
+ grade_easy,
10
+ grade_episode,
11
+ grade_gauntlet,
12
+ grade_hard,
13
+ grade_heavy_lifting,
14
+ grade_medium,
15
+ grade_obstacle_course,
16
+ grade_stamina_run,
17
+ )
18
  from grid_env.models import BinState, OrderLine, WarehouseMetrics, WarehouseState
19
 
20
 
 
31
  invalid_actions: int = 0,
32
  recharges: int = 0,
33
  battery_depletion_events: int = 0,
34
+ stamina_depletion_events: int = 0,
35
+ rest_events: int = 0,
36
+ obstacle_collisions: int = 0,
37
+ money_earned: float = 0.0,
38
+ money_lost: float = 0.0,
39
+ overweight_attempts: int = 0,
40
  action_history: list[str] | None = None,
41
  ) -> WarehouseState:
42
  metrics = WarehouseMetrics(
 
48
  invalid_actions=invalid_actions,
49
  recharges=recharges,
50
  battery_depletion_events=battery_depletion_events,
51
+ stamina_depletion_events=stamina_depletion_events,
52
+ rest_events=rest_events,
53
+ obstacle_collisions=obstacle_collisions,
54
+ money_earned=money_earned,
55
+ money_lost=money_lost,
56
+ overweight_attempts=overweight_attempts,
57
  )
58
  return WarehouseState(
59
  episode_id="test-ep",
 
160
  assert 0.0 < grade_hard(state) < 1.0
161
 
162
 
163
+ def test_grade_obstacle_course_full_rubric_passes():
164
+ state = _make_state(
165
+ "obstacle_course",
166
+ completion_ratio=1.0,
167
+ correct_scans=2,
168
+ wrong_picks=0,
169
+ invalid_actions=2,
170
+ obstacle_collisions=0,
171
+ action_history=["scan_bin", "pick_item", "pack_item"],
172
+ )
173
+ assert grade_obstacle_course(state) == pytest.approx(1.0)
174
+
175
+
176
+ def test_grade_obstacle_course_collision_lowers_score():
177
+ state = _make_state(
178
+ "obstacle_course",
179
+ completion_ratio=1.0,
180
+ correct_scans=2,
181
+ wrong_picks=0,
182
+ invalid_actions=2,
183
+ obstacle_collisions=3,
184
+ action_history=["scan_bin", "pick_item", "pack_item"],
185
+ )
186
+ assert 0.0 < grade_obstacle_course(state) < 1.0
187
+
188
+
189
+ def test_grade_heavy_lifting_full_rubric_passes():
190
+ state = _make_state(
191
+ "heavy_lifting",
192
+ completion_ratio=1.0,
193
+ correct_scans=3,
194
+ wrong_picks=0,
195
+ invalid_actions=2,
196
+ recharges=1,
197
+ battery_depletion_events=0,
198
+ overweight_attempts=0,
199
+ action_history=["scan_bin", "pick_item", "pack_item", "recharge"],
200
+ )
201
+ assert grade_heavy_lifting(state) == pytest.approx(1.0)
202
+
203
+
204
+ def test_grade_stamina_run_full_rubric_passes():
205
+ state = _make_state(
206
+ "stamina_run",
207
+ completion_ratio=1.0,
208
+ correct_scans=2,
209
+ wrong_picks=0,
210
+ stamina_depletion_events=0,
211
+ battery_depletion_events=0,
212
+ rest_events=1,
213
+ action_history=["scan_bin", "pick_item", "pack_item", "rest"],
214
+ )
215
+ assert grade_stamina_run(state) == pytest.approx(1.0)
216
+
217
+
218
+ def test_grade_budget_run_full_rubric_passes():
219
+ state = _make_state(
220
+ "budget_run",
221
+ completion_ratio=1.0,
222
+ correct_scans=2,
223
+ wrong_picks=0,
224
+ invalid_actions=1,
225
+ recharges=1,
226
+ money_earned=18.0,
227
+ money_lost=0.0,
228
+ action_history=["scan_bin", "pick_item", "pack_item", "recharge"],
229
+ )
230
+ assert grade_budget_run(state) == pytest.approx(1.0)
231
+
232
+
233
+ def test_grade_gauntlet_full_rubric_passes():
234
+ state = _make_state(
235
+ "gauntlet",
236
+ completion_ratio=1.0,
237
+ correct_scans=4,
238
+ wrong_picks=0,
239
+ invalid_actions=2,
240
+ recharges=1,
241
+ battery_depletion_events=0,
242
+ stamina_depletion_events=0,
243
+ rest_events=1,
244
+ obstacle_collisions=0,
245
+ overweight_attempts=0,
246
+ money_earned=23.0,
247
+ money_lost=0.0,
248
+ action_history=["scan_bin", "pick_item", "pack_item", "recharge", "rest"],
249
+ )
250
+ assert grade_gauntlet(state) == pytest.approx(1.0)
251
+
252
+
253
+ def test_grade_gauntlet_partial_scores_lower():
254
+ state = _make_state(
255
+ "gauntlet",
256
+ completion_ratio=1.0,
257
+ correct_scans=2,
258
+ wrong_picks=0,
259
+ invalid_actions=5,
260
+ recharges=0,
261
+ battery_depletion_events=1,
262
+ stamina_depletion_events=1,
263
+ obstacle_collisions=2,
264
+ overweight_attempts=1,
265
+ money_earned=10.0,
266
+ money_lost=5.0,
267
+ action_history=["scan_bin", "pick_item", "pack_item"],
268
+ )
269
+ assert 0.0 < grade_gauntlet(state) < 1.0
270
+
271
+
272
+ ALL_GRADERS = [
273
+ ("easy_single_pick", grade_easy),
274
+ ("medium_multi_item", grade_medium),
275
+ ("hard_restock_priority", grade_hard),
276
+ ("obstacle_course", grade_obstacle_course),
277
+ ("heavy_lifting", grade_heavy_lifting),
278
+ ("stamina_run", grade_stamina_run),
279
+ ("budget_run", grade_budget_run),
280
+ ("gauntlet", grade_gauntlet),
281
+ ]
282
+
283
+
284
+ @pytest.mark.parametrize("task_id,grader", ALL_GRADERS)
285
  def test_grade_episode_dispatches_correctly(task_id, grader):
286
  state = _make_state(task_id, completion_ratio=0.5, action_history=["pick_item"])
287
  assert grade_episode(state) == grader(state)
tests/test_server.py CHANGED
@@ -18,7 +18,16 @@ from fastapi.testclient import TestClient
18
  from grid_env.Server.app import app
19
 
20
 
21
- TASK_IDS = ["easy_single_pick", "medium_multi_item", "hard_restock_priority"]
 
 
 
 
 
 
 
 
 
22
  ALL_ACTIONS = [
23
  "turn_left",
24
  "turn_right",
@@ -27,6 +36,7 @@ ALL_ACTIONS = [
27
  "pick_item",
28
  "pack_item",
29
  "recharge",
 
30
  "wait",
31
  ]
32
 
@@ -78,7 +88,7 @@ def test_tasks_has_tasks_key(client):
78
  assert isinstance(body["tasks"], list)
79
 
80
 
81
- def test_tasks_returns_all_three(client):
82
  body = client.get("/tasks").json()
83
  ids = {t["task_id"] for t in body["tasks"]}
84
  assert ids == set(TASK_IDS)
 
18
  from grid_env.Server.app import app
19
 
20
 
21
+ TASK_IDS = [
22
+ "easy_single_pick",
23
+ "medium_multi_item",
24
+ "hard_restock_priority",
25
+ "obstacle_course",
26
+ "heavy_lifting",
27
+ "stamina_run",
28
+ "budget_run",
29
+ "gauntlet",
30
+ ]
31
  ALL_ACTIONS = [
32
  "turn_left",
33
  "turn_right",
 
36
  "pick_item",
37
  "pack_item",
38
  "recharge",
39
+ "rest",
40
  "wait",
41
  ]
42
 
 
88
  assert isinstance(body["tasks"], list)
89
 
90
 
91
+ def test_tasks_returns_all(client):
92
  body = client.get("/tasks").json()
93
  ids = {t["task_id"] for t in body["tasks"]}
94
  assert ids == set(TASK_IDS)
tests/test_tasks.py CHANGED
@@ -1,5 +1,5 @@
1
  """
2
- Tests for task definitions: presence of all 3 tasks, structural validity,
3
  and grader callability.
4
  """
5
 
@@ -9,11 +9,20 @@ from grid_env.graders import grade_episode
9
  from grid_env.env import WarehouseFulfillmentEnv
10
 
11
 
12
- EXPECTED_TASK_IDS = {"easy_single_pick", "medium_multi_item", "hard_restock_priority"}
 
 
 
 
 
 
 
 
 
13
 
14
 
15
- def test_exactly_three_tasks_registered():
16
- assert len(TASKS) == 3
17
 
18
 
19
  def test_all_expected_task_ids_present():
@@ -24,7 +33,7 @@ def test_all_expected_task_ids_present():
24
  def test_task_has_required_fields(task_id):
25
  task = get_task(task_id)
26
  assert task.task_id == task_id
27
- assert task.difficulty in {"easy", "medium", "hard"}
28
  assert task.max_steps > 0
29
  assert task.battery_capacity > 0
30
  assert len(task.bins) > 0
@@ -69,6 +78,24 @@ def test_grader_callable_returns_float_in_range(task_id):
69
  assert 0.0 <= score <= 1.0
70
 
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  def test_get_task_raises_on_unknown_id():
73
  with pytest.raises(KeyError, match="Unknown task_id"):
74
  get_task("does_not_exist")
 
1
  """
2
+ Tests for task definitions: presence of all tasks, structural validity,
3
  and grader callability.
4
  """
5
 
 
9
  from grid_env.env import WarehouseFulfillmentEnv
10
 
11
 
12
+ EXPECTED_TASK_IDS = {
13
+ "easy_single_pick",
14
+ "medium_multi_item",
15
+ "hard_restock_priority",
16
+ "obstacle_course",
17
+ "heavy_lifting",
18
+ "stamina_run",
19
+ "budget_run",
20
+ "gauntlet",
21
+ }
22
 
23
 
24
+ def test_expected_task_count():
25
+ assert len(TASKS) == len(EXPECTED_TASK_IDS)
26
 
27
 
28
  def test_all_expected_task_ids_present():
 
33
  def test_task_has_required_fields(task_id):
34
  task = get_task(task_id)
35
  assert task.task_id == task_id
36
+ assert task.difficulty in {"easy", "medium", "hard", "expert"}
37
  assert task.max_steps > 0
38
  assert task.battery_capacity > 0
39
  assert len(task.bins) > 0
 
78
  assert 0.0 <= score <= 1.0
79
 
80
 
81
+ @pytest.mark.parametrize("task_id", list(EXPECTED_TASK_IDS))
82
+ def test_obstacles_do_not_overlap_bins_or_stations(task_id):
83
+ """Obstacles must not overlap with bins, stations, or agent start."""
84
+ task = get_task(task_id)
85
+ obstacle_set = set(task.obstacles)
86
+ bin_positions = {tuple(b.position) for b in task.bins}
87
+ fixed_positions = {
88
+ tuple(task.pack_station_position),
89
+ tuple(task.charger_position),
90
+ tuple(task.dock_position),
91
+ tuple(task.agent_start),
92
+ }
93
+ if task.rest_position:
94
+ fixed_positions.add(tuple(task.rest_position))
95
+ assert obstacle_set.isdisjoint(bin_positions), f"Obstacle overlaps bin in {task_id}"
96
+ assert obstacle_set.isdisjoint(fixed_positions), f"Obstacle overlaps station in {task_id}"
97
+
98
+
99
  def test_get_task_raises_on_unknown_id():
100
  with pytest.raises(KeyError, match="Unknown task_id"):
101
  get_task("does_not_exist")