| """ |
| Eval comparison script — runs multiple agents on fixed seeds and prints a summary table. |
| |
| Usage: |
| python scripts/eval_compare.py --seeds 42 43 44 45 46 --tiers medium hard --agents random heuristic |
| python scripts/eval_compare.py --quick |
| """ |
|
|
| import argparse |
| import json |
| import os |
| import sys |
| import warnings |
|
|
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
| from env import WildfireEnv |
| from agents.random_agent import RandomAgent |
| from agents.heuristic_agent import HeuristicAgent |
|
|
|
|
| def _make_llm_agent(model_path_env: str): |
| """Return an LLM agent factory or None if the model path is unset.""" |
| path = os.environ.get(model_path_env) |
| if not path: |
| return None |
| try: |
| from agents.llm_agent import LLMAgent |
| return LLMAgent(model_path=path) |
| except ImportError: |
| warnings.warn(f"agents.llm_agent not found — skipping {model_path_env}") |
| return None |
|
|
|
|
| AGENT_REGISTRY = { |
| "random": lambda: RandomAgent(), |
| "heuristic": lambda: HeuristicAgent(), |
| "base_llm": lambda: _make_llm_agent("BASE_MODEL_PATH"), |
| "trained_llm": lambda: _make_llm_agent("TRAINED_MODEL_PATH"), |
| } |
|
|
| AGENT_LABELS = { |
| "random": "Random Agent", |
| "heuristic": "Heuristic Agent", |
| "base_llm": "Base LLM", |
| "trained_llm": "Trained LLM (ours)", |
| } |
|
|
|
|
| def run_episode(agent, tier: str, seed: int) -> dict: |
| env = WildfireEnv() |
| obs = env.reset(task_id=tier, seed=seed) |
| total_reward = 0.0 |
| steps = 0 |
| done = False |
| while not done: |
| action = agent.act(obs) |
| result = env.step(action) |
| total_reward += result.reward |
| obs = result.observation |
| done = result.done |
| steps += 1 |
|
|
| final = env.state() |
| total_pop = final.get("total_population", 1) or 1 |
| pop_lost = final.get("population_lost", 0) |
| containment = final.get("containment_pct", 0.0) |
|
|
| return { |
| "containment_pct": containment, |
| "pop_saved_pct": 1.0 - pop_lost / total_pop, |
| "total_reward": total_reward, |
| "episode_steps": steps, |
| } |
|
|
|
|
| def run_comparison(agent_names, tiers, seeds): |
| results = {} |
| for agent_name in agent_names: |
| factory = AGENT_REGISTRY.get(agent_name) |
| agent = factory() if factory else None |
| results[agent_name] = {} |
| for tier in tiers: |
| if agent is None: |
| results[agent_name][tier] = None |
| continue |
| tier_results = [] |
| for seed in seeds: |
| ep = run_episode(agent, tier, seed) |
| tier_results.append(ep) |
| results[agent_name][tier] = { |
| "containment_pct": sum(r["containment_pct"] for r in tier_results) / len(tier_results), |
| "pop_saved_pct": sum(r["pop_saved_pct"] for r in tier_results) / len(tier_results), |
| "total_reward": sum(r["total_reward"] for r in tier_results) / len(tier_results), |
| "episode_steps": sum(r["episode_steps"] for r in tier_results) / len(tier_results), |
| "runs": tier_results, |
| } |
| return results |
|
|
|
|
| def print_table(results, tiers, agent_names, seeds): |
| for tier in tiers: |
| n = len(seeds) |
| print(f"\n=== EVAL RESULTS — {tier.capitalize()} Tier ({n} seed{'s' if n != 1 else ''}) ===") |
| header = f"{'Agent':<28} {'Containment':>12} {'Pop Saved':>10} {'Reward':>8} {'Steps':>7}" |
| print(header) |
| print("-" * len(header)) |
| for agent_name in agent_names: |
| label = AGENT_LABELS.get(agent_name, agent_name) |
| data = results[agent_name].get(tier) |
| if data is None: |
| print(f"{label:<28} {'[skipped — no model]':>39}") |
| else: |
| containment = f"{data['containment_pct']*100:.0f}%" |
| pop_saved = f"{data['pop_saved_pct']*100:.0f}%" |
| reward = f"{data['total_reward']:+.1f}" |
| steps = f"{data['episode_steps']:.0f}" |
| print(f"{label:<28} {containment:>12} {pop_saved:>10} {reward:>8} {steps:>7}") |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--seeds", type=int, nargs="+", default=[42, 43, 44, 45, 46]) |
| parser.add_argument("--tiers", nargs="+", choices=["easy", "medium", "hard"], default=["medium", "hard"]) |
| parser.add_argument("--agents", nargs="+", choices=list(AGENT_REGISTRY), default=["random", "heuristic"]) |
| parser.add_argument("--output", default="eval_results.json") |
| parser.add_argument("--quick", action="store_true", help="Easy tier, 2 seeds only") |
| args = parser.parse_args() |
|
|
| if args.quick: |
| args.tiers = ["easy"] |
| args.seeds = [42, 43] |
| args.agents = [a for a in args.agents if a in ("random", "heuristic")] |
|
|
| print(f"Running: agents={args.agents}, tiers={args.tiers}, seeds={args.seeds}") |
| results = run_comparison(args.agents, args.tiers, args.seeds) |
| print_table(results, args.tiers, args.agents, args.seeds) |
|
|
| out_dir = os.path.dirname(args.output) |
| if out_dir: |
| os.makedirs(out_dir, exist_ok=True) |
|
|
| serializable = {} |
| for agent_name, tier_data in results.items(): |
| serializable[agent_name] = {} |
| for tier, data in tier_data.items(): |
| if data is None: |
| serializable[agent_name][tier] = None |
| else: |
| serializable[agent_name][tier] = { |
| k: v for k, v in data.items() if k != "runs" |
| } |
| serializable[agent_name][tier]["runs"] = data["runs"] |
|
|
| with open(args.output, "w") as f: |
| json.dump(serializable, f, indent=2) |
| print(f"\nResults saved -> {args.output}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|