sohambose98's picture
updated state management
dda0af2
"""
MiniGrid-style warehouse fulfillment environment.
"""
from __future__ import annotations
from typing import Any, Dict, List, Optional, Tuple
from .graders import grade_episode
from .models import (
BinState,
OrderLine,
PackedOrderLine,
PendingOrderLine,
TaskDefinition,
WarehouseAction,
WarehouseMetrics,
WarehouseObservation,
WarehouseReward,
WarehouseState,
)
from .tasks import GRID_SIZE, TASKS, get_task
HEADINGS = ["N", "E", "S", "W"]
MOVE_DELTA = {
"N": (0, -1),
"E": (1, 0),
"S": (0, 1),
"W": (-1, 0),
}
class WarehouseFulfillmentEnv:
action_space = [
"turn_left",
"turn_right",
"move_forward",
"scan_bin",
"pick_item",
"pack_item",
"recharge",
"rest",
"wait",
]
observation_space = [
"task_id",
"mission",
"narrative",
"agent_position",
"heading",
"front_cell",
"carrying",
"carrying_weight",
"battery_level",
"stamina_level",
"money",
"visible_bins",
"pending_order",
"packed_order",
"progress_ratio",
]
def __init__(self, task_id: str = "easy_single_pick", seed: int = 7) -> None:
self.grid_size = GRID_SIZE
self.seed = seed
self._episode_counter = 0
self.task: Optional[TaskDefinition] = None
self._reset_runtime(task_id)
def reset(
self,
task_id: Optional[str] = None,
seed: Optional[int] = None,
) -> WarehouseObservation:
if seed is not None:
self.seed = seed
self._episode_counter += 1
self._reset_runtime(task_id or self.task.task_id)
return self._build_observation("Episode reset. Start the pick-pack workflow.")
def step(
self,
action: WarehouseAction | str,
) -> Tuple[WarehouseObservation, WarehouseReward, bool, Dict[str, Any]]:
# Accept string actions from callers but keep handling robust for unknowns.
command = action.command if isinstance(action, WarehouseAction) else str(action)
if self.done:
observation = self._build_observation("Episode already complete.")
reward = WarehouseReward(
value=0.0,
reason="Episode already complete.",
completion_ratio=self._completion_ratio(),
)
return observation, reward, True, {
"score": grade_episode(self.state()),
"completion_ratio": self._completion_ratio(),
"metrics": self.metrics.model_dump(),
"terminated": bool(self.success),
"truncated": bool(self.step_count >= self.task.max_steps and not self.success),
"termination_reason": self.termination_reason,
}
self.step_count += 1
reward_value = -0.01
narrative = "Action processed."
prev_completion = self._completion_ratio()
if command == "turn_left":
self.heading = HEADINGS[(HEADINGS.index(self.heading) - 1) % len(HEADINGS)]
self._consume_battery(1)
narrative = f"Turned left. Now facing {self.heading}."
elif command == "turn_right":
self.heading = HEADINGS[(HEADINGS.index(self.heading) + 1) % len(HEADINGS)]
self._consume_battery(1)
narrative = f"Turned right. Now facing {self.heading}."
elif command == "move_forward":
reward_value, narrative = self._move_forward(reward_value)
elif command == "scan_bin":
reward_value, narrative = self._scan_bin(reward_value)
elif command == "pick_item":
reward_value, narrative = self._pick_item(reward_value)
elif command == "pack_item":
reward_value, narrative = self._pack_item(reward_value)
elif command == "recharge":
reward_value, narrative = self._recharge(reward_value)
elif command == "rest":
reward_value, narrative = self._rest(reward_value)
elif command == "wait":
reward_value -= 0.01
narrative = "Waited in place and lost time."
else:
self.metrics.invalid_actions += 1
reward_value -= 0.10
narrative = f"Unknown action: {command}."
completion_now = self._completion_ratio()
progress_delta = max(0.0, completion_now - prev_completion)
if progress_delta > 0.0:
# Small dense shaping for moving the order toward completion.
reward_value += 0.15 * progress_delta
self.action_history.append(command)
is_complete = self._is_episode_complete()
hit_step_limit = self.step_count >= self.task.max_steps
self.done = is_complete or hit_step_limit
self.success = is_complete
self.termination_reason = (
"task_complete" if is_complete else ("max_steps_reached" if hit_step_limit else None)
)
if self.success:
reward_value += 0.50
narrative = "Order fully packed and ready for dispatch."
self.total_reward += reward_value
observation = self._build_observation(narrative)
reward = WarehouseReward(
value=round(reward_value, 4),
reason=narrative,
completion_ratio=self._completion_ratio(),
)
info = {
"score": grade_episode(self.state()) if self.done else None,
"completion_ratio": self._completion_ratio(),
"metrics": self.metrics.model_dump(),
"terminated": bool(self.done and self.success),
"truncated": bool(self.done and not self.success and self.step_count >= self.task.max_steps),
"termination_reason": self.termination_reason,
}
return observation, reward, self.done, info
def state(self) -> WarehouseState:
return WarehouseState(
episode_id=self.episode_id,
task_id=self.task.task_id,
difficulty=self.task.difficulty,
step_count=self.step_count,
done=self.done,
success=self.success,
max_steps=self.task.max_steps,
grid_size=self.grid_size,
agent_position=self.agent_position,
heading=self.heading,
carrying=self.carrying,
carrying_weight=self.carrying_weight,
battery_level=self.battery_level,
battery_capacity=self.task.battery_capacity,
stamina_level=self.stamina_level,
stamina_capacity=self.task.stamina_capacity,
money=round(self.money, 2),
profit_target=self.task.profit_target,
dock_position=self.task.dock_position,
pack_station_position=self.task.pack_station_position,
charger_position=self.task.charger_position,
obstacles=list(self.task.obstacles),
bins=[self._clone_bin(bin_state) for bin_state in self.bins],
order=[self._clone_order_line(line) for line in self.order],
packed_order=[self._clone_order_line(line) for line in self.packed_order],
scanned_bins=sorted(self.scanned_bins),
metrics=WarehouseMetrics(**self.metrics.model_dump()),
action_history=list(self.action_history),
total_reward=round(self.total_reward, 4),
completion_ratio=self._completion_ratio(),
task_description=self.task.description,
)
def _reset_runtime(self, task_id: str) -> None:
self.task = self._clone_task(get_task(task_id))
self.episode_id = f"{task_id}-seed{self.seed}-ep{self._episode_counter + 1}"
self.agent_position = self.task.agent_start
self.heading = self.task.agent_heading
self.battery_level = self.task.battery_capacity
self.carrying: Optional[str] = None
self.carrying_weight: int = 0
self.stamina_level: int = self.task.stamina_capacity
self.money: float = 0.0
self.step_count = 0
self.done = False
self.success = False
self.termination_reason: Optional[str] = None
self.total_reward = 0.0
self.metrics = WarehouseMetrics()
self.scanned_bins: set[str] = set()
self.action_history: List[str] = []
self.bins = [self._clone_bin(bin_state) for bin_state in self.task.bins]
self.order = [self._clone_order_line(line) for line in self.task.order]
self.packed_order = [OrderLine(sku=line.sku, quantity=0) for line in self.task.order]
def _clone_bin(self, bin_state: BinState | Dict[str, Any]) -> BinState:
payload = bin_state.model_dump() if hasattr(bin_state, "model_dump") else dict(bin_state)
return BinState(**payload)
def _clone_order_line(self, line: OrderLine | Dict[str, Any]) -> OrderLine:
payload = line.model_dump() if hasattr(line, "model_dump") else dict(line)
return OrderLine(**payload)
def _clone_task(self, task: TaskDefinition | Dict[str, Any]) -> TaskDefinition:
payload = task.model_dump() if hasattr(task, "model_dump") else dict(task)
payload["bins"] = [self._clone_bin(bin_state) for bin_state in payload["bins"]]
payload["order"] = [self._clone_order_line(line) for line in payload["order"]]
return TaskDefinition(**payload)
def _front_position(self) -> Tuple[int, int]:
dx, dy = MOVE_DELTA[self.heading]
return (self.agent_position[0] + dx, self.agent_position[1] + dy)
def _front_bin(self) -> Optional[BinState]:
pos = self._front_position()
for bin_state in self.bins:
if bin_state.position == pos:
return bin_state
return None
def _front_cell_label(self) -> str:
front = self._front_position()
if not self._in_bounds(front):
return "wall"
if self._is_obstacle(front):
return "obstacle"
front_bin = self._front_bin()
if front_bin:
return f"bin {front_bin.bin_id} ({front_bin.sku})"
if front == self.task.pack_station_position:
return "pack station"
if front == self.task.charger_position:
return "charger"
if front == self.task.dock_position:
return "dock"
if self.task.rest_position and front == self.task.rest_position:
return "rest area"
return "aisle"
def _move_forward(self, reward: float) -> Tuple[float, str]:
next_pos = self._front_position()
if self._is_obstacle(next_pos):
self.metrics.obstacle_collisions += 1
self.metrics.invalid_actions += 1
self._consume_battery(1)
return reward - 0.12, "Blocked by an obstacle! Find another route."
if not self._in_bounds(next_pos) or self._occupied(next_pos):
self.metrics.invalid_actions += 1
self._consume_battery(1)
return reward - 0.08, "Forward move blocked by warehouse infrastructure."
battery_cost = 2
weight_penalty = self.carrying_weight if self.carrying_weight > 1 else 0
battery_cost += weight_penalty
if self._has_stamina() and self.stamina_level <= 0:
battery_cost *= 2
self.agent_position = next_pos
self.metrics.distance_travelled += 1
self._consume_battery(battery_cost)
self._consume_stamina(self.task.stamina_move_cost)
return reward, f"Moved to aisle cell {self.agent_position}."
def _scan_bin(self, reward: float) -> Tuple[float, str]:
bin_state = self._front_bin()
if not bin_state:
self.metrics.invalid_actions += 1
self._consume_battery(1)
return reward - 0.08, "No bin in front to scan."
self._consume_battery(1)
if bin_state.bin_id not in self.scanned_bins:
self.scanned_bins.add(bin_state.bin_id)
if bin_state.bin_id in self.task.required_scans:
self.metrics.correct_scans += 1
return reward + 0.12, f"Scanned {bin_state.bin_id}; confirmed {bin_state.sku}."
self.metrics.wrong_scans += 1
return reward - 0.02, f"Scanned {bin_state.bin_id}; item not needed for this order."
return reward - 0.01, f"Bin {bin_state.bin_id} was already scanned."
def _pick_item(self, reward: float) -> Tuple[float, str]:
bin_state = self._front_bin()
if not bin_state:
self.metrics.invalid_actions += 1
self._consume_battery(1)
return reward - 0.08, "No pick face in front of the agent."
if self.carrying is not None:
self.metrics.invalid_actions += 1
self._consume_battery(1)
return reward - 0.08, f"Hands already occupied with {self.carrying}."
if bin_state.quantity <= 0:
self.metrics.invalid_actions += 1
self._consume_battery(1)
return reward - 0.10, f"Bin {bin_state.bin_id} is empty."
if bin_state.weight > self.task.carry_capacity:
self.metrics.overweight_attempts += 1
self.metrics.invalid_actions += 1
self._consume_battery(1)
return reward - 0.12, (
f"Item {bin_state.sku} weighs {bin_state.weight} but carry capacity "
f"is {self.task.carry_capacity}. Too heavy!"
)
self._consume_battery(1)
bin_state.quantity -= 1
self.carrying = bin_state.sku
self.carrying_weight = bin_state.weight
if self._remaining_quantity(bin_state.sku) > 0:
self.metrics.correct_picks += 1
return reward + 0.20, f"Picked {bin_state.sku} (weight {bin_state.weight}) from {bin_state.bin_id}."
self.metrics.wrong_picks += 1
return reward - 0.18, f"Picked {bin_state.sku}, which is not needed now."
def _pack_item(self, reward: float) -> Tuple[float, str]:
if self._front_position() != self.task.pack_station_position:
self.metrics.invalid_actions += 1
self._consume_battery(1)
return reward - 0.08, "Agent is not facing the pack station."
if self.carrying is None:
self.metrics.invalid_actions += 1
self._consume_battery(1)
return reward - 0.08, "Nothing in hand to pack."
self._consume_battery(1)
remaining = self._remaining_quantity(self.carrying)
if remaining <= 0:
item = self.carrying
item_value = self._item_value(item)
self.carrying = None
self.carrying_weight = 0
self.metrics.wrong_picks += 1
if item_value > 0:
self.money -= item_value * 0.5
self.metrics.money_lost += item_value * 0.5
return reward - 0.15, f"Packed extra unit of {item}; order did not require it."
item_value = self._item_value(self.carrying)
for packed_line in self.packed_order:
if packed_line.sku == self.carrying:
packed_line.quantity += 1
break
item = self.carrying
self.carrying = None
self.carrying_weight = 0
self.metrics.correct_packs += 1
if item_value > 0:
self.money += item_value
self.metrics.money_earned += item_value
return reward + 0.35, f"Packed {item} at the station. (+${item_value:.2f})"
def _recharge(self, reward: float) -> Tuple[float, str]:
if self._front_position() != self.task.charger_position:
self.metrics.invalid_actions += 1
self._consume_battery(1)
return reward - 0.08, "Recharge action requires facing the charger."
if self.battery_level >= self.task.battery_capacity:
return reward - 0.03, "Battery already full."
benefit = 0.08 if self.battery_level <= self.task.low_battery_threshold else -0.02
self.battery_level = self.task.battery_capacity
self.metrics.recharges += 1
return reward + benefit, "Battery restored to full capacity."
def _rest(self, reward: float) -> Tuple[float, str]:
if not self._has_stamina():
self.metrics.invalid_actions += 1
return reward - 0.03, "This task has no stamina mechanic."
if self.task.rest_position and self._front_position() != self.task.rest_position:
self.metrics.invalid_actions += 1
self._consume_battery(1)
return reward - 0.08, "Rest action requires facing the rest area."
if self.stamina_level >= self.task.stamina_capacity:
return reward - 0.03, "Stamina already full."
benefit = 0.06 if self.stamina_level <= self.task.stamina_capacity // 4 else -0.02
self.stamina_level = self.task.stamina_capacity
self.metrics.rest_events += 1
return reward + benefit, "Stamina restored to full capacity."
def _build_observation(self, narrative: str) -> WarehouseObservation:
nearby_bins = []
for bin_state in self.bins:
distance = abs(bin_state.position[0] - self.agent_position[0]) + abs(bin_state.position[1] - self.agent_position[1])
if distance <= 2:
nearby_bins.append(f"{bin_state.bin_id}:{bin_state.sku}:{bin_state.quantity}")
pending = []
packed = []
for order_line, packed_line in zip(self.order, self.packed_order):
pending_qty = max(0, order_line.quantity - packed_line.quantity)
pending.append(PendingOrderLine(sku=order_line.sku, remaining=pending_qty))
packed.append(PackedOrderLine(sku=packed_line.sku, packed=packed_line.quantity))
return WarehouseObservation(
task_id=self.task.task_id,
mission=self.task.description,
narrative=narrative,
agent_position=self.agent_position,
heading=self.heading,
front_cell=self._front_cell_label(),
carrying=self.carrying,
carrying_weight=self.carrying_weight,
battery_level=self.battery_level,
stamina_level=self.stamina_level,
money=round(self.money, 2),
visible_bins=nearby_bins,
pending_order=pending,
packed_order=packed,
progress_ratio=self._completion_ratio(),
)
def _completion_ratio(self) -> float:
total_required = sum(line.quantity for line in self.order)
total_packed = sum(min(order_line.quantity, packed_line.quantity) for order_line, packed_line in zip(self.order, self.packed_order))
if total_required == 0:
return 1.0
return round(total_packed / total_required, 4)
def _remaining_quantity(self, sku: str) -> int:
for order_line, packed_line in zip(self.order, self.packed_order):
if order_line.sku == sku:
return max(0, order_line.quantity - packed_line.quantity)
return 0
def _all_order_lines_complete(self) -> bool:
return all(self._remaining_quantity(line.sku) == 0 for line in self.order)
def _consume_battery(self, amount: int) -> None:
previous = self.battery_level
self.battery_level = max(0, self.battery_level - amount)
if previous > 0 and self.battery_level == 0:
self.metrics.battery_depletion_events += 1
def _consume_stamina(self, amount: int) -> None:
if not self._has_stamina():
return
previous = self.stamina_level
self.stamina_level = max(0, self.stamina_level - amount)
if previous > 0 and self.stamina_level == 0:
self.metrics.stamina_depletion_events += 1
def _has_stamina(self) -> bool:
return self.task.stamina_capacity > 0
def _is_obstacle(self, position: Tuple[int, int]) -> bool:
return tuple(position) in {tuple(o) for o in self.task.obstacles}
def _item_value(self, sku: str) -> float:
for bin_state in self.task.bins:
if bin_state.sku == sku:
return bin_state.value
return 0.0
def _is_episode_complete(self) -> bool:
if not self._all_order_lines_complete():
return False
if self.task.profit_target > 0 and self.money < self.task.profit_target:
return False
return True
def _in_bounds(self, position: Tuple[int, int]) -> bool:
return 0 <= position[0] < self.grid_size[0] and 0 <= position[1] < self.grid_size[1]
def _occupied(self, position: Tuple[int, int]) -> bool:
fixed = {self.task.pack_station_position, self.task.charger_position, self.task.dock_position}
if self.task.rest_position:
fixed.add(self.task.rest_position)
if position in fixed:
return True
if self._is_obstacle(position):
return True
return any(bin_state.position == position for bin_state in self.bins)
def available_tasks() -> List[Dict[str, str]]:
return [
{
"task_id": task.task_id,
"difficulty": task.difficulty,
"title": task.title,
"description": task.description,
}
for task in TASKS.values()
]