mini-rl-env / inference.py
sohambose98's picture
updated configs
0e823e2
"""
inference.py β€” Warehouse Fulfillment Agent
==========================================
Mandatory environment variables:
API_BASE_URL The API endpoint for the LLM (OpenAI-compatible).
MODEL_NAME The model identifier to use for inference.
HF_TOKEN Your Hugging Face / API key.
Optional environment variables:
LOCAL_IMAGE_NAME Local Docker image name, if needed by the runner.
Uses the OpenAI client for all LLM calls.
Runs all 3 warehouse tasks and prints scores.
Runtime target: < 20 min on vcpu=2, 8 GB RAM.
"""
from __future__ import annotations
import json
import os
import sys
from typing import Any, Dict, List
from openai import OpenAI
from grid_env.env import WarehouseFulfillmentEnv
from grid_env.graders import grade_episode
from grid_env.models import WarehouseObservation, WarehouseState
from grid_env.tasks import TASKS
# ── Environment configuration ────────────────────────────────────────────────
API_BASE_URL: str = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
MODEL_NAME: str = os.getenv("MODEL_NAME", "gpt-4o-mini")
HF_TOKEN: str | None = os.getenv("HF_TOKEN")
LOCAL_IMAGE_NAME: str | None = os.getenv("LOCAL_IMAGE_NAME")
# Multi-seed evaluation: comma-separated list of seeds
EVAL_SEEDS_STR: str = os.getenv("EVAL_SEEDS", "7,42,123,456,789")
EVAL_SEEDS: List[int] = [int(s.strip()) for s in EVAL_SEEDS_STR.split(",")]
# Safety cap so we never exceed the 20-min wall-clock limit.
# Each task has its own max_steps; this is an outer guard.
GLOBAL_MAX_STEPS: int = 60
# ── Validator-parseable stdout blocks ─────────────────────────────────────────
def _stdout_block(tag: str, **fields: Any) -> None:
"""
Emit a single-line structured block to stdout.
External validators look for literal `[START]`, `[STEP]`, `[END]` tokens in
stdout and then parse key=value pairs.
"""
parts: List[str] = [f"[{tag}]"]
for k, v in fields.items():
parts.append(f"{k}={v}")
print(" ".join(parts), flush=True)
def _reward_value(reward: Any) -> float:
try:
return float(getattr(reward, "value"))
except Exception: # noqa: BLE001
try:
return float(reward)
except Exception: # noqa: BLE001
return 0.0
# ── Prompts ───────────────────────────────────────────────────────────────────
SYSTEM_PROMPT = """\
You control a warehouse fulfillment robot.
You will receive the current environment observation as JSON.
Reply with exactly one JSON object β€” no markdown, no extra text:
{
"command": "<one of: turn_left | turn_right | move_forward | scan_bin | pick_item | pack_item | recharge | rest | wait>",
"rationale": "<one short sentence>"
}
Guidelines:
- Scan the required bin before picking from it.
- Pick the item your order needs and carry it to the pack station.
- Face the pack station before calling pack_item.
- Recharge before battery runs to 0 if needed.
- Avoid invalid actions β€” every wasted step costs score.
Advanced mechanics (active on harder tasks):
- Obstacles: some cells are impassable. If front_cell says "obstacle", turn to find another route.
- Item weight: items have weight. If an item exceeds your carry capacity, you cannot pick it.
Heavier items drain more battery while moving.
- Stamina: movement costs stamina. When stamina hits 0, movement costs double battery.
Use the "rest" action at the rest area to restore stamina.
- Money: packing correct items earns money; wrong packs lose money. Hit the profit target if set.
"""
def _build_user_message(obs: WarehouseObservation, state: WarehouseState) -> str:
payload = {
"task_id": obs.task_id,
"mission": obs.mission,
"narrative": obs.narrative,
"agent_position": list(obs.agent_position),
"heading": obs.heading,
"front_cell": obs.front_cell,
"carrying": obs.carrying,
"battery_level": obs.battery_level,
"visible_bins": obs.visible_bins,
"pending_order": [{"sku": p.sku, "remaining": p.remaining} for p in obs.pending_order],
"packed_order": [{"sku": p.sku, "packed": p.packed} for p in obs.packed_order],
"progress_ratio": obs.progress_ratio,
"step_count": state.step_count,
"max_steps": state.max_steps,
"recent_actions": state.action_history[-6:],
}
return json.dumps(payload, indent=2)
# ── LLM action selection ──────────────────────────────────────────────────────
VALID_COMMANDS = {
"turn_left", "turn_right", "move_forward",
"scan_bin", "pick_item", "pack_item", "recharge", "rest", "wait",
}
def pick_action(client: OpenAI, obs: WarehouseObservation, state: WarehouseState) -> str:
"""Call the LLM and return a valid command string."""
try:
response = client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": _build_user_message(obs, state)},
],
temperature=0.2,
max_tokens=120,
response_format={"type": "json_object"},
)
content = response.choices[0].message.content or ""
parsed = json.loads(content)
command = str(parsed.get("command", "wait")).strip()
if command not in VALID_COMMANDS:
print(f" [warn] Model returned unknown command '{command}', using 'wait'")
return "wait"
return command
except Exception as exc: # noqa: BLE001
print(f" [warn] LLM call failed: {exc} β€” using 'wait'")
return "wait"
# ── Task runner ───────────────────────────────────────────────────────────────
def run_task(client: OpenAI, task_id: str, seed: int, verbose: bool = True) -> Dict[str, Any]:
"""Run a single task with a specific seed."""
if verbose:
print(f"\n{'='*60}")
print(f"Task: {task_id} | Seed: {seed}")
print(f"{'='*60}")
env = WarehouseFulfillmentEnv(task_id=task_id, seed=seed)
obs = env.reset(task_id=task_id, seed=seed)
done = False
step = 0
_stdout_block("START", task=task_id, seed=seed)
while not done and step < GLOBAL_MAX_STEPS:
state = env.state()
command = pick_action(client, obs, state)
obs, reward, done, info = env.step(command)
step += 1
_stdout_block(
"STEP",
task=task_id,
seed=seed,
step=step,
reward=f"{_reward_value(reward):.6f}",
command=command,
done=int(bool(done)),
)
if verbose:
print(
f" step {step:>3} | {command:<14} | reward {reward.value:+.3f} "
f"| progress {obs.progress_ratio:.2f} | battery {obs.battery_level}"
)
final_state = env.state()
score = grade_episode(final_state)
_stdout_block(
"END",
task=task_id,
seed=seed,
score=f"{float(score):.6f}",
steps=final_state.step_count,
success=int(bool(final_state.success)),
)
result = {
"task_id": task_id,
"seed": seed,
"score": round(score, 4),
"reward": round(final_state.total_reward, 4),
"steps": final_state.step_count,
"success": final_state.success,
"completion_ratio": final_state.completion_ratio,
}
if verbose:
print(
f"\n β†’ score={result['score']:.4f} reward={result['reward']:.4f} "
f"steps={result['steps']} success={result['success']}"
)
return result
def run_task_multiseed(client: OpenAI, task_id: str, seeds: List[int]) -> Dict[str, Any]:
"""Run a task across multiple seeds and aggregate results."""
print(f"\n{'='*60}")
print(f"Task: {task_id} | Seeds: {seeds}")
print(f"{'='*60}")
seed_results = []
for seed in seeds:
result = run_task(client, task_id, seed, verbose=False)
seed_results.append(result)
print(f" Seed {seed:>3}: score={result['score']:.4f} steps={result['steps']:>3} success={result['success']}")
# Aggregate statistics
scores = [r["score"] for r in seed_results]
rewards = [r["reward"] for r in seed_results]
steps = [r["steps"] for r in seed_results]
successes = [1 if r["success"] else 0 for r in seed_results]
mean_score = sum(scores) / len(scores)
std_score = (sum((s - mean_score)**2 for s in scores) / len(scores))**0.5
aggregated = {
"task_id": task_id,
"score": round(mean_score, 4),
"score_std": round(std_score, 4),
"score_min": round(min(scores), 4),
"score_max": round(max(scores), 4),
"reward": round(sum(rewards) / len(rewards), 4),
"steps": round(sum(steps) / len(steps), 1),
"success_rate": round(sum(successes) / len(successes), 2),
"num_seeds": len(seeds),
"seeds": seeds,
"seed_results": seed_results,
}
print(f"\n β†’ Aggregated: score={aggregated['score']:.4f} (Β±{aggregated['score_std']:.4f}) "
f"success_rate={aggregated['success_rate']:.2f}")
return aggregated
# ── Main ──────────────────────────────────────────────────────────────────────
def main() -> None:
if not HF_TOKEN:
print("ERROR: HF_TOKEN environment variable is not set.", file=sys.stderr)
sys.exit(1)
if not MODEL_NAME:
print("ERROR: MODEL_NAME environment variable is not set.", file=sys.stderr)
sys.exit(1)
print(f"API_BASE_URL : {API_BASE_URL}")
print(f"MODEL_NAME : {MODEL_NAME}")
print(f"LOCAL_IMAGE_NAME : {LOCAL_IMAGE_NAME}")
print(f"Tasks : {list(TASKS.keys())}")
print(f"Eval Seeds : {EVAL_SEEDS} ({len(EVAL_SEEDS)} seeds)")
client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
results: List[Dict[str, Any]] = []
if len(EVAL_SEEDS) == 1:
# Single-seed evaluation (backward compatible)
for task_id in TASKS:
result = run_task(client, task_id, EVAL_SEEDS[0], verbose=True)
results.append(result)
else:
# Multi-seed evaluation
for task_id in TASKS:
result = run_task_multiseed(client, task_id, EVAL_SEEDS)
results.append(result)
# ── Summary report ────────────────────────────────────────────────────────
print(f"\n{'='*60}")
print("RESULTS SUMMARY")
print(f"{'='*60}")
is_multiseed = len(EVAL_SEEDS) > 1
if is_multiseed:
print(f"{'task_id':<30} {'score (Β±std)':>15} {'min':>7} {'max':>7} {'success':>8}")
print("-" * 80)
for r in results:
print(
f"{r['task_id']:<30} {r['score']:>6.4f} (Β±{r['score_std']:.4f}) "
f"{r['score_min']:>7.4f} {r['score_max']:>7.4f} {r['success_rate']:>8.2f}"
)
else:
print(f"{'task_id':<30} {'score':>7} {'reward':>8} {'steps':>6} {'success':>8}")
print("-" * 60)
for r in results:
print(
f"{r['task_id']:<30} {r['score']:>7.4f} {r['reward']:>8.4f} "
f"{r['steps']:>6} {str(r['success']):>8}"
)
mean_score = sum(r["score"] for r in results) / len(results)
print("-" * (80 if is_multiseed else 60))
print(f"{'mean_score':<30} {mean_score:>7.4f}")
# Validator-friendly JSON output to stdout
print("\nJSON_RESULTS:", json.dumps(results, indent=2))
# Exit non-zero if any score is outside [0, 1] β€” catches grader bugs
for r in results:
if not (0.0 <= r["score"] <= 1.0):
print(f"ERROR: score out of range for {r['task_id']}: {r['score']}", file=sys.stderr)
sys.exit(2)
print("\nAll tasks completed successfully.")
if __name__ == "__main__":
main()