File size: 5,729 Bytes
363abf3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 | """
Eval comparison script — runs multiple agents on fixed seeds and prints a summary table.
Usage:
python scripts/eval_compare.py --seeds 42 43 44 45 46 --tiers medium hard --agents random heuristic
python scripts/eval_compare.py --quick
"""
import argparse
import json
import os
import sys
import warnings
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from env import WildfireEnv
from agents.random_agent import RandomAgent
from agents.heuristic_agent import HeuristicAgent
def _make_llm_agent(model_path_env: str):
"""Return an LLM agent factory or None if the model path is unset."""
path = os.environ.get(model_path_env)
if not path:
return None
try:
from agents.llm_agent import LLMAgent # type: ignore
return LLMAgent(model_path=path)
except ImportError:
warnings.warn(f"agents.llm_agent not found — skipping {model_path_env}")
return None
AGENT_REGISTRY = {
"random": lambda: RandomAgent(),
"heuristic": lambda: HeuristicAgent(),
"base_llm": lambda: _make_llm_agent("BASE_MODEL_PATH"),
"trained_llm": lambda: _make_llm_agent("TRAINED_MODEL_PATH"),
}
AGENT_LABELS = {
"random": "Random Agent",
"heuristic": "Heuristic Agent",
"base_llm": "Base LLM",
"trained_llm": "Trained LLM (ours)",
}
def run_episode(agent, tier: str, seed: int) -> dict:
env = WildfireEnv()
obs = env.reset(task_id=tier, seed=seed)
total_reward = 0.0
steps = 0
done = False
while not done:
action = agent.act(obs)
result = env.step(action)
total_reward += result.reward
obs = result.observation
done = result.done
steps += 1
final = env.state()
total_pop = final.get("total_population", 1) or 1
pop_lost = final.get("population_lost", 0)
containment = final.get("containment_pct", 0.0)
return {
"containment_pct": containment,
"pop_saved_pct": 1.0 - pop_lost / total_pop,
"total_reward": total_reward,
"episode_steps": steps,
}
def run_comparison(agent_names, tiers, seeds):
results = {}
for agent_name in agent_names:
factory = AGENT_REGISTRY.get(agent_name)
agent = factory() if factory else None
results[agent_name] = {}
for tier in tiers:
if agent is None:
results[agent_name][tier] = None
continue
tier_results = []
for seed in seeds:
ep = run_episode(agent, tier, seed)
tier_results.append(ep)
results[agent_name][tier] = {
"containment_pct": sum(r["containment_pct"] for r in tier_results) / len(tier_results),
"pop_saved_pct": sum(r["pop_saved_pct"] for r in tier_results) / len(tier_results),
"total_reward": sum(r["total_reward"] for r in tier_results) / len(tier_results),
"episode_steps": sum(r["episode_steps"] for r in tier_results) / len(tier_results),
"runs": tier_results,
}
return results
def print_table(results, tiers, agent_names, seeds):
for tier in tiers:
n = len(seeds)
print(f"\n=== EVAL RESULTS — {tier.capitalize()} Tier ({n} seed{'s' if n != 1 else ''}) ===")
header = f"{'Agent':<28} {'Containment':>12} {'Pop Saved':>10} {'Reward':>8} {'Steps':>7}"
print(header)
print("-" * len(header))
for agent_name in agent_names:
label = AGENT_LABELS.get(agent_name, agent_name)
data = results[agent_name].get(tier)
if data is None:
print(f"{label:<28} {'[skipped — no model]':>39}")
else:
containment = f"{data['containment_pct']*100:.0f}%"
pop_saved = f"{data['pop_saved_pct']*100:.0f}%"
reward = f"{data['total_reward']:+.1f}"
steps = f"{data['episode_steps']:.0f}"
print(f"{label:<28} {containment:>12} {pop_saved:>10} {reward:>8} {steps:>7}")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--seeds", type=int, nargs="+", default=[42, 43, 44, 45, 46])
parser.add_argument("--tiers", nargs="+", choices=["easy", "medium", "hard"], default=["medium", "hard"])
parser.add_argument("--agents", nargs="+", choices=list(AGENT_REGISTRY), default=["random", "heuristic"])
parser.add_argument("--output", default="eval_results.json")
parser.add_argument("--quick", action="store_true", help="Easy tier, 2 seeds only")
args = parser.parse_args()
if args.quick:
args.tiers = ["easy"]
args.seeds = [42, 43]
args.agents = [a for a in args.agents if a in ("random", "heuristic")]
print(f"Running: agents={args.agents}, tiers={args.tiers}, seeds={args.seeds}")
results = run_comparison(args.agents, args.tiers, args.seeds)
print_table(results, args.tiers, args.agents, args.seeds)
out_dir = os.path.dirname(args.output)
if out_dir:
os.makedirs(out_dir, exist_ok=True)
serializable = {}
for agent_name, tier_data in results.items():
serializable[agent_name] = {}
for tier, data in tier_data.items():
if data is None:
serializable[agent_name][tier] = None
else:
serializable[agent_name][tier] = {
k: v for k, v in data.items() if k != "runs"
}
serializable[agent_name][tier]["runs"] = data["runs"]
with open(args.output, "w") as f:
json.dump(serializable, f, indent=2)
print(f"\nResults saved -> {args.output}")
if __name__ == "__main__":
main()
|