GptForChess / src /model.py
robell05's picture
serving model
6d75857
import math
import torch
import torch.nn as nn
import chess
from src.tokenizer import Tokenizer
CLS_TOKEN = "[CLS]"
PAD_TOKEN = "[PAD]"
PIECE_VALUES = {
chess.PAWN: 1,
chess.KNIGHT: 3,
chess.BISHOP: 3,
chess.ROOK: 5,
chess.QUEEN: 9,
chess.KING: 0,
}
BOARD_PLANES = 19
def board_to_planes(board: chess.Board) -> torch.Tensor:
"""chess.Board -> (19, 8, 8) float tensor."""
planes = torch.zeros(BOARD_PLANES, 8, 8, dtype=torch.float32)
pieces = [chess.PAWN, chess.KNIGHT, chess.BISHOP, chess.ROOK, chess.QUEEN, chess.KING]
colors = [chess.WHITE, chess.BLACK]
piece_to_plane = {(piece, color) : 6 * color_num + piece_num for piece_num, piece in enumerate(pieces) for color_num, color in enumerate(colors)}
for sq, piece in board.piece_map().items():
r, c = chess.square_rank(sq), chess.square_file(sq)
planes[piece_to_plane[(piece.piece_type, piece.color)], r, c] = 1.0
if board.turn == chess.WHITE:
planes[12].fill_(1.0)
if board.has_kingside_castling_rights(chess.WHITE): planes[13].fill_(1.0)
if board.has_queenside_castling_rights(chess.WHITE): planes[14].fill_(1.0)
if board.has_kingside_castling_rights(chess.BLACK): planes[15].fill_(1.0)
if board.has_queenside_castling_rights(chess.BLACK): planes[16].fill_(1.0)
if board.ep_square is not None:
r, c = chess.square_rank(board.ep_square), chess.square_file(board.ep_square)
planes[17, r, c] = 1.0
planes[18].fill_(min(board.halfmove_clock, 100) / 100.0)
return planes
def _group_norm(channels: int, groups: int = 32) -> nn.GroupNorm:
return nn.GroupNorm(num_groups=min(groups, channels), num_channels=channels)
class ResidualBlock(nn.Module):
def __init__(self, channels: int):
super().__init__()
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1, bias=False)
self.norm1 = _group_norm(channels)
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1, bias=False)
self.norm2 = _group_norm(channels)
def forward(self, x: torch.Tensor) -> torch.Tensor:
h = torch.relu(self.norm1(self.conv1(x)))
h = self.norm2(self.conv2(h))
return torch.relu(h + x)
class BoardCNN(nn.Module):
def __init__(self, d_model, channels=128, num_blocks=6):
super().__init__()
self.stem = nn.Sequential(
nn.Conv2d(BOARD_PLANES, channels, 3, padding=1, bias=False),
_group_norm(channels),
nn.ReLU(inplace=True),
)
self.blocks = nn.Sequential(*[ResidualBlock(channels) for _ in range(num_blocks)])
self.proj = nn.Linear(channels, d_model)
self.square_pos = nn.Embedding(64, d_model)
def forward(self, planes : torch.Tensor) -> torch.Tensor:
x = self.stem(planes)
x = self.blocks(x) # (N, C, 8, 8)
x = x.permute(0, 2, 3, 1).reshape(x.size(0), 64, -1) # (n, 64, C)
x = self.proj(x) + self.square_pos.weight # (n, 64, d_model)
return x
class CrossAttnBlock(nn.Module):
def __init__(self, d_model, n_head, dim_ff, dropout):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, n_head, dropout = dropout, batch_first=True)
self.cross_attn = nn.MultiheadAttention(d_model, n_head, dropout = dropout, batch_first = True)
self.ff = nn.Sequential(
nn.Linear(d_model, dim_ff), nn.GELU(), nn.Linear(dim_ff, d_model)
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.drop = nn.Dropout(dropout)
#Adding this gate which is init to 0 so cross-attn starts disabled
self.cross_gate = nn.Parameter(torch.zeros(1))
def forward(self, moves, board, key_padding_mask, attn_mask):
"""
moves: (B, T, d)
board: (B, T, 64, d) -- per-position K/V banks
key_padding_mask: (B, T) -- True = padded move position
attn_mask: (T, T) -- causal mask for self-attn
"""
m = self.norm1(moves)
sa, _ = self.self_attn(m, m, m, attn_mask = attn_mask, key_padding_mask=key_padding_mask, need_weights=False)
moves = moves + self.drop(sa)
B, T, d = moves.shape
q = self.norm2(moves).reshape(B * T, 1, d)
kv = board.reshape(B * T, 64, d)
ca, _ = self.cross_attn(q, kv, kv, need_weights = False)
ca = ca.reshape(B, T, d)
moves = moves + self.drop(self.cross_gate.tanh() * ca)
# FFN
moves = moves + self.drop(self.ff(self.norm3(moves)))
return moves
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, max_len: int = 512, dropout: float = 0.1):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0) # (1, max_len, d_model)
self.register_buffer("pe", pe)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.pe[:, :x.size(1)]
return self.dropout(x)
class ChessRewardModel(nn.Module):
def __init__(
self,
vocab_size: int,
d_model: int = 768,
nhead: int = 12,
num_layers: int = 8,
dim_feedforward: int = 3072,
max_seq_len: int = 128,
dropout: float = 0.1,
):
super().__init__()
self.token_embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoding = PositionalEncoding(d_model, max_seq_len, dropout)
encoder_layer = nn.TransformerEncoderLayer(
d_model, nhead, dim_feedforward, dropout, batch_first=True
)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.reward_head = nn.Linear(d_model, 1)
def forward(
self,
token_ids: torch.Tensor,
attention_mask: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Args:
token_ids: (batch, seq_len) int tensor with CLS prepended
attention_mask: (batch, seq_len) bool tensor, True where padded
Returns:
(batch,) float tensor bounded to [-1, 1]
"""
x = self.token_embedding(token_ids)
x = self.pos_encoding(x)
x = self.encoder(x, src_key_padding_mask=attention_mask)
cls_hidden = x[:, 0, :] # CLS token at position 0
reward = self.reward_head(cls_hidden).squeeze(-1)
return torch.tanh(reward)
class ChessPolicyModel(nn.Module):
"""Causal next-move predictor with per-position live-board cross-attention.
Two streams flow through every block:
- Move stream: token embeddings + sinusoidal positional encoding, doing
causal self-attention over the move history.
- Board stream: a (B, T, 64, d_model) bank of CNN-encoded board features
where bank `t` is the state after token_ids[1..t] have been played.
At each block, the move query at position t cross-attends only to its
own 64 board-square keys — implicit causality via data layout, no
masking needed.
The board representation never depends on a token the model is being
asked to predict, so multi-position LM-style training is leak-safe.
"""
def __init__(
self,
vocab_size: int,
d_model: int = 768,
nhead: int = 12,
num_layers: int = 8,
dim_feedforward: int = 3072,
max_seq_len: int = 128,
dropout: float = 0.1,
cnn_channels: int = 128,
cnn_blocks: int = 6,
):
super().__init__()
self.vocab_size = vocab_size
self.token_embedding = nn.Embedding(vocab_size, d_model)
self.board_cnn = BoardCNN(d_model, cnn_channels, cnn_blocks)
self.pos_encoding = PositionalEncoding(d_model, max_seq_len, dropout)
self.blocks = nn.ModuleList([
CrossAttnBlock(d_model, nhead, dim_feedforward, dropout)
for _ in range(num_layers)
])
self.norm_out = nn.LayerNorm(d_model)
self.prob_head = nn.Linear(d_model, vocab_size)
def forward(
self,
token_ids: torch.Tensor,
board_planes: torch.Tensor,
attention_mask: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Args:
token_ids: (B, T) int — CLS at position 0
board_planes: (B, T, 19, 8, 8) float — per-position live planes;
planes[:, t] is the board state after token_ids[1..t]
attention_mask: (B, T) bool — True where padded
Returns:
(B, T, vocab_size) raw logits at every position
"""
B, T = token_ids.shape
moves = self.token_embedding(token_ids)
moves = self.pos_encoding(moves) # (B, T, d)
# Vectorize the CNN over (B*T) boards — one big conv batch, not a loop.
planes_flat = board_planes.reshape(B * T, BOARD_PLANES, 8, 8)
board_feats = self.board_cnn(planes_flat) # (B*T, 64, d)
board_feats = board_feats.reshape(B, T, 64, -1) # (B, T, 64, d)
# Bool causal mask (True = masked future position) to match the bool
# key_padding_mask. PyTorch deprecates mixing float and bool masks.
causal = torch.triu(
torch.ones(T, T, dtype=torch.bool, device=token_ids.device), diagonal=1
)
for blk in self.blocks:
moves = blk(moves, board_feats, attention_mask, causal)
moves = self.norm_out(moves)
return self.prob_head(moves) # (B, T, vocab)
class DummyRewardModel:
"""Material-count heuristic for MCTS testing."""
def __call__(self, board: chess.Board) -> float:
score = 0.0
for piece_type in PIECE_VALUES:
score += len(board.pieces(piece_type, chess.WHITE)) * PIECE_VALUES[piece_type]
score -= len(board.pieces(piece_type, chess.BLACK)) * PIECE_VALUES[piece_type]
return math.tanh(score / 10.0)
class RewardModelInference:
"""Wraps ChessRewardModel + Tokenizer for use in minimax"""
def __init__(self, model: ChessRewardModel, tokenizer: Tokenizer, device: str = "cpu"):
self.model = model
self.tokenizer = tokenizer
self.device = device
self.cls_id = tokenizer.symbol_to_token[CLS_TOKEN]
self.pad_id = tokenizer.symbol_to_token[PAD_TOKEN]
self.model.eval()
@torch.no_grad()
def __call__(self, board: chess.Board, max_seq_len: int = 128) -> float:
moves_uci = [move.uci() for move in board.move_stack]
# Keep the most recent moves to stay within the training sequence length.
# CLS occupies position 0, so cap move history at max_seq_len - 1.
moves_uci = moves_uci[-(max_seq_len - 1):]
token_ids = [self.cls_id] + self.tokenizer.encode_moves(moves_uci)
token_tensor = torch.tensor([token_ids], dtype=torch.long, device=self.device)
reward = self.model(token_tensor)
return reward.item()
class PolicyModelInference:
"""Wraps ChessPolicyModel + Tokenizer"""
def __init__(self, model: ChessPolicyModel, tokenizer: Tokenizer, device: str = "cpu"):
self.model = model
self.tokenizer = tokenizer
self.device = device
self.cls_id = tokenizer.symbol_to_token[CLS_TOKEN]
self.pad_id = tokenizer.symbol_to_token[PAD_TOKEN]
self.model.eval()
@torch.no_grad()
def __call__(self, board: chess.Board) -> str:
moves_uci = [move.uci() for move in board.move_stack]
token_ids = [self.cls_id] + self.tokenizer.encode_moves(moves_uci)
token_tensor = torch.tensor([token_ids], dtype=torch.long, device=self.device)
# Replay the full move history on a fresh board, snapshotting planes
# at every position. planes[0] = standard starting board (model has
# only seen [CLS]); planes[t] = state after the first t moves played.
# This matches the training pipeline (ChessPolicyDataset._replay_planes
# with start_board=chess.Board()) exactly.
replay_board = chess.Board()
plane_list = [board_to_planes(replay_board)]
for uci in moves_uci:
replay_board.push(chess.Move.from_uci(uci))
plane_list.append(board_to_planes(replay_board))
planes = torch.stack(plane_list).unsqueeze(0).to(self.device) # (1, T, 19, 8, 8)
logits = self.model(token_tensor, planes) # (1, T, vocab_size)
last_logits = logits[0, -1] # last position predicts the next move
legal_move_ids = [self.tokenizer.symbol_to_token[move.uci()] for move in board.legal_moves]
mask = torch.full((self.tokenizer.language_size,), float('-inf'), device=self.device)
mask[legal_move_ids] = 0.0
best_move_idx = (last_logits + mask).argmax().item()
return self.tokenizer.token_to_symbol[best_move_idx]