RL-Chess / RL_Chess_Alpha.py
Gregniuki's picture
Upload RL_Chess_Alpha.py
c8cc8b1 verified
import chess
import chess.engine
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
import os
import pygame
import time
import socket
from collections import deque
from torch.utils.tensorboard import SummaryWriter
# --- ALPHAZERO CONSTANTS ---
MCTS_SIMULATIONS = 100 # Low for training speed, high for evaluation
DIRICHLET_EPSILON = 0.25
DIRICHLET_ALPHA = 0.3
LR = 0.0001
RESIDUAL_BLOCKS = 20 # FULL AlphaZero Scale
FILTERS = 256 # Expanded from 128 for 100MB capacity
class ResBlock(nn.Module):
def __init__(self, channels):
super(ResBlock, self).__init__()
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(channels)
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(channels)
def forward(self, x):
residual = x
x = torch.relu(self.bn1(self.conv1(x)))
x = self.bn2(self.conv2(x))
x += residual
x = torch.relu(x)
return x
class AlphaChessNet(nn.Module):
def __init__(self):
super(AlphaChessNet, self).__init__()
self.conv_entry = nn.Sequential(
nn.Conv2d(18, FILTERS, kernel_size=3, padding=1),
nn.BatchNorm2d(FILTERS),
nn.ReLU()
)
self.res_blocks = nn.ModuleList([ResBlock(FILTERS) for _ in range(RESIDUAL_BLOCKS)])
# Policy Head
self.policy_head = nn.Sequential(
nn.Conv2d(FILTERS, 2, kernel_size=1),
nn.BatchNorm2d(2),
nn.ReLU(),
nn.Flatten(),
nn.Linear(2 * 8 * 8, 4096)
)
# Value Head
self.value_head = nn.Sequential(
nn.Conv2d(FILTERS, 1, kernel_size=1),
nn.BatchNorm2d(1),
nn.Identity(), # REPLACED ReLU WITH IDENTITY TO PROTECT SIGNAL WHILE KEEPING WEIGHT COMPATIBILITY
nn.Flatten(),
nn.Linear(8 * 8, 256),
nn.ReLU(),
nn.Linear(256, 1),
nn.Tanh()
)
def forward(self, x):
x = self.conv_entry(x)
for block in self.res_blocks:
x = block(x)
p = self.policy_head(x)
v = self.value_head(x)
return p, v
class MCTSNode:
def __init__(self, prior, to_play):
self.P = prior
self.to_play = to_play
self.N = 0
self.W = 0
self.Q = 0
self.children = {}
def is_expanded(self):
return len(self.children) > 0
class AlphaMCTS:
def __init__(self, model, device):
self.model = model
self.device = device
def search(self, board, simulations=MCTS_SIMULATIONS, training=True):
root = MCTSNode(0, 1 if board.turn == chess.WHITE else -1)
self.expand(root, board)
c_puct = 1.0 if training else 2.5 # More thorough exploration for evaluation
for _ in range(simulations):
node = root
search_board = board.copy()
path = [node]
# Root Dirichlet Noise (AlphaZero Exploration)
if _ == 0 and training:
actions = list(node.children.keys())
if len(actions) > 0:
noise = np.random.dirichlet([DIRICHLET_ALPHA] * len(actions))
for i, action in enumerate(actions):
node.children[action].P = (1 - DIRICHLET_EPSILON) * node.children[action].P + DIRICHLET_EPSILON * noise[i]
while node.is_expanded():
move, node = self.select_child(node, c_puct)
search_board.push(move)
path.append(node)
# Expansion returns value from perspective of CURRENT player at leaf
leaf_v_pov = self.expand(node, search_board)
# Translate to Absolute Value (White = +1, Black = -1)
v_white = leaf_v_pov * node.to_play
for back_node in path:
back_node.W += v_white
back_node.N += 1
back_node.Q = back_node.W / back_node.N
return root
def select_child(self, node, c_puct):
best_u = -float('inf')
best_move = None
best_child = None
for move, child in node.children.items():
# Standard Lc0 Selection: My Perspective = my_turn * child.Q
u = (node.to_play * child.Q) + c_puct * child.P * (np.sqrt(node.N) / (1 + child.N))
if u > best_u:
best_u = u
best_move = move
best_child = child
return best_move, best_child
def expand(self, node, board):
if board.is_game_over():
outcome = board.outcome()
if outcome.winner is None: return 0
# If the game is over and not a draw, the person whose turn it is
# (the one who got checkmated) has lost. Therefore it is -1.
return -1
input_state = self.get_state(board) # Already (18, 8, 8)
input_tensor = torch.FloatTensor(input_state).unsqueeze(0).to(self.device)
with torch.no_grad():
p_logits, v = self.model(input_tensor)
p = torch.softmax(p_logits, dim=1).cpu().numpy()[0]
legal_moves = list(board.legal_moves)
total_p = 0
for move in legal_moves:
# Use Perspective-Aware indexing
idx = self.move_to_index(move, board.turn == chess.WHITE)
prob = p[idx]
node.children[move] = MCTSNode(prob, -node.to_play)
total_p += prob
if total_p > 0:
for move in node.children:
node.children[move].P /= total_p
return v.item()
def get_state(self, board):
# AlphaZero Style: Bottom-Player Perspective (Universal Logic)
state = np.zeros((18, 8, 8), dtype=np.float32)
is_white = board.turn == chess.WHITE
for square in chess.SQUARES:
piece = board.piece_at(square)
if piece:
r, f = chess.square_rank(square), chess.square_file(square)
# Full 180-Degree Rotation if sitting at the bottom as Black
actual_r = r if is_white else 7 - r
actual_f = f if is_white else 7 - f
# Active player pieces in 0-5, Opponent in 6-11
if piece.color == board.turn:
idx = piece.piece_type - 1
else:
idx = piece.piece_type - 1 + 6
state[idx][actual_r][actual_f] = 1
# Planes 12-16: Castling Rights (Perspective Aware)
# White Kingside, White Queenside, Black Kingside, Black Queenside
if board.has_kingside_castling_rights(board.turn): state[12].fill(1)
if board.has_queenside_castling_rights(board.turn): state[13].fill(1)
if board.has_kingside_castling_rights(not board.turn): state[14].fill(1)
if board.has_queenside_castling_rights(not board.turn): state[15].fill(1)
# Plane 16: Halfmove Clock (normalized)
state[16].fill(min(board.halfmove_clock, 100) / 100.0)
return state
def move_to_index(self, move, is_white=True):
# Translate Global Board move to Perspective Move Index
fro_r, fro_f = chess.square_rank(move.from_square), chess.square_file(move.from_square)
to_r, to_f = chess.square_rank(move.to_square), chess.square_file(move.to_square)
if not is_white:
fro_r, fro_f = 7 - fro_r, 7 - fro_f
to_r, to_f = 7 - to_r, 7 - to_f
fro = fro_r * 8 + fro_f
to = to_r * 8 + to_f
return fro * 64 + to
def index_to_move(self, index, is_white=True):
fro_idx, to_idx = index // 64, index % 64
fro_r, fro_f = fro_idx // 8, fro_idx % 8
to_r, to_f = to_idx // 8, to_idx % 8
if not is_white:
fro_r, fro_f = 7 - fro_r, 7 - fro_f
to_r, to_f = 7 - to_r, 7 - to_f
return chess.Move(fro_r * 8 + fro_f, to_r * 8 + to_f)
class AlphaAgent:
def __init__(self, device='cuda'):
self.device = device
self.model = AlphaChessNet().to(device)
self.optimizer = optim.Adam(self.model.parameters(), lr=LR)
self.mcts = AlphaMCTS(self.model, device)
def train_step(self, states, mcts_probs, winners):
# Input states are already (Batch, 18, 8, 8)
states = torch.FloatTensor(np.array(states)).to(self.device)
mcts_probs = torch.FloatTensor(np.array(mcts_probs)).to(self.device)
winners = torch.FloatTensor(np.array(winners)).unsqueeze(1).to(self.device)
self.optimizer.zero_grad()
p_logits, v = self.model(states)
# AlphaZero Hybrid Loss: Policy (CrossEntropy) + Value (MSE)
p_loss = -torch.mean(torch.sum(mcts_probs * torch.log_softmax(p_logits, dim=1), dim=1))
v_loss = torch.mean((winners - v)**2)
# Standard AlphaZero loss ratio (1:1). The 50x multiplier was
# killing the policy head and causing catastrophic move collapse.
loss = p_loss + v_loss
loss.backward()
self.optimizer.step()
return loss.item()
if __name__ == "__main__":
print("RL-Chess-Alpha V1 Initialized. Use 'Distill-Alpha.py' to prime it with Stockfish knowledge.")