Spaces:
Sleeping
Sleeping
| """ | |
| inference.py β Warehouse Fulfillment Agent | |
| ========================================== | |
| Mandatory environment variables: | |
| API_BASE_URL The API endpoint for the LLM (OpenAI-compatible). | |
| MODEL_NAME The model identifier to use for inference. | |
| HF_TOKEN Your Hugging Face / API key. | |
| Optional environment variables: | |
| LOCAL_IMAGE_NAME Local Docker image name, if needed by the runner. | |
| Uses the OpenAI client for all LLM calls. | |
| Runs all 3 warehouse tasks and prints scores. | |
| Runtime target: < 20 min on vcpu=2, 8 GB RAM. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import os | |
| import sys | |
| from typing import Any, Dict, List | |
| from openai import OpenAI | |
| from grid_env.env import WarehouseFulfillmentEnv | |
| from grid_env.graders import grade_episode | |
| from grid_env.models import WarehouseObservation, WarehouseState | |
| from grid_env.tasks import TASKS | |
| # ββ Environment configuration ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| API_BASE_URL: str = os.getenv("API_BASE_URL", "https://api.openai.com/v1") | |
| MODEL_NAME: str = os.getenv("MODEL_NAME", "gpt-4o-mini") | |
| HF_TOKEN: str | None = os.getenv("HF_TOKEN") | |
| LOCAL_IMAGE_NAME: str | None = os.getenv("LOCAL_IMAGE_NAME") | |
| # Multi-seed evaluation: comma-separated list of seeds | |
| EVAL_SEEDS_STR: str = os.getenv("EVAL_SEEDS", "7,42,123,456,789") | |
| EVAL_SEEDS: List[int] = [int(s.strip()) for s in EVAL_SEEDS_STR.split(",")] | |
| # Safety cap so we never exceed the 20-min wall-clock limit. | |
| # Each task has its own max_steps; this is an outer guard. | |
| GLOBAL_MAX_STEPS: int = 60 | |
| # ββ Validator-parseable stdout blocks βββββββββββββββββββββββββββββββββββββββββ | |
| def _stdout_block(tag: str, **fields: Any) -> None: | |
| """ | |
| Emit a single-line structured block to stdout. | |
| External validators look for literal `[START]`, `[STEP]`, `[END]` tokens in | |
| stdout and then parse key=value pairs. | |
| """ | |
| parts: List[str] = [f"[{tag}]"] | |
| for k, v in fields.items(): | |
| parts.append(f"{k}={v}") | |
| print(" ".join(parts), flush=True) | |
| def _reward_value(reward: Any) -> float: | |
| try: | |
| return float(getattr(reward, "value")) | |
| except Exception: # noqa: BLE001 | |
| try: | |
| return float(reward) | |
| except Exception: # noqa: BLE001 | |
| return 0.0 | |
| # ββ Prompts βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| SYSTEM_PROMPT = """\ | |
| You control a warehouse fulfillment robot. | |
| You will receive the current environment observation as JSON. | |
| Reply with exactly one JSON object β no markdown, no extra text: | |
| { | |
| "command": "<one of: turn_left | turn_right | move_forward | scan_bin | pick_item | pack_item | recharge | rest | wait>", | |
| "rationale": "<one short sentence>" | |
| } | |
| Guidelines: | |
| - Scan the required bin before picking from it. | |
| - Pick the item your order needs and carry it to the pack station. | |
| - Face the pack station before calling pack_item. | |
| - Recharge before battery runs to 0 if needed. | |
| - Avoid invalid actions β every wasted step costs score. | |
| Advanced mechanics (active on harder tasks): | |
| - Obstacles: some cells are impassable. If front_cell says "obstacle", turn to find another route. | |
| - Item weight: items have weight. If an item exceeds your carry capacity, you cannot pick it. | |
| Heavier items drain more battery while moving. | |
| - Stamina: movement costs stamina. When stamina hits 0, movement costs double battery. | |
| Use the "rest" action at the rest area to restore stamina. | |
| - Money: packing correct items earns money; wrong packs lose money. Hit the profit target if set. | |
| """ | |
| def _build_user_message(obs: WarehouseObservation, state: WarehouseState) -> str: | |
| payload = { | |
| "task_id": obs.task_id, | |
| "mission": obs.mission, | |
| "narrative": obs.narrative, | |
| "agent_position": list(obs.agent_position), | |
| "heading": obs.heading, | |
| "front_cell": obs.front_cell, | |
| "carrying": obs.carrying, | |
| "battery_level": obs.battery_level, | |
| "visible_bins": obs.visible_bins, | |
| "pending_order": [{"sku": p.sku, "remaining": p.remaining} for p in obs.pending_order], | |
| "packed_order": [{"sku": p.sku, "packed": p.packed} for p in obs.packed_order], | |
| "progress_ratio": obs.progress_ratio, | |
| "step_count": state.step_count, | |
| "max_steps": state.max_steps, | |
| "recent_actions": state.action_history[-6:], | |
| } | |
| return json.dumps(payload, indent=2) | |
| # ββ LLM action selection ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| VALID_COMMANDS = { | |
| "turn_left", "turn_right", "move_forward", | |
| "scan_bin", "pick_item", "pack_item", "recharge", "rest", "wait", | |
| } | |
| def pick_action(client: OpenAI, obs: WarehouseObservation, state: WarehouseState) -> str: | |
| """Call the LLM and return a valid command string.""" | |
| try: | |
| response = client.chat.completions.create( | |
| model=MODEL_NAME, | |
| messages=[ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": _build_user_message(obs, state)}, | |
| ], | |
| temperature=0.2, | |
| max_tokens=120, | |
| response_format={"type": "json_object"}, | |
| ) | |
| content = response.choices[0].message.content or "" | |
| parsed = json.loads(content) | |
| command = str(parsed.get("command", "wait")).strip() | |
| if command not in VALID_COMMANDS: | |
| print(f" [warn] Model returned unknown command '{command}', using 'wait'") | |
| return "wait" | |
| return command | |
| except Exception as exc: # noqa: BLE001 | |
| print(f" [warn] LLM call failed: {exc} β using 'wait'") | |
| return "wait" | |
| # ββ Task runner βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def run_task(client: OpenAI, task_id: str, seed: int, verbose: bool = True) -> Dict[str, Any]: | |
| """Run a single task with a specific seed.""" | |
| if verbose: | |
| print(f"\n{'='*60}") | |
| print(f"Task: {task_id} | Seed: {seed}") | |
| print(f"{'='*60}") | |
| env = WarehouseFulfillmentEnv(task_id=task_id, seed=seed) | |
| obs = env.reset(task_id=task_id, seed=seed) | |
| done = False | |
| step = 0 | |
| _stdout_block("START", task=task_id, seed=seed) | |
| while not done and step < GLOBAL_MAX_STEPS: | |
| state = env.state() | |
| command = pick_action(client, obs, state) | |
| obs, reward, done, info = env.step(command) | |
| step += 1 | |
| _stdout_block( | |
| "STEP", | |
| task=task_id, | |
| seed=seed, | |
| step=step, | |
| reward=f"{_reward_value(reward):.6f}", | |
| command=command, | |
| done=int(bool(done)), | |
| ) | |
| if verbose: | |
| print( | |
| f" step {step:>3} | {command:<14} | reward {reward.value:+.3f} " | |
| f"| progress {obs.progress_ratio:.2f} | battery {obs.battery_level}" | |
| ) | |
| final_state = env.state() | |
| score = grade_episode(final_state) | |
| _stdout_block( | |
| "END", | |
| task=task_id, | |
| seed=seed, | |
| score=f"{float(score):.6f}", | |
| steps=final_state.step_count, | |
| success=int(bool(final_state.success)), | |
| ) | |
| result = { | |
| "task_id": task_id, | |
| "seed": seed, | |
| "score": round(score, 4), | |
| "reward": round(final_state.total_reward, 4), | |
| "steps": final_state.step_count, | |
| "success": final_state.success, | |
| "completion_ratio": final_state.completion_ratio, | |
| } | |
| if verbose: | |
| print( | |
| f"\n β score={result['score']:.4f} reward={result['reward']:.4f} " | |
| f"steps={result['steps']} success={result['success']}" | |
| ) | |
| return result | |
| def run_task_multiseed(client: OpenAI, task_id: str, seeds: List[int]) -> Dict[str, Any]: | |
| """Run a task across multiple seeds and aggregate results.""" | |
| print(f"\n{'='*60}") | |
| print(f"Task: {task_id} | Seeds: {seeds}") | |
| print(f"{'='*60}") | |
| seed_results = [] | |
| for seed in seeds: | |
| result = run_task(client, task_id, seed, verbose=False) | |
| seed_results.append(result) | |
| print(f" Seed {seed:>3}: score={result['score']:.4f} steps={result['steps']:>3} success={result['success']}") | |
| # Aggregate statistics | |
| scores = [r["score"] for r in seed_results] | |
| rewards = [r["reward"] for r in seed_results] | |
| steps = [r["steps"] for r in seed_results] | |
| successes = [1 if r["success"] else 0 for r in seed_results] | |
| mean_score = sum(scores) / len(scores) | |
| std_score = (sum((s - mean_score)**2 for s in scores) / len(scores))**0.5 | |
| aggregated = { | |
| "task_id": task_id, | |
| "score": round(mean_score, 4), | |
| "score_std": round(std_score, 4), | |
| "score_min": round(min(scores), 4), | |
| "score_max": round(max(scores), 4), | |
| "reward": round(sum(rewards) / len(rewards), 4), | |
| "steps": round(sum(steps) / len(steps), 1), | |
| "success_rate": round(sum(successes) / len(successes), 2), | |
| "num_seeds": len(seeds), | |
| "seeds": seeds, | |
| "seed_results": seed_results, | |
| } | |
| print(f"\n β Aggregated: score={aggregated['score']:.4f} (Β±{aggregated['score_std']:.4f}) " | |
| f"success_rate={aggregated['success_rate']:.2f}") | |
| return aggregated | |
| # ββ Main ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def main() -> None: | |
| if not HF_TOKEN: | |
| print("ERROR: HF_TOKEN environment variable is not set.", file=sys.stderr) | |
| sys.exit(1) | |
| if not MODEL_NAME: | |
| print("ERROR: MODEL_NAME environment variable is not set.", file=sys.stderr) | |
| sys.exit(1) | |
| print(f"API_BASE_URL : {API_BASE_URL}") | |
| print(f"MODEL_NAME : {MODEL_NAME}") | |
| print(f"LOCAL_IMAGE_NAME : {LOCAL_IMAGE_NAME}") | |
| print(f"Tasks : {list(TASKS.keys())}") | |
| print(f"Eval Seeds : {EVAL_SEEDS} ({len(EVAL_SEEDS)} seeds)") | |
| client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN) | |
| results: List[Dict[str, Any]] = [] | |
| if len(EVAL_SEEDS) == 1: | |
| # Single-seed evaluation (backward compatible) | |
| for task_id in TASKS: | |
| result = run_task(client, task_id, EVAL_SEEDS[0], verbose=True) | |
| results.append(result) | |
| else: | |
| # Multi-seed evaluation | |
| for task_id in TASKS: | |
| result = run_task_multiseed(client, task_id, EVAL_SEEDS) | |
| results.append(result) | |
| # ββ Summary report ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print(f"\n{'='*60}") | |
| print("RESULTS SUMMARY") | |
| print(f"{'='*60}") | |
| is_multiseed = len(EVAL_SEEDS) > 1 | |
| if is_multiseed: | |
| print(f"{'task_id':<30} {'score (Β±std)':>15} {'min':>7} {'max':>7} {'success':>8}") | |
| print("-" * 80) | |
| for r in results: | |
| print( | |
| f"{r['task_id']:<30} {r['score']:>6.4f} (Β±{r['score_std']:.4f}) " | |
| f"{r['score_min']:>7.4f} {r['score_max']:>7.4f} {r['success_rate']:>8.2f}" | |
| ) | |
| else: | |
| print(f"{'task_id':<30} {'score':>7} {'reward':>8} {'steps':>6} {'success':>8}") | |
| print("-" * 60) | |
| for r in results: | |
| print( | |
| f"{r['task_id']:<30} {r['score']:>7.4f} {r['reward']:>8.4f} " | |
| f"{r['steps']:>6} {str(r['success']):>8}" | |
| ) | |
| mean_score = sum(r["score"] for r in results) / len(results) | |
| print("-" * (80 if is_multiseed else 60)) | |
| print(f"{'mean_score':<30} {mean_score:>7.4f}") | |
| # Validator-friendly JSON output to stdout | |
| print("\nJSON_RESULTS:", json.dumps(results, indent=2)) | |
| # Exit non-zero if any score is outside [0, 1] β catches grader bugs | |
| for r in results: | |
| if not (0.0 <= r["score"] <= 1.0): | |
| print(f"ERROR: score out of range for {r['task_id']}: {r['score']}", file=sys.stderr) | |
| sys.exit(2) | |
| print("\nAll tasks completed successfully.") | |
| if __name__ == "__main__": | |
| main() | |