""" Runner for the warehouse fulfillment environment. By default this executes a deterministic planner that solves all tasks reproducibly. If OpenAI credentials are configured, it can also run a model policy against the same environment. """ from __future__ import annotations import argparse import json import os from collections import deque from typing import Any, Dict, List, Sequence, Tuple from grid_env.env import WarehouseFulfillmentEnv from grid_env.graders import grade_episode from grid_env.models import BaselineCommand, WarehouseObservation, WarehouseState, model_to_dict from grid_env.tasks import TASKS try: from openai import OpenAI except ImportError: # pragma: no cover OpenAI = None # type: ignore[assignment] HEADINGS = ["N", "E", "S", "W"] MOVE_DELTA = { "N": (0, -1), "E": (1, 0), "S": (0, 1), "W": (-1, 0), } SYSTEM_PROMPT = """You control a warehouse fulfillment robot. Return exactly one JSON object with: - command: one of turn_left, turn_right, move_forward, scan_bin, pick_item, pack_item, recharge, wait - rationale: a short sentence """ def _adjacent_goal_positions( target: Tuple[int, int], blocked: set[Tuple[int, int]], grid_size: Tuple[int, int], ) -> List[Tuple[Tuple[int, int], str]]: candidates = [] for heading, (dx, dy) in MOVE_DELTA.items(): pos = (target[0] - dx, target[1] - dy) if 0 <= pos[0] < grid_size[0] and 0 <= pos[1] < grid_size[1] and pos not in blocked: candidates.append((pos, heading)) return candidates def _neighbors( position: Tuple[int, int], blocked: set[Tuple[int, int]], grid_size: Tuple[int, int], ) -> List[Tuple[int, int]]: results = [] for dx, dy in MOVE_DELTA.values(): nxt = (position[0] + dx, position[1] + dy) if 0 <= nxt[0] < grid_size[0] and 0 <= nxt[1] < grid_size[1] and nxt not in blocked: results.append(nxt) return results def _bfs_path( start: Tuple[int, int], goals: Sequence[Tuple[int, int]], blocked: set[Tuple[int, int]], grid_size: Tuple[int, int], ) -> List[Tuple[int, int]]: goal_set = set(goals) queue = deque([start]) came_from: Dict[Tuple[int, int], Tuple[int, int] | None] = {start: None} found = None while queue: current = queue.popleft() if current in goal_set: found = current break for nxt in _neighbors(current, blocked, grid_size): if nxt not in came_from: came_from[nxt] = current queue.append(nxt) if found is None: raise RuntimeError("No path to target.") path = [] current = found while current != start: path.append(current) current = came_from[current] path.reverse() return path def _rotate_actions(current_heading: str, desired_heading: str) -> List[str]: current_idx = HEADINGS.index(current_heading) desired_idx = HEADINGS.index(desired_heading) right_turns = (desired_idx - current_idx) % 4 left_turns = (current_idx - desired_idx) % 4 if right_turns <= left_turns: return ["turn_right"] * right_turns return ["turn_left"] * left_turns def _move_adjacent_and_face(env: WarehouseFulfillmentEnv, target: Tuple[int, int]) -> List[str]: state = env.state() blocked = {bin_state.position for bin_state in state.bins} blocked.update({state.pack_station_position, state.charger_position, state.dock_position}) if state.agent_position in blocked: blocked.remove(state.agent_position) candidates = _adjacent_goal_positions(target, blocked, state.grid_size) positions = [pos for pos, _ in candidates] path = _bfs_path(state.agent_position, positions, blocked, state.grid_size) planned_actions: List[str] = [] current_heading = state.heading current_position = state.agent_position for step in path: dx = step[0] - current_position[0] dy = step[1] - current_position[1] desired_heading = next(k for k, v in MOVE_DELTA.items() if v == (dx, dy)) turns = _rotate_actions(current_heading, desired_heading) planned_actions.extend(turns) planned_actions.append("move_forward") current_heading = desired_heading current_position = step for pos, heading in candidates: if pos == current_position: planned_actions.extend(_rotate_actions(current_heading, heading)) break return planned_actions def _maybe_recharge_plan(env: WarehouseFulfillmentEnv) -> List[str]: state = env.state() distance_to_charger = abs(state.agent_position[0] - state.charger_position[0]) + abs( state.agent_position[1] - state.charger_position[1] ) threshold = max(6, (2 * distance_to_charger) + 4) if state.battery_level > threshold: return [] return _move_adjacent_and_face(env, state.charger_position) + ["recharge"] def planned_actions_for_task(env: WarehouseFulfillmentEnv) -> List[str]: actions: List[str] = [] state = env.state() sku_to_bin = {bin_state.sku: bin_state for bin_state in state.bins} for order_line in state.order: for _ in range(order_line.quantity): actions.extend(_maybe_recharge_plan(env)) for action in actions[len(env.state().action_history):]: env.step(action) bin_state = sku_to_bin[order_line.sku] path_to_bin = _move_adjacent_and_face(env, bin_state.position) actions.extend(path_to_bin) for action in path_to_bin: env.step(action) if bin_state.bin_id not in env.state().scanned_bins: actions.append("scan_bin") env.step("scan_bin") actions.append("pick_item") env.step("pick_item") recharge_path = _maybe_recharge_plan(env) actions.extend(recharge_path) for action in recharge_path: env.step(action) path_to_pack = _move_adjacent_and_face(env, env.state().pack_station_position) actions.extend(path_to_pack) for action in path_to_pack: env.step(action) actions.append("pack_item") env.step("pack_item") return actions def heuristic_next_action(env: WarehouseFulfillmentEnv, cached_plan: List[str]) -> str: state = env.state() if state.step_count < len(cached_plan): return cached_plan[state.step_count] if state.done: return "wait" return "wait" def build_openai_prompt(observation: WarehouseObservation, state: WarehouseState) -> str: payload = { "mission": observation.mission, "observation": model_to_dict(observation), "state_summary": { "step_count": state.step_count, "max_steps": state.max_steps, "battery_level": state.battery_level, "carrying": state.carrying, "scanned_bins": state.scanned_bins, "completion_ratio": state.completion_ratio, "recent_actions": state.action_history[-6:], }, } return json.dumps(payload, indent=2, sort_keys=True) def openai_next_action( client: Any, model: str, observation: WarehouseObservation, state: WarehouseState, ) -> str: response = client.responses.create( model=model, input=[ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": build_openai_prompt(observation, state)}, ], text={ "format": { "type": "json_schema", "name": "warehouse_action", "strict": True, "schema": BaselineCommand.model_json_schema(), } }, ) content = getattr(response, "output_text", "").strip() if not content: return "wait" payload = json.loads(content) return BaselineCommand(**payload).command def run_episode(task_id: str, seed: int, policy: str, model: str | None) -> Dict[str, float]: env_for_plan = WarehouseFulfillmentEnv(task_id=task_id, seed=seed) env_for_plan.reset(task_id=task_id, seed=seed) cached_plan = planned_actions_for_task(env_for_plan) env = WarehouseFulfillmentEnv(task_id=task_id, seed=seed) observation = env.reset(task_id=task_id, seed=seed) client = None if policy == "openai": if OpenAI is None: raise RuntimeError("The openai package is not installed.") api_key = os.environ.get("OPENAI_API_KEY") or os.environ.get("HF_TOKEN") or os.environ.get("API_KEY") if not api_key: raise RuntimeError("OPENAI_API_KEY (or HF_TOKEN/API_KEY) is not set.") base_url = os.environ.get("API_BASE_URL") client = OpenAI(api_key=api_key, base_url=base_url) done = False while not done: state = env.state() if policy == "openai": command = openai_next_action(client, model or os.environ.get("MODEL_NAME", "gpt-4.1-mini"), observation, state) else: command = heuristic_next_action(env, cached_plan) observation, reward, done, info = env.step(command) print( f"[{task_id}] step={state.step_count + 1} action={command} " f"reward={reward.value:+.2f} done={done}" ) final_state = env.state() return { "task_id": task_id, "reward": round(final_state.total_reward, 4), "score": grade_episode(final_state), "steps": float(final_state.step_count), "success": 1.0 if final_state.success else 0.0, } def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Run the warehouse fulfillment environment.") parser.add_argument("--task-id", choices=sorted(TASKS.keys()), help="Run a single task instead of all tasks.") parser.add_argument("--seed", type=int, default=7, help="Deterministic environment seed.") parser.add_argument( "--policy", choices=["heuristic", "openai"], default="heuristic", help="Action policy to use.", ) parser.add_argument( "--model", default=os.environ.get("MODEL_NAME") or os.environ.get("OPENAI_MODEL"), help="Model name for --policy openai.", ) return parser.parse_args() def main() -> None: args = parse_args() task_ids = [args.task_id] if args.task_id else list(TASKS.keys()) results = [run_episode(task_id, seed=args.seed, policy=args.policy, model=args.model) for task_id in task_ids] print("\ntask_id | score | reward | steps | success") for result in results: print( f"{result['task_id']} | {result['score']:.4f} | " f"{result['reward']:.4f} | {int(result['steps'])} | {int(result['success'])}" ) mean_score = sum(result["score"] for result in results) / len(results) print(f"mean_score | {mean_score:.4f}") if __name__ == "__main__": main()