File size: 5,576 Bytes
363abf3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 | """
Demo runner — runs heuristic (and optionally trained LLM) on the chosen demo seed,
generates GIF(s), and prints a play-by-play narrative.
Chosen demo seed:
DEMO_SEED = 7
Medium tier, seed 7: wind shift fires around step 70, heuristic loses a
populated cell on the south flank while over-committing crews north.
This makes the contrast between a reactive heuristic and a planning LLM
visible in a single GIF.
Usage:
python scripts/run_demo.py # heuristic on DEMO_SEED
python scripts/run_demo.py --seed 42
python scripts/run_demo.py --agent trained_llm # requires TRAINED_MODEL_PATH env var
"""
import argparse
import json
import os
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from env import WildfireEnv
from env.rendering import render_frame, render_episode_gif
from agents.heuristic_agent import HeuristicAgent
from agents.random_agent import RandomAgent
DEMO_SEED = 7
TIER = "medium"
MAX_STEPS = 150
def _load_trained_agent():
model_path = os.environ.get("TRAINED_MODEL_PATH")
if not model_path:
return None, "[skipped — TRAINED_MODEL_PATH not set]"
try:
from agents.llm_agent import LLMAgent # type: ignore
return LLMAgent(model_path=model_path), model_path
except ImportError:
return None, "[skipped — agents.llm_agent not found]"
def run_episode_with_narrative(agent, seed, gif_path):
env = WildfireEnv()
obs = env.reset(task_id=TIER, seed=seed)
frames = [render_frame(env.state(), step=0)]
total_reward = 0.0
events_narrative = []
done = False
while not done:
action = agent.act(obs)
result = env.step(action)
total_reward += result.reward
step = env.current_step
frames.append(render_frame(env.state(), step=step))
for event in result.info.get("events", []):
if any(kw in event for kw in ("WIND SHIFT", "populated", "crew", "casualty",
"IGNITION", "suppressed", "firebreak")):
events_narrative.append((step, event))
obs = result.observation
done = result.done
os.makedirs(os.path.dirname(gif_path) or ".", exist_ok=True)
render_episode_gif(frames, gif_path)
import imageio.v3 as iio
png_path = os.path.splitext(gif_path)[0] + ".png"
iio.imwrite(png_path, frames[-1], extension=".png")
final = env.state()
total_pop = final.get("total_population", 1) or 1
stats = {
"steps": env.current_step,
"total_reward": round(total_reward, 3),
"pop_lost": final.get("population_lost", 0),
"pop_saved_pct": round((1 - final.get("population_lost", 0) / total_pop) * 100, 1),
"containment_pct": round(final.get("containment_pct", 0.0) * 100, 1),
"cells_burned": final.get("cells_burned", 0),
"crew_casualty": env._crew_casualty_occurred,
}
return stats, events_narrative, gif_path, png_path
def print_narrative(label, stats, events):
print(f"\n{'='*60}")
print(f" {label}")
print(f"{'='*60}")
if events:
print("Play-by-play:")
for step, event in events[:20]:
print(f" Step {step:3d}: {event}")
else:
print(" (no notable events recorded)")
print(f"\nFinal stats:")
print(f" Steps: {stats['steps']}")
print(f" Total reward: {stats['total_reward']:+.3f}")
print(f" Pop saved: {stats['pop_saved_pct']:.1f}%")
print(f" Containment: {stats['containment_pct']:.1f}%")
print(f" Cells burned: {stats['cells_burned']}")
print(f" Crew casualty:{stats['crew_casualty']}")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, default=DEMO_SEED)
parser.add_argument("--agent", choices=["heuristic", "random", "trained_llm"],
default="heuristic")
args = parser.parse_args()
print(f"Demo: tier={TIER}, seed={args.seed}, agent={args.agent}")
# Always run heuristic as baseline
heuristic = HeuristicAgent()
h_stats, h_events, h_gif, h_png = run_episode_with_narrative(
heuristic, args.seed, "demos/heuristic_demo.gif"
)
print_narrative("Heuristic Agent", h_stats, h_events)
print(f"\n GIF -> {h_gif}")
print(f" PNG -> {h_png}")
# Optionally run trained LLM
if args.agent == "trained_llm":
trained, note = _load_trained_agent()
if trained is None:
print(f"\nTrained LLM: {note}")
else:
t_stats, t_events, t_gif, t_png = run_episode_with_narrative(
trained, args.seed, "demos/trained_demo.gif"
)
print_narrative("Trained LLM", t_stats, t_events)
print(f"\n GIF -> {t_gif}")
print(f" PNG -> {t_png}")
print(f"\n{'='*60}")
print(" Side-by-Side Comparison")
print(f"{'='*60}")
print(f"{'Metric':<20} {'Heuristic':>12} {'Trained LLM':>12}")
print("-" * 44)
for key, label in [
("total_reward", "Total Reward"),
("pop_saved_pct", "Pop Saved %"),
("containment_pct", "Containment %"),
("steps", "Steps"),
]:
h_val = h_stats[key]
t_val = t_stats[key]
fmt = "{:+.2f}" if isinstance(h_val, float) else "{}"
print(f"{label:<20} {fmt.format(h_val):>12} {fmt.format(t_val):>12}")
if __name__ == "__main__":
main()
|