import chess import chess.engine import torch import torch.nn as nn import torch.optim as optim import numpy as np import random import os import pygame import time import socket from collections import deque from torch.utils.tensorboard import SummaryWriter # --- ALPHAZERO CONSTANTS --- MCTS_SIMULATIONS = 100 # Low for training speed, high for evaluation DIRICHLET_EPSILON = 0.25 DIRICHLET_ALPHA = 0.3 LR = 0.0001 RESIDUAL_BLOCKS = 20 # FULL AlphaZero Scale FILTERS = 256 # Expanded from 128 for 100MB capacity class ResBlock(nn.Module): def __init__(self, channels): super(ResBlock, self).__init__() self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) self.bn1 = nn.BatchNorm2d(channels) self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) self.bn2 = nn.BatchNorm2d(channels) def forward(self, x): residual = x x = torch.relu(self.bn1(self.conv1(x))) x = self.bn2(self.conv2(x)) x += residual x = torch.relu(x) return x class AlphaChessNet(nn.Module): def __init__(self): super(AlphaChessNet, self).__init__() self.conv_entry = nn.Sequential( nn.Conv2d(18, FILTERS, kernel_size=3, padding=1), nn.BatchNorm2d(FILTERS), nn.ReLU() ) self.res_blocks = nn.ModuleList([ResBlock(FILTERS) for _ in range(RESIDUAL_BLOCKS)]) # Policy Head self.policy_head = nn.Sequential( nn.Conv2d(FILTERS, 2, kernel_size=1), nn.BatchNorm2d(2), nn.ReLU(), nn.Flatten(), nn.Linear(2 * 8 * 8, 4096) ) # Value Head self.value_head = nn.Sequential( nn.Conv2d(FILTERS, 1, kernel_size=1), nn.BatchNorm2d(1), nn.Identity(), # REPLACED ReLU WITH IDENTITY TO PROTECT SIGNAL WHILE KEEPING WEIGHT COMPATIBILITY nn.Flatten(), nn.Linear(8 * 8, 256), nn.ReLU(), nn.Linear(256, 1), nn.Tanh() ) def forward(self, x): x = self.conv_entry(x) for block in self.res_blocks: x = block(x) p = self.policy_head(x) v = self.value_head(x) return p, v class MCTSNode: def __init__(self, prior, to_play): self.P = prior self.to_play = to_play self.N = 0 self.W = 0 self.Q = 0 self.children = {} def is_expanded(self): return len(self.children) > 0 class AlphaMCTS: def __init__(self, model, device): self.model = model self.device = device def search(self, board, simulations=MCTS_SIMULATIONS, training=True): root = MCTSNode(0, 1 if board.turn == chess.WHITE else -1) self.expand(root, board) c_puct = 1.0 if training else 2.5 # More thorough exploration for evaluation for _ in range(simulations): node = root search_board = board.copy() path = [node] # Root Dirichlet Noise (AlphaZero Exploration) if _ == 0 and training: actions = list(node.children.keys()) if len(actions) > 0: noise = np.random.dirichlet([DIRICHLET_ALPHA] * len(actions)) for i, action in enumerate(actions): node.children[action].P = (1 - DIRICHLET_EPSILON) * node.children[action].P + DIRICHLET_EPSILON * noise[i] while node.is_expanded(): move, node = self.select_child(node, c_puct) search_board.push(move) path.append(node) # Expansion returns value from perspective of CURRENT player at leaf leaf_v_pov = self.expand(node, search_board) # Translate to Absolute Value (White = +1, Black = -1) v_white = leaf_v_pov * node.to_play for back_node in path: back_node.W += v_white back_node.N += 1 back_node.Q = back_node.W / back_node.N return root def select_child(self, node, c_puct): best_u = -float('inf') best_move = None best_child = None for move, child in node.children.items(): # Standard Lc0 Selection: My Perspective = my_turn * child.Q u = (node.to_play * child.Q) + c_puct * child.P * (np.sqrt(node.N) / (1 + child.N)) if u > best_u: best_u = u best_move = move best_child = child return best_move, best_child def expand(self, node, board): if board.is_game_over(): outcome = board.outcome() if outcome.winner is None: return 0 # If the game is over and not a draw, the person whose turn it is # (the one who got checkmated) has lost. Therefore it is -1. return -1 input_state = self.get_state(board) # Already (18, 8, 8) input_tensor = torch.FloatTensor(input_state).unsqueeze(0).to(self.device) with torch.no_grad(): p_logits, v = self.model(input_tensor) p = torch.softmax(p_logits, dim=1).cpu().numpy()[0] legal_moves = list(board.legal_moves) total_p = 0 for move in legal_moves: # Use Perspective-Aware indexing idx = self.move_to_index(move, board.turn == chess.WHITE) prob = p[idx] node.children[move] = MCTSNode(prob, -node.to_play) total_p += prob if total_p > 0: for move in node.children: node.children[move].P /= total_p return v.item() def get_state(self, board): # AlphaZero Style: Bottom-Player Perspective (Universal Logic) state = np.zeros((18, 8, 8), dtype=np.float32) is_white = board.turn == chess.WHITE for square in chess.SQUARES: piece = board.piece_at(square) if piece: r, f = chess.square_rank(square), chess.square_file(square) # Full 180-Degree Rotation if sitting at the bottom as Black actual_r = r if is_white else 7 - r actual_f = f if is_white else 7 - f # Active player pieces in 0-5, Opponent in 6-11 if piece.color == board.turn: idx = piece.piece_type - 1 else: idx = piece.piece_type - 1 + 6 state[idx][actual_r][actual_f] = 1 # Planes 12-16: Castling Rights (Perspective Aware) # White Kingside, White Queenside, Black Kingside, Black Queenside if board.has_kingside_castling_rights(board.turn): state[12].fill(1) if board.has_queenside_castling_rights(board.turn): state[13].fill(1) if board.has_kingside_castling_rights(not board.turn): state[14].fill(1) if board.has_queenside_castling_rights(not board.turn): state[15].fill(1) # Plane 16: Halfmove Clock (normalized) state[16].fill(min(board.halfmove_clock, 100) / 100.0) return state def move_to_index(self, move, is_white=True): # Translate Global Board move to Perspective Move Index fro_r, fro_f = chess.square_rank(move.from_square), chess.square_file(move.from_square) to_r, to_f = chess.square_rank(move.to_square), chess.square_file(move.to_square) if not is_white: fro_r, fro_f = 7 - fro_r, 7 - fro_f to_r, to_f = 7 - to_r, 7 - to_f fro = fro_r * 8 + fro_f to = to_r * 8 + to_f return fro * 64 + to def index_to_move(self, index, is_white=True): fro_idx, to_idx = index // 64, index % 64 fro_r, fro_f = fro_idx // 8, fro_idx % 8 to_r, to_f = to_idx // 8, to_idx % 8 if not is_white: fro_r, fro_f = 7 - fro_r, 7 - fro_f to_r, to_f = 7 - to_r, 7 - to_f return chess.Move(fro_r * 8 + fro_f, to_r * 8 + to_f) class AlphaAgent: def __init__(self, device='cuda'): self.device = device self.model = AlphaChessNet().to(device) self.optimizer = optim.Adam(self.model.parameters(), lr=LR) self.mcts = AlphaMCTS(self.model, device) def train_step(self, states, mcts_probs, winners): # Input states are already (Batch, 18, 8, 8) states = torch.FloatTensor(np.array(states)).to(self.device) mcts_probs = torch.FloatTensor(np.array(mcts_probs)).to(self.device) winners = torch.FloatTensor(np.array(winners)).unsqueeze(1).to(self.device) self.optimizer.zero_grad() p_logits, v = self.model(states) # AlphaZero Hybrid Loss: Policy (CrossEntropy) + Value (MSE) p_loss = -torch.mean(torch.sum(mcts_probs * torch.log_softmax(p_logits, dim=1), dim=1)) v_loss = torch.mean((winners - v)**2) # Standard AlphaZero loss ratio (1:1). The 50x multiplier was # killing the policy head and causing catastrophic move collapse. loss = p_loss + v_loss loss.backward() self.optimizer.step() return loss.item() if __name__ == "__main__": print("RL-Chess-Alpha V1 Initialized. Use 'Distill-Alpha.py' to prime it with Stockfish knowledge.")