Eshit commited on
Commit
ad92ece
·
1 Parent(s): 363abf3

Improve wildfire metrics and training assets

Browse files
env/models.py CHANGED
@@ -288,7 +288,13 @@ class ClusterStats(BaseModel):
288
  cells_saved: int = 0
289
  population_threatened: int = 0
290
  population_lost: int = 0
 
291
  containment_pct: float = Field(ge=0.0, le=100.0, default=0.0)
 
 
 
 
 
292
  current_step: int = 0
293
  max_steps: int = 100
294
  firebreaks_built: int = 0
@@ -342,6 +348,7 @@ class TierConfig(BaseModel):
342
  crew_loss_step: Optional[int] = None
343
  crew_loss_id: Optional[str] = None
344
  tanker_cooldown: int = 5
 
345
  wind_speed_init: float = 10.0
346
  wind_dir_init: float = 0.0
347
  humidity_init: float = 40.0
@@ -367,11 +374,12 @@ TIER_EASY = TierConfig(
367
  firebreak_budget=15,
368
  recon_budget=0,
369
  episode_length=80,
370
- num_ignition_points=1,
371
  enable_smoke_occlusion=False,
372
  enable_sensor_noise=False,
373
  enable_fog_of_war=False,
374
  enable_wind_shifts=False,
 
375
  wind_speed_init=10.0,
376
  wind_dir_init=0.0,
377
  humidity_init=40.0,
@@ -391,11 +399,12 @@ TIER_MEDIUM = TierConfig(
391
  firebreak_budget=20,
392
  recon_budget=1,
393
  episode_length=150,
394
- num_ignition_points=2,
395
  enable_smoke_occlusion=True,
396
  enable_sensor_noise=True,
397
  enable_fog_of_war=False,
398
  enable_wind_shifts=True,
 
399
  wind_speed_init=15.0,
400
  wind_dir_init=45.0,
401
  humidity_init=35.0,
@@ -417,6 +426,7 @@ TIER_HARD = TierConfig(
417
  episode_length=300,
418
  num_ignition_points=3,
419
  staggered_ignition_step=30,
 
420
  enable_smoke_occlusion=True,
421
  enable_sensor_noise=True,
422
  enable_fog_of_war=True,
 
288
  cells_saved: int = 0
289
  population_threatened: int = 0
290
  population_lost: int = 0
291
+ total_population: int = Field(ge=0, default=0, description="Initial population (for UI % civ safe)")
292
  containment_pct: float = Field(ge=0.0, le=100.0, default=0.0)
293
+ # Meaningful progress metrics shown to agent and display
294
+ area_saved_pct: float = Field(ge=0.0, le=100.0, default=100.0,
295
+ description="Percentage of burnable land not yet burned")
296
+ civilians_saved_pct: float = Field(ge=0.0, le=100.0, default=100.0,
297
+ description="Percentage of civilians in unburned zones")
298
  current_step: int = 0
299
  max_steps: int = 100
300
  firebreaks_built: int = 0
 
348
  crew_loss_step: Optional[int] = None
349
  crew_loss_id: Optional[str] = None
350
  tanker_cooldown: int = 5
351
+ min_active_steps: int = 5 # episode cannot end via fire-out before this step
352
  wind_speed_init: float = 10.0
353
  wind_dir_init: float = 0.0
354
  humidity_init: float = 40.0
 
374
  firebreak_budget=15,
375
  recon_budget=0,
376
  episode_length=80,
377
+ num_ignition_points=2,
378
  enable_smoke_occlusion=False,
379
  enable_sensor_noise=False,
380
  enable_fog_of_war=False,
381
  enable_wind_shifts=False,
382
+ min_active_steps=25,
383
  wind_speed_init=10.0,
384
  wind_dir_init=0.0,
385
  humidity_init=40.0,
 
399
  firebreak_budget=20,
400
  recon_budget=1,
401
  episode_length=150,
402
+ num_ignition_points=3,
403
  enable_smoke_occlusion=True,
404
  enable_sensor_noise=True,
405
  enable_fog_of_war=False,
406
  enable_wind_shifts=True,
407
+ min_active_steps=45,
408
  wind_speed_init=15.0,
409
  wind_dir_init=45.0,
410
  humidity_init=35.0,
 
426
  episode_length=300,
427
  num_ignition_points=3,
428
  staggered_ignition_step=30,
429
+ min_active_steps=80,
430
  enable_smoke_occlusion=True,
431
  enable_sensor_noise=True,
432
  enable_fog_of_war=True,
env/serialization.py CHANGED
@@ -12,8 +12,14 @@ if TYPE_CHECKING:
12
  from .models import FireState, IntensityBin
13
 
14
 
15
- def serialize_observation(obs: "Observation", step_num: int, max_steps: int) -> str:
16
- situation = _format_situation(obs)
 
 
 
 
 
 
17
  grid_summary = _summarize_grid_regions(obs.grid)
18
  resources = _format_resources(obs.resources)
19
  events = _format_events(obs.recent_events)
@@ -29,8 +35,9 @@ def serialize_observation(obs: "Observation", step_num: int, max_steps: int) ->
29
  parts.append(obs._briefing_reminder)
30
  parts.append("")
31
 
 
32
  parts += [
33
- f"=== WILDFIRE INCIDENT COMMAND — STEP {step_num}/{max_steps} ===",
34
  "",
35
  "SITUATION:",
36
  situation,
@@ -52,12 +59,14 @@ def serialize_observation(obs: "Observation", step_num: int, max_steps: int) ->
52
 
53
  # ── Situation block ──────────────────────────────────────────
54
 
55
- def _format_situation(obs: "Observation") -> str:
56
  stats = obs.stats
57
  w = obs.weather
58
 
59
  burning = stats.cells_burning
60
- containment = round(stats.containment_pct, 1)
 
 
61
  pop_at_risk = stats.population_threatened
62
 
63
  wind_dir = _deg_to_compass(w.wind_direction_deg)
@@ -65,8 +74,19 @@ def _format_situation(obs: "Observation") -> str:
65
 
66
  last_event = obs.recent_events[-1] if obs.recent_events else "None"
67
 
 
 
 
 
 
 
 
 
 
68
  lines = [
69
- f"- Fire active on {burning} cells. Containment: {containment}%. Population at risk: {pop_at_risk} zones.",
 
 
70
  f"- Wind: {w.wind_speed_kmh:.0f} km/h {wind_dir} (±5 km/h noise). Humidity: {w.humidity_pct:.0f}%. Rain: {rain}.",
71
  f"- Last event: {last_event}",
72
  ]
 
12
  from .models import FireState, IntensityBin
13
 
14
 
15
+ def serialize_observation(
16
+ obs: "Observation",
17
+ step_num: int,
18
+ max_steps: int,
19
+ tier: str = "",
20
+ prev_cells_burning: int = 0,
21
+ ) -> str:
22
+ situation = _format_situation(obs, prev_cells_burning)
23
  grid_summary = _summarize_grid_regions(obs.grid)
24
  resources = _format_resources(obs.resources)
25
  events = _format_events(obs.recent_events)
 
35
  parts.append(obs._briefing_reminder)
36
  parts.append("")
37
 
38
+ tier_str = f" [{tier.upper()}]" if tier else ""
39
  parts += [
40
+ f"=== WILDFIRE INCIDENT COMMAND{tier_str} — STEP {step_num}/{max_steps} ===",
41
  "",
42
  "SITUATION:",
43
  situation,
 
59
 
60
  # ── Situation block ──────────────────────────────────────────
61
 
62
+ def _format_situation(obs: "Observation", prev_cells_burning: int = 0) -> str:
63
  stats = obs.stats
64
  w = obs.weather
65
 
66
  burning = stats.cells_burning
67
+ land_saved = round(stats.area_saved_pct, 1)
68
+ civ_safe = round(stats.civilians_saved_pct, 1)
69
+ cells_burned = stats.cells_burned
70
  pop_at_risk = stats.population_threatened
71
 
72
  wind_dir = _deg_to_compass(w.wind_direction_deg)
 
74
 
75
  last_event = obs.recent_events[-1] if obs.recent_events else "None"
76
 
77
+ # Spread delta — positive means fire is growing, negative means shrinking
78
+ delta = burning - prev_cells_burning
79
+ if delta > 0:
80
+ spread_str = f" (+{delta} spreading)"
81
+ elif delta < 0:
82
+ spread_str = f" ({delta} shrinking)"
83
+ else:
84
+ spread_str = " (stable)"
85
+
86
  lines = [
87
+ f"- Fire active on {burning} cells{spread_str}. Land saved: {land_saved}% of burnable area "
88
+ f"({cells_burned} cells burned out). Civilians safe: {civ_safe}%. "
89
+ f"Population at risk: {pop_at_risk} zones.",
90
  f"- Wind: {w.wind_speed_kmh:.0f} km/h {wind_dir} (±5 km/h noise). Humidity: {w.humidity_pct:.0f}%. Rain: {rain}.",
91
  f"- Last event: {last_event}",
92
  ]
env/wildfire_env.py CHANGED
@@ -212,6 +212,16 @@ class WildfireEnv:
212
 
213
  self.current_step += 1
214
 
 
 
 
 
 
 
 
 
 
 
215
  # ── Step 9: Compute reward ──
216
  legacy_reward = self.reward_calc.compute_reward(self.grid, self.resources, self.current_step)
217
 
@@ -324,19 +334,39 @@ class WildfireEnv:
324
  }
325
 
326
  def _is_redundant(self, action: Action) -> bool:
327
- """True if action repeats the same type + target coords as the previous action."""
 
 
 
 
 
 
 
328
  if self._prev_action is None:
329
  return False
330
  prev = self._prev_action
331
  if action.action_type != prev.action_type:
332
  return False
333
- return action.target_row == prev.target_row and action.target_col == prev.target_col
 
 
 
 
 
 
 
 
334
 
335
  def _ignite_initial_fires(self) -> None:
336
  """Place initial fire ignition points based on tier config.
337
 
338
  Ignition candidates are shifted away from populated cells to ensure
339
  a minimum survivable distance, reducing unwinnable-scenario variance.
 
 
 
 
 
340
  """
341
  rows, cols = self.config.grid_rows, self.config.grid_cols
342
 
@@ -344,19 +374,25 @@ class WildfireEnv:
344
  min_pop_dist = {"easy": 4, "medium": 6, "hard": 7}.get(self.config.tier_name, 5)
345
 
346
  if self.config.tier_name == "easy":
347
- r, c = self._find_ignition_candidate(rows // 2, cols // 2, min_pop_dist)
348
- self.grid.ignite_cell(r, c, intensity=0.3)
 
 
 
349
  elif self.config.tier_name == "medium":
350
- r1, c1 = self._find_ignition_candidate(rows // 3, cols // 3, min_pop_dist)
351
- self.grid.ignite_cell(r1, c1, intensity=0.3)
 
352
  r2, c2 = self._find_ignition_candidate(2 * rows // 3, 2 * cols // 3, min_pop_dist)
353
- self.grid.ignite_cell(r2, c2, intensity=0.3)
 
 
354
  else:
355
- # Two initial points (third comes later via staggered ignition)
356
  r1, c1 = self._find_ignition_candidate(rows // 4, cols // 4, min_pop_dist)
357
- self.grid.ignite_cell(r1, c1, intensity=0.3)
358
  r2, c2 = self._find_ignition_candidate(rows // 2, 3 * cols // 4, min_pop_dist)
359
- self.grid.ignite_cell(r2, c2, intensity=0.3)
360
 
361
  def _find_ignition_candidate(self, target_r: int, target_c: int, min_pop_dist: int) -> tuple[int, int]:
362
  """Return the nearest valid ignition cell to (target_r, target_c) that is at
@@ -482,11 +518,17 @@ class WildfireEnv:
482
  # Fire fully contained (no burning cells)
483
  burning = self.grid.count_by_state(FireState.BURNING)
484
  ember = self.grid.count_by_state(FireState.EMBER)
485
- if burning == 0 and ember == 0 and self.current_step > 1:
486
- # Don't end on step 0-1 (fire just started)
487
- if not (self.config.staggered_ignition_step
 
 
 
 
 
488
  and self.current_step < self.config.staggered_ignition_step):
489
- return True
 
490
 
491
  # All populated zones burned (catastrophic failure)
492
  total_pop = self.grid.get_total_population()
@@ -514,13 +556,29 @@ class WildfireEnv:
514
  resource_state = self.resources.get_resource_state()
515
 
516
  # Stats
 
 
 
 
 
 
 
 
 
 
 
 
 
517
  stats = ClusterStats(
518
- cells_burned=self.grid.get_burned_count(),
519
  cells_burning=self.grid.count_by_state(FireState.BURNING),
520
- cells_saved=self.grid.get_total_burnable() - self.grid.get_burned_count() - self.grid.count_by_state(FireState.BURNING),
521
  population_threatened=self._count_threatened_population(),
522
- population_lost=self.grid.get_population_lost(),
 
523
  containment_pct=self._compute_containment_pct(),
 
 
524
  current_step=self.current_step,
525
  max_steps=self.config.episode_length,
526
  firebreaks_built=self.resources.total_firebreaks_built,
 
212
 
213
  self.current_step += 1
214
 
215
+ # Log a hold-message when fire is extinguished before min_active_steps so
216
+ # agents (and the LLM) understand the episode must continue for monitoring.
217
+ burning_now = (self.grid.count_by_state(FireState.BURNING)
218
+ + self.grid.count_by_state(FireState.EMBER))
219
+ if burning_now == 0 and self.current_step < self.config.min_active_steps:
220
+ step_events.append(
221
+ f"All fires contained. Holding perimeter until step "
222
+ f"{self.config.min_active_steps} (min_active_steps)."
223
+ )
224
+
225
  # ── Step 9: Compute reward ──
226
  legacy_reward = self.reward_calc.compute_reward(self.grid, self.resources, self.current_step)
227
 
 
334
  }
335
 
336
  def _is_redundant(self, action: Action) -> bool:
337
+ """True if action is a meaningless repeat of the previous action.
338
+
339
+ Actions that use target coordinates (DROP_RETARDANT, DEPLOY_CREW, RECON_FLIGHT)
340
+ are redundant when the type + target cell match. Directional actions (MOVE_CREW,
341
+ BUILD_FIREBREAK) require the same crew_id AND direction to be redundant — two
342
+ consecutive MOVE_CREW steps by different crews, or in different directions, are
343
+ valid patrol behaviour and must not be penalised.
344
+ """
345
  if self._prev_action is None:
346
  return False
347
  prev = self._prev_action
348
  if action.action_type != prev.action_type:
349
  return False
350
+ # Coordinate-targeted actions: redundant when same cell is targeted again
351
+ if action.target_row is not None or prev.target_row is not None:
352
+ return (action.target_row == prev.target_row
353
+ and action.target_col == prev.target_col)
354
+ # Crew directional actions: redundant only when same crew moves same direction
355
+ if action.crew_id is not None:
356
+ return (action.crew_id == prev.crew_id
357
+ and action.direction == prev.direction)
358
+ return False
359
 
360
  def _ignite_initial_fires(self) -> None:
361
  """Place initial fire ignition points based on tier config.
362
 
363
  Ignition candidates are shifted away from populated cells to ensure
364
  a minimum survivable distance, reducing unwinnable-scenario variance.
365
+
366
+ Intensity is set high enough (0.65) that a single tanker drop (-0.4)
367
+ leaves residual fire (0.25) so the episode cannot be solved in 1-2
368
+ steps. The fire must spread, be actively managed, and burn for at
369
+ least min_active_steps before the episode can end.
370
  """
371
  rows, cols = self.config.grid_rows, self.config.grid_cols
372
 
 
374
  min_pop_dist = {"easy": 4, "medium": 6, "hard": 7}.get(self.config.tier_name, 5)
375
 
376
  if self.config.tier_name == "easy":
377
+ # Two ignition points spread across the grid so crews must split
378
+ r1, c1 = self._find_ignition_candidate(rows // 2, cols // 3, min_pop_dist)
379
+ self.grid.ignite_cell(r1, c1, intensity=0.65)
380
+ r2, c2 = self._find_ignition_candidate(rows // 2, 2 * cols // 3, min_pop_dist)
381
+ self.grid.ignite_cell(r2, c2, intensity=0.65)
382
  elif self.config.tier_name == "medium":
383
+ # Three ignition points: forces genuine multi-front management
384
+ r1, c1 = self._find_ignition_candidate(rows // 4, cols // 3, min_pop_dist)
385
+ self.grid.ignite_cell(r1, c1, intensity=0.65)
386
  r2, c2 = self._find_ignition_candidate(2 * rows // 3, 2 * cols // 3, min_pop_dist)
387
+ self.grid.ignite_cell(r2, c2, intensity=0.65)
388
+ r3, c3 = self._find_ignition_candidate(rows // 2, cols // 2, min_pop_dist)
389
+ self.grid.ignite_cell(r3, c3, intensity=0.65)
390
  else:
391
+ # Two initial points (third comes later via staggered ignition at step 30)
392
  r1, c1 = self._find_ignition_candidate(rows // 4, cols // 4, min_pop_dist)
393
+ self.grid.ignite_cell(r1, c1, intensity=0.65)
394
  r2, c2 = self._find_ignition_candidate(rows // 2, 3 * cols // 4, min_pop_dist)
395
+ self.grid.ignite_cell(r2, c2, intensity=0.65)
396
 
397
  def _find_ignition_candidate(self, target_r: int, target_c: int, min_pop_dist: int) -> tuple[int, int]:
398
  """Return the nearest valid ignition cell to (target_r, target_c) that is at
 
518
  # Fire fully contained (no burning cells)
519
  burning = self.grid.count_by_state(FireState.BURNING)
520
  ember = self.grid.count_by_state(FireState.EMBER)
521
+ if burning == 0 and ember == 0:
522
+ # Enforce minimum active steps prevents trivial 1-2 step episodes
523
+ # where a single tanker drop or natural burnout ends the episode
524
+ # before the agent has taken any meaningful sequence of actions.
525
+ if self.current_step < self.config.min_active_steps:
526
+ return False
527
+ # Don't terminate before staggered ignition fires (hard tier)
528
+ if (self.config.staggered_ignition_step
529
  and self.current_step < self.config.staggered_ignition_step):
530
+ return False
531
+ return True
532
 
533
  # All populated zones burned (catastrophic failure)
534
  total_pop = self.grid.get_total_population()
 
556
  resource_state = self.resources.get_resource_state()
557
 
558
  # Stats
559
+ total_burnable = self.grid.get_total_burnable()
560
+ cells_burned = self.grid.get_burned_count()
561
+ total_pop = self.grid.get_total_population()
562
+ pop_lost = self.grid.get_population_lost()
563
+
564
+ area_saved_pct = round(
565
+ 100.0 * (total_burnable - cells_burned) / total_burnable, 1
566
+ ) if total_burnable > 0 else 100.0
567
+
568
+ civilians_saved_pct = round(
569
+ 100.0 * (total_pop - pop_lost) / total_pop, 1
570
+ ) if total_pop > 0 else 100.0
571
+
572
  stats = ClusterStats(
573
+ cells_burned=cells_burned,
574
  cells_burning=self.grid.count_by_state(FireState.BURNING),
575
+ cells_saved=total_burnable - cells_burned - self.grid.count_by_state(FireState.BURNING),
576
  population_threatened=self._count_threatened_population(),
577
+ population_lost=pop_lost,
578
+ total_population=total_pop,
579
  containment_pct=self._compute_containment_pct(),
580
+ area_saved_pct=area_saved_pct,
581
+ civilians_saved_pct=civilians_saved_pct,
582
  current_step=self.current_step,
583
  max_steps=self.config.episode_length,
584
  firebreaks_built=self.resources.total_firebreaks_built,
frontend/app.js CHANGED
@@ -11,6 +11,69 @@
11
 
12
  "use strict";
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  // ── Simulation state ──────────────────────────────────────────────────────────
15
  const sim = {
16
  obs: null, // current Observation (agent's view)
@@ -161,17 +224,28 @@ function renderCanvas(obs, groundTruth = null) {
161
  }
162
 
163
  // ── Stats panel ───────────────────────────────────────────────────────────────
164
- function updateStats(stats, cumulativeReward, lastStepReward) {
165
- if (!stats) return;
166
-
167
- const cur = stats.current_step ?? 0;
168
- const max = stats.max_steps ?? 1;
169
-
170
- setText("stat-step", `${cur} / ${max}`);
171
- setText("stat-containment-val", `${(stats.containment_pct ?? 0).toFixed(1)}%`);
172
- setText("stat-burning-val", stats.cells_burning ?? 0);
173
- setText("stat-pop-threat-val", stats.population_threatened ?? 0);
174
- setText("stat-pop-lost-val", stats.population_lost ?? 0);
 
 
 
 
 
 
 
 
 
 
 
175
 
176
  // Cumulative reward
177
  setText("reward-total", cumulativeReward.toFixed(3));
@@ -298,31 +372,53 @@ function updateActionLog(action) {
298
  }
299
 
300
  // ── Terminal overlay ──────────────────────────────────────────────────────────
301
- function showTerminal(obs) {
302
  const overlay = document.getElementById("terminal-overlay");
303
  if (!overlay) return;
304
 
305
- const stats = obs?.stats ?? {};
306
- const popLost = stats.population_lost ?? 0;
307
- const containment = stats.containment_pct ?? 0;
308
-
309
  const card = document.getElementById("terminal-card");
 
 
 
310
  const title = card.querySelector("h2");
311
 
312
- if (popLost === 0) {
313
- title.textContent = "✅ FIRE CONTAINED";
314
  title.className = "win";
315
  } else {
316
  title.textContent = "⚠ EPISODE ENDED";
317
  title.className = "loss";
318
  }
319
 
320
- setText("terminal-containment", `${containment.toFixed(1)}%`);
321
- setText("terminal-pop-lost", popLost);
322
- setText("terminal-reward", sim.cumulativeReward.toFixed(3));
323
- setText("terminal-step", stats.current_step ?? "—");
 
 
 
 
324
 
325
  overlay.classList.add("show");
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
326
  }
327
 
328
  function hideTerminal() {
@@ -356,7 +452,7 @@ async function apiGet(path) {
356
  function applyObservation(obs) {
357
  sim.obs = obs;
358
  renderCanvas(obs, sim.groundTruthData);
359
- updateStats(obs.stats, sim.cumulativeReward, sim.lastStepReward);
360
  updateResources(obs.resources);
361
  updateWeather(obs.weather);
362
  updateEvents(obs.recent_events ?? []);
@@ -417,7 +513,7 @@ async function doAutoStep() {
417
 
418
  if (snap.done) {
419
  stopPlay();
420
- showTerminal(snap.observation);
421
  break;
422
  }
423
  }
 
11
 
12
  "use strict";
13
 
14
+ // ── API field helpers (snake_case from Python; tolerate camelCase if ever used) ─
15
+ function pickStat(obj, ...keys) {
16
+ if (!obj) return undefined;
17
+ for (const k of keys) {
18
+ if (Object.prototype.hasOwnProperty.call(obj, k) && obj[k] != null) {
19
+ return obj[k];
20
+ }
21
+ }
22
+ return undefined;
23
+ }
24
+
25
+ /**
26
+ * Build display-ready episode metrics from the latest observation.
27
+ * Falls back to grid-visible cells for land % only when server omits area_saved_pct.
28
+ */
29
+ function normalizeEpisodeStats(obs) {
30
+ const st = obs?.stats ?? {};
31
+ const cellsBurned = pickStat(st, "cells_burned", "cellsBurned") ?? 0;
32
+ const popLost = pickStat(st, "population_lost", "populationLost") ?? 0;
33
+ const totalPop = pickStat(st, "total_population", "totalPopulation") ?? 0;
34
+
35
+ let areaSaved = pickStat(st, "area_saved_pct", "areaSavedPct");
36
+ let civSafe = pickStat(st, "civilians_saved_pct", "civiliansSavedPct");
37
+
38
+ if (areaSaved == null && obs?.grid?.length) {
39
+ let burnable = 0;
40
+ let burnedVis = 0;
41
+ for (const row of obs.grid) {
42
+ for (const cell of row) {
43
+ const f = cell.fuel_type;
44
+ if (!f || f === "water" || f === "road") continue;
45
+ if (cell.fire_state === "unknown") continue;
46
+ burnable++;
47
+ if (cell.fire_state === "burned_out") burnedVis++;
48
+ }
49
+ }
50
+ if (burnable > 0) {
51
+ areaSaved = Math.round(1000 * (burnable - burnedVis) / burnable) / 10;
52
+ }
53
+ }
54
+
55
+ if (civSafe == null && totalPop > 0) {
56
+ civSafe = Math.round(1000 * (totalPop - popLost) / totalPop) / 10;
57
+ } else if (civSafe == null && popLost === 0) {
58
+ civSafe = 100.0;
59
+ }
60
+
61
+ const containment = pickStat(st, "containment_pct", "containmentPct");
62
+ if (areaSaved == null && containment != null) {
63
+ areaSaved = containment;
64
+ }
65
+
66
+ return {
67
+ areaSaved,
68
+ civSafe,
69
+ cellsBurned,
70
+ popLost,
71
+ totalPop,
72
+ currentStep: pickStat(st, "current_step", "currentStep"),
73
+ raw: st,
74
+ };
75
+ }
76
+
77
  // ── Simulation state ──────────────────────────────────────────────────────────
78
  const sim = {
79
  obs: null, // current Observation (agent's view)
 
224
  }
225
 
226
  // ── Stats panel ───────────────────────────────────────────────────────────────
227
+ function updateStats(obs, cumulativeReward, lastStepReward) {
228
+ if (!obs?.stats) return;
229
+ const stats = obs.stats;
230
+
231
+ const cur = pickStat(stats, "current_step", "currentStep") ?? 0;
232
+ const max = pickStat(stats, "max_steps", "maxSteps") ?? 1;
233
+
234
+ setText("stat-step", `${cur} / ${max}`);
235
+
236
+ const n = normalizeEpisodeStats(obs);
237
+ setText(
238
+ "stat-land-saved-val",
239
+ n.areaSaved != null ? `${Number(n.areaSaved).toFixed(1)}%` : "—"
240
+ );
241
+ setText(
242
+ "stat-civilians-safe-val",
243
+ n.civSafe != null ? `${Number(n.civSafe).toFixed(1)}%` : "—"
244
+ );
245
+ setText("stat-cells-burned-val", n.cellsBurned);
246
+ setText("stat-burning-val", pickStat(stats, "cells_burning", "cellsBurning") ?? 0);
247
+ setText("stat-pop-threat-val", pickStat(stats, "population_threatened", "populationThreatened") ?? 0);
248
+ setText("stat-pop-lost-val", n.popLost);
249
 
250
  // Cumulative reward
251
  setText("reward-total", cumulativeReward.toFixed(3));
 
372
  }
373
 
374
  // ── Terminal overlay ──────────────────────────────────────────────────────────
375
+ async function showTerminal() {
376
  const overlay = document.getElementById("terminal-overlay");
377
  if (!overlay) return;
378
 
 
 
 
 
379
  const card = document.getElementById("terminal-card");
380
+ if (!card) return;
381
+
382
+ const n = normalizeEpisodeStats(sim.obs);
383
  const title = card.querySelector("h2");
384
 
385
+ if (n.popLost === 0) {
386
+ title.textContent = "✅ EPISODE COMPLETE";
387
  title.className = "win";
388
  } else {
389
  title.textContent = "⚠ EPISODE ENDED";
390
  title.className = "loss";
391
  }
392
 
393
+ const landStr = n.areaSaved != null ? `${Number(n.areaSaved).toFixed(1)}%` : "—";
394
+ const civStr = n.civSafe != null ? `${Number(n.civSafe).toFixed(1)}%` : "";
395
+ setText("terminal-land-saved", landStr);
396
+ setText("terminal-civilians-safe", civStr);
397
+ setText("terminal-cells-burned", String(n.cellsBurned));
398
+ setText("terminal-pop-lost", n.popLost);
399
+ setText("terminal-reward", sim.cumulativeReward.toFixed(3));
400
+ setText("terminal-step", n.currentStep ?? "—");
401
 
402
  overlay.classList.add("show");
403
+
404
+ // Authoritative end-game numbers (ground truth — fixes blank UI if observation JSON differed)
405
+ try {
406
+ const st = await apiGet("/state");
407
+ if (st.error) return;
408
+ const tb = st.total_burnable ?? 0;
409
+ const burned = st.cells_burned ?? 0;
410
+ const landPct = tb > 0 ? Math.round(1000 * (tb - burned) / tb) / 10 : 100;
411
+ const tp = st.total_population ?? 0;
412
+ const lost = st.population_lost ?? 0;
413
+ const civPct = tp > 0 ? Math.round(1000 * (tp - lost) / tp) / 10 : 100;
414
+ setText("terminal-land-saved", `${landPct}%`);
415
+ setText("terminal-civilians-safe", `${civPct}%`);
416
+ setText("terminal-cells-burned", String(burned));
417
+ setText("terminal-pop-lost", String(lost));
418
+ setText("terminal-step", st.current_step ?? "—");
419
+ } catch (e) {
420
+ console.warn("Could not refresh end-game stats from /state", e);
421
+ }
422
  }
423
 
424
  function hideTerminal() {
 
452
  function applyObservation(obs) {
453
  sim.obs = obs;
454
  renderCanvas(obs, sim.groundTruthData);
455
+ updateStats(obs, sim.cumulativeReward, sim.lastStepReward);
456
  updateResources(obs.resources);
457
  updateWeather(obs.weather);
458
  updateEvents(obs.recent_events ?? []);
 
513
 
514
  if (snap.done) {
515
  stopPlay();
516
+ await showTerminal();
517
  break;
518
  }
519
  }
frontend/index.html CHANGED
@@ -83,8 +83,16 @@
83
  <div id="terminal-card">
84
  <h2 class="win">✅ FIRE CONTAINED</h2>
85
  <div class="stat-row">
86
- <span>Containment</span>
87
- <span id="terminal-containment">—</span>
 
 
 
 
 
 
 
 
88
  </div>
89
  <div class="stat-row">
90
  <span>Population lost</span>
@@ -104,6 +112,10 @@
104
  </div>
105
  </div>
106
  </div>
 
 
 
 
107
  </main>
108
 
109
  <!-- Sidebar -->
@@ -117,9 +129,17 @@
117
  <span class="stat-label">STEP</span>
118
  <span class="stat-value" id="stat-step">— / —</span>
119
  </div>
120
- <div class="stat-item" id="stat-containment">
121
- <span class="stat-label">CONTAINMENT</span>
122
- <span class="stat-value" id="stat-containment-val">—</span>
 
 
 
 
 
 
 
 
123
  </div>
124
  <div class="stat-item" id="stat-burning">
125
  <span class="stat-label">BURNING</span>
@@ -274,6 +294,6 @@
274
  </span>
275
  </footer>
276
 
277
- <script src="app.js"></script>
278
  </body>
279
  </html>
 
83
  <div id="terminal-card">
84
  <h2 class="win">✅ FIRE CONTAINED</h2>
85
  <div class="stat-row">
86
+ <span>Land saved (unburned)</span>
87
+ <span id="terminal-land-saved">—</span>
88
+ </div>
89
+ <div class="stat-row">
90
+ <span>Civilians safe</span>
91
+ <span id="terminal-civilians-safe">—</span>
92
+ </div>
93
+ <div class="stat-row">
94
+ <span>Cells burned (total)</span>
95
+ <span id="terminal-cells-burned">—</span>
96
  </div>
97
  <div class="stat-row">
98
  <span>Population lost</span>
 
112
  </div>
113
  </div>
114
  </div>
115
+ <p id="map-legend" class="map-legend">
116
+ <strong>Map:</strong> green dot / circle = ground crew · blue outline = populated zone ·
117
+ bright blue cells = water · grey = roads
118
+ </p>
119
  </main>
120
 
121
  <!-- Sidebar -->
 
129
  <span class="stat-label">STEP</span>
130
  <span class="stat-value" id="stat-step">— / —</span>
131
  </div>
132
+ <div class="stat-item" id="stat-land-saved">
133
+ <span class="stat-label">LAND SAVED</span>
134
+ <span class="stat-value" id="stat-land-saved-val">—</span>
135
+ </div>
136
+ <div class="stat-item" id="stat-civilians-safe">
137
+ <span class="stat-label">CIVILIANS SAFE</span>
138
+ <span class="stat-value" id="stat-civilians-safe-val">—</span>
139
+ </div>
140
+ <div class="stat-item" id="stat-cells-burned">
141
+ <span class="stat-label">CELLS BURNED</span>
142
+ <span class="stat-value" id="stat-cells-burned-val">—</span>
143
  </div>
144
  <div class="stat-item" id="stat-burning">
145
  <span class="stat-label">BURNING</span>
 
294
  </span>
295
  </footer>
296
 
297
+ <script src="app.js?v=4"></script>
298
  </body>
299
  </html>
frontend/style.css CHANGED
@@ -250,6 +250,16 @@ input[type="range"]::-webkit-slider-thumb {
250
 
251
  #grid-canvas { display: block; image-rendering: pixelated; }
252
 
 
 
 
 
 
 
 
 
 
 
253
  /* Tooltip overlay (shows cell info on hover) */
254
  #cell-tooltip {
255
  position: absolute;
@@ -356,8 +366,10 @@ input[type="range"]::-webkit-slider-thumb {
356
  .stat-item.step-item { grid-column: 1 / -1; }
357
  .stat-item.step-item .stat-value { font-size: 14px; }
358
 
359
- #stat-containment .stat-value { color: var(--safe); }
360
- #stat-burning .stat-value { color: var(--fire); }
 
 
361
  #stat-pop-threat .stat-value { color: var(--warn); }
362
  #stat-pop-lost .stat-value { color: var(--crit); }
363
 
 
250
 
251
  #grid-canvas { display: block; image-rendering: pixelated; }
252
 
253
+ .map-legend {
254
+ margin: 8px 0 0;
255
+ padding: 6px 10px;
256
+ font-size: 11px;
257
+ color: var(--text-muted);
258
+ line-height: 1.45;
259
+ max-width: 100%;
260
+ }
261
+ .map-legend strong { color: var(--text); }
262
+
263
  /* Tooltip overlay (shows cell info on hover) */
264
  #cell-tooltip {
265
  position: absolute;
 
366
  .stat-item.step-item { grid-column: 1 / -1; }
367
  .stat-item.step-item .stat-value { font-size: 14px; }
368
 
369
+ #stat-land-saved .stat-value { color: var(--safe); }
370
+ #stat-civilians-safe .stat-value { color: var(--safe); }
371
+ #stat-cells-burned .stat-value { color: var(--warn); }
372
+ #stat-burning .stat-value { color: var(--fire); }
373
  #stat-pop-threat .stat-value { color: var(--warn); }
374
  #stat-pop-lost .stat-value { color: var(--crit); }
375
 
graders/grader_easy.py CHANGED
@@ -27,7 +27,7 @@ def grade(agent, seed: int = 42):
27
 
28
  details = {
29
  "total_reward": round(total_reward, 4),
30
- "containment_pct": round(final.get("containment_pct", 0.0), 4),
31
  "pop_saved_pct": round(1.0 - pop_lost / total_pop, 4),
32
  "steps": env.current_step,
33
  "crew_casualty": env._crew_casualty_occurred,
 
27
 
28
  details = {
29
  "total_reward": round(total_reward, 4),
30
+ "containment_pct": round(final.get("reward_breakdown", {}).get("containment", 0.0), 4),
31
  "pop_saved_pct": round(1.0 - pop_lost / total_pop, 4),
32
  "steps": env.current_step,
33
  "crew_casualty": env._crew_casualty_occurred,
graders/grader_hard.py CHANGED
@@ -27,7 +27,7 @@ def grade(agent, seed: int = 42):
27
 
28
  details = {
29
  "total_reward": round(total_reward, 4),
30
- "containment_pct": round(final.get("containment_pct", 0.0), 4),
31
  "pop_saved_pct": round(1.0 - pop_lost / total_pop, 4),
32
  "steps": env.current_step,
33
  "crew_casualty": env._crew_casualty_occurred,
 
27
 
28
  details = {
29
  "total_reward": round(total_reward, 4),
30
+ "containment_pct": round(final.get("reward_breakdown", {}).get("containment", 0.0), 4),
31
  "pop_saved_pct": round(1.0 - pop_lost / total_pop, 4),
32
  "steps": env.current_step,
33
  "crew_casualty": env._crew_casualty_occurred,
graders/grader_medium.py CHANGED
@@ -27,7 +27,7 @@ def grade(agent, seed: int = 42):
27
 
28
  details = {
29
  "total_reward": round(total_reward, 4),
30
- "containment_pct": round(final.get("containment_pct", 0.0), 4),
31
  "pop_saved_pct": round(1.0 - pop_lost / total_pop, 4),
32
  "steps": env.current_step,
33
  "crew_casualty": env._crew_casualty_occurred,
 
27
 
28
  details = {
29
  "total_reward": round(total_reward, 4),
30
+ "containment_pct": round(final.get("reward_breakdown", {}).get("containment", 0.0), 4),
31
  "pop_saved_pct": round(1.0 - pop_lost / total_pop, 4),
32
  "steps": env.current_step,
33
  "crew_casualty": env._crew_casualty_occurred,
scripts/generate_sft_data.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Generate supervised fine-tuning (SFT) training examples by running the
3
+ HeuristicAgent through episodes and recording (prompt, action) pairs.
4
+
5
+ Usage:
6
+ python scripts/generate_sft_data.py
7
+ python scripts/generate_sft_data.py --output training/sft_data.jsonl --easy-seeds 500
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import argparse
13
+ import json
14
+ import os
15
+ import random
16
+ import sys
17
+ from pathlib import Path
18
+
19
+ PROJECT_ROOT = str(Path(__file__).resolve().parent.parent)
20
+ sys.path.insert(0, PROJECT_ROOT)
21
+
22
+ from env.wildfire_env import WildfireEnv
23
+ from env.serialization import serialize_observation
24
+ from env.models import TIER_EASY, TIER_MEDIUM, TIER_HARD, ActionType
25
+ from agents.heuristic_agent import HeuristicAgent
26
+
27
+ SYSTEM_PROMPT = (
28
+ "You are an AI Incident Commander managing wildfire containment. "
29
+ "You will receive a situation briefing each step. "
30
+ "Respond with ONLY a valid JSON action object and nothing else. "
31
+ 'Example: {"action_type": "idle"}'
32
+ )
33
+
34
+ TIER_CONFIGS = {
35
+ "easy": {"max_steps": TIER_EASY.episode_length, "target": 2000},
36
+ "medium": {"max_steps": TIER_MEDIUM.episode_length, "target": 1500},
37
+ "hard": {"max_steps": TIER_HARD.episode_length, "target": 800},
38
+ }
39
+
40
+
41
+ def run_episode(tier: str, seed: int) -> list[dict] | None:
42
+ """Run a full episode with the HeuristicAgent.
43
+
44
+ Returns a list of raw (prompt, action, step) records for the episode,
45
+ or None if the episode is unsuccessful (population lost > 0).
46
+ """
47
+ max_steps = TIER_CONFIGS[tier]["max_steps"]
48
+ env = WildfireEnv()
49
+ obs = env.reset(task_id=tier, seed=seed)
50
+ agent = HeuristicAgent()
51
+
52
+ offset = random.randint(0, min(30, max_steps // 4))
53
+
54
+ prev_cells_burning = 0
55
+ records: list[dict] = []
56
+ step_num = 0
57
+
58
+ while not env.done:
59
+ action = agent.act(obs)
60
+
61
+ if step_num >= offset:
62
+ prompt_text = serialize_observation(
63
+ obs, step_num, max_steps,
64
+ tier=tier, prev_cells_burning=prev_cells_burning,
65
+ )
66
+ action_json = action.model_dump_json(exclude_none=True)
67
+ records.append({
68
+ "messages": [
69
+ {"role": "system", "content": SYSTEM_PROMPT},
70
+ {"role": "user", "content": prompt_text},
71
+ ],
72
+ "completion": action_json,
73
+ "tier": tier,
74
+ "seed": seed,
75
+ "step": step_num,
76
+ "action_type": action.action_type.value,
77
+ })
78
+
79
+ prev_cells_burning = obs.stats.cells_burning
80
+ result = env.step(action)
81
+ obs = result.observation
82
+ step_num += 1
83
+
84
+ state = env.state()
85
+ if state["population_lost"] != 0:
86
+ return None
87
+
88
+ return records
89
+
90
+
91
+ def filter_idle(records: list[dict]) -> list[dict]:
92
+ """Keep all non-IDLE steps, then cap IDLE steps at 20% of total."""
93
+ non_idle = [r for r in records if r["action_type"] != "idle"]
94
+ idle = [r for r in records if r["action_type"] == "idle"]
95
+
96
+ if not non_idle:
97
+ return idle
98
+
99
+ max_idle = max(1, int(len(non_idle) * 0.25))
100
+ if len(idle) > max_idle:
101
+ random.shuffle(idle)
102
+ idle = idle[:max_idle]
103
+
104
+ combined = non_idle + idle
105
+ combined.sort(key=lambda r: r["step"])
106
+ return combined
107
+
108
+
109
+ def strip_internal_fields(records: list[dict]) -> list[dict]:
110
+ """Remove the action_type helper field before writing."""
111
+ for r in records:
112
+ r.pop("action_type", None)
113
+ return records
114
+
115
+
116
+ def generate(output_path: str, max_seeds: dict[str, int]) -> None:
117
+ os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
118
+
119
+ all_examples: list[dict] = []
120
+ tier_counts = {t: 0 for t in TIER_CONFIGS}
121
+
122
+ for tier in ["easy", "medium", "hard"]:
123
+ target = TIER_CONFIGS[tier]["target"]
124
+ limit = max_seeds[tier]
125
+ seed = 0
126
+
127
+ print(f"\n{'='*50}")
128
+ print(f"Generating {tier} tier (target={target}, max_seeds={limit})")
129
+ print(f"{'='*50}")
130
+
131
+ while tier_counts[tier] < target and seed < limit:
132
+ records = run_episode(tier, seed)
133
+
134
+ if records is not None:
135
+ filtered = filter_idle(records)
136
+ remaining = target - tier_counts[tier]
137
+ if len(filtered) > remaining:
138
+ filtered = filtered[:remaining]
139
+ all_examples.extend(strip_internal_fields(filtered))
140
+ tier_counts[tier] += len(filtered)
141
+
142
+ seed += 1
143
+ if seed % 50 == 0:
144
+ print(f" [{tier}] seed={seed}, examples={tier_counts[tier]}/{target}")
145
+
146
+ print(f" [{tier}] DONE — {tier_counts[tier]} examples from {seed} seeds")
147
+
148
+ with open(output_path, "w", encoding="utf-8") as f:
149
+ for ex in all_examples:
150
+ f.write(json.dumps(ex, ensure_ascii=False) + "\n")
151
+
152
+ total = len(all_examples)
153
+ print(f"\n{'='*50}")
154
+ print(f"SFT data saved to {output_path}")
155
+ print(f"Total examples: {total}")
156
+ print(f"Tier distribution:")
157
+ for tier in ["easy", "medium", "hard"]:
158
+ print(f" {tier}: {tier_counts[tier]}")
159
+ print(f"{'='*50}")
160
+
161
+
162
+ def main():
163
+ parser = argparse.ArgumentParser(description="Generate SFT training data from HeuristicAgent episodes")
164
+ parser.add_argument("--output", default="training/sft_data.jsonl",
165
+ help="Output JSONL file path (default: training/sft_data.jsonl)")
166
+ parser.add_argument("--easy-seeds", type=int, default=500,
167
+ help="Max seeds to try for easy tier")
168
+ parser.add_argument("--medium-seeds", type=int, default=500,
169
+ help="Max seeds to try for medium tier")
170
+ parser.add_argument("--hard-seeds", type=int, default=500,
171
+ help="Max seeds to try for hard tier")
172
+ args = parser.parse_args()
173
+
174
+ max_seeds = {
175
+ "easy": args.easy_seeds,
176
+ "medium": args.medium_seeds,
177
+ "hard": args.hard_seeds,
178
+ }
179
+ generate(args.output, max_seeds)
180
+
181
+
182
+ if __name__ == "__main__":
183
+ main()
scripts/results.json CHANGED
@@ -2,131 +2,101 @@
2
  "random": {
3
  "easy": {
4
  "scores": [
5
- 8.225,
6
- 8.35,
7
- 0.39,
8
- 8.35,
9
- 7.875,
10
- 8.25,
11
- 0.36,
12
- 8.35,
13
- 6.8251,
14
- 5.825
15
  ],
16
- "mean": 6.28,
17
- "std": 3.0546,
18
- "mean_containment_pct": 0.0,
19
- "mean_pop_saved_pct": 0.975,
20
- "mean_steps": 16.1,
21
  "crew_casualty_rate": 0.0,
22
- "mean_time_s": 0.097
23
  },
24
  "medium": {
25
  "scores": [
26
- -1.1475,
27
- 8.3067,
28
- 8.0667,
29
- 7.84,
30
- 0.2919,
31
- 7.2,
32
- 8.3733,
33
- 8.3333,
34
- -1.024,
35
- -3.6238
36
  ],
37
- "mean": 4.2617,
38
- "std": 4.7,
39
- "mean_containment_pct": 0.0,
40
- "mean_pop_saved_pct": 0.9587,
41
- "mean_steps": 32.2,
42
  "crew_casualty_rate": 0.0,
43
- "mean_time_s": 0.468
44
  },
45
  "hard": {
46
  "scores": [
47
- -7.6189,
48
- -3.9186,
49
- 5.3,
50
- 5.2999,
51
- -2.8187,
52
- -2.9395,
53
- -5.5375,
54
- -1.5395,
55
- 5.3,
56
- 5.3
57
  ],
58
- "mean": -0.3173,
59
- "std": 4.8412,
60
- "mean_containment_pct": 0.0,
61
- "mean_pop_saved_pct": 0.9802,
62
- "mean_steps": 44.7,
63
  "crew_casualty_rate": 0.0,
64
- "mean_time_s": 1.298
65
  }
66
  },
67
  "heuristic": {
68
  "easy": {
69
  "scores": [
70
- 8.35,
71
- 8.35,
72
- 8.35,
73
- 8.35,
74
- 8.35,
75
- 8.35,
76
- 8.35,
77
- 8.35,
78
- 8.35,
79
- 8.35
80
  ],
81
- "mean": 8.35,
82
- "std": 0.0,
83
- "mean_containment_pct": 0.0,
84
  "mean_pop_saved_pct": 1.0,
85
- "mean_steps": 2.0,
86
  "crew_casualty_rate": 0.0,
87
- "mean_time_s": 0.021
88
  },
89
  "medium": {
90
  "scores": [
91
- 5.5,
92
- 8.3733,
93
- 8.3733,
94
- 8.3733,
95
- 8.3067,
96
- 7.94,
97
- 8.3733,
98
- 8.3733,
99
- 7.8467,
100
- 7.2933
101
  ],
102
- "mean": 7.8753,
103
- "std": 0.8609,
104
- "mean_containment_pct": 0.0,
105
- "mean_pop_saved_pct": 1.0,
106
- "mean_steps": 11.6,
107
  "crew_casualty_rate": 0.0,
108
- "mean_time_s": 0.214
109
  },
110
  "hard": {
111
  "scores": [
112
- 6.5001,
113
- -5.5396,
114
- 6.0,
115
- 4.6468,
116
- 6.8,
117
- 5.8,
118
- 5.8,
119
- 4.8001,
120
- 5.6,
121
- 5.9
122
  ],
123
- "mean": 4.6307,
124
- "std": 3.4471,
125
- "mean_containment_pct": 0.0,
126
- "mean_pop_saved_pct": 0.9988,
127
- "mean_steps": 41.4,
128
  "crew_casualty_rate": 0.0,
129
- "mean_time_s": 1.384
130
  }
131
  }
132
  }
 
2
  "random": {
3
  "easy": {
4
  "scores": [
5
+ 7.7749,
6
+ 7.7751,
7
+ 7.775,
8
+ 7.775,
9
+ 0.04
 
 
 
 
 
10
  ],
11
+ "mean": 6.228,
12
+ "std": 3.094,
13
+ "mean_containment_pct": 1.0,
14
+ "mean_pop_saved_pct": 0.92,
15
+ "mean_steps": 25.8,
16
  "crew_casualty_rate": 0.0,
17
+ "mean_time_s": 0.067
18
  },
19
  "medium": {
20
  "scores": [
21
+ -1.7044,
22
+ -1.0029,
23
+ 1.0762,
24
+ 0.7527,
25
+ 7.4403
 
 
 
 
 
26
  ],
27
+ "mean": 1.3124,
28
+ "std": 3.2367,
29
+ "mean_containment_pct": 1.0,
30
+ "mean_pop_saved_pct": 0.7365,
31
+ "mean_steps": 72.0,
32
  "crew_casualty_rate": 0.0,
33
+ "mean_time_s": 0.676
34
  },
35
  "hard": {
36
  "scores": [
37
+ 7.8668,
38
+ 1.3602,
39
+ -0.7466,
40
+ 1.0443,
41
+ 1.2813
 
 
 
 
 
42
  ],
43
+ "mean": 2.1612,
44
+ "std": 2.9554,
45
+ "mean_containment_pct": 1.0,
46
+ "mean_pop_saved_pct": 0.9023,
47
+ "mean_steps": 84.6,
48
  "crew_casualty_rate": 0.0,
49
+ "mean_time_s": 1.301
50
  }
51
  },
52
  "heuristic": {
53
  "easy": {
54
  "scores": [
55
+ 7.6749,
56
+ 7.575,
57
+ 7.475,
58
+ 7.475,
59
+ 7.4749
 
 
 
 
 
60
  ],
61
+ "mean": 7.535,
62
+ "std": 0.08,
63
+ "mean_containment_pct": 1.0,
64
  "mean_pop_saved_pct": 1.0,
65
+ "mean_steps": 26.6,
66
  "crew_casualty_rate": 0.0,
67
+ "mean_time_s": 0.118
68
  },
69
  "medium": {
70
  "scores": [
71
+ 7.6001,
72
+ 7.7001,
73
+ 7.8,
74
+ 7.7,
75
+ 0.7683
 
 
 
 
 
76
  ],
77
+ "mean": 6.3137,
78
+ "std": 2.7734,
79
+ "mean_containment_pct": 1.0,
80
+ "mean_pop_saved_pct": 0.9746,
81
+ "mean_steps": 46.2,
82
  "crew_casualty_rate": 0.0,
83
+ "mean_time_s": 0.48
84
  },
85
  "hard": {
86
  "scores": [
87
+ 7.8668,
88
+ 7.867,
89
+ 0.9443,
90
+ 7.6667,
91
+ -0.6696
 
 
 
 
 
92
  ],
93
+ "mean": 4.735,
94
+ "std": 3.7892,
95
+ "mean_containment_pct": 1.0,
96
+ "mean_pop_saved_pct": 0.9279,
97
+ "mean_steps": 83.2,
98
  "crew_casualty_rate": 0.0,
99
+ "mean_time_s": 1.487
100
  }
101
  }
102
  }
training/grpo_v2_colab.ipynb ADDED
@@ -0,0 +1,678 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Wildfire Incident Command - GRPO Training (v2)\n",
8
+ "\n",
9
+ "GRPO reinforcement learning on a wildfire incident command model, starting from the SFT checkpoint.\n",
10
+ "\n",
11
+ "**Five critical issues fixed in this version:**\n",
12
+ "1. Prompt/reward state mismatch - dataset uses step-0 prompts only; reward replays the exact (tier, seed)\n",
13
+ "2. Truncated rollout - reward runs full episode to completion (heuristic continuation), terminal reward always included\n",
14
+ "3. Wasted inner model generations - MODEL_STEPS=1, only the sampled completion is applied\n",
15
+ "4. GRPO loop too slow - consequence of fix 3\n",
16
+ "5. parse_action(text, None) crash - standalone check_json_format() for format reward\n",
17
+ "\n",
18
+ "**Hardware:** A100 40GB on Colab"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "markdown",
23
+ "metadata": {},
24
+ "source": [
25
+ "## Section 1 - Install and Assert GPU"
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "code",
30
+ "execution_count": null,
31
+ "metadata": {},
32
+ "outputs": [],
33
+ "source": [
34
+ "!pip install \"unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git\"\n",
35
+ "!pip install trl==0.15.2 datasets==3.4.1 wandb"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "execution_count": null,
41
+ "metadata": {},
42
+ "outputs": [],
43
+ "source": [
44
+ "import torch\n",
45
+ "assert torch.cuda.is_available(), \"GPU not available - switch to a GPU runtime\"\n",
46
+ "gpu_name = torch.cuda.get_device_name(0)\n",
47
+ "gpu_mem = torch.cuda.get_device_properties(0).total_mem / 1e9\n",
48
+ "print(f\"GPU: {gpu_name} | VRAM: {gpu_mem:.1f} GB\")"
49
+ ]
50
+ },
51
+ {
52
+ "cell_type": "markdown",
53
+ "metadata": {},
54
+ "source": [
55
+ "## Section 2 - Load SFT Checkpoint"
56
+ ]
57
+ },
58
+ {
59
+ "cell_type": "code",
60
+ "execution_count": null,
61
+ "metadata": {},
62
+ "outputs": [],
63
+ "source": [
64
+ "from unsloth import FastLanguageModel\n",
65
+ "\n",
66
+ "# Option A: Load from HuggingFace Hub\n",
67
+ "SFT_MODEL = \"Eshit/wildfire-sft-7b\"\n",
68
+ "# Option B: Load from local zip (uncomment and adjust if needed)\n",
69
+ "# !unzip sft_final.zip -d sft_final_dir\n",
70
+ "# SFT_MODEL = \"./sft_final_dir/sft_final\"\n",
71
+ "\n",
72
+ "model, tokenizer = FastLanguageModel.from_pretrained(\n",
73
+ " model_name=SFT_MODEL,\n",
74
+ " max_seq_length=2048,\n",
75
+ " load_in_4bit=True,\n",
76
+ ")\n",
77
+ "\n",
78
+ "if tokenizer.pad_token is None:\n",
79
+ " tokenizer.pad_token = tokenizer.eos_token\n",
80
+ "\n",
81
+ "print(f\"Loaded SFT checkpoint: {SFT_MODEL}\")\n",
82
+ "model.print_trainable_parameters()"
83
+ ]
84
+ },
85
+ {
86
+ "cell_type": "markdown",
87
+ "metadata": {},
88
+ "source": [
89
+ "## Section 3 - Constants and Controller Setup"
90
+ ]
91
+ },
92
+ {
93
+ "cell_type": "code",
94
+ "execution_count": null,
95
+ "metadata": {},
96
+ "outputs": [],
97
+ "source": [
98
+ "import os, random, json, sys\n",
99
+ "import torch\n",
100
+ "\n",
101
+ "REPO_ROOT = \".\" # Adjust to repo root in Colab\n",
102
+ "if REPO_ROOT not in sys.path:\n",
103
+ " sys.path.insert(0, REPO_ROOT)\n",
104
+ "\n",
105
+ "from env.wildfire_env import WildfireEnv\n",
106
+ "from env.serialization import serialize_observation\n",
107
+ "from env.action_parser import parse_action\n",
108
+ "from agents.heuristic_agent import HeuristicAgent\n",
109
+ "from env.curriculum import CurriculumController\n",
110
+ "from datasets import Dataset\n",
111
+ "\n",
112
+ "SEED_POOL = list(range(100))\n",
113
+ "TIER_MAX_STEPS = {'easy': 80, 'medium': 150, 'hard': 300}\n",
114
+ "SYSTEM_PROMPT = (\n",
115
+ " 'You are an AI Incident Commander managing wildfire containment. '\n",
116
+ " 'You will receive a situation briefing each step. '\n",
117
+ " 'Respond with ONLY a valid JSON action object and nothing else. '\n",
118
+ " 'Example: {\"action_type\": \"idle\"}'\n",
119
+ ")\n",
120
+ "\n",
121
+ "# Thresholds calibrated to full-episode reward with heuristic continuation.\n",
122
+ "# Promote easy->medium once model's first action consistently beats random (+6.23).\n",
123
+ "# Promote medium->hard once model demonstrates meaningful improvement over random (+1.31).\n",
124
+ "controller = CurriculumController(\n",
125
+ " start_tier='easy',\n",
126
+ " thresholds={'easy': 6.5, 'medium': 3.5},\n",
127
+ ")\n",
128
+ "\n",
129
+ "os.makedirs('training/samples', exist_ok=True)\n",
130
+ "_reward_call_count = 0\n",
131
+ "\n",
132
+ "print(f\"Start tier: {controller.get_tier()}\")\n",
133
+ "print(f\"Seed pool: {len(SEED_POOL)} seeds\")\n",
134
+ "print(\"Env imports OK\")"
135
+ ]
136
+ },
137
+ {
138
+ "cell_type": "markdown",
139
+ "metadata": {},
140
+ "source": [
141
+ "## Section 4 - Standalone JSON Format Checker\n",
142
+ "\n",
143
+ "Replaces parse_action for format reward - no obs object needed (Issue 5 fix)."
144
+ ]
145
+ },
146
+ {
147
+ "cell_type": "code",
148
+ "execution_count": null,
149
+ "metadata": {},
150
+ "outputs": [],
151
+ "source": [
152
+ "import json as _json\n",
153
+ "import re as _re\n",
154
+ "from env.models import ActionType as _AT\n",
155
+ "\n",
156
+ "_VALID_ACTION_TYPES = {a.value for a in _AT}\n",
157
+ "\n",
158
+ "\n",
159
+ "def check_json_format(text: str) -> str:\n",
160
+ " \"\"\"\n",
161
+ " Validate LLM output format without needing an obs object.\n",
162
+ " Returns 'json_success', 'regex_fallback', or 'safe_idle'.\n",
163
+ " Does NOT use parse_action - avoids the obs.grid dependency.\n",
164
+ " \"\"\"\n",
165
+ " text = _re.sub(r'```(?:json)?\\s*', '', text).replace('```', '')\n",
166
+ " start = text.find('{')\n",
167
+ " if start == -1:\n",
168
+ " return 'safe_idle'\n",
169
+ " depth = 0\n",
170
+ " end = -1\n",
171
+ " for i, ch in enumerate(text[start:], start=start):\n",
172
+ " if ch == '{':\n",
173
+ " depth += 1\n",
174
+ " elif ch == '}':\n",
175
+ " depth -= 1\n",
176
+ " if depth == 0:\n",
177
+ " end = i\n",
178
+ " break\n",
179
+ " if end == -1:\n",
180
+ " return 'safe_idle'\n",
181
+ " try:\n",
182
+ " obj = _json.loads(text[start:end+1])\n",
183
+ " if not isinstance(obj, dict):\n",
184
+ " return 'safe_idle'\n",
185
+ " at = str(obj.get('action_type', '')).lower()\n",
186
+ " if at in _VALID_ACTION_TYPES:\n",
187
+ " return 'json_success'\n",
188
+ " return 'regex_fallback'\n",
189
+ " except Exception:\n",
190
+ " return 'regex_fallback'\n",
191
+ "\n",
192
+ "\n",
193
+ "assert check_json_format('{\"action_type\": \"idle\"}') == 'json_success'\n",
194
+ "assert check_json_format('{\"action_type\": \"bogus\"}') == 'regex_fallback'\n",
195
+ "assert check_json_format('no json here') == 'safe_idle'\n",
196
+ "print('check_json_format OK')"
197
+ ]
198
+ },
199
+ {
200
+ "cell_type": "markdown",
201
+ "metadata": {},
202
+ "source": [
203
+ "## Section 5 - Reward Functions\n",
204
+ "\n",
205
+ "Two reward signals for GRPO:\n",
206
+ "- **reward_fn_outcome** - full-episode env reward (1 model step + heuristic continuation)\n",
207
+ "- **reward_fn_format** - JSON formatting quality (fast, no env needed)"
208
+ ]
209
+ },
210
+ {
211
+ "cell_type": "code",
212
+ "execution_count": null,
213
+ "metadata": {},
214
+ "outputs": [],
215
+ "source": [
216
+ "def reward_fn_outcome(completions, prompts, tier=None, seed=None, **kwargs):\n",
217
+ " \"\"\"\n",
218
+ " Score each GRPO completion by:\n",
219
+ " 1. Resetting the env to the EXACT (tier, seed) that generated the prompt (Issue 1 fix).\n",
220
+ " 2. Applying the sampled completion as the single first action (MODEL_STEPS=1, Issue 3/4 fix).\n",
221
+ " 3. Running HeuristicAgent until episode completion (Issue 2 fix - captures terminal reward).\n",
222
+ "\n",
223
+ " tier and seed are dataset columns forwarded by GRPOTrainer.\n",
224
+ " \"\"\"\n",
225
+ " global _reward_call_count\n",
226
+ " _reward_call_count += 1\n",
227
+ " rewards = []\n",
228
+ "\n",
229
+ " for i, completion in enumerate(completions):\n",
230
+ " ep_tier = tier[i] if tier is not None else controller.get_tier()\n",
231
+ " ep_seed = seed[i] if seed is not None else random.choice(SEED_POOL)\n",
232
+ "\n",
233
+ " env = WildfireEnv()\n",
234
+ " obs = env.reset(task_id=ep_tier, seed=ep_seed)\n",
235
+ " total_reward = 0.0\n",
236
+ "\n",
237
+ " # Apply the sampled completion as step 0\n",
238
+ " text = completion if isinstance(completion, str) else completion[0]['content']\n",
239
+ " action, _ = parse_action(text, obs)\n",
240
+ " result = env.step(action)\n",
241
+ " total_reward += result.reward\n",
242
+ " obs = result.observation\n",
243
+ "\n",
244
+ " # Heuristic drives everything after (full episode to capture terminal reward)\n",
245
+ " heuristic = HeuristicAgent()\n",
246
+ " while not env.done:\n",
247
+ " action = heuristic.act(obs)\n",
248
+ " result = env.step(action)\n",
249
+ " total_reward += result.reward\n",
250
+ " obs = result.observation\n",
251
+ "\n",
252
+ " rewards.append(total_reward)\n",
253
+ "\n",
254
+ " # Update curriculum (once per batch, not per completion)\n",
255
+ " mean_r = sum(rewards) / len(rewards)\n",
256
+ " promoted = controller.after_episode(mean_r)\n",
257
+ " if promoted:\n",
258
+ " print(f' *** Curriculum promoted to: {promoted} (mean batch reward={mean_r:.2f}) ***')\n",
259
+ "\n",
260
+ " # Sample completions to disk for inspection\n",
261
+ " if _reward_call_count % 10 == 0:\n",
262
+ " sample_path = f'training/samples/call_{_reward_call_count}.txt'\n",
263
+ " with open(sample_path, 'w') as f:\n",
264
+ " f.write(f'call={_reward_call_count} tier={tier[0] if tier else \"?\"} reward={rewards[0]:.3f}\\n')\n",
265
+ " f.write('---\\n')\n",
266
+ " c = completions[0]\n",
267
+ " f.write(c if isinstance(c, str) else c[0]['content'])\n",
268
+ " f.write('\\n')\n",
269
+ "\n",
270
+ " return rewards\n",
271
+ "\n",
272
+ "\n",
273
+ "def reward_fn_format(completions, prompts, **kwargs):\n",
274
+ " \"\"\"\n",
275
+ " Scores JSON formatting quality using check_json_format() (no obs needed).\n",
276
+ " Runs independently of the env - fast and always well-defined.\n",
277
+ " \"\"\"\n",
278
+ " rewards = []\n",
279
+ " for completion in completions:\n",
280
+ " text = completion if isinstance(completion, str) else completion[0]['content']\n",
281
+ " status = check_json_format(text)\n",
282
+ " if status == 'json_success':\n",
283
+ " r = 0.15\n",
284
+ " elif status == 'regex_fallback':\n",
285
+ " r = 0.0\n",
286
+ " else:\n",
287
+ " r = -0.20\n",
288
+ " rewards.append(r)\n",
289
+ " return rewards\n",
290
+ "\n",
291
+ "\n",
292
+ "print('Reward functions defined.')"
293
+ ]
294
+ },
295
+ {
296
+ "cell_type": "markdown",
297
+ "metadata": {},
298
+ "source": [
299
+ "## Section 6 - Dataset Builder (Step-0 Only)\n",
300
+ "\n",
301
+ "Each row stores the seed so reward_fn_outcome can replay the exact same env state.\n",
302
+ "No mid-episode offset - GRPO prompt and reward state are always step-0."
303
+ ]
304
+ },
305
+ {
306
+ "cell_type": "code",
307
+ "execution_count": null,
308
+ "metadata": {},
309
+ "outputs": [],
310
+ "source": [
311
+ "def build_prompt_dataset(n=200):\n",
312
+ " \"\"\"\n",
313
+ " Build step-0 prompts for the current curriculum tier.\n",
314
+ " Stores the seed in each row so reward_fn can replay the exact same env state.\n",
315
+ " \"\"\"\n",
316
+ " rows = []\n",
317
+ " env_tmp = WildfireEnv()\n",
318
+ " tier = controller.get_tier()\n",
319
+ " max_steps = TIER_MAX_STEPS[tier]\n",
320
+ "\n",
321
+ " for i in range(n):\n",
322
+ " seed = SEED_POOL[i % len(SEED_POOL)]\n",
323
+ " obs = env_tmp.reset(task_id=tier, seed=seed)\n",
324
+ " prompt = serialize_observation(obs, 0, max_steps, tier=tier, prev_cells_burning=0)\n",
325
+ " rows.append({\n",
326
+ " 'prompt': [\n",
327
+ " {'role': 'system', 'content': SYSTEM_PROMPT},\n",
328
+ " {'role': 'user', 'content': prompt},\n",
329
+ " ],\n",
330
+ " 'tier': tier,\n",
331
+ " 'seed': seed,\n",
332
+ " })\n",
333
+ " return rows\n",
334
+ "\n",
335
+ "\n",
336
+ "_test_ds = build_prompt_dataset(3)\n",
337
+ "print(f\"Sample dataset row keys: {list(_test_ds[0].keys())}\")\n",
338
+ "print(f\"Tier: {_test_ds[0]['tier']}, Seed: {_test_ds[0]['seed']}\")\n",
339
+ "print(f\"Prompt roles: {[m['role'] for m in _test_ds[0]['prompt']]}\")\n",
340
+ "del _test_ds"
341
+ ]
342
+ },
343
+ {
344
+ "cell_type": "markdown",
345
+ "metadata": {},
346
+ "source": [
347
+ "## Section 7 - CurriculumDatasetCallback\n",
348
+ "\n",
349
+ "Rebuilds the training dataset whenever the curriculum controller promotes to a new tier."
350
+ ]
351
+ },
352
+ {
353
+ "cell_type": "code",
354
+ "execution_count": null,
355
+ "metadata": {},
356
+ "outputs": [],
357
+ "source": [
358
+ "from transformers import TrainerCallback\n",
359
+ "\n",
360
+ "\n",
361
+ "class CurriculumDatasetCallback(TrainerCallback):\n",
362
+ " def __init__(self, trainer_ref):\n",
363
+ " self._trainer = trainer_ref\n",
364
+ " self._last_tier = controller.get_tier()\n",
365
+ "\n",
366
+ " def on_step_end(self, args, state, control, **kwargs):\n",
367
+ " current_tier = controller.get_tier()\n",
368
+ " if current_tier != self._last_tier:\n",
369
+ " print(f' Rebuilding dataset for tier: {current_tier}')\n",
370
+ " new_ds = Dataset.from_list(build_prompt_dataset(200))\n",
371
+ " self._trainer.train_dataset = new_ds\n",
372
+ " self._last_tier = current_tier\n",
373
+ "\n",
374
+ "\n",
375
+ "print('CurriculumDatasetCallback defined.')"
376
+ ]
377
+ },
378
+ {
379
+ "cell_type": "markdown",
380
+ "metadata": {},
381
+ "source": [
382
+ "## Section 8 - GRPOTrainer Setup"
383
+ ]
384
+ },
385
+ {
386
+ "cell_type": "code",
387
+ "execution_count": null,
388
+ "metadata": {},
389
+ "outputs": [],
390
+ "source": [
391
+ "from trl import GRPOTrainer, GRPOConfig\n",
392
+ "\n",
393
+ "grpo_config = GRPOConfig(\n",
394
+ " output_dir='./grpo_checkpoints',\n",
395
+ " num_generations=8,\n",
396
+ " learning_rate=3e-6,\n",
397
+ " max_steps=400,\n",
398
+ " save_steps=20,\n",
399
+ " per_device_train_batch_size=1,\n",
400
+ " gradient_accumulation_steps=4,\n",
401
+ " max_completion_length=192,\n",
402
+ " logging_steps=1,\n",
403
+ " report_to='wandb',\n",
404
+ ")\n",
405
+ "\n",
406
+ "FastLanguageModel.for_training(model)\n",
407
+ "\n",
408
+ "dataset = Dataset.from_list(build_prompt_dataset(200))\n",
409
+ "print(f'Initial dataset: {len(dataset)} rows, tier={controller.get_tier()}')\n",
410
+ "\n",
411
+ "trainer = GRPOTrainer(\n",
412
+ " model=model,\n",
413
+ " processing_class=tokenizer,\n",
414
+ " reward_funcs=[reward_fn_outcome, reward_fn_format],\n",
415
+ " args=grpo_config,\n",
416
+ " train_dataset=dataset,\n",
417
+ ")\n",
418
+ "trainer.add_callback(CurriculumDatasetCallback(trainer))\n",
419
+ "\n",
420
+ "print('GRPOTrainer ready.')"
421
+ ]
422
+ },
423
+ {
424
+ "cell_type": "markdown",
425
+ "metadata": {},
426
+ "source": [
427
+ "## Section 9 - Run Training"
428
+ ]
429
+ },
430
+ {
431
+ "cell_type": "code",
432
+ "execution_count": null,
433
+ "metadata": {},
434
+ "outputs": [],
435
+ "source": [
436
+ "import wandb\n",
437
+ "wandb.init(project='wildfire-grpo', name='qwen7b-v2')\n",
438
+ "\n",
439
+ "print(f'Starting GRPO - {grpo_config.max_steps} steps, {grpo_config.num_generations} gen/prompt')\n",
440
+ "print(f'Reward: 1 model step at step-0, heuristic continuation to episode completion')\n",
441
+ "print(f'Start tier: {controller.get_tier()}')\n",
442
+ "\n",
443
+ "trainer.train()\n",
444
+ "print('Training complete.')\n",
445
+ "\n",
446
+ "history = controller.get_history()\n",
447
+ "stats = [{'step': ep, 'tier': t, 'mean_reward': r} for ep, t, r in history]\n",
448
+ "with open('./training_stats.json', 'w') as f:\n",
449
+ " json.dump(stats, f, indent=2)\n",
450
+ "print('Stats saved -> training_stats.json')"
451
+ ]
452
+ },
453
+ {
454
+ "cell_type": "markdown",
455
+ "metadata": {},
456
+ "source": [
457
+ "## Section 10 - Evaluate vs Baselines\n",
458
+ "\n",
459
+ "Run 15 full episodes per tier (seeds 42-56), compare with heuristic and random baselines."
460
+ ]
461
+ },
462
+ {
463
+ "cell_type": "code",
464
+ "execution_count": null,
465
+ "metadata": {},
466
+ "outputs": [],
467
+ "source": [
468
+ "class LLMAgent:\n",
469
+ " \"\"\"Wraps the trained model for evaluation. Must be re-instantiated per episode.\"\"\"\n",
470
+ "\n",
471
+ " def __init__(self, model, tokenizer, tier, max_steps):\n",
472
+ " self.model = model\n",
473
+ " self.tokenizer = tokenizer\n",
474
+ " self.tier = tier\n",
475
+ " self.max_steps = max_steps\n",
476
+ " self._step = 0\n",
477
+ " self._prev_burning = 0\n",
478
+ " self.json_success = self.regex_fallback = self.safe_idle = 0\n",
479
+ "\n",
480
+ " def act(self, obs):\n",
481
+ " prompt = serialize_observation(\n",
482
+ " obs, self._step, self.max_steps,\n",
483
+ " tier=self.tier,\n",
484
+ " prev_cells_burning=self._prev_burning,\n",
485
+ " )\n",
486
+ " self._prev_burning = obs.stats.cells_burning\n",
487
+ " messages = [\n",
488
+ " {'role': 'system', 'content': SYSTEM_PROMPT},\n",
489
+ " {'role': 'user', 'content': prompt},\n",
490
+ " ]\n",
491
+ " input_ids = self.tokenizer.apply_chat_template(\n",
492
+ " messages, tokenize=True,\n",
493
+ " add_generation_prompt=True, return_tensors='pt',\n",
494
+ " ).to(self.model.device)\n",
495
+ " with torch.no_grad():\n",
496
+ " out = self.model.generate(\n",
497
+ " input_ids, max_new_tokens=128,\n",
498
+ " pad_token_id=self.tokenizer.eos_token_id,\n",
499
+ " )\n",
500
+ " text = self.tokenizer.decode(out[0][input_ids.shape[1]:], skip_special_tokens=True)\n",
501
+ " action, status = parse_action(text, obs)\n",
502
+ " if status == 'json_success':\n",
503
+ " self.json_success += 1\n",
504
+ " elif status == 'regex_fallback':\n",
505
+ " self.regex_fallback += 1\n",
506
+ " else:\n",
507
+ " self.safe_idle += 1\n",
508
+ " self._step += 1\n",
509
+ " return action\n",
510
+ "\n",
511
+ "\n",
512
+ "print('LLMAgent class defined.')"
513
+ ]
514
+ },
515
+ {
516
+ "cell_type": "code",
517
+ "execution_count": null,
518
+ "metadata": {},
519
+ "outputs": [],
520
+ "source": [
521
+ "import numpy as np\n",
522
+ "\n",
523
+ "with open('scripts/results.json', 'r') as f:\n",
524
+ " baselines = json.load(f)\n",
525
+ "\n",
526
+ "FastLanguageModel.for_inference(model)\n",
527
+ "\n",
528
+ "EVAL_SEEDS = list(range(42, 57))\n",
529
+ "TIERS = ['easy', 'medium', 'hard']\n",
530
+ "\n",
531
+ "results = {}\n",
532
+ "\n",
533
+ "for tier in TIERS:\n",
534
+ " max_steps = TIER_MAX_STEPS[tier]\n",
535
+ " tier_rewards = []\n",
536
+ " tier_pop_saved = []\n",
537
+ " tier_json_success = 0\n",
538
+ " tier_total_actions = 0\n",
539
+ "\n",
540
+ " print(f'\\nEvaluating {tier} tier...')\n",
541
+ "\n",
542
+ " for seed in EVAL_SEEDS:\n",
543
+ " agent = LLMAgent(model, tokenizer, tier, max_steps)\n",
544
+ " env = WildfireEnv()\n",
545
+ " obs = env.reset(task_id=tier, seed=seed)\n",
546
+ " total_reward = 0.0\n",
547
+ "\n",
548
+ " while not env.done:\n",
549
+ " action = agent.act(obs)\n",
550
+ " result = env.step(action)\n",
551
+ " total_reward += result.reward\n",
552
+ " obs = result.observation\n",
553
+ "\n",
554
+ " tier_rewards.append(total_reward)\n",
555
+ "\n",
556
+ " state = env.state()\n",
557
+ " total_pop = state['total_population']\n",
558
+ " pop_lost = state['population_lost']\n",
559
+ " pop_saved = 100.0 * (total_pop - pop_lost) / total_pop if total_pop > 0 else 100.0\n",
560
+ " tier_pop_saved.append(pop_saved)\n",
561
+ "\n",
562
+ " tier_json_success += agent.json_success\n",
563
+ " tier_total_actions += agent.json_success + agent.regex_fallback + agent.safe_idle\n",
564
+ "\n",
565
+ " print(f' seed={seed}: reward={total_reward:+.2f}, pop_saved={pop_saved:.0f}%')\n",
566
+ "\n",
567
+ " json_rate = 100.0 * tier_json_success / tier_total_actions if tier_total_actions > 0 else 0\n",
568
+ " results[tier] = {\n",
569
+ " 'mean': float(np.mean(tier_rewards)),\n",
570
+ " 'std': float(np.std(tier_rewards)),\n",
571
+ " 'pop_saved_pct': float(np.mean(tier_pop_saved)),\n",
572
+ " 'json_success_rate': json_rate,\n",
573
+ " }\n",
574
+ "\n",
575
+ "print()\n",
576
+ "print('=' * 65)\n",
577
+ "print('=== Evaluation: Trained Model vs Baselines ===')\n",
578
+ "print('Seeds: 42-56 (15 per tier)')\n",
579
+ "print('=' * 65)\n",
580
+ "header = f'{\"Tier\":<10} {\"Trained\":>12} {\"Heuristic\":>12} {\"Random\":>12} {\"vs Heur.\":>12}'\n",
581
+ "print(header)\n",
582
+ "print('-' * 65)\n",
583
+ "\n",
584
+ "any_tier_close = False\n",
585
+ "for tier in TIERS:\n",
586
+ " t = results[tier]\n",
587
+ " h_mean = baselines['heuristic'][tier]['mean']\n",
588
+ " h_std = baselines['heuristic'][tier]['std']\n",
589
+ " r_mean = baselines['random'][tier]['mean']\n",
590
+ " r_std = baselines['random'][tier]['std']\n",
591
+ " delta = t['mean'] - h_mean\n",
592
+ " marker = ' OK' if delta >= -1.0 else ''\n",
593
+ " if delta >= -1.0:\n",
594
+ " any_tier_close = True\n",
595
+ " print(\n",
596
+ " f'{tier:<10} '\n",
597
+ " f'{t[\"mean\"]:+.2f}+/-{t[\"std\"]:.1f} '\n",
598
+ " f'{h_mean:+.2f}+/-{h_std:.1f} '\n",
599
+ " f'{r_mean:+.2f}+/-{r_std:.1f} '\n",
600
+ " f'{delta:+.2f}{marker}'\n",
601
+ " )\n",
602
+ "\n",
603
+ "print()\n",
604
+ "print('JSON success rate: ', end='')\n",
605
+ "print(' '.join(f'{t}={results[t][\"json_success_rate\"]:.1f}%' for t in TIERS))\n",
606
+ "print('Pop saved rate: ', end='')\n",
607
+ "print(' '.join(f'{t}={results[t][\"pop_saved_pct\"]:.0f}%' for t in TIERS))\n",
608
+ "\n",
609
+ "assert any_tier_close, (\n",
610
+ " 'Trained model did not come within 1.0 of heuristic on any tier. '\n",
611
+ " 'Check training logs and sample completions.'\n",
612
+ ")\n",
613
+ "print('\\nPASS: At least one tier within 1.0 of heuristic baseline.')\n",
614
+ "\n",
615
+ "FastLanguageModel.for_training(model)"
616
+ ]
617
+ },
618
+ {
619
+ "cell_type": "markdown",
620
+ "metadata": {},
621
+ "source": [
622
+ "## Section 11 - Save and Push"
623
+ ]
624
+ },
625
+ {
626
+ "cell_type": "code",
627
+ "execution_count": null,
628
+ "metadata": {},
629
+ "outputs": [],
630
+ "source": [
631
+ "model.save_pretrained('./grpo_final')\n",
632
+ "tokenizer.save_pretrained('./grpo_final')\n",
633
+ "print('Saved to ./grpo_final')"
634
+ ]
635
+ },
636
+ {
637
+ "cell_type": "code",
638
+ "execution_count": null,
639
+ "metadata": {},
640
+ "outputs": [],
641
+ "source": [
642
+ "HF_USERNAME = 'Eshit' # <-- CHANGE THIS\n",
643
+ "model.push_to_hub(f'{HF_USERNAME}/wildfire-grpo-7b')\n",
644
+ "tokenizer.push_to_hub(f'{HF_USERNAME}/wildfire-grpo-7b')\n",
645
+ "print(f'Pushed to hub: {HF_USERNAME}/wildfire-grpo-7b')"
646
+ ]
647
+ },
648
+ {
649
+ "cell_type": "code",
650
+ "execution_count": null,
651
+ "metadata": {},
652
+ "outputs": [],
653
+ "source": [
654
+ "!zip -r grpo_final.zip ./grpo_final\n",
655
+ "from google.colab import files\n",
656
+ "files.download('grpo_final.zip')\n",
657
+ "print('Download started.')"
658
+ ]
659
+ }
660
+ ],
661
+ "metadata": {
662
+ "accelerator": "GPU",
663
+ "colab": {
664
+ "gpuType": "A100",
665
+ "provenance": []
666
+ },
667
+ "kernelspec": {
668
+ "display_name": "Python 3",
669
+ "name": "python3"
670
+ },
671
+ "language_info": {
672
+ "name": "python",
673
+ "version": "3.10.0"
674
+ }
675
+ },
676
+ "nbformat": 4,
677
+ "nbformat_minor": 5
678
+ }
training/sft_colab.ipynb ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Wildfire Incident Command — SFT Training\n",
8
+ "\n",
9
+ "Supervised fine-tuning of **Qwen2.5-7B-Instruct** on wildfire incident command data.\n",
10
+ "\n",
11
+ "- **Input:** `training/sft_data.jsonl` (generated by `scripts/generate_sft_data.py`)\n",
12
+ "- **Goal:** Teach the model to output valid JSON action objects given wildfire observations\n",
13
+ "- **Hardware:** A100 40GB on Colab"
14
+ ]
15
+ },
16
+ {
17
+ "cell_type": "markdown",
18
+ "metadata": {},
19
+ "source": [
20
+ "## Section 1 — Install Dependencies"
21
+ ]
22
+ },
23
+ {
24
+ "cell_type": "code",
25
+ "execution_count": null,
26
+ "metadata": {},
27
+ "outputs": [],
28
+ "source": [
29
+ "!pip install \"unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git\"\n",
30
+ "!pip install trl==0.15.2 datasets==3.4.1"
31
+ ]
32
+ },
33
+ {
34
+ "cell_type": "code",
35
+ "execution_count": null,
36
+ "metadata": {},
37
+ "outputs": [],
38
+ "source": [
39
+ "import torch\n",
40
+ "assert torch.cuda.is_available(), \"GPU not available — switch to a GPU runtime\"\n",
41
+ "gpu_name = torch.cuda.get_device_name(0)\n",
42
+ "gpu_mem = torch.cuda.get_device_properties(0).total_mem / 1e9\n",
43
+ "print(f\"GPU: {gpu_name} | VRAM: {gpu_mem:.1f} GB\")"
44
+ ]
45
+ },
46
+ {
47
+ "cell_type": "markdown",
48
+ "metadata": {},
49
+ "source": [
50
+ "## Section 2 — Load Model"
51
+ ]
52
+ },
53
+ {
54
+ "cell_type": "code",
55
+ "execution_count": null,
56
+ "metadata": {},
57
+ "outputs": [],
58
+ "source": [
59
+ "from unsloth import FastLanguageModel\n",
60
+ "\n",
61
+ "model, tokenizer = FastLanguageModel.from_pretrained(\n",
62
+ " model_name=\"unsloth/Qwen2.5-7B-Instruct\",\n",
63
+ " max_seq_length=2048,\n",
64
+ " load_in_4bit=True,\n",
65
+ ")\n",
66
+ "\n",
67
+ "if tokenizer.pad_token is None:\n",
68
+ " tokenizer.pad_token = tokenizer.eos_token\n",
69
+ "\n",
70
+ "model = FastLanguageModel.get_peft_model(\n",
71
+ " model,\n",
72
+ " r=32,\n",
73
+ " lora_alpha=64,\n",
74
+ " lora_dropout=0.05,\n",
75
+ " target_modules=[\n",
76
+ " 'q_proj', 'k_proj', 'v_proj', 'o_proj',\n",
77
+ " 'gate_proj', 'up_proj', 'down_proj',\n",
78
+ " ],\n",
79
+ " bias=\"none\",\n",
80
+ " use_gradient_checkpointing=\"unsloth\",\n",
81
+ ")\n",
82
+ "\n",
83
+ "print(f\"Model loaded. Trainable params: {model.print_trainable_parameters()}\")"
84
+ ]
85
+ },
86
+ {
87
+ "cell_type": "markdown",
88
+ "metadata": {},
89
+ "source": [
90
+ "## Section 3 — Load Data"
91
+ ]
92
+ },
93
+ {
94
+ "cell_type": "code",
95
+ "execution_count": null,
96
+ "metadata": {},
97
+ "outputs": [],
98
+ "source": [
99
+ "import json\n",
100
+ "from datasets import Dataset\n",
101
+ "from collections import Counter\n",
102
+ "\n",
103
+ "SFT_DATA_PATH = \"training/sft_data.jsonl\"\n",
104
+ "\n",
105
+ "raw_examples = []\n",
106
+ "with open(SFT_DATA_PATH, \"r\", encoding=\"utf-8\") as f:\n",
107
+ " for line in f:\n",
108
+ " raw_examples.append(json.loads(line))\n",
109
+ "\n",
110
+ "print(f\"Loaded {len(raw_examples)} raw examples\")\n",
111
+ "\n",
112
+ "tier_dist = Counter(ex[\"tier\"] for ex in raw_examples)\n",
113
+ "print(f\"Tier distribution: {dict(tier_dist)}\")"
114
+ ]
115
+ },
116
+ {
117
+ "cell_type": "code",
118
+ "execution_count": null,
119
+ "metadata": {},
120
+ "outputs": [],
121
+ "source": [
122
+ "def format_example(ex):\n",
123
+ " \"\"\"Format a single SFT example into a full conversation string for causal LM loss.\"\"\"\n",
124
+ " messages = ex[\"messages\"]\n",
125
+ " completion = ex[\"completion\"]\n",
126
+ "\n",
127
+ " prompt_str = tokenizer.apply_chat_template(\n",
128
+ " messages,\n",
129
+ " tokenize=False,\n",
130
+ " add_generation_prompt=True,\n",
131
+ " )\n",
132
+ " full_text = prompt_str + completion + tokenizer.eos_token\n",
133
+ " return {\"text\": full_text}\n",
134
+ "\n",
135
+ "\n",
136
+ "formatted = [format_example(ex) for ex in raw_examples]\n",
137
+ "dataset = Dataset.from_list(formatted)\n",
138
+ "\n",
139
+ "split = dataset.train_test_split(test_size=0.05, seed=42)\n",
140
+ "train_dataset = split[\"train\"]\n",
141
+ "val_dataset = split[\"test\"]\n",
142
+ "\n",
143
+ "print(f\"Train: {len(train_dataset)} | Val: {len(val_dataset)}\")\n",
144
+ "print(f\"\\nSample (first 500 chars):\\n{formatted[0]['text'][:500]}\")"
145
+ ]
146
+ },
147
+ {
148
+ "cell_type": "markdown",
149
+ "metadata": {},
150
+ "source": [
151
+ "## Section 4 — Train"
152
+ ]
153
+ },
154
+ {
155
+ "cell_type": "code",
156
+ "execution_count": null,
157
+ "metadata": {},
158
+ "outputs": [],
159
+ "source": [
160
+ "from trl import SFTTrainer\n",
161
+ "from transformers import TrainingArguments\n",
162
+ "\n",
163
+ "trainer = SFTTrainer(\n",
164
+ " model=model,\n",
165
+ " tokenizer=tokenizer,\n",
166
+ " train_dataset=train_dataset,\n",
167
+ " eval_dataset=val_dataset,\n",
168
+ " dataset_text_field=\"text\",\n",
169
+ " max_seq_length=2048,\n",
170
+ " packing=True,\n",
171
+ " args=TrainingArguments(\n",
172
+ " output_dir=\"./sft_checkpoints\",\n",
173
+ " per_device_train_batch_size=2,\n",
174
+ " gradient_accumulation_steps=4,\n",
175
+ " num_train_epochs=1,\n",
176
+ " learning_rate=2e-4,\n",
177
+ " warmup_ratio=0.05,\n",
178
+ " lr_scheduler_type=\"cosine\",\n",
179
+ " logging_steps=10,\n",
180
+ " save_steps=100,\n",
181
+ " save_total_limit=2,\n",
182
+ " eval_strategy=\"steps\",\n",
183
+ " eval_steps=100,\n",
184
+ " fp16=not torch.cuda.is_bf16_supported(),\n",
185
+ " bf16=torch.cuda.is_bf16_supported(),\n",
186
+ " report_to=\"none\",\n",
187
+ " optim=\"adamw_8bit\",\n",
188
+ " seed=42,\n",
189
+ " ),\n",
190
+ ")\n",
191
+ "\n",
192
+ "print(\"Starting SFT training...\")\n",
193
+ "trainer.train()\n",
194
+ "print(\"SFT training complete.\")"
195
+ ]
196
+ },
197
+ {
198
+ "cell_type": "markdown",
199
+ "metadata": {},
200
+ "source": [
201
+ "## Section 5 — Quick Eval\n",
202
+ "\n",
203
+ "Run 10 full episodes on easy tier with the trained model driving every step.\n",
204
+ "Requires env imports — upload the repo or clone it."
205
+ ]
206
+ },
207
+ {
208
+ "cell_type": "code",
209
+ "execution_count": null,
210
+ "metadata": {},
211
+ "outputs": [],
212
+ "source": [
213
+ "import sys, os\n",
214
+ "\n",
215
+ "# Adjust this path to wherever the repo root is in Colab\n",
216
+ "REPO_ROOT = \".\" # or e.g. \"/content/Wildfire-Containment-Simulator-main\"\n",
217
+ "if REPO_ROOT not in sys.path:\n",
218
+ " sys.path.insert(0, REPO_ROOT)\n",
219
+ "\n",
220
+ "from env.wildfire_env import WildfireEnv\n",
221
+ "from env.serialization import serialize_observation\n",
222
+ "from env.action_parser import parse_action\n",
223
+ "from env.models import TIER_EASY\n",
224
+ "\n",
225
+ "SYSTEM_PROMPT = (\n",
226
+ " \"You are an AI Incident Commander managing wildfire containment. \"\n",
227
+ " \"You will receive a situation briefing each step. \"\n",
228
+ " \"Respond with ONLY a valid JSON action object and nothing else. \"\n",
229
+ " 'Example: {\"action_type\": \"idle\"}'\n",
230
+ ")\n",
231
+ "\n",
232
+ "print(\"Env imports OK\")"
233
+ ]
234
+ },
235
+ {
236
+ "cell_type": "code",
237
+ "execution_count": null,
238
+ "metadata": {},
239
+ "outputs": [],
240
+ "source": [
241
+ "import numpy as np\n",
242
+ "\n",
243
+ "FastLanguageModel.for_inference(model)\n",
244
+ "\n",
245
+ "EVAL_SEEDS = range(42, 52)\n",
246
+ "TIER = \"easy\"\n",
247
+ "MAX_STEPS = TIER_EASY.episode_length\n",
248
+ "\n",
249
+ "rewards = []\n",
250
+ "pop_saved_pcts = []\n",
251
+ "parse_counts = {\"json_success\": 0, \"regex_fallback\": 0, \"safe_idle\": 0}\n",
252
+ "total_steps = 0\n",
253
+ "\n",
254
+ "for seed in EVAL_SEEDS:\n",
255
+ " env = WildfireEnv()\n",
256
+ " obs = env.reset(task_id=TIER, seed=seed)\n",
257
+ " episode_reward = 0.0\n",
258
+ " step_num = 0\n",
259
+ " prev_burning = 0\n",
260
+ "\n",
261
+ " while not env.done:\n",
262
+ " prompt = serialize_observation(\n",
263
+ " obs, step_num, MAX_STEPS,\n",
264
+ " tier=TIER, prev_cells_burning=prev_burning,\n",
265
+ " )\n",
266
+ " prev_burning = obs.stats.cells_burning\n",
267
+ "\n",
268
+ " messages = [\n",
269
+ " {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
270
+ " {\"role\": \"user\", \"content\": prompt},\n",
271
+ " ]\n",
272
+ " input_ids = tokenizer.apply_chat_template(\n",
273
+ " messages, tokenize=True,\n",
274
+ " add_generation_prompt=True, return_tensors=\"pt\",\n",
275
+ " ).to(model.device)\n",
276
+ "\n",
277
+ " with torch.no_grad():\n",
278
+ " out = model.generate(\n",
279
+ " input_ids,\n",
280
+ " max_new_tokens=128,\n",
281
+ " pad_token_id=tokenizer.eos_token_id,\n",
282
+ " )\n",
283
+ " text = tokenizer.decode(out[0][input_ids.shape[1]:], skip_special_tokens=True)\n",
284
+ "\n",
285
+ " action, status = parse_action(text, obs)\n",
286
+ " parse_counts[status] = parse_counts.get(status, 0) + 1\n",
287
+ "\n",
288
+ " result = env.step(action)\n",
289
+ " episode_reward += result.reward\n",
290
+ " obs = result.observation\n",
291
+ " step_num += 1\n",
292
+ "\n",
293
+ " total_steps += step_num\n",
294
+ " rewards.append(episode_reward)\n",
295
+ "\n",
296
+ " state = env.state()\n",
297
+ " total_pop = state[\"total_population\"]\n",
298
+ " pop_lost = state[\"population_lost\"]\n",
299
+ " pop_saved = 100.0 * (total_pop - pop_lost) / total_pop if total_pop > 0 else 100.0\n",
300
+ " pop_saved_pcts.append(pop_saved)\n",
301
+ "\n",
302
+ " print(f\" Seed {seed}: reward={episode_reward:+.2f}, steps={step_num}, pop_saved={pop_saved:.0f}%\")\n",
303
+ "\n",
304
+ "mean_reward = np.mean(rewards)\n",
305
+ "std_reward = np.std(rewards)\n",
306
+ "total_parses = sum(parse_counts.values())\n",
307
+ "json_rate = 100.0 * parse_counts[\"json_success\"] / total_parses if total_parses > 0 else 0\n",
308
+ "mean_pop = np.mean(pop_saved_pcts)\n",
309
+ "\n",
310
+ "print(f\"\\n{'='*50}\")\n",
311
+ "print(f\"SFT Quick Eval — {TIER} tier, seeds {EVAL_SEEDS.start}-{EVAL_SEEDS.stop-1}\")\n",
312
+ "print(f\"Mean reward: {mean_reward:+.2f} ± {std_reward:.2f}\")\n",
313
+ "print(f\"JSON success rate: {json_rate:.1f}%\")\n",
314
+ "print(f\"Mean pop saved: {mean_pop:.1f}%\")\n",
315
+ "print(f\"Parse breakdown: {dict(parse_counts)}\")\n",
316
+ "print(f\"{'='*50}\")\n",
317
+ "\n",
318
+ "assert mean_reward > 2.0, (\n",
319
+ " f\"SFT warm-up insufficient (mean_reward={mean_reward:.2f}) — do not proceed to GRPO\"\n",
320
+ ")\n",
321
+ "print(\"\\n✓ SFT checkpoint passes warm-up gate. Safe to proceed to GRPO.\")\n",
322
+ "\n",
323
+ "FastLanguageModel.for_training(model)"
324
+ ]
325
+ },
326
+ {
327
+ "cell_type": "markdown",
328
+ "metadata": {},
329
+ "source": [
330
+ "## Section 6 — Save & Export"
331
+ ]
332
+ },
333
+ {
334
+ "cell_type": "code",
335
+ "execution_count": null,
336
+ "metadata": {},
337
+ "outputs": [],
338
+ "source": [
339
+ "model.save_pretrained(\"./sft_final\")\n",
340
+ "tokenizer.save_pretrained(\"./sft_final\")\n",
341
+ "print(\"Saved to ./sft_final\")"
342
+ ]
343
+ },
344
+ {
345
+ "cell_type": "code",
346
+ "execution_count": null,
347
+ "metadata": {},
348
+ "outputs": [],
349
+ "source": [
350
+ "# Push to HuggingFace Hub — replace with your username\n",
351
+ "HF_USERNAME = \"Eshit\"\n",
352
+ "model.push_to_hub(f\"{HF_USERNAME}/wildfire-sft-7b\")\n",
353
+ "tokenizer.push_to_hub(f\"{HF_USERNAME}/wildfire-sft-7b\")\n",
354
+ "print(f\"Pushed to hub: {HF_USERNAME}/wildfire-sft-7b\")"
355
+ ]
356
+ },
357
+ {
358
+ "cell_type": "code",
359
+ "execution_count": null,
360
+ "metadata": {},
361
+ "outputs": [],
362
+ "source": [
363
+ "!zip -r sft_final.zip ./sft_final\n",
364
+ "from google.colab import files\n",
365
+ "files.download(\"sft_final.zip\")\n",
366
+ "print(\"Download started.\")"
367
+ ]
368
+ }
369
+ ],
370
+ "metadata": {
371
+ "accelerator": "GPU",
372
+ "colab": {
373
+ "gpuType": "A100",
374
+ "provenance": []
375
+ },
376
+ "kernelspec": {
377
+ "display_name": "Python 3",
378
+ "name": "python3"
379
+ },
380
+ "language_info": {
381
+ "name": "python",
382
+ "version": "3.10.0"
383
+ }
384
+ },
385
+ "nbformat": 4,
386
+ "nbformat_minor": 5
387
+ }
training/sft_data.jsonl ADDED
The diff for this file is too large to render. See raw diff