|
|
import numpy as np |
|
|
import torch |
|
|
import bulletchess |
|
|
from typing import List, Tuple, Optional |
|
|
from .vocab import policy_index |
|
|
|
|
|
|
|
|
policy_to_idx = {u: i for i, u in enumerate(policy_index)} |
|
|
|
|
|
|
|
|
def _board_to_12_piece_planes(board: bulletchess.Board) -> np.ndarray: |
|
|
piece_types = [bulletchess.PAWN, bulletchess.KNIGHT, bulletchess.BISHOP, bulletchess.ROOK, bulletchess.QUEEN, bulletchess.KING] |
|
|
piece_colors = [bulletchess.WHITE, bulletchess.BLACK] |
|
|
|
|
|
planes = [] |
|
|
for color in piece_colors: |
|
|
for piece_type in piece_types: |
|
|
mask = np.zeros((8, 8), dtype=np.float32) |
|
|
|
|
|
bitboard = board[color, piece_type] |
|
|
for square in bitboard: |
|
|
|
|
|
square_idx = square.index() |
|
|
rank = square_idx // 8 |
|
|
file = square_idx % 8 |
|
|
mask[rank][file] = 1.0 |
|
|
planes.append(mask) |
|
|
|
|
|
return np.transpose(np.array(planes, dtype=np.float32), (1, 2, 0)) |
|
|
|
|
|
|
|
|
def _castling_planes(board: bulletchess.Board) -> np.ndarray: |
|
|
|
|
|
|
|
|
wq = 1.0 if bulletchess.WHITE_QUEENSIDE in board.castling_rights else 0.0 |
|
|
wk = 1.0 if bulletchess.WHITE_KINGSIDE in board.castling_rights else 0.0 |
|
|
bq = 1.0 if bulletchess.BLACK_QUEENSIDE in board.castling_rights else 0.0 |
|
|
bk = 1.0 if bulletchess.BLACK_KINGSIDE in board.castling_rights else 0.0 |
|
|
planes = [ |
|
|
np.full((8, 8), wq, dtype=np.float32), |
|
|
np.full((8, 8), wk, dtype=np.float32), |
|
|
np.full((8, 8), bq, dtype=np.float32), |
|
|
np.full((8, 8), bk, dtype=np.float32), |
|
|
] |
|
|
return np.stack(planes, axis=0) |
|
|
|
|
|
|
|
|
def _mirror_board(board: bulletchess.Board) -> bulletchess.Board: |
|
|
""" |
|
|
Fast mirror implementation for bulletchess.Board. |
|
|
Mirrors the board (flips ranks 1<->8, 2<->7, etc.) and flips colors. |
|
|
""" |
|
|
|
|
|
mirrored = bulletchess.Board.empty() |
|
|
|
|
|
|
|
|
for square in bulletchess.SQUARES: |
|
|
piece = board[square] |
|
|
if piece is not None: |
|
|
|
|
|
square_idx = square.index() |
|
|
rank = square_idx // 8 |
|
|
file = square_idx % 8 |
|
|
mirrored_rank = 7 - rank |
|
|
mirrored_idx = mirrored_rank * 8 + file |
|
|
mirrored_square = bulletchess.SQUARES[mirrored_idx] |
|
|
|
|
|
|
|
|
mirrored_color = piece.color.opposite |
|
|
mirrored[mirrored_square] = bulletchess.Piece(mirrored_color, piece.piece_type) |
|
|
|
|
|
|
|
|
|
|
|
new_castling_types = [] |
|
|
if bulletchess.WHITE_KINGSIDE in board.castling_rights: |
|
|
new_castling_types.append(bulletchess.BLACK_KINGSIDE) |
|
|
if bulletchess.WHITE_QUEENSIDE in board.castling_rights: |
|
|
new_castling_types.append(bulletchess.BLACK_QUEENSIDE) |
|
|
if bulletchess.BLACK_KINGSIDE in board.castling_rights: |
|
|
new_castling_types.append(bulletchess.WHITE_KINGSIDE) |
|
|
if bulletchess.BLACK_QUEENSIDE in board.castling_rights: |
|
|
new_castling_types.append(bulletchess.WHITE_QUEENSIDE) |
|
|
|
|
|
|
|
|
if new_castling_types: |
|
|
mirrored.castling_rights = bulletchess.CastlingRights(new_castling_types) |
|
|
else: |
|
|
mirrored.castling_rights = bulletchess.NO_CASTLING |
|
|
|
|
|
|
|
|
mirrored.turn = board.turn.opposite |
|
|
|
|
|
|
|
|
if board.en_passant_square is not None: |
|
|
ep_idx = board.en_passant_square.index() |
|
|
ep_rank = ep_idx // 8 |
|
|
ep_file = ep_idx % 8 |
|
|
mirrored_ep_rank = 7 - ep_rank |
|
|
mirrored_ep_idx = mirrored_ep_rank * 8 + ep_file |
|
|
mirrored.en_passant_square = bulletchess.SQUARES[mirrored_ep_idx] |
|
|
|
|
|
|
|
|
mirrored.halfmove_clock = board.halfmove_clock |
|
|
mirrored.fullmove_number = board.fullmove_number |
|
|
|
|
|
return mirrored |
|
|
|
|
|
|
|
|
def _build_snapshots(board: bulletchess.Board) -> List[bulletchess.Board]: |
|
|
|
|
|
temp = board.copy() |
|
|
snaps: List[bulletchess.Board] = [temp.copy()] |
|
|
for _ in range(7): |
|
|
|
|
|
try: |
|
|
temp.undo() |
|
|
snaps.append(temp.copy()) |
|
|
except (IndexError, AttributeError): |
|
|
|
|
|
snaps.append(None) |
|
|
return snaps |
|
|
|
|
|
|
|
|
def encode_moves_to_tensor(uci_moves: List[str], starting_fen: Optional[str] = None) -> Tuple[torch.Tensor, np.ndarray]: |
|
|
board = bulletchess.Board.from_fen(starting_fen) if starting_fen is not None else bulletchess.Board() |
|
|
for mv in uci_moves: |
|
|
move = bulletchess.Move.from_uci(mv) |
|
|
board.apply(move) |
|
|
|
|
|
|
|
|
snapshots = _build_snapshots(board) |
|
|
|
|
|
|
|
|
mirror = (board.turn == bulletchess.BLACK) |
|
|
if mirror: |
|
|
snapshots = [_mirror_board(s) if s is not None else None for s in snapshots] |
|
|
|
|
|
|
|
|
|
|
|
channels: List[np.ndarray] = [] |
|
|
for i in range(8): |
|
|
if snapshots[i] is not None: |
|
|
planes12 = _board_to_12_piece_planes(snapshots[i]) |
|
|
channels.append(planes12) |
|
|
else: |
|
|
channels.append(np.zeros((8, 8, 12), dtype=np.float32)) |
|
|
|
|
|
channels.append(np.zeros((8, 8, 1), dtype=np.float32)) |
|
|
|
|
|
|
|
|
current_for_flags = snapshots[0] |
|
|
assert current_for_flags is not None |
|
|
castling = _castling_planes(current_for_flags) |
|
|
is_black_to_move = 1.0 if (board.turn == bulletchess.BLACK) else 0.0 |
|
|
specials = [ |
|
|
castling[0:1, :, :], |
|
|
castling[1:2, :, :], |
|
|
castling[2:3, :, :], |
|
|
castling[3:4, :, :], |
|
|
np.full((1, 8, 8), is_black_to_move, dtype=np.float32), |
|
|
np.zeros((1, 8, 8), dtype=np.float32), |
|
|
np.zeros((1, 8, 8), dtype=np.float32), |
|
|
np.ones((1, 8, 8), dtype=np.float32), |
|
|
] |
|
|
|
|
|
|
|
|
stacked = np.concatenate(channels, axis=2) |
|
|
specials_hwk = np.transpose(np.concatenate(specials, axis=0), (1, 2, 0)) |
|
|
final_hwk = np.concatenate([stacked, specials_hwk], axis=2) |
|
|
|
|
|
|
|
|
final_tensor = torch.from_numpy(final_hwk).permute(2, 0, 1).unsqueeze(0).float() |
|
|
|
|
|
|
|
|
board_for_mask = _mirror_board(board) if (board.turn == bulletchess.BLACK) else board.copy() |
|
|
lm = np.ones(1858, dtype=np.float32) * (-1000) |
|
|
|
|
|
|
|
|
legal_moves_uci = set() |
|
|
for possible in board_for_mask.legal_moves(): |
|
|
u = possible.uci() |
|
|
if u[-1] != 'n': |
|
|
legal_moves_uci.add(u) |
|
|
else: |
|
|
legal_moves_uci.add(u[:-1]) |
|
|
|
|
|
|
|
|
for u in legal_moves_uci: |
|
|
idx = policy_to_idx.get(u) |
|
|
if idx is not None: |
|
|
lm[idx] = 0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if "e1g1" in legal_moves_uci: |
|
|
castling_move = "e1h1" |
|
|
idx = policy_to_idx.get(castling_move) |
|
|
if idx is not None: |
|
|
lm[idx] = 0 |
|
|
|
|
|
|
|
|
if "e1c1" in legal_moves_uci: |
|
|
castling_move = "e1a1" |
|
|
idx = policy_to_idx.get(castling_move) |
|
|
if idx is not None: |
|
|
lm[idx] = 0 |
|
|
|
|
|
|
|
|
if "e8g8" in legal_moves_uci: |
|
|
castling_move = "e8h8" |
|
|
idx = policy_to_idx.get(castling_move) |
|
|
if idx is not None: |
|
|
lm[idx] = 0 |
|
|
|
|
|
|
|
|
if "e8c8" in legal_moves_uci: |
|
|
castling_move = "e8a8" |
|
|
idx = policy_to_idx.get(castling_move) |
|
|
if idx is not None: |
|
|
lm[idx] = 0 |
|
|
|
|
|
return final_tensor, lm |
|
|
|
|
|
|
|
|
def encode_fen_to_tensor(fen: str) -> Tuple[torch.Tensor, np.ndarray]: |
|
|
board = bulletchess.Board.from_fen(fen) |
|
|
|
|
|
|
|
|
snapshots = [board.copy()] + [None] * 7 |
|
|
|
|
|
|
|
|
if board.turn == bulletchess.BLACK: |
|
|
snapshots = [_mirror_board(s) if s is not None else None for s in snapshots] |
|
|
|
|
|
|
|
|
channels: List[np.ndarray] = [] |
|
|
for i in range(8): |
|
|
if snapshots[i] is not None: |
|
|
planes12 = _board_to_12_piece_planes(snapshots[i]) |
|
|
channels.append(planes12) |
|
|
else: |
|
|
channels.append(np.zeros((8, 8, 12), dtype=np.float32)) |
|
|
channels.append(np.zeros((8, 8, 1), dtype=np.float32)) |
|
|
|
|
|
current_for_flags = snapshots[0] |
|
|
assert current_for_flags is not None |
|
|
castling = _castling_planes(current_for_flags) |
|
|
is_black_to_move = 1.0 if (board.turn == bulletchess.BLACK) else 0.0 |
|
|
specials = [ |
|
|
castling[0:1, :, :], |
|
|
castling[1:2, :, :], |
|
|
castling[2:3, :, :], |
|
|
castling[3:4, :, :], |
|
|
np.full((1, 8, 8), is_black_to_move, dtype=np.float32), |
|
|
np.zeros((1, 8, 8), dtype=np.float32), |
|
|
np.zeros((1, 8, 8), dtype=np.float32), |
|
|
np.ones((1, 8, 8), dtype=np.float32), |
|
|
] |
|
|
|
|
|
stacked = np.concatenate(channels, axis=2) |
|
|
specials_hwk = np.transpose(np.concatenate(specials, axis=0), (1, 2, 0)) |
|
|
final_hwk = np.concatenate([stacked, specials_hwk], axis=2) |
|
|
|
|
|
final_tensor = torch.from_numpy(final_hwk).permute(2, 0, 1).unsqueeze(0).float() |
|
|
|
|
|
|
|
|
board_for_mask = _mirror_board(board) if (board.turn == bulletchess.BLACK) else board.copy() |
|
|
lm = np.ones(1858, dtype=np.float32) * (-1000) |
|
|
|
|
|
|
|
|
legal_moves_uci = set() |
|
|
for possible in board_for_mask.legal_moves(): |
|
|
u = possible.uci() |
|
|
if u[-1] != 'n': |
|
|
legal_moves_uci.add(u) |
|
|
else: |
|
|
legal_moves_uci.add(u[:-1]) |
|
|
|
|
|
|
|
|
for u in legal_moves_uci: |
|
|
idx = policy_to_idx.get(u) |
|
|
if idx is not None: |
|
|
lm[idx] = 0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if "e1g1" in legal_moves_uci: |
|
|
castling_move = "e1h1" |
|
|
idx = policy_to_idx.get(castling_move) |
|
|
if idx is not None: |
|
|
lm[idx] = 0 |
|
|
|
|
|
|
|
|
if "e1c1" in legal_moves_uci: |
|
|
castling_move = "e1a1" |
|
|
idx = policy_to_idx.get(castling_move) |
|
|
if idx is not None: |
|
|
lm[idx] = 0 |
|
|
|
|
|
|
|
|
if "e8g8" in legal_moves_uci: |
|
|
castling_move = "e8h8" |
|
|
idx = policy_to_idx.get(castling_move) |
|
|
if idx is not None: |
|
|
lm[idx] = 0 |
|
|
|
|
|
|
|
|
if "e8c8" in legal_moves_uci: |
|
|
castling_move = "e8a8" |
|
|
idx = policy_to_idx.get(castling_move) |
|
|
if idx is not None: |
|
|
lm[idx] = 0 |
|
|
|
|
|
return final_tensor, lm |
|
|
|
|
|
|