import math import torch import torch.nn as nn import chess from src.tokenizer import Tokenizer CLS_TOKEN = "[CLS]" PAD_TOKEN = "[PAD]" PIECE_VALUES = { chess.PAWN: 1, chess.KNIGHT: 3, chess.BISHOP: 3, chess.ROOK: 5, chess.QUEEN: 9, chess.KING: 0, } BOARD_PLANES = 19 def board_to_planes(board: chess.Board) -> torch.Tensor: """chess.Board -> (19, 8, 8) float tensor.""" planes = torch.zeros(BOARD_PLANES, 8, 8, dtype=torch.float32) pieces = [chess.PAWN, chess.KNIGHT, chess.BISHOP, chess.ROOK, chess.QUEEN, chess.KING] colors = [chess.WHITE, chess.BLACK] piece_to_plane = {(piece, color) : 6 * color_num + piece_num for piece_num, piece in enumerate(pieces) for color_num, color in enumerate(colors)} for sq, piece in board.piece_map().items(): r, c = chess.square_rank(sq), chess.square_file(sq) planes[piece_to_plane[(piece.piece_type, piece.color)], r, c] = 1.0 if board.turn == chess.WHITE: planes[12].fill_(1.0) if board.has_kingside_castling_rights(chess.WHITE): planes[13].fill_(1.0) if board.has_queenside_castling_rights(chess.WHITE): planes[14].fill_(1.0) if board.has_kingside_castling_rights(chess.BLACK): planes[15].fill_(1.0) if board.has_queenside_castling_rights(chess.BLACK): planes[16].fill_(1.0) if board.ep_square is not None: r, c = chess.square_rank(board.ep_square), chess.square_file(board.ep_square) planes[17, r, c] = 1.0 planes[18].fill_(min(board.halfmove_clock, 100) / 100.0) return planes def _group_norm(channels: int, groups: int = 32) -> nn.GroupNorm: return nn.GroupNorm(num_groups=min(groups, channels), num_channels=channels) class ResidualBlock(nn.Module): def __init__(self, channels: int): super().__init__() self.conv1 = nn.Conv2d(channels, channels, 3, padding=1, bias=False) self.norm1 = _group_norm(channels) self.conv2 = nn.Conv2d(channels, channels, 3, padding=1, bias=False) self.norm2 = _group_norm(channels) def forward(self, x: torch.Tensor) -> torch.Tensor: h = torch.relu(self.norm1(self.conv1(x))) h = self.norm2(self.conv2(h)) return torch.relu(h + x) class BoardCNN(nn.Module): def __init__(self, d_model, channels=128, num_blocks=6): super().__init__() self.stem = nn.Sequential( nn.Conv2d(BOARD_PLANES, channels, 3, padding=1, bias=False), _group_norm(channels), nn.ReLU(inplace=True), ) self.blocks = nn.Sequential(*[ResidualBlock(channels) for _ in range(num_blocks)]) self.proj = nn.Linear(channels, d_model) self.square_pos = nn.Embedding(64, d_model) def forward(self, planes : torch.Tensor) -> torch.Tensor: x = self.stem(planes) x = self.blocks(x) # (N, C, 8, 8) x = x.permute(0, 2, 3, 1).reshape(x.size(0), 64, -1) # (n, 64, C) x = self.proj(x) + self.square_pos.weight # (n, 64, d_model) return x class CrossAttnBlock(nn.Module): def __init__(self, d_model, n_head, dim_ff, dropout): super().__init__() self.self_attn = nn.MultiheadAttention(d_model, n_head, dropout = dropout, batch_first=True) self.cross_attn = nn.MultiheadAttention(d_model, n_head, dropout = dropout, batch_first = True) self.ff = nn.Sequential( nn.Linear(d_model, dim_ff), nn.GELU(), nn.Linear(dim_ff, d_model) ) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.norm3 = nn.LayerNorm(d_model) self.drop = nn.Dropout(dropout) #Adding this gate which is init to 0 so cross-attn starts disabled self.cross_gate = nn.Parameter(torch.zeros(1)) def forward(self, moves, board, key_padding_mask, attn_mask): """ moves: (B, T, d) board: (B, T, 64, d) -- per-position K/V banks key_padding_mask: (B, T) -- True = padded move position attn_mask: (T, T) -- causal mask for self-attn """ m = self.norm1(moves) sa, _ = self.self_attn(m, m, m, attn_mask = attn_mask, key_padding_mask=key_padding_mask, need_weights=False) moves = moves + self.drop(sa) B, T, d = moves.shape q = self.norm2(moves).reshape(B * T, 1, d) kv = board.reshape(B * T, 64, d) ca, _ = self.cross_attn(q, kv, kv, need_weights = False) ca = ca.reshape(B, T, d) moves = moves + self.drop(self.cross_gate.tanh() * ca) # FFN moves = moves + self.drop(self.ff(self.norm3(moves))) return moves class PositionalEncoding(nn.Module): def __init__(self, d_model: int, max_len: int = 512, dropout: float = 0.1): super().__init__() self.dropout = nn.Dropout(p=dropout) pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) # (1, max_len, d_model) self.register_buffer("pe", pe) def forward(self, x: torch.Tensor) -> torch.Tensor: x = x + self.pe[:, :x.size(1)] return self.dropout(x) class ChessRewardModel(nn.Module): def __init__( self, vocab_size: int, d_model: int = 768, nhead: int = 12, num_layers: int = 8, dim_feedforward: int = 3072, max_seq_len: int = 128, dropout: float = 0.1, ): super().__init__() self.token_embedding = nn.Embedding(vocab_size, d_model) self.pos_encoding = PositionalEncoding(d_model, max_seq_len, dropout) encoder_layer = nn.TransformerEncoderLayer( d_model, nhead, dim_feedforward, dropout, batch_first=True ) self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) self.reward_head = nn.Linear(d_model, 1) def forward( self, token_ids: torch.Tensor, attention_mask: torch.Tensor | None = None, ) -> torch.Tensor: """ Args: token_ids: (batch, seq_len) int tensor with CLS prepended attention_mask: (batch, seq_len) bool tensor, True where padded Returns: (batch,) float tensor bounded to [-1, 1] """ x = self.token_embedding(token_ids) x = self.pos_encoding(x) x = self.encoder(x, src_key_padding_mask=attention_mask) cls_hidden = x[:, 0, :] # CLS token at position 0 reward = self.reward_head(cls_hidden).squeeze(-1) return torch.tanh(reward) class ChessPolicyModel(nn.Module): """Causal next-move predictor with per-position live-board cross-attention. Two streams flow through every block: - Move stream: token embeddings + sinusoidal positional encoding, doing causal self-attention over the move history. - Board stream: a (B, T, 64, d_model) bank of CNN-encoded board features where bank `t` is the state after token_ids[1..t] have been played. At each block, the move query at position t cross-attends only to its own 64 board-square keys — implicit causality via data layout, no masking needed. The board representation never depends on a token the model is being asked to predict, so multi-position LM-style training is leak-safe. """ def __init__( self, vocab_size: int, d_model: int = 768, nhead: int = 12, num_layers: int = 8, dim_feedforward: int = 3072, max_seq_len: int = 128, dropout: float = 0.1, cnn_channels: int = 128, cnn_blocks: int = 6, ): super().__init__() self.vocab_size = vocab_size self.token_embedding = nn.Embedding(vocab_size, d_model) self.board_cnn = BoardCNN(d_model, cnn_channels, cnn_blocks) self.pos_encoding = PositionalEncoding(d_model, max_seq_len, dropout) self.blocks = nn.ModuleList([ CrossAttnBlock(d_model, nhead, dim_feedforward, dropout) for _ in range(num_layers) ]) self.norm_out = nn.LayerNorm(d_model) self.prob_head = nn.Linear(d_model, vocab_size) def forward( self, token_ids: torch.Tensor, board_planes: torch.Tensor, attention_mask: torch.Tensor | None = None, ) -> torch.Tensor: """ Args: token_ids: (B, T) int — CLS at position 0 board_planes: (B, T, 19, 8, 8) float — per-position live planes; planes[:, t] is the board state after token_ids[1..t] attention_mask: (B, T) bool — True where padded Returns: (B, T, vocab_size) raw logits at every position """ B, T = token_ids.shape moves = self.token_embedding(token_ids) moves = self.pos_encoding(moves) # (B, T, d) # Vectorize the CNN over (B*T) boards — one big conv batch, not a loop. planes_flat = board_planes.reshape(B * T, BOARD_PLANES, 8, 8) board_feats = self.board_cnn(planes_flat) # (B*T, 64, d) board_feats = board_feats.reshape(B, T, 64, -1) # (B, T, 64, d) # Bool causal mask (True = masked future position) to match the bool # key_padding_mask. PyTorch deprecates mixing float and bool masks. causal = torch.triu( torch.ones(T, T, dtype=torch.bool, device=token_ids.device), diagonal=1 ) for blk in self.blocks: moves = blk(moves, board_feats, attention_mask, causal) moves = self.norm_out(moves) return self.prob_head(moves) # (B, T, vocab) class DummyRewardModel: """Material-count heuristic for MCTS testing.""" def __call__(self, board: chess.Board) -> float: score = 0.0 for piece_type in PIECE_VALUES: score += len(board.pieces(piece_type, chess.WHITE)) * PIECE_VALUES[piece_type] score -= len(board.pieces(piece_type, chess.BLACK)) * PIECE_VALUES[piece_type] return math.tanh(score / 10.0) class RewardModelInference: """Wraps ChessRewardModel + Tokenizer for use in minimax""" def __init__(self, model: ChessRewardModel, tokenizer: Tokenizer, device: str = "cpu"): self.model = model self.tokenizer = tokenizer self.device = device self.cls_id = tokenizer.symbol_to_token[CLS_TOKEN] self.pad_id = tokenizer.symbol_to_token[PAD_TOKEN] self.model.eval() @torch.no_grad() def __call__(self, board: chess.Board, max_seq_len: int = 128) -> float: moves_uci = [move.uci() for move in board.move_stack] # Keep the most recent moves to stay within the training sequence length. # CLS occupies position 0, so cap move history at max_seq_len - 1. moves_uci = moves_uci[-(max_seq_len - 1):] token_ids = [self.cls_id] + self.tokenizer.encode_moves(moves_uci) token_tensor = torch.tensor([token_ids], dtype=torch.long, device=self.device) reward = self.model(token_tensor) return reward.item() class PolicyModelInference: """Wraps ChessPolicyModel + Tokenizer""" def __init__(self, model: ChessPolicyModel, tokenizer: Tokenizer, device: str = "cpu"): self.model = model self.tokenizer = tokenizer self.device = device self.cls_id = tokenizer.symbol_to_token[CLS_TOKEN] self.pad_id = tokenizer.symbol_to_token[PAD_TOKEN] self.model.eval() @torch.no_grad() def __call__(self, board: chess.Board) -> str: moves_uci = [move.uci() for move in board.move_stack] token_ids = [self.cls_id] + self.tokenizer.encode_moves(moves_uci) token_tensor = torch.tensor([token_ids], dtype=torch.long, device=self.device) # Replay the full move history on a fresh board, snapshotting planes # at every position. planes[0] = standard starting board (model has # only seen [CLS]); planes[t] = state after the first t moves played. # This matches the training pipeline (ChessPolicyDataset._replay_planes # with start_board=chess.Board()) exactly. replay_board = chess.Board() plane_list = [board_to_planes(replay_board)] for uci in moves_uci: replay_board.push(chess.Move.from_uci(uci)) plane_list.append(board_to_planes(replay_board)) planes = torch.stack(plane_list).unsqueeze(0).to(self.device) # (1, T, 19, 8, 8) logits = self.model(token_tensor, planes) # (1, T, vocab_size) last_logits = logits[0, -1] # last position predicts the next move legal_move_ids = [self.tokenizer.symbol_to_token[move.uci()] for move in board.legal_moves] mask = torch.full((self.tokenizer.language_size,), float('-inf'), device=self.device) mask[legal_move_ids] = 0.0 best_move_idx = (last_logits + mask).argmax().item() return self.tokenizer.token_to_symbol[best_move_idx]