File size: 11,600 Bytes
fe115bf
 
 
 
86f507b
fe115bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ced530
fe115bf
3a5b76e
5a39470
2038b66
fe115bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86f507b
fe115bf
86f507b
fe115bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a39470
9dd7378
fe115bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49913c2
 
fe115bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9dd7378
fe115bf
 
 
 
 
 
 
 
9dd7378
fe115bf
 
 
 
 
 
49913c2
 
 
 
fe115bf
5a39470
9dd7378
5a39470
 
86f507b
9dd7378
 
 
5a39470
 
fe115bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
"""
Tetris Environment for OpenEnv.
Full game logic with combo scoring reward system.
"""
__version__ = "0.6.0"  # configurable height_breach_penalty per instance

import random
import copy
from typing import Optional

# Standard Tetris pieces (rotations handled by rotation logic)
PIECES = {
    "I": [[1, 1, 1, 1]],
    "O": [[1, 1],
          [1, 1]],
    "T": [[0, 1, 0],
          [1, 1, 1]],
    "S": [[0, 1, 1],
          [1, 1, 0]],
    "Z": [[1, 1, 0],
          [0, 1, 1]],
    "L": [[1, 0],
          [1, 0],
          [1, 1]],
    "J": [[0, 1],
          [0, 1],
          [1, 1]],
}

BOARD_WIDTH = 10
BOARD_HEIGHT = 20

# Combo scoring: more lines cleared at once = disproportionately higher reward
LINE_REWARDS = {
    1: 100,
    2: 300,
    3: 700,
    4: 1500,  # "Tetris!" — the dream
}

STEP_PENALTY = -0.1
HOLE_PENALTY_MULT = -5
GAME_OVER_PENALTY = -50
HEIGHT_BREACH_THRESHOLD = 4
HEIGHT_BREACH_PENALTY = -50  # per level above threshold, decays with pieces_locked


def rotate_cw(piece: list[list[int]]) -> list[list[int]]:
    """Rotate piece 90 degrees clockwise."""
    rows = len(piece)
    cols = len(piece[0])
    rotated = [[0] * rows for _ in range(cols)]
    for r in range(rows):
        for c in range(cols):
            rotated[c][rows - 1 - r] = piece[r][c]
    return rotated


def rotate_ccw(piece: list[list[int]]) -> list[list[int]]:
    """Rotate piece 90 degrees counter-clockwise."""
    rows = len(piece)
    cols = len(piece[0])
    rotated = [[0] * rows for _ in range(cols)]
    for r in range(rows):
        for c in range(cols):
            rotated[cols - 1 - c][r] = piece[r][c]
    return rotated


class TetrisEnv:
    def __init__(self, seed: Optional[int] = None, height_breach_penalty: float = HEIGHT_BREACH_PENALTY):
        self.rng = random.Random(seed)
        self.height_breach_penalty = height_breach_penalty
        self.reset_state()

    def reset_state(self, seed: Optional[int] = None):
        if seed is not None:
            self.rng = random.Random(seed)
        self.board = [[0] * BOARD_WIDTH for _ in range(BOARD_HEIGHT)]
        self.score = 0
        self.total_lines = 0
        self.steps = 0
        self.done = False
        self.current_piece = None
        self.current_piece_name = ""
        self.current_x = 0
        self.current_y = 0
        self.next_piece_name = ""
        self.next_piece = None
        self.max_penalized_height = HEIGHT_BREACH_THRESHOLD
        self.pieces_locked = 0
        self._spawn_next()
        self._spawn_next()

    def _spawn_next(self):
        """Move next piece to current, generate new next piece."""
        self.current_piece = self.next_piece
        self.current_piece_name = self.next_piece_name
        self.next_piece_name = self.rng.choice(list(PIECES.keys()))
        self.next_piece = copy.deepcopy(PIECES[self.next_piece_name])

        if self.current_piece is not None:
            piece_width = len(self.current_piece[0])
            self.current_x = BOARD_WIDTH // 2 - piece_width // 2
            self.current_y = 0

            if not self._is_valid_position(self.current_piece, self.current_x, self.current_y):
                self.done = True

    def _is_valid_position(self, piece: list[list[int]], x: int, y: int) -> bool:
        """Check if piece at (x, y) doesn't collide with board or walls."""
        for row_idx, row in enumerate(piece):
            for col_idx, cell in enumerate(row):
                if cell == 0:
                    continue
                board_x = x + col_idx
                board_y = y + row_idx
                if board_x < 0 or board_x >= BOARD_WIDTH:
                    return False
                if board_y < 0 or board_y >= BOARD_HEIGHT:
                    return False
                if self.board[board_y][board_x] != 0:
                    return False
        return True

    def _lock_piece(self):
        """Lock current piece into the board."""
        for row_idx, row in enumerate(self.current_piece):
            for col_idx, cell in enumerate(row):
                if cell:
                    bx = self.current_x + col_idx
                    by = self.current_y + row_idx
                    if 0 <= by < BOARD_HEIGHT and 0 <= bx < BOARD_WIDTH:
                        self.board[by][bx] = 1

    def _clear_lines(self) -> int:
        """Clear completed lines. Returns number of lines cleared."""
        lines_cleared = 0
        new_board = []
        for row in self.board:
            if all(cell == 1 for cell in row):
                lines_cleared += 1
            else:
                new_board.append(row)

        # Add empty rows at the top
        while len(new_board) < BOARD_HEIGHT:
            new_board.insert(0, [0] * BOARD_WIDTH)

        self.board = new_board
        self.total_lines += lines_cleared
        return lines_cleared

    def _count_holes(self) -> int:
        """Count holes: empty cells with at least one filled cell above them."""
        holes = 0
        for col in range(BOARD_WIDTH):
            found_block = False
            for row in range(BOARD_HEIGHT):
                if self.board[row][col] == 1:
                    found_block = True
                elif found_block and self.board[row][col] == 0:
                    holes += 1
        return holes

    def _max_height(self) -> int:
        """Height of the tallest column."""
        for row in range(BOARD_HEIGHT):
            if any(cell == 1 for cell in self.board[row]):
                return BOARD_HEIGHT - row
        return 0

    def _drop_piece(self):
        """Hard drop: move piece down until it can't go further."""
        while self._is_valid_position(self.current_piece, self.current_x, self.current_y + 1):
            self.current_y += 1

    def get_board_with_piece(self) -> list[list[int]]:
        """Return board with current piece overlaid (for observation)."""
        display = copy.deepcopy(self.board)
        if self.current_piece and not self.done:
            for row_idx, row in enumerate(self.current_piece):
                for col_idx, cell in enumerate(row):
                    if cell:
                        bx = self.current_x + col_idx
                        by = self.current_y + row_idx
                        if 0 <= by < BOARD_HEIGHT and 0 <= bx < BOARD_WIDTH:
                            display[by][bx] = 2  # 2 = current piece
        return display

    def board_to_text(self) -> str:
        """Render board as text for LLM observation."""
        display = self.get_board_with_piece()
        symbols = {0: ".", 1: "#", 2: "@"}
        lines = []
        lines.append("+" + "-" * BOARD_WIDTH + "+")
        for row in display:
            line = "|" + "".join(symbols[c] for c in row) + "|"
            lines.append(line)
        lines.append("+" + "-" * BOARD_WIDTH + "+")
        return "\n".join(lines)

    def piece_to_text(self, piece: list[list[int]]) -> str:
        """Render a piece as text."""
        return "\n".join("".join("#" if c else "." for c in row) for row in piece)

    def step(self, action: str) -> dict:
        """
        Execute one action. Valid actions:
        - "left": move piece left
        - "right": move piece right
        - "rotate_cw": rotate clockwise
        - "rotate_ccw": rotate counter-clockwise
        - "drop": hard drop and lock
        - "down": soft drop one row
        - "noop": do nothing (piece falls one row)

        Returns dict with: observation, reward, done, info
        """
        if self.done:
            return self._make_result(0)

        self.steps += 1
        reward = STEP_PENALTY  # base penalty per step

        holes_before = self._count_holes()

        action = action.strip().lower()

        if action == "left":
            if self._is_valid_position(self.current_piece, self.current_x - 1, self.current_y):
                self.current_x -= 1
        elif action == "right":
            if self._is_valid_position(self.current_piece, self.current_x + 1, self.current_y):
                self.current_x += 1
        elif action == "rotate_cw":
            rotated = rotate_cw(self.current_piece)
            if self._is_valid_position(rotated, self.current_x, self.current_y):
                self.current_piece = rotated
        elif action == "rotate_ccw":
            rotated = rotate_ccw(self.current_piece)
            if self._is_valid_position(rotated, self.current_x, self.current_y):
                self.current_piece = rotated
        elif action == "drop":
            self._drop_piece()
        elif action == "down":
            if self._is_valid_position(self.current_piece, self.current_x, self.current_y + 1):
                self.current_y += 1
        elif action == "noop":
            pass

        # After action: try to move piece down (gravity)
        if action != "drop":
            if self._is_valid_position(self.current_piece, self.current_x, self.current_y + 1):
                self.current_y += 1
            else:
                # Can't move down — lock piece
                self._lock_piece()
                self.pieces_locked += 1
                lines = self._clear_lines()
                if lines > 0:
                    reward += LINE_REWARDS.get(lines, lines * 400)
                    self.score += LINE_REWARDS.get(lines, lines * 400)
                self._spawn_next()
        else:
            # Drop action: lock immediately
            self._lock_piece()
            self.pieces_locked += 1
            lines = self._clear_lines()
            if lines > 0:
                reward += LINE_REWARDS.get(lines, lines * 400)
                self.score += LINE_REWARDS.get(lines, lines * 400)
            self._spawn_next()

        # Penalty only for NEW holes created by this step
        new_holes = self._count_holes() - holes_before
        if new_holes > 0:
            reward += HOLE_PENALTY_MULT * new_holes

        # One-time penalty for each height level breached above threshold
        # Decays by 5 per piece locked: piece 0 → -50, piece 9 → -5, piece 10+ → 0
        current_height = self._max_height()
        if current_height > self.max_penalized_height:
            penalty_per_level = min(0, self.height_breach_penalty + 5 * self.pieces_locked)
            if penalty_per_level < 0:
                new_levels = current_height - self.max_penalized_height
                reward += penalty_per_level * new_levels
            self.max_penalized_height = current_height

        if self.done:
            reward += GAME_OVER_PENALTY

        return self._make_result(reward)

    def _make_result(self, reward: float) -> dict:
        """Build the observation/result dict."""
        return {
            "board": self.board_to_text(),
            "current_piece": self.current_piece_name,
            "current_piece_shape": self.piece_to_text(self.current_piece) if self.current_piece else "",
            "next_piece": self.next_piece_name,
            "next_piece_shape": self.piece_to_text(self.next_piece) if self.next_piece else "",
            "piece_x": self.current_x,
            "piece_y": self.current_y,
            "score": self.score,
            "total_lines": self.total_lines,
            "steps": self.steps,
            "max_height": self._max_height(),
            "holes": self._count_holes(),
            "reward": reward,
            "done": self.done,
        }

    def reset(self, seed: Optional[int] = None) -> dict:
        """Reset the environment. Returns initial observation."""
        self.reset_state(seed)
        return self._make_result(0)