Spaces:
Sleeping
Sleeping
| """Random-action baseline agent for Pyre (single-agent). | |
| Runs N episodes using the PyreEnv client and prints per-episode stats. | |
| Use this to smoke-test the server and verify the reward distribution | |
| spans a meaningful range. | |
| Usage: | |
| # Server must be running first: | |
| # cd pyre_env && uv run server | |
| # | |
| python examples/random_agent.py --episodes 5 | |
| python examples/random_agent.py --episodes 5 --verbose | |
| python examples/random_agent.py --url http://localhost:8000 --episodes 10 | |
| """ | |
| import argparse | |
| import random | |
| import sys | |
| from typing import List | |
| import requests | |
| from pyre_env import PyreEnv, PyreAction | |
| # --------------------------------------------------------------------------- | |
| # Action sampling | |
| # --------------------------------------------------------------------------- | |
| def _parse_hint(hint: str) -> PyreAction: | |
| """Parse a hint string from available_actions_hint into a PyreAction.""" | |
| try: | |
| h = hint.strip() | |
| if h.startswith("move("): | |
| return PyreAction(action="move", direction=h.split("'")[1]) | |
| elif h.startswith("door("): | |
| parts = h.split("'") | |
| # parts: ["door(target_id=", did, ", door_state=", state, ")"] | |
| target_id = parts[1] | |
| door_state = parts[3] | |
| return PyreAction(action="door", target_id=target_id, door_state=door_state) | |
| elif h == "wait()": | |
| return PyreAction(action="wait") | |
| except (IndexError, ValueError): | |
| pass | |
| return PyreAction(action="wait") | |
| def random_action(hints: List[str], rng: random.Random) -> PyreAction: | |
| """Pick a random action, biasing toward available hints 70% of the time.""" | |
| if hints and rng.random() < 0.7: | |
| return _parse_hint(rng.choice(hints)) | |
| # Fallback: random move | |
| return PyreAction(action="move", direction=rng.choice(["north", "south", "east", "west"])) | |
| # --------------------------------------------------------------------------- | |
| # Episode runner | |
| # --------------------------------------------------------------------------- | |
| def run_episode(env, max_steps: int, rng: random.Random, verbose: bool) -> dict: | |
| result = env.reset() | |
| obs = result.observation | |
| episode_reward = 0.0 | |
| steps = 0 | |
| done = result.done | |
| while not done and steps < max_steps: | |
| action = random_action(obs.available_actions_hint, rng) | |
| result = env.step(action) | |
| obs = result.observation | |
| reward = result.reward or 0.0 | |
| done = result.done | |
| episode_reward += reward | |
| steps += 1 | |
| if verbose: | |
| first_line = obs.narrative.split("\n")[0] if obs.narrative else "" | |
| print( | |
| f" step {steps:3d} | hp={obs.agent_health:5.1f}" | |
| f" | r={reward:+.3f} | done={done} | {first_line[:70]}" | |
| ) | |
| meta = obs.metadata or {} | |
| return { | |
| "steps": steps, | |
| "total_reward": round(episode_reward, 3), | |
| "done": done, | |
| "evacuated": obs.agent_evacuated, | |
| "final_health": obs.agent_health, | |
| "wind_dir": obs.wind_dir, | |
| "fire_sources": meta.get("fire_sources", "?"), | |
| "fire_spread": meta.get("fire_spread_rate", "?"), | |
| "last_narrative": obs.narrative[:120] if obs.narrative else "", | |
| } | |
| # --------------------------------------------------------------------------- | |
| # Main | |
| # --------------------------------------------------------------------------- | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Pyre random-agent baseline") | |
| parser.add_argument("--url", default="http://localhost:8000", help="Server base URL") | |
| parser.add_argument("--episodes", type=int, default=5) | |
| parser.add_argument("--max-steps", type=int, default=100) | |
| parser.add_argument("--seed", type=int, default=7) | |
| parser.add_argument("--verbose", action="store_true") | |
| args = parser.parse_args() | |
| # Health check | |
| try: | |
| r = requests.get(f"{args.url}/health", timeout=5) | |
| r.raise_for_status() | |
| print(f"Server healthy: {args.url}") | |
| except Exception as e: | |
| print(f"Server not reachable at {args.url}: {e}") | |
| sys.exit(1) | |
| rng = random.Random(args.seed) | |
| results: List[dict] = [] | |
| with PyreEnv(base_url=args.url).sync() as env: | |
| for ep in range(args.episodes): | |
| print(f"\n=== Episode {ep + 1}/{args.episodes} ===") | |
| stats = run_episode(env, args.max_steps, rng, args.verbose) | |
| results.append(stats) | |
| print( | |
| f" DONE steps={stats['steps']} reward={stats['total_reward']:+.3f}" | |
| f" health={stats['final_health']:.1f}" | |
| f" wind={stats['wind_dir']} sources={stats['fire_sources']}" | |
| f" spread={stats['fire_spread']}" | |
| ) | |
| print("\n=== Summary ===") | |
| rewards = [r["total_reward"] for r in results] | |
| print(f"Episodes: {len(results)}") | |
| print(f"Reward min/max: {min(rewards):.3f} / {max(rewards):.3f}") | |
| print(f"Reward mean: {sum(rewards)/len(rewards):.3f}") | |
| print(f"Avg steps: {sum(r['steps'] for r in results) / len(results):.1f}") | |
| if __name__ == "__main__": | |
| main() | |