File size: 5,162 Bytes
2d3e659
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
"""Random-action baseline agent for Pyre (single-agent).

Runs N episodes using the PyreEnv client and prints per-episode stats.
Use this to smoke-test the server and verify the reward distribution
spans a meaningful range.

Usage:
    # Server must be running first:
    #   cd pyre_env && uv run server
    #
    python examples/random_agent.py --episodes 5
    python examples/random_agent.py --episodes 5 --verbose
    python examples/random_agent.py --url http://localhost:8000 --episodes 10
"""

import argparse
import random
import sys
from typing import List

import requests

from pyre_env import PyreEnv, PyreAction


# ---------------------------------------------------------------------------
# Action sampling
# ---------------------------------------------------------------------------

def _parse_hint(hint: str) -> PyreAction:
    """Parse a hint string from available_actions_hint into a PyreAction."""
    try:
        h = hint.strip()
        if h.startswith("move("):
            return PyreAction(action="move", direction=h.split("'")[1])
        elif h.startswith("door("):
            parts = h.split("'")
            # parts: ["door(target_id=", did, ", door_state=", state, ")"]
            target_id = parts[1]
            door_state = parts[3]
            return PyreAction(action="door", target_id=target_id, door_state=door_state)
        elif h == "wait()":
            return PyreAction(action="wait")
    except (IndexError, ValueError):
        pass
    return PyreAction(action="wait")


def random_action(hints: List[str], rng: random.Random) -> PyreAction:
    """Pick a random action, biasing toward available hints 70% of the time."""
    if hints and rng.random() < 0.7:
        return _parse_hint(rng.choice(hints))
    # Fallback: random move
    return PyreAction(action="move", direction=rng.choice(["north", "south", "east", "west"]))


# ---------------------------------------------------------------------------
# Episode runner
# ---------------------------------------------------------------------------

def run_episode(env, max_steps: int, rng: random.Random, verbose: bool) -> dict:
    result = env.reset()
    obs = result.observation

    episode_reward = 0.0
    steps = 0
    done = result.done

    while not done and steps < max_steps:
        action = random_action(obs.available_actions_hint, rng)
        result = env.step(action)
        obs = result.observation
        reward = result.reward or 0.0
        done = result.done
        episode_reward += reward
        steps += 1

        if verbose:
            first_line = obs.narrative.split("\n")[0] if obs.narrative else ""
            print(
                f"  step {steps:3d} | hp={obs.agent_health:5.1f}"
                f" | r={reward:+.3f} | done={done} | {first_line[:70]}"
            )

    meta = obs.metadata or {}
    return {
        "steps": steps,
        "total_reward": round(episode_reward, 3),
        "done": done,
        "evacuated": obs.agent_evacuated,
        "final_health": obs.agent_health,
        "wind_dir": obs.wind_dir,
        "fire_sources": meta.get("fire_sources", "?"),
        "fire_spread": meta.get("fire_spread_rate", "?"),
        "last_narrative": obs.narrative[:120] if obs.narrative else "",
    }


# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------

def main():
    parser = argparse.ArgumentParser(description="Pyre random-agent baseline")
    parser.add_argument("--url", default="http://localhost:8000", help="Server base URL")
    parser.add_argument("--episodes", type=int, default=5)
    parser.add_argument("--max-steps", type=int, default=100)
    parser.add_argument("--seed", type=int, default=7)
    parser.add_argument("--verbose", action="store_true")
    args = parser.parse_args()

    # Health check
    try:
        r = requests.get(f"{args.url}/health", timeout=5)
        r.raise_for_status()
        print(f"Server healthy: {args.url}")
    except Exception as e:
        print(f"Server not reachable at {args.url}: {e}")
        sys.exit(1)

    rng = random.Random(args.seed)
    results: List[dict] = []

    with PyreEnv(base_url=args.url).sync() as env:
        for ep in range(args.episodes):
            print(f"\n=== Episode {ep + 1}/{args.episodes} ===")
            stats = run_episode(env, args.max_steps, rng, args.verbose)
            results.append(stats)
            print(
                f"  DONE  steps={stats['steps']}  reward={stats['total_reward']:+.3f}"
                f"  health={stats['final_health']:.1f}"
                f"  wind={stats['wind_dir']}  sources={stats['fire_sources']}"
                f"  spread={stats['fire_spread']}"
            )

    print("\n=== Summary ===")
    rewards = [r["total_reward"] for r in results]
    print(f"Episodes:       {len(results)}")
    print(f"Reward min/max: {min(rewards):.3f} / {max(rewards):.3f}")
    print(f"Reward mean:    {sum(rewards)/len(rewards):.3f}")
    print(f"Avg steps:      {sum(r['steps'] for r in results) / len(results):.1f}")


if __name__ == "__main__":
    main()