File size: 5,607 Bytes
363abf3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
"""
Random agent baseline for the Wildfire Containment Simulator.

Selects random valid actions each step. Serves as the lower-bound
baseline for score comparison.
"""

from __future__ import annotations

import numpy as np

from env.models import (
    Action, ActionType, Observation, Direction,
    FireState, FuelType, DIRECTION_DELTAS,
)


class RandomAgent:
    """Agent that picks a random valid action each step."""

    def __init__(self, seed: int = 42):
        self.rng = np.random.default_rng(seed)

    def act(self, obs: Observation) -> Action:
        """Select a random valid action given the current observation."""
        # Collect available actions
        candidates: list[Action] = []

        # DEPLOY_CREW: deploy undeployed crews to safe cells
        for crew in obs.resources.crews:
            if crew.is_active and not crew.is_deployed:
                safe_cells = self._get_safe_cells(obs)
                if safe_cells:
                    r, c = safe_cells[self.rng.integers(0, len(safe_cells))]
                    candidates.append(Action(
                        action_type=ActionType.DEPLOY_CREW,
                        crew_id=crew.crew_id,
                        target_row=r, target_col=c,
                    ))

        # MOVE_CREW: move deployed crews in random direction
        for crew in obs.resources.crews:
            if crew.is_active and crew.is_deployed:
                valid_dirs = self._get_valid_move_dirs(obs, crew.row, crew.col)
                if valid_dirs:
                    d = valid_dirs[self.rng.integers(0, len(valid_dirs))]
                    candidates.append(Action(
                        action_type=ActionType.MOVE_CREW,
                        crew_id=crew.crew_id,
                        direction=d,
                    ))

        # DROP_RETARDANT: drop on burning area
        for tanker in obs.resources.tankers:
            if tanker.is_active and tanker.cooldown_remaining == 0:
                burning = self._get_burning_cells(obs)
                if burning:
                    r, c = burning[self.rng.integers(0, len(burning))]
                    candidates.append(Action(
                        action_type=ActionType.DROP_RETARDANT,
                        tanker_id=tanker.tanker_id,
                        target_row=r, target_col=c,
                    ))

        # BUILD_FIREBREAK: if crew deployed and budget available
        if obs.resources.firebreak_budget > 0:
            for crew in obs.resources.crews:
                if crew.is_active and crew.is_deployed:
                    dirs = list(Direction)
                    self.rng.shuffle(dirs)
                    for d in dirs:
                        dr, dc = DIRECTION_DELTAS[d]
                        nr, nc = crew.row + dr, crew.col + dc
                        if self._is_valid_firebreak_target(obs, nr, nc):
                            candidates.append(Action(
                                action_type=ActionType.BUILD_FIREBREAK,
                                crew_id=crew.crew_id,
                                direction=d,
                            ))
                            break

        # IDLE: always available
        candidates.append(Action(
            action_type=ActionType.IDLE,
            reason="Random agent waiting",
        ))

        # Pick random candidate
        idx = self.rng.integers(0, len(candidates))
        return candidates[idx]

    def _get_safe_cells(self, obs: Observation) -> list[tuple[int, int]]:
        """Get cells that are safe to deploy a crew to."""
        safe = []
        for row in obs.grid:
            for cell in row:
                if (cell.fire_state in (FireState.UNBURNED, FireState.FIREBREAK, FireState.SUPPRESSED)
                        and cell.fuel_type not in (FuelType.WATER,)
                        and not cell.crew_present):
                    safe.append((cell.row, cell.col))
        # Sample a subset to avoid huge lists
        if len(safe) > 20:
            indices = self.rng.choice(len(safe), 20, replace=False)
            safe = [safe[i] for i in indices]
        return safe

    def _get_valid_move_dirs(self, obs: Observation, row: int, col: int) -> list[Direction]:
        """Get directions a crew can move from (row, col)."""
        valid = []
        rows = len(obs.grid)
        cols = len(obs.grid[0]) if rows > 0 else 0
        for d in Direction:
            dr, dc = DIRECTION_DELTAS[d]
            nr, nc = row + dr, col + dc
            if 0 <= nr < rows and 0 <= nc < cols:
                cell = obs.grid[nr][nc]
                if (cell.fuel_type != FuelType.WATER
                        and cell.fire_state not in (FireState.UNKNOWN,)):
                    valid.append(d)
        return valid

    def _get_burning_cells(self, obs: Observation) -> list[tuple[int, int]]:
        """Get cells that are currently burning."""
        burning = []
        for row in obs.grid:
            for cell in row:
                if cell.fire_state == FireState.BURNING:
                    burning.append((cell.row, cell.col))
        return burning

    def _is_valid_firebreak_target(self, obs: Observation, row: int, col: int) -> bool:
        """Check if a cell is valid for firebreak construction."""
        rows = len(obs.grid)
        cols = len(obs.grid[0]) if rows > 0 else 0
        if not (0 <= row < rows and 0 <= col < cols):
            return False
        cell = obs.grid[row][col]
        return (cell.fire_state == FireState.UNBURNED
                and cell.fuel_type not in (FuelType.WATER, FuelType.URBAN))