| """ |
| Demo runner — runs heuristic (and optionally trained LLM) on the chosen demo seed, |
| generates GIF(s), and prints a play-by-play narrative. |
| |
| Chosen demo seed: |
| DEMO_SEED = 7 |
| Medium tier, seed 7: wind shift fires around step 70, heuristic loses a |
| populated cell on the south flank while over-committing crews north. |
| This makes the contrast between a reactive heuristic and a planning LLM |
| visible in a single GIF. |
| |
| Usage: |
| python scripts/run_demo.py # heuristic on DEMO_SEED |
| python scripts/run_demo.py --seed 42 |
| python scripts/run_demo.py --agent trained_llm # requires TRAINED_MODEL_PATH env var |
| """ |
|
|
| import argparse |
| import json |
| import os |
| import sys |
|
|
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
| from env import WildfireEnv |
| from env.rendering import render_frame, render_episode_gif |
| from agents.heuristic_agent import HeuristicAgent |
| from agents.random_agent import RandomAgent |
|
|
| DEMO_SEED = 7 |
| TIER = "medium" |
| MAX_STEPS = 150 |
|
|
|
|
| def _load_trained_agent(): |
| model_path = os.environ.get("TRAINED_MODEL_PATH") |
| if not model_path: |
| return None, "[skipped — TRAINED_MODEL_PATH not set]" |
| try: |
| from agents.llm_agent import LLMAgent |
| return LLMAgent(model_path=model_path), model_path |
| except ImportError: |
| return None, "[skipped — agents.llm_agent not found]" |
|
|
|
|
| def run_episode_with_narrative(agent, seed, gif_path): |
| env = WildfireEnv() |
| obs = env.reset(task_id=TIER, seed=seed) |
| frames = [render_frame(env.state(), step=0)] |
| total_reward = 0.0 |
| events_narrative = [] |
| done = False |
|
|
| while not done: |
| action = agent.act(obs) |
| result = env.step(action) |
| total_reward += result.reward |
| step = env.current_step |
| frames.append(render_frame(env.state(), step=step)) |
|
|
| for event in result.info.get("events", []): |
| if any(kw in event for kw in ("WIND SHIFT", "populated", "crew", "casualty", |
| "IGNITION", "suppressed", "firebreak")): |
| events_narrative.append((step, event)) |
|
|
| obs = result.observation |
| done = result.done |
|
|
| os.makedirs(os.path.dirname(gif_path) or ".", exist_ok=True) |
| render_episode_gif(frames, gif_path) |
|
|
| import imageio.v3 as iio |
| png_path = os.path.splitext(gif_path)[0] + ".png" |
| iio.imwrite(png_path, frames[-1], extension=".png") |
|
|
| final = env.state() |
| total_pop = final.get("total_population", 1) or 1 |
| stats = { |
| "steps": env.current_step, |
| "total_reward": round(total_reward, 3), |
| "pop_lost": final.get("population_lost", 0), |
| "pop_saved_pct": round((1 - final.get("population_lost", 0) / total_pop) * 100, 1), |
| "containment_pct": round(final.get("containment_pct", 0.0) * 100, 1), |
| "cells_burned": final.get("cells_burned", 0), |
| "crew_casualty": env._crew_casualty_occurred, |
| } |
| return stats, events_narrative, gif_path, png_path |
|
|
|
|
| def print_narrative(label, stats, events): |
| print(f"\n{'='*60}") |
| print(f" {label}") |
| print(f"{'='*60}") |
| if events: |
| print("Play-by-play:") |
| for step, event in events[:20]: |
| print(f" Step {step:3d}: {event}") |
| else: |
| print(" (no notable events recorded)") |
| print(f"\nFinal stats:") |
| print(f" Steps: {stats['steps']}") |
| print(f" Total reward: {stats['total_reward']:+.3f}") |
| print(f" Pop saved: {stats['pop_saved_pct']:.1f}%") |
| print(f" Containment: {stats['containment_pct']:.1f}%") |
| print(f" Cells burned: {stats['cells_burned']}") |
| print(f" Crew casualty:{stats['crew_casualty']}") |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--seed", type=int, default=DEMO_SEED) |
| parser.add_argument("--agent", choices=["heuristic", "random", "trained_llm"], |
| default="heuristic") |
| args = parser.parse_args() |
|
|
| print(f"Demo: tier={TIER}, seed={args.seed}, agent={args.agent}") |
|
|
| |
| heuristic = HeuristicAgent() |
| h_stats, h_events, h_gif, h_png = run_episode_with_narrative( |
| heuristic, args.seed, "demos/heuristic_demo.gif" |
| ) |
| print_narrative("Heuristic Agent", h_stats, h_events) |
| print(f"\n GIF -> {h_gif}") |
| print(f" PNG -> {h_png}") |
|
|
| |
| if args.agent == "trained_llm": |
| trained, note = _load_trained_agent() |
| if trained is None: |
| print(f"\nTrained LLM: {note}") |
| else: |
| t_stats, t_events, t_gif, t_png = run_episode_with_narrative( |
| trained, args.seed, "demos/trained_demo.gif" |
| ) |
| print_narrative("Trained LLM", t_stats, t_events) |
| print(f"\n GIF -> {t_gif}") |
| print(f" PNG -> {t_png}") |
|
|
| print(f"\n{'='*60}") |
| print(" Side-by-Side Comparison") |
| print(f"{'='*60}") |
| print(f"{'Metric':<20} {'Heuristic':>12} {'Trained LLM':>12}") |
| print("-" * 44) |
| for key, label in [ |
| ("total_reward", "Total Reward"), |
| ("pop_saved_pct", "Pop Saved %"), |
| ("containment_pct", "Containment %"), |
| ("steps", "Steps"), |
| ]: |
| h_val = h_stats[key] |
| t_val = t_stats[key] |
| fmt = "{:+.2f}" if isinstance(h_val, float) else "{}" |
| print(f"{label:<20} {fmt.format(h_val):>12} {fmt.format(t_val):>12}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|