File size: 12,178 Bytes
cf58b05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
import numpy as np
import torch
import bulletchess
from typing import List, Tuple, Optional
from .vocab import policy_index

# Precompute policy string to index mapping for O(1) lookups
policy_to_idx = {u: i for i, u in enumerate(policy_index)}


def _board_to_12_piece_planes(board: bulletchess.Board) -> np.ndarray:
    piece_types = [bulletchess.PAWN, bulletchess.KNIGHT, bulletchess.BISHOP, bulletchess.ROOK, bulletchess.QUEEN, bulletchess.KING]
    piece_colors = [bulletchess.WHITE, bulletchess.BLACK]

    planes = []
    for color in piece_colors:
        for piece_type in piece_types:
            mask = np.zeros((8, 8), dtype=np.float32)
            # Use board[color, piece_type] to get Bitboard, then iterate over squares
            bitboard = board[color, piece_type]
            for square in bitboard:
                # In bulletchess, squares have an index() method that returns 0-63
                square_idx = square.index()
                rank = square_idx // 8
                file = square_idx % 8
                mask[rank][file] = 1.0
            planes.append(mask)
    # Shape (8,8,12)
    return np.transpose(np.array(planes, dtype=np.float32), (1, 2, 0))


def _castling_planes(board: bulletchess.Board) -> np.ndarray:
    # Order must match existing model expectation via ustotheirs:
    # [WQ, WK, BQ, BK]
    wq = 1.0 if bulletchess.WHITE_QUEENSIDE in board.castling_rights else 0.0
    wk = 1.0 if bulletchess.WHITE_KINGSIDE in board.castling_rights else 0.0
    bq = 1.0 if bulletchess.BLACK_QUEENSIDE in board.castling_rights else 0.0
    bk = 1.0 if bulletchess.BLACK_KINGSIDE in board.castling_rights else 0.0
    planes = [
        np.full((8, 8), wq, dtype=np.float32),
        np.full((8, 8), wk, dtype=np.float32),
        np.full((8, 8), bq, dtype=np.float32),
        np.full((8, 8), bk, dtype=np.float32),
    ]
    return np.stack(planes, axis=0)  # (4,8,8)


def _mirror_board(board: bulletchess.Board) -> bulletchess.Board:
    """
    Fast mirror implementation for bulletchess.Board.
    Mirrors the board (flips ranks 1<->8, 2<->7, etc.) and flips colors.
    """
    # Create empty board
    mirrored = bulletchess.Board.empty()
    
    # Mirror all pieces
    for square in bulletchess.SQUARES:
        piece = board[square]
        if piece is not None:
            # Calculate mirrored square: flip rank (0-7 -> 7-0), keep file
            square_idx = square.index()
            rank = square_idx // 8
            file = square_idx % 8
            mirrored_rank = 7 - rank
            mirrored_idx = mirrored_rank * 8 + file
            mirrored_square = bulletchess.SQUARES[mirrored_idx]
            
            # Flip piece color
            mirrored_color = piece.color.opposite
            mirrored[mirrored_square] = bulletchess.Piece(mirrored_color, piece.piece_type)
    
    # Mirror castling rights: swap white<->black
    # Build castling rights by checking each type and creating CastlingRights
    new_castling_types = []
    if bulletchess.WHITE_KINGSIDE in board.castling_rights:
        new_castling_types.append(bulletchess.BLACK_KINGSIDE)
    if bulletchess.WHITE_QUEENSIDE in board.castling_rights:
        new_castling_types.append(bulletchess.BLACK_QUEENSIDE)
    if bulletchess.BLACK_KINGSIDE in board.castling_rights:
        new_castling_types.append(bulletchess.WHITE_KINGSIDE)
    if bulletchess.BLACK_QUEENSIDE in board.castling_rights:
        new_castling_types.append(bulletchess.WHITE_QUEENSIDE)
    
    # Build CastlingRights from list of types
    if new_castling_types:
        mirrored.castling_rights = bulletchess.CastlingRights(new_castling_types)
    else:
        mirrored.castling_rights = bulletchess.NO_CASTLING
    
    # Flip turn
    mirrored.turn = board.turn.opposite
    
    # Mirror en passant square if exists
    if board.en_passant_square is not None:
        ep_idx = board.en_passant_square.index()
        ep_rank = ep_idx // 8
        ep_file = ep_idx % 8
        mirrored_ep_rank = 7 - ep_rank
        mirrored_ep_idx = mirrored_ep_rank * 8 + ep_file
        mirrored.en_passant_square = bulletchess.SQUARES[mirrored_ep_idx]
    
    # Copy move counters
    mirrored.halfmove_clock = board.halfmove_clock
    mirrored.fullmove_number = board.fullmove_number
    
    return mirrored


def _build_snapshots(board: bulletchess.Board) -> List[bulletchess.Board]:
    # snapshots[0] is current, snapshots[1] one ply ago, ... up to 7 plies ago
    temp = board.copy()
    snaps: List[bulletchess.Board] = [temp.copy()]
    for _ in range(7):
        # Check if there are moves to undo by checking if undo() returns None
        try:
            temp.undo()
            snaps.append(temp.copy())
        except (IndexError, AttributeError):
            # No more moves to undo
            snaps.append(None)  # type: ignore
    return snaps


def encode_moves_to_tensor(uci_moves: List[str], starting_fen: Optional[str] = None) -> Tuple[torch.Tensor, np.ndarray]:
    board = bulletchess.Board.from_fen(starting_fen) if starting_fen is not None else bulletchess.Board()
    for mv in uci_moves:
        move = bulletchess.Move.from_uci(mv)
        board.apply(move)

    # Build history snapshots (current first)
    snapshots = _build_snapshots(board)

    # Always encode from white's perspective; mirror all snapshots if black to move
    mirror = (board.turn == bulletchess.BLACK)
    if mirror:
        snapshots = [_mirror_board(s) if s is not None else None for s in snapshots]

    # Assemble 112-channel tensor
    # 8 groups: each 12 piece planes + 1 blank = 13; total 104
    channels: List[np.ndarray] = []
    for i in range(8):
        if snapshots[i] is not None:
            planes12 = _board_to_12_piece_planes(snapshots[i])  # (8,8,12)
            channels.append(planes12)
        else:
            channels.append(np.zeros((8, 8, 12), dtype=np.float32))
        # blank plane
        channels.append(np.zeros((8, 8, 1), dtype=np.float32))

    # Special planes: WQ, WK, BQ, BK, is_black_to_move, blank, blank, ones
    current_for_flags = snapshots[0]
    assert current_for_flags is not None
    castling = _castling_planes(current_for_flags)  # (4,8,8)
    is_black_to_move = 1.0 if (board.turn == bulletchess.BLACK) else 0.0
    specials = [
        castling[0:1, :, :],  # WQ
        castling[1:2, :, :],  # WK
        castling[2:3, :, :],  # BQ
        castling[3:4, :, :],  # BK
        np.full((1, 8, 8), is_black_to_move, dtype=np.float32),
        np.zeros((1, 8, 8), dtype=np.float32),
        np.zeros((1, 8, 8), dtype=np.float32),
        np.ones((1, 8, 8), dtype=np.float32),
    ]

    # Concatenate to (8,8,112)
    stacked = np.concatenate(channels, axis=2)  # (8,8,104)
    specials_hwk = np.transpose(np.concatenate(specials, axis=0), (1, 2, 0))  # (8,8,8)
    final_hwk = np.concatenate([stacked, specials_hwk], axis=2)  # (8,8,112)

    # Convert to tensor (1,112,8,8)
    final_tensor = torch.from_numpy(final_hwk).permute(2, 0, 1).unsqueeze(0).float()

    # Legal moves mask built from mirrored board to match policy_index perspective
    board_for_mask = _mirror_board(board) if (board.turn == bulletchess.BLACK) else board.copy()
    lm = np.ones(1858, dtype=np.float32) * (-1000)
    
    # Collect all legal moves as UCI strings
    legal_moves_uci = set()
    for possible in board_for_mask.legal_moves():
        u = possible.uci()
        if u[-1] != 'n':
            legal_moves_uci.add(u)
        else:
            legal_moves_uci.add(u[:-1])
    
    # Mark all legal moves
    for u in legal_moves_uci:
        idx = policy_to_idx.get(u)
        if idx is not None:
            lm[idx] = 0
    
    # Add castling moves as king-to-rook-square moves ONLY if the corresponding 
    # standard castling move is actually legal (to verify castling is possible)
    # White kingside: e1h1 (king to rook square) if e1g1 is legal
    if "e1g1" in legal_moves_uci:
        castling_move = "e1h1"
        idx = policy_to_idx.get(castling_move)
        if idx is not None:
            lm[idx] = 0
    
    # White queenside: e1a1 (king to rook square) if e1c1 is legal
    if "e1c1" in legal_moves_uci:
        castling_move = "e1a1"
        idx = policy_to_idx.get(castling_move)
        if idx is not None:
            lm[idx] = 0
    
    # Black kingside: e8h8 (king to rook square) if e8g8 is legal
    if "e8g8" in legal_moves_uci:
        castling_move = "e8h8"
        idx = policy_to_idx.get(castling_move)
        if idx is not None:
            lm[idx] = 0
    
    # Black queenside: e8a8 (king to rook square) if e8c8 is legal
    if "e8c8" in legal_moves_uci:
        castling_move = "e8a8"
        idx = policy_to_idx.get(castling_move)
        if idx is not None:
            lm[idx] = 0

    return final_tensor, lm


def encode_fen_to_tensor(fen: str) -> Tuple[torch.Tensor, np.ndarray]:
    board = bulletchess.Board.from_fen(fen)

    # History: only current snapshot, others are zeros
    snapshots = [board.copy()] + [None] * 7

    # Mirror snapshots if black to move so encoding is from white's perspective
    if board.turn == bulletchess.BLACK:
        snapshots = [_mirror_board(s) if s is not None else None for s in snapshots]

    # Assemble 112-channel tensor
    channels: List[np.ndarray] = []
    for i in range(8):
        if snapshots[i] is not None:
            planes12 = _board_to_12_piece_planes(snapshots[i])
            channels.append(planes12)
        else:
            channels.append(np.zeros((8, 8, 12), dtype=np.float32))
        channels.append(np.zeros((8, 8, 1), dtype=np.float32))

    current_for_flags = snapshots[0]
    assert current_for_flags is not None
    castling = _castling_planes(current_for_flags)
    is_black_to_move = 1.0 if (board.turn == bulletchess.BLACK) else 0.0
    specials = [
        castling[0:1, :, :],
        castling[1:2, :, :],
        castling[2:3, :, :],
        castling[3:4, :, :],
        np.full((1, 8, 8), is_black_to_move, dtype=np.float32),
        np.zeros((1, 8, 8), dtype=np.float32),
        np.zeros((1, 8, 8), dtype=np.float32),
        np.ones((1, 8, 8), dtype=np.float32),
    ]

    stacked = np.concatenate(channels, axis=2)
    specials_hwk = np.transpose(np.concatenate(specials, axis=0), (1, 2, 0))
    final_hwk = np.concatenate([stacked, specials_hwk], axis=2)

    final_tensor = torch.from_numpy(final_hwk).permute(2, 0, 1).unsqueeze(0).float()

    # Legal moves mask from mirrored perspective when black to move
    board_for_mask = _mirror_board(board) if (board.turn == bulletchess.BLACK) else board.copy()
    lm = np.ones(1858, dtype=np.float32) * (-1000)
    
    # Collect all legal moves as UCI strings
    legal_moves_uci = set()
    for possible in board_for_mask.legal_moves():
        u = possible.uci()
        if u[-1] != 'n':
            legal_moves_uci.add(u)
        else:
            legal_moves_uci.add(u[:-1])
    
    # Mark all legal moves
    for u in legal_moves_uci:
        idx = policy_to_idx.get(u)
        if idx is not None:
            lm[idx] = 0
    
    # Add castling moves as king-to-rook-square moves ONLY if the corresponding 
    # standard castling move is actually legal (to verify castling is possible)
    # White kingside: e1h1 (king to rook square) if e1g1 is legal
    if "e1g1" in legal_moves_uci:
        castling_move = "e1h1"
        idx = policy_to_idx.get(castling_move)
        if idx is not None:
            lm[idx] = 0
    
    # White queenside: e1a1 (king to rook square) if e1c1 is legal
    if "e1c1" in legal_moves_uci:
        castling_move = "e1a1"
        idx = policy_to_idx.get(castling_move)
        if idx is not None:
            lm[idx] = 0
    
    # Black kingside: e8h8 (king to rook square) if e8g8 is legal
    if "e8g8" in legal_moves_uci:
        castling_move = "e8h8"
        idx = policy_to_idx.get(castling_move)
        if idx is not None:
            lm[idx] = 0
    
    # Black queenside: e8a8 (king to rook square) if e8c8 is legal
    if "e8c8" in legal_moves_uci:
        castling_move = "e8a8"
        idx = policy_to_idx.get(castling_move)
        if idx is not None:
            lm[idx] = 0

    return final_tensor, lm