import torch import torch.nn as nn import torch.nn.functional as F from model import ChessTransformer from data_loader import VOCAB import os import random import chess import chess.engine import gradio as gr import threading import time import collections import numpy as np # Global variables for UI monitoring current_game_pgn = "" current_eval = 0.0 current_sf_eval = 0.0 champion_wins = 0 challenger_wins = 0 draws = 0 training_stats = {"loss": 0.0, "reward": 0.0, "epoch": 0} current_challenger_is_white = True last_promoted_step = 0 # Thread-safe queues and locks class ReplayBuffer: def __init__(self, capacity=50000): self.buffer = [] self.capacity = capacity self.ptr = 0 self.lock = threading.Lock() def append(self, item): with self.lock: if len(self.buffer) < self.capacity: self.buffer.append(item) else: self.buffer[self.ptr] = item self.ptr = (self.ptr + 1) % self.capacity def __len__(self): with self.lock: return len(self.buffer) def sample(self, batch_size): with self.lock: return random.sample(self.buffer, batch_size) replay_buffer = ReplayBuffer(50000) recent_outcomes = collections.deque(maxlen=100) ui_lock = threading.Lock() stats_lock = threading.Lock() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") INV_VOCAB = {v: k for k, v in VOCAB.items()} # Global Models champion = None challenger = None optimizer = None def get_stockfish_engine(): stockfish_path = os.path.join("stockfish", "stockfish-windows-x86-64-avx2.exe") if os.path.exists(stockfish_path): sf_engine = chess.engine.SimpleEngine.popen_uci(stockfish_path) sf_engine.configure({"Hash": 64, "Threads": 1}) return sf_engine return None def encode_history(history, max_length=120): seq = [] for tok in history: seq.append(VOCAB.get(tok, VOCAB.get("", 0))) if len(seq) > max_length: seq = seq[-max_length:] else: seq = seq + [0] * (max_length - len(seq)) return torch.tensor(seq, dtype=torch.long, device=device).unsqueeze(0) def sample_move(policy_logits, board, temperature=1.0): logits = policy_logits / temperature legal_moves = list(board.legal_moves) legal_ucis = [m.uci() for m in legal_moves] mask = torch.full_like(logits, float('-inf')) for idx, token in INV_VOCAB.items(): if token in legal_ucis: mask[0, idx] = 0.0 masked_logits = logits + mask probs = F.softmax(masked_logits, dim=-1) if torch.isnan(probs).any() or probs.sum() == 0: return random.choice(legal_moves).uci() m = torch.multinomial(probs[0], 1).item() action = INV_VOCAB.get(m) if action not in legal_ucis: return random.choice(legal_moves).uci() return action def actor_worker(worker_id): """Background thread playing BATCH_SIZE games concurrently.""" global current_game_pgn, current_eval, current_sf_eval, current_challenger_is_white global champion_wins, challenger_wins, draws BATCH_SIZE = 16 sf_engine = get_stockfish_engine() sf_limit = chess.engine.Limit(time=0.05) def evaluate_position(board): if sf_engine is None: return 0.0 info = sf_engine.analyse(board, sf_limit) score = info["score"].white() if score.is_mate(): return 10000 if score.mate() > 0 else -10000 return score.score() print(f"Actor {worker_id} started with Batch Size {BATCH_SIZE}.") while True: boards = [chess.Board() for _ in range(BATCH_SIZE)] histories = [[""] for _ in range(BATCH_SIZE)] active = [True for _ in range(BATCH_SIZE)] evals_current = [] for b in boards: evals_current.append(evaluate_position(b)) import random challenger_is_white = [random.choice([True, False]) for _ in range(BATCH_SIZE)] if worker_id == 0: with ui_lock: current_challenger_is_white = challenger_is_white[0] game_data = [{"states": [], "actions": [], "advantages": [], "sf_values": []} for _ in range(BATCH_SIZE)] turn_count = 0 while any(active) and turn_count < 200: turn_count += 1 challenger_indices = [] champion_indices = [] for i in range(BATCH_SIZE): if not active[i]: continue turn_white = boards[i].turn if turn_white == challenger_is_white[i]: challenger_indices.append(i) else: champion_indices.append(i) p_chal, v_chal = None, None p_champ, v_champ = None, None with torch.no_grad(): with torch.autocast(device_type='cuda', dtype=torch.float16): if len(challenger_indices) > 0: x_chal = torch.cat([encode_history(histories[i]) for i in challenger_indices], dim=0) p_chal, v_chal = challenger(x_chal) if len(champion_indices) > 0: x_champ = torch.cat([encode_history(histories[i]) for i in champion_indices], dim=0) p_champ, v_champ = champion(x_champ) # Process Challenger moves for idx, i in enumerate(challenger_indices): p_logits = p_chal[idx, -1, :] v = v_chal[idx, -1].item() action = sample_move(p_logits.unsqueeze(0), boards[i]) if worker_id == 0 and i == 0: with ui_lock: current_game_pgn = boards[0].fen() turn_white_ui = boards[0].turn current_eval = v if turn_white_ui else -v current_sf_eval = max(-1.0, min(1.0, evals_current[0] / 500.0)) import time time.sleep(0.05) game_data[i]["states"].append(encode_history(histories[i]).cpu()) game_data[i]["actions"].append(VOCAB.get(action, 0)) boards[i].push_uci(action) eval_next_cp = evaluate_position(boards[i]) turn_white_val = not boards[i].turn adv = (eval_next_cp - evals_current[i]) if turn_white_val else (evals_current[i] - eval_next_cp) game_data[i]["advantages"].append(adv) sf_val = max(-1.0, min(1.0, evals_current[i] / 500.0)) if not turn_white_val: sf_val = -sf_val game_data[i]["sf_values"].append(sf_val) evals_current[i] = eval_next_cp histories[i].append(action) if abs(eval_next_cp) > 800 or boards[i].is_game_over(): active[i] = False # Process Champion moves for idx, i in enumerate(champion_indices): p_logits = p_champ[idx, -1, :] action = sample_move(p_logits.unsqueeze(0), boards[i]) boards[i].push_uci(action) eval_next_cp = evaluate_position(boards[i]) evals_current[i] = eval_next_cp histories[i].append(action) if abs(eval_next_cp) > 800 or boards[i].is_game_over(): active[i] = False # Calculate rewards and push to replay buffer for i in range(BATCH_SIZE): outcome = boards[i].outcome() if outcome is None: if evals_current[i] >= 500: reward = 1.0 if challenger_is_white[i] else -1.0 elif evals_current[i] <= -500: reward = -1.0 if challenger_is_white[i] else 1.0 else: reward = 0.0 elif outcome.winner is None: reward = 0.0 else: reward = 1.0 if outcome.winner == challenger_is_white[i] else -1.0 with stats_lock: if reward > 0: challenger_wins += 1 recent_outcomes.append(1) elif reward < 0: champion_wins += 1 recent_outcomes.append(-1) else: draws += 1 recent_outcomes.append(0) states = game_data[i]["states"] if len(states) > 0: adv_tensor = torch.tensor(game_data[i]["advantages"], dtype=torch.float32) if adv_tensor.std() > 0: adv_tensor = (adv_tensor - adv_tensor.mean()) / (adv_tensor.std() + 1e-8) else: adv_tensor = adv_tensor - adv_tensor.mean() for j in range(len(states)): replay_buffer.append((states[j], game_data[i]["actions"][j], adv_tensor[j].item(), game_data[i]["sf_values"][j])) def learner_worker(): """Background thread that continuously samples the Replay Buffer and updates the Neural Network.""" global training_stats print("Learner started.") batch_size = 128 scaler = torch.amp.GradScaler('cuda') import time while True: if len(replay_buffer) < batch_size: time.sleep(1) continue batch = replay_buffer.sample(batch_size) # Move batch to GPU s_batch = torch.cat([b[0] for b in batch]).to(device) a_batch = torch.tensor([b[1] for b in batch], dtype=torch.long, device=device) adv_batch = torch.tensor([b[2] for b in batch], dtype=torch.float32, device=device) sf_val_batch = torch.tensor([b[3] for b in batch], dtype=torch.float32, device=device) optimizer.zero_grad() with torch.autocast(device_type='cuda', dtype=torch.float16): p, v_pred = challenger(s_batch) p_logits = p[:, -1, :] v_pred = v_pred[:, -1].squeeze(-1) log_prob = F.log_softmax(p_logits, dim=-1) action_log_probs = log_prob[torch.arange(batch_size), a_batch] # CRITICAL FIX: Only train on positive advantages (Advantage-Weighted Behavioral Cloning). # If we allow negative advantages, the optimizer pushes log_prob to negative infinity, # causing the policy_loss to explode into massive negative numbers (e.g. -59.7) and # completely destroying the Neural Network's weights (including the Value Head). positive_adv = torch.clamp(adv_batch, min=0.0) policy_loss = -(action_log_probs * positive_adv).mean() value_loss = F.mse_loss(v_pred, sf_val_batch) loss = policy_loss + 0.5 * value_loss scaler.scale(loss).backward() scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(challenger.parameters(), 1.0) scaler.step(optimizer) scaler.update() with stats_lock: training_stats["epoch"] += 1 training_stats["loss"] = loss.item() if len(recent_outcomes) > 0: win_rate = sum(1 for x in recent_outcomes if x == 1) / len(recent_outcomes) training_stats["reward"] = win_rate # If Challenger is consistently crushing the Champion, promote it! if len(recent_outcomes) >= 50 and win_rate >= 0.55: print(f"\\n>>> PROMOTING CHALLENGER! Win rate: {win_rate:.2f} <<<\\n") global last_promoted_step last_promoted_step = training_stats["epoch"] champion.load_state_dict(challenger.state_dict()) recent_outcomes.clear() torch.save({ "epoch": training_stats["epoch"], "model_state_dict": challenger.state_dict() }, "rl_weights/champion_latest.pth") if training_stats["epoch"] % 100 == 0: print(f"Step {training_stats['epoch']} | Loss: {loss.item():.4f} | Win Rate: {training_stats['reward']:.2f} | Buffer: {len(replay_buffer)}") torch.save({ "epoch": training_stats["epoch"], "model_state_dict": challenger.state_dict() }, "rl_weights/challenger_latest.pth") def init_models(): global champion, challenger, optimizer os.makedirs("rl_weights", exist_ok=True) champion = ChessTransformer(vocab_size=len(VOCAB), d_model=512, nhead=8, num_layers=6, max_length=120).to(device) challenger = ChessTransformer(vocab_size=len(VOCAB), d_model=512, nhead=8, num_layers=6, max_length=120).to(device) latest_rl_weights = "rl_weights/champion_latest.pth" fast_weights = "weights/chess_fast_best.pth" start_epoch = 0 if os.path.exists(latest_rl_weights): print(f"Loading latest RL champion from {latest_rl_weights}") ckpt = torch.load(latest_rl_weights, map_location=device) if "epoch" in ckpt and isinstance(ckpt, dict): start_epoch = ckpt["epoch"] training_stats["epoch"] = start_epoch global last_promoted_step last_promoted_step = start_epoch if "model_state_dict" in ckpt: ckpt = ckpt["model_state_dict"] champion.load_state_dict(ckpt) challenger.load_state_dict(ckpt) elif os.path.exists(fast_weights): print(f"Loading base Fast weights from {fast_weights}") ckpt = torch.load(fast_weights, map_location=device) if "epoch" in ckpt and isinstance(ckpt, dict): start_epoch = ckpt["epoch"] training_stats["epoch"] = start_epoch if "model_state_dict" in ckpt: ckpt = ckpt["model_state_dict"] champion.load_state_dict(ckpt) challenger.load_state_dict(ckpt) champion.eval() challenger.train() optimizer = torch.optim.AdamW(challenger.parameters(), lr=1e-5) def build_ui(): def get_state(): try: import json with ui_lock: b = chess.Board(current_game_pgn) eval_val = current_eval sf_eval_val = current_sf_eval white_name = "Challenger" if current_challenger_is_white else "Champion" black_name = "Champion" if current_challenger_is_white else "Challenger" return json.dumps({ "fen": b.fen(), "eval": eval_val, "sf_eval": sf_eval_val, "white_name": white_name, "black_name": black_name }) except: import json return json.dumps({"fen": chess.STARTING_FEN, "eval": 0.0, "sf_eval": 0.0, "white_name": "", "black_name": ""}) def get_stats(): with stats_lock: return [ ["Training Steps", str(training_stats['epoch'])], ["Last Promoted Step", str(last_promoted_step)], ["Loss", f"{training_stats['loss']:.4f}"], ["Recent Win Rate", f"{training_stats['reward']:.2f}"], ["Replay Buffer Size", str(len(replay_buffer))], ["Challenger Wins", str(challenger_wins)], ["Champion Wins", str(champion_wins)], ["Draws", str(draws)] ] custom_css = """ @import url('https://fonts.googleapis.com/css2?family=Outfit:wght@400;700&display=swap'); body, .gradio-container { font-family: 'Outfit', sans-serif !important; background: linear-gradient(135deg, #0f2027, #203a43, #2c5364) !important; background-attachment: fixed !important; color: white !important; } .gradio-container { border: none !important; } .glass-panel { background: rgba(255, 255, 255, 0.1) !important; backdrop-filter: blur(10px) !important; border-radius: 12px !important; border: 1px solid rgba(255, 255, 255, 0.18) !important; padding: 20px !important; box-shadow: 0 8px 32px 0 rgba(31, 38, 135, 0.37) !important; } .eval-bar-container { width: 30px; height: 400px; background-color: #333; border-radius: 4px; border: 4px solid #fff; position: relative; overflow: hidden; display: flex; flex-direction: column-reverse; box-shadow: 0 15px 35px rgba(0,0,0,0.5); } .eval-bar-fill { width: 100%; height: 50%; background-color: #fff; transition: height 0.5s cubic-bezier(0.4, 0, 0.2, 1); } .eval-marker { position: absolute; top: 50%; left: 0; width: 100%; height: 2px; background-color: #ff5e7e; z-index: 10; } """ with gr.Blocks(title="Neurex RL Dashboard", css=custom_css) as demo: gr.HTML("

🧠 Neurex RL Self-Play Dashboard

") gr.HTML("

ASYNCHRONOUS ALPHA-ZERO MODE | Real-time Actor-Learner Architecture

") with gr.Row(): with gr.Column(elem_classes=["glass-panel"]): board_html = """
Black
White
""" board_view = gr.HTML(board_html) current_state_box = gr.Textbox(visible=False) with gr.Column(elem_classes=["glass-panel"]): stats_view = gr.Dataframe(headers=["Metric", "Value"], interactive=False) timer = gr.Timer(0.5) timer.tick(get_state, inputs=[], outputs=[current_state_box]) timer.tick(get_stats, inputs=[], outputs=[stats_view]) js_callback = """ (state_str) => { try { let state = JSON.parse(state_str); if (window.my_board) window.my_board.position(state.fen); let evalBar = document.getElementById('evalBar'); if (evalBar) { let heightPercent = ((state.eval + 1.0) / 2.0) * 100; heightPercent = Math.max(0, Math.min(100, heightPercent)); evalBar.style.height = heightPercent + '%'; } let sfEvalBar = document.getElementById('sfEvalBar'); if (sfEvalBar) { let sfHeightPercent = ((state.sf_eval + 1.0) / 2.0) * 100; sfHeightPercent = Math.max(0, Math.min(100, sfHeightPercent)); sfEvalBar.style.height = sfHeightPercent + '%'; } let blackName = document.getElementById('blackName'); if (blackName && state.black_name) { let dot = state.black_name === "Challenger" ? "🟢" : "🔴"; let color = state.black_name === "Challenger" ? "#00ff88" : "#ff5e7e"; blackName.innerText = dot + " " + state.black_name + " (Black)"; blackName.style.color = color; } let whiteName = document.getElementById('whiteName'); if (whiteName && state.white_name) { let dot = state.white_name === "Challenger" ? "🟢" : "🔴"; let color = state.white_name === "Challenger" ? "#00ff88" : "#ff5e7e"; whiteName.innerText = dot + " " + state.white_name + " (White)"; whiteName.style.color = color; } } catch(e) {} return state_str; } """ current_state_box.change(None, inputs=[current_state_box], js=js_callback) init_js = """ function() { var jq = document.createElement('script'); jq.src = "https://code.jquery.com/jquery-3.5.1.min.js"; document.head.appendChild(jq); var css = document.createElement('link'); css.rel = "stylesheet"; css.href = "https://unpkg.com/@chrisoakman/chessboardjs@1.0.0/dist/chessboard-1.0.0.min.css"; document.head.appendChild(css); jq.onload = function() { var cb = document.createElement('script'); cb.src = "https://unpkg.com/@chrisoakman/chessboardjs@1.0.0/dist/chessboard-1.0.0.min.js"; document.head.appendChild(cb); cb.onload = function() { let checkExist = setInterval(function() { if (document.getElementById('board')) { window.my_board = Chessboard('board', { position: 'start', pieceTheme: 'https://chessboardjs.com/img/chesspieces/wikipedia/{piece}.png' }); clearInterval(checkExist); } }, 100); }; }; } """ demo.load(None, None, None, js=init_js) return demo if __name__ == "__main__": init_models() # Spawn 3 Actor Threads to play games using 3 Stockfish instances for i in range(4): t = threading.Thread(target=actor_worker, args=(i,), daemon=True) t.start() # Spawn 1 Learner Thread to aggressively train the GPU t_learner = threading.Thread(target=learner_worker, daemon=True) t_learner.start() # Launch Gradio UI in main thread demo = build_ui() demo.launch(server_name="0.0.0.0", prevent_thread_lock=False)