| """ |
| Find demo seeds on medium tier where the heuristic struggles interestingly. |
| |
| Filters for seeds where: |
| (a) a wind shift fires between step 60-90 |
| (b) heuristic loses at least one populated cell |
| (c) heuristic total_reward is between -4.0 and +2.0 |
| |
| Usage: |
| python scripts/find_demo_seed.py |
| """ |
|
|
| import json |
| import os |
| import sys |
|
|
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
| from env import WildfireEnv |
| from agents.heuristic_agent import HeuristicAgent |
|
|
| TIER = "medium" |
| MAX_SEED = 500 |
|
|
|
|
| def scan_seed(seed): |
| env = WildfireEnv() |
| agent = HeuristicAgent() |
| obs = env.reset(task_id=TIER, seed=seed) |
|
|
| total_reward = 0.0 |
| wind_shift_step = None |
| done = False |
|
|
| while not done: |
| action = agent.act(obs) |
| result = env.step(action) |
| total_reward += result.reward |
|
|
| for event in result.info.get("events", []): |
| if "WIND SHIFT" in event and wind_shift_step is None: |
| wind_shift_step = env.current_step |
|
|
| obs = result.observation |
| done = result.done |
|
|
| final = env.state() |
| pop_lost = final.get("population_lost", 0) |
| total_pop = final.get("total_population", 1) or 1 |
|
|
| return { |
| "seed": seed, |
| "total_reward": round(total_reward, 3), |
| "pop_lost": pop_lost, |
| "pop_saved_pct": round(1.0 - pop_lost / total_pop, 3), |
| "wind_shift_step": wind_shift_step, |
| "steps": env.current_step, |
| "containment_pct": round(final.get("containment_pct", 0.0), 3), |
| } |
|
|
|
|
| def main(): |
| candidates = [] |
| print(f"Scanning seeds 0-{MAX_SEED - 1} on {TIER} tier...") |
|
|
| for seed in range(MAX_SEED): |
| if seed % 50 == 0: |
| print(f" seed {seed}...") |
| info = scan_seed(seed) |
|
|
| wind_ok = (info["wind_shift_step"] is not None |
| and 60 <= info["wind_shift_step"] <= 90) |
| pop_ok = info["pop_lost"] >= 1 |
| reward_ok = -4.0 <= info["total_reward"] <= 2.0 |
|
|
| if wind_ok and pop_ok and reward_ok: |
| candidates.append(info) |
|
|
| candidates.sort(key=lambda x: x["total_reward"], reverse=True) |
| top5 = candidates[:5] |
|
|
| for c in top5: |
| ws = c["wind_shift_step"] |
| print(f" seed={c['seed']:3d} reward={c['total_reward']:+.2f} " |
| f"pop_lost={c['pop_lost']} wind_shift=step {ws} " |
| f"steps={c['steps']}") |
|
|
| os.makedirs("demos", exist_ok=True) |
| with open("demos/candidate_seeds.json", "w") as f: |
| json.dump(top5, f, indent=2) |
| print(f"\nTop {len(top5)} candidates saved -> demos/candidate_seeds.json") |
|
|
| if not top5: |
| print("No candidates matched all 3 filters — relaxing pop_lost filter...") |
| fallback = [scan_seed(s) for s in [42, 7, 13, 99, 123]] |
| fallback.sort(key=lambda x: x["total_reward"], reverse=True) |
| with open("demos/candidate_seeds.json", "w") as f: |
| json.dump(fallback, f, indent=2) |
| print("Saved fallback candidates.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|