Chess-Web / rl_train.py
dpv007's picture
Upload folder using huggingface_hub
769f59f verified
Raw
History Blame Contribute Delete
24.3 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from model import ChessTransformer
from data_loader import VOCAB
import os
import random
import chess
import chess.engine
import gradio as gr
import threading
import time
import collections
import numpy as np
# Global variables for UI monitoring
current_game_pgn = ""
current_eval = 0.0
current_sf_eval = 0.0
champion_wins = 0
challenger_wins = 0
draws = 0
training_stats = {"loss": 0.0, "reward": 0.0, "epoch": 0}
current_challenger_is_white = True
last_promoted_step = 0
# Thread-safe queues and locks
class ReplayBuffer:
def __init__(self, capacity=50000):
self.buffer = []
self.capacity = capacity
self.ptr = 0
self.lock = threading.Lock()
def append(self, item):
with self.lock:
if len(self.buffer) < self.capacity:
self.buffer.append(item)
else:
self.buffer[self.ptr] = item
self.ptr = (self.ptr + 1) % self.capacity
def __len__(self):
with self.lock:
return len(self.buffer)
def sample(self, batch_size):
with self.lock:
return random.sample(self.buffer, batch_size)
replay_buffer = ReplayBuffer(50000)
recent_outcomes = collections.deque(maxlen=100)
ui_lock = threading.Lock()
stats_lock = threading.Lock()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
INV_VOCAB = {v: k for k, v in VOCAB.items()}
# Global Models
champion = None
challenger = None
optimizer = None
def get_stockfish_engine():
stockfish_path = os.path.join("stockfish", "stockfish-windows-x86-64-avx2.exe")
if os.path.exists(stockfish_path):
sf_engine = chess.engine.SimpleEngine.popen_uci(stockfish_path)
sf_engine.configure({"Hash": 64, "Threads": 1})
return sf_engine
return None
def encode_history(history, max_length=120):
seq = []
for tok in history:
seq.append(VOCAB.get(tok, VOCAB.get("<unk>", 0)))
if len(seq) > max_length:
seq = seq[-max_length:]
else:
seq = seq + [0] * (max_length - len(seq))
return torch.tensor(seq, dtype=torch.long, device=device).unsqueeze(0)
def sample_move(policy_logits, board, temperature=1.0):
logits = policy_logits / temperature
legal_moves = list(board.legal_moves)
legal_ucis = [m.uci() for m in legal_moves]
mask = torch.full_like(logits, float('-inf'))
for idx, token in INV_VOCAB.items():
if token in legal_ucis:
mask[0, idx] = 0.0
masked_logits = logits + mask
probs = F.softmax(masked_logits, dim=-1)
if torch.isnan(probs).any() or probs.sum() == 0:
return random.choice(legal_moves).uci()
m = torch.multinomial(probs[0], 1).item()
action = INV_VOCAB.get(m)
if action not in legal_ucis:
return random.choice(legal_moves).uci()
return action
def actor_worker(worker_id):
"""Background thread playing BATCH_SIZE games concurrently."""
global current_game_pgn, current_eval, current_sf_eval, current_challenger_is_white
global champion_wins, challenger_wins, draws
BATCH_SIZE = 16
sf_engine = get_stockfish_engine()
sf_limit = chess.engine.Limit(time=0.05)
def evaluate_position(board):
if sf_engine is None: return 0.0
info = sf_engine.analyse(board, sf_limit)
score = info["score"].white()
if score.is_mate(): return 10000 if score.mate() > 0 else -10000
return score.score()
print(f"Actor {worker_id} started with Batch Size {BATCH_SIZE}.")
while True:
boards = [chess.Board() for _ in range(BATCH_SIZE)]
histories = [["<bos>"] for _ in range(BATCH_SIZE)]
active = [True for _ in range(BATCH_SIZE)]
evals_current = []
for b in boards:
evals_current.append(evaluate_position(b))
import random
challenger_is_white = [random.choice([True, False]) for _ in range(BATCH_SIZE)]
if worker_id == 0:
with ui_lock:
current_challenger_is_white = challenger_is_white[0]
game_data = [{"states": [], "actions": [], "advantages": [], "sf_values": []} for _ in range(BATCH_SIZE)]
turn_count = 0
while any(active) and turn_count < 200:
turn_count += 1
challenger_indices = []
champion_indices = []
for i in range(BATCH_SIZE):
if not active[i]:
continue
turn_white = boards[i].turn
if turn_white == challenger_is_white[i]:
challenger_indices.append(i)
else:
champion_indices.append(i)
p_chal, v_chal = None, None
p_champ, v_champ = None, None
with torch.no_grad():
with torch.autocast(device_type='cuda', dtype=torch.float16):
if len(challenger_indices) > 0:
x_chal = torch.cat([encode_history(histories[i]) for i in challenger_indices], dim=0)
p_chal, v_chal = challenger(x_chal)
if len(champion_indices) > 0:
x_champ = torch.cat([encode_history(histories[i]) for i in champion_indices], dim=0)
p_champ, v_champ = champion(x_champ)
# Process Challenger moves
for idx, i in enumerate(challenger_indices):
p_logits = p_chal[idx, -1, :]
v = v_chal[idx, -1].item()
action = sample_move(p_logits.unsqueeze(0), boards[i])
if worker_id == 0 and i == 0:
with ui_lock:
current_game_pgn = boards[0].fen()
turn_white_ui = boards[0].turn
current_eval = v if turn_white_ui else -v
current_sf_eval = max(-1.0, min(1.0, evals_current[0] / 500.0))
import time
time.sleep(0.05)
game_data[i]["states"].append(encode_history(histories[i]).cpu())
game_data[i]["actions"].append(VOCAB.get(action, 0))
boards[i].push_uci(action)
eval_next_cp = evaluate_position(boards[i])
turn_white_val = not boards[i].turn
adv = (eval_next_cp - evals_current[i]) if turn_white_val else (evals_current[i] - eval_next_cp)
game_data[i]["advantages"].append(adv)
sf_val = max(-1.0, min(1.0, evals_current[i] / 500.0))
if not turn_white_val: sf_val = -sf_val
game_data[i]["sf_values"].append(sf_val)
evals_current[i] = eval_next_cp
histories[i].append(action)
if abs(eval_next_cp) > 800 or boards[i].is_game_over():
active[i] = False
# Process Champion moves
for idx, i in enumerate(champion_indices):
p_logits = p_champ[idx, -1, :]
action = sample_move(p_logits.unsqueeze(0), boards[i])
boards[i].push_uci(action)
eval_next_cp = evaluate_position(boards[i])
evals_current[i] = eval_next_cp
histories[i].append(action)
if abs(eval_next_cp) > 800 or boards[i].is_game_over():
active[i] = False
# Calculate rewards and push to replay buffer
for i in range(BATCH_SIZE):
outcome = boards[i].outcome()
if outcome is None:
if evals_current[i] >= 500: reward = 1.0 if challenger_is_white[i] else -1.0
elif evals_current[i] <= -500: reward = -1.0 if challenger_is_white[i] else 1.0
else: reward = 0.0
elif outcome.winner is None:
reward = 0.0
else:
reward = 1.0 if outcome.winner == challenger_is_white[i] else -1.0
with stats_lock:
if reward > 0:
challenger_wins += 1
recent_outcomes.append(1)
elif reward < 0:
champion_wins += 1
recent_outcomes.append(-1)
else:
draws += 1
recent_outcomes.append(0)
states = game_data[i]["states"]
if len(states) > 0:
adv_tensor = torch.tensor(game_data[i]["advantages"], dtype=torch.float32)
if adv_tensor.std() > 0:
adv_tensor = (adv_tensor - adv_tensor.mean()) / (adv_tensor.std() + 1e-8)
else:
adv_tensor = adv_tensor - adv_tensor.mean()
for j in range(len(states)):
replay_buffer.append((states[j], game_data[i]["actions"][j], adv_tensor[j].item(), game_data[i]["sf_values"][j]))
def learner_worker():
"""Background thread that continuously samples the Replay Buffer and updates the Neural Network."""
global training_stats
print("Learner started.")
batch_size = 128
scaler = torch.amp.GradScaler('cuda')
import time
while True:
if len(replay_buffer) < batch_size:
time.sleep(1)
continue
batch = replay_buffer.sample(batch_size)
# Move batch to GPU
s_batch = torch.cat([b[0] for b in batch]).to(device)
a_batch = torch.tensor([b[1] for b in batch], dtype=torch.long, device=device)
adv_batch = torch.tensor([b[2] for b in batch], dtype=torch.float32, device=device)
sf_val_batch = torch.tensor([b[3] for b in batch], dtype=torch.float32, device=device)
optimizer.zero_grad()
with torch.autocast(device_type='cuda', dtype=torch.float16):
p, v_pred = challenger(s_batch)
p_logits = p[:, -1, :]
v_pred = v_pred[:, -1].squeeze(-1)
log_prob = F.log_softmax(p_logits, dim=-1)
action_log_probs = log_prob[torch.arange(batch_size), a_batch]
# CRITICAL FIX: Only train on positive advantages (Advantage-Weighted Behavioral Cloning).
# If we allow negative advantages, the optimizer pushes log_prob to negative infinity,
# causing the policy_loss to explode into massive negative numbers (e.g. -59.7) and
# completely destroying the Neural Network's weights (including the Value Head).
positive_adv = torch.clamp(adv_batch, min=0.0)
policy_loss = -(action_log_probs * positive_adv).mean()
value_loss = F.mse_loss(v_pred, sf_val_batch)
loss = policy_loss + 0.5 * value_loss
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(challenger.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
with stats_lock:
training_stats["epoch"] += 1
training_stats["loss"] = loss.item()
if len(recent_outcomes) > 0:
win_rate = sum(1 for x in recent_outcomes if x == 1) / len(recent_outcomes)
training_stats["reward"] = win_rate
# If Challenger is consistently crushing the Champion, promote it!
if len(recent_outcomes) >= 50 and win_rate >= 0.55:
print(f"\\n>>> PROMOTING CHALLENGER! Win rate: {win_rate:.2f} <<<\\n")
global last_promoted_step
last_promoted_step = training_stats["epoch"]
champion.load_state_dict(challenger.state_dict())
recent_outcomes.clear()
torch.save({
"epoch": training_stats["epoch"],
"model_state_dict": challenger.state_dict()
}, "rl_weights/champion_latest.pth")
if training_stats["epoch"] % 100 == 0:
print(f"Step {training_stats['epoch']} | Loss: {loss.item():.4f} | Win Rate: {training_stats['reward']:.2f} | Buffer: {len(replay_buffer)}")
torch.save({
"epoch": training_stats["epoch"],
"model_state_dict": challenger.state_dict()
}, "rl_weights/challenger_latest.pth")
def init_models():
global champion, challenger, optimizer
os.makedirs("rl_weights", exist_ok=True)
champion = ChessTransformer(vocab_size=len(VOCAB), d_model=512, nhead=8, num_layers=6, max_length=120).to(device)
challenger = ChessTransformer(vocab_size=len(VOCAB), d_model=512, nhead=8, num_layers=6, max_length=120).to(device)
latest_rl_weights = "rl_weights/champion_latest.pth"
fast_weights = "weights/chess_fast_best.pth"
start_epoch = 0
if os.path.exists(latest_rl_weights):
print(f"Loading latest RL champion from {latest_rl_weights}")
ckpt = torch.load(latest_rl_weights, map_location=device)
if "epoch" in ckpt and isinstance(ckpt, dict):
start_epoch = ckpt["epoch"]
training_stats["epoch"] = start_epoch
global last_promoted_step
last_promoted_step = start_epoch
if "model_state_dict" in ckpt:
ckpt = ckpt["model_state_dict"]
champion.load_state_dict(ckpt)
challenger.load_state_dict(ckpt)
elif os.path.exists(fast_weights):
print(f"Loading base Fast weights from {fast_weights}")
ckpt = torch.load(fast_weights, map_location=device)
if "epoch" in ckpt and isinstance(ckpt, dict):
start_epoch = ckpt["epoch"]
training_stats["epoch"] = start_epoch
if "model_state_dict" in ckpt:
ckpt = ckpt["model_state_dict"]
champion.load_state_dict(ckpt)
challenger.load_state_dict(ckpt)
champion.eval()
challenger.train()
optimizer = torch.optim.AdamW(challenger.parameters(), lr=1e-5)
def build_ui():
def get_state():
try:
import json
with ui_lock:
b = chess.Board(current_game_pgn)
eval_val = current_eval
sf_eval_val = current_sf_eval
white_name = "Challenger" if current_challenger_is_white else "Champion"
black_name = "Champion" if current_challenger_is_white else "Challenger"
return json.dumps({
"fen": b.fen(),
"eval": eval_val,
"sf_eval": sf_eval_val,
"white_name": white_name,
"black_name": black_name
})
except:
import json
return json.dumps({"fen": chess.STARTING_FEN, "eval": 0.0, "sf_eval": 0.0, "white_name": "", "black_name": ""})
def get_stats():
with stats_lock:
return [
["Training Steps", str(training_stats['epoch'])],
["Last Promoted Step", str(last_promoted_step)],
["Loss", f"{training_stats['loss']:.4f}"],
["Recent Win Rate", f"{training_stats['reward']:.2f}"],
["Replay Buffer Size", str(len(replay_buffer))],
["Challenger Wins", str(challenger_wins)],
["Champion Wins", str(champion_wins)],
["Draws", str(draws)]
]
custom_css = """
@import url('https://fonts.googleapis.com/css2?family=Outfit:wght@400;700&display=swap');
body, .gradio-container {
font-family: 'Outfit', sans-serif !important;
background: linear-gradient(135deg, #0f2027, #203a43, #2c5364) !important;
background-attachment: fixed !important;
color: white !important;
}
.gradio-container { border: none !important; }
.glass-panel {
background: rgba(255, 255, 255, 0.1) !important;
backdrop-filter: blur(10px) !important;
border-radius: 12px !important;
border: 1px solid rgba(255, 255, 255, 0.18) !important;
padding: 20px !important;
box-shadow: 0 8px 32px 0 rgba(31, 38, 135, 0.37) !important;
}
.eval-bar-container {
width: 30px;
height: 400px;
background-color: #333;
border-radius: 4px;
border: 4px solid #fff;
position: relative;
overflow: hidden;
display: flex;
flex-direction: column-reverse;
box-shadow: 0 15px 35px rgba(0,0,0,0.5);
}
.eval-bar-fill {
width: 100%;
height: 50%;
background-color: #fff;
transition: height 0.5s cubic-bezier(0.4, 0, 0.2, 1);
}
.eval-marker {
position: absolute;
top: 50%;
left: 0;
width: 100%;
height: 2px;
background-color: #ff5e7e;
z-index: 10;
}
"""
with gr.Blocks(title="Neurex RL Dashboard", css=custom_css) as demo:
gr.HTML("<h1 style='text-align: center; color: white; font-weight: 700; font-size: 2.5rem; text-shadow: 2px 2px 10px rgba(0,0,0,0.5);'>🧠 Neurex RL Self-Play Dashboard</h1>")
gr.HTML("<p style='text-align: center; color: #ddd; font-size: 1.1rem;'><b>ASYNCHRONOUS ALPHA-ZERO MODE</b> | Real-time Actor-Learner Architecture</p>")
with gr.Row():
with gr.Column(elem_classes=["glass-panel"]):
board_html = """
<div style="display: flex; flex-direction: column; align-items: center;">
<div id="blackName" style="font-size: 1.3rem; font-weight: bold; margin-bottom: 12px; color: #ff5e7e; text-shadow: 1px 1px 5px rgba(0,0,0,0.5);">Black</div>
<div style="display: flex; align-items: center; gap: 15px; justify-content: center;">
<div class="eval-bar-container" title="Neural Network Evaluation">
<div class="eval-bar-fill" id="evalBar"></div>
<div class="eval-marker"></div>
</div>
<div id="board" style="width: 400px; box-shadow: 0 15px 35px rgba(0,0,0,0.5); border: 4px solid #fff; border-radius: 4px; overflow: hidden;"></div>
<div class="eval-bar-container" title="Stockfish Evaluation">
<div class="eval-bar-fill" id="sfEvalBar" style="background-color: #00ff88;"></div>
<div class="eval-marker"></div>
</div>
</div>
<div id="whiteName" style="font-size: 1.3rem; font-weight: bold; margin-top: 12px; color: #00ff88; text-shadow: 1px 1px 5px rgba(0,0,0,0.5);">White</div>
</div>
"""
board_view = gr.HTML(board_html)
current_state_box = gr.Textbox(visible=False)
with gr.Column(elem_classes=["glass-panel"]):
stats_view = gr.Dataframe(headers=["Metric", "Value"], interactive=False)
timer = gr.Timer(0.5)
timer.tick(get_state, inputs=[], outputs=[current_state_box])
timer.tick(get_stats, inputs=[], outputs=[stats_view])
js_callback = """
(state_str) => {
try {
let state = JSON.parse(state_str);
if (window.my_board) window.my_board.position(state.fen);
let evalBar = document.getElementById('evalBar');
if (evalBar) {
let heightPercent = ((state.eval + 1.0) / 2.0) * 100;
heightPercent = Math.max(0, Math.min(100, heightPercent));
evalBar.style.height = heightPercent + '%';
}
let sfEvalBar = document.getElementById('sfEvalBar');
if (sfEvalBar) {
let sfHeightPercent = ((state.sf_eval + 1.0) / 2.0) * 100;
sfHeightPercent = Math.max(0, Math.min(100, sfHeightPercent));
sfEvalBar.style.height = sfHeightPercent + '%';
}
let blackName = document.getElementById('blackName');
if (blackName && state.black_name) {
let dot = state.black_name === "Challenger" ? "🟢" : "🔴";
let color = state.black_name === "Challenger" ? "#00ff88" : "#ff5e7e";
blackName.innerText = dot + " " + state.black_name + " (Black)";
blackName.style.color = color;
}
let whiteName = document.getElementById('whiteName');
if (whiteName && state.white_name) {
let dot = state.white_name === "Challenger" ? "🟢" : "🔴";
let color = state.white_name === "Challenger" ? "#00ff88" : "#ff5e7e";
whiteName.innerText = dot + " " + state.white_name + " (White)";
whiteName.style.color = color;
}
} catch(e) {}
return state_str;
}
"""
current_state_box.change(None, inputs=[current_state_box], js=js_callback)
init_js = """
function() {
var jq = document.createElement('script');
jq.src = "https://code.jquery.com/jquery-3.5.1.min.js";
document.head.appendChild(jq);
var css = document.createElement('link');
css.rel = "stylesheet";
css.href = "https://unpkg.com/@chrisoakman/chessboardjs@1.0.0/dist/chessboard-1.0.0.min.css";
document.head.appendChild(css);
jq.onload = function() {
var cb = document.createElement('script');
cb.src = "https://unpkg.com/@chrisoakman/chessboardjs@1.0.0/dist/chessboard-1.0.0.min.js";
document.head.appendChild(cb);
cb.onload = function() {
let checkExist = setInterval(function() {
if (document.getElementById('board')) {
window.my_board = Chessboard('board', {
position: 'start',
pieceTheme: 'https://chessboardjs.com/img/chesspieces/wikipedia/{piece}.png'
});
clearInterval(checkExist);
}
}, 100);
};
};
}
"""
demo.load(None, None, None, js=init_js)
return demo
if __name__ == "__main__":
init_models()
# Spawn 3 Actor Threads to play games using 3 Stockfish instances
for i in range(4):
t = threading.Thread(target=actor_worker, args=(i,), daemon=True)
t.start()
# Spawn 1 Learner Thread to aggressively train the GPU
t_learner = threading.Thread(target=learner_worker, daemon=True)
t_learner.start()
# Launch Gradio UI in main thread
demo = build_ui()
demo.launch(server_name="0.0.0.0", prevent_thread_lock=False)