| import chess |
| import chess.engine |
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| import numpy as np |
| import random |
| import os |
| import pygame |
| import time |
| import socket |
| from collections import deque |
| from torch.utils.tensorboard import SummaryWriter |
|
|
| |
| MCTS_SIMULATIONS = 100 |
| DIRICHLET_EPSILON = 0.25 |
| DIRICHLET_ALPHA = 0.3 |
| LR = 0.0001 |
| RESIDUAL_BLOCKS = 20 |
| FILTERS = 256 |
|
|
| class ResBlock(nn.Module): |
| def __init__(self, channels): |
| super(ResBlock, self).__init__() |
| self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) |
| self.bn1 = nn.BatchNorm2d(channels) |
| self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) |
| self.bn2 = nn.BatchNorm2d(channels) |
|
|
| def forward(self, x): |
| residual = x |
| x = torch.relu(self.bn1(self.conv1(x))) |
| x = self.bn2(self.conv2(x)) |
| x += residual |
| x = torch.relu(x) |
| return x |
|
|
| class AlphaChessNet(nn.Module): |
| def __init__(self): |
| super(AlphaChessNet, self).__init__() |
| self.conv_entry = nn.Sequential( |
| nn.Conv2d(18, FILTERS, kernel_size=3, padding=1), |
| nn.BatchNorm2d(FILTERS), |
| nn.ReLU() |
| ) |
| self.res_blocks = nn.ModuleList([ResBlock(FILTERS) for _ in range(RESIDUAL_BLOCKS)]) |
| |
| |
| self.policy_head = nn.Sequential( |
| nn.Conv2d(FILTERS, 2, kernel_size=1), |
| nn.BatchNorm2d(2), |
| nn.ReLU(), |
| nn.Flatten(), |
| nn.Linear(2 * 8 * 8, 4096) |
| ) |
| |
| |
| self.value_head = nn.Sequential( |
| nn.Conv2d(FILTERS, 1, kernel_size=1), |
| nn.BatchNorm2d(1), |
| nn.Identity(), |
| nn.Flatten(), |
| nn.Linear(8 * 8, 256), |
| nn.ReLU(), |
| nn.Linear(256, 1), |
| nn.Tanh() |
| ) |
|
|
|
|
| def forward(self, x): |
| x = self.conv_entry(x) |
| for block in self.res_blocks: |
| x = block(x) |
| |
| p = self.policy_head(x) |
| v = self.value_head(x) |
| return p, v |
|
|
| class MCTSNode: |
| def __init__(self, prior, to_play): |
| self.P = prior |
| self.to_play = to_play |
| self.N = 0 |
| self.W = 0 |
| self.Q = 0 |
| self.children = {} |
|
|
| def is_expanded(self): |
| return len(self.children) > 0 |
|
|
| class AlphaMCTS: |
| def __init__(self, model, device): |
| self.model = model |
| self.device = device |
|
|
| def search(self, board, simulations=MCTS_SIMULATIONS, training=True): |
| root = MCTSNode(0, 1 if board.turn == chess.WHITE else -1) |
| self.expand(root, board) |
| |
| c_puct = 1.0 if training else 2.5 |
| |
| for _ in range(simulations): |
| node = root |
| search_board = board.copy() |
| path = [node] |
|
|
| |
| if _ == 0 and training: |
| actions = list(node.children.keys()) |
| if len(actions) > 0: |
| noise = np.random.dirichlet([DIRICHLET_ALPHA] * len(actions)) |
| for i, action in enumerate(actions): |
| node.children[action].P = (1 - DIRICHLET_EPSILON) * node.children[action].P + DIRICHLET_EPSILON * noise[i] |
|
|
| while node.is_expanded(): |
|
|
| move, node = self.select_child(node, c_puct) |
| search_board.push(move) |
| path.append(node) |
|
|
| |
| leaf_v_pov = self.expand(node, search_board) |
| |
| v_white = leaf_v_pov * node.to_play |
| |
| for back_node in path: |
| back_node.W += v_white |
| back_node.N += 1 |
| back_node.Q = back_node.W / back_node.N |
|
|
| return root |
|
|
| def select_child(self, node, c_puct): |
| best_u = -float('inf') |
| best_move = None |
| best_child = None |
| |
| for move, child in node.children.items(): |
| |
| u = (node.to_play * child.Q) + c_puct * child.P * (np.sqrt(node.N) / (1 + child.N)) |
| if u > best_u: |
| best_u = u |
| best_move = move |
| best_child = child |
| return best_move, best_child |
|
|
|
|
|
|
| def expand(self, node, board): |
| if board.is_game_over(): |
| outcome = board.outcome() |
| if outcome.winner is None: return 0 |
| |
| |
| return -1 |
|
|
| input_state = self.get_state(board) |
| input_tensor = torch.FloatTensor(input_state).unsqueeze(0).to(self.device) |
| |
| with torch.no_grad(): |
| p_logits, v = self.model(input_tensor) |
| |
| p = torch.softmax(p_logits, dim=1).cpu().numpy()[0] |
| legal_moves = list(board.legal_moves) |
| |
| total_p = 0 |
| for move in legal_moves: |
| |
| idx = self.move_to_index(move, board.turn == chess.WHITE) |
| prob = p[idx] |
| node.children[move] = MCTSNode(prob, -node.to_play) |
| total_p += prob |
| |
| if total_p > 0: |
| for move in node.children: |
| node.children[move].P /= total_p |
| |
| return v.item() |
|
|
| def get_state(self, board): |
| |
| state = np.zeros((18, 8, 8), dtype=np.float32) |
| is_white = board.turn == chess.WHITE |
| |
| for square in chess.SQUARES: |
| piece = board.piece_at(square) |
| if piece: |
| r, f = chess.square_rank(square), chess.square_file(square) |
| |
| actual_r = r if is_white else 7 - r |
| actual_f = f if is_white else 7 - f |
| |
| |
| if piece.color == board.turn: |
| idx = piece.piece_type - 1 |
| else: |
| idx = piece.piece_type - 1 + 6 |
| state[idx][actual_r][actual_f] = 1 |
|
|
| |
| |
| |
| if board.has_kingside_castling_rights(board.turn): state[12].fill(1) |
| if board.has_queenside_castling_rights(board.turn): state[13].fill(1) |
| if board.has_kingside_castling_rights(not board.turn): state[14].fill(1) |
| if board.has_queenside_castling_rights(not board.turn): state[15].fill(1) |
| |
| |
| state[16].fill(min(board.halfmove_clock, 100) / 100.0) |
| return state |
|
|
| def move_to_index(self, move, is_white=True): |
| |
| fro_r, fro_f = chess.square_rank(move.from_square), chess.square_file(move.from_square) |
| to_r, to_f = chess.square_rank(move.to_square), chess.square_file(move.to_square) |
| |
| if not is_white: |
| fro_r, fro_f = 7 - fro_r, 7 - fro_f |
| to_r, to_f = 7 - to_r, 7 - to_f |
| |
| fro = fro_r * 8 + fro_f |
| to = to_r * 8 + to_f |
|
|
| return fro * 64 + to |
|
|
| def index_to_move(self, index, is_white=True): |
| fro_idx, to_idx = index // 64, index % 64 |
| fro_r, fro_f = fro_idx // 8, fro_idx % 8 |
| to_r, to_f = to_idx // 8, to_idx % 8 |
| |
| if not is_white: |
| fro_r, fro_f = 7 - fro_r, 7 - fro_f |
| to_r, to_f = 7 - to_r, 7 - to_f |
| |
| return chess.Move(fro_r * 8 + fro_f, to_r * 8 + to_f) |
|
|
|
|
|
|
| class AlphaAgent: |
| def __init__(self, device='cuda'): |
| self.device = device |
| self.model = AlphaChessNet().to(device) |
| self.optimizer = optim.Adam(self.model.parameters(), lr=LR) |
| self.mcts = AlphaMCTS(self.model, device) |
|
|
| def train_step(self, states, mcts_probs, winners): |
| |
| states = torch.FloatTensor(np.array(states)).to(self.device) |
| mcts_probs = torch.FloatTensor(np.array(mcts_probs)).to(self.device) |
| winners = torch.FloatTensor(np.array(winners)).unsqueeze(1).to(self.device) |
| |
| self.optimizer.zero_grad() |
| p_logits, v = self.model(states) |
| |
| |
| p_loss = -torch.mean(torch.sum(mcts_probs * torch.log_softmax(p_logits, dim=1), dim=1)) |
| v_loss = torch.mean((winners - v)**2) |
| |
| |
| |
| loss = p_loss + v_loss |
| loss.backward() |
| self.optimizer.step() |
| return loss.item() |
|
|
| if __name__ == "__main__": |
| print("RL-Chess-Alpha V1 Initialized. Use 'Distill-Alpha.py' to prime it with Stockfish knowledge.") |
|
|