chess-bot / model.py
AubreeL's picture
Upload model.py with huggingface_hub
e763fc3 verified
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import chess
MOVES_PER_SQUARE = 73
POLICY_SIZE = 64 * MOVES_PER_SQUARE
class ResidualBlock(nn.Module):
def __init__(self, channels: int) -> None:
super().__init__()
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
out = F.relu(self.conv1(x))
out = self.conv2(out)
out = out + residual
return F.relu(out)
class TinyPCN(nn.Module):
def __init__(self, board_channels: int = 18, policy_size: int = POLICY_SIZE) -> None:
"""Tiny policy-value net: shared trunk plus separate heads."""
super().__init__()
self.conv1 = nn.Conv2d(board_channels, 32, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(32, 32, kernel_size=3, padding=1)
self.res_block = ResidualBlock(32)
self.policy_conv = nn.Conv2d(32, 32, kernel_size=1)
self.policy_fc = nn.Linear(32 * 8 * 8, policy_size)
self.value_conv = nn.Conv2d(32, 1, kernel_size=1)
self.value_fc1 = nn.Linear(8 * 8, 64)
self.value_fc2 = nn.Linear(64, 1)
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = self.res_block(x)
p = F.relu(self.policy_conv(x))
p = p.view(p.size(0), -1)
policy_logits = self.policy_fc(p)
v = F.relu(self.value_conv(x))
v = v.view(v.size(0), -1)
v = F.relu(self.value_fc1(v))
value = torch.tanh(self.value_fc2(v))
return policy_logits, value
def board_to_18_planes(board: chess.Board) -> torch.FloatTensor:
"""Return 18 AlphaZero-style planes for the given board."""
planes = np.zeros((18, 8, 8), dtype=np.float32)
for square, piece in board.piece_map().items():
row = 7 - (square // 8)
col = square % 8
color_offset = 0 if piece.color == chess.WHITE else 6
plane_idx = (piece.piece_type - 1) + color_offset
planes[plane_idx, row, col] = 1.0
planes[12, :, :] = 1.0 if board.has_kingside_castling_rights(chess.WHITE) else 0.0
planes[13, :, :] = 1.0 if board.has_queenside_castling_rights(chess.WHITE) else 0.0
planes[14, :, :] = 1.0 if board.has_kingside_castling_rights(chess.BLACK) else 0.0
planes[15, :, :] = 1.0 if board.has_queenside_castling_rights(chess.BLACK) else 0.0
planes[16, :, :] = 1.0 if board.turn == chess.WHITE else 0.0
if board.ep_square is not None:
ep_row = 7 - (board.ep_square // 8)
ep_col = board.ep_square % 8
planes[17, ep_row, ep_col] = 1.0
return torch.from_numpy(planes)
def board_to_20_planes(board: chess.Board) -> torch.FloatTensor:
"""Return 20 planes (18 standard plus repetition and move count)."""
planes18 = board_to_18_planes(board).numpy()
extra = np.zeros((2, 8, 8), dtype=np.float32)
try:
repetition = board.is_repetition()
except Exception:
repetition = False
extra[0, :, :] = 1.0 if repetition else 0.0
move_norm = min(board.fullmove_number / 100.0, 1.0)
extra[1, :, :] = float(move_norm)
planes20 = np.concatenate([planes18, extra], axis=0)
return torch.from_numpy(planes20)
def encode_board(board: chess.Board, variant: str = "18") -> torch.FloatTensor:
if variant == "18":
return board_to_18_planes(board)
if variant == "20":
return board_to_20_planes(board)
raise ValueError("variant must be '18' or '20'")
_RAY_OFFSETS = ((1, 0), (-1, 0), (0, 1), (0, -1), (1, 1), (1, -1), (-1, 1), (-1, -1))
_KNIGHT_OFFSETS = ((2, 1), (1, 2), (-1, 2), (-2, 1), (-2, -1), (-1, -2), (1, -2), (2, -1))
_PROMOTION_OFFSETS = ((1, 0), (1, 1), (1, -1), (2, 0))
_PROMOTION_PIECES = ("q", "r", "b", "n")
_move_to_index: dict[tuple[int, int, str | None], int] = {}
_index_to_move: dict[int, tuple[int, int, str | None]] = {}
def _init_move_tables() -> None:
idx = 0
for sq in range(64):
row0, col0 = divmod(sq, 8)
for dx, dy in _RAY_OFFSETS:
for step in range(1, 8):
row = row0 + dx * step
col = col0 + dy * step
if 0 <= row < 8 and 0 <= col < 8:
target = row * 8 + col
_move_to_index[(sq, target, None)] = idx
_index_to_move[idx] = (sq, target, None)
idx += 1
for dx, dy in _KNIGHT_OFFSETS:
row = row0 + dx
col = col0 + dy
if 0 <= row < 8 and 0 <= col < 8:
target = row * 8 + col
_move_to_index[(sq, target, None)] = idx
_index_to_move[idx] = (sq, target, None)
idx += 1
for dx, dy in _PROMOTION_OFFSETS:
row = row0 + dx
col = col0 + dy
if 0 <= row < 8 and 0 <= col < 8:
target = row * 8 + col
for promo in _PROMOTION_PIECES:
_move_to_index[(sq, target, promo)] = idx
_index_to_move[idx] = (sq, target, promo)
idx += 1
else:
idx += len(_PROMOTION_PIECES)
while idx % MOVES_PER_SQUARE != 0:
_index_to_move[idx] = None
idx += 1
_init_move_tables()
def _promotion_symbol(piece_type: int | None) -> str | None:
if piece_type is None:
return None
return chess.Piece(piece_type, chess.WHITE).symbol().lower()
def encode_move(move: chess.Move, board: chess.Board) -> int:
from_sq = move.from_square
to_sq = move.to_square
promo_symbol = _promotion_symbol(move.promotion)
if board.color_at(move.from_square) == chess.BLACK:
from_sq = chess.square_mirror(from_sq)
to_sq = chess.square_mirror(to_sq)
key = (from_sq, to_sq, promo_symbol)
return _move_to_index.get(key, -1)
def decode_move(index: int, board: chess.Board | None = None) -> chess.Move | None:
triple = _index_to_move.get(index)
if triple is None:
return None
from_sq, to_sq, promo = triple
if board is not None and board.turn == chess.BLACK:
from_sq = chess.square_mirror(from_sq)
to_sq = chess.square_mirror(to_sq)
promotion = chess.Piece.from_symbol(promo.upper()).piece_type if promo else None
return chess.Move(from_sq, to_sq, promotion=promotion)