File size: 5,162 Bytes
2d3e659 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 | """Random-action baseline agent for Pyre (single-agent).
Runs N episodes using the PyreEnv client and prints per-episode stats.
Use this to smoke-test the server and verify the reward distribution
spans a meaningful range.
Usage:
# Server must be running first:
# cd pyre_env && uv run server
#
python examples/random_agent.py --episodes 5
python examples/random_agent.py --episodes 5 --verbose
python examples/random_agent.py --url http://localhost:8000 --episodes 10
"""
import argparse
import random
import sys
from typing import List
import requests
from pyre_env import PyreEnv, PyreAction
# ---------------------------------------------------------------------------
# Action sampling
# ---------------------------------------------------------------------------
def _parse_hint(hint: str) -> PyreAction:
"""Parse a hint string from available_actions_hint into a PyreAction."""
try:
h = hint.strip()
if h.startswith("move("):
return PyreAction(action="move", direction=h.split("'")[1])
elif h.startswith("door("):
parts = h.split("'")
# parts: ["door(target_id=", did, ", door_state=", state, ")"]
target_id = parts[1]
door_state = parts[3]
return PyreAction(action="door", target_id=target_id, door_state=door_state)
elif h == "wait()":
return PyreAction(action="wait")
except (IndexError, ValueError):
pass
return PyreAction(action="wait")
def random_action(hints: List[str], rng: random.Random) -> PyreAction:
"""Pick a random action, biasing toward available hints 70% of the time."""
if hints and rng.random() < 0.7:
return _parse_hint(rng.choice(hints))
# Fallback: random move
return PyreAction(action="move", direction=rng.choice(["north", "south", "east", "west"]))
# ---------------------------------------------------------------------------
# Episode runner
# ---------------------------------------------------------------------------
def run_episode(env, max_steps: int, rng: random.Random, verbose: bool) -> dict:
result = env.reset()
obs = result.observation
episode_reward = 0.0
steps = 0
done = result.done
while not done and steps < max_steps:
action = random_action(obs.available_actions_hint, rng)
result = env.step(action)
obs = result.observation
reward = result.reward or 0.0
done = result.done
episode_reward += reward
steps += 1
if verbose:
first_line = obs.narrative.split("\n")[0] if obs.narrative else ""
print(
f" step {steps:3d} | hp={obs.agent_health:5.1f}"
f" | r={reward:+.3f} | done={done} | {first_line[:70]}"
)
meta = obs.metadata or {}
return {
"steps": steps,
"total_reward": round(episode_reward, 3),
"done": done,
"evacuated": obs.agent_evacuated,
"final_health": obs.agent_health,
"wind_dir": obs.wind_dir,
"fire_sources": meta.get("fire_sources", "?"),
"fire_spread": meta.get("fire_spread_rate", "?"),
"last_narrative": obs.narrative[:120] if obs.narrative else "",
}
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(description="Pyre random-agent baseline")
parser.add_argument("--url", default="http://localhost:8000", help="Server base URL")
parser.add_argument("--episodes", type=int, default=5)
parser.add_argument("--max-steps", type=int, default=100)
parser.add_argument("--seed", type=int, default=7)
parser.add_argument("--verbose", action="store_true")
args = parser.parse_args()
# Health check
try:
r = requests.get(f"{args.url}/health", timeout=5)
r.raise_for_status()
print(f"Server healthy: {args.url}")
except Exception as e:
print(f"Server not reachable at {args.url}: {e}")
sys.exit(1)
rng = random.Random(args.seed)
results: List[dict] = []
with PyreEnv(base_url=args.url).sync() as env:
for ep in range(args.episodes):
print(f"\n=== Episode {ep + 1}/{args.episodes} ===")
stats = run_episode(env, args.max_steps, rng, args.verbose)
results.append(stats)
print(
f" DONE steps={stats['steps']} reward={stats['total_reward']:+.3f}"
f" health={stats['final_health']:.1f}"
f" wind={stats['wind_dir']} sources={stats['fire_sources']}"
f" spread={stats['fire_spread']}"
)
print("\n=== Summary ===")
rewards = [r["total_reward"] for r in results]
print(f"Episodes: {len(results)}")
print(f"Reward min/max: {min(rewards):.3f} / {max(rewards):.3f}")
print(f"Reward mean: {sum(rewards)/len(rewards):.3f}")
print(f"Avg steps: {sum(r['steps'] for r in results) / len(results):.1f}")
if __name__ == "__main__":
main()
|