Spaces:
Sleeping
Sleeping
| """ | |
| MiniGrid-style warehouse fulfillment environment. | |
| """ | |
| from __future__ import annotations | |
| from typing import Any, Dict, List, Optional, Tuple | |
| from .graders import grade_episode | |
| from .models import ( | |
| BinState, | |
| OrderLine, | |
| PackedOrderLine, | |
| PendingOrderLine, | |
| TaskDefinition, | |
| WarehouseAction, | |
| WarehouseMetrics, | |
| WarehouseObservation, | |
| WarehouseReward, | |
| WarehouseState, | |
| ) | |
| from .tasks import GRID_SIZE, TASKS, get_task | |
| HEADINGS = ["N", "E", "S", "W"] | |
| MOVE_DELTA = { | |
| "N": (0, -1), | |
| "E": (1, 0), | |
| "S": (0, 1), | |
| "W": (-1, 0), | |
| } | |
| class WarehouseFulfillmentEnv: | |
| action_space = [ | |
| "turn_left", | |
| "turn_right", | |
| "move_forward", | |
| "scan_bin", | |
| "pick_item", | |
| "pack_item", | |
| "recharge", | |
| "rest", | |
| "wait", | |
| ] | |
| observation_space = [ | |
| "task_id", | |
| "mission", | |
| "narrative", | |
| "agent_position", | |
| "heading", | |
| "front_cell", | |
| "carrying", | |
| "carrying_weight", | |
| "battery_level", | |
| "stamina_level", | |
| "money", | |
| "visible_bins", | |
| "pending_order", | |
| "packed_order", | |
| "progress_ratio", | |
| ] | |
| def __init__(self, task_id: str = "easy_single_pick", seed: int = 7) -> None: | |
| self.grid_size = GRID_SIZE | |
| self.seed = seed | |
| self._episode_counter = 0 | |
| self.task: Optional[TaskDefinition] = None | |
| self._reset_runtime(task_id) | |
| def reset( | |
| self, | |
| task_id: Optional[str] = None, | |
| seed: Optional[int] = None, | |
| ) -> WarehouseObservation: | |
| if seed is not None: | |
| self.seed = seed | |
| self._episode_counter += 1 | |
| self._reset_runtime(task_id or self.task.task_id) | |
| return self._build_observation("Episode reset. Start the pick-pack workflow.") | |
| def step( | |
| self, | |
| action: WarehouseAction | str, | |
| ) -> Tuple[WarehouseObservation, WarehouseReward, bool, Dict[str, Any]]: | |
| # Accept string actions from callers but keep handling robust for unknowns. | |
| command = action.command if isinstance(action, WarehouseAction) else str(action) | |
| if self.done: | |
| observation = self._build_observation("Episode already complete.") | |
| reward = WarehouseReward( | |
| value=0.0, | |
| reason="Episode already complete.", | |
| completion_ratio=self._completion_ratio(), | |
| ) | |
| return observation, reward, True, { | |
| "score": grade_episode(self.state()), | |
| "completion_ratio": self._completion_ratio(), | |
| "metrics": self.metrics.model_dump(), | |
| "terminated": bool(self.success), | |
| "truncated": bool(self.step_count >= self.task.max_steps and not self.success), | |
| "termination_reason": self.termination_reason, | |
| } | |
| self.step_count += 1 | |
| reward_value = -0.01 | |
| narrative = "Action processed." | |
| prev_completion = self._completion_ratio() | |
| if command == "turn_left": | |
| self.heading = HEADINGS[(HEADINGS.index(self.heading) - 1) % len(HEADINGS)] | |
| self._consume_battery(1) | |
| narrative = f"Turned left. Now facing {self.heading}." | |
| elif command == "turn_right": | |
| self.heading = HEADINGS[(HEADINGS.index(self.heading) + 1) % len(HEADINGS)] | |
| self._consume_battery(1) | |
| narrative = f"Turned right. Now facing {self.heading}." | |
| elif command == "move_forward": | |
| reward_value, narrative = self._move_forward(reward_value) | |
| elif command == "scan_bin": | |
| reward_value, narrative = self._scan_bin(reward_value) | |
| elif command == "pick_item": | |
| reward_value, narrative = self._pick_item(reward_value) | |
| elif command == "pack_item": | |
| reward_value, narrative = self._pack_item(reward_value) | |
| elif command == "recharge": | |
| reward_value, narrative = self._recharge(reward_value) | |
| elif command == "rest": | |
| reward_value, narrative = self._rest(reward_value) | |
| elif command == "wait": | |
| reward_value -= 0.01 | |
| narrative = "Waited in place and lost time." | |
| else: | |
| self.metrics.invalid_actions += 1 | |
| reward_value -= 0.10 | |
| narrative = f"Unknown action: {command}." | |
| completion_now = self._completion_ratio() | |
| progress_delta = max(0.0, completion_now - prev_completion) | |
| if progress_delta > 0.0: | |
| # Small dense shaping for moving the order toward completion. | |
| reward_value += 0.15 * progress_delta | |
| self.action_history.append(command) | |
| is_complete = self._is_episode_complete() | |
| hit_step_limit = self.step_count >= self.task.max_steps | |
| self.done = is_complete or hit_step_limit | |
| self.success = is_complete | |
| self.termination_reason = ( | |
| "task_complete" if is_complete else ("max_steps_reached" if hit_step_limit else None) | |
| ) | |
| if self.success: | |
| reward_value += 0.50 | |
| narrative = "Order fully packed and ready for dispatch." | |
| self.total_reward += reward_value | |
| observation = self._build_observation(narrative) | |
| reward = WarehouseReward( | |
| value=round(reward_value, 4), | |
| reason=narrative, | |
| completion_ratio=self._completion_ratio(), | |
| ) | |
| info = { | |
| "score": grade_episode(self.state()) if self.done else None, | |
| "completion_ratio": self._completion_ratio(), | |
| "metrics": self.metrics.model_dump(), | |
| "terminated": bool(self.done and self.success), | |
| "truncated": bool(self.done and not self.success and self.step_count >= self.task.max_steps), | |
| "termination_reason": self.termination_reason, | |
| } | |
| return observation, reward, self.done, info | |
| def state(self) -> WarehouseState: | |
| return WarehouseState( | |
| episode_id=self.episode_id, | |
| task_id=self.task.task_id, | |
| difficulty=self.task.difficulty, | |
| step_count=self.step_count, | |
| done=self.done, | |
| success=self.success, | |
| max_steps=self.task.max_steps, | |
| grid_size=self.grid_size, | |
| agent_position=self.agent_position, | |
| heading=self.heading, | |
| carrying=self.carrying, | |
| carrying_weight=self.carrying_weight, | |
| battery_level=self.battery_level, | |
| battery_capacity=self.task.battery_capacity, | |
| stamina_level=self.stamina_level, | |
| stamina_capacity=self.task.stamina_capacity, | |
| money=round(self.money, 2), | |
| profit_target=self.task.profit_target, | |
| dock_position=self.task.dock_position, | |
| pack_station_position=self.task.pack_station_position, | |
| charger_position=self.task.charger_position, | |
| obstacles=list(self.task.obstacles), | |
| bins=[self._clone_bin(bin_state) for bin_state in self.bins], | |
| order=[self._clone_order_line(line) for line in self.order], | |
| packed_order=[self._clone_order_line(line) for line in self.packed_order], | |
| scanned_bins=sorted(self.scanned_bins), | |
| metrics=WarehouseMetrics(**self.metrics.model_dump()), | |
| action_history=list(self.action_history), | |
| total_reward=round(self.total_reward, 4), | |
| completion_ratio=self._completion_ratio(), | |
| task_description=self.task.description, | |
| ) | |
| def _reset_runtime(self, task_id: str) -> None: | |
| self.task = self._clone_task(get_task(task_id)) | |
| self.episode_id = f"{task_id}-seed{self.seed}-ep{self._episode_counter + 1}" | |
| self.agent_position = self.task.agent_start | |
| self.heading = self.task.agent_heading | |
| self.battery_level = self.task.battery_capacity | |
| self.carrying: Optional[str] = None | |
| self.carrying_weight: int = 0 | |
| self.stamina_level: int = self.task.stamina_capacity | |
| self.money: float = 0.0 | |
| self.step_count = 0 | |
| self.done = False | |
| self.success = False | |
| self.termination_reason: Optional[str] = None | |
| self.total_reward = 0.0 | |
| self.metrics = WarehouseMetrics() | |
| self.scanned_bins: set[str] = set() | |
| self.action_history: List[str] = [] | |
| self.bins = [self._clone_bin(bin_state) for bin_state in self.task.bins] | |
| self.order = [self._clone_order_line(line) for line in self.task.order] | |
| self.packed_order = [OrderLine(sku=line.sku, quantity=0) for line in self.task.order] | |
| def _clone_bin(self, bin_state: BinState | Dict[str, Any]) -> BinState: | |
| payload = bin_state.model_dump() if hasattr(bin_state, "model_dump") else dict(bin_state) | |
| return BinState(**payload) | |
| def _clone_order_line(self, line: OrderLine | Dict[str, Any]) -> OrderLine: | |
| payload = line.model_dump() if hasattr(line, "model_dump") else dict(line) | |
| return OrderLine(**payload) | |
| def _clone_task(self, task: TaskDefinition | Dict[str, Any]) -> TaskDefinition: | |
| payload = task.model_dump() if hasattr(task, "model_dump") else dict(task) | |
| payload["bins"] = [self._clone_bin(bin_state) for bin_state in payload["bins"]] | |
| payload["order"] = [self._clone_order_line(line) for line in payload["order"]] | |
| return TaskDefinition(**payload) | |
| def _front_position(self) -> Tuple[int, int]: | |
| dx, dy = MOVE_DELTA[self.heading] | |
| return (self.agent_position[0] + dx, self.agent_position[1] + dy) | |
| def _front_bin(self) -> Optional[BinState]: | |
| pos = self._front_position() | |
| for bin_state in self.bins: | |
| if bin_state.position == pos: | |
| return bin_state | |
| return None | |
| def _front_cell_label(self) -> str: | |
| front = self._front_position() | |
| if not self._in_bounds(front): | |
| return "wall" | |
| if self._is_obstacle(front): | |
| return "obstacle" | |
| front_bin = self._front_bin() | |
| if front_bin: | |
| return f"bin {front_bin.bin_id} ({front_bin.sku})" | |
| if front == self.task.pack_station_position: | |
| return "pack station" | |
| if front == self.task.charger_position: | |
| return "charger" | |
| if front == self.task.dock_position: | |
| return "dock" | |
| if self.task.rest_position and front == self.task.rest_position: | |
| return "rest area" | |
| return "aisle" | |
| def _move_forward(self, reward: float) -> Tuple[float, str]: | |
| next_pos = self._front_position() | |
| if self._is_obstacle(next_pos): | |
| self.metrics.obstacle_collisions += 1 | |
| self.metrics.invalid_actions += 1 | |
| self._consume_battery(1) | |
| return reward - 0.12, "Blocked by an obstacle! Find another route." | |
| if not self._in_bounds(next_pos) or self._occupied(next_pos): | |
| self.metrics.invalid_actions += 1 | |
| self._consume_battery(1) | |
| return reward - 0.08, "Forward move blocked by warehouse infrastructure." | |
| battery_cost = 2 | |
| weight_penalty = self.carrying_weight if self.carrying_weight > 1 else 0 | |
| battery_cost += weight_penalty | |
| if self._has_stamina() and self.stamina_level <= 0: | |
| battery_cost *= 2 | |
| self.agent_position = next_pos | |
| self.metrics.distance_travelled += 1 | |
| self._consume_battery(battery_cost) | |
| self._consume_stamina(self.task.stamina_move_cost) | |
| return reward, f"Moved to aisle cell {self.agent_position}." | |
| def _scan_bin(self, reward: float) -> Tuple[float, str]: | |
| bin_state = self._front_bin() | |
| if not bin_state: | |
| self.metrics.invalid_actions += 1 | |
| self._consume_battery(1) | |
| return reward - 0.08, "No bin in front to scan." | |
| self._consume_battery(1) | |
| if bin_state.bin_id not in self.scanned_bins: | |
| self.scanned_bins.add(bin_state.bin_id) | |
| if bin_state.bin_id in self.task.required_scans: | |
| self.metrics.correct_scans += 1 | |
| return reward + 0.12, f"Scanned {bin_state.bin_id}; confirmed {bin_state.sku}." | |
| self.metrics.wrong_scans += 1 | |
| return reward - 0.02, f"Scanned {bin_state.bin_id}; item not needed for this order." | |
| return reward - 0.01, f"Bin {bin_state.bin_id} was already scanned." | |
| def _pick_item(self, reward: float) -> Tuple[float, str]: | |
| bin_state = self._front_bin() | |
| if not bin_state: | |
| self.metrics.invalid_actions += 1 | |
| self._consume_battery(1) | |
| return reward - 0.08, "No pick face in front of the agent." | |
| if self.carrying is not None: | |
| self.metrics.invalid_actions += 1 | |
| self._consume_battery(1) | |
| return reward - 0.08, f"Hands already occupied with {self.carrying}." | |
| if bin_state.quantity <= 0: | |
| self.metrics.invalid_actions += 1 | |
| self._consume_battery(1) | |
| return reward - 0.10, f"Bin {bin_state.bin_id} is empty." | |
| if bin_state.weight > self.task.carry_capacity: | |
| self.metrics.overweight_attempts += 1 | |
| self.metrics.invalid_actions += 1 | |
| self._consume_battery(1) | |
| return reward - 0.12, ( | |
| f"Item {bin_state.sku} weighs {bin_state.weight} but carry capacity " | |
| f"is {self.task.carry_capacity}. Too heavy!" | |
| ) | |
| self._consume_battery(1) | |
| bin_state.quantity -= 1 | |
| self.carrying = bin_state.sku | |
| self.carrying_weight = bin_state.weight | |
| if self._remaining_quantity(bin_state.sku) > 0: | |
| self.metrics.correct_picks += 1 | |
| return reward + 0.20, f"Picked {bin_state.sku} (weight {bin_state.weight}) from {bin_state.bin_id}." | |
| self.metrics.wrong_picks += 1 | |
| return reward - 0.18, f"Picked {bin_state.sku}, which is not needed now." | |
| def _pack_item(self, reward: float) -> Tuple[float, str]: | |
| if self._front_position() != self.task.pack_station_position: | |
| self.metrics.invalid_actions += 1 | |
| self._consume_battery(1) | |
| return reward - 0.08, "Agent is not facing the pack station." | |
| if self.carrying is None: | |
| self.metrics.invalid_actions += 1 | |
| self._consume_battery(1) | |
| return reward - 0.08, "Nothing in hand to pack." | |
| self._consume_battery(1) | |
| remaining = self._remaining_quantity(self.carrying) | |
| if remaining <= 0: | |
| item = self.carrying | |
| item_value = self._item_value(item) | |
| self.carrying = None | |
| self.carrying_weight = 0 | |
| self.metrics.wrong_picks += 1 | |
| if item_value > 0: | |
| self.money -= item_value * 0.5 | |
| self.metrics.money_lost += item_value * 0.5 | |
| return reward - 0.15, f"Packed extra unit of {item}; order did not require it." | |
| item_value = self._item_value(self.carrying) | |
| for packed_line in self.packed_order: | |
| if packed_line.sku == self.carrying: | |
| packed_line.quantity += 1 | |
| break | |
| item = self.carrying | |
| self.carrying = None | |
| self.carrying_weight = 0 | |
| self.metrics.correct_packs += 1 | |
| if item_value > 0: | |
| self.money += item_value | |
| self.metrics.money_earned += item_value | |
| return reward + 0.35, f"Packed {item} at the station. (+${item_value:.2f})" | |
| def _recharge(self, reward: float) -> Tuple[float, str]: | |
| if self._front_position() != self.task.charger_position: | |
| self.metrics.invalid_actions += 1 | |
| self._consume_battery(1) | |
| return reward - 0.08, "Recharge action requires facing the charger." | |
| if self.battery_level >= self.task.battery_capacity: | |
| return reward - 0.03, "Battery already full." | |
| benefit = 0.08 if self.battery_level <= self.task.low_battery_threshold else -0.02 | |
| self.battery_level = self.task.battery_capacity | |
| self.metrics.recharges += 1 | |
| return reward + benefit, "Battery restored to full capacity." | |
| def _rest(self, reward: float) -> Tuple[float, str]: | |
| if not self._has_stamina(): | |
| self.metrics.invalid_actions += 1 | |
| return reward - 0.03, "This task has no stamina mechanic." | |
| if self.task.rest_position and self._front_position() != self.task.rest_position: | |
| self.metrics.invalid_actions += 1 | |
| self._consume_battery(1) | |
| return reward - 0.08, "Rest action requires facing the rest area." | |
| if self.stamina_level >= self.task.stamina_capacity: | |
| return reward - 0.03, "Stamina already full." | |
| benefit = 0.06 if self.stamina_level <= self.task.stamina_capacity // 4 else -0.02 | |
| self.stamina_level = self.task.stamina_capacity | |
| self.metrics.rest_events += 1 | |
| return reward + benefit, "Stamina restored to full capacity." | |
| def _build_observation(self, narrative: str) -> WarehouseObservation: | |
| nearby_bins = [] | |
| for bin_state in self.bins: | |
| distance = abs(bin_state.position[0] - self.agent_position[0]) + abs(bin_state.position[1] - self.agent_position[1]) | |
| if distance <= 2: | |
| nearby_bins.append(f"{bin_state.bin_id}:{bin_state.sku}:{bin_state.quantity}") | |
| pending = [] | |
| packed = [] | |
| for order_line, packed_line in zip(self.order, self.packed_order): | |
| pending_qty = max(0, order_line.quantity - packed_line.quantity) | |
| pending.append(PendingOrderLine(sku=order_line.sku, remaining=pending_qty)) | |
| packed.append(PackedOrderLine(sku=packed_line.sku, packed=packed_line.quantity)) | |
| return WarehouseObservation( | |
| task_id=self.task.task_id, | |
| mission=self.task.description, | |
| narrative=narrative, | |
| agent_position=self.agent_position, | |
| heading=self.heading, | |
| front_cell=self._front_cell_label(), | |
| carrying=self.carrying, | |
| carrying_weight=self.carrying_weight, | |
| battery_level=self.battery_level, | |
| stamina_level=self.stamina_level, | |
| money=round(self.money, 2), | |
| visible_bins=nearby_bins, | |
| pending_order=pending, | |
| packed_order=packed, | |
| progress_ratio=self._completion_ratio(), | |
| ) | |
| def _completion_ratio(self) -> float: | |
| total_required = sum(line.quantity for line in self.order) | |
| total_packed = sum(min(order_line.quantity, packed_line.quantity) for order_line, packed_line in zip(self.order, self.packed_order)) | |
| if total_required == 0: | |
| return 1.0 | |
| return round(total_packed / total_required, 4) | |
| def _remaining_quantity(self, sku: str) -> int: | |
| for order_line, packed_line in zip(self.order, self.packed_order): | |
| if order_line.sku == sku: | |
| return max(0, order_line.quantity - packed_line.quantity) | |
| return 0 | |
| def _all_order_lines_complete(self) -> bool: | |
| return all(self._remaining_quantity(line.sku) == 0 for line in self.order) | |
| def _consume_battery(self, amount: int) -> None: | |
| previous = self.battery_level | |
| self.battery_level = max(0, self.battery_level - amount) | |
| if previous > 0 and self.battery_level == 0: | |
| self.metrics.battery_depletion_events += 1 | |
| def _consume_stamina(self, amount: int) -> None: | |
| if not self._has_stamina(): | |
| return | |
| previous = self.stamina_level | |
| self.stamina_level = max(0, self.stamina_level - amount) | |
| if previous > 0 and self.stamina_level == 0: | |
| self.metrics.stamina_depletion_events += 1 | |
| def _has_stamina(self) -> bool: | |
| return self.task.stamina_capacity > 0 | |
| def _is_obstacle(self, position: Tuple[int, int]) -> bool: | |
| return tuple(position) in {tuple(o) for o in self.task.obstacles} | |
| def _item_value(self, sku: str) -> float: | |
| for bin_state in self.task.bins: | |
| if bin_state.sku == sku: | |
| return bin_state.value | |
| return 0.0 | |
| def _is_episode_complete(self) -> bool: | |
| if not self._all_order_lines_complete(): | |
| return False | |
| if self.task.profit_target > 0 and self.money < self.task.profit_target: | |
| return False | |
| return True | |
| def _in_bounds(self, position: Tuple[int, int]) -> bool: | |
| return 0 <= position[0] < self.grid_size[0] and 0 <= position[1] < self.grid_size[1] | |
| def _occupied(self, position: Tuple[int, int]) -> bool: | |
| fixed = {self.task.pack_station_position, self.task.charger_position, self.task.dock_position} | |
| if self.task.rest_position: | |
| fixed.add(self.task.rest_position) | |
| if position in fixed: | |
| return True | |
| if self._is_obstacle(position): | |
| return True | |
| return any(bin_state.position == position for bin_state in self.bins) | |
| def available_tasks() -> List[Dict[str, str]]: | |
| return [ | |
| { | |
| "task_id": task.task_id, | |
| "difficulty": task.difficulty, | |
| "title": task.title, | |
| "description": task.description, | |
| } | |
| for task in TASKS.values() | |
| ] | |