""" MiniGrid-style warehouse fulfillment environment. """ from __future__ import annotations from typing import Any, Dict, List, Optional, Tuple from .graders import grade_episode from .models import ( BinState, OrderLine, PackedOrderLine, PendingOrderLine, TaskDefinition, WarehouseAction, WarehouseMetrics, WarehouseObservation, WarehouseReward, WarehouseState, ) from .tasks import GRID_SIZE, TASKS, get_task HEADINGS = ["N", "E", "S", "W"] MOVE_DELTA = { "N": (0, -1), "E": (1, 0), "S": (0, 1), "W": (-1, 0), } class WarehouseFulfillmentEnv: action_space = [ "turn_left", "turn_right", "move_forward", "scan_bin", "pick_item", "pack_item", "recharge", "rest", "wait", ] observation_space = [ "task_id", "mission", "narrative", "agent_position", "heading", "front_cell", "carrying", "carrying_weight", "battery_level", "stamina_level", "money", "visible_bins", "pending_order", "packed_order", "progress_ratio", ] def __init__(self, task_id: str = "easy_single_pick", seed: int = 7) -> None: self.grid_size = GRID_SIZE self.seed = seed self._episode_counter = 0 self.task: Optional[TaskDefinition] = None self._reset_runtime(task_id) def reset( self, task_id: Optional[str] = None, seed: Optional[int] = None, ) -> WarehouseObservation: if seed is not None: self.seed = seed self._episode_counter += 1 self._reset_runtime(task_id or self.task.task_id) return self._build_observation("Episode reset. Start the pick-pack workflow.") def step( self, action: WarehouseAction | str, ) -> Tuple[WarehouseObservation, WarehouseReward, bool, Dict[str, Any]]: # Accept string actions from callers but keep handling robust for unknowns. command = action.command if isinstance(action, WarehouseAction) else str(action) if self.done: observation = self._build_observation("Episode already complete.") reward = WarehouseReward( value=0.0, reason="Episode already complete.", completion_ratio=self._completion_ratio(), ) return observation, reward, True, { "score": grade_episode(self.state()), "completion_ratio": self._completion_ratio(), "metrics": self.metrics.model_dump(), "terminated": bool(self.success), "truncated": bool(self.step_count >= self.task.max_steps and not self.success), "termination_reason": self.termination_reason, } self.step_count += 1 reward_value = -0.01 narrative = "Action processed." prev_completion = self._completion_ratio() if command == "turn_left": self.heading = HEADINGS[(HEADINGS.index(self.heading) - 1) % len(HEADINGS)] self._consume_battery(1) narrative = f"Turned left. Now facing {self.heading}." elif command == "turn_right": self.heading = HEADINGS[(HEADINGS.index(self.heading) + 1) % len(HEADINGS)] self._consume_battery(1) narrative = f"Turned right. Now facing {self.heading}." elif command == "move_forward": reward_value, narrative = self._move_forward(reward_value) elif command == "scan_bin": reward_value, narrative = self._scan_bin(reward_value) elif command == "pick_item": reward_value, narrative = self._pick_item(reward_value) elif command == "pack_item": reward_value, narrative = self._pack_item(reward_value) elif command == "recharge": reward_value, narrative = self._recharge(reward_value) elif command == "rest": reward_value, narrative = self._rest(reward_value) elif command == "wait": reward_value -= 0.01 narrative = "Waited in place and lost time." else: self.metrics.invalid_actions += 1 reward_value -= 0.10 narrative = f"Unknown action: {command}." completion_now = self._completion_ratio() progress_delta = max(0.0, completion_now - prev_completion) if progress_delta > 0.0: # Small dense shaping for moving the order toward completion. reward_value += 0.15 * progress_delta self.action_history.append(command) is_complete = self._is_episode_complete() hit_step_limit = self.step_count >= self.task.max_steps self.done = is_complete or hit_step_limit self.success = is_complete self.termination_reason = ( "task_complete" if is_complete else ("max_steps_reached" if hit_step_limit else None) ) if self.success: reward_value += 0.50 narrative = "Order fully packed and ready for dispatch." self.total_reward += reward_value observation = self._build_observation(narrative) reward = WarehouseReward( value=round(reward_value, 4), reason=narrative, completion_ratio=self._completion_ratio(), ) info = { "score": grade_episode(self.state()) if self.done else None, "completion_ratio": self._completion_ratio(), "metrics": self.metrics.model_dump(), "terminated": bool(self.done and self.success), "truncated": bool(self.done and not self.success and self.step_count >= self.task.max_steps), "termination_reason": self.termination_reason, } return observation, reward, self.done, info def state(self) -> WarehouseState: return WarehouseState( episode_id=self.episode_id, task_id=self.task.task_id, difficulty=self.task.difficulty, step_count=self.step_count, done=self.done, success=self.success, max_steps=self.task.max_steps, grid_size=self.grid_size, agent_position=self.agent_position, heading=self.heading, carrying=self.carrying, carrying_weight=self.carrying_weight, battery_level=self.battery_level, battery_capacity=self.task.battery_capacity, stamina_level=self.stamina_level, stamina_capacity=self.task.stamina_capacity, money=round(self.money, 2), profit_target=self.task.profit_target, dock_position=self.task.dock_position, pack_station_position=self.task.pack_station_position, charger_position=self.task.charger_position, obstacles=list(self.task.obstacles), bins=[self._clone_bin(bin_state) for bin_state in self.bins], order=[self._clone_order_line(line) for line in self.order], packed_order=[self._clone_order_line(line) for line in self.packed_order], scanned_bins=sorted(self.scanned_bins), metrics=WarehouseMetrics(**self.metrics.model_dump()), action_history=list(self.action_history), total_reward=round(self.total_reward, 4), completion_ratio=self._completion_ratio(), task_description=self.task.description, ) def _reset_runtime(self, task_id: str) -> None: self.task = self._clone_task(get_task(task_id)) self.episode_id = f"{task_id}-seed{self.seed}-ep{self._episode_counter + 1}" self.agent_position = self.task.agent_start self.heading = self.task.agent_heading self.battery_level = self.task.battery_capacity self.carrying: Optional[str] = None self.carrying_weight: int = 0 self.stamina_level: int = self.task.stamina_capacity self.money: float = 0.0 self.step_count = 0 self.done = False self.success = False self.termination_reason: Optional[str] = None self.total_reward = 0.0 self.metrics = WarehouseMetrics() self.scanned_bins: set[str] = set() self.action_history: List[str] = [] self.bins = [self._clone_bin(bin_state) for bin_state in self.task.bins] self.order = [self._clone_order_line(line) for line in self.task.order] self.packed_order = [OrderLine(sku=line.sku, quantity=0) for line in self.task.order] def _clone_bin(self, bin_state: BinState | Dict[str, Any]) -> BinState: payload = bin_state.model_dump() if hasattr(bin_state, "model_dump") else dict(bin_state) return BinState(**payload) def _clone_order_line(self, line: OrderLine | Dict[str, Any]) -> OrderLine: payload = line.model_dump() if hasattr(line, "model_dump") else dict(line) return OrderLine(**payload) def _clone_task(self, task: TaskDefinition | Dict[str, Any]) -> TaskDefinition: payload = task.model_dump() if hasattr(task, "model_dump") else dict(task) payload["bins"] = [self._clone_bin(bin_state) for bin_state in payload["bins"]] payload["order"] = [self._clone_order_line(line) for line in payload["order"]] return TaskDefinition(**payload) def _front_position(self) -> Tuple[int, int]: dx, dy = MOVE_DELTA[self.heading] return (self.agent_position[0] + dx, self.agent_position[1] + dy) def _front_bin(self) -> Optional[BinState]: pos = self._front_position() for bin_state in self.bins: if bin_state.position == pos: return bin_state return None def _front_cell_label(self) -> str: front = self._front_position() if not self._in_bounds(front): return "wall" if self._is_obstacle(front): return "obstacle" front_bin = self._front_bin() if front_bin: return f"bin {front_bin.bin_id} ({front_bin.sku})" if front == self.task.pack_station_position: return "pack station" if front == self.task.charger_position: return "charger" if front == self.task.dock_position: return "dock" if self.task.rest_position and front == self.task.rest_position: return "rest area" return "aisle" def _move_forward(self, reward: float) -> Tuple[float, str]: next_pos = self._front_position() if self._is_obstacle(next_pos): self.metrics.obstacle_collisions += 1 self.metrics.invalid_actions += 1 self._consume_battery(1) return reward - 0.12, "Blocked by an obstacle! Find another route." if not self._in_bounds(next_pos) or self._occupied(next_pos): self.metrics.invalid_actions += 1 self._consume_battery(1) return reward - 0.08, "Forward move blocked by warehouse infrastructure." battery_cost = 2 weight_penalty = self.carrying_weight if self.carrying_weight > 1 else 0 battery_cost += weight_penalty if self._has_stamina() and self.stamina_level <= 0: battery_cost *= 2 self.agent_position = next_pos self.metrics.distance_travelled += 1 self._consume_battery(battery_cost) self._consume_stamina(self.task.stamina_move_cost) return reward, f"Moved to aisle cell {self.agent_position}." def _scan_bin(self, reward: float) -> Tuple[float, str]: bin_state = self._front_bin() if not bin_state: self.metrics.invalid_actions += 1 self._consume_battery(1) return reward - 0.08, "No bin in front to scan." self._consume_battery(1) if bin_state.bin_id not in self.scanned_bins: self.scanned_bins.add(bin_state.bin_id) if bin_state.bin_id in self.task.required_scans: self.metrics.correct_scans += 1 return reward + 0.12, f"Scanned {bin_state.bin_id}; confirmed {bin_state.sku}." self.metrics.wrong_scans += 1 return reward - 0.02, f"Scanned {bin_state.bin_id}; item not needed for this order." return reward - 0.01, f"Bin {bin_state.bin_id} was already scanned." def _pick_item(self, reward: float) -> Tuple[float, str]: bin_state = self._front_bin() if not bin_state: self.metrics.invalid_actions += 1 self._consume_battery(1) return reward - 0.08, "No pick face in front of the agent." if self.carrying is not None: self.metrics.invalid_actions += 1 self._consume_battery(1) return reward - 0.08, f"Hands already occupied with {self.carrying}." if bin_state.quantity <= 0: self.metrics.invalid_actions += 1 self._consume_battery(1) return reward - 0.10, f"Bin {bin_state.bin_id} is empty." if bin_state.weight > self.task.carry_capacity: self.metrics.overweight_attempts += 1 self.metrics.invalid_actions += 1 self._consume_battery(1) return reward - 0.12, ( f"Item {bin_state.sku} weighs {bin_state.weight} but carry capacity " f"is {self.task.carry_capacity}. Too heavy!" ) self._consume_battery(1) bin_state.quantity -= 1 self.carrying = bin_state.sku self.carrying_weight = bin_state.weight if self._remaining_quantity(bin_state.sku) > 0: self.metrics.correct_picks += 1 return reward + 0.20, f"Picked {bin_state.sku} (weight {bin_state.weight}) from {bin_state.bin_id}." self.metrics.wrong_picks += 1 return reward - 0.18, f"Picked {bin_state.sku}, which is not needed now." def _pack_item(self, reward: float) -> Tuple[float, str]: if self._front_position() != self.task.pack_station_position: self.metrics.invalid_actions += 1 self._consume_battery(1) return reward - 0.08, "Agent is not facing the pack station." if self.carrying is None: self.metrics.invalid_actions += 1 self._consume_battery(1) return reward - 0.08, "Nothing in hand to pack." self._consume_battery(1) remaining = self._remaining_quantity(self.carrying) if remaining <= 0: item = self.carrying item_value = self._item_value(item) self.carrying = None self.carrying_weight = 0 self.metrics.wrong_picks += 1 if item_value > 0: self.money -= item_value * 0.5 self.metrics.money_lost += item_value * 0.5 return reward - 0.15, f"Packed extra unit of {item}; order did not require it." item_value = self._item_value(self.carrying) for packed_line in self.packed_order: if packed_line.sku == self.carrying: packed_line.quantity += 1 break item = self.carrying self.carrying = None self.carrying_weight = 0 self.metrics.correct_packs += 1 if item_value > 0: self.money += item_value self.metrics.money_earned += item_value return reward + 0.35, f"Packed {item} at the station. (+${item_value:.2f})" def _recharge(self, reward: float) -> Tuple[float, str]: if self._front_position() != self.task.charger_position: self.metrics.invalid_actions += 1 self._consume_battery(1) return reward - 0.08, "Recharge action requires facing the charger." if self.battery_level >= self.task.battery_capacity: return reward - 0.03, "Battery already full." benefit = 0.08 if self.battery_level <= self.task.low_battery_threshold else -0.02 self.battery_level = self.task.battery_capacity self.metrics.recharges += 1 return reward + benefit, "Battery restored to full capacity." def _rest(self, reward: float) -> Tuple[float, str]: if not self._has_stamina(): self.metrics.invalid_actions += 1 return reward - 0.03, "This task has no stamina mechanic." if self.task.rest_position and self._front_position() != self.task.rest_position: self.metrics.invalid_actions += 1 self._consume_battery(1) return reward - 0.08, "Rest action requires facing the rest area." if self.stamina_level >= self.task.stamina_capacity: return reward - 0.03, "Stamina already full." benefit = 0.06 if self.stamina_level <= self.task.stamina_capacity // 4 else -0.02 self.stamina_level = self.task.stamina_capacity self.metrics.rest_events += 1 return reward + benefit, "Stamina restored to full capacity." def _build_observation(self, narrative: str) -> WarehouseObservation: nearby_bins = [] for bin_state in self.bins: distance = abs(bin_state.position[0] - self.agent_position[0]) + abs(bin_state.position[1] - self.agent_position[1]) if distance <= 2: nearby_bins.append(f"{bin_state.bin_id}:{bin_state.sku}:{bin_state.quantity}") pending = [] packed = [] for order_line, packed_line in zip(self.order, self.packed_order): pending_qty = max(0, order_line.quantity - packed_line.quantity) pending.append(PendingOrderLine(sku=order_line.sku, remaining=pending_qty)) packed.append(PackedOrderLine(sku=packed_line.sku, packed=packed_line.quantity)) return WarehouseObservation( task_id=self.task.task_id, mission=self.task.description, narrative=narrative, agent_position=self.agent_position, heading=self.heading, front_cell=self._front_cell_label(), carrying=self.carrying, carrying_weight=self.carrying_weight, battery_level=self.battery_level, stamina_level=self.stamina_level, money=round(self.money, 2), visible_bins=nearby_bins, pending_order=pending, packed_order=packed, progress_ratio=self._completion_ratio(), ) def _completion_ratio(self) -> float: total_required = sum(line.quantity for line in self.order) total_packed = sum(min(order_line.quantity, packed_line.quantity) for order_line, packed_line in zip(self.order, self.packed_order)) if total_required == 0: return 1.0 return round(total_packed / total_required, 4) def _remaining_quantity(self, sku: str) -> int: for order_line, packed_line in zip(self.order, self.packed_order): if order_line.sku == sku: return max(0, order_line.quantity - packed_line.quantity) return 0 def _all_order_lines_complete(self) -> bool: return all(self._remaining_quantity(line.sku) == 0 for line in self.order) def _consume_battery(self, amount: int) -> None: previous = self.battery_level self.battery_level = max(0, self.battery_level - amount) if previous > 0 and self.battery_level == 0: self.metrics.battery_depletion_events += 1 def _consume_stamina(self, amount: int) -> None: if not self._has_stamina(): return previous = self.stamina_level self.stamina_level = max(0, self.stamina_level - amount) if previous > 0 and self.stamina_level == 0: self.metrics.stamina_depletion_events += 1 def _has_stamina(self) -> bool: return self.task.stamina_capacity > 0 def _is_obstacle(self, position: Tuple[int, int]) -> bool: return tuple(position) in {tuple(o) for o in self.task.obstacles} def _item_value(self, sku: str) -> float: for bin_state in self.task.bins: if bin_state.sku == sku: return bin_state.value return 0.0 def _is_episode_complete(self) -> bool: if not self._all_order_lines_complete(): return False if self.task.profit_target > 0 and self.money < self.task.profit_target: return False return True def _in_bounds(self, position: Tuple[int, int]) -> bool: return 0 <= position[0] < self.grid_size[0] and 0 <= position[1] < self.grid_size[1] def _occupied(self, position: Tuple[int, int]) -> bool: fixed = {self.task.pack_station_position, self.task.charger_position, self.task.dock_position} if self.task.rest_position: fixed.add(self.task.rest_position) if position in fixed: return True if self._is_obstacle(position): return True return any(bin_state.position == position for bin_state in self.bins) def available_tasks() -> List[Dict[str, str]]: return [ { "task_id": task.task_id, "difficulty": task.difficulty, "title": task.title, "description": task.description, } for task in TASKS.values() ]