wildfire_env / server /wildfire_environment.py
shankerram3's picture
Upload folder using huggingface_hub
8c6fae6 verified
import os
import random
import uuid
# Support both in-repo and standalone imports
try:
# In-repo imports (when running from OpenEnv repository)
from openenv.core.env_server import Environment
from ..models import WildfireAction, WildfireObservation, WildfireState
except ImportError:
# Standalone imports (when environment is standalone with openenv-core from pip)
from openenv_core.env_server import Environment
from wildfire_env.models import WildfireAction, WildfireObservation, WildfireState
# Helpers
DIRS_8 = {
"N": (0, -1), "NE": (1, -1), "E": (1, 0), "SE": (1, 1),
"S": (0, 1), "SW": (-1, 1), "W": (-1, 0), "NW": (-1, -1),
"CALM": (0, 0),
}
def idx(x: int, y: int, w: int) -> int:
# Defensive type conversion to ensure all parameters are integers
x, y, w = int(x), int(y), int(w)
return y * w + x
def in_bounds(x: int, y: int, w: int, h: int) -> bool:
# Defensive type conversion to ensure all parameters are integers
x, y, w, h = int(x), int(y), int(w), int(h)
return 0 <= x < w and 0 <= y < h
class WildfireEnvironment(Environment):
"""
Weather-aware wildfire simulation.
Grid encodings:
0 = ash (burned out)
1 = fuel / vegetation
2 = burning
3 = firebreak
4 = watered / damp
Each step:
- agent acts (water/break/wait)
- burning spreads to neighbors with wind + humidity effects
- burning cells burn for multiple ticks, then become ash
"""
def __init__(
self,
width: int = 32,
height: int = 32,
base_ignite_prob: float = 0.30,
wind_bias: float = 0.20, # kept for compatibility (not directly used in B model)
diag_factor: float = 0.7, # kept for compatibility (not directly used in B model)
humidity: float = 0.25,
init_sources: int = 2,
seed: int = 3407,
max_steps: int = 128,
water_capacity: int = 8, # ↓ encourage strategic water use
break_capacity: int = 50,
):
super().__init__()
# --- Env-var overrides (optional) ---
width = int(os.environ.get("WILDFIRE_WIDTH", width))
height = int(os.environ.get("WILDFIRE_HEIGHT", height))
humidity = float(os.environ.get("WILDFIRE_HUMIDITY", humidity))
forced_wind = os.environ.get("WILDFIRE_WIND", None)
# Store config (ensure integers)
self.w = int(width)
self.h = int(height)
self.base_ignite_prob = base_ignite_prob
self.wind_bias = wind_bias
self.diag_factor = diag_factor
self.init_humidity = humidity
self.init_sources = init_sources
self.rng = random.Random(seed)
self.max_steps = max_steps
self.init_water = water_capacity
self.init_breaks = break_capacity
self.forced_wind = forced_wind
# burn lifetime in ticks (balanced model)
self.burn_lifetime = 3
# Initialize state with minimal defaults (will be properly set in reset())
# We can't use WildfireState() directly due to Pydantic/dataclass conflicts,
# so we'll initialize it in reset() and handle None case in state property
self._state: WildfireState | None = None
# --- Core API ---
def reset(self) -> WildfireObservation:
# Ensure w and h are integers (defensive type conversion)
w, h = int(self.w), int(self.h)
# Start with all fuel
grid = [1] * (w * h)
# Wind (forced if provided)
if self.forced_wind and self.forced_wind in DIRS_8:
wind_dir = self.forced_wind
else:
wind_dir = self.rng.choice(list(DIRS_8.keys()))
# Humidity small variation around init
humidity = min(1.0, max(0.0, self.init_humidity + self.rng.uniform(-0.05, 0.05)))
# Place initial fires
for _ in range(self.init_sources):
x = self.rng.randrange(w)
y = self.rng.randrange(h)
i = idx(x, y, w)
# Safety check: ensure index is within grid bounds
if 0 <= i < len(grid):
grid[i] = 2
# Initialize burn timers before creating state
burn_timers = [0] * (w * h)
# Use model_construct to bypass Pydantic validation for dataclass/Pydantic compatibility
self._state = WildfireState.model_construct(
episode_id=str(uuid.uuid4()),
step_count=0,
total_burned=0,
total_extinguished=0,
last_action="reset",
width=w,
height=h,
wind_dir=wind_dir,
humidity=humidity,
remaining_water=self.init_water,
remaining_breaks=self.init_breaks,
grid=grid,
burn_timers=burn_timers,
)
obs = self._make_observation(reward_hint=0.0)
return obs
def step(self, action: WildfireAction) -> WildfireObservation:
st = self._state
reward = 0.0
# --- Agent action effects ---
if (
action.action == "water"
and st.remaining_water > 0
and action.x is not None
and action.y is not None
):
reward += self._apply_water(action.x, action.y)
elif (
action.action == "break"
and st.remaining_breaks > 0
and action.x is not None
and action.y is not None
):
reward += self._apply_break(action.x, action.y)
elif action.action == "wait":
pass
else:
reward -= 0.05 # invalid or exhausted resources
# --- Natural fire dynamics ---
prev_burning = self._burning_count()
prev_burned = sum(1 for v in st.grid if v == 0)
newly_burned = self._spread_fire()
new_burning = self._burning_count()
now_burned = sum(1 for v in st.grid if v == 0)
st.total_burned += newly_burned
st.step_count += 1
st.last_action = action.action
# --- Spread vs containment shaping ---
spread_delta = new_burning - prev_burning
burned_delta = now_burned - prev_burned
# Strong penalty for spread
if spread_delta > 0:
reward -= 0.15 * spread_delta # 🔥 focus on containment
elif spread_delta < 0:
reward += 0.10 * abs(spread_delta) # reward shrinkage
# Mild penalty for newly burned cells (area loss)
if burned_delta > 0:
reward -= 0.05 * burned_delta
# Small time penalty to prefer fast control
reward -= 0.01
done = self._is_done()
# --- End of episode bonuses ---
if done:
saved_ratio = self._saved_cells() / (self.w * self.h)
burned_ratio = now_burned / (self.w * self.h)
burning_left = self._burning_count()
# Big containment bonus
if burning_left == 0:
reward += 0.5 + 0.5 * saved_ratio
# Fallback proportional reward
reward += 0.2 * (1.0 - burned_ratio)
obs = self._make_observation(reward_hint=reward)
obs.done = done
obs.reward = reward
return obs
# --- Internal mechanics ---
def _apply_water(self, x: int, y: int) -> float:
st = self._state
# Ensure x and y are integers (defensive type conversion)
x, y = int(x), int(y)
if not in_bounds(x, y, self.w, self.h):
return -0.05
# Strong penalty if no water left
if st.remaining_water <= 0:
return -0.5
i = idx(x, y, self.w)
# Safety check: ensure index is within grid bounds
if i < 0 or i >= len(st.grid):
return -0.05
reward = 0.0
if st.grid[i] == 2:
st.grid[i] = 4 # extinguish & dampen
st.burn_timers[i] = 0
st.total_extinguished += 1
reward += 0.25
elif st.grid[i] == 1:
st.grid[i] = 4 # dampen fuel (mild penalty to avoid spamming)
st.burn_timers[i] = 0
reward -= 0.10
elif st.grid[i] == 4:
# redundant watering
reward -= 0.05
else:
# watering ash/break gives slight penalty
reward -= 0.05
st.remaining_water -= 1
return reward
def _apply_break(self, x: int, y: int) -> float:
st = self._state
# Ensure x and y are integers (defensive type conversion)
x, y = int(x), int(y)
if not in_bounds(x, y, self.w, self.h):
return -0.05
i = idx(x, y, self.w)
# Safety check: ensure index is within grid bounds
if i < 0 or i >= len(st.grid):
return -0.05
reward = 0.0
if st.grid[i] in (1, 4):
st.grid[i] = 3
st.burn_timers[i] = 0
reward += 0.15 # slightly more than before to make firebreaks attractive
elif st.grid[i] == 2:
st.grid[i] = 3
st.burn_timers[i] = 0
reward -= 0.02
elif st.grid[i] == 3:
reward -= 0.01
else:
reward -= 0.02
st.remaining_breaks -= 1
return reward
def _spread_fire(self) -> int:
"""
Balanced wildfire spread model:
- burning cells persist for multiple ticks before turning to ash
- 8-direction spread (diagonals weaker)
- wind accelerates in wind direction, weakens upwind
- humidity suppresses ignition probability
- water (4) is IMMUNE to ignition while damp and reverts to fuel after several ticks
"""
st = self._state
new_grid = st.grid[:]
newly_burned = 0
# Ensure w and h are integers (defensive type conversion)
w, h = int(self.w), int(self.h)
# 8-neighbor model
neighbors = [(-1, 0), (1, 0), (0, -1), (0, 1),
(-1, -1), (1, -1), (-1, 1), (1, 1)]
wx, wy = DIRS_8.get(st.wind_dir, (0, 0))
base = self.base_ignite_prob
humidity_factor = (1.0 - st.humidity)
ignite_flags = [False] * (w * h)
# First pass: evaluate ignitions, increment burn timers
for y in range(h):
for x in range(w):
i = idx(x, y, w)
# Safety check: ensure index is within grid bounds
if i < 0 or i >= len(st.grid):
continue
cell = st.grid[i]
if cell == 2: # burning
st.burn_timers[i] += 1
for dx, dy in neighbors:
nx, ny = x + dx, y + dy
if not in_bounds(nx, ny, w, h):
continue
ni = idx(nx, ny, w)
# Safety check: ensure neighbor index is within grid bounds
if ni < 0 or ni >= len(st.grid):
continue
target = st.grid[ni]
# Only fuel or water/damp can be candidates, but cells with code 4 (watered/damp) are immune to ignition
if target == 4:
# Watered/damp cells (code 4) do not ignite at all while in this state
continue
if target != 1:
continue
# Wind multiplier
if (dx, dy) == (wx, wy):
wind_mult = 2.0
elif (dx, dy) == (-wx, -wy):
wind_mult = 0.5
else:
wind_mult = 1.0
# Diagonals weaker
diag_mult = 0.6 if (dx != 0 and dy != 0) else 1.0
p = base * humidity_factor * wind_mult * diag_mult
p = max(0.0, min(1.0, p))
if self.rng.random() < p:
# Safety check: ensure ni is within ignite_flags bounds
if 0 <= ni < len(ignite_flags):
ignite_flags[ni] = True
# Second pass: apply transitions
for i, cell in enumerate(st.grid):
# Safety check: ensure index is within bounds for all arrays
if i < 0 or i >= len(new_grid) or i >= len(st.burn_timers):
continue
if cell == 2:
# burns for burn_lifetime ticks before turning to ash
if st.burn_timers[i] >= self.burn_lifetime:
new_grid[i] = 0 # ash
newly_burned += 1
else:
new_grid[i] = 2 # keep burning
elif i < len(ignite_flags) and ignite_flags[i] and new_grid[i] == 1:
new_grid[i] = 2
st.burn_timers[i] = 0
elif cell == 4:
# Water stays damp for several ticks before reverting to fuel
st.burn_timers[i] += 1
if st.burn_timers[i] >= 6: # was 3; extend to make water useful
new_grid[i] = 1
st.grid = new_grid
return newly_burned
def _burning_count(self) -> int:
return sum(1 for v in self._state.grid if v == 2)
def _saved_cells(self) -> int:
# cells not turned to ash (includes fuel, burning, break, water)
return sum(1 for v in self._state.grid if v in (1, 2, 3, 4))
def _is_done(self) -> bool:
return self._burning_count() == 0 or self._state.step_count >= self.max_steps
def _make_observation(self, reward_hint: float = 0.0) -> WildfireObservation:
st = self._state
burning = self._burning_count()
burned = sum(1 for v in st.grid if v == 0)
# Use model_construct to bypass Pydantic validation for dataclass/Pydantic compatibility
return WildfireObservation.model_construct(
grid=st.grid[:],
width=self.w,
height=self.h,
step=st.step_count,
wind_dir=st.wind_dir,
humidity=st.humidity,
burning_count=burning,
remaining_water=st.remaining_water, # ✅ new
remaining_breaks=st.remaining_breaks, # ✅ new
burned_count=burned,
reward_hint=reward_hint,
)
# --- Required abstract property implementation ---
@property
def state(self) -> WildfireState:
"""Return the current environment state."""
if self._state is None:
# Initialize with minimal defaults if accessed before reset()
# Use model_construct to bypass Pydantic validation for dataclass/Pydantic compatibility
self._state = WildfireState.model_construct(
episode_id="",
step_count=0,
total_burned=0,
total_extinguished=0,
last_action="reset",
width=0,
height=0,
wind_dir="CALM",
humidity=0.25,
remaining_water=self.init_water,
remaining_breaks=self.init_breaks,
grid=[],
burn_timers=[],
)
return self._state