Eshit's picture
Improve wildfire metrics and training assets
ad92ece
"""
Wildfire Containment Simulator — Main Environment.
Implements the OpenEnv API: step(), reset(), state().
Orchestrates grid, fire spread, weather, resources, and reward computation.
"""
from __future__ import annotations
import logging
from typing import Optional
import numpy as np
from pydantic import ValidationError
from .models import (
Action, ActionType, Observation, StepResult, ClusterStats,
FireState, FuelType, TierConfig, TIER_EASY, TIER_MEDIUM, TIER_HARD,
)
from .grid import Grid
from .fire_spread import FireSpreadEngine
from .weather import WeatherEngine
from .resources import ResourceManager
from .reward import RewardCalculator
from .briefing import generate_briefing, OperationalBriefing
logger = logging.getLogger(__name__)
class WildfireEnv:
"""
Wildfire Containment Simulator environment.
Simulates a grid-based wildfire where an AI agent dispatches
firefighting resources to contain the fire before it reaches
populated zones.
API:
reset(task_id, seed) -> Observation
step(action) -> StepResult
state() -> dict
"""
TIER_MAP = {
"easy": TIER_EASY,
"medium": TIER_MEDIUM,
"hard": TIER_HARD,
}
def __init__(self, config: Optional[TierConfig] = None):
self.config = config or TIER_EASY
self.rng = np.random.default_rng(42)
self.current_step = 0
self.done = False
# Components (initialized in reset)
self.grid: Optional[Grid] = None
self.fire_engine: Optional[FireSpreadEngine] = None
self.weather: Optional[WeatherEngine] = None
self.resources: Optional[ResourceManager] = None
self.reward_calc: Optional[RewardCalculator] = None
self.events_log: list[str] = []
# Episode-level tracking for new reward structure
self._prev_action: Optional[Action] = None
self._invalid_action_count: int = 0
self._crew_casualty_occurred: bool = False
self._prev_state: Optional[dict] = None
self.active_briefing: Optional[OperationalBriefing] = None
# Last observation returned to the agent (agent's view, not ground truth)
self._current_obs: Optional[Observation] = None
def reset(self, task_id: str = "easy", seed: int = 42) -> Observation:
"""
Initialize the environment for a new episode.
Args:
task_id: One of "easy", "medium", "hard".
seed: Random seed for reproducibility.
Returns:
Initial observation.
"""
self.config = self.TIER_MAP.get(task_id, TIER_EASY)
self.rng = np.random.default_rng(seed)
self.current_step = 0
self.done = False
self.events_log = []
self._prev_action = None
self._invalid_action_count = 0
self._crew_casualty_occurred = False
self._prev_state = None
# Initialize components
self.grid = Grid(self.config, self.rng)
self.fire_engine = FireSpreadEngine(self.grid, self.rng)
self.weather = WeatherEngine(self.config, self.rng)
self.resources = ResourceManager(self.config, self.grid)
self.reward_calc = RewardCalculator(self.config)
self.reward_calc.reset()
self.resources.reset()
self.weather.reset()
# Ignite initial fire points
self._ignite_initial_fires()
# Generate operational briefing for this episode
self.active_briefing = generate_briefing(self.config, self.rng, self.grid)
# Build and return initial observation (with briefing attached)
obs = self._build_observation()
obs.briefing = self.active_briefing
self.events_log.append("Episode started. Fire ignited.")
self._current_obs = obs
return obs
def step(self, action: Action) -> StepResult:
"""
Execute one simulation step.
Follows the 11-step tick sequence:
1. Validate action
2. Execute action
3. Spread fire
4. Update intensities (handled inside spread)
5. Apply suppression
6. Evolve weather
7. Update moisture
8. Propagate smoke
9. Compute reward
10. Check termination
11. Build observation
Args:
action: The agent's chosen action.
Returns:
StepResult with observation, reward, done flag, and info dict.
"""
if self.done:
return StepResult(
observation=self._build_observation(),
reward=0.0,
done=True,
info={"error": "Episode already finished"},
)
step_events: list[str] = []
# Snapshot state before this step's changes
prev_state = self._snapshot_state()
# ── Step 1: Validate action ──
action_was_redundant = self._is_redundant(action)
valid, msg = self._validate_action(action)
if not valid:
self.reward_calc.record_invalid_action()
self._invalid_action_count += 1
self.resources.wasted_actions += 1
step_events.append(f"Invalid action: {msg}")
# Skip to reward/termination
else:
# ── Step 2: Execute action ──
exec_events = self._execute_action(action)
step_events.extend(exec_events)
self._prev_action = action
# ── Step 3-4: Spread fire + update intensities ──
ws = self.weather.state
spread_events = self.fire_engine.spread_step(ws.wind_speed_kmh, ws.wind_direction_deg)
step_events.extend(spread_events)
# ── Step 5: Apply suppression ──
supp_events = self.resources.apply_suppression()
step_events.extend(supp_events)
# ── Step 6: Evolve weather ──
weather_events = self.weather.step(self.current_step)
step_events.extend(weather_events)
# ── Step 7: Update moisture ──
self.grid.update_moisture(ws.rain_active, ws.humidity_pct)
# ── Step 8: Propagate smoke ──
self.grid.propagate_smoke(ws.wind_direction_deg, ws.wind_speed_kmh)
# ── Tick tanker cooldowns ──
self.resources.tick_tanker_cooldowns()
# ── Expire recon reveals ──
self.resources.expire_reveals(self.current_step)
# ── Handle staggered ignition (hard tier) ──
if (self.config.staggered_ignition_step is not None
and self.current_step == self.config.staggered_ignition_step):
self._ignite_staggered_fire()
step_events.append("NEW IGNITION: Additional fire started!")
# ── Handle crew loss (hard tier) ──
if (self.config.enable_crew_loss
and self.config.crew_loss_step == self.current_step
and self.config.crew_loss_id):
loss_events = self.resources.apply_crew_loss(self.config.crew_loss_id)
step_events.extend(loss_events)
# Track crew casualty
if self.resources.crew_casualties:
self._crew_casualty_occurred = True
self.current_step += 1
# Log a hold-message when fire is extinguished before min_active_steps so
# agents (and the LLM) understand the episode must continue for monitoring.
burning_now = (self.grid.count_by_state(FireState.BURNING)
+ self.grid.count_by_state(FireState.EMBER))
if burning_now == 0 and self.current_step < self.config.min_active_steps:
step_events.append(
f"All fires contained. Holding perimeter until step "
f"{self.config.min_active_steps} (min_active_steps)."
)
# ── Step 9: Compute reward ──
legacy_reward = self.reward_calc.compute_reward(self.grid, self.resources, self.current_step)
current_state = self._snapshot_state()
step_reward = self.reward_calc.compute_step_reward(
prev_state, current_state, valid, action_was_redundant
)
# ── Step 10: Check termination ──
self.done = self._check_termination()
terminal_reward = 0.0
if self.done:
terminal_state = dict(current_state)
terminal_state["crew_casualty_occurred"] = self._crew_casualty_occurred
terminal_state["invalid_action_count"] = self._invalid_action_count
if self.active_briefing:
terminal_state["priority_zones"] = self.active_briefing.priority_populated_zones
terminal_state["_grid_ref"] = self.grid
terminal_reward = self.reward_calc.compute_terminal_reward(
terminal_state, self.current_step, self.config.episode_length
)
reward = step_reward + terminal_reward
# ── Step 11: Build observation ──
obs = self._build_observation()
# Keep last 5 events
self.events_log = (self.events_log + step_events)[-20:]
info = {
"step": self.current_step,
"events": step_events,
"legacy_reward": round(legacy_reward, 4),
"reward_breakdown": self.reward_calc.get_component_breakdown(
self.grid, self.resources, self.current_step
),
}
result = StepResult(
observation=obs,
reward=round(reward, 4),
done=self.done,
info=info,
)
self._current_obs = result.observation
return result
def state(self) -> dict:
"""
Return full ground-truth state for grading/debugging.
NOT for agent use — contains information hidden from the agent.
"""
if self.grid is None:
return {"error": "Environment not initialized. Call reset() first."}
# Full grid state without any occlusion
full_grid = []
for r in range(self.grid.rows):
row = []
for c in range(self.grid.cols):
static = self.grid.static_grid[r][c]
dynamic = self.grid.dynamic_grid[r][c]
row.append({
"row": r, "col": c,
"fuel_type": static.fuel_type.value,
"fuel_load": static.fuel_load,
"elevation_m": static.elevation_m,
"is_populated": static.is_populated,
"population": static.population,
"fire_state": dynamic.fire_state.value,
"fire_intensity": round(dynamic.fire_intensity, 4),
"moisture": round(dynamic.moisture, 4),
"time_burning": dynamic.time_burning,
"suppression_level": round(dynamic.suppression_level, 4),
"smoke_density": round(dynamic.smoke_density, 4),
"crew_present": dynamic.crew_present,
})
full_grid.append(row)
return {
"tier": self.config.tier_name,
"current_step": self.current_step,
"done": self.done,
"grid": full_grid,
"weather": self.weather.get_true_state().model_dump(),
"resources": self.resources.get_resource_state().model_dump(),
"reward_breakdown": self.reward_calc.get_component_breakdown(
self.grid, self.resources, self.current_step
),
"total_population": self.grid.get_total_population(),
"population_lost": self.grid.get_population_lost(),
"cells_burned": self.grid.get_burned_count(),
"total_burnable": self.grid.get_total_burnable(),
}
# ══════════════════════════════════════════════════
# PRIVATE METHODS
# ══════════════════════════════════════════════════
def _snapshot_state(self) -> dict:
"""Capture a lightweight state dict for reward delta computation."""
total, contained = self.grid.get_fire_perimeter()
containment_pct = contained / total if total > 0 else 1.0
return {
"containment_pct": containment_pct,
"pop_lost": self.grid.get_population_lost(),
"total_pop": self.grid.get_total_population(),
}
def _is_redundant(self, action: Action) -> bool:
"""True if action is a meaningless repeat of the previous action.
Actions that use target coordinates (DROP_RETARDANT, DEPLOY_CREW, RECON_FLIGHT)
are redundant when the type + target cell match. Directional actions (MOVE_CREW,
BUILD_FIREBREAK) require the same crew_id AND direction to be redundant — two
consecutive MOVE_CREW steps by different crews, or in different directions, are
valid patrol behaviour and must not be penalised.
"""
if self._prev_action is None:
return False
prev = self._prev_action
if action.action_type != prev.action_type:
return False
# Coordinate-targeted actions: redundant when same cell is targeted again
if action.target_row is not None or prev.target_row is not None:
return (action.target_row == prev.target_row
and action.target_col == prev.target_col)
# Crew directional actions: redundant only when same crew moves same direction
if action.crew_id is not None:
return (action.crew_id == prev.crew_id
and action.direction == prev.direction)
return False
def _ignite_initial_fires(self) -> None:
"""Place initial fire ignition points based on tier config.
Ignition candidates are shifted away from populated cells to ensure
a minimum survivable distance, reducing unwinnable-scenario variance.
Intensity is set high enough (0.65) that a single tanker drop (-0.4)
leaves residual fire (0.25) so the episode cannot be solved in 1-2
steps. The fire must spread, be actively managed, and burn for at
least min_active_steps before the episode can end.
"""
rows, cols = self.config.grid_rows, self.config.grid_cols
# Minimum Manhattan distance from any populated cell per tier
min_pop_dist = {"easy": 4, "medium": 6, "hard": 7}.get(self.config.tier_name, 5)
if self.config.tier_name == "easy":
# Two ignition points spread across the grid so crews must split
r1, c1 = self._find_ignition_candidate(rows // 2, cols // 3, min_pop_dist)
self.grid.ignite_cell(r1, c1, intensity=0.65)
r2, c2 = self._find_ignition_candidate(rows // 2, 2 * cols // 3, min_pop_dist)
self.grid.ignite_cell(r2, c2, intensity=0.65)
elif self.config.tier_name == "medium":
# Three ignition points: forces genuine multi-front management
r1, c1 = self._find_ignition_candidate(rows // 4, cols // 3, min_pop_dist)
self.grid.ignite_cell(r1, c1, intensity=0.65)
r2, c2 = self._find_ignition_candidate(2 * rows // 3, 2 * cols // 3, min_pop_dist)
self.grid.ignite_cell(r2, c2, intensity=0.65)
r3, c3 = self._find_ignition_candidate(rows // 2, cols // 2, min_pop_dist)
self.grid.ignite_cell(r3, c3, intensity=0.65)
else:
# Two initial points (third comes later via staggered ignition at step 30)
r1, c1 = self._find_ignition_candidate(rows // 4, cols // 4, min_pop_dist)
self.grid.ignite_cell(r1, c1, intensity=0.65)
r2, c2 = self._find_ignition_candidate(rows // 2, 3 * cols // 4, min_pop_dist)
self.grid.ignite_cell(r2, c2, intensity=0.65)
def _find_ignition_candidate(self, target_r: int, target_c: int, min_pop_dist: int) -> tuple[int, int]:
"""Return the nearest valid ignition cell to (target_r, target_c) that is at
least min_pop_dist (Manhattan) from every populated cell.
Searches in expanding rings; falls back to the original target if no
compliant cell is found within the grid bounds.
"""
rows, cols = self.config.grid_rows, self.config.grid_cols
pop_cells = [
(r, c)
for r in range(rows)
for c in range(cols)
if self.grid.static_grid[r][c].is_populated
]
def _min_pop_dist(r: int, c: int) -> int:
if not pop_cells:
return 9999
return min(abs(r - pr) + abs(c - pc) for pr, pc in pop_cells)
for radius in range(max(rows, cols)):
for dr in range(-radius, radius + 1):
for dc in range(-radius, radius + 1):
if radius > 0 and abs(dr) + abs(dc) != radius:
continue
r, c = target_r + dr, target_c + dc
if not self.grid._in_bounds(r, c):
continue
static = self.grid.static_grid[r][c]
if static.fuel_type in (FuelType.WATER, FuelType.ROAD):
continue
if _min_pop_dist(r, c) >= min_pop_dist:
return r, c
return target_r, target_c
def _ignite_staggered_fire(self) -> None:
"""Ignite additional fire point(s) for hard tier."""
rows, cols = self.config.grid_rows, self.config.grid_cols
# Place in an area likely to cause problems
target_r = 3 * rows // 4
target_c = cols // 3
# Find nearest unburned cell
for dr in range(5):
for dc in range(5):
r, c = target_r + dr, target_c + dc
if self.grid._in_bounds(r, c):
if self.grid.dynamic_grid[r][c].fire_state == FireState.UNBURNED:
self.grid.ignite_cell(r, c, intensity=0.7)
return
def _validate_action(self, action: Action) -> tuple[bool, str]:
"""Validate action parameters. Returns (is_valid, error_message)."""
try:
# Pydantic validation already ran on construction,
# but we do semantic validation here
if action.action_type == ActionType.DEPLOY_CREW:
if not self.grid._in_bounds(action.target_row, action.target_col):
return False, f"Target ({action.target_row},{action.target_col}) out of bounds"
elif action.action_type == ActionType.DROP_RETARDANT:
if not self.grid._in_bounds(action.target_row, action.target_col):
return False, f"Target ({action.target_row},{action.target_col}) out of bounds"
elif action.action_type == ActionType.RECON_FLIGHT:
if not self.grid._in_bounds(action.target_row, action.target_col):
return False, f"Target ({action.target_row},{action.target_col}) out of bounds"
return True, ""
except Exception as e:
return False, str(e)
def _execute_action(self, action: Action) -> list[str]:
"""Execute a validated action. Returns event messages."""
events = []
at = action.action_type
if at == ActionType.DEPLOY_CREW:
ok, msg = self.resources.deploy_crew(action.crew_id, action.target_row, action.target_col)
events.append(msg)
if not ok:
self.resources.wasted_actions += 1
elif at == ActionType.MOVE_CREW:
ok, msg = self.resources.move_crew(action.crew_id, action.direction)
events.append(msg)
if not ok:
self.resources.wasted_actions += 1
elif at == ActionType.DROP_RETARDANT:
ok, msg = self.resources.drop_retardant(action.tanker_id, action.target_row, action.target_col)
events.append(msg)
if not ok:
self.resources.wasted_actions += 1
elif at == ActionType.BUILD_FIREBREAK:
ok, msg = self.resources.build_firebreak(action.crew_id, action.direction)
events.append(msg)
if not ok:
self.resources.wasted_actions += 1
elif at == ActionType.RECON_FLIGHT:
ok, msg = self.resources.recon_flight(action.target_row, action.target_col, self.current_step)
events.append(msg)
if not ok:
self.resources.wasted_actions += 1
elif at == ActionType.IDLE:
reason = action.reason or "No action taken"
events.append(f"IDLE: {reason}")
return events
def _check_termination(self) -> bool:
"""Check if the episode should end."""
# Time limit
if self.current_step >= self.config.episode_length:
return True
# Fire fully contained (no burning cells)
burning = self.grid.count_by_state(FireState.BURNING)
ember = self.grid.count_by_state(FireState.EMBER)
if burning == 0 and ember == 0:
# Enforce minimum active steps — prevents trivial 1-2 step episodes
# where a single tanker drop or natural burnout ends the episode
# before the agent has taken any meaningful sequence of actions.
if self.current_step < self.config.min_active_steps:
return False
# Don't terminate before staggered ignition fires (hard tier)
if (self.config.staggered_ignition_step
and self.current_step < self.config.staggered_ignition_step):
return False
return True
# All populated zones burned (catastrophic failure)
total_pop = self.grid.get_total_population()
lost_pop = self.grid.get_population_lost()
if total_pop > 0 and lost_pop >= total_pop:
return True
return False
def _build_observation(self) -> Observation:
"""Build the agent's observation with appropriate noise/occlusion."""
# Grid observation with fog/smoke
crew_positions = self.resources.get_crew_positions()
grid_obs = self.grid.build_observation(
enable_fog=self.config.enable_fog_of_war,
fog_radius=self.config.fog_visibility_radius,
crew_positions=crew_positions,
revealed_cells=self.resources.revealed_cells,
)
# Weather observation (possibly noisy)
weather_obs = self.weather.get_observation()
# Resource state (fully observable)
resource_state = self.resources.get_resource_state()
# Stats
total_burnable = self.grid.get_total_burnable()
cells_burned = self.grid.get_burned_count()
total_pop = self.grid.get_total_population()
pop_lost = self.grid.get_population_lost()
area_saved_pct = round(
100.0 * (total_burnable - cells_burned) / total_burnable, 1
) if total_burnable > 0 else 100.0
civilians_saved_pct = round(
100.0 * (total_pop - pop_lost) / total_pop, 1
) if total_pop > 0 else 100.0
stats = ClusterStats(
cells_burned=cells_burned,
cells_burning=self.grid.count_by_state(FireState.BURNING),
cells_saved=total_burnable - cells_burned - self.grid.count_by_state(FireState.BURNING),
population_threatened=self._count_threatened_population(),
population_lost=pop_lost,
total_population=total_pop,
containment_pct=self._compute_containment_pct(),
area_saved_pct=area_saved_pct,
civilians_saved_pct=civilians_saved_pct,
current_step=self.current_step,
max_steps=self.config.episode_length,
firebreaks_built=self.resources.total_firebreaks_built,
retardant_drops=self.resources.total_retardant_drops,
)
# Recent events (last 5)
recent = self.events_log[-5:] if self.events_log else []
return Observation(
grid=grid_obs,
weather=weather_obs,
resources=resource_state,
stats=stats,
recent_events=recent,
)
def _count_threatened_population(self) -> int:
"""Count population within 3 cells of active fire."""
threatened = 0
burning_cells = self.grid.get_burning_cells()
counted = set()
for br, bc in burning_cells:
for r in range(max(0, br - 3), min(self.grid.rows, br + 4)):
for c in range(max(0, bc - 3), min(self.grid.cols, bc + 4)):
if (r, c) not in counted:
static = self.grid.static_grid[r][c]
if static.is_populated:
dynamic = self.grid.dynamic_grid[r][c]
if dynamic.fire_state not in (FireState.BURNED_OUT, FireState.BURNING):
threatened += static.population
counted.add((r, c))
return threatened
def _compute_containment_pct(self) -> float:
"""Compute fire containment percentage."""
total, contained = self.grid.get_fire_perimeter()
if total == 0:
return 100.0
return round(100.0 * contained / total, 1)