kbsooo commited on
Commit
1e2624a
·
verified ·
1 Parent(s): 3c2b820

Upload envs/fruitbox_env.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. envs/fruitbox_env.py +269 -0
envs/fruitbox_env.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Improved FruitBox environment that addresses several issues in the baseline:
3
+ - Optional backward board generation for solvable boards (high coverage).
4
+ - Illegal actions advance time and can carry a penalty; episodes end when no legal actions.
5
+ - Incremental action-mask updates so we do not rescan every rectangle on illegal steps.
6
+ - Reward can include zero-valued cells to encourage 0 활용 전략.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from dataclasses import dataclass
12
+ from typing import Dict, Optional, Tuple, List
13
+
14
+ import gymnasium as gym
15
+ import numpy as np
16
+ from gymnasium import spaces
17
+
18
+ from envs.backward_generator import BackwardBoardGenerator
19
+
20
+
21
+ @dataclass
22
+ class FruitBoxImprovedConfig:
23
+ rows: int = 10
24
+ cols: int = 17
25
+ reward_per_cell: float = 1.0
26
+ reward_per_zero_cell: float = 0.0 # zero-valued cells (cleared apples) give no extra reward
27
+ illegal_action_reward: float = -1.0
28
+ max_steps: int = 500 # safety cap; original game uses time, not steps
29
+
30
+ # Board generation
31
+ use_backward_generator: bool = True
32
+ target_coverage: float = 0.95 # only used when use_backward_generator is True
33
+ enforce_total_sum_mod_10: bool = True # fallback random generation
34
+
35
+ # Rendering
36
+ render_mode: Optional[str] = None # "ansi" or None
37
+
38
+
39
+ class FruitBoxEnvImproved(gym.Env):
40
+ metadata = {"render_modes": ["ansi"], "render_fps": 30}
41
+
42
+ def __init__(self, config: Optional[FruitBoxImprovedConfig] = None, **kwargs):
43
+ super().__init__()
44
+ if config is None:
45
+ cfg = FruitBoxImprovedConfig(**kwargs) if kwargs else FruitBoxImprovedConfig()
46
+ else:
47
+ cfg = config
48
+ for k, v in kwargs.items():
49
+ setattr(cfg, k, v)
50
+ self.cfg: FruitBoxImprovedConfig = cfg
51
+
52
+ R, C = self.cfg.rows, self.cfg.cols
53
+ assert R > 0 and C > 0, "rows and cols must be positive"
54
+
55
+ # Observation: integers 0..9 (0 means empty)
56
+ self.observation_space = spaces.Box(low=0, high=9, shape=(R, C), dtype=np.int8)
57
+
58
+ # Actions: choose any axis-aligned rectangle (r1,c1,r2,c2) with r1<=r2, c1<=c2
59
+ rects = []
60
+ for r1 in range(R):
61
+ for r2 in range(r1, R):
62
+ for c1 in range(C):
63
+ for c2 in range(c1, C):
64
+ rects.append((r1, c1, r2, c2))
65
+ self.rects: np.ndarray = np.array(rects, dtype=np.int32) # (N, 4)
66
+ self.n_actions: int = self.rects.shape[0]
67
+ self.action_space = spaces.Discrete(self.n_actions)
68
+
69
+ # Precompute indices for vectorized prefix-sum rectangle queries
70
+ self._idx_r1 = self.rects[:, 0]
71
+ self._idx_c1 = self.rects[:, 1]
72
+ self._idx_r2p = self.rects[:, 2] + 1 # r2+1
73
+ self._idx_c2p = self.rects[:, 3] + 1 # c2+1
74
+
75
+ # Cell -> list of rectangles that include the cell (for incremental updates)
76
+ self._cell_to_rects: List[np.ndarray] = self._build_cell_to_rects()
77
+
78
+ self.board: np.ndarray = np.zeros((R, C), dtype=np.int16)
79
+ self.steps: int = 0
80
+ self.np_random = np.random.default_rng()
81
+
82
+ # Cached per-rect sums and mask
83
+ self._rect_sums: np.ndarray = np.zeros(self.n_actions, dtype=np.int32)
84
+ self._action_mask: np.ndarray = np.zeros(self.n_actions, dtype=bool)
85
+
86
+ # ---------- utilities ----------
87
+ def _build_cell_to_rects(self) -> List[np.ndarray]:
88
+ R, C = self.cfg.rows, self.cfg.cols
89
+ mapping: List[List[int]] = [[] for _ in range(R * C)]
90
+ for idx, (r1, c1, r2, c2) in enumerate(self.rects):
91
+ for r in range(r1, r2 + 1):
92
+ base = r * C
93
+ for c in range(c1, c2 + 1):
94
+ mapping[base + c].append(idx)
95
+ return [np.array(indices, dtype=np.int32) for indices in mapping]
96
+
97
+ @staticmethod
98
+ def _padded_prefix_sums(arr: np.ndarray) -> np.ndarray:
99
+ """Return (R+1, C+1) padded summed-area table."""
100
+ R, C = arr.shape
101
+ ps = np.zeros((R + 1, C + 1), dtype=np.int32)
102
+ ps[1:, 1:] = arr.cumsum(axis=0).cumsum(axis=1)
103
+ return ps
104
+
105
+ def _rect_sums_vectorized(self, ps: np.ndarray) -> np.ndarray:
106
+ """Compute sums for all rectangles using padded prefix sums (vectorized)."""
107
+ return (
108
+ ps[self._idx_r2p, self._idx_c2p]
109
+ - ps[self._idx_r1, self._idx_c2p]
110
+ - ps[self._idx_r2p, self._idx_c1]
111
+ + ps[self._idx_r1, self._idx_c1]
112
+ )
113
+
114
+ def _gen_board(self) -> np.ndarray:
115
+ """Generate a board; prefers solvable boards via backward generator."""
116
+ R, C = self.cfg.rows, self.cfg.cols
117
+ if self.cfg.use_backward_generator:
118
+ gen_seed = int(self.np_random.integers(0, 1_000_000_000))
119
+ generator = BackwardBoardGenerator(rows=R, cols=C, seed=gen_seed)
120
+ board, solution = generator.generate(target_coverage=self.cfg.target_coverage)
121
+ self._last_solution = solution
122
+ return board.astype(np.int16, copy=False)
123
+
124
+ # Fallback: random board with sum%10 adjusted
125
+ low, high = 1, 9
126
+ board = self.np_random.integers(low, high + 1, size=(R, C), dtype=np.int16)
127
+ if self.cfg.enforce_total_sum_mod_10:
128
+ delta = int((10 - (board.sum() % 10)) % 10)
129
+ tries = 0
130
+ while delta > 0 and tries < 100:
131
+ r = int(self.np_random.integers(0, R))
132
+ c = int(self.np_random.integers(0, C))
133
+ inc = min(9 - int(board[r, c]), delta)
134
+ if inc > 0:
135
+ board[r, c] += inc
136
+ delta -= inc
137
+ tries += 1
138
+ return board
139
+
140
+ def _compute_full_mask(self, board: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
141
+ """Compute sums and mask for all rectangles."""
142
+ ps_val = self._padded_prefix_sums(board)
143
+ sums = self._rect_sums_vectorized(ps_val)
144
+ mask = (sums == 10)
145
+ return sums.astype(np.int32, copy=False), mask
146
+
147
+ def _update_after_clear(self, r1: int, c1: int, r2: int, c2: int, cleared_vals: np.ndarray):
148
+ """
149
+ Incrementally update rectangle sums/mask after setting a region to zero.
150
+ cleared_vals is the pre-zeroing values of shape (r2-r1+1, c2-c1+1).
151
+ """
152
+ R, C = self.cfg.rows, self.cfg.cols
153
+ deltas: Dict[int, int] = {}
154
+ for dr, row in enumerate(range(r1, r2 + 1)):
155
+ base = row * C
156
+ for dc, col in enumerate(range(c1, c2 + 1)):
157
+ val = int(cleared_vals[dr, dc])
158
+ if val == 0:
159
+ continue
160
+ cell_rects = self._cell_to_rects[base + col]
161
+ for rect_idx in cell_rects:
162
+ deltas[rect_idx] = deltas.get(rect_idx, 0) + val
163
+
164
+ for rect_idx, delta in deltas.items():
165
+ self._rect_sums[rect_idx] -= delta
166
+ self._action_mask[rect_idx] = (self._rect_sums[rect_idx] == 10)
167
+
168
+ # ---------- Gymnasium API ----------
169
+ def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None) -> Tuple[np.ndarray, dict]:
170
+ if seed is not None:
171
+ self.np_random = np.random.default_rng(seed)
172
+ self.steps = 0
173
+ self.board = self._gen_board().astype(np.int16, copy=False)
174
+ self._rect_sums, self._action_mask = self._compute_full_mask(self.board)
175
+ info = {"action_mask": self._action_mask}
176
+ obs = self.board.clip(0, 9).astype(np.int8, copy=False)
177
+ return obs, info
178
+
179
+ def step(self, action: int):
180
+ assert isinstance(action, (int, np.integer)), "action must be an integer index"
181
+ terminated = False
182
+ truncated = False
183
+ reward = 0.0
184
+
185
+ # Illegal action: advance time, optional penalty, end if no legal actions remain.
186
+ if action < 0 or action >= self.n_actions or not self._action_mask[action]:
187
+ self.steps += 1
188
+ reward = float(self.cfg.illegal_action_reward)
189
+ if not self._action_mask.any():
190
+ terminated = True
191
+ if self.steps >= self.cfg.max_steps:
192
+ truncated = True
193
+ obs = self.board.clip(0, 9).astype(np.int8, copy=False)
194
+ info = {"action_mask": self._action_mask, "illegal_action": True}
195
+ return obs, reward, terminated, truncated, info
196
+
197
+ r1, c1, r2, c2 = self.rects[action]
198
+ region = self.board[r1 : r2 + 1, c1 : c2 + 1]
199
+ cleared_vals = region.copy()
200
+ cells_total = region.size
201
+ cells_nonzero = int(np.sum(region > 0))
202
+ cells_zero = cells_total - cells_nonzero
203
+
204
+ # Apply action
205
+ self.board[r1 : r2 + 1, c1 : c2 + 1] = 0
206
+ self.steps += 1
207
+
208
+ reward = (
209
+ self.cfg.reward_per_cell * float(cells_nonzero)
210
+ + self.cfg.reward_per_zero_cell * float(cells_zero)
211
+ )
212
+
213
+ # Incremental mask update
214
+ self._update_after_clear(r1, c1, r2, c2, cleared_vals)
215
+
216
+ if not self._action_mask.any():
217
+ terminated = True
218
+ if self.steps >= self.cfg.max_steps:
219
+ truncated = True
220
+
221
+ obs = self.board.clip(0, 9).astype(np.int8, copy=False)
222
+ info = {"action_mask": self._action_mask, "illegal_action": False}
223
+ return obs, float(reward), terminated, truncated, info
224
+
225
+ # ---------- helpers ----------
226
+ def legal_actions(self) -> np.ndarray:
227
+ return np.nonzero(self._action_mask)[0]
228
+
229
+ def sample_valid_action(self) -> Optional[int]:
230
+ legal = self.legal_actions()
231
+ if legal.size == 0:
232
+ return None
233
+ return int(self.np_random.choice(legal))
234
+
235
+ # ---------- rendering ----------
236
+ def render(self):
237
+ if self.cfg.render_mode != "ansi":
238
+ return
239
+ lines = []
240
+ lines.append(f"Steps={self.steps}")
241
+ lines.append("+" + "---" * self.cfg.cols + "+")
242
+ for r in range(self.cfg.rows):
243
+ row_vals = " ".join(f"{int(v):1d}" for v in self.board[r])
244
+ lines.append(f"| {row_vals} |")
245
+ lines.append("+" + "---" * self.cfg.cols + "+")
246
+ return "\n".join(lines)
247
+
248
+ def close(self):
249
+ pass
250
+
251
+
252
+ # ---- quick smoke test ----
253
+ if __name__ == "__main__":
254
+ env = FruitBoxEnvImproved(FruitBoxImprovedConfig(render_mode="ansi"))
255
+ obs, info = env.reset(seed=0)
256
+ print("Initial legal actions:", len(np.nonzero(info["action_mask"])[0]))
257
+ total = 0.0
258
+ while True:
259
+ mask = info["action_mask"]
260
+ if not mask.any():
261
+ break
262
+ a = int(np.flatnonzero(mask)[0])
263
+ obs, r, terminated, truncated, info = env.step(a)
264
+ total += r
265
+ if env.cfg.render_mode == "ansi":
266
+ print(env.render())
267
+ if terminated or truncated:
268
+ break
269
+ print("Episode total reward:", total)