| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| from model import ChessTransformer
|
| from data_loader import VOCAB
|
| import os
|
| import random
|
| import chess
|
| import chess.engine
|
| import gradio as gr
|
| import threading
|
| import time
|
| import collections
|
| import numpy as np
|
|
|
|
|
| current_game_pgn = ""
|
| current_eval = 0.0
|
| current_sf_eval = 0.0
|
| champion_wins = 0
|
| challenger_wins = 0
|
| draws = 0
|
| training_stats = {"loss": 0.0, "reward": 0.0, "epoch": 0}
|
| current_challenger_is_white = True
|
| last_promoted_step = 0
|
|
|
|
|
| class ReplayBuffer:
|
| def __init__(self, capacity=50000):
|
| self.buffer = []
|
| self.capacity = capacity
|
| self.ptr = 0
|
| self.lock = threading.Lock()
|
|
|
| def append(self, item):
|
| with self.lock:
|
| if len(self.buffer) < self.capacity:
|
| self.buffer.append(item)
|
| else:
|
| self.buffer[self.ptr] = item
|
| self.ptr = (self.ptr + 1) % self.capacity
|
|
|
| def __len__(self):
|
| with self.lock:
|
| return len(self.buffer)
|
|
|
| def sample(self, batch_size):
|
| with self.lock:
|
| return random.sample(self.buffer, batch_size)
|
|
|
| replay_buffer = ReplayBuffer(50000)
|
| recent_outcomes = collections.deque(maxlen=100)
|
| ui_lock = threading.Lock()
|
| stats_lock = threading.Lock()
|
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| print(f"Using device: {device}")
|
|
|
| INV_VOCAB = {v: k for k, v in VOCAB.items()}
|
|
|
|
|
| champion = None
|
| challenger = None
|
| optimizer = None
|
|
|
| def get_stockfish_engine():
|
| stockfish_path = os.path.join("stockfish", "stockfish-windows-x86-64-avx2.exe")
|
| if os.path.exists(stockfish_path):
|
| sf_engine = chess.engine.SimpleEngine.popen_uci(stockfish_path)
|
| sf_engine.configure({"Hash": 64, "Threads": 1})
|
| return sf_engine
|
| return None
|
|
|
| def encode_history(history, max_length=120):
|
| seq = []
|
| for tok in history:
|
| seq.append(VOCAB.get(tok, VOCAB.get("<unk>", 0)))
|
| if len(seq) > max_length:
|
| seq = seq[-max_length:]
|
| else:
|
| seq = seq + [0] * (max_length - len(seq))
|
| return torch.tensor(seq, dtype=torch.long, device=device).unsqueeze(0)
|
|
|
| def sample_move(policy_logits, board, temperature=1.0):
|
| logits = policy_logits / temperature
|
| legal_moves = list(board.legal_moves)
|
| legal_ucis = [m.uci() for m in legal_moves]
|
|
|
| mask = torch.full_like(logits, float('-inf'))
|
| for idx, token in INV_VOCAB.items():
|
| if token in legal_ucis:
|
| mask[0, idx] = 0.0
|
|
|
| masked_logits = logits + mask
|
| probs = F.softmax(masked_logits, dim=-1)
|
|
|
| if torch.isnan(probs).any() or probs.sum() == 0:
|
| return random.choice(legal_moves).uci()
|
|
|
| m = torch.multinomial(probs[0], 1).item()
|
| action = INV_VOCAB.get(m)
|
|
|
| if action not in legal_ucis:
|
| return random.choice(legal_moves).uci()
|
| return action
|
|
|
| def actor_worker(worker_id):
|
| """Background thread playing BATCH_SIZE games concurrently."""
|
| global current_game_pgn, current_eval, current_sf_eval, current_challenger_is_white
|
| global champion_wins, challenger_wins, draws
|
|
|
| BATCH_SIZE = 16
|
| sf_engine = get_stockfish_engine()
|
| sf_limit = chess.engine.Limit(time=0.05)
|
|
|
| def evaluate_position(board):
|
| if sf_engine is None: return 0.0
|
| info = sf_engine.analyse(board, sf_limit)
|
| score = info["score"].white()
|
| if score.is_mate(): return 10000 if score.mate() > 0 else -10000
|
| return score.score()
|
|
|
| print(f"Actor {worker_id} started with Batch Size {BATCH_SIZE}.")
|
|
|
| while True:
|
| boards = [chess.Board() for _ in range(BATCH_SIZE)]
|
| histories = [["<bos>"] for _ in range(BATCH_SIZE)]
|
| active = [True for _ in range(BATCH_SIZE)]
|
|
|
| evals_current = []
|
| for b in boards:
|
| evals_current.append(evaluate_position(b))
|
|
|
| import random
|
| challenger_is_white = [random.choice([True, False]) for _ in range(BATCH_SIZE)]
|
|
|
| if worker_id == 0:
|
| with ui_lock:
|
| current_challenger_is_white = challenger_is_white[0]
|
|
|
| game_data = [{"states": [], "actions": [], "advantages": [], "sf_values": []} for _ in range(BATCH_SIZE)]
|
|
|
| turn_count = 0
|
| while any(active) and turn_count < 200:
|
| turn_count += 1
|
|
|
| challenger_indices = []
|
| champion_indices = []
|
|
|
| for i in range(BATCH_SIZE):
|
| if not active[i]:
|
| continue
|
| turn_white = boards[i].turn
|
| if turn_white == challenger_is_white[i]:
|
| challenger_indices.append(i)
|
| else:
|
| champion_indices.append(i)
|
|
|
| p_chal, v_chal = None, None
|
| p_champ, v_champ = None, None
|
|
|
| with torch.no_grad():
|
| with torch.autocast(device_type='cuda', dtype=torch.float16):
|
| if len(challenger_indices) > 0:
|
| x_chal = torch.cat([encode_history(histories[i]) for i in challenger_indices], dim=0)
|
| p_chal, v_chal = challenger(x_chal)
|
| if len(champion_indices) > 0:
|
| x_champ = torch.cat([encode_history(histories[i]) for i in champion_indices], dim=0)
|
| p_champ, v_champ = champion(x_champ)
|
|
|
|
|
| for idx, i in enumerate(challenger_indices):
|
| p_logits = p_chal[idx, -1, :]
|
| v = v_chal[idx, -1].item()
|
| action = sample_move(p_logits.unsqueeze(0), boards[i])
|
|
|
| if worker_id == 0 and i == 0:
|
| with ui_lock:
|
| current_game_pgn = boards[0].fen()
|
| turn_white_ui = boards[0].turn
|
| current_eval = v if turn_white_ui else -v
|
| current_sf_eval = max(-1.0, min(1.0, evals_current[0] / 500.0))
|
| import time
|
| time.sleep(0.05)
|
|
|
| game_data[i]["states"].append(encode_history(histories[i]).cpu())
|
| game_data[i]["actions"].append(VOCAB.get(action, 0))
|
|
|
| boards[i].push_uci(action)
|
| eval_next_cp = evaluate_position(boards[i])
|
|
|
| turn_white_val = not boards[i].turn
|
| adv = (eval_next_cp - evals_current[i]) if turn_white_val else (evals_current[i] - eval_next_cp)
|
| game_data[i]["advantages"].append(adv)
|
|
|
| sf_val = max(-1.0, min(1.0, evals_current[i] / 500.0))
|
| if not turn_white_val: sf_val = -sf_val
|
| game_data[i]["sf_values"].append(sf_val)
|
|
|
| evals_current[i] = eval_next_cp
|
| histories[i].append(action)
|
|
|
| if abs(eval_next_cp) > 800 or boards[i].is_game_over():
|
| active[i] = False
|
|
|
|
|
| for idx, i in enumerate(champion_indices):
|
| p_logits = p_champ[idx, -1, :]
|
| action = sample_move(p_logits.unsqueeze(0), boards[i])
|
|
|
| boards[i].push_uci(action)
|
| eval_next_cp = evaluate_position(boards[i])
|
| evals_current[i] = eval_next_cp
|
| histories[i].append(action)
|
|
|
| if abs(eval_next_cp) > 800 or boards[i].is_game_over():
|
| active[i] = False
|
|
|
|
|
| for i in range(BATCH_SIZE):
|
| outcome = boards[i].outcome()
|
| if outcome is None:
|
| if evals_current[i] >= 500: reward = 1.0 if challenger_is_white[i] else -1.0
|
| elif evals_current[i] <= -500: reward = -1.0 if challenger_is_white[i] else 1.0
|
| else: reward = 0.0
|
| elif outcome.winner is None:
|
| reward = 0.0
|
| else:
|
| reward = 1.0 if outcome.winner == challenger_is_white[i] else -1.0
|
|
|
| with stats_lock:
|
| if reward > 0:
|
| challenger_wins += 1
|
| recent_outcomes.append(1)
|
| elif reward < 0:
|
| champion_wins += 1
|
| recent_outcomes.append(-1)
|
| else:
|
| draws += 1
|
| recent_outcomes.append(0)
|
|
|
| states = game_data[i]["states"]
|
| if len(states) > 0:
|
| adv_tensor = torch.tensor(game_data[i]["advantages"], dtype=torch.float32)
|
| if adv_tensor.std() > 0:
|
| adv_tensor = (adv_tensor - adv_tensor.mean()) / (adv_tensor.std() + 1e-8)
|
| else:
|
| adv_tensor = adv_tensor - adv_tensor.mean()
|
|
|
| for j in range(len(states)):
|
| replay_buffer.append((states[j], game_data[i]["actions"][j], adv_tensor[j].item(), game_data[i]["sf_values"][j]))
|
|
|
|
|
| def learner_worker():
|
| """Background thread that continuously samples the Replay Buffer and updates the Neural Network."""
|
| global training_stats
|
| print("Learner started.")
|
|
|
| batch_size = 128
|
| scaler = torch.amp.GradScaler('cuda')
|
|
|
| import time
|
| while True:
|
| if len(replay_buffer) < batch_size:
|
| time.sleep(1)
|
| continue
|
|
|
| batch = replay_buffer.sample(batch_size)
|
|
|
|
|
| s_batch = torch.cat([b[0] for b in batch]).to(device)
|
| a_batch = torch.tensor([b[1] for b in batch], dtype=torch.long, device=device)
|
| adv_batch = torch.tensor([b[2] for b in batch], dtype=torch.float32, device=device)
|
| sf_val_batch = torch.tensor([b[3] for b in batch], dtype=torch.float32, device=device)
|
|
|
| optimizer.zero_grad()
|
|
|
| with torch.autocast(device_type='cuda', dtype=torch.float16):
|
| p, v_pred = challenger(s_batch)
|
|
|
| p_logits = p[:, -1, :]
|
| v_pred = v_pred[:, -1].squeeze(-1)
|
|
|
| log_prob = F.log_softmax(p_logits, dim=-1)
|
| action_log_probs = log_prob[torch.arange(batch_size), a_batch]
|
|
|
|
|
|
|
|
|
|
|
| positive_adv = torch.clamp(adv_batch, min=0.0)
|
| policy_loss = -(action_log_probs * positive_adv).mean()
|
|
|
| value_loss = F.mse_loss(v_pred, sf_val_batch)
|
|
|
| loss = policy_loss + 0.5 * value_loss
|
|
|
| scaler.scale(loss).backward()
|
| scaler.unscale_(optimizer)
|
| torch.nn.utils.clip_grad_norm_(challenger.parameters(), 1.0)
|
| scaler.step(optimizer)
|
| scaler.update()
|
|
|
| with stats_lock:
|
| training_stats["epoch"] += 1
|
| training_stats["loss"] = loss.item()
|
|
|
| if len(recent_outcomes) > 0:
|
| win_rate = sum(1 for x in recent_outcomes if x == 1) / len(recent_outcomes)
|
| training_stats["reward"] = win_rate
|
|
|
|
|
| if len(recent_outcomes) >= 50 and win_rate >= 0.55:
|
| print(f"\\n>>> PROMOTING CHALLENGER! Win rate: {win_rate:.2f} <<<\\n")
|
| global last_promoted_step
|
| last_promoted_step = training_stats["epoch"]
|
| champion.load_state_dict(challenger.state_dict())
|
| recent_outcomes.clear()
|
|
|
| torch.save({
|
| "epoch": training_stats["epoch"],
|
| "model_state_dict": challenger.state_dict()
|
| }, "rl_weights/champion_latest.pth")
|
|
|
| if training_stats["epoch"] % 100 == 0:
|
| print(f"Step {training_stats['epoch']} | Loss: {loss.item():.4f} | Win Rate: {training_stats['reward']:.2f} | Buffer: {len(replay_buffer)}")
|
| torch.save({
|
| "epoch": training_stats["epoch"],
|
| "model_state_dict": challenger.state_dict()
|
| }, "rl_weights/challenger_latest.pth")
|
|
|
|
|
| def init_models():
|
| global champion, challenger, optimizer
|
|
|
| os.makedirs("rl_weights", exist_ok=True)
|
|
|
| champion = ChessTransformer(vocab_size=len(VOCAB), d_model=512, nhead=8, num_layers=6, max_length=120).to(device)
|
| challenger = ChessTransformer(vocab_size=len(VOCAB), d_model=512, nhead=8, num_layers=6, max_length=120).to(device)
|
|
|
| latest_rl_weights = "rl_weights/champion_latest.pth"
|
| fast_weights = "weights/chess_fast_best.pth"
|
|
|
| start_epoch = 0
|
| if os.path.exists(latest_rl_weights):
|
| print(f"Loading latest RL champion from {latest_rl_weights}")
|
| ckpt = torch.load(latest_rl_weights, map_location=device)
|
| if "epoch" in ckpt and isinstance(ckpt, dict):
|
| start_epoch = ckpt["epoch"]
|
| training_stats["epoch"] = start_epoch
|
| global last_promoted_step
|
| last_promoted_step = start_epoch
|
| if "model_state_dict" in ckpt:
|
| ckpt = ckpt["model_state_dict"]
|
| champion.load_state_dict(ckpt)
|
| challenger.load_state_dict(ckpt)
|
| elif os.path.exists(fast_weights):
|
| print(f"Loading base Fast weights from {fast_weights}")
|
| ckpt = torch.load(fast_weights, map_location=device)
|
| if "epoch" in ckpt and isinstance(ckpt, dict):
|
| start_epoch = ckpt["epoch"]
|
| training_stats["epoch"] = start_epoch
|
| if "model_state_dict" in ckpt:
|
| ckpt = ckpt["model_state_dict"]
|
| champion.load_state_dict(ckpt)
|
| challenger.load_state_dict(ckpt)
|
|
|
| champion.eval()
|
| challenger.train()
|
|
|
| optimizer = torch.optim.AdamW(challenger.parameters(), lr=1e-5)
|
|
|
|
|
| def build_ui():
|
| def get_state():
|
| try:
|
| import json
|
| with ui_lock:
|
| b = chess.Board(current_game_pgn)
|
| eval_val = current_eval
|
| sf_eval_val = current_sf_eval
|
| white_name = "Challenger" if current_challenger_is_white else "Champion"
|
| black_name = "Champion" if current_challenger_is_white else "Challenger"
|
|
|
| return json.dumps({
|
| "fen": b.fen(),
|
| "eval": eval_val,
|
| "sf_eval": sf_eval_val,
|
| "white_name": white_name,
|
| "black_name": black_name
|
| })
|
| except:
|
| import json
|
| return json.dumps({"fen": chess.STARTING_FEN, "eval": 0.0, "sf_eval": 0.0, "white_name": "", "black_name": ""})
|
|
|
| def get_stats():
|
| with stats_lock:
|
| return [
|
| ["Training Steps", str(training_stats['epoch'])],
|
| ["Last Promoted Step", str(last_promoted_step)],
|
| ["Loss", f"{training_stats['loss']:.4f}"],
|
| ["Recent Win Rate", f"{training_stats['reward']:.2f}"],
|
| ["Replay Buffer Size", str(len(replay_buffer))],
|
| ["Challenger Wins", str(challenger_wins)],
|
| ["Champion Wins", str(champion_wins)],
|
| ["Draws", str(draws)]
|
| ]
|
|
|
| custom_css = """
|
| @import url('https://fonts.googleapis.com/css2?family=Outfit:wght@400;700&display=swap');
|
|
|
| body, .gradio-container {
|
| font-family: 'Outfit', sans-serif !important;
|
| background: linear-gradient(135deg, #0f2027, #203a43, #2c5364) !important;
|
| background-attachment: fixed !important;
|
| color: white !important;
|
| }
|
|
|
| .gradio-container { border: none !important; }
|
|
|
| .glass-panel {
|
| background: rgba(255, 255, 255, 0.1) !important;
|
| backdrop-filter: blur(10px) !important;
|
| border-radius: 12px !important;
|
| border: 1px solid rgba(255, 255, 255, 0.18) !important;
|
| padding: 20px !important;
|
| box-shadow: 0 8px 32px 0 rgba(31, 38, 135, 0.37) !important;
|
| }
|
|
|
| .eval-bar-container {
|
| width: 30px;
|
| height: 400px;
|
| background-color: #333;
|
| border-radius: 4px;
|
| border: 4px solid #fff;
|
| position: relative;
|
| overflow: hidden;
|
| display: flex;
|
| flex-direction: column-reverse;
|
| box-shadow: 0 15px 35px rgba(0,0,0,0.5);
|
| }
|
| .eval-bar-fill {
|
| width: 100%;
|
| height: 50%;
|
| background-color: #fff;
|
| transition: height 0.5s cubic-bezier(0.4, 0, 0.2, 1);
|
| }
|
| .eval-marker {
|
| position: absolute;
|
| top: 50%;
|
| left: 0;
|
| width: 100%;
|
| height: 2px;
|
| background-color: #ff5e7e;
|
| z-index: 10;
|
| }
|
| """
|
|
|
| with gr.Blocks(title="Neurex RL Dashboard", css=custom_css) as demo:
|
| gr.HTML("<h1 style='text-align: center; color: white; font-weight: 700; font-size: 2.5rem; text-shadow: 2px 2px 10px rgba(0,0,0,0.5);'>🧠 Neurex RL Self-Play Dashboard</h1>")
|
| gr.HTML("<p style='text-align: center; color: #ddd; font-size: 1.1rem;'><b>ASYNCHRONOUS ALPHA-ZERO MODE</b> | Real-time Actor-Learner Architecture</p>")
|
|
|
| with gr.Row():
|
| with gr.Column(elem_classes=["glass-panel"]):
|
| board_html = """
|
| <div style="display: flex; flex-direction: column; align-items: center;">
|
| <div id="blackName" style="font-size: 1.3rem; font-weight: bold; margin-bottom: 12px; color: #ff5e7e; text-shadow: 1px 1px 5px rgba(0,0,0,0.5);">Black</div>
|
| <div style="display: flex; align-items: center; gap: 15px; justify-content: center;">
|
| <div class="eval-bar-container" title="Neural Network Evaluation">
|
| <div class="eval-bar-fill" id="evalBar"></div>
|
| <div class="eval-marker"></div>
|
| </div>
|
| <div id="board" style="width: 400px; box-shadow: 0 15px 35px rgba(0,0,0,0.5); border: 4px solid #fff; border-radius: 4px; overflow: hidden;"></div>
|
| <div class="eval-bar-container" title="Stockfish Evaluation">
|
| <div class="eval-bar-fill" id="sfEvalBar" style="background-color: #00ff88;"></div>
|
| <div class="eval-marker"></div>
|
| </div>
|
| </div>
|
| <div id="whiteName" style="font-size: 1.3rem; font-weight: bold; margin-top: 12px; color: #00ff88; text-shadow: 1px 1px 5px rgba(0,0,0,0.5);">White</div>
|
| </div>
|
| """
|
| board_view = gr.HTML(board_html)
|
| current_state_box = gr.Textbox(visible=False)
|
| with gr.Column(elem_classes=["glass-panel"]):
|
| stats_view = gr.Dataframe(headers=["Metric", "Value"], interactive=False)
|
|
|
| timer = gr.Timer(0.5)
|
| timer.tick(get_state, inputs=[], outputs=[current_state_box])
|
| timer.tick(get_stats, inputs=[], outputs=[stats_view])
|
|
|
| js_callback = """
|
| (state_str) => {
|
| try {
|
| let state = JSON.parse(state_str);
|
| if (window.my_board) window.my_board.position(state.fen);
|
|
|
| let evalBar = document.getElementById('evalBar');
|
| if (evalBar) {
|
| let heightPercent = ((state.eval + 1.0) / 2.0) * 100;
|
| heightPercent = Math.max(0, Math.min(100, heightPercent));
|
| evalBar.style.height = heightPercent + '%';
|
| }
|
|
|
| let sfEvalBar = document.getElementById('sfEvalBar');
|
| if (sfEvalBar) {
|
| let sfHeightPercent = ((state.sf_eval + 1.0) / 2.0) * 100;
|
| sfHeightPercent = Math.max(0, Math.min(100, sfHeightPercent));
|
| sfEvalBar.style.height = sfHeightPercent + '%';
|
| }
|
|
|
| let blackName = document.getElementById('blackName');
|
| if (blackName && state.black_name) {
|
| let dot = state.black_name === "Challenger" ? "🟢" : "🔴";
|
| let color = state.black_name === "Challenger" ? "#00ff88" : "#ff5e7e";
|
| blackName.innerText = dot + " " + state.black_name + " (Black)";
|
| blackName.style.color = color;
|
| }
|
|
|
| let whiteName = document.getElementById('whiteName');
|
| if (whiteName && state.white_name) {
|
| let dot = state.white_name === "Challenger" ? "🟢" : "🔴";
|
| let color = state.white_name === "Challenger" ? "#00ff88" : "#ff5e7e";
|
| whiteName.innerText = dot + " " + state.white_name + " (White)";
|
| whiteName.style.color = color;
|
| }
|
| } catch(e) {}
|
| return state_str;
|
| }
|
| """
|
| current_state_box.change(None, inputs=[current_state_box], js=js_callback)
|
|
|
| init_js = """
|
| function() {
|
| var jq = document.createElement('script');
|
| jq.src = "https://code.jquery.com/jquery-3.5.1.min.js";
|
| document.head.appendChild(jq);
|
|
|
| var css = document.createElement('link');
|
| css.rel = "stylesheet";
|
| css.href = "https://unpkg.com/@chrisoakman/chessboardjs@1.0.0/dist/chessboard-1.0.0.min.css";
|
| document.head.appendChild(css);
|
|
|
| jq.onload = function() {
|
| var cb = document.createElement('script');
|
| cb.src = "https://unpkg.com/@chrisoakman/chessboardjs@1.0.0/dist/chessboard-1.0.0.min.js";
|
| document.head.appendChild(cb);
|
| cb.onload = function() {
|
| let checkExist = setInterval(function() {
|
| if (document.getElementById('board')) {
|
| window.my_board = Chessboard('board', {
|
| position: 'start',
|
| pieceTheme: 'https://chessboardjs.com/img/chesspieces/wikipedia/{piece}.png'
|
| });
|
| clearInterval(checkExist);
|
| }
|
| }, 100);
|
| };
|
| };
|
| }
|
| """
|
| demo.load(None, None, None, js=init_js)
|
|
|
| return demo
|
|
|
| if __name__ == "__main__":
|
| init_models()
|
|
|
|
|
| for i in range(4):
|
| t = threading.Thread(target=actor_worker, args=(i,), daemon=True)
|
| t.start()
|
|
|
|
|
| t_learner = threading.Thread(target=learner_worker, daemon=True)
|
| t_learner.start()
|
|
|
|
|
| demo = build_ui()
|
| demo.launch(server_name="0.0.0.0", prevent_thread_lock=False)
|
|
|