File size: 15,840 Bytes
62b53b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a1f6bc
 
 
dd7277b
62b53b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e95764
 
 
2f38b12
 
62b53b4
 
 
 
 
 
 
 
 
 
 
 
0e95764
62b53b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c6fae6
 
62b53b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dd7277b
5a1f6bc
fefc730
 
5a1f6bc
 
 
 
 
 
 
 
 
 
 
 
 
 
62b53b4
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438

import os
import random
import uuid

# Support both in-repo and standalone imports
try:
    # In-repo imports (when running from OpenEnv repository)
    from openenv.core.env_server import Environment
    from ..models import WildfireAction, WildfireObservation, WildfireState
except ImportError:
    # Standalone imports (when environment is standalone with openenv-core from pip)
    from openenv_core.env_server import Environment
    from wildfire_env.models import WildfireAction, WildfireObservation, WildfireState

# Helpers
DIRS_8 = {
    "N":  (0, -1), "NE": (1, -1), "E":  (1, 0), "SE": (1, 1),
    "S":  (0,  1), "SW": (-1, 1), "W":  (-1, 0), "NW": (-1, -1),
    "CALM": (0, 0),
}

def idx(x: int, y: int, w: int) -> int:
    # Defensive type conversion to ensure all parameters are integers
    x, y, w = int(x), int(y), int(w)
    return y * w + x

def in_bounds(x: int, y: int, w: int, h: int) -> bool:
    # Defensive type conversion to ensure all parameters are integers
    x, y, w, h = int(x), int(y), int(w), int(h)
    return 0 <= x < w and 0 <= y < h


class WildfireEnvironment(Environment):
    """

    Weather-aware wildfire simulation.



    Grid encodings:

      0 = ash (burned out)

      1 = fuel / vegetation

      2 = burning

      3 = firebreak

      4 = watered / damp



    Each step:

      - agent acts (water/break/wait)

      - burning spreads to neighbors with wind + humidity effects

      - burning cells burn for multiple ticks, then become ash

    """

    def __init__(

        self,

        width: int = 32,

        height: int = 32,

        base_ignite_prob: float = 0.30,

        wind_bias: float = 0.20,      # kept for compatibility (not directly used in B model)

        diag_factor: float = 0.7,     # kept for compatibility (not directly used in B model)

        humidity: float = 0.25,

        init_sources: int = 2,

        seed: int = 3407,

        max_steps: int = 128,

        water_capacity: int = 8,      # ↓ encourage strategic water use

        break_capacity: int = 50,

    ):
        super().__init__()

        # --- Env-var overrides (optional) ---
        width     = int(os.environ.get("WILDFIRE_WIDTH", width))
        height    = int(os.environ.get("WILDFIRE_HEIGHT", height))
        humidity  = float(os.environ.get("WILDFIRE_HUMIDITY", humidity))
        forced_wind = os.environ.get("WILDFIRE_WIND", None)

        # Store config (ensure integers)
        self.w = int(width)
        self.h = int(height)
        self.base_ignite_prob = base_ignite_prob
        self.wind_bias = wind_bias
        self.diag_factor = diag_factor
        self.init_humidity = humidity
        self.init_sources = init_sources
        self.rng = random.Random(seed)
        self.max_steps = max_steps
        self.init_water = water_capacity
        self.init_breaks = break_capacity
        self.forced_wind = forced_wind

        # burn lifetime in ticks (balanced model)
        self.burn_lifetime = 3

        # Initialize state with minimal defaults (will be properly set in reset())
        # We can't use WildfireState() directly due to Pydantic/dataclass conflicts,
        # so we'll initialize it in reset() and handle None case in state property
        self._state: WildfireState | None = None

    # --- Core API ---

    def reset(self) -> WildfireObservation:
        # Ensure w and h are integers (defensive type conversion)
        w, h = int(self.w), int(self.h)
        
        # Start with all fuel
        grid = [1] * (w * h)

        # Wind (forced if provided)
        if self.forced_wind and self.forced_wind in DIRS_8:
            wind_dir = self.forced_wind
        else:
            wind_dir = self.rng.choice(list(DIRS_8.keys()))

        # Humidity small variation around init
        humidity = min(1.0, max(0.0, self.init_humidity + self.rng.uniform(-0.05, 0.05)))

        # Place initial fires
        for _ in range(self.init_sources):
            x = self.rng.randrange(w)
            y = self.rng.randrange(h)
            i = idx(x, y, w)
            # Safety check: ensure index is within grid bounds
            if 0 <= i < len(grid):
                grid[i] = 2

        # Initialize burn timers before creating state
        burn_timers = [0] * (w * h)
        
        # Use model_construct to bypass Pydantic validation for dataclass/Pydantic compatibility
        self._state = WildfireState.model_construct(
            episode_id=str(uuid.uuid4()),
            step_count=0,
            total_burned=0,
            total_extinguished=0,
            last_action="reset",
            width=w,
            height=h,
            wind_dir=wind_dir,
            humidity=humidity,
            remaining_water=self.init_water,
            remaining_breaks=self.init_breaks,
            grid=grid,
            burn_timers=burn_timers,
        )

        obs = self._make_observation(reward_hint=0.0)
        return obs

    def step(self, action: WildfireAction) -> WildfireObservation:
        st = self._state
        reward = 0.0

        # --- Agent action effects ---
        if (
            action.action == "water"
            and st.remaining_water > 0
            and action.x is not None
            and action.y is not None
        ):
            reward += self._apply_water(action.x, action.y)
        elif (
            action.action == "break"
            and st.remaining_breaks > 0
            and action.x is not None
            and action.y is not None
        ):
            reward += self._apply_break(action.x, action.y)
        elif action.action == "wait":
            pass
        else:
            reward -= 0.05  # invalid or exhausted resources

        # --- Natural fire dynamics ---
        prev_burning = self._burning_count()
        prev_burned = sum(1 for v in st.grid if v == 0)

        newly_burned = self._spread_fire()
        new_burning = self._burning_count()
        now_burned = sum(1 for v in st.grid if v == 0)

        st.total_burned += newly_burned
        st.step_count += 1
        st.last_action = action.action

        # --- Spread vs containment shaping ---
        spread_delta = new_burning - prev_burning
        burned_delta = now_burned - prev_burned

        # Strong penalty for spread
        if spread_delta > 0:
            reward -= 0.15 * spread_delta  # 🔥 focus on containment
        elif spread_delta < 0:
            reward += 0.10 * abs(spread_delta)  # reward shrinkage

        # Mild penalty for newly burned cells (area loss)
        if burned_delta > 0:
            reward -= 0.05 * burned_delta

        # Small time penalty to prefer fast control
        reward -= 0.01

        done = self._is_done()

        # --- End of episode bonuses ---
        if done:
            saved_ratio = self._saved_cells() / (self.w * self.h)
            burned_ratio = now_burned / (self.w * self.h)
            burning_left = self._burning_count()

            # Big containment bonus
            if burning_left == 0:
                reward += 0.5 + 0.5 * saved_ratio

            # Fallback proportional reward
            reward += 0.2 * (1.0 - burned_ratio)

        obs = self._make_observation(reward_hint=reward)
        obs.done = done
        obs.reward = reward
        return obs


    # --- Internal mechanics ---

    def _apply_water(self, x: int, y: int) -> float:
        st = self._state
        # Ensure x and y are integers (defensive type conversion)
        x, y = int(x), int(y)
        if not in_bounds(x, y, self.w, self.h):
            return -0.05

        # Strong penalty if no water left
        if st.remaining_water <= 0:
            return -0.5

        i = idx(x, y, self.w)
        # Safety check: ensure index is within grid bounds
        if i < 0 or i >= len(st.grid):
            return -0.05
        
        reward = 0.0

        if st.grid[i] == 2:
            st.grid[i] = 4  # extinguish & dampen
            st.burn_timers[i] = 0
            st.total_extinguished += 1
            reward += 0.25
        elif st.grid[i] == 1:
            st.grid[i] = 4  # dampen fuel (mild penalty to avoid spamming)
            st.burn_timers[i] = 0
            reward -= 0.10
        elif st.grid[i] == 4:
            # redundant watering
            reward -= 0.05
        else:
            # watering ash/break gives slight penalty
            reward -= 0.05

        st.remaining_water -= 1
        return reward

    def _apply_break(self, x: int, y: int) -> float:
        st = self._state
        # Ensure x and y are integers (defensive type conversion)
        x, y = int(x), int(y)
        if not in_bounds(x, y, self.w, self.h):
            return -0.05
        i = idx(x, y, self.w)
        # Safety check: ensure index is within grid bounds
        if i < 0 or i >= len(st.grid):
            return -0.05
        
        reward = 0.0

        if st.grid[i] in (1, 4):
            st.grid[i] = 3
            st.burn_timers[i] = 0
            reward += 0.15  # slightly more than before to make firebreaks attractive
        elif st.grid[i] == 2:
            st.grid[i] = 3
            st.burn_timers[i] = 0
            reward -= 0.02
        elif st.grid[i] == 3:
            reward -= 0.01
        else:
            reward -= 0.02

        st.remaining_breaks -= 1
        return reward

    def _spread_fire(self) -> int:
        """

        Balanced wildfire spread model:

          - burning cells persist for multiple ticks before turning to ash

          - 8-direction spread (diagonals weaker)

          - wind accelerates in wind direction, weakens upwind

          - humidity suppresses ignition probability

          - water (4) is IMMUNE to ignition while damp and reverts to fuel after several ticks

        """
        st = self._state
        new_grid = st.grid[:]
        newly_burned = 0

        # Ensure w and h are integers (defensive type conversion)
        w, h = int(self.w), int(self.h)

        # 8-neighbor model
        neighbors = [(-1, 0), (1, 0), (0, -1), (0, 1),
                     (-1, -1), (1, -1), (-1, 1), (1, 1)]
        wx, wy = DIRS_8.get(st.wind_dir, (0, 0))

        base = self.base_ignite_prob
        humidity_factor = (1.0 - st.humidity)

        ignite_flags = [False] * (w * h)

        # First pass: evaluate ignitions, increment burn timers
        for y in range(h):
            for x in range(w):
                i = idx(x, y, w)
                # Safety check: ensure index is within grid bounds
                if i < 0 or i >= len(st.grid):
                    continue
                cell = st.grid[i]

                if cell == 2:  # burning
                    st.burn_timers[i] += 1

                    for dx, dy in neighbors:
                        nx, ny = x + dx, y + dy
                        if not in_bounds(nx, ny, w, h):
                            continue
                        ni = idx(nx, ny, w)
                        # Safety check: ensure neighbor index is within grid bounds
                        if ni < 0 or ni >= len(st.grid):
                            continue
                        target = st.grid[ni]

                        # Only fuel or water/damp can be candidates, but cells with code 4 (watered/damp) are immune to ignition
                        if target == 4:
                            # Watered/damp cells (code 4) do not ignite at all while in this state
                            continue
                        if target != 1:
                            continue

                        # Wind multiplier
                        if (dx, dy) == (wx, wy):
                            wind_mult = 2.0
                        elif (dx, dy) == (-wx, -wy):
                            wind_mult = 0.5
                        else:
                            wind_mult = 1.0

                        # Diagonals weaker
                        diag_mult = 0.6 if (dx != 0 and dy != 0) else 1.0

                        p = base * humidity_factor * wind_mult * diag_mult
                        p = max(0.0, min(1.0, p))
                        if self.rng.random() < p:
                            # Safety check: ensure ni is within ignite_flags bounds
                            if 0 <= ni < len(ignite_flags):
                                ignite_flags[ni] = True

        # Second pass: apply transitions
        for i, cell in enumerate(st.grid):
            # Safety check: ensure index is within bounds for all arrays
            if i < 0 or i >= len(new_grid) or i >= len(st.burn_timers):
                continue
            
            if cell == 2:
                # burns for burn_lifetime ticks before turning to ash
                if st.burn_timers[i] >= self.burn_lifetime:
                    new_grid[i] = 0  # ash
                    newly_burned += 1
                else:
                    new_grid[i] = 2  # keep burning
            elif i < len(ignite_flags) and ignite_flags[i] and new_grid[i] == 1:
                new_grid[i] = 2
                st.burn_timers[i] = 0
            elif cell == 4:
                # Water stays damp for several ticks before reverting to fuel
                st.burn_timers[i] += 1
                if st.burn_timers[i] >= 6:   # was 3; extend to make water useful
                    new_grid[i] = 1

        st.grid = new_grid
        return newly_burned

    def _burning_count(self) -> int:
        return sum(1 for v in self._state.grid if v == 2)

    def _saved_cells(self) -> int:
        # cells not turned to ash (includes fuel, burning, break, water)
        return sum(1 for v in self._state.grid if v in (1, 2, 3, 4))

    def _is_done(self) -> bool:
        return self._burning_count() == 0 or self._state.step_count >= self.max_steps

    def _make_observation(self, reward_hint: float = 0.0) -> WildfireObservation:
        st = self._state
        burning = self._burning_count()
        burned = sum(1 for v in st.grid if v == 0)
        # Use model_construct to bypass Pydantic validation for dataclass/Pydantic compatibility
        return WildfireObservation.model_construct(
            grid=st.grid[:],
            width=self.w,
            height=self.h,
            step=st.step_count,
            wind_dir=st.wind_dir,
            humidity=st.humidity,
            burning_count=burning,
            remaining_water=st.remaining_water,     # ✅ new
            remaining_breaks=st.remaining_breaks,   # ✅ new
            burned_count=burned,
            reward_hint=reward_hint,
        )

    # --- Required abstract property implementation ---
    @property
    def state(self) -> WildfireState:
        """Return the current environment state."""
        if self._state is None:
            # Initialize with minimal defaults if accessed before reset()
            # Use model_construct to bypass Pydantic validation for dataclass/Pydantic compatibility
            self._state = WildfireState.model_construct(
                episode_id="",
                step_count=0,
                total_burned=0,
                total_extinguished=0,
                last_action="reset",
                width=0,
                height=0,
                wind_dir="CALM",
                humidity=0.25,
                remaining_water=self.init_water,
                remaining_breaks=self.init_breaks,
                grid=[],
                burn_timers=[],
            )
        return self._state