Spaces:
Sleeping
Sleeping
| """ | |
| Runner for the warehouse fulfillment environment. | |
| By default this executes a deterministic planner that solves all tasks | |
| reproducibly. If OpenAI credentials are configured, it can also run a model | |
| policy against the same environment. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import os | |
| from collections import deque | |
| from typing import Any, Dict, List, Sequence, Tuple | |
| from grid_env.env import WarehouseFulfillmentEnv | |
| from grid_env.graders import grade_episode | |
| from grid_env.models import BaselineCommand, WarehouseObservation, WarehouseState, model_to_dict | |
| from grid_env.tasks import TASKS | |
| try: | |
| from openai import OpenAI | |
| except ImportError: # pragma: no cover | |
| OpenAI = None # type: ignore[assignment] | |
| HEADINGS = ["N", "E", "S", "W"] | |
| MOVE_DELTA = { | |
| "N": (0, -1), | |
| "E": (1, 0), | |
| "S": (0, 1), | |
| "W": (-1, 0), | |
| } | |
| SYSTEM_PROMPT = """You control a warehouse fulfillment robot. | |
| Return exactly one JSON object with: | |
| - command: one of turn_left, turn_right, move_forward, scan_bin, pick_item, pack_item, recharge, wait | |
| - rationale: a short sentence | |
| """ | |
| def _adjacent_goal_positions( | |
| target: Tuple[int, int], | |
| blocked: set[Tuple[int, int]], | |
| grid_size: Tuple[int, int], | |
| ) -> List[Tuple[Tuple[int, int], str]]: | |
| candidates = [] | |
| for heading, (dx, dy) in MOVE_DELTA.items(): | |
| pos = (target[0] - dx, target[1] - dy) | |
| if 0 <= pos[0] < grid_size[0] and 0 <= pos[1] < grid_size[1] and pos not in blocked: | |
| candidates.append((pos, heading)) | |
| return candidates | |
| def _neighbors( | |
| position: Tuple[int, int], | |
| blocked: set[Tuple[int, int]], | |
| grid_size: Tuple[int, int], | |
| ) -> List[Tuple[int, int]]: | |
| results = [] | |
| for dx, dy in MOVE_DELTA.values(): | |
| nxt = (position[0] + dx, position[1] + dy) | |
| if 0 <= nxt[0] < grid_size[0] and 0 <= nxt[1] < grid_size[1] and nxt not in blocked: | |
| results.append(nxt) | |
| return results | |
| def _bfs_path( | |
| start: Tuple[int, int], | |
| goals: Sequence[Tuple[int, int]], | |
| blocked: set[Tuple[int, int]], | |
| grid_size: Tuple[int, int], | |
| ) -> List[Tuple[int, int]]: | |
| goal_set = set(goals) | |
| queue = deque([start]) | |
| came_from: Dict[Tuple[int, int], Tuple[int, int] | None] = {start: None} | |
| found = None | |
| while queue: | |
| current = queue.popleft() | |
| if current in goal_set: | |
| found = current | |
| break | |
| for nxt in _neighbors(current, blocked, grid_size): | |
| if nxt not in came_from: | |
| came_from[nxt] = current | |
| queue.append(nxt) | |
| if found is None: | |
| raise RuntimeError("No path to target.") | |
| path = [] | |
| current = found | |
| while current != start: | |
| path.append(current) | |
| current = came_from[current] | |
| path.reverse() | |
| return path | |
| def _rotate_actions(current_heading: str, desired_heading: str) -> List[str]: | |
| current_idx = HEADINGS.index(current_heading) | |
| desired_idx = HEADINGS.index(desired_heading) | |
| right_turns = (desired_idx - current_idx) % 4 | |
| left_turns = (current_idx - desired_idx) % 4 | |
| if right_turns <= left_turns: | |
| return ["turn_right"] * right_turns | |
| return ["turn_left"] * left_turns | |
| def _move_adjacent_and_face(env: WarehouseFulfillmentEnv, target: Tuple[int, int]) -> List[str]: | |
| state = env.state() | |
| blocked = {bin_state.position for bin_state in state.bins} | |
| blocked.update({state.pack_station_position, state.charger_position, state.dock_position}) | |
| if state.agent_position in blocked: | |
| blocked.remove(state.agent_position) | |
| candidates = _adjacent_goal_positions(target, blocked, state.grid_size) | |
| positions = [pos for pos, _ in candidates] | |
| path = _bfs_path(state.agent_position, positions, blocked, state.grid_size) | |
| planned_actions: List[str] = [] | |
| current_heading = state.heading | |
| current_position = state.agent_position | |
| for step in path: | |
| dx = step[0] - current_position[0] | |
| dy = step[1] - current_position[1] | |
| desired_heading = next(k for k, v in MOVE_DELTA.items() if v == (dx, dy)) | |
| turns = _rotate_actions(current_heading, desired_heading) | |
| planned_actions.extend(turns) | |
| planned_actions.append("move_forward") | |
| current_heading = desired_heading | |
| current_position = step | |
| for pos, heading in candidates: | |
| if pos == current_position: | |
| planned_actions.extend(_rotate_actions(current_heading, heading)) | |
| break | |
| return planned_actions | |
| def _maybe_recharge_plan(env: WarehouseFulfillmentEnv) -> List[str]: | |
| state = env.state() | |
| distance_to_charger = abs(state.agent_position[0] - state.charger_position[0]) + abs( | |
| state.agent_position[1] - state.charger_position[1] | |
| ) | |
| threshold = max(6, (2 * distance_to_charger) + 4) | |
| if state.battery_level > threshold: | |
| return [] | |
| return _move_adjacent_and_face(env, state.charger_position) + ["recharge"] | |
| def planned_actions_for_task(env: WarehouseFulfillmentEnv) -> List[str]: | |
| actions: List[str] = [] | |
| state = env.state() | |
| sku_to_bin = {bin_state.sku: bin_state for bin_state in state.bins} | |
| for order_line in state.order: | |
| for _ in range(order_line.quantity): | |
| actions.extend(_maybe_recharge_plan(env)) | |
| for action in actions[len(env.state().action_history):]: | |
| env.step(action) | |
| bin_state = sku_to_bin[order_line.sku] | |
| path_to_bin = _move_adjacent_and_face(env, bin_state.position) | |
| actions.extend(path_to_bin) | |
| for action in path_to_bin: | |
| env.step(action) | |
| if bin_state.bin_id not in env.state().scanned_bins: | |
| actions.append("scan_bin") | |
| env.step("scan_bin") | |
| actions.append("pick_item") | |
| env.step("pick_item") | |
| recharge_path = _maybe_recharge_plan(env) | |
| actions.extend(recharge_path) | |
| for action in recharge_path: | |
| env.step(action) | |
| path_to_pack = _move_adjacent_and_face(env, env.state().pack_station_position) | |
| actions.extend(path_to_pack) | |
| for action in path_to_pack: | |
| env.step(action) | |
| actions.append("pack_item") | |
| env.step("pack_item") | |
| return actions | |
| def heuristic_next_action(env: WarehouseFulfillmentEnv, cached_plan: List[str]) -> str: | |
| state = env.state() | |
| if state.step_count < len(cached_plan): | |
| return cached_plan[state.step_count] | |
| if state.done: | |
| return "wait" | |
| return "wait" | |
| def build_openai_prompt(observation: WarehouseObservation, state: WarehouseState) -> str: | |
| payload = { | |
| "mission": observation.mission, | |
| "observation": model_to_dict(observation), | |
| "state_summary": { | |
| "step_count": state.step_count, | |
| "max_steps": state.max_steps, | |
| "battery_level": state.battery_level, | |
| "carrying": state.carrying, | |
| "scanned_bins": state.scanned_bins, | |
| "completion_ratio": state.completion_ratio, | |
| "recent_actions": state.action_history[-6:], | |
| }, | |
| } | |
| return json.dumps(payload, indent=2, sort_keys=True) | |
| def openai_next_action( | |
| client: Any, | |
| model: str, | |
| observation: WarehouseObservation, | |
| state: WarehouseState, | |
| ) -> str: | |
| response = client.responses.create( | |
| model=model, | |
| input=[ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": build_openai_prompt(observation, state)}, | |
| ], | |
| text={ | |
| "format": { | |
| "type": "json_schema", | |
| "name": "warehouse_action", | |
| "strict": True, | |
| "schema": BaselineCommand.model_json_schema(), | |
| } | |
| }, | |
| ) | |
| content = getattr(response, "output_text", "").strip() | |
| if not content: | |
| return "wait" | |
| payload = json.loads(content) | |
| return BaselineCommand(**payload).command | |
| def run_episode(task_id: str, seed: int, policy: str, model: str | None) -> Dict[str, float]: | |
| env_for_plan = WarehouseFulfillmentEnv(task_id=task_id, seed=seed) | |
| env_for_plan.reset(task_id=task_id, seed=seed) | |
| cached_plan = planned_actions_for_task(env_for_plan) | |
| env = WarehouseFulfillmentEnv(task_id=task_id, seed=seed) | |
| observation = env.reset(task_id=task_id, seed=seed) | |
| client = None | |
| if policy == "openai": | |
| if OpenAI is None: | |
| raise RuntimeError("The openai package is not installed.") | |
| api_key = os.environ.get("OPENAI_API_KEY") or os.environ.get("HF_TOKEN") or os.environ.get("API_KEY") | |
| if not api_key: | |
| raise RuntimeError("OPENAI_API_KEY (or HF_TOKEN/API_KEY) is not set.") | |
| base_url = os.environ.get("API_BASE_URL") | |
| client = OpenAI(api_key=api_key, base_url=base_url) | |
| done = False | |
| while not done: | |
| state = env.state() | |
| if policy == "openai": | |
| command = openai_next_action(client, model or os.environ.get("MODEL_NAME", "gpt-4.1-mini"), observation, state) | |
| else: | |
| command = heuristic_next_action(env, cached_plan) | |
| observation, reward, done, info = env.step(command) | |
| print( | |
| f"[{task_id}] step={state.step_count + 1} action={command} " | |
| f"reward={reward.value:+.2f} done={done}" | |
| ) | |
| final_state = env.state() | |
| return { | |
| "task_id": task_id, | |
| "reward": round(final_state.total_reward, 4), | |
| "score": grade_episode(final_state), | |
| "steps": float(final_state.step_count), | |
| "success": 1.0 if final_state.success else 0.0, | |
| } | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser(description="Run the warehouse fulfillment environment.") | |
| parser.add_argument("--task-id", choices=sorted(TASKS.keys()), help="Run a single task instead of all tasks.") | |
| parser.add_argument("--seed", type=int, default=7, help="Deterministic environment seed.") | |
| parser.add_argument( | |
| "--policy", | |
| choices=["heuristic", "openai"], | |
| default="heuristic", | |
| help="Action policy to use.", | |
| ) | |
| parser.add_argument( | |
| "--model", | |
| default=os.environ.get("MODEL_NAME") or os.environ.get("OPENAI_MODEL"), | |
| help="Model name for --policy openai.", | |
| ) | |
| return parser.parse_args() | |
| def main() -> None: | |
| args = parse_args() | |
| task_ids = [args.task_id] if args.task_id else list(TASKS.keys()) | |
| results = [run_episode(task_id, seed=args.seed, policy=args.policy, model=args.model) for task_id in task_ids] | |
| print("\ntask_id | score | reward | steps | success") | |
| for result in results: | |
| print( | |
| f"{result['task_id']} | {result['score']:.4f} | " | |
| f"{result['reward']:.4f} | {int(result['steps'])} | {int(result['success'])}" | |
| ) | |
| mean_score = sum(result["score"] for result in results) / len(results) | |
| print(f"mean_score | {mean_score:.4f}") | |
| if __name__ == "__main__": | |
| main() | |