import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import chess MOVES_PER_SQUARE = 73 POLICY_SIZE = 64 * MOVES_PER_SQUARE class ResidualBlock(nn.Module): def __init__(self, channels: int) -> None: super().__init__() self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) def forward(self, x: torch.Tensor) -> torch.Tensor: residual = x out = F.relu(self.conv1(x)) out = self.conv2(out) out = out + residual return F.relu(out) class TinyPCN(nn.Module): def __init__(self, board_channels: int = 18, policy_size: int = POLICY_SIZE) -> None: """Tiny policy-value net: shared trunk plus separate heads.""" super().__init__() self.conv1 = nn.Conv2d(board_channels, 32, kernel_size=3, padding=1) self.conv2 = nn.Conv2d(32, 32, kernel_size=3, padding=1) self.res_block = ResidualBlock(32) self.policy_conv = nn.Conv2d(32, 32, kernel_size=1) self.policy_fc = nn.Linear(32 * 8 * 8, policy_size) self.value_conv = nn.Conv2d(32, 1, kernel_size=1) self.value_fc1 = nn.Linear(8 * 8, 64) self.value_fc2 = nn.Linear(64, 1) def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: x = F.relu(self.conv1(x)) x = F.relu(self.conv2(x)) x = self.res_block(x) p = F.relu(self.policy_conv(x)) p = p.view(p.size(0), -1) policy_logits = self.policy_fc(p) v = F.relu(self.value_conv(x)) v = v.view(v.size(0), -1) v = F.relu(self.value_fc1(v)) value = torch.tanh(self.value_fc2(v)) return policy_logits, value def board_to_18_planes(board: chess.Board) -> torch.FloatTensor: """Return 18 AlphaZero-style planes for the given board.""" planes = np.zeros((18, 8, 8), dtype=np.float32) for square, piece in board.piece_map().items(): row = 7 - (square // 8) col = square % 8 color_offset = 0 if piece.color == chess.WHITE else 6 plane_idx = (piece.piece_type - 1) + color_offset planes[plane_idx, row, col] = 1.0 planes[12, :, :] = 1.0 if board.has_kingside_castling_rights(chess.WHITE) else 0.0 planes[13, :, :] = 1.0 if board.has_queenside_castling_rights(chess.WHITE) else 0.0 planes[14, :, :] = 1.0 if board.has_kingside_castling_rights(chess.BLACK) else 0.0 planes[15, :, :] = 1.0 if board.has_queenside_castling_rights(chess.BLACK) else 0.0 planes[16, :, :] = 1.0 if board.turn == chess.WHITE else 0.0 if board.ep_square is not None: ep_row = 7 - (board.ep_square // 8) ep_col = board.ep_square % 8 planes[17, ep_row, ep_col] = 1.0 return torch.from_numpy(planes) def board_to_20_planes(board: chess.Board) -> torch.FloatTensor: """Return 20 planes (18 standard plus repetition and move count).""" planes18 = board_to_18_planes(board).numpy() extra = np.zeros((2, 8, 8), dtype=np.float32) try: repetition = board.is_repetition() except Exception: repetition = False extra[0, :, :] = 1.0 if repetition else 0.0 move_norm = min(board.fullmove_number / 100.0, 1.0) extra[1, :, :] = float(move_norm) planes20 = np.concatenate([planes18, extra], axis=0) return torch.from_numpy(planes20) def encode_board(board: chess.Board, variant: str = "18") -> torch.FloatTensor: if variant == "18": return board_to_18_planes(board) if variant == "20": return board_to_20_planes(board) raise ValueError("variant must be '18' or '20'") _RAY_OFFSETS = ((1, 0), (-1, 0), (0, 1), (0, -1), (1, 1), (1, -1), (-1, 1), (-1, -1)) _KNIGHT_OFFSETS = ((2, 1), (1, 2), (-1, 2), (-2, 1), (-2, -1), (-1, -2), (1, -2), (2, -1)) _PROMOTION_OFFSETS = ((1, 0), (1, 1), (1, -1), (2, 0)) _PROMOTION_PIECES = ("q", "r", "b", "n") _move_to_index: dict[tuple[int, int, str | None], int] = {} _index_to_move: dict[int, tuple[int, int, str | None]] = {} def _init_move_tables() -> None: idx = 0 for sq in range(64): row0, col0 = divmod(sq, 8) for dx, dy in _RAY_OFFSETS: for step in range(1, 8): row = row0 + dx * step col = col0 + dy * step if 0 <= row < 8 and 0 <= col < 8: target = row * 8 + col _move_to_index[(sq, target, None)] = idx _index_to_move[idx] = (sq, target, None) idx += 1 for dx, dy in _KNIGHT_OFFSETS: row = row0 + dx col = col0 + dy if 0 <= row < 8 and 0 <= col < 8: target = row * 8 + col _move_to_index[(sq, target, None)] = idx _index_to_move[idx] = (sq, target, None) idx += 1 for dx, dy in _PROMOTION_OFFSETS: row = row0 + dx col = col0 + dy if 0 <= row < 8 and 0 <= col < 8: target = row * 8 + col for promo in _PROMOTION_PIECES: _move_to_index[(sq, target, promo)] = idx _index_to_move[idx] = (sq, target, promo) idx += 1 else: idx += len(_PROMOTION_PIECES) while idx % MOVES_PER_SQUARE != 0: _index_to_move[idx] = None idx += 1 _init_move_tables() def _promotion_symbol(piece_type: int | None) -> str | None: if piece_type is None: return None return chess.Piece(piece_type, chess.WHITE).symbol().lower() def encode_move(move: chess.Move, board: chess.Board) -> int: from_sq = move.from_square to_sq = move.to_square promo_symbol = _promotion_symbol(move.promotion) if board.color_at(move.from_square) == chess.BLACK: from_sq = chess.square_mirror(from_sq) to_sq = chess.square_mirror(to_sq) key = (from_sq, to_sq, promo_symbol) return _move_to_index.get(key, -1) def decode_move(index: int, board: chess.Board | None = None) -> chess.Move | None: triple = _index_to_move.get(index) if triple is None: return None from_sq, to_sq, promo = triple if board is not None and board.turn == chess.BLACK: from_sq = chess.square_mirror(from_sq) to_sq = chess.square_mirror(to_sq) promotion = chess.Piece.from_symbol(promo.upper()).piece_type if promo else None return chess.Move(from_sq, to_sq, promotion=promotion)