Maxlegrec commited on
Commit
cf58b05
·
verified ·
1 Parent(s): 8378d33

Upload encoding_simple.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. encoding_simple.py +317 -0
encoding_simple.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import bulletchess
4
+ from typing import List, Tuple, Optional
5
+ from .vocab import policy_index
6
+
7
+ # Precompute policy string to index mapping for O(1) lookups
8
+ policy_to_idx = {u: i for i, u in enumerate(policy_index)}
9
+
10
+
11
+ def _board_to_12_piece_planes(board: bulletchess.Board) -> np.ndarray:
12
+ piece_types = [bulletchess.PAWN, bulletchess.KNIGHT, bulletchess.BISHOP, bulletchess.ROOK, bulletchess.QUEEN, bulletchess.KING]
13
+ piece_colors = [bulletchess.WHITE, bulletchess.BLACK]
14
+
15
+ planes = []
16
+ for color in piece_colors:
17
+ for piece_type in piece_types:
18
+ mask = np.zeros((8, 8), dtype=np.float32)
19
+ # Use board[color, piece_type] to get Bitboard, then iterate over squares
20
+ bitboard = board[color, piece_type]
21
+ for square in bitboard:
22
+ # In bulletchess, squares have an index() method that returns 0-63
23
+ square_idx = square.index()
24
+ rank = square_idx // 8
25
+ file = square_idx % 8
26
+ mask[rank][file] = 1.0
27
+ planes.append(mask)
28
+ # Shape (8,8,12)
29
+ return np.transpose(np.array(planes, dtype=np.float32), (1, 2, 0))
30
+
31
+
32
+ def _castling_planes(board: bulletchess.Board) -> np.ndarray:
33
+ # Order must match existing model expectation via ustotheirs:
34
+ # [WQ, WK, BQ, BK]
35
+ wq = 1.0 if bulletchess.WHITE_QUEENSIDE in board.castling_rights else 0.0
36
+ wk = 1.0 if bulletchess.WHITE_KINGSIDE in board.castling_rights else 0.0
37
+ bq = 1.0 if bulletchess.BLACK_QUEENSIDE in board.castling_rights else 0.0
38
+ bk = 1.0 if bulletchess.BLACK_KINGSIDE in board.castling_rights else 0.0
39
+ planes = [
40
+ np.full((8, 8), wq, dtype=np.float32),
41
+ np.full((8, 8), wk, dtype=np.float32),
42
+ np.full((8, 8), bq, dtype=np.float32),
43
+ np.full((8, 8), bk, dtype=np.float32),
44
+ ]
45
+ return np.stack(planes, axis=0) # (4,8,8)
46
+
47
+
48
+ def _mirror_board(board: bulletchess.Board) -> bulletchess.Board:
49
+ """
50
+ Fast mirror implementation for bulletchess.Board.
51
+ Mirrors the board (flips ranks 1<->8, 2<->7, etc.) and flips colors.
52
+ """
53
+ # Create empty board
54
+ mirrored = bulletchess.Board.empty()
55
+
56
+ # Mirror all pieces
57
+ for square in bulletchess.SQUARES:
58
+ piece = board[square]
59
+ if piece is not None:
60
+ # Calculate mirrored square: flip rank (0-7 -> 7-0), keep file
61
+ square_idx = square.index()
62
+ rank = square_idx // 8
63
+ file = square_idx % 8
64
+ mirrored_rank = 7 - rank
65
+ mirrored_idx = mirrored_rank * 8 + file
66
+ mirrored_square = bulletchess.SQUARES[mirrored_idx]
67
+
68
+ # Flip piece color
69
+ mirrored_color = piece.color.opposite
70
+ mirrored[mirrored_square] = bulletchess.Piece(mirrored_color, piece.piece_type)
71
+
72
+ # Mirror castling rights: swap white<->black
73
+ # Build castling rights by checking each type and creating CastlingRights
74
+ new_castling_types = []
75
+ if bulletchess.WHITE_KINGSIDE in board.castling_rights:
76
+ new_castling_types.append(bulletchess.BLACK_KINGSIDE)
77
+ if bulletchess.WHITE_QUEENSIDE in board.castling_rights:
78
+ new_castling_types.append(bulletchess.BLACK_QUEENSIDE)
79
+ if bulletchess.BLACK_KINGSIDE in board.castling_rights:
80
+ new_castling_types.append(bulletchess.WHITE_KINGSIDE)
81
+ if bulletchess.BLACK_QUEENSIDE in board.castling_rights:
82
+ new_castling_types.append(bulletchess.WHITE_QUEENSIDE)
83
+
84
+ # Build CastlingRights from list of types
85
+ if new_castling_types:
86
+ mirrored.castling_rights = bulletchess.CastlingRights(new_castling_types)
87
+ else:
88
+ mirrored.castling_rights = bulletchess.NO_CASTLING
89
+
90
+ # Flip turn
91
+ mirrored.turn = board.turn.opposite
92
+
93
+ # Mirror en passant square if exists
94
+ if board.en_passant_square is not None:
95
+ ep_idx = board.en_passant_square.index()
96
+ ep_rank = ep_idx // 8
97
+ ep_file = ep_idx % 8
98
+ mirrored_ep_rank = 7 - ep_rank
99
+ mirrored_ep_idx = mirrored_ep_rank * 8 + ep_file
100
+ mirrored.en_passant_square = bulletchess.SQUARES[mirrored_ep_idx]
101
+
102
+ # Copy move counters
103
+ mirrored.halfmove_clock = board.halfmove_clock
104
+ mirrored.fullmove_number = board.fullmove_number
105
+
106
+ return mirrored
107
+
108
+
109
+ def _build_snapshots(board: bulletchess.Board) -> List[bulletchess.Board]:
110
+ # snapshots[0] is current, snapshots[1] one ply ago, ... up to 7 plies ago
111
+ temp = board.copy()
112
+ snaps: List[bulletchess.Board] = [temp.copy()]
113
+ for _ in range(7):
114
+ # Check if there are moves to undo by checking if undo() returns None
115
+ try:
116
+ temp.undo()
117
+ snaps.append(temp.copy())
118
+ except (IndexError, AttributeError):
119
+ # No more moves to undo
120
+ snaps.append(None) # type: ignore
121
+ return snaps
122
+
123
+
124
+ def encode_moves_to_tensor(uci_moves: List[str], starting_fen: Optional[str] = None) -> Tuple[torch.Tensor, np.ndarray]:
125
+ board = bulletchess.Board.from_fen(starting_fen) if starting_fen is not None else bulletchess.Board()
126
+ for mv in uci_moves:
127
+ move = bulletchess.Move.from_uci(mv)
128
+ board.apply(move)
129
+
130
+ # Build history snapshots (current first)
131
+ snapshots = _build_snapshots(board)
132
+
133
+ # Always encode from white's perspective; mirror all snapshots if black to move
134
+ mirror = (board.turn == bulletchess.BLACK)
135
+ if mirror:
136
+ snapshots = [_mirror_board(s) if s is not None else None for s in snapshots]
137
+
138
+ # Assemble 112-channel tensor
139
+ # 8 groups: each 12 piece planes + 1 blank = 13; total 104
140
+ channels: List[np.ndarray] = []
141
+ for i in range(8):
142
+ if snapshots[i] is not None:
143
+ planes12 = _board_to_12_piece_planes(snapshots[i]) # (8,8,12)
144
+ channels.append(planes12)
145
+ else:
146
+ channels.append(np.zeros((8, 8, 12), dtype=np.float32))
147
+ # blank plane
148
+ channels.append(np.zeros((8, 8, 1), dtype=np.float32))
149
+
150
+ # Special planes: WQ, WK, BQ, BK, is_black_to_move, blank, blank, ones
151
+ current_for_flags = snapshots[0]
152
+ assert current_for_flags is not None
153
+ castling = _castling_planes(current_for_flags) # (4,8,8)
154
+ is_black_to_move = 1.0 if (board.turn == bulletchess.BLACK) else 0.0
155
+ specials = [
156
+ castling[0:1, :, :], # WQ
157
+ castling[1:2, :, :], # WK
158
+ castling[2:3, :, :], # BQ
159
+ castling[3:4, :, :], # BK
160
+ np.full((1, 8, 8), is_black_to_move, dtype=np.float32),
161
+ np.zeros((1, 8, 8), dtype=np.float32),
162
+ np.zeros((1, 8, 8), dtype=np.float32),
163
+ np.ones((1, 8, 8), dtype=np.float32),
164
+ ]
165
+
166
+ # Concatenate to (8,8,112)
167
+ stacked = np.concatenate(channels, axis=2) # (8,8,104)
168
+ specials_hwk = np.transpose(np.concatenate(specials, axis=0), (1, 2, 0)) # (8,8,8)
169
+ final_hwk = np.concatenate([stacked, specials_hwk], axis=2) # (8,8,112)
170
+
171
+ # Convert to tensor (1,112,8,8)
172
+ final_tensor = torch.from_numpy(final_hwk).permute(2, 0, 1).unsqueeze(0).float()
173
+
174
+ # Legal moves mask built from mirrored board to match policy_index perspective
175
+ board_for_mask = _mirror_board(board) if (board.turn == bulletchess.BLACK) else board.copy()
176
+ lm = np.ones(1858, dtype=np.float32) * (-1000)
177
+
178
+ # Collect all legal moves as UCI strings
179
+ legal_moves_uci = set()
180
+ for possible in board_for_mask.legal_moves():
181
+ u = possible.uci()
182
+ if u[-1] != 'n':
183
+ legal_moves_uci.add(u)
184
+ else:
185
+ legal_moves_uci.add(u[:-1])
186
+
187
+ # Mark all legal moves
188
+ for u in legal_moves_uci:
189
+ idx = policy_to_idx.get(u)
190
+ if idx is not None:
191
+ lm[idx] = 0
192
+
193
+ # Add castling moves as king-to-rook-square moves ONLY if the corresponding
194
+ # standard castling move is actually legal (to verify castling is possible)
195
+ # White kingside: e1h1 (king to rook square) if e1g1 is legal
196
+ if "e1g1" in legal_moves_uci:
197
+ castling_move = "e1h1"
198
+ idx = policy_to_idx.get(castling_move)
199
+ if idx is not None:
200
+ lm[idx] = 0
201
+
202
+ # White queenside: e1a1 (king to rook square) if e1c1 is legal
203
+ if "e1c1" in legal_moves_uci:
204
+ castling_move = "e1a1"
205
+ idx = policy_to_idx.get(castling_move)
206
+ if idx is not None:
207
+ lm[idx] = 0
208
+
209
+ # Black kingside: e8h8 (king to rook square) if e8g8 is legal
210
+ if "e8g8" in legal_moves_uci:
211
+ castling_move = "e8h8"
212
+ idx = policy_to_idx.get(castling_move)
213
+ if idx is not None:
214
+ lm[idx] = 0
215
+
216
+ # Black queenside: e8a8 (king to rook square) if e8c8 is legal
217
+ if "e8c8" in legal_moves_uci:
218
+ castling_move = "e8a8"
219
+ idx = policy_to_idx.get(castling_move)
220
+ if idx is not None:
221
+ lm[idx] = 0
222
+
223
+ return final_tensor, lm
224
+
225
+
226
+ def encode_fen_to_tensor(fen: str) -> Tuple[torch.Tensor, np.ndarray]:
227
+ board = bulletchess.Board.from_fen(fen)
228
+
229
+ # History: only current snapshot, others are zeros
230
+ snapshots = [board.copy()] + [None] * 7
231
+
232
+ # Mirror snapshots if black to move so encoding is from white's perspective
233
+ if board.turn == bulletchess.BLACK:
234
+ snapshots = [_mirror_board(s) if s is not None else None for s in snapshots]
235
+
236
+ # Assemble 112-channel tensor
237
+ channels: List[np.ndarray] = []
238
+ for i in range(8):
239
+ if snapshots[i] is not None:
240
+ planes12 = _board_to_12_piece_planes(snapshots[i])
241
+ channels.append(planes12)
242
+ else:
243
+ channels.append(np.zeros((8, 8, 12), dtype=np.float32))
244
+ channels.append(np.zeros((8, 8, 1), dtype=np.float32))
245
+
246
+ current_for_flags = snapshots[0]
247
+ assert current_for_flags is not None
248
+ castling = _castling_planes(current_for_flags)
249
+ is_black_to_move = 1.0 if (board.turn == bulletchess.BLACK) else 0.0
250
+ specials = [
251
+ castling[0:1, :, :],
252
+ castling[1:2, :, :],
253
+ castling[2:3, :, :],
254
+ castling[3:4, :, :],
255
+ np.full((1, 8, 8), is_black_to_move, dtype=np.float32),
256
+ np.zeros((1, 8, 8), dtype=np.float32),
257
+ np.zeros((1, 8, 8), dtype=np.float32),
258
+ np.ones((1, 8, 8), dtype=np.float32),
259
+ ]
260
+
261
+ stacked = np.concatenate(channels, axis=2)
262
+ specials_hwk = np.transpose(np.concatenate(specials, axis=0), (1, 2, 0))
263
+ final_hwk = np.concatenate([stacked, specials_hwk], axis=2)
264
+
265
+ final_tensor = torch.from_numpy(final_hwk).permute(2, 0, 1).unsqueeze(0).float()
266
+
267
+ # Legal moves mask from mirrored perspective when black to move
268
+ board_for_mask = _mirror_board(board) if (board.turn == bulletchess.BLACK) else board.copy()
269
+ lm = np.ones(1858, dtype=np.float32) * (-1000)
270
+
271
+ # Collect all legal moves as UCI strings
272
+ legal_moves_uci = set()
273
+ for possible in board_for_mask.legal_moves():
274
+ u = possible.uci()
275
+ if u[-1] != 'n':
276
+ legal_moves_uci.add(u)
277
+ else:
278
+ legal_moves_uci.add(u[:-1])
279
+
280
+ # Mark all legal moves
281
+ for u in legal_moves_uci:
282
+ idx = policy_to_idx.get(u)
283
+ if idx is not None:
284
+ lm[idx] = 0
285
+
286
+ # Add castling moves as king-to-rook-square moves ONLY if the corresponding
287
+ # standard castling move is actually legal (to verify castling is possible)
288
+ # White kingside: e1h1 (king to rook square) if e1g1 is legal
289
+ if "e1g1" in legal_moves_uci:
290
+ castling_move = "e1h1"
291
+ idx = policy_to_idx.get(castling_move)
292
+ if idx is not None:
293
+ lm[idx] = 0
294
+
295
+ # White queenside: e1a1 (king to rook square) if e1c1 is legal
296
+ if "e1c1" in legal_moves_uci:
297
+ castling_move = "e1a1"
298
+ idx = policy_to_idx.get(castling_move)
299
+ if idx is not None:
300
+ lm[idx] = 0
301
+
302
+ # Black kingside: e8h8 (king to rook square) if e8g8 is legal
303
+ if "e8g8" in legal_moves_uci:
304
+ castling_move = "e8h8"
305
+ idx = policy_to_idx.get(castling_move)
306
+ if idx is not None:
307
+ lm[idx] = 0
308
+
309
+ # Black queenside: e8a8 (king to rook square) if e8c8 is legal
310
+ if "e8c8" in legal_moves_uci:
311
+ castling_move = "e8a8"
312
+ idx = policy_to_idx.get(castling_move)
313
+ if idx is not None:
314
+ lm[idx] = 0
315
+
316
+ return final_tensor, lm
317
+