mini-rl-env / script.py
sohambose98's picture
base inference added
97ac6b2
"""
Runner for the warehouse fulfillment environment.
By default this executes a deterministic planner that solves all tasks
reproducibly. If OpenAI credentials are configured, it can also run a model
policy against the same environment.
"""
from __future__ import annotations
import argparse
import json
import os
from collections import deque
from typing import Any, Dict, List, Sequence, Tuple
from grid_env.env import WarehouseFulfillmentEnv
from grid_env.graders import grade_episode
from grid_env.models import BaselineCommand, WarehouseObservation, WarehouseState, model_to_dict
from grid_env.tasks import TASKS
try:
from openai import OpenAI
except ImportError: # pragma: no cover
OpenAI = None # type: ignore[assignment]
HEADINGS = ["N", "E", "S", "W"]
MOVE_DELTA = {
"N": (0, -1),
"E": (1, 0),
"S": (0, 1),
"W": (-1, 0),
}
SYSTEM_PROMPT = """You control a warehouse fulfillment robot.
Return exactly one JSON object with:
- command: one of turn_left, turn_right, move_forward, scan_bin, pick_item, pack_item, recharge, wait
- rationale: a short sentence
"""
def _adjacent_goal_positions(
target: Tuple[int, int],
blocked: set[Tuple[int, int]],
grid_size: Tuple[int, int],
) -> List[Tuple[Tuple[int, int], str]]:
candidates = []
for heading, (dx, dy) in MOVE_DELTA.items():
pos = (target[0] - dx, target[1] - dy)
if 0 <= pos[0] < grid_size[0] and 0 <= pos[1] < grid_size[1] and pos not in blocked:
candidates.append((pos, heading))
return candidates
def _neighbors(
position: Tuple[int, int],
blocked: set[Tuple[int, int]],
grid_size: Tuple[int, int],
) -> List[Tuple[int, int]]:
results = []
for dx, dy in MOVE_DELTA.values():
nxt = (position[0] + dx, position[1] + dy)
if 0 <= nxt[0] < grid_size[0] and 0 <= nxt[1] < grid_size[1] and nxt not in blocked:
results.append(nxt)
return results
def _bfs_path(
start: Tuple[int, int],
goals: Sequence[Tuple[int, int]],
blocked: set[Tuple[int, int]],
grid_size: Tuple[int, int],
) -> List[Tuple[int, int]]:
goal_set = set(goals)
queue = deque([start])
came_from: Dict[Tuple[int, int], Tuple[int, int] | None] = {start: None}
found = None
while queue:
current = queue.popleft()
if current in goal_set:
found = current
break
for nxt in _neighbors(current, blocked, grid_size):
if nxt not in came_from:
came_from[nxt] = current
queue.append(nxt)
if found is None:
raise RuntimeError("No path to target.")
path = []
current = found
while current != start:
path.append(current)
current = came_from[current]
path.reverse()
return path
def _rotate_actions(current_heading: str, desired_heading: str) -> List[str]:
current_idx = HEADINGS.index(current_heading)
desired_idx = HEADINGS.index(desired_heading)
right_turns = (desired_idx - current_idx) % 4
left_turns = (current_idx - desired_idx) % 4
if right_turns <= left_turns:
return ["turn_right"] * right_turns
return ["turn_left"] * left_turns
def _move_adjacent_and_face(env: WarehouseFulfillmentEnv, target: Tuple[int, int]) -> List[str]:
state = env.state()
blocked = {bin_state.position for bin_state in state.bins}
blocked.update({state.pack_station_position, state.charger_position, state.dock_position})
if state.agent_position in blocked:
blocked.remove(state.agent_position)
candidates = _adjacent_goal_positions(target, blocked, state.grid_size)
positions = [pos for pos, _ in candidates]
path = _bfs_path(state.agent_position, positions, blocked, state.grid_size)
planned_actions: List[str] = []
current_heading = state.heading
current_position = state.agent_position
for step in path:
dx = step[0] - current_position[0]
dy = step[1] - current_position[1]
desired_heading = next(k for k, v in MOVE_DELTA.items() if v == (dx, dy))
turns = _rotate_actions(current_heading, desired_heading)
planned_actions.extend(turns)
planned_actions.append("move_forward")
current_heading = desired_heading
current_position = step
for pos, heading in candidates:
if pos == current_position:
planned_actions.extend(_rotate_actions(current_heading, heading))
break
return planned_actions
def _maybe_recharge_plan(env: WarehouseFulfillmentEnv) -> List[str]:
state = env.state()
distance_to_charger = abs(state.agent_position[0] - state.charger_position[0]) + abs(
state.agent_position[1] - state.charger_position[1]
)
threshold = max(6, (2 * distance_to_charger) + 4)
if state.battery_level > threshold:
return []
return _move_adjacent_and_face(env, state.charger_position) + ["recharge"]
def planned_actions_for_task(env: WarehouseFulfillmentEnv) -> List[str]:
actions: List[str] = []
state = env.state()
sku_to_bin = {bin_state.sku: bin_state for bin_state in state.bins}
for order_line in state.order:
for _ in range(order_line.quantity):
actions.extend(_maybe_recharge_plan(env))
for action in actions[len(env.state().action_history):]:
env.step(action)
bin_state = sku_to_bin[order_line.sku]
path_to_bin = _move_adjacent_and_face(env, bin_state.position)
actions.extend(path_to_bin)
for action in path_to_bin:
env.step(action)
if bin_state.bin_id not in env.state().scanned_bins:
actions.append("scan_bin")
env.step("scan_bin")
actions.append("pick_item")
env.step("pick_item")
recharge_path = _maybe_recharge_plan(env)
actions.extend(recharge_path)
for action in recharge_path:
env.step(action)
path_to_pack = _move_adjacent_and_face(env, env.state().pack_station_position)
actions.extend(path_to_pack)
for action in path_to_pack:
env.step(action)
actions.append("pack_item")
env.step("pack_item")
return actions
def heuristic_next_action(env: WarehouseFulfillmentEnv, cached_plan: List[str]) -> str:
state = env.state()
if state.step_count < len(cached_plan):
return cached_plan[state.step_count]
if state.done:
return "wait"
return "wait"
def build_openai_prompt(observation: WarehouseObservation, state: WarehouseState) -> str:
payload = {
"mission": observation.mission,
"observation": model_to_dict(observation),
"state_summary": {
"step_count": state.step_count,
"max_steps": state.max_steps,
"battery_level": state.battery_level,
"carrying": state.carrying,
"scanned_bins": state.scanned_bins,
"completion_ratio": state.completion_ratio,
"recent_actions": state.action_history[-6:],
},
}
return json.dumps(payload, indent=2, sort_keys=True)
def openai_next_action(
client: Any,
model: str,
observation: WarehouseObservation,
state: WarehouseState,
) -> str:
response = client.responses.create(
model=model,
input=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": build_openai_prompt(observation, state)},
],
text={
"format": {
"type": "json_schema",
"name": "warehouse_action",
"strict": True,
"schema": BaselineCommand.model_json_schema(),
}
},
)
content = getattr(response, "output_text", "").strip()
if not content:
return "wait"
payload = json.loads(content)
return BaselineCommand(**payload).command
def run_episode(task_id: str, seed: int, policy: str, model: str | None) -> Dict[str, float]:
env_for_plan = WarehouseFulfillmentEnv(task_id=task_id, seed=seed)
env_for_plan.reset(task_id=task_id, seed=seed)
cached_plan = planned_actions_for_task(env_for_plan)
env = WarehouseFulfillmentEnv(task_id=task_id, seed=seed)
observation = env.reset(task_id=task_id, seed=seed)
client = None
if policy == "openai":
if OpenAI is None:
raise RuntimeError("The openai package is not installed.")
api_key = os.environ.get("OPENAI_API_KEY") or os.environ.get("HF_TOKEN") or os.environ.get("API_KEY")
if not api_key:
raise RuntimeError("OPENAI_API_KEY (or HF_TOKEN/API_KEY) is not set.")
base_url = os.environ.get("API_BASE_URL")
client = OpenAI(api_key=api_key, base_url=base_url)
done = False
while not done:
state = env.state()
if policy == "openai":
command = openai_next_action(client, model or os.environ.get("MODEL_NAME", "gpt-4.1-mini"), observation, state)
else:
command = heuristic_next_action(env, cached_plan)
observation, reward, done, info = env.step(command)
print(
f"[{task_id}] step={state.step_count + 1} action={command} "
f"reward={reward.value:+.2f} done={done}"
)
final_state = env.state()
return {
"task_id": task_id,
"reward": round(final_state.total_reward, 4),
"score": grade_episode(final_state),
"steps": float(final_state.step_count),
"success": 1.0 if final_state.success else 0.0,
}
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Run the warehouse fulfillment environment.")
parser.add_argument("--task-id", choices=sorted(TASKS.keys()), help="Run a single task instead of all tasks.")
parser.add_argument("--seed", type=int, default=7, help="Deterministic environment seed.")
parser.add_argument(
"--policy",
choices=["heuristic", "openai"],
default="heuristic",
help="Action policy to use.",
)
parser.add_argument(
"--model",
default=os.environ.get("MODEL_NAME") or os.environ.get("OPENAI_MODEL"),
help="Model name for --policy openai.",
)
return parser.parse_args()
def main() -> None:
args = parse_args()
task_ids = [args.task_id] if args.task_id else list(TASKS.keys())
results = [run_episode(task_id, seed=args.seed, policy=args.policy, model=args.model) for task_id in task_ids]
print("\ntask_id | score | reward | steps | success")
for result in results:
print(
f"{result['task_id']} | {result['score']:.4f} | "
f"{result['reward']:.4f} | {int(result['steps'])} | {int(result['success'])}"
)
mean_score = sum(result["score"] for result in results) / len(results)
print(f"mean_score | {mean_score:.4f}")
if __name__ == "__main__":
main()