Improve wildfire metrics and training assets
Browse files- env/models.py +12 -2
- env/serialization.py +26 -6
- env/wildfire_env.py +75 -17
- frontend/app.js +120 -24
- frontend/index.html +26 -6
- frontend/style.css +14 -2
- graders/grader_easy.py +1 -1
- graders/grader_hard.py +1 -1
- graders/grader_medium.py +1 -1
- scripts/generate_sft_data.py +183 -0
- scripts/results.json +65 -95
- training/grpo_v2_colab.ipynb +678 -0
- training/sft_colab.ipynb +387 -0
- training/sft_data.jsonl +0 -0
env/models.py
CHANGED
|
@@ -288,7 +288,13 @@ class ClusterStats(BaseModel):
|
|
| 288 |
cells_saved: int = 0
|
| 289 |
population_threatened: int = 0
|
| 290 |
population_lost: int = 0
|
|
|
|
| 291 |
containment_pct: float = Field(ge=0.0, le=100.0, default=0.0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 292 |
current_step: int = 0
|
| 293 |
max_steps: int = 100
|
| 294 |
firebreaks_built: int = 0
|
|
@@ -342,6 +348,7 @@ class TierConfig(BaseModel):
|
|
| 342 |
crew_loss_step: Optional[int] = None
|
| 343 |
crew_loss_id: Optional[str] = None
|
| 344 |
tanker_cooldown: int = 5
|
|
|
|
| 345 |
wind_speed_init: float = 10.0
|
| 346 |
wind_dir_init: float = 0.0
|
| 347 |
humidity_init: float = 40.0
|
|
@@ -367,11 +374,12 @@ TIER_EASY = TierConfig(
|
|
| 367 |
firebreak_budget=15,
|
| 368 |
recon_budget=0,
|
| 369 |
episode_length=80,
|
| 370 |
-
num_ignition_points=
|
| 371 |
enable_smoke_occlusion=False,
|
| 372 |
enable_sensor_noise=False,
|
| 373 |
enable_fog_of_war=False,
|
| 374 |
enable_wind_shifts=False,
|
|
|
|
| 375 |
wind_speed_init=10.0,
|
| 376 |
wind_dir_init=0.0,
|
| 377 |
humidity_init=40.0,
|
|
@@ -391,11 +399,12 @@ TIER_MEDIUM = TierConfig(
|
|
| 391 |
firebreak_budget=20,
|
| 392 |
recon_budget=1,
|
| 393 |
episode_length=150,
|
| 394 |
-
num_ignition_points=
|
| 395 |
enable_smoke_occlusion=True,
|
| 396 |
enable_sensor_noise=True,
|
| 397 |
enable_fog_of_war=False,
|
| 398 |
enable_wind_shifts=True,
|
|
|
|
| 399 |
wind_speed_init=15.0,
|
| 400 |
wind_dir_init=45.0,
|
| 401 |
humidity_init=35.0,
|
|
@@ -417,6 +426,7 @@ TIER_HARD = TierConfig(
|
|
| 417 |
episode_length=300,
|
| 418 |
num_ignition_points=3,
|
| 419 |
staggered_ignition_step=30,
|
|
|
|
| 420 |
enable_smoke_occlusion=True,
|
| 421 |
enable_sensor_noise=True,
|
| 422 |
enable_fog_of_war=True,
|
|
|
|
| 288 |
cells_saved: int = 0
|
| 289 |
population_threatened: int = 0
|
| 290 |
population_lost: int = 0
|
| 291 |
+
total_population: int = Field(ge=0, default=0, description="Initial population (for UI % civ safe)")
|
| 292 |
containment_pct: float = Field(ge=0.0, le=100.0, default=0.0)
|
| 293 |
+
# Meaningful progress metrics shown to agent and display
|
| 294 |
+
area_saved_pct: float = Field(ge=0.0, le=100.0, default=100.0,
|
| 295 |
+
description="Percentage of burnable land not yet burned")
|
| 296 |
+
civilians_saved_pct: float = Field(ge=0.0, le=100.0, default=100.0,
|
| 297 |
+
description="Percentage of civilians in unburned zones")
|
| 298 |
current_step: int = 0
|
| 299 |
max_steps: int = 100
|
| 300 |
firebreaks_built: int = 0
|
|
|
|
| 348 |
crew_loss_step: Optional[int] = None
|
| 349 |
crew_loss_id: Optional[str] = None
|
| 350 |
tanker_cooldown: int = 5
|
| 351 |
+
min_active_steps: int = 5 # episode cannot end via fire-out before this step
|
| 352 |
wind_speed_init: float = 10.0
|
| 353 |
wind_dir_init: float = 0.0
|
| 354 |
humidity_init: float = 40.0
|
|
|
|
| 374 |
firebreak_budget=15,
|
| 375 |
recon_budget=0,
|
| 376 |
episode_length=80,
|
| 377 |
+
num_ignition_points=2,
|
| 378 |
enable_smoke_occlusion=False,
|
| 379 |
enable_sensor_noise=False,
|
| 380 |
enable_fog_of_war=False,
|
| 381 |
enable_wind_shifts=False,
|
| 382 |
+
min_active_steps=25,
|
| 383 |
wind_speed_init=10.0,
|
| 384 |
wind_dir_init=0.0,
|
| 385 |
humidity_init=40.0,
|
|
|
|
| 399 |
firebreak_budget=20,
|
| 400 |
recon_budget=1,
|
| 401 |
episode_length=150,
|
| 402 |
+
num_ignition_points=3,
|
| 403 |
enable_smoke_occlusion=True,
|
| 404 |
enable_sensor_noise=True,
|
| 405 |
enable_fog_of_war=False,
|
| 406 |
enable_wind_shifts=True,
|
| 407 |
+
min_active_steps=45,
|
| 408 |
wind_speed_init=15.0,
|
| 409 |
wind_dir_init=45.0,
|
| 410 |
humidity_init=35.0,
|
|
|
|
| 426 |
episode_length=300,
|
| 427 |
num_ignition_points=3,
|
| 428 |
staggered_ignition_step=30,
|
| 429 |
+
min_active_steps=80,
|
| 430 |
enable_smoke_occlusion=True,
|
| 431 |
enable_sensor_noise=True,
|
| 432 |
enable_fog_of_war=True,
|
env/serialization.py
CHANGED
|
@@ -12,8 +12,14 @@ if TYPE_CHECKING:
|
|
| 12 |
from .models import FireState, IntensityBin
|
| 13 |
|
| 14 |
|
| 15 |
-
def serialize_observation(
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
grid_summary = _summarize_grid_regions(obs.grid)
|
| 18 |
resources = _format_resources(obs.resources)
|
| 19 |
events = _format_events(obs.recent_events)
|
|
@@ -29,8 +35,9 @@ def serialize_observation(obs: "Observation", step_num: int, max_steps: int) ->
|
|
| 29 |
parts.append(obs._briefing_reminder)
|
| 30 |
parts.append("")
|
| 31 |
|
|
|
|
| 32 |
parts += [
|
| 33 |
-
f"=== WILDFIRE INCIDENT COMMAND — STEP {step_num}/{max_steps} ===",
|
| 34 |
"",
|
| 35 |
"SITUATION:",
|
| 36 |
situation,
|
|
@@ -52,12 +59,14 @@ def serialize_observation(obs: "Observation", step_num: int, max_steps: int) ->
|
|
| 52 |
|
| 53 |
# ── Situation block ──────────────────────────────────────────
|
| 54 |
|
| 55 |
-
def _format_situation(obs: "Observation") -> str:
|
| 56 |
stats = obs.stats
|
| 57 |
w = obs.weather
|
| 58 |
|
| 59 |
burning = stats.cells_burning
|
| 60 |
-
|
|
|
|
|
|
|
| 61 |
pop_at_risk = stats.population_threatened
|
| 62 |
|
| 63 |
wind_dir = _deg_to_compass(w.wind_direction_deg)
|
|
@@ -65,8 +74,19 @@ def _format_situation(obs: "Observation") -> str:
|
|
| 65 |
|
| 66 |
last_event = obs.recent_events[-1] if obs.recent_events else "None"
|
| 67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
lines = [
|
| 69 |
-
f"- Fire active on {burning} cells.
|
|
|
|
|
|
|
| 70 |
f"- Wind: {w.wind_speed_kmh:.0f} km/h {wind_dir} (±5 km/h noise). Humidity: {w.humidity_pct:.0f}%. Rain: {rain}.",
|
| 71 |
f"- Last event: {last_event}",
|
| 72 |
]
|
|
|
|
| 12 |
from .models import FireState, IntensityBin
|
| 13 |
|
| 14 |
|
| 15 |
+
def serialize_observation(
|
| 16 |
+
obs: "Observation",
|
| 17 |
+
step_num: int,
|
| 18 |
+
max_steps: int,
|
| 19 |
+
tier: str = "",
|
| 20 |
+
prev_cells_burning: int = 0,
|
| 21 |
+
) -> str:
|
| 22 |
+
situation = _format_situation(obs, prev_cells_burning)
|
| 23 |
grid_summary = _summarize_grid_regions(obs.grid)
|
| 24 |
resources = _format_resources(obs.resources)
|
| 25 |
events = _format_events(obs.recent_events)
|
|
|
|
| 35 |
parts.append(obs._briefing_reminder)
|
| 36 |
parts.append("")
|
| 37 |
|
| 38 |
+
tier_str = f" [{tier.upper()}]" if tier else ""
|
| 39 |
parts += [
|
| 40 |
+
f"=== WILDFIRE INCIDENT COMMAND{tier_str} — STEP {step_num}/{max_steps} ===",
|
| 41 |
"",
|
| 42 |
"SITUATION:",
|
| 43 |
situation,
|
|
|
|
| 59 |
|
| 60 |
# ── Situation block ──────────────────────────────────────────
|
| 61 |
|
| 62 |
+
def _format_situation(obs: "Observation", prev_cells_burning: int = 0) -> str:
|
| 63 |
stats = obs.stats
|
| 64 |
w = obs.weather
|
| 65 |
|
| 66 |
burning = stats.cells_burning
|
| 67 |
+
land_saved = round(stats.area_saved_pct, 1)
|
| 68 |
+
civ_safe = round(stats.civilians_saved_pct, 1)
|
| 69 |
+
cells_burned = stats.cells_burned
|
| 70 |
pop_at_risk = stats.population_threatened
|
| 71 |
|
| 72 |
wind_dir = _deg_to_compass(w.wind_direction_deg)
|
|
|
|
| 74 |
|
| 75 |
last_event = obs.recent_events[-1] if obs.recent_events else "None"
|
| 76 |
|
| 77 |
+
# Spread delta — positive means fire is growing, negative means shrinking
|
| 78 |
+
delta = burning - prev_cells_burning
|
| 79 |
+
if delta > 0:
|
| 80 |
+
spread_str = f" (+{delta} spreading)"
|
| 81 |
+
elif delta < 0:
|
| 82 |
+
spread_str = f" ({delta} shrinking)"
|
| 83 |
+
else:
|
| 84 |
+
spread_str = " (stable)"
|
| 85 |
+
|
| 86 |
lines = [
|
| 87 |
+
f"- Fire active on {burning} cells{spread_str}. Land saved: {land_saved}% of burnable area "
|
| 88 |
+
f"({cells_burned} cells burned out). Civilians safe: {civ_safe}%. "
|
| 89 |
+
f"Population at risk: {pop_at_risk} zones.",
|
| 90 |
f"- Wind: {w.wind_speed_kmh:.0f} km/h {wind_dir} (±5 km/h noise). Humidity: {w.humidity_pct:.0f}%. Rain: {rain}.",
|
| 91 |
f"- Last event: {last_event}",
|
| 92 |
]
|
env/wildfire_env.py
CHANGED
|
@@ -212,6 +212,16 @@ class WildfireEnv:
|
|
| 212 |
|
| 213 |
self.current_step += 1
|
| 214 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
# ── Step 9: Compute reward ──
|
| 216 |
legacy_reward = self.reward_calc.compute_reward(self.grid, self.resources, self.current_step)
|
| 217 |
|
|
@@ -324,19 +334,39 @@ class WildfireEnv:
|
|
| 324 |
}
|
| 325 |
|
| 326 |
def _is_redundant(self, action: Action) -> bool:
|
| 327 |
-
"""True if action
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 328 |
if self._prev_action is None:
|
| 329 |
return False
|
| 330 |
prev = self._prev_action
|
| 331 |
if action.action_type != prev.action_type:
|
| 332 |
return False
|
| 333 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 334 |
|
| 335 |
def _ignite_initial_fires(self) -> None:
|
| 336 |
"""Place initial fire ignition points based on tier config.
|
| 337 |
|
| 338 |
Ignition candidates are shifted away from populated cells to ensure
|
| 339 |
a minimum survivable distance, reducing unwinnable-scenario variance.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 340 |
"""
|
| 341 |
rows, cols = self.config.grid_rows, self.config.grid_cols
|
| 342 |
|
|
@@ -344,19 +374,25 @@ class WildfireEnv:
|
|
| 344 |
min_pop_dist = {"easy": 4, "medium": 6, "hard": 7}.get(self.config.tier_name, 5)
|
| 345 |
|
| 346 |
if self.config.tier_name == "easy":
|
| 347 |
-
|
| 348 |
-
self.
|
|
|
|
|
|
|
|
|
|
| 349 |
elif self.config.tier_name == "medium":
|
| 350 |
-
|
| 351 |
-
|
|
|
|
| 352 |
r2, c2 = self._find_ignition_candidate(2 * rows // 3, 2 * cols // 3, min_pop_dist)
|
| 353 |
-
self.grid.ignite_cell(r2, c2, intensity=0.
|
|
|
|
|
|
|
| 354 |
else:
|
| 355 |
-
# Two initial points (third comes later via staggered ignition)
|
| 356 |
r1, c1 = self._find_ignition_candidate(rows // 4, cols // 4, min_pop_dist)
|
| 357 |
-
self.grid.ignite_cell(r1, c1, intensity=0.
|
| 358 |
r2, c2 = self._find_ignition_candidate(rows // 2, 3 * cols // 4, min_pop_dist)
|
| 359 |
-
self.grid.ignite_cell(r2, c2, intensity=0.
|
| 360 |
|
| 361 |
def _find_ignition_candidate(self, target_r: int, target_c: int, min_pop_dist: int) -> tuple[int, int]:
|
| 362 |
"""Return the nearest valid ignition cell to (target_r, target_c) that is at
|
|
@@ -482,11 +518,17 @@ class WildfireEnv:
|
|
| 482 |
# Fire fully contained (no burning cells)
|
| 483 |
burning = self.grid.count_by_state(FireState.BURNING)
|
| 484 |
ember = self.grid.count_by_state(FireState.EMBER)
|
| 485 |
-
if burning == 0 and ember == 0
|
| 486 |
-
#
|
| 487 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 488 |
and self.current_step < self.config.staggered_ignition_step):
|
| 489 |
-
return
|
|
|
|
| 490 |
|
| 491 |
# All populated zones burned (catastrophic failure)
|
| 492 |
total_pop = self.grid.get_total_population()
|
|
@@ -514,13 +556,29 @@ class WildfireEnv:
|
|
| 514 |
resource_state = self.resources.get_resource_state()
|
| 515 |
|
| 516 |
# Stats
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 517 |
stats = ClusterStats(
|
| 518 |
-
cells_burned=
|
| 519 |
cells_burning=self.grid.count_by_state(FireState.BURNING),
|
| 520 |
-
cells_saved=
|
| 521 |
population_threatened=self._count_threatened_population(),
|
| 522 |
-
population_lost=
|
|
|
|
| 523 |
containment_pct=self._compute_containment_pct(),
|
|
|
|
|
|
|
| 524 |
current_step=self.current_step,
|
| 525 |
max_steps=self.config.episode_length,
|
| 526 |
firebreaks_built=self.resources.total_firebreaks_built,
|
|
|
|
| 212 |
|
| 213 |
self.current_step += 1
|
| 214 |
|
| 215 |
+
# Log a hold-message when fire is extinguished before min_active_steps so
|
| 216 |
+
# agents (and the LLM) understand the episode must continue for monitoring.
|
| 217 |
+
burning_now = (self.grid.count_by_state(FireState.BURNING)
|
| 218 |
+
+ self.grid.count_by_state(FireState.EMBER))
|
| 219 |
+
if burning_now == 0 and self.current_step < self.config.min_active_steps:
|
| 220 |
+
step_events.append(
|
| 221 |
+
f"All fires contained. Holding perimeter until step "
|
| 222 |
+
f"{self.config.min_active_steps} (min_active_steps)."
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
# ── Step 9: Compute reward ──
|
| 226 |
legacy_reward = self.reward_calc.compute_reward(self.grid, self.resources, self.current_step)
|
| 227 |
|
|
|
|
| 334 |
}
|
| 335 |
|
| 336 |
def _is_redundant(self, action: Action) -> bool:
|
| 337 |
+
"""True if action is a meaningless repeat of the previous action.
|
| 338 |
+
|
| 339 |
+
Actions that use target coordinates (DROP_RETARDANT, DEPLOY_CREW, RECON_FLIGHT)
|
| 340 |
+
are redundant when the type + target cell match. Directional actions (MOVE_CREW,
|
| 341 |
+
BUILD_FIREBREAK) require the same crew_id AND direction to be redundant — two
|
| 342 |
+
consecutive MOVE_CREW steps by different crews, or in different directions, are
|
| 343 |
+
valid patrol behaviour and must not be penalised.
|
| 344 |
+
"""
|
| 345 |
if self._prev_action is None:
|
| 346 |
return False
|
| 347 |
prev = self._prev_action
|
| 348 |
if action.action_type != prev.action_type:
|
| 349 |
return False
|
| 350 |
+
# Coordinate-targeted actions: redundant when same cell is targeted again
|
| 351 |
+
if action.target_row is not None or prev.target_row is not None:
|
| 352 |
+
return (action.target_row == prev.target_row
|
| 353 |
+
and action.target_col == prev.target_col)
|
| 354 |
+
# Crew directional actions: redundant only when same crew moves same direction
|
| 355 |
+
if action.crew_id is not None:
|
| 356 |
+
return (action.crew_id == prev.crew_id
|
| 357 |
+
and action.direction == prev.direction)
|
| 358 |
+
return False
|
| 359 |
|
| 360 |
def _ignite_initial_fires(self) -> None:
|
| 361 |
"""Place initial fire ignition points based on tier config.
|
| 362 |
|
| 363 |
Ignition candidates are shifted away from populated cells to ensure
|
| 364 |
a minimum survivable distance, reducing unwinnable-scenario variance.
|
| 365 |
+
|
| 366 |
+
Intensity is set high enough (0.65) that a single tanker drop (-0.4)
|
| 367 |
+
leaves residual fire (0.25) so the episode cannot be solved in 1-2
|
| 368 |
+
steps. The fire must spread, be actively managed, and burn for at
|
| 369 |
+
least min_active_steps before the episode can end.
|
| 370 |
"""
|
| 371 |
rows, cols = self.config.grid_rows, self.config.grid_cols
|
| 372 |
|
|
|
|
| 374 |
min_pop_dist = {"easy": 4, "medium": 6, "hard": 7}.get(self.config.tier_name, 5)
|
| 375 |
|
| 376 |
if self.config.tier_name == "easy":
|
| 377 |
+
# Two ignition points spread across the grid so crews must split
|
| 378 |
+
r1, c1 = self._find_ignition_candidate(rows // 2, cols // 3, min_pop_dist)
|
| 379 |
+
self.grid.ignite_cell(r1, c1, intensity=0.65)
|
| 380 |
+
r2, c2 = self._find_ignition_candidate(rows // 2, 2 * cols // 3, min_pop_dist)
|
| 381 |
+
self.grid.ignite_cell(r2, c2, intensity=0.65)
|
| 382 |
elif self.config.tier_name == "medium":
|
| 383 |
+
# Three ignition points: forces genuine multi-front management
|
| 384 |
+
r1, c1 = self._find_ignition_candidate(rows // 4, cols // 3, min_pop_dist)
|
| 385 |
+
self.grid.ignite_cell(r1, c1, intensity=0.65)
|
| 386 |
r2, c2 = self._find_ignition_candidate(2 * rows // 3, 2 * cols // 3, min_pop_dist)
|
| 387 |
+
self.grid.ignite_cell(r2, c2, intensity=0.65)
|
| 388 |
+
r3, c3 = self._find_ignition_candidate(rows // 2, cols // 2, min_pop_dist)
|
| 389 |
+
self.grid.ignite_cell(r3, c3, intensity=0.65)
|
| 390 |
else:
|
| 391 |
+
# Two initial points (third comes later via staggered ignition at step 30)
|
| 392 |
r1, c1 = self._find_ignition_candidate(rows // 4, cols // 4, min_pop_dist)
|
| 393 |
+
self.grid.ignite_cell(r1, c1, intensity=0.65)
|
| 394 |
r2, c2 = self._find_ignition_candidate(rows // 2, 3 * cols // 4, min_pop_dist)
|
| 395 |
+
self.grid.ignite_cell(r2, c2, intensity=0.65)
|
| 396 |
|
| 397 |
def _find_ignition_candidate(self, target_r: int, target_c: int, min_pop_dist: int) -> tuple[int, int]:
|
| 398 |
"""Return the nearest valid ignition cell to (target_r, target_c) that is at
|
|
|
|
| 518 |
# Fire fully contained (no burning cells)
|
| 519 |
burning = self.grid.count_by_state(FireState.BURNING)
|
| 520 |
ember = self.grid.count_by_state(FireState.EMBER)
|
| 521 |
+
if burning == 0 and ember == 0:
|
| 522 |
+
# Enforce minimum active steps — prevents trivial 1-2 step episodes
|
| 523 |
+
# where a single tanker drop or natural burnout ends the episode
|
| 524 |
+
# before the agent has taken any meaningful sequence of actions.
|
| 525 |
+
if self.current_step < self.config.min_active_steps:
|
| 526 |
+
return False
|
| 527 |
+
# Don't terminate before staggered ignition fires (hard tier)
|
| 528 |
+
if (self.config.staggered_ignition_step
|
| 529 |
and self.current_step < self.config.staggered_ignition_step):
|
| 530 |
+
return False
|
| 531 |
+
return True
|
| 532 |
|
| 533 |
# All populated zones burned (catastrophic failure)
|
| 534 |
total_pop = self.grid.get_total_population()
|
|
|
|
| 556 |
resource_state = self.resources.get_resource_state()
|
| 557 |
|
| 558 |
# Stats
|
| 559 |
+
total_burnable = self.grid.get_total_burnable()
|
| 560 |
+
cells_burned = self.grid.get_burned_count()
|
| 561 |
+
total_pop = self.grid.get_total_population()
|
| 562 |
+
pop_lost = self.grid.get_population_lost()
|
| 563 |
+
|
| 564 |
+
area_saved_pct = round(
|
| 565 |
+
100.0 * (total_burnable - cells_burned) / total_burnable, 1
|
| 566 |
+
) if total_burnable > 0 else 100.0
|
| 567 |
+
|
| 568 |
+
civilians_saved_pct = round(
|
| 569 |
+
100.0 * (total_pop - pop_lost) / total_pop, 1
|
| 570 |
+
) if total_pop > 0 else 100.0
|
| 571 |
+
|
| 572 |
stats = ClusterStats(
|
| 573 |
+
cells_burned=cells_burned,
|
| 574 |
cells_burning=self.grid.count_by_state(FireState.BURNING),
|
| 575 |
+
cells_saved=total_burnable - cells_burned - self.grid.count_by_state(FireState.BURNING),
|
| 576 |
population_threatened=self._count_threatened_population(),
|
| 577 |
+
population_lost=pop_lost,
|
| 578 |
+
total_population=total_pop,
|
| 579 |
containment_pct=self._compute_containment_pct(),
|
| 580 |
+
area_saved_pct=area_saved_pct,
|
| 581 |
+
civilians_saved_pct=civilians_saved_pct,
|
| 582 |
current_step=self.current_step,
|
| 583 |
max_steps=self.config.episode_length,
|
| 584 |
firebreaks_built=self.resources.total_firebreaks_built,
|
frontend/app.js
CHANGED
|
@@ -11,6 +11,69 @@
|
|
| 11 |
|
| 12 |
"use strict";
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
// ── Simulation state ──────────────────────────────────────────────────────────
|
| 15 |
const sim = {
|
| 16 |
obs: null, // current Observation (agent's view)
|
|
@@ -161,17 +224,28 @@ function renderCanvas(obs, groundTruth = null) {
|
|
| 161 |
}
|
| 162 |
|
| 163 |
// ── Stats panel ───────────────────────────────────────────────────────────────
|
| 164 |
-
function updateStats(
|
| 165 |
-
if (!stats) return;
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
const
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
setText("stat-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
setText(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
|
| 176 |
// Cumulative reward
|
| 177 |
setText("reward-total", cumulativeReward.toFixed(3));
|
|
@@ -298,31 +372,53 @@ function updateActionLog(action) {
|
|
| 298 |
}
|
| 299 |
|
| 300 |
// ── Terminal overlay ──────────────────────────────────────────────────────────
|
| 301 |
-
function showTerminal(
|
| 302 |
const overlay = document.getElementById("terminal-overlay");
|
| 303 |
if (!overlay) return;
|
| 304 |
|
| 305 |
-
const stats = obs?.stats ?? {};
|
| 306 |
-
const popLost = stats.population_lost ?? 0;
|
| 307 |
-
const containment = stats.containment_pct ?? 0;
|
| 308 |
-
|
| 309 |
const card = document.getElementById("terminal-card");
|
|
|
|
|
|
|
|
|
|
| 310 |
const title = card.querySelector("h2");
|
| 311 |
|
| 312 |
-
if (popLost === 0) {
|
| 313 |
-
title.textContent = "✅
|
| 314 |
title.className = "win";
|
| 315 |
} else {
|
| 316 |
title.textContent = "⚠ EPISODE ENDED";
|
| 317 |
title.className = "loss";
|
| 318 |
}
|
| 319 |
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
setText("terminal-
|
| 323 |
-
setText("terminal-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 324 |
|
| 325 |
overlay.classList.add("show");
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 326 |
}
|
| 327 |
|
| 328 |
function hideTerminal() {
|
|
@@ -356,7 +452,7 @@ async function apiGet(path) {
|
|
| 356 |
function applyObservation(obs) {
|
| 357 |
sim.obs = obs;
|
| 358 |
renderCanvas(obs, sim.groundTruthData);
|
| 359 |
-
updateStats(obs
|
| 360 |
updateResources(obs.resources);
|
| 361 |
updateWeather(obs.weather);
|
| 362 |
updateEvents(obs.recent_events ?? []);
|
|
@@ -417,7 +513,7 @@ async function doAutoStep() {
|
|
| 417 |
|
| 418 |
if (snap.done) {
|
| 419 |
stopPlay();
|
| 420 |
-
showTerminal(
|
| 421 |
break;
|
| 422 |
}
|
| 423 |
}
|
|
|
|
| 11 |
|
| 12 |
"use strict";
|
| 13 |
|
| 14 |
+
// ── API field helpers (snake_case from Python; tolerate camelCase if ever used) ─
|
| 15 |
+
function pickStat(obj, ...keys) {
|
| 16 |
+
if (!obj) return undefined;
|
| 17 |
+
for (const k of keys) {
|
| 18 |
+
if (Object.prototype.hasOwnProperty.call(obj, k) && obj[k] != null) {
|
| 19 |
+
return obj[k];
|
| 20 |
+
}
|
| 21 |
+
}
|
| 22 |
+
return undefined;
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
/**
|
| 26 |
+
* Build display-ready episode metrics from the latest observation.
|
| 27 |
+
* Falls back to grid-visible cells for land % only when server omits area_saved_pct.
|
| 28 |
+
*/
|
| 29 |
+
function normalizeEpisodeStats(obs) {
|
| 30 |
+
const st = obs?.stats ?? {};
|
| 31 |
+
const cellsBurned = pickStat(st, "cells_burned", "cellsBurned") ?? 0;
|
| 32 |
+
const popLost = pickStat(st, "population_lost", "populationLost") ?? 0;
|
| 33 |
+
const totalPop = pickStat(st, "total_population", "totalPopulation") ?? 0;
|
| 34 |
+
|
| 35 |
+
let areaSaved = pickStat(st, "area_saved_pct", "areaSavedPct");
|
| 36 |
+
let civSafe = pickStat(st, "civilians_saved_pct", "civiliansSavedPct");
|
| 37 |
+
|
| 38 |
+
if (areaSaved == null && obs?.grid?.length) {
|
| 39 |
+
let burnable = 0;
|
| 40 |
+
let burnedVis = 0;
|
| 41 |
+
for (const row of obs.grid) {
|
| 42 |
+
for (const cell of row) {
|
| 43 |
+
const f = cell.fuel_type;
|
| 44 |
+
if (!f || f === "water" || f === "road") continue;
|
| 45 |
+
if (cell.fire_state === "unknown") continue;
|
| 46 |
+
burnable++;
|
| 47 |
+
if (cell.fire_state === "burned_out") burnedVis++;
|
| 48 |
+
}
|
| 49 |
+
}
|
| 50 |
+
if (burnable > 0) {
|
| 51 |
+
areaSaved = Math.round(1000 * (burnable - burnedVis) / burnable) / 10;
|
| 52 |
+
}
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
if (civSafe == null && totalPop > 0) {
|
| 56 |
+
civSafe = Math.round(1000 * (totalPop - popLost) / totalPop) / 10;
|
| 57 |
+
} else if (civSafe == null && popLost === 0) {
|
| 58 |
+
civSafe = 100.0;
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
const containment = pickStat(st, "containment_pct", "containmentPct");
|
| 62 |
+
if (areaSaved == null && containment != null) {
|
| 63 |
+
areaSaved = containment;
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
return {
|
| 67 |
+
areaSaved,
|
| 68 |
+
civSafe,
|
| 69 |
+
cellsBurned,
|
| 70 |
+
popLost,
|
| 71 |
+
totalPop,
|
| 72 |
+
currentStep: pickStat(st, "current_step", "currentStep"),
|
| 73 |
+
raw: st,
|
| 74 |
+
};
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
// ── Simulation state ──────────────────────────────────────────────────────────
|
| 78 |
const sim = {
|
| 79 |
obs: null, // current Observation (agent's view)
|
|
|
|
| 224 |
}
|
| 225 |
|
| 226 |
// ── Stats panel ───────────────────────────────────────────────────────────────
|
| 227 |
+
function updateStats(obs, cumulativeReward, lastStepReward) {
|
| 228 |
+
if (!obs?.stats) return;
|
| 229 |
+
const stats = obs.stats;
|
| 230 |
+
|
| 231 |
+
const cur = pickStat(stats, "current_step", "currentStep") ?? 0;
|
| 232 |
+
const max = pickStat(stats, "max_steps", "maxSteps") ?? 1;
|
| 233 |
+
|
| 234 |
+
setText("stat-step", `${cur} / ${max}`);
|
| 235 |
+
|
| 236 |
+
const n = normalizeEpisodeStats(obs);
|
| 237 |
+
setText(
|
| 238 |
+
"stat-land-saved-val",
|
| 239 |
+
n.areaSaved != null ? `${Number(n.areaSaved).toFixed(1)}%` : "—"
|
| 240 |
+
);
|
| 241 |
+
setText(
|
| 242 |
+
"stat-civilians-safe-val",
|
| 243 |
+
n.civSafe != null ? `${Number(n.civSafe).toFixed(1)}%` : "—"
|
| 244 |
+
);
|
| 245 |
+
setText("stat-cells-burned-val", n.cellsBurned);
|
| 246 |
+
setText("stat-burning-val", pickStat(stats, "cells_burning", "cellsBurning") ?? 0);
|
| 247 |
+
setText("stat-pop-threat-val", pickStat(stats, "population_threatened", "populationThreatened") ?? 0);
|
| 248 |
+
setText("stat-pop-lost-val", n.popLost);
|
| 249 |
|
| 250 |
// Cumulative reward
|
| 251 |
setText("reward-total", cumulativeReward.toFixed(3));
|
|
|
|
| 372 |
}
|
| 373 |
|
| 374 |
// ── Terminal overlay ──────────────────────────────────────────────────────────
|
| 375 |
+
async function showTerminal() {
|
| 376 |
const overlay = document.getElementById("terminal-overlay");
|
| 377 |
if (!overlay) return;
|
| 378 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 379 |
const card = document.getElementById("terminal-card");
|
| 380 |
+
if (!card) return;
|
| 381 |
+
|
| 382 |
+
const n = normalizeEpisodeStats(sim.obs);
|
| 383 |
const title = card.querySelector("h2");
|
| 384 |
|
| 385 |
+
if (n.popLost === 0) {
|
| 386 |
+
title.textContent = "✅ EPISODE COMPLETE";
|
| 387 |
title.className = "win";
|
| 388 |
} else {
|
| 389 |
title.textContent = "⚠ EPISODE ENDED";
|
| 390 |
title.className = "loss";
|
| 391 |
}
|
| 392 |
|
| 393 |
+
const landStr = n.areaSaved != null ? `${Number(n.areaSaved).toFixed(1)}%` : "—";
|
| 394 |
+
const civStr = n.civSafe != null ? `${Number(n.civSafe).toFixed(1)}%` : "—";
|
| 395 |
+
setText("terminal-land-saved", landStr);
|
| 396 |
+
setText("terminal-civilians-safe", civStr);
|
| 397 |
+
setText("terminal-cells-burned", String(n.cellsBurned));
|
| 398 |
+
setText("terminal-pop-lost", n.popLost);
|
| 399 |
+
setText("terminal-reward", sim.cumulativeReward.toFixed(3));
|
| 400 |
+
setText("terminal-step", n.currentStep ?? "—");
|
| 401 |
|
| 402 |
overlay.classList.add("show");
|
| 403 |
+
|
| 404 |
+
// Authoritative end-game numbers (ground truth — fixes blank UI if observation JSON differed)
|
| 405 |
+
try {
|
| 406 |
+
const st = await apiGet("/state");
|
| 407 |
+
if (st.error) return;
|
| 408 |
+
const tb = st.total_burnable ?? 0;
|
| 409 |
+
const burned = st.cells_burned ?? 0;
|
| 410 |
+
const landPct = tb > 0 ? Math.round(1000 * (tb - burned) / tb) / 10 : 100;
|
| 411 |
+
const tp = st.total_population ?? 0;
|
| 412 |
+
const lost = st.population_lost ?? 0;
|
| 413 |
+
const civPct = tp > 0 ? Math.round(1000 * (tp - lost) / tp) / 10 : 100;
|
| 414 |
+
setText("terminal-land-saved", `${landPct}%`);
|
| 415 |
+
setText("terminal-civilians-safe", `${civPct}%`);
|
| 416 |
+
setText("terminal-cells-burned", String(burned));
|
| 417 |
+
setText("terminal-pop-lost", String(lost));
|
| 418 |
+
setText("terminal-step", st.current_step ?? "—");
|
| 419 |
+
} catch (e) {
|
| 420 |
+
console.warn("Could not refresh end-game stats from /state", e);
|
| 421 |
+
}
|
| 422 |
}
|
| 423 |
|
| 424 |
function hideTerminal() {
|
|
|
|
| 452 |
function applyObservation(obs) {
|
| 453 |
sim.obs = obs;
|
| 454 |
renderCanvas(obs, sim.groundTruthData);
|
| 455 |
+
updateStats(obs, sim.cumulativeReward, sim.lastStepReward);
|
| 456 |
updateResources(obs.resources);
|
| 457 |
updateWeather(obs.weather);
|
| 458 |
updateEvents(obs.recent_events ?? []);
|
|
|
|
| 513 |
|
| 514 |
if (snap.done) {
|
| 515 |
stopPlay();
|
| 516 |
+
await showTerminal();
|
| 517 |
break;
|
| 518 |
}
|
| 519 |
}
|
frontend/index.html
CHANGED
|
@@ -83,8 +83,16 @@
|
|
| 83 |
<div id="terminal-card">
|
| 84 |
<h2 class="win">✅ FIRE CONTAINED</h2>
|
| 85 |
<div class="stat-row">
|
| 86 |
-
<span>
|
| 87 |
-
<span id="terminal-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
</div>
|
| 89 |
<div class="stat-row">
|
| 90 |
<span>Population lost</span>
|
|
@@ -104,6 +112,10 @@
|
|
| 104 |
</div>
|
| 105 |
</div>
|
| 106 |
</div>
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
</main>
|
| 108 |
|
| 109 |
<!-- Sidebar -->
|
|
@@ -117,9 +129,17 @@
|
|
| 117 |
<span class="stat-label">STEP</span>
|
| 118 |
<span class="stat-value" id="stat-step">— / —</span>
|
| 119 |
</div>
|
| 120 |
-
<div class="stat-item" id="stat-
|
| 121 |
-
<span class="stat-label">
|
| 122 |
-
<span class="stat-value" id="stat-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
</div>
|
| 124 |
<div class="stat-item" id="stat-burning">
|
| 125 |
<span class="stat-label">BURNING</span>
|
|
@@ -274,6 +294,6 @@
|
|
| 274 |
</span>
|
| 275 |
</footer>
|
| 276 |
|
| 277 |
-
<script src="app.js"></script>
|
| 278 |
</body>
|
| 279 |
</html>
|
|
|
|
| 83 |
<div id="terminal-card">
|
| 84 |
<h2 class="win">✅ FIRE CONTAINED</h2>
|
| 85 |
<div class="stat-row">
|
| 86 |
+
<span>Land saved (unburned)</span>
|
| 87 |
+
<span id="terminal-land-saved">—</span>
|
| 88 |
+
</div>
|
| 89 |
+
<div class="stat-row">
|
| 90 |
+
<span>Civilians safe</span>
|
| 91 |
+
<span id="terminal-civilians-safe">—</span>
|
| 92 |
+
</div>
|
| 93 |
+
<div class="stat-row">
|
| 94 |
+
<span>Cells burned (total)</span>
|
| 95 |
+
<span id="terminal-cells-burned">—</span>
|
| 96 |
</div>
|
| 97 |
<div class="stat-row">
|
| 98 |
<span>Population lost</span>
|
|
|
|
| 112 |
</div>
|
| 113 |
</div>
|
| 114 |
</div>
|
| 115 |
+
<p id="map-legend" class="map-legend">
|
| 116 |
+
<strong>Map:</strong> green dot / circle = ground crew · blue outline = populated zone ·
|
| 117 |
+
bright blue cells = water · grey = roads
|
| 118 |
+
</p>
|
| 119 |
</main>
|
| 120 |
|
| 121 |
<!-- Sidebar -->
|
|
|
|
| 129 |
<span class="stat-label">STEP</span>
|
| 130 |
<span class="stat-value" id="stat-step">— / —</span>
|
| 131 |
</div>
|
| 132 |
+
<div class="stat-item" id="stat-land-saved">
|
| 133 |
+
<span class="stat-label">LAND SAVED</span>
|
| 134 |
+
<span class="stat-value" id="stat-land-saved-val">—</span>
|
| 135 |
+
</div>
|
| 136 |
+
<div class="stat-item" id="stat-civilians-safe">
|
| 137 |
+
<span class="stat-label">CIVILIANS SAFE</span>
|
| 138 |
+
<span class="stat-value" id="stat-civilians-safe-val">—</span>
|
| 139 |
+
</div>
|
| 140 |
+
<div class="stat-item" id="stat-cells-burned">
|
| 141 |
+
<span class="stat-label">CELLS BURNED</span>
|
| 142 |
+
<span class="stat-value" id="stat-cells-burned-val">—</span>
|
| 143 |
</div>
|
| 144 |
<div class="stat-item" id="stat-burning">
|
| 145 |
<span class="stat-label">BURNING</span>
|
|
|
|
| 294 |
</span>
|
| 295 |
</footer>
|
| 296 |
|
| 297 |
+
<script src="app.js?v=4"></script>
|
| 298 |
</body>
|
| 299 |
</html>
|
frontend/style.css
CHANGED
|
@@ -250,6 +250,16 @@ input[type="range"]::-webkit-slider-thumb {
|
|
| 250 |
|
| 251 |
#grid-canvas { display: block; image-rendering: pixelated; }
|
| 252 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 253 |
/* Tooltip overlay (shows cell info on hover) */
|
| 254 |
#cell-tooltip {
|
| 255 |
position: absolute;
|
|
@@ -356,8 +366,10 @@ input[type="range"]::-webkit-slider-thumb {
|
|
| 356 |
.stat-item.step-item { grid-column: 1 / -1; }
|
| 357 |
.stat-item.step-item .stat-value { font-size: 14px; }
|
| 358 |
|
| 359 |
-
#stat-
|
| 360 |
-
#stat-
|
|
|
|
|
|
|
| 361 |
#stat-pop-threat .stat-value { color: var(--warn); }
|
| 362 |
#stat-pop-lost .stat-value { color: var(--crit); }
|
| 363 |
|
|
|
|
| 250 |
|
| 251 |
#grid-canvas { display: block; image-rendering: pixelated; }
|
| 252 |
|
| 253 |
+
.map-legend {
|
| 254 |
+
margin: 8px 0 0;
|
| 255 |
+
padding: 6px 10px;
|
| 256 |
+
font-size: 11px;
|
| 257 |
+
color: var(--text-muted);
|
| 258 |
+
line-height: 1.45;
|
| 259 |
+
max-width: 100%;
|
| 260 |
+
}
|
| 261 |
+
.map-legend strong { color: var(--text); }
|
| 262 |
+
|
| 263 |
/* Tooltip overlay (shows cell info on hover) */
|
| 264 |
#cell-tooltip {
|
| 265 |
position: absolute;
|
|
|
|
| 366 |
.stat-item.step-item { grid-column: 1 / -1; }
|
| 367 |
.stat-item.step-item .stat-value { font-size: 14px; }
|
| 368 |
|
| 369 |
+
#stat-land-saved .stat-value { color: var(--safe); }
|
| 370 |
+
#stat-civilians-safe .stat-value { color: var(--safe); }
|
| 371 |
+
#stat-cells-burned .stat-value { color: var(--warn); }
|
| 372 |
+
#stat-burning .stat-value { color: var(--fire); }
|
| 373 |
#stat-pop-threat .stat-value { color: var(--warn); }
|
| 374 |
#stat-pop-lost .stat-value { color: var(--crit); }
|
| 375 |
|
graders/grader_easy.py
CHANGED
|
@@ -27,7 +27,7 @@ def grade(agent, seed: int = 42):
|
|
| 27 |
|
| 28 |
details = {
|
| 29 |
"total_reward": round(total_reward, 4),
|
| 30 |
-
"containment_pct": round(final.get("
|
| 31 |
"pop_saved_pct": round(1.0 - pop_lost / total_pop, 4),
|
| 32 |
"steps": env.current_step,
|
| 33 |
"crew_casualty": env._crew_casualty_occurred,
|
|
|
|
| 27 |
|
| 28 |
details = {
|
| 29 |
"total_reward": round(total_reward, 4),
|
| 30 |
+
"containment_pct": round(final.get("reward_breakdown", {}).get("containment", 0.0), 4),
|
| 31 |
"pop_saved_pct": round(1.0 - pop_lost / total_pop, 4),
|
| 32 |
"steps": env.current_step,
|
| 33 |
"crew_casualty": env._crew_casualty_occurred,
|
graders/grader_hard.py
CHANGED
|
@@ -27,7 +27,7 @@ def grade(agent, seed: int = 42):
|
|
| 27 |
|
| 28 |
details = {
|
| 29 |
"total_reward": round(total_reward, 4),
|
| 30 |
-
"containment_pct": round(final.get("
|
| 31 |
"pop_saved_pct": round(1.0 - pop_lost / total_pop, 4),
|
| 32 |
"steps": env.current_step,
|
| 33 |
"crew_casualty": env._crew_casualty_occurred,
|
|
|
|
| 27 |
|
| 28 |
details = {
|
| 29 |
"total_reward": round(total_reward, 4),
|
| 30 |
+
"containment_pct": round(final.get("reward_breakdown", {}).get("containment", 0.0), 4),
|
| 31 |
"pop_saved_pct": round(1.0 - pop_lost / total_pop, 4),
|
| 32 |
"steps": env.current_step,
|
| 33 |
"crew_casualty": env._crew_casualty_occurred,
|
graders/grader_medium.py
CHANGED
|
@@ -27,7 +27,7 @@ def grade(agent, seed: int = 42):
|
|
| 27 |
|
| 28 |
details = {
|
| 29 |
"total_reward": round(total_reward, 4),
|
| 30 |
-
"containment_pct": round(final.get("
|
| 31 |
"pop_saved_pct": round(1.0 - pop_lost / total_pop, 4),
|
| 32 |
"steps": env.current_step,
|
| 33 |
"crew_casualty": env._crew_casualty_occurred,
|
|
|
|
| 27 |
|
| 28 |
details = {
|
| 29 |
"total_reward": round(total_reward, 4),
|
| 30 |
+
"containment_pct": round(final.get("reward_breakdown", {}).get("containment", 0.0), 4),
|
| 31 |
"pop_saved_pct": round(1.0 - pop_lost / total_pop, 4),
|
| 32 |
"steps": env.current_step,
|
| 33 |
"crew_casualty": env._crew_casualty_occurred,
|
scripts/generate_sft_data.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Generate supervised fine-tuning (SFT) training examples by running the
|
| 3 |
+
HeuristicAgent through episodes and recording (prompt, action) pairs.
|
| 4 |
+
|
| 5 |
+
Usage:
|
| 6 |
+
python scripts/generate_sft_data.py
|
| 7 |
+
python scripts/generate_sft_data.py --output training/sft_data.jsonl --easy-seeds 500
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import argparse
|
| 13 |
+
import json
|
| 14 |
+
import os
|
| 15 |
+
import random
|
| 16 |
+
import sys
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
|
| 19 |
+
PROJECT_ROOT = str(Path(__file__).resolve().parent.parent)
|
| 20 |
+
sys.path.insert(0, PROJECT_ROOT)
|
| 21 |
+
|
| 22 |
+
from env.wildfire_env import WildfireEnv
|
| 23 |
+
from env.serialization import serialize_observation
|
| 24 |
+
from env.models import TIER_EASY, TIER_MEDIUM, TIER_HARD, ActionType
|
| 25 |
+
from agents.heuristic_agent import HeuristicAgent
|
| 26 |
+
|
| 27 |
+
SYSTEM_PROMPT = (
|
| 28 |
+
"You are an AI Incident Commander managing wildfire containment. "
|
| 29 |
+
"You will receive a situation briefing each step. "
|
| 30 |
+
"Respond with ONLY a valid JSON action object and nothing else. "
|
| 31 |
+
'Example: {"action_type": "idle"}'
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
TIER_CONFIGS = {
|
| 35 |
+
"easy": {"max_steps": TIER_EASY.episode_length, "target": 2000},
|
| 36 |
+
"medium": {"max_steps": TIER_MEDIUM.episode_length, "target": 1500},
|
| 37 |
+
"hard": {"max_steps": TIER_HARD.episode_length, "target": 800},
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def run_episode(tier: str, seed: int) -> list[dict] | None:
|
| 42 |
+
"""Run a full episode with the HeuristicAgent.
|
| 43 |
+
|
| 44 |
+
Returns a list of raw (prompt, action, step) records for the episode,
|
| 45 |
+
or None if the episode is unsuccessful (population lost > 0).
|
| 46 |
+
"""
|
| 47 |
+
max_steps = TIER_CONFIGS[tier]["max_steps"]
|
| 48 |
+
env = WildfireEnv()
|
| 49 |
+
obs = env.reset(task_id=tier, seed=seed)
|
| 50 |
+
agent = HeuristicAgent()
|
| 51 |
+
|
| 52 |
+
offset = random.randint(0, min(30, max_steps // 4))
|
| 53 |
+
|
| 54 |
+
prev_cells_burning = 0
|
| 55 |
+
records: list[dict] = []
|
| 56 |
+
step_num = 0
|
| 57 |
+
|
| 58 |
+
while not env.done:
|
| 59 |
+
action = agent.act(obs)
|
| 60 |
+
|
| 61 |
+
if step_num >= offset:
|
| 62 |
+
prompt_text = serialize_observation(
|
| 63 |
+
obs, step_num, max_steps,
|
| 64 |
+
tier=tier, prev_cells_burning=prev_cells_burning,
|
| 65 |
+
)
|
| 66 |
+
action_json = action.model_dump_json(exclude_none=True)
|
| 67 |
+
records.append({
|
| 68 |
+
"messages": [
|
| 69 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 70 |
+
{"role": "user", "content": prompt_text},
|
| 71 |
+
],
|
| 72 |
+
"completion": action_json,
|
| 73 |
+
"tier": tier,
|
| 74 |
+
"seed": seed,
|
| 75 |
+
"step": step_num,
|
| 76 |
+
"action_type": action.action_type.value,
|
| 77 |
+
})
|
| 78 |
+
|
| 79 |
+
prev_cells_burning = obs.stats.cells_burning
|
| 80 |
+
result = env.step(action)
|
| 81 |
+
obs = result.observation
|
| 82 |
+
step_num += 1
|
| 83 |
+
|
| 84 |
+
state = env.state()
|
| 85 |
+
if state["population_lost"] != 0:
|
| 86 |
+
return None
|
| 87 |
+
|
| 88 |
+
return records
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def filter_idle(records: list[dict]) -> list[dict]:
|
| 92 |
+
"""Keep all non-IDLE steps, then cap IDLE steps at 20% of total."""
|
| 93 |
+
non_idle = [r for r in records if r["action_type"] != "idle"]
|
| 94 |
+
idle = [r for r in records if r["action_type"] == "idle"]
|
| 95 |
+
|
| 96 |
+
if not non_idle:
|
| 97 |
+
return idle
|
| 98 |
+
|
| 99 |
+
max_idle = max(1, int(len(non_idle) * 0.25))
|
| 100 |
+
if len(idle) > max_idle:
|
| 101 |
+
random.shuffle(idle)
|
| 102 |
+
idle = idle[:max_idle]
|
| 103 |
+
|
| 104 |
+
combined = non_idle + idle
|
| 105 |
+
combined.sort(key=lambda r: r["step"])
|
| 106 |
+
return combined
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def strip_internal_fields(records: list[dict]) -> list[dict]:
|
| 110 |
+
"""Remove the action_type helper field before writing."""
|
| 111 |
+
for r in records:
|
| 112 |
+
r.pop("action_type", None)
|
| 113 |
+
return records
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def generate(output_path: str, max_seeds: dict[str, int]) -> None:
|
| 117 |
+
os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
|
| 118 |
+
|
| 119 |
+
all_examples: list[dict] = []
|
| 120 |
+
tier_counts = {t: 0 for t in TIER_CONFIGS}
|
| 121 |
+
|
| 122 |
+
for tier in ["easy", "medium", "hard"]:
|
| 123 |
+
target = TIER_CONFIGS[tier]["target"]
|
| 124 |
+
limit = max_seeds[tier]
|
| 125 |
+
seed = 0
|
| 126 |
+
|
| 127 |
+
print(f"\n{'='*50}")
|
| 128 |
+
print(f"Generating {tier} tier (target={target}, max_seeds={limit})")
|
| 129 |
+
print(f"{'='*50}")
|
| 130 |
+
|
| 131 |
+
while tier_counts[tier] < target and seed < limit:
|
| 132 |
+
records = run_episode(tier, seed)
|
| 133 |
+
|
| 134 |
+
if records is not None:
|
| 135 |
+
filtered = filter_idle(records)
|
| 136 |
+
remaining = target - tier_counts[tier]
|
| 137 |
+
if len(filtered) > remaining:
|
| 138 |
+
filtered = filtered[:remaining]
|
| 139 |
+
all_examples.extend(strip_internal_fields(filtered))
|
| 140 |
+
tier_counts[tier] += len(filtered)
|
| 141 |
+
|
| 142 |
+
seed += 1
|
| 143 |
+
if seed % 50 == 0:
|
| 144 |
+
print(f" [{tier}] seed={seed}, examples={tier_counts[tier]}/{target}")
|
| 145 |
+
|
| 146 |
+
print(f" [{tier}] DONE — {tier_counts[tier]} examples from {seed} seeds")
|
| 147 |
+
|
| 148 |
+
with open(output_path, "w", encoding="utf-8") as f:
|
| 149 |
+
for ex in all_examples:
|
| 150 |
+
f.write(json.dumps(ex, ensure_ascii=False) + "\n")
|
| 151 |
+
|
| 152 |
+
total = len(all_examples)
|
| 153 |
+
print(f"\n{'='*50}")
|
| 154 |
+
print(f"SFT data saved to {output_path}")
|
| 155 |
+
print(f"Total examples: {total}")
|
| 156 |
+
print(f"Tier distribution:")
|
| 157 |
+
for tier in ["easy", "medium", "hard"]:
|
| 158 |
+
print(f" {tier}: {tier_counts[tier]}")
|
| 159 |
+
print(f"{'='*50}")
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def main():
|
| 163 |
+
parser = argparse.ArgumentParser(description="Generate SFT training data from HeuristicAgent episodes")
|
| 164 |
+
parser.add_argument("--output", default="training/sft_data.jsonl",
|
| 165 |
+
help="Output JSONL file path (default: training/sft_data.jsonl)")
|
| 166 |
+
parser.add_argument("--easy-seeds", type=int, default=500,
|
| 167 |
+
help="Max seeds to try for easy tier")
|
| 168 |
+
parser.add_argument("--medium-seeds", type=int, default=500,
|
| 169 |
+
help="Max seeds to try for medium tier")
|
| 170 |
+
parser.add_argument("--hard-seeds", type=int, default=500,
|
| 171 |
+
help="Max seeds to try for hard tier")
|
| 172 |
+
args = parser.parse_args()
|
| 173 |
+
|
| 174 |
+
max_seeds = {
|
| 175 |
+
"easy": args.easy_seeds,
|
| 176 |
+
"medium": args.medium_seeds,
|
| 177 |
+
"hard": args.hard_seeds,
|
| 178 |
+
}
|
| 179 |
+
generate(args.output, max_seeds)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
if __name__ == "__main__":
|
| 183 |
+
main()
|
scripts/results.json
CHANGED
|
@@ -2,131 +2,101 @@
|
|
| 2 |
"random": {
|
| 3 |
"easy": {
|
| 4 |
"scores": [
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
8.25,
|
| 11 |
-
0.36,
|
| 12 |
-
8.35,
|
| 13 |
-
6.8251,
|
| 14 |
-
5.825
|
| 15 |
],
|
| 16 |
-
"mean": 6.
|
| 17 |
-
"std": 3.
|
| 18 |
-
"mean_containment_pct":
|
| 19 |
-
"mean_pop_saved_pct": 0.
|
| 20 |
-
"mean_steps":
|
| 21 |
"crew_casualty_rate": 0.0,
|
| 22 |
-
"mean_time_s": 0.
|
| 23 |
},
|
| 24 |
"medium": {
|
| 25 |
"scores": [
|
| 26 |
-
-1.
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
7.2,
|
| 32 |
-
8.3733,
|
| 33 |
-
8.3333,
|
| 34 |
-
-1.024,
|
| 35 |
-
-3.6238
|
| 36 |
],
|
| 37 |
-
"mean":
|
| 38 |
-
"std":
|
| 39 |
-
"mean_containment_pct":
|
| 40 |
-
"mean_pop_saved_pct": 0.
|
| 41 |
-
"mean_steps":
|
| 42 |
"crew_casualty_rate": 0.0,
|
| 43 |
-
"mean_time_s": 0.
|
| 44 |
},
|
| 45 |
"hard": {
|
| 46 |
"scores": [
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
-2.9395,
|
| 53 |
-
-5.5375,
|
| 54 |
-
-1.5395,
|
| 55 |
-
5.3,
|
| 56 |
-
5.3
|
| 57 |
],
|
| 58 |
-
"mean":
|
| 59 |
-
"std":
|
| 60 |
-
"mean_containment_pct":
|
| 61 |
-
"mean_pop_saved_pct": 0.
|
| 62 |
-
"mean_steps":
|
| 63 |
"crew_casualty_rate": 0.0,
|
| 64 |
-
"mean_time_s": 1.
|
| 65 |
}
|
| 66 |
},
|
| 67 |
"heuristic": {
|
| 68 |
"easy": {
|
| 69 |
"scores": [
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
8.35,
|
| 76 |
-
8.35,
|
| 77 |
-
8.35,
|
| 78 |
-
8.35,
|
| 79 |
-
8.35
|
| 80 |
],
|
| 81 |
-
"mean":
|
| 82 |
-
"std": 0.
|
| 83 |
-
"mean_containment_pct":
|
| 84 |
"mean_pop_saved_pct": 1.0,
|
| 85 |
-
"mean_steps":
|
| 86 |
"crew_casualty_rate": 0.0,
|
| 87 |
-
"mean_time_s": 0.
|
| 88 |
},
|
| 89 |
"medium": {
|
| 90 |
"scores": [
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
7.94,
|
| 97 |
-
8.3733,
|
| 98 |
-
8.3733,
|
| 99 |
-
7.8467,
|
| 100 |
-
7.2933
|
| 101 |
],
|
| 102 |
-
"mean":
|
| 103 |
-
"std":
|
| 104 |
-
"mean_containment_pct":
|
| 105 |
-
"mean_pop_saved_pct":
|
| 106 |
-
"mean_steps":
|
| 107 |
"crew_casualty_rate": 0.0,
|
| 108 |
-
"mean_time_s": 0.
|
| 109 |
},
|
| 110 |
"hard": {
|
| 111 |
"scores": [
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
5.8,
|
| 118 |
-
5.8,
|
| 119 |
-
4.8001,
|
| 120 |
-
5.6,
|
| 121 |
-
5.9
|
| 122 |
],
|
| 123 |
-
"mean": 4.
|
| 124 |
-
"std": 3.
|
| 125 |
-
"mean_containment_pct":
|
| 126 |
-
"mean_pop_saved_pct": 0.
|
| 127 |
-
"mean_steps":
|
| 128 |
"crew_casualty_rate": 0.0,
|
| 129 |
-
"mean_time_s": 1.
|
| 130 |
}
|
| 131 |
}
|
| 132 |
}
|
|
|
|
| 2 |
"random": {
|
| 3 |
"easy": {
|
| 4 |
"scores": [
|
| 5 |
+
7.7749,
|
| 6 |
+
7.7751,
|
| 7 |
+
7.775,
|
| 8 |
+
7.775,
|
| 9 |
+
0.04
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
],
|
| 11 |
+
"mean": 6.228,
|
| 12 |
+
"std": 3.094,
|
| 13 |
+
"mean_containment_pct": 1.0,
|
| 14 |
+
"mean_pop_saved_pct": 0.92,
|
| 15 |
+
"mean_steps": 25.8,
|
| 16 |
"crew_casualty_rate": 0.0,
|
| 17 |
+
"mean_time_s": 0.067
|
| 18 |
},
|
| 19 |
"medium": {
|
| 20 |
"scores": [
|
| 21 |
+
-1.7044,
|
| 22 |
+
-1.0029,
|
| 23 |
+
1.0762,
|
| 24 |
+
0.7527,
|
| 25 |
+
7.4403
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
],
|
| 27 |
+
"mean": 1.3124,
|
| 28 |
+
"std": 3.2367,
|
| 29 |
+
"mean_containment_pct": 1.0,
|
| 30 |
+
"mean_pop_saved_pct": 0.7365,
|
| 31 |
+
"mean_steps": 72.0,
|
| 32 |
"crew_casualty_rate": 0.0,
|
| 33 |
+
"mean_time_s": 0.676
|
| 34 |
},
|
| 35 |
"hard": {
|
| 36 |
"scores": [
|
| 37 |
+
7.8668,
|
| 38 |
+
1.3602,
|
| 39 |
+
-0.7466,
|
| 40 |
+
1.0443,
|
| 41 |
+
1.2813
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
],
|
| 43 |
+
"mean": 2.1612,
|
| 44 |
+
"std": 2.9554,
|
| 45 |
+
"mean_containment_pct": 1.0,
|
| 46 |
+
"mean_pop_saved_pct": 0.9023,
|
| 47 |
+
"mean_steps": 84.6,
|
| 48 |
"crew_casualty_rate": 0.0,
|
| 49 |
+
"mean_time_s": 1.301
|
| 50 |
}
|
| 51 |
},
|
| 52 |
"heuristic": {
|
| 53 |
"easy": {
|
| 54 |
"scores": [
|
| 55 |
+
7.6749,
|
| 56 |
+
7.575,
|
| 57 |
+
7.475,
|
| 58 |
+
7.475,
|
| 59 |
+
7.4749
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
],
|
| 61 |
+
"mean": 7.535,
|
| 62 |
+
"std": 0.08,
|
| 63 |
+
"mean_containment_pct": 1.0,
|
| 64 |
"mean_pop_saved_pct": 1.0,
|
| 65 |
+
"mean_steps": 26.6,
|
| 66 |
"crew_casualty_rate": 0.0,
|
| 67 |
+
"mean_time_s": 0.118
|
| 68 |
},
|
| 69 |
"medium": {
|
| 70 |
"scores": [
|
| 71 |
+
7.6001,
|
| 72 |
+
7.7001,
|
| 73 |
+
7.8,
|
| 74 |
+
7.7,
|
| 75 |
+
0.7683
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
],
|
| 77 |
+
"mean": 6.3137,
|
| 78 |
+
"std": 2.7734,
|
| 79 |
+
"mean_containment_pct": 1.0,
|
| 80 |
+
"mean_pop_saved_pct": 0.9746,
|
| 81 |
+
"mean_steps": 46.2,
|
| 82 |
"crew_casualty_rate": 0.0,
|
| 83 |
+
"mean_time_s": 0.48
|
| 84 |
},
|
| 85 |
"hard": {
|
| 86 |
"scores": [
|
| 87 |
+
7.8668,
|
| 88 |
+
7.867,
|
| 89 |
+
0.9443,
|
| 90 |
+
7.6667,
|
| 91 |
+
-0.6696
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
],
|
| 93 |
+
"mean": 4.735,
|
| 94 |
+
"std": 3.7892,
|
| 95 |
+
"mean_containment_pct": 1.0,
|
| 96 |
+
"mean_pop_saved_pct": 0.9279,
|
| 97 |
+
"mean_steps": 83.2,
|
| 98 |
"crew_casualty_rate": 0.0,
|
| 99 |
+
"mean_time_s": 1.487
|
| 100 |
}
|
| 101 |
}
|
| 102 |
}
|
training/grpo_v2_colab.ipynb
ADDED
|
@@ -0,0 +1,678 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# Wildfire Incident Command - GRPO Training (v2)\n",
|
| 8 |
+
"\n",
|
| 9 |
+
"GRPO reinforcement learning on a wildfire incident command model, starting from the SFT checkpoint.\n",
|
| 10 |
+
"\n",
|
| 11 |
+
"**Five critical issues fixed in this version:**\n",
|
| 12 |
+
"1. Prompt/reward state mismatch - dataset uses step-0 prompts only; reward replays the exact (tier, seed)\n",
|
| 13 |
+
"2. Truncated rollout - reward runs full episode to completion (heuristic continuation), terminal reward always included\n",
|
| 14 |
+
"3. Wasted inner model generations - MODEL_STEPS=1, only the sampled completion is applied\n",
|
| 15 |
+
"4. GRPO loop too slow - consequence of fix 3\n",
|
| 16 |
+
"5. parse_action(text, None) crash - standalone check_json_format() for format reward\n",
|
| 17 |
+
"\n",
|
| 18 |
+
"**Hardware:** A100 40GB on Colab"
|
| 19 |
+
]
|
| 20 |
+
},
|
| 21 |
+
{
|
| 22 |
+
"cell_type": "markdown",
|
| 23 |
+
"metadata": {},
|
| 24 |
+
"source": [
|
| 25 |
+
"## Section 1 - Install and Assert GPU"
|
| 26 |
+
]
|
| 27 |
+
},
|
| 28 |
+
{
|
| 29 |
+
"cell_type": "code",
|
| 30 |
+
"execution_count": null,
|
| 31 |
+
"metadata": {},
|
| 32 |
+
"outputs": [],
|
| 33 |
+
"source": [
|
| 34 |
+
"!pip install \"unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git\"\n",
|
| 35 |
+
"!pip install trl==0.15.2 datasets==3.4.1 wandb"
|
| 36 |
+
]
|
| 37 |
+
},
|
| 38 |
+
{
|
| 39 |
+
"cell_type": "code",
|
| 40 |
+
"execution_count": null,
|
| 41 |
+
"metadata": {},
|
| 42 |
+
"outputs": [],
|
| 43 |
+
"source": [
|
| 44 |
+
"import torch\n",
|
| 45 |
+
"assert torch.cuda.is_available(), \"GPU not available - switch to a GPU runtime\"\n",
|
| 46 |
+
"gpu_name = torch.cuda.get_device_name(0)\n",
|
| 47 |
+
"gpu_mem = torch.cuda.get_device_properties(0).total_mem / 1e9\n",
|
| 48 |
+
"print(f\"GPU: {gpu_name} | VRAM: {gpu_mem:.1f} GB\")"
|
| 49 |
+
]
|
| 50 |
+
},
|
| 51 |
+
{
|
| 52 |
+
"cell_type": "markdown",
|
| 53 |
+
"metadata": {},
|
| 54 |
+
"source": [
|
| 55 |
+
"## Section 2 - Load SFT Checkpoint"
|
| 56 |
+
]
|
| 57 |
+
},
|
| 58 |
+
{
|
| 59 |
+
"cell_type": "code",
|
| 60 |
+
"execution_count": null,
|
| 61 |
+
"metadata": {},
|
| 62 |
+
"outputs": [],
|
| 63 |
+
"source": [
|
| 64 |
+
"from unsloth import FastLanguageModel\n",
|
| 65 |
+
"\n",
|
| 66 |
+
"# Option A: Load from HuggingFace Hub\n",
|
| 67 |
+
"SFT_MODEL = \"Eshit/wildfire-sft-7b\"\n",
|
| 68 |
+
"# Option B: Load from local zip (uncomment and adjust if needed)\n",
|
| 69 |
+
"# !unzip sft_final.zip -d sft_final_dir\n",
|
| 70 |
+
"# SFT_MODEL = \"./sft_final_dir/sft_final\"\n",
|
| 71 |
+
"\n",
|
| 72 |
+
"model, tokenizer = FastLanguageModel.from_pretrained(\n",
|
| 73 |
+
" model_name=SFT_MODEL,\n",
|
| 74 |
+
" max_seq_length=2048,\n",
|
| 75 |
+
" load_in_4bit=True,\n",
|
| 76 |
+
")\n",
|
| 77 |
+
"\n",
|
| 78 |
+
"if tokenizer.pad_token is None:\n",
|
| 79 |
+
" tokenizer.pad_token = tokenizer.eos_token\n",
|
| 80 |
+
"\n",
|
| 81 |
+
"print(f\"Loaded SFT checkpoint: {SFT_MODEL}\")\n",
|
| 82 |
+
"model.print_trainable_parameters()"
|
| 83 |
+
]
|
| 84 |
+
},
|
| 85 |
+
{
|
| 86 |
+
"cell_type": "markdown",
|
| 87 |
+
"metadata": {},
|
| 88 |
+
"source": [
|
| 89 |
+
"## Section 3 - Constants and Controller Setup"
|
| 90 |
+
]
|
| 91 |
+
},
|
| 92 |
+
{
|
| 93 |
+
"cell_type": "code",
|
| 94 |
+
"execution_count": null,
|
| 95 |
+
"metadata": {},
|
| 96 |
+
"outputs": [],
|
| 97 |
+
"source": [
|
| 98 |
+
"import os, random, json, sys\n",
|
| 99 |
+
"import torch\n",
|
| 100 |
+
"\n",
|
| 101 |
+
"REPO_ROOT = \".\" # Adjust to repo root in Colab\n",
|
| 102 |
+
"if REPO_ROOT not in sys.path:\n",
|
| 103 |
+
" sys.path.insert(0, REPO_ROOT)\n",
|
| 104 |
+
"\n",
|
| 105 |
+
"from env.wildfire_env import WildfireEnv\n",
|
| 106 |
+
"from env.serialization import serialize_observation\n",
|
| 107 |
+
"from env.action_parser import parse_action\n",
|
| 108 |
+
"from agents.heuristic_agent import HeuristicAgent\n",
|
| 109 |
+
"from env.curriculum import CurriculumController\n",
|
| 110 |
+
"from datasets import Dataset\n",
|
| 111 |
+
"\n",
|
| 112 |
+
"SEED_POOL = list(range(100))\n",
|
| 113 |
+
"TIER_MAX_STEPS = {'easy': 80, 'medium': 150, 'hard': 300}\n",
|
| 114 |
+
"SYSTEM_PROMPT = (\n",
|
| 115 |
+
" 'You are an AI Incident Commander managing wildfire containment. '\n",
|
| 116 |
+
" 'You will receive a situation briefing each step. '\n",
|
| 117 |
+
" 'Respond with ONLY a valid JSON action object and nothing else. '\n",
|
| 118 |
+
" 'Example: {\"action_type\": \"idle\"}'\n",
|
| 119 |
+
")\n",
|
| 120 |
+
"\n",
|
| 121 |
+
"# Thresholds calibrated to full-episode reward with heuristic continuation.\n",
|
| 122 |
+
"# Promote easy->medium once model's first action consistently beats random (+6.23).\n",
|
| 123 |
+
"# Promote medium->hard once model demonstrates meaningful improvement over random (+1.31).\n",
|
| 124 |
+
"controller = CurriculumController(\n",
|
| 125 |
+
" start_tier='easy',\n",
|
| 126 |
+
" thresholds={'easy': 6.5, 'medium': 3.5},\n",
|
| 127 |
+
")\n",
|
| 128 |
+
"\n",
|
| 129 |
+
"os.makedirs('training/samples', exist_ok=True)\n",
|
| 130 |
+
"_reward_call_count = 0\n",
|
| 131 |
+
"\n",
|
| 132 |
+
"print(f\"Start tier: {controller.get_tier()}\")\n",
|
| 133 |
+
"print(f\"Seed pool: {len(SEED_POOL)} seeds\")\n",
|
| 134 |
+
"print(\"Env imports OK\")"
|
| 135 |
+
]
|
| 136 |
+
},
|
| 137 |
+
{
|
| 138 |
+
"cell_type": "markdown",
|
| 139 |
+
"metadata": {},
|
| 140 |
+
"source": [
|
| 141 |
+
"## Section 4 - Standalone JSON Format Checker\n",
|
| 142 |
+
"\n",
|
| 143 |
+
"Replaces parse_action for format reward - no obs object needed (Issue 5 fix)."
|
| 144 |
+
]
|
| 145 |
+
},
|
| 146 |
+
{
|
| 147 |
+
"cell_type": "code",
|
| 148 |
+
"execution_count": null,
|
| 149 |
+
"metadata": {},
|
| 150 |
+
"outputs": [],
|
| 151 |
+
"source": [
|
| 152 |
+
"import json as _json\n",
|
| 153 |
+
"import re as _re\n",
|
| 154 |
+
"from env.models import ActionType as _AT\n",
|
| 155 |
+
"\n",
|
| 156 |
+
"_VALID_ACTION_TYPES = {a.value for a in _AT}\n",
|
| 157 |
+
"\n",
|
| 158 |
+
"\n",
|
| 159 |
+
"def check_json_format(text: str) -> str:\n",
|
| 160 |
+
" \"\"\"\n",
|
| 161 |
+
" Validate LLM output format without needing an obs object.\n",
|
| 162 |
+
" Returns 'json_success', 'regex_fallback', or 'safe_idle'.\n",
|
| 163 |
+
" Does NOT use parse_action - avoids the obs.grid dependency.\n",
|
| 164 |
+
" \"\"\"\n",
|
| 165 |
+
" text = _re.sub(r'```(?:json)?\\s*', '', text).replace('```', '')\n",
|
| 166 |
+
" start = text.find('{')\n",
|
| 167 |
+
" if start == -1:\n",
|
| 168 |
+
" return 'safe_idle'\n",
|
| 169 |
+
" depth = 0\n",
|
| 170 |
+
" end = -1\n",
|
| 171 |
+
" for i, ch in enumerate(text[start:], start=start):\n",
|
| 172 |
+
" if ch == '{':\n",
|
| 173 |
+
" depth += 1\n",
|
| 174 |
+
" elif ch == '}':\n",
|
| 175 |
+
" depth -= 1\n",
|
| 176 |
+
" if depth == 0:\n",
|
| 177 |
+
" end = i\n",
|
| 178 |
+
" break\n",
|
| 179 |
+
" if end == -1:\n",
|
| 180 |
+
" return 'safe_idle'\n",
|
| 181 |
+
" try:\n",
|
| 182 |
+
" obj = _json.loads(text[start:end+1])\n",
|
| 183 |
+
" if not isinstance(obj, dict):\n",
|
| 184 |
+
" return 'safe_idle'\n",
|
| 185 |
+
" at = str(obj.get('action_type', '')).lower()\n",
|
| 186 |
+
" if at in _VALID_ACTION_TYPES:\n",
|
| 187 |
+
" return 'json_success'\n",
|
| 188 |
+
" return 'regex_fallback'\n",
|
| 189 |
+
" except Exception:\n",
|
| 190 |
+
" return 'regex_fallback'\n",
|
| 191 |
+
"\n",
|
| 192 |
+
"\n",
|
| 193 |
+
"assert check_json_format('{\"action_type\": \"idle\"}') == 'json_success'\n",
|
| 194 |
+
"assert check_json_format('{\"action_type\": \"bogus\"}') == 'regex_fallback'\n",
|
| 195 |
+
"assert check_json_format('no json here') == 'safe_idle'\n",
|
| 196 |
+
"print('check_json_format OK')"
|
| 197 |
+
]
|
| 198 |
+
},
|
| 199 |
+
{
|
| 200 |
+
"cell_type": "markdown",
|
| 201 |
+
"metadata": {},
|
| 202 |
+
"source": [
|
| 203 |
+
"## Section 5 - Reward Functions\n",
|
| 204 |
+
"\n",
|
| 205 |
+
"Two reward signals for GRPO:\n",
|
| 206 |
+
"- **reward_fn_outcome** - full-episode env reward (1 model step + heuristic continuation)\n",
|
| 207 |
+
"- **reward_fn_format** - JSON formatting quality (fast, no env needed)"
|
| 208 |
+
]
|
| 209 |
+
},
|
| 210 |
+
{
|
| 211 |
+
"cell_type": "code",
|
| 212 |
+
"execution_count": null,
|
| 213 |
+
"metadata": {},
|
| 214 |
+
"outputs": [],
|
| 215 |
+
"source": [
|
| 216 |
+
"def reward_fn_outcome(completions, prompts, tier=None, seed=None, **kwargs):\n",
|
| 217 |
+
" \"\"\"\n",
|
| 218 |
+
" Score each GRPO completion by:\n",
|
| 219 |
+
" 1. Resetting the env to the EXACT (tier, seed) that generated the prompt (Issue 1 fix).\n",
|
| 220 |
+
" 2. Applying the sampled completion as the single first action (MODEL_STEPS=1, Issue 3/4 fix).\n",
|
| 221 |
+
" 3. Running HeuristicAgent until episode completion (Issue 2 fix - captures terminal reward).\n",
|
| 222 |
+
"\n",
|
| 223 |
+
" tier and seed are dataset columns forwarded by GRPOTrainer.\n",
|
| 224 |
+
" \"\"\"\n",
|
| 225 |
+
" global _reward_call_count\n",
|
| 226 |
+
" _reward_call_count += 1\n",
|
| 227 |
+
" rewards = []\n",
|
| 228 |
+
"\n",
|
| 229 |
+
" for i, completion in enumerate(completions):\n",
|
| 230 |
+
" ep_tier = tier[i] if tier is not None else controller.get_tier()\n",
|
| 231 |
+
" ep_seed = seed[i] if seed is not None else random.choice(SEED_POOL)\n",
|
| 232 |
+
"\n",
|
| 233 |
+
" env = WildfireEnv()\n",
|
| 234 |
+
" obs = env.reset(task_id=ep_tier, seed=ep_seed)\n",
|
| 235 |
+
" total_reward = 0.0\n",
|
| 236 |
+
"\n",
|
| 237 |
+
" # Apply the sampled completion as step 0\n",
|
| 238 |
+
" text = completion if isinstance(completion, str) else completion[0]['content']\n",
|
| 239 |
+
" action, _ = parse_action(text, obs)\n",
|
| 240 |
+
" result = env.step(action)\n",
|
| 241 |
+
" total_reward += result.reward\n",
|
| 242 |
+
" obs = result.observation\n",
|
| 243 |
+
"\n",
|
| 244 |
+
" # Heuristic drives everything after (full episode to capture terminal reward)\n",
|
| 245 |
+
" heuristic = HeuristicAgent()\n",
|
| 246 |
+
" while not env.done:\n",
|
| 247 |
+
" action = heuristic.act(obs)\n",
|
| 248 |
+
" result = env.step(action)\n",
|
| 249 |
+
" total_reward += result.reward\n",
|
| 250 |
+
" obs = result.observation\n",
|
| 251 |
+
"\n",
|
| 252 |
+
" rewards.append(total_reward)\n",
|
| 253 |
+
"\n",
|
| 254 |
+
" # Update curriculum (once per batch, not per completion)\n",
|
| 255 |
+
" mean_r = sum(rewards) / len(rewards)\n",
|
| 256 |
+
" promoted = controller.after_episode(mean_r)\n",
|
| 257 |
+
" if promoted:\n",
|
| 258 |
+
" print(f' *** Curriculum promoted to: {promoted} (mean batch reward={mean_r:.2f}) ***')\n",
|
| 259 |
+
"\n",
|
| 260 |
+
" # Sample completions to disk for inspection\n",
|
| 261 |
+
" if _reward_call_count % 10 == 0:\n",
|
| 262 |
+
" sample_path = f'training/samples/call_{_reward_call_count}.txt'\n",
|
| 263 |
+
" with open(sample_path, 'w') as f:\n",
|
| 264 |
+
" f.write(f'call={_reward_call_count} tier={tier[0] if tier else \"?\"} reward={rewards[0]:.3f}\\n')\n",
|
| 265 |
+
" f.write('---\\n')\n",
|
| 266 |
+
" c = completions[0]\n",
|
| 267 |
+
" f.write(c if isinstance(c, str) else c[0]['content'])\n",
|
| 268 |
+
" f.write('\\n')\n",
|
| 269 |
+
"\n",
|
| 270 |
+
" return rewards\n",
|
| 271 |
+
"\n",
|
| 272 |
+
"\n",
|
| 273 |
+
"def reward_fn_format(completions, prompts, **kwargs):\n",
|
| 274 |
+
" \"\"\"\n",
|
| 275 |
+
" Scores JSON formatting quality using check_json_format() (no obs needed).\n",
|
| 276 |
+
" Runs independently of the env - fast and always well-defined.\n",
|
| 277 |
+
" \"\"\"\n",
|
| 278 |
+
" rewards = []\n",
|
| 279 |
+
" for completion in completions:\n",
|
| 280 |
+
" text = completion if isinstance(completion, str) else completion[0]['content']\n",
|
| 281 |
+
" status = check_json_format(text)\n",
|
| 282 |
+
" if status == 'json_success':\n",
|
| 283 |
+
" r = 0.15\n",
|
| 284 |
+
" elif status == 'regex_fallback':\n",
|
| 285 |
+
" r = 0.0\n",
|
| 286 |
+
" else:\n",
|
| 287 |
+
" r = -0.20\n",
|
| 288 |
+
" rewards.append(r)\n",
|
| 289 |
+
" return rewards\n",
|
| 290 |
+
"\n",
|
| 291 |
+
"\n",
|
| 292 |
+
"print('Reward functions defined.')"
|
| 293 |
+
]
|
| 294 |
+
},
|
| 295 |
+
{
|
| 296 |
+
"cell_type": "markdown",
|
| 297 |
+
"metadata": {},
|
| 298 |
+
"source": [
|
| 299 |
+
"## Section 6 - Dataset Builder (Step-0 Only)\n",
|
| 300 |
+
"\n",
|
| 301 |
+
"Each row stores the seed so reward_fn_outcome can replay the exact same env state.\n",
|
| 302 |
+
"No mid-episode offset - GRPO prompt and reward state are always step-0."
|
| 303 |
+
]
|
| 304 |
+
},
|
| 305 |
+
{
|
| 306 |
+
"cell_type": "code",
|
| 307 |
+
"execution_count": null,
|
| 308 |
+
"metadata": {},
|
| 309 |
+
"outputs": [],
|
| 310 |
+
"source": [
|
| 311 |
+
"def build_prompt_dataset(n=200):\n",
|
| 312 |
+
" \"\"\"\n",
|
| 313 |
+
" Build step-0 prompts for the current curriculum tier.\n",
|
| 314 |
+
" Stores the seed in each row so reward_fn can replay the exact same env state.\n",
|
| 315 |
+
" \"\"\"\n",
|
| 316 |
+
" rows = []\n",
|
| 317 |
+
" env_tmp = WildfireEnv()\n",
|
| 318 |
+
" tier = controller.get_tier()\n",
|
| 319 |
+
" max_steps = TIER_MAX_STEPS[tier]\n",
|
| 320 |
+
"\n",
|
| 321 |
+
" for i in range(n):\n",
|
| 322 |
+
" seed = SEED_POOL[i % len(SEED_POOL)]\n",
|
| 323 |
+
" obs = env_tmp.reset(task_id=tier, seed=seed)\n",
|
| 324 |
+
" prompt = serialize_observation(obs, 0, max_steps, tier=tier, prev_cells_burning=0)\n",
|
| 325 |
+
" rows.append({\n",
|
| 326 |
+
" 'prompt': [\n",
|
| 327 |
+
" {'role': 'system', 'content': SYSTEM_PROMPT},\n",
|
| 328 |
+
" {'role': 'user', 'content': prompt},\n",
|
| 329 |
+
" ],\n",
|
| 330 |
+
" 'tier': tier,\n",
|
| 331 |
+
" 'seed': seed,\n",
|
| 332 |
+
" })\n",
|
| 333 |
+
" return rows\n",
|
| 334 |
+
"\n",
|
| 335 |
+
"\n",
|
| 336 |
+
"_test_ds = build_prompt_dataset(3)\n",
|
| 337 |
+
"print(f\"Sample dataset row keys: {list(_test_ds[0].keys())}\")\n",
|
| 338 |
+
"print(f\"Tier: {_test_ds[0]['tier']}, Seed: {_test_ds[0]['seed']}\")\n",
|
| 339 |
+
"print(f\"Prompt roles: {[m['role'] for m in _test_ds[0]['prompt']]}\")\n",
|
| 340 |
+
"del _test_ds"
|
| 341 |
+
]
|
| 342 |
+
},
|
| 343 |
+
{
|
| 344 |
+
"cell_type": "markdown",
|
| 345 |
+
"metadata": {},
|
| 346 |
+
"source": [
|
| 347 |
+
"## Section 7 - CurriculumDatasetCallback\n",
|
| 348 |
+
"\n",
|
| 349 |
+
"Rebuilds the training dataset whenever the curriculum controller promotes to a new tier."
|
| 350 |
+
]
|
| 351 |
+
},
|
| 352 |
+
{
|
| 353 |
+
"cell_type": "code",
|
| 354 |
+
"execution_count": null,
|
| 355 |
+
"metadata": {},
|
| 356 |
+
"outputs": [],
|
| 357 |
+
"source": [
|
| 358 |
+
"from transformers import TrainerCallback\n",
|
| 359 |
+
"\n",
|
| 360 |
+
"\n",
|
| 361 |
+
"class CurriculumDatasetCallback(TrainerCallback):\n",
|
| 362 |
+
" def __init__(self, trainer_ref):\n",
|
| 363 |
+
" self._trainer = trainer_ref\n",
|
| 364 |
+
" self._last_tier = controller.get_tier()\n",
|
| 365 |
+
"\n",
|
| 366 |
+
" def on_step_end(self, args, state, control, **kwargs):\n",
|
| 367 |
+
" current_tier = controller.get_tier()\n",
|
| 368 |
+
" if current_tier != self._last_tier:\n",
|
| 369 |
+
" print(f' Rebuilding dataset for tier: {current_tier}')\n",
|
| 370 |
+
" new_ds = Dataset.from_list(build_prompt_dataset(200))\n",
|
| 371 |
+
" self._trainer.train_dataset = new_ds\n",
|
| 372 |
+
" self._last_tier = current_tier\n",
|
| 373 |
+
"\n",
|
| 374 |
+
"\n",
|
| 375 |
+
"print('CurriculumDatasetCallback defined.')"
|
| 376 |
+
]
|
| 377 |
+
},
|
| 378 |
+
{
|
| 379 |
+
"cell_type": "markdown",
|
| 380 |
+
"metadata": {},
|
| 381 |
+
"source": [
|
| 382 |
+
"## Section 8 - GRPOTrainer Setup"
|
| 383 |
+
]
|
| 384 |
+
},
|
| 385 |
+
{
|
| 386 |
+
"cell_type": "code",
|
| 387 |
+
"execution_count": null,
|
| 388 |
+
"metadata": {},
|
| 389 |
+
"outputs": [],
|
| 390 |
+
"source": [
|
| 391 |
+
"from trl import GRPOTrainer, GRPOConfig\n",
|
| 392 |
+
"\n",
|
| 393 |
+
"grpo_config = GRPOConfig(\n",
|
| 394 |
+
" output_dir='./grpo_checkpoints',\n",
|
| 395 |
+
" num_generations=8,\n",
|
| 396 |
+
" learning_rate=3e-6,\n",
|
| 397 |
+
" max_steps=400,\n",
|
| 398 |
+
" save_steps=20,\n",
|
| 399 |
+
" per_device_train_batch_size=1,\n",
|
| 400 |
+
" gradient_accumulation_steps=4,\n",
|
| 401 |
+
" max_completion_length=192,\n",
|
| 402 |
+
" logging_steps=1,\n",
|
| 403 |
+
" report_to='wandb',\n",
|
| 404 |
+
")\n",
|
| 405 |
+
"\n",
|
| 406 |
+
"FastLanguageModel.for_training(model)\n",
|
| 407 |
+
"\n",
|
| 408 |
+
"dataset = Dataset.from_list(build_prompt_dataset(200))\n",
|
| 409 |
+
"print(f'Initial dataset: {len(dataset)} rows, tier={controller.get_tier()}')\n",
|
| 410 |
+
"\n",
|
| 411 |
+
"trainer = GRPOTrainer(\n",
|
| 412 |
+
" model=model,\n",
|
| 413 |
+
" processing_class=tokenizer,\n",
|
| 414 |
+
" reward_funcs=[reward_fn_outcome, reward_fn_format],\n",
|
| 415 |
+
" args=grpo_config,\n",
|
| 416 |
+
" train_dataset=dataset,\n",
|
| 417 |
+
")\n",
|
| 418 |
+
"trainer.add_callback(CurriculumDatasetCallback(trainer))\n",
|
| 419 |
+
"\n",
|
| 420 |
+
"print('GRPOTrainer ready.')"
|
| 421 |
+
]
|
| 422 |
+
},
|
| 423 |
+
{
|
| 424 |
+
"cell_type": "markdown",
|
| 425 |
+
"metadata": {},
|
| 426 |
+
"source": [
|
| 427 |
+
"## Section 9 - Run Training"
|
| 428 |
+
]
|
| 429 |
+
},
|
| 430 |
+
{
|
| 431 |
+
"cell_type": "code",
|
| 432 |
+
"execution_count": null,
|
| 433 |
+
"metadata": {},
|
| 434 |
+
"outputs": [],
|
| 435 |
+
"source": [
|
| 436 |
+
"import wandb\n",
|
| 437 |
+
"wandb.init(project='wildfire-grpo', name='qwen7b-v2')\n",
|
| 438 |
+
"\n",
|
| 439 |
+
"print(f'Starting GRPO - {grpo_config.max_steps} steps, {grpo_config.num_generations} gen/prompt')\n",
|
| 440 |
+
"print(f'Reward: 1 model step at step-0, heuristic continuation to episode completion')\n",
|
| 441 |
+
"print(f'Start tier: {controller.get_tier()}')\n",
|
| 442 |
+
"\n",
|
| 443 |
+
"trainer.train()\n",
|
| 444 |
+
"print('Training complete.')\n",
|
| 445 |
+
"\n",
|
| 446 |
+
"history = controller.get_history()\n",
|
| 447 |
+
"stats = [{'step': ep, 'tier': t, 'mean_reward': r} for ep, t, r in history]\n",
|
| 448 |
+
"with open('./training_stats.json', 'w') as f:\n",
|
| 449 |
+
" json.dump(stats, f, indent=2)\n",
|
| 450 |
+
"print('Stats saved -> training_stats.json')"
|
| 451 |
+
]
|
| 452 |
+
},
|
| 453 |
+
{
|
| 454 |
+
"cell_type": "markdown",
|
| 455 |
+
"metadata": {},
|
| 456 |
+
"source": [
|
| 457 |
+
"## Section 10 - Evaluate vs Baselines\n",
|
| 458 |
+
"\n",
|
| 459 |
+
"Run 15 full episodes per tier (seeds 42-56), compare with heuristic and random baselines."
|
| 460 |
+
]
|
| 461 |
+
},
|
| 462 |
+
{
|
| 463 |
+
"cell_type": "code",
|
| 464 |
+
"execution_count": null,
|
| 465 |
+
"metadata": {},
|
| 466 |
+
"outputs": [],
|
| 467 |
+
"source": [
|
| 468 |
+
"class LLMAgent:\n",
|
| 469 |
+
" \"\"\"Wraps the trained model for evaluation. Must be re-instantiated per episode.\"\"\"\n",
|
| 470 |
+
"\n",
|
| 471 |
+
" def __init__(self, model, tokenizer, tier, max_steps):\n",
|
| 472 |
+
" self.model = model\n",
|
| 473 |
+
" self.tokenizer = tokenizer\n",
|
| 474 |
+
" self.tier = tier\n",
|
| 475 |
+
" self.max_steps = max_steps\n",
|
| 476 |
+
" self._step = 0\n",
|
| 477 |
+
" self._prev_burning = 0\n",
|
| 478 |
+
" self.json_success = self.regex_fallback = self.safe_idle = 0\n",
|
| 479 |
+
"\n",
|
| 480 |
+
" def act(self, obs):\n",
|
| 481 |
+
" prompt = serialize_observation(\n",
|
| 482 |
+
" obs, self._step, self.max_steps,\n",
|
| 483 |
+
" tier=self.tier,\n",
|
| 484 |
+
" prev_cells_burning=self._prev_burning,\n",
|
| 485 |
+
" )\n",
|
| 486 |
+
" self._prev_burning = obs.stats.cells_burning\n",
|
| 487 |
+
" messages = [\n",
|
| 488 |
+
" {'role': 'system', 'content': SYSTEM_PROMPT},\n",
|
| 489 |
+
" {'role': 'user', 'content': prompt},\n",
|
| 490 |
+
" ]\n",
|
| 491 |
+
" input_ids = self.tokenizer.apply_chat_template(\n",
|
| 492 |
+
" messages, tokenize=True,\n",
|
| 493 |
+
" add_generation_prompt=True, return_tensors='pt',\n",
|
| 494 |
+
" ).to(self.model.device)\n",
|
| 495 |
+
" with torch.no_grad():\n",
|
| 496 |
+
" out = self.model.generate(\n",
|
| 497 |
+
" input_ids, max_new_tokens=128,\n",
|
| 498 |
+
" pad_token_id=self.tokenizer.eos_token_id,\n",
|
| 499 |
+
" )\n",
|
| 500 |
+
" text = self.tokenizer.decode(out[0][input_ids.shape[1]:], skip_special_tokens=True)\n",
|
| 501 |
+
" action, status = parse_action(text, obs)\n",
|
| 502 |
+
" if status == 'json_success':\n",
|
| 503 |
+
" self.json_success += 1\n",
|
| 504 |
+
" elif status == 'regex_fallback':\n",
|
| 505 |
+
" self.regex_fallback += 1\n",
|
| 506 |
+
" else:\n",
|
| 507 |
+
" self.safe_idle += 1\n",
|
| 508 |
+
" self._step += 1\n",
|
| 509 |
+
" return action\n",
|
| 510 |
+
"\n",
|
| 511 |
+
"\n",
|
| 512 |
+
"print('LLMAgent class defined.')"
|
| 513 |
+
]
|
| 514 |
+
},
|
| 515 |
+
{
|
| 516 |
+
"cell_type": "code",
|
| 517 |
+
"execution_count": null,
|
| 518 |
+
"metadata": {},
|
| 519 |
+
"outputs": [],
|
| 520 |
+
"source": [
|
| 521 |
+
"import numpy as np\n",
|
| 522 |
+
"\n",
|
| 523 |
+
"with open('scripts/results.json', 'r') as f:\n",
|
| 524 |
+
" baselines = json.load(f)\n",
|
| 525 |
+
"\n",
|
| 526 |
+
"FastLanguageModel.for_inference(model)\n",
|
| 527 |
+
"\n",
|
| 528 |
+
"EVAL_SEEDS = list(range(42, 57))\n",
|
| 529 |
+
"TIERS = ['easy', 'medium', 'hard']\n",
|
| 530 |
+
"\n",
|
| 531 |
+
"results = {}\n",
|
| 532 |
+
"\n",
|
| 533 |
+
"for tier in TIERS:\n",
|
| 534 |
+
" max_steps = TIER_MAX_STEPS[tier]\n",
|
| 535 |
+
" tier_rewards = []\n",
|
| 536 |
+
" tier_pop_saved = []\n",
|
| 537 |
+
" tier_json_success = 0\n",
|
| 538 |
+
" tier_total_actions = 0\n",
|
| 539 |
+
"\n",
|
| 540 |
+
" print(f'\\nEvaluating {tier} tier...')\n",
|
| 541 |
+
"\n",
|
| 542 |
+
" for seed in EVAL_SEEDS:\n",
|
| 543 |
+
" agent = LLMAgent(model, tokenizer, tier, max_steps)\n",
|
| 544 |
+
" env = WildfireEnv()\n",
|
| 545 |
+
" obs = env.reset(task_id=tier, seed=seed)\n",
|
| 546 |
+
" total_reward = 0.0\n",
|
| 547 |
+
"\n",
|
| 548 |
+
" while not env.done:\n",
|
| 549 |
+
" action = agent.act(obs)\n",
|
| 550 |
+
" result = env.step(action)\n",
|
| 551 |
+
" total_reward += result.reward\n",
|
| 552 |
+
" obs = result.observation\n",
|
| 553 |
+
"\n",
|
| 554 |
+
" tier_rewards.append(total_reward)\n",
|
| 555 |
+
"\n",
|
| 556 |
+
" state = env.state()\n",
|
| 557 |
+
" total_pop = state['total_population']\n",
|
| 558 |
+
" pop_lost = state['population_lost']\n",
|
| 559 |
+
" pop_saved = 100.0 * (total_pop - pop_lost) / total_pop if total_pop > 0 else 100.0\n",
|
| 560 |
+
" tier_pop_saved.append(pop_saved)\n",
|
| 561 |
+
"\n",
|
| 562 |
+
" tier_json_success += agent.json_success\n",
|
| 563 |
+
" tier_total_actions += agent.json_success + agent.regex_fallback + agent.safe_idle\n",
|
| 564 |
+
"\n",
|
| 565 |
+
" print(f' seed={seed}: reward={total_reward:+.2f}, pop_saved={pop_saved:.0f}%')\n",
|
| 566 |
+
"\n",
|
| 567 |
+
" json_rate = 100.0 * tier_json_success / tier_total_actions if tier_total_actions > 0 else 0\n",
|
| 568 |
+
" results[tier] = {\n",
|
| 569 |
+
" 'mean': float(np.mean(tier_rewards)),\n",
|
| 570 |
+
" 'std': float(np.std(tier_rewards)),\n",
|
| 571 |
+
" 'pop_saved_pct': float(np.mean(tier_pop_saved)),\n",
|
| 572 |
+
" 'json_success_rate': json_rate,\n",
|
| 573 |
+
" }\n",
|
| 574 |
+
"\n",
|
| 575 |
+
"print()\n",
|
| 576 |
+
"print('=' * 65)\n",
|
| 577 |
+
"print('=== Evaluation: Trained Model vs Baselines ===')\n",
|
| 578 |
+
"print('Seeds: 42-56 (15 per tier)')\n",
|
| 579 |
+
"print('=' * 65)\n",
|
| 580 |
+
"header = f'{\"Tier\":<10} {\"Trained\":>12} {\"Heuristic\":>12} {\"Random\":>12} {\"vs Heur.\":>12}'\n",
|
| 581 |
+
"print(header)\n",
|
| 582 |
+
"print('-' * 65)\n",
|
| 583 |
+
"\n",
|
| 584 |
+
"any_tier_close = False\n",
|
| 585 |
+
"for tier in TIERS:\n",
|
| 586 |
+
" t = results[tier]\n",
|
| 587 |
+
" h_mean = baselines['heuristic'][tier]['mean']\n",
|
| 588 |
+
" h_std = baselines['heuristic'][tier]['std']\n",
|
| 589 |
+
" r_mean = baselines['random'][tier]['mean']\n",
|
| 590 |
+
" r_std = baselines['random'][tier]['std']\n",
|
| 591 |
+
" delta = t['mean'] - h_mean\n",
|
| 592 |
+
" marker = ' OK' if delta >= -1.0 else ''\n",
|
| 593 |
+
" if delta >= -1.0:\n",
|
| 594 |
+
" any_tier_close = True\n",
|
| 595 |
+
" print(\n",
|
| 596 |
+
" f'{tier:<10} '\n",
|
| 597 |
+
" f'{t[\"mean\"]:+.2f}+/-{t[\"std\"]:.1f} '\n",
|
| 598 |
+
" f'{h_mean:+.2f}+/-{h_std:.1f} '\n",
|
| 599 |
+
" f'{r_mean:+.2f}+/-{r_std:.1f} '\n",
|
| 600 |
+
" f'{delta:+.2f}{marker}'\n",
|
| 601 |
+
" )\n",
|
| 602 |
+
"\n",
|
| 603 |
+
"print()\n",
|
| 604 |
+
"print('JSON success rate: ', end='')\n",
|
| 605 |
+
"print(' '.join(f'{t}={results[t][\"json_success_rate\"]:.1f}%' for t in TIERS))\n",
|
| 606 |
+
"print('Pop saved rate: ', end='')\n",
|
| 607 |
+
"print(' '.join(f'{t}={results[t][\"pop_saved_pct\"]:.0f}%' for t in TIERS))\n",
|
| 608 |
+
"\n",
|
| 609 |
+
"assert any_tier_close, (\n",
|
| 610 |
+
" 'Trained model did not come within 1.0 of heuristic on any tier. '\n",
|
| 611 |
+
" 'Check training logs and sample completions.'\n",
|
| 612 |
+
")\n",
|
| 613 |
+
"print('\\nPASS: At least one tier within 1.0 of heuristic baseline.')\n",
|
| 614 |
+
"\n",
|
| 615 |
+
"FastLanguageModel.for_training(model)"
|
| 616 |
+
]
|
| 617 |
+
},
|
| 618 |
+
{
|
| 619 |
+
"cell_type": "markdown",
|
| 620 |
+
"metadata": {},
|
| 621 |
+
"source": [
|
| 622 |
+
"## Section 11 - Save and Push"
|
| 623 |
+
]
|
| 624 |
+
},
|
| 625 |
+
{
|
| 626 |
+
"cell_type": "code",
|
| 627 |
+
"execution_count": null,
|
| 628 |
+
"metadata": {},
|
| 629 |
+
"outputs": [],
|
| 630 |
+
"source": [
|
| 631 |
+
"model.save_pretrained('./grpo_final')\n",
|
| 632 |
+
"tokenizer.save_pretrained('./grpo_final')\n",
|
| 633 |
+
"print('Saved to ./grpo_final')"
|
| 634 |
+
]
|
| 635 |
+
},
|
| 636 |
+
{
|
| 637 |
+
"cell_type": "code",
|
| 638 |
+
"execution_count": null,
|
| 639 |
+
"metadata": {},
|
| 640 |
+
"outputs": [],
|
| 641 |
+
"source": [
|
| 642 |
+
"HF_USERNAME = 'Eshit' # <-- CHANGE THIS\n",
|
| 643 |
+
"model.push_to_hub(f'{HF_USERNAME}/wildfire-grpo-7b')\n",
|
| 644 |
+
"tokenizer.push_to_hub(f'{HF_USERNAME}/wildfire-grpo-7b')\n",
|
| 645 |
+
"print(f'Pushed to hub: {HF_USERNAME}/wildfire-grpo-7b')"
|
| 646 |
+
]
|
| 647 |
+
},
|
| 648 |
+
{
|
| 649 |
+
"cell_type": "code",
|
| 650 |
+
"execution_count": null,
|
| 651 |
+
"metadata": {},
|
| 652 |
+
"outputs": [],
|
| 653 |
+
"source": [
|
| 654 |
+
"!zip -r grpo_final.zip ./grpo_final\n",
|
| 655 |
+
"from google.colab import files\n",
|
| 656 |
+
"files.download('grpo_final.zip')\n",
|
| 657 |
+
"print('Download started.')"
|
| 658 |
+
]
|
| 659 |
+
}
|
| 660 |
+
],
|
| 661 |
+
"metadata": {
|
| 662 |
+
"accelerator": "GPU",
|
| 663 |
+
"colab": {
|
| 664 |
+
"gpuType": "A100",
|
| 665 |
+
"provenance": []
|
| 666 |
+
},
|
| 667 |
+
"kernelspec": {
|
| 668 |
+
"display_name": "Python 3",
|
| 669 |
+
"name": "python3"
|
| 670 |
+
},
|
| 671 |
+
"language_info": {
|
| 672 |
+
"name": "python",
|
| 673 |
+
"version": "3.10.0"
|
| 674 |
+
}
|
| 675 |
+
},
|
| 676 |
+
"nbformat": 4,
|
| 677 |
+
"nbformat_minor": 5
|
| 678 |
+
}
|
training/sft_colab.ipynb
ADDED
|
@@ -0,0 +1,387 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# Wildfire Incident Command — SFT Training\n",
|
| 8 |
+
"\n",
|
| 9 |
+
"Supervised fine-tuning of **Qwen2.5-7B-Instruct** on wildfire incident command data.\n",
|
| 10 |
+
"\n",
|
| 11 |
+
"- **Input:** `training/sft_data.jsonl` (generated by `scripts/generate_sft_data.py`)\n",
|
| 12 |
+
"- **Goal:** Teach the model to output valid JSON action objects given wildfire observations\n",
|
| 13 |
+
"- **Hardware:** A100 40GB on Colab"
|
| 14 |
+
]
|
| 15 |
+
},
|
| 16 |
+
{
|
| 17 |
+
"cell_type": "markdown",
|
| 18 |
+
"metadata": {},
|
| 19 |
+
"source": [
|
| 20 |
+
"## Section 1 — Install Dependencies"
|
| 21 |
+
]
|
| 22 |
+
},
|
| 23 |
+
{
|
| 24 |
+
"cell_type": "code",
|
| 25 |
+
"execution_count": null,
|
| 26 |
+
"metadata": {},
|
| 27 |
+
"outputs": [],
|
| 28 |
+
"source": [
|
| 29 |
+
"!pip install \"unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git\"\n",
|
| 30 |
+
"!pip install trl==0.15.2 datasets==3.4.1"
|
| 31 |
+
]
|
| 32 |
+
},
|
| 33 |
+
{
|
| 34 |
+
"cell_type": "code",
|
| 35 |
+
"execution_count": null,
|
| 36 |
+
"metadata": {},
|
| 37 |
+
"outputs": [],
|
| 38 |
+
"source": [
|
| 39 |
+
"import torch\n",
|
| 40 |
+
"assert torch.cuda.is_available(), \"GPU not available — switch to a GPU runtime\"\n",
|
| 41 |
+
"gpu_name = torch.cuda.get_device_name(0)\n",
|
| 42 |
+
"gpu_mem = torch.cuda.get_device_properties(0).total_mem / 1e9\n",
|
| 43 |
+
"print(f\"GPU: {gpu_name} | VRAM: {gpu_mem:.1f} GB\")"
|
| 44 |
+
]
|
| 45 |
+
},
|
| 46 |
+
{
|
| 47 |
+
"cell_type": "markdown",
|
| 48 |
+
"metadata": {},
|
| 49 |
+
"source": [
|
| 50 |
+
"## Section 2 — Load Model"
|
| 51 |
+
]
|
| 52 |
+
},
|
| 53 |
+
{
|
| 54 |
+
"cell_type": "code",
|
| 55 |
+
"execution_count": null,
|
| 56 |
+
"metadata": {},
|
| 57 |
+
"outputs": [],
|
| 58 |
+
"source": [
|
| 59 |
+
"from unsloth import FastLanguageModel\n",
|
| 60 |
+
"\n",
|
| 61 |
+
"model, tokenizer = FastLanguageModel.from_pretrained(\n",
|
| 62 |
+
" model_name=\"unsloth/Qwen2.5-7B-Instruct\",\n",
|
| 63 |
+
" max_seq_length=2048,\n",
|
| 64 |
+
" load_in_4bit=True,\n",
|
| 65 |
+
")\n",
|
| 66 |
+
"\n",
|
| 67 |
+
"if tokenizer.pad_token is None:\n",
|
| 68 |
+
" tokenizer.pad_token = tokenizer.eos_token\n",
|
| 69 |
+
"\n",
|
| 70 |
+
"model = FastLanguageModel.get_peft_model(\n",
|
| 71 |
+
" model,\n",
|
| 72 |
+
" r=32,\n",
|
| 73 |
+
" lora_alpha=64,\n",
|
| 74 |
+
" lora_dropout=0.05,\n",
|
| 75 |
+
" target_modules=[\n",
|
| 76 |
+
" 'q_proj', 'k_proj', 'v_proj', 'o_proj',\n",
|
| 77 |
+
" 'gate_proj', 'up_proj', 'down_proj',\n",
|
| 78 |
+
" ],\n",
|
| 79 |
+
" bias=\"none\",\n",
|
| 80 |
+
" use_gradient_checkpointing=\"unsloth\",\n",
|
| 81 |
+
")\n",
|
| 82 |
+
"\n",
|
| 83 |
+
"print(f\"Model loaded. Trainable params: {model.print_trainable_parameters()}\")"
|
| 84 |
+
]
|
| 85 |
+
},
|
| 86 |
+
{
|
| 87 |
+
"cell_type": "markdown",
|
| 88 |
+
"metadata": {},
|
| 89 |
+
"source": [
|
| 90 |
+
"## Section 3 — Load Data"
|
| 91 |
+
]
|
| 92 |
+
},
|
| 93 |
+
{
|
| 94 |
+
"cell_type": "code",
|
| 95 |
+
"execution_count": null,
|
| 96 |
+
"metadata": {},
|
| 97 |
+
"outputs": [],
|
| 98 |
+
"source": [
|
| 99 |
+
"import json\n",
|
| 100 |
+
"from datasets import Dataset\n",
|
| 101 |
+
"from collections import Counter\n",
|
| 102 |
+
"\n",
|
| 103 |
+
"SFT_DATA_PATH = \"training/sft_data.jsonl\"\n",
|
| 104 |
+
"\n",
|
| 105 |
+
"raw_examples = []\n",
|
| 106 |
+
"with open(SFT_DATA_PATH, \"r\", encoding=\"utf-8\") as f:\n",
|
| 107 |
+
" for line in f:\n",
|
| 108 |
+
" raw_examples.append(json.loads(line))\n",
|
| 109 |
+
"\n",
|
| 110 |
+
"print(f\"Loaded {len(raw_examples)} raw examples\")\n",
|
| 111 |
+
"\n",
|
| 112 |
+
"tier_dist = Counter(ex[\"tier\"] for ex in raw_examples)\n",
|
| 113 |
+
"print(f\"Tier distribution: {dict(tier_dist)}\")"
|
| 114 |
+
]
|
| 115 |
+
},
|
| 116 |
+
{
|
| 117 |
+
"cell_type": "code",
|
| 118 |
+
"execution_count": null,
|
| 119 |
+
"metadata": {},
|
| 120 |
+
"outputs": [],
|
| 121 |
+
"source": [
|
| 122 |
+
"def format_example(ex):\n",
|
| 123 |
+
" \"\"\"Format a single SFT example into a full conversation string for causal LM loss.\"\"\"\n",
|
| 124 |
+
" messages = ex[\"messages\"]\n",
|
| 125 |
+
" completion = ex[\"completion\"]\n",
|
| 126 |
+
"\n",
|
| 127 |
+
" prompt_str = tokenizer.apply_chat_template(\n",
|
| 128 |
+
" messages,\n",
|
| 129 |
+
" tokenize=False,\n",
|
| 130 |
+
" add_generation_prompt=True,\n",
|
| 131 |
+
" )\n",
|
| 132 |
+
" full_text = prompt_str + completion + tokenizer.eos_token\n",
|
| 133 |
+
" return {\"text\": full_text}\n",
|
| 134 |
+
"\n",
|
| 135 |
+
"\n",
|
| 136 |
+
"formatted = [format_example(ex) for ex in raw_examples]\n",
|
| 137 |
+
"dataset = Dataset.from_list(formatted)\n",
|
| 138 |
+
"\n",
|
| 139 |
+
"split = dataset.train_test_split(test_size=0.05, seed=42)\n",
|
| 140 |
+
"train_dataset = split[\"train\"]\n",
|
| 141 |
+
"val_dataset = split[\"test\"]\n",
|
| 142 |
+
"\n",
|
| 143 |
+
"print(f\"Train: {len(train_dataset)} | Val: {len(val_dataset)}\")\n",
|
| 144 |
+
"print(f\"\\nSample (first 500 chars):\\n{formatted[0]['text'][:500]}\")"
|
| 145 |
+
]
|
| 146 |
+
},
|
| 147 |
+
{
|
| 148 |
+
"cell_type": "markdown",
|
| 149 |
+
"metadata": {},
|
| 150 |
+
"source": [
|
| 151 |
+
"## Section 4 — Train"
|
| 152 |
+
]
|
| 153 |
+
},
|
| 154 |
+
{
|
| 155 |
+
"cell_type": "code",
|
| 156 |
+
"execution_count": null,
|
| 157 |
+
"metadata": {},
|
| 158 |
+
"outputs": [],
|
| 159 |
+
"source": [
|
| 160 |
+
"from trl import SFTTrainer\n",
|
| 161 |
+
"from transformers import TrainingArguments\n",
|
| 162 |
+
"\n",
|
| 163 |
+
"trainer = SFTTrainer(\n",
|
| 164 |
+
" model=model,\n",
|
| 165 |
+
" tokenizer=tokenizer,\n",
|
| 166 |
+
" train_dataset=train_dataset,\n",
|
| 167 |
+
" eval_dataset=val_dataset,\n",
|
| 168 |
+
" dataset_text_field=\"text\",\n",
|
| 169 |
+
" max_seq_length=2048,\n",
|
| 170 |
+
" packing=True,\n",
|
| 171 |
+
" args=TrainingArguments(\n",
|
| 172 |
+
" output_dir=\"./sft_checkpoints\",\n",
|
| 173 |
+
" per_device_train_batch_size=2,\n",
|
| 174 |
+
" gradient_accumulation_steps=4,\n",
|
| 175 |
+
" num_train_epochs=1,\n",
|
| 176 |
+
" learning_rate=2e-4,\n",
|
| 177 |
+
" warmup_ratio=0.05,\n",
|
| 178 |
+
" lr_scheduler_type=\"cosine\",\n",
|
| 179 |
+
" logging_steps=10,\n",
|
| 180 |
+
" save_steps=100,\n",
|
| 181 |
+
" save_total_limit=2,\n",
|
| 182 |
+
" eval_strategy=\"steps\",\n",
|
| 183 |
+
" eval_steps=100,\n",
|
| 184 |
+
" fp16=not torch.cuda.is_bf16_supported(),\n",
|
| 185 |
+
" bf16=torch.cuda.is_bf16_supported(),\n",
|
| 186 |
+
" report_to=\"none\",\n",
|
| 187 |
+
" optim=\"adamw_8bit\",\n",
|
| 188 |
+
" seed=42,\n",
|
| 189 |
+
" ),\n",
|
| 190 |
+
")\n",
|
| 191 |
+
"\n",
|
| 192 |
+
"print(\"Starting SFT training...\")\n",
|
| 193 |
+
"trainer.train()\n",
|
| 194 |
+
"print(\"SFT training complete.\")"
|
| 195 |
+
]
|
| 196 |
+
},
|
| 197 |
+
{
|
| 198 |
+
"cell_type": "markdown",
|
| 199 |
+
"metadata": {},
|
| 200 |
+
"source": [
|
| 201 |
+
"## Section 5 — Quick Eval\n",
|
| 202 |
+
"\n",
|
| 203 |
+
"Run 10 full episodes on easy tier with the trained model driving every step.\n",
|
| 204 |
+
"Requires env imports — upload the repo or clone it."
|
| 205 |
+
]
|
| 206 |
+
},
|
| 207 |
+
{
|
| 208 |
+
"cell_type": "code",
|
| 209 |
+
"execution_count": null,
|
| 210 |
+
"metadata": {},
|
| 211 |
+
"outputs": [],
|
| 212 |
+
"source": [
|
| 213 |
+
"import sys, os\n",
|
| 214 |
+
"\n",
|
| 215 |
+
"# Adjust this path to wherever the repo root is in Colab\n",
|
| 216 |
+
"REPO_ROOT = \".\" # or e.g. \"/content/Wildfire-Containment-Simulator-main\"\n",
|
| 217 |
+
"if REPO_ROOT not in sys.path:\n",
|
| 218 |
+
" sys.path.insert(0, REPO_ROOT)\n",
|
| 219 |
+
"\n",
|
| 220 |
+
"from env.wildfire_env import WildfireEnv\n",
|
| 221 |
+
"from env.serialization import serialize_observation\n",
|
| 222 |
+
"from env.action_parser import parse_action\n",
|
| 223 |
+
"from env.models import TIER_EASY\n",
|
| 224 |
+
"\n",
|
| 225 |
+
"SYSTEM_PROMPT = (\n",
|
| 226 |
+
" \"You are an AI Incident Commander managing wildfire containment. \"\n",
|
| 227 |
+
" \"You will receive a situation briefing each step. \"\n",
|
| 228 |
+
" \"Respond with ONLY a valid JSON action object and nothing else. \"\n",
|
| 229 |
+
" 'Example: {\"action_type\": \"idle\"}'\n",
|
| 230 |
+
")\n",
|
| 231 |
+
"\n",
|
| 232 |
+
"print(\"Env imports OK\")"
|
| 233 |
+
]
|
| 234 |
+
},
|
| 235 |
+
{
|
| 236 |
+
"cell_type": "code",
|
| 237 |
+
"execution_count": null,
|
| 238 |
+
"metadata": {},
|
| 239 |
+
"outputs": [],
|
| 240 |
+
"source": [
|
| 241 |
+
"import numpy as np\n",
|
| 242 |
+
"\n",
|
| 243 |
+
"FastLanguageModel.for_inference(model)\n",
|
| 244 |
+
"\n",
|
| 245 |
+
"EVAL_SEEDS = range(42, 52)\n",
|
| 246 |
+
"TIER = \"easy\"\n",
|
| 247 |
+
"MAX_STEPS = TIER_EASY.episode_length\n",
|
| 248 |
+
"\n",
|
| 249 |
+
"rewards = []\n",
|
| 250 |
+
"pop_saved_pcts = []\n",
|
| 251 |
+
"parse_counts = {\"json_success\": 0, \"regex_fallback\": 0, \"safe_idle\": 0}\n",
|
| 252 |
+
"total_steps = 0\n",
|
| 253 |
+
"\n",
|
| 254 |
+
"for seed in EVAL_SEEDS:\n",
|
| 255 |
+
" env = WildfireEnv()\n",
|
| 256 |
+
" obs = env.reset(task_id=TIER, seed=seed)\n",
|
| 257 |
+
" episode_reward = 0.0\n",
|
| 258 |
+
" step_num = 0\n",
|
| 259 |
+
" prev_burning = 0\n",
|
| 260 |
+
"\n",
|
| 261 |
+
" while not env.done:\n",
|
| 262 |
+
" prompt = serialize_observation(\n",
|
| 263 |
+
" obs, step_num, MAX_STEPS,\n",
|
| 264 |
+
" tier=TIER, prev_cells_burning=prev_burning,\n",
|
| 265 |
+
" )\n",
|
| 266 |
+
" prev_burning = obs.stats.cells_burning\n",
|
| 267 |
+
"\n",
|
| 268 |
+
" messages = [\n",
|
| 269 |
+
" {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
|
| 270 |
+
" {\"role\": \"user\", \"content\": prompt},\n",
|
| 271 |
+
" ]\n",
|
| 272 |
+
" input_ids = tokenizer.apply_chat_template(\n",
|
| 273 |
+
" messages, tokenize=True,\n",
|
| 274 |
+
" add_generation_prompt=True, return_tensors=\"pt\",\n",
|
| 275 |
+
" ).to(model.device)\n",
|
| 276 |
+
"\n",
|
| 277 |
+
" with torch.no_grad():\n",
|
| 278 |
+
" out = model.generate(\n",
|
| 279 |
+
" input_ids,\n",
|
| 280 |
+
" max_new_tokens=128,\n",
|
| 281 |
+
" pad_token_id=tokenizer.eos_token_id,\n",
|
| 282 |
+
" )\n",
|
| 283 |
+
" text = tokenizer.decode(out[0][input_ids.shape[1]:], skip_special_tokens=True)\n",
|
| 284 |
+
"\n",
|
| 285 |
+
" action, status = parse_action(text, obs)\n",
|
| 286 |
+
" parse_counts[status] = parse_counts.get(status, 0) + 1\n",
|
| 287 |
+
"\n",
|
| 288 |
+
" result = env.step(action)\n",
|
| 289 |
+
" episode_reward += result.reward\n",
|
| 290 |
+
" obs = result.observation\n",
|
| 291 |
+
" step_num += 1\n",
|
| 292 |
+
"\n",
|
| 293 |
+
" total_steps += step_num\n",
|
| 294 |
+
" rewards.append(episode_reward)\n",
|
| 295 |
+
"\n",
|
| 296 |
+
" state = env.state()\n",
|
| 297 |
+
" total_pop = state[\"total_population\"]\n",
|
| 298 |
+
" pop_lost = state[\"population_lost\"]\n",
|
| 299 |
+
" pop_saved = 100.0 * (total_pop - pop_lost) / total_pop if total_pop > 0 else 100.0\n",
|
| 300 |
+
" pop_saved_pcts.append(pop_saved)\n",
|
| 301 |
+
"\n",
|
| 302 |
+
" print(f\" Seed {seed}: reward={episode_reward:+.2f}, steps={step_num}, pop_saved={pop_saved:.0f}%\")\n",
|
| 303 |
+
"\n",
|
| 304 |
+
"mean_reward = np.mean(rewards)\n",
|
| 305 |
+
"std_reward = np.std(rewards)\n",
|
| 306 |
+
"total_parses = sum(parse_counts.values())\n",
|
| 307 |
+
"json_rate = 100.0 * parse_counts[\"json_success\"] / total_parses if total_parses > 0 else 0\n",
|
| 308 |
+
"mean_pop = np.mean(pop_saved_pcts)\n",
|
| 309 |
+
"\n",
|
| 310 |
+
"print(f\"\\n{'='*50}\")\n",
|
| 311 |
+
"print(f\"SFT Quick Eval — {TIER} tier, seeds {EVAL_SEEDS.start}-{EVAL_SEEDS.stop-1}\")\n",
|
| 312 |
+
"print(f\"Mean reward: {mean_reward:+.2f} ± {std_reward:.2f}\")\n",
|
| 313 |
+
"print(f\"JSON success rate: {json_rate:.1f}%\")\n",
|
| 314 |
+
"print(f\"Mean pop saved: {mean_pop:.1f}%\")\n",
|
| 315 |
+
"print(f\"Parse breakdown: {dict(parse_counts)}\")\n",
|
| 316 |
+
"print(f\"{'='*50}\")\n",
|
| 317 |
+
"\n",
|
| 318 |
+
"assert mean_reward > 2.0, (\n",
|
| 319 |
+
" f\"SFT warm-up insufficient (mean_reward={mean_reward:.2f}) — do not proceed to GRPO\"\n",
|
| 320 |
+
")\n",
|
| 321 |
+
"print(\"\\n✓ SFT checkpoint passes warm-up gate. Safe to proceed to GRPO.\")\n",
|
| 322 |
+
"\n",
|
| 323 |
+
"FastLanguageModel.for_training(model)"
|
| 324 |
+
]
|
| 325 |
+
},
|
| 326 |
+
{
|
| 327 |
+
"cell_type": "markdown",
|
| 328 |
+
"metadata": {},
|
| 329 |
+
"source": [
|
| 330 |
+
"## Section 6 — Save & Export"
|
| 331 |
+
]
|
| 332 |
+
},
|
| 333 |
+
{
|
| 334 |
+
"cell_type": "code",
|
| 335 |
+
"execution_count": null,
|
| 336 |
+
"metadata": {},
|
| 337 |
+
"outputs": [],
|
| 338 |
+
"source": [
|
| 339 |
+
"model.save_pretrained(\"./sft_final\")\n",
|
| 340 |
+
"tokenizer.save_pretrained(\"./sft_final\")\n",
|
| 341 |
+
"print(\"Saved to ./sft_final\")"
|
| 342 |
+
]
|
| 343 |
+
},
|
| 344 |
+
{
|
| 345 |
+
"cell_type": "code",
|
| 346 |
+
"execution_count": null,
|
| 347 |
+
"metadata": {},
|
| 348 |
+
"outputs": [],
|
| 349 |
+
"source": [
|
| 350 |
+
"# Push to HuggingFace Hub — replace with your username\n",
|
| 351 |
+
"HF_USERNAME = \"Eshit\"\n",
|
| 352 |
+
"model.push_to_hub(f\"{HF_USERNAME}/wildfire-sft-7b\")\n",
|
| 353 |
+
"tokenizer.push_to_hub(f\"{HF_USERNAME}/wildfire-sft-7b\")\n",
|
| 354 |
+
"print(f\"Pushed to hub: {HF_USERNAME}/wildfire-sft-7b\")"
|
| 355 |
+
]
|
| 356 |
+
},
|
| 357 |
+
{
|
| 358 |
+
"cell_type": "code",
|
| 359 |
+
"execution_count": null,
|
| 360 |
+
"metadata": {},
|
| 361 |
+
"outputs": [],
|
| 362 |
+
"source": [
|
| 363 |
+
"!zip -r sft_final.zip ./sft_final\n",
|
| 364 |
+
"from google.colab import files\n",
|
| 365 |
+
"files.download(\"sft_final.zip\")\n",
|
| 366 |
+
"print(\"Download started.\")"
|
| 367 |
+
]
|
| 368 |
+
}
|
| 369 |
+
],
|
| 370 |
+
"metadata": {
|
| 371 |
+
"accelerator": "GPU",
|
| 372 |
+
"colab": {
|
| 373 |
+
"gpuType": "A100",
|
| 374 |
+
"provenance": []
|
| 375 |
+
},
|
| 376 |
+
"kernelspec": {
|
| 377 |
+
"display_name": "Python 3",
|
| 378 |
+
"name": "python3"
|
| 379 |
+
},
|
| 380 |
+
"language_info": {
|
| 381 |
+
"name": "python",
|
| 382 |
+
"version": "3.10.0"
|
| 383 |
+
}
|
| 384 |
+
},
|
| 385 |
+
"nbformat": 4,
|
| 386 |
+
"nbformat_minor": 5
|
| 387 |
+
}
|
training/sft_data.jsonl
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|