| """Random-action baseline agent for Pyre (single-agent). |
| |
| Runs N episodes using the PyreEnv client and prints per-episode stats. |
| Use this to smoke-test the server and verify the reward distribution |
| spans a meaningful range. |
| |
| Usage: |
| # Server must be running first: |
| # cd pyre_env && uv run server |
| # |
| python examples/random_agent.py --episodes 5 |
| python examples/random_agent.py --episodes 5 --verbose |
| python examples/random_agent.py --url http://localhost:8000 --episodes 10 |
| """ |
|
|
| import argparse |
| import random |
| import sys |
| from typing import List |
|
|
| import requests |
|
|
| from pyre_env import PyreEnv, PyreAction |
|
|
|
|
| |
| |
| |
|
|
| def _parse_hint(hint: str) -> PyreAction: |
| """Parse a hint string from available_actions_hint into a PyreAction.""" |
| try: |
| h = hint.strip() |
| if h.startswith("move("): |
| return PyreAction(action="move", direction=h.split("'")[1]) |
| elif h.startswith("door("): |
| parts = h.split("'") |
| |
| target_id = parts[1] |
| door_state = parts[3] |
| return PyreAction(action="door", target_id=target_id, door_state=door_state) |
| elif h == "wait()": |
| return PyreAction(action="wait") |
| except (IndexError, ValueError): |
| pass |
| return PyreAction(action="wait") |
|
|
|
|
| def random_action(hints: List[str], rng: random.Random) -> PyreAction: |
| """Pick a random action, biasing toward available hints 70% of the time.""" |
| if hints and rng.random() < 0.7: |
| return _parse_hint(rng.choice(hints)) |
| |
| return PyreAction(action="move", direction=rng.choice(["north", "south", "east", "west"])) |
|
|
|
|
| |
| |
| |
|
|
| def run_episode(env, max_steps: int, rng: random.Random, verbose: bool) -> dict: |
| result = env.reset() |
| obs = result.observation |
|
|
| episode_reward = 0.0 |
| steps = 0 |
| done = result.done |
|
|
| while not done and steps < max_steps: |
| action = random_action(obs.available_actions_hint, rng) |
| result = env.step(action) |
| obs = result.observation |
| reward = result.reward or 0.0 |
| done = result.done |
| episode_reward += reward |
| steps += 1 |
|
|
| if verbose: |
| first_line = obs.narrative.split("\n")[0] if obs.narrative else "" |
| print( |
| f" step {steps:3d} | hp={obs.agent_health:5.1f}" |
| f" | r={reward:+.3f} | done={done} | {first_line[:70]}" |
| ) |
|
|
| meta = obs.metadata or {} |
| return { |
| "steps": steps, |
| "total_reward": round(episode_reward, 3), |
| "done": done, |
| "evacuated": obs.agent_evacuated, |
| "final_health": obs.agent_health, |
| "wind_dir": obs.wind_dir, |
| "fire_sources": meta.get("fire_sources", "?"), |
| "fire_spread": meta.get("fire_spread_rate", "?"), |
| "last_narrative": obs.narrative[:120] if obs.narrative else "", |
| } |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Pyre random-agent baseline") |
| parser.add_argument("--url", default="http://localhost:8000", help="Server base URL") |
| parser.add_argument("--episodes", type=int, default=5) |
| parser.add_argument("--max-steps", type=int, default=100) |
| parser.add_argument("--seed", type=int, default=7) |
| parser.add_argument("--verbose", action="store_true") |
| args = parser.parse_args() |
|
|
| |
| try: |
| r = requests.get(f"{args.url}/health", timeout=5) |
| r.raise_for_status() |
| print(f"Server healthy: {args.url}") |
| except Exception as e: |
| print(f"Server not reachable at {args.url}: {e}") |
| sys.exit(1) |
|
|
| rng = random.Random(args.seed) |
| results: List[dict] = [] |
|
|
| with PyreEnv(base_url=args.url).sync() as env: |
| for ep in range(args.episodes): |
| print(f"\n=== Episode {ep + 1}/{args.episodes} ===") |
| stats = run_episode(env, args.max_steps, rng, args.verbose) |
| results.append(stats) |
| print( |
| f" DONE steps={stats['steps']} reward={stats['total_reward']:+.3f}" |
| f" health={stats['final_health']:.1f}" |
| f" wind={stats['wind_dir']} sources={stats['fire_sources']}" |
| f" spread={stats['fire_spread']}" |
| ) |
|
|
| print("\n=== Summary ===") |
| rewards = [r["total_reward"] for r in results] |
| print(f"Episodes: {len(results)}") |
| print(f"Reward min/max: {min(rewards):.3f} / {max(rewards):.3f}") |
| print(f"Reward mean: {sum(rewards)/len(rewards):.3f}") |
| print(f"Avg steps: {sum(r['steps'] for r in results) / len(results):.1f}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|