# NOTE FOR COLAB USERS: Run in a separate cell first: # !pip -q install chess numpy torch matplotlib pandas """ Aggressive GRPO Chess Agent — T4/Colab Optimized """ import os, sys, csv, time, math, shutil, argparse, random import numpy as np import pandas as pd import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt try: import chess except ImportError: os.system("pip install -q chess") import chess import torch import torch.nn as nn import torch.nn.functional as F # ── Hardware flags ───────────────────────────────────────────────────────────── torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True if hasattr(torch, 'set_float32_matmul_precision'): torch.set_float32_matmul_precision('high') # ── Constants ────────────────────────────────────────────────────────────────── PIECE_VAL = { chess.PAWN: 1.0, chess.KNIGHT: 3.0, chess.BISHOP: 3.2, chess.ROOK: 5.0, chess.QUEEN: 9.0, chess.KING: 0.0, } RANDOM_BASELINE_ELO = 800 # estimated ELO of uniform-random player CONFIG = { "num_envs": 256, "grpo_group_size": 8, # G envs per group, all start from same opening position "ppo_epochs": 3, "mini_batch_size": 4096, "learning_rate": 2e-4, "weight_decay": 1e-4, "gamma": 0.98, # lower → discount future more → prefer fast wins "clip_epsilon": 0.15, "entropy_coef": 0.02, # low → exploit aggressive lines "value_coef": 0.5, "max_steps": 100, "opening_max_moves": 10, # randomize opening for GRPO diversity "checkpoint_dir": "./checkpoints", "save_interval": 50, "log_interval": 1, "elo_eval_interval": 100, # evaluate ELO every N iterations "elo_eval_games": 32, "max_runtime_hours": 4.5, # auto-save + download before Colab kills session "device": "cuda" if torch.cuda.is_available() else "cpu", "seed": 42, } # ── Action Space ─────────────────────────────────────────────────────────────── class ActionMapper: __slots__ = ['move_to_idx', 'idx_to_move', 'num_actions'] def __init__(self): self.move_to_idx: dict[str, int] = {} self.idx_to_move: list[str] = [] idx = 0 for f in range(64): for t in range(64): if f == t: continue uci = chess.SQUARE_NAMES[f] + chess.SQUARE_NAMES[t] self.move_to_idx[uci] = idx self.idx_to_move.append(uci) idx += 1 if chess.square_rank(f) in (1, 6) and \ abs(chess.square_file(f) - chess.square_file(t)) <= 1: for promo in "nbrq": puci = uci + promo self.move_to_idx[puci] = idx self.idx_to_move.append(puci) idx += 1 self.num_actions = idx ACTION_MAPPER = ActionMapper() # ── Board Encoding ───────────────────────────────────────────────────────────── def populate_states_fast(envs: list, active_mask: np.ndarray, bbs_np: np.ndarray, meta_np: np.ndarray) -> None: """Fill bbs_np [B,12] int64 and meta_np [B,3] float32 for active envs.""" for b in range(len(envs)): if not active_mask[b]: continue env = envs[b] w = env.occupied_co[chess.WHITE] bc = env.occupied_co[chess.BLACK] bbs_np[b, 0] = env.pawns & w; bbs_np[b, 1] = env.knights & w bbs_np[b, 2] = env.bishops & w; bbs_np[b, 3] = env.rooks & w bbs_np[b, 4] = env.queens & w; bbs_np[b, 5] = env.kings & w bbs_np[b, 6] = env.pawns & bc; bbs_np[b, 7] = env.knights & bc bbs_np[b, 8] = env.bishops & bc; bbs_np[b, 9] = env.rooks & bc bbs_np[b, 10] = env.queens & bc; bbs_np[b, 11] = env.kings & bc meta_np[b, 0] = 1.0 if env.turn else -1.0 meta_np[b, 1] = float(env.castling_rights) / 15.0 # [0,1] meta_np[b, 2] = 1.0 if env.ep_square is not None else 0.0 def get_legal_masks(envs: list, active_mask: np.ndarray): masks = np.zeros((len(envs), ACTION_MAPPER.num_actions), dtype=np.bool_) moves_list = [None] * len(envs) for b in range(len(envs)): if not active_mask[b]: continue legal = list(envs[b].legal_moves) moves_list[b] = legal for m in legal: masks[b, ACTION_MAPPER.move_to_idx[m.uci()]] = True return masks, moves_list # ── Neural Network ───────────────────────────────────────────────────────────── class ChessNet(nn.Module): def __init__(self, res_blocks: int = 8, channels: int = 128): super().__init__() self.conv_in = nn.Conv2d(14, channels, 3, padding=1, bias=False) self.bn_in = nn.BatchNorm2d(channels) self.res_blocks = nn.ModuleList([ nn.Sequential( nn.Conv2d(channels, channels, 3, padding=1, bias=False), nn.BatchNorm2d(channels), nn.ReLU(inplace=True), nn.Conv2d(channels, channels, 3, padding=1, bias=False), nn.BatchNorm2d(channels), ) for _ in range(res_blocks) ]) self.policy_head = nn.Sequential( nn.Conv2d(channels, 32, 1, bias=False), nn.BatchNorm2d(32), nn.ReLU(inplace=True), nn.Flatten(), nn.Linear(32 * 64, ACTION_MAPPER.num_actions), ) # No Tanh — shaped rewards exceed [-1,1]; unbounded linear output self.value_head = nn.Sequential( nn.Conv2d(channels, 32, 1, bias=False), nn.BatchNorm2d(32), nn.ReLU(inplace=True), nn.Flatten(), nn.Linear(32 * 64, 256), nn.ReLU(inplace=True), nn.Linear(256, 1), ) def forward(self, x): x = F.relu(self.bn_in(self.conv_in(x)), inplace=True) for blk in self.res_blocks: x = F.relu(x + blk(x), inplace=True) return self.policy_head(x), self.value_head(x) # ── ELO Tracker ─────────────────────────────────────────────────────────────── class ELOTracker: def __init__(self, initial_elo: float = 1200.0, K: float = 32.0): self.elo = initial_elo self.K = K def expected(self, opp_elo: float) -> float: return 1.0 / (1.0 + 10.0 ** ((opp_elo - self.elo) / 400.0)) def update(self, score: float, opp_elo: float) -> None: self.elo += self.K * (score - self.expected(opp_elo)) # ── Opening Position Generator ───────────────────────────────────────────────── def get_opening_position(max_moves: int = 10) -> chess.Board: """Play 0..max_moves random half-moves from start for GRPO diversity.""" board = chess.Board() for _ in range(random.randint(0, max_moves)): if board.is_game_over(): break board.push(random.choice(list(board.legal_moves))) return chess.Board(board.fen()) # detached copy # ── Auto-download ────────────────────────────────────────────────────────────── def auto_download(checkpoint_dir: str) -> None: """Sync to Google Drive if mounted, else trigger browser downloads.""" try: from google.colab import files as _cf drive_dst = '/content/drive/MyDrive/chess_agent' if os.path.exists('/content/drive/MyDrive'): os.makedirs(drive_dst, exist_ok=True) shutil.copytree(checkpoint_dir, drive_dst, dirs_exist_ok=True) print(f"[AutoSave] Synced → {drive_dst}") else: for fname in ['best.pt', 'latest.pt', 'training_log.csv', 'elo_log.csv', 'training_performance.png']: fpath = os.path.join(checkpoint_dir, fname) if os.path.exists(fpath): _cf.download(fpath) print(f"[AutoSave] Downloaded {fname}") except Exception as e: print(f"[AutoSave] {e}") # ── GRPO Trainer ─────────────────────────────────────────────────────────────── class GRPOTrainer: def __init__(self): self.device = CONFIG["device"] _model = ChessNet(res_blocks=8, channels=128) _model = _model.to(self.device).to(memory_format=torch.channels_last) try: print("Compiling model (reduce-overhead)…") self.model = torch.compile(_model, mode="reduce-overhead") except Exception: self.model = _model self.optimizer = torch.optim.AdamW( self.model.parameters(), lr=CONFIG["learning_rate"], weight_decay=CONFIG["weight_decay"], fused=torch.cuda.is_available(), ) self.scaler = torch.amp.GradScaler('cuda') self.start_iter = 0 self.best_win_rate = 0.0 self.elo_tracker = ELOTracker() # Shared shift tensor for bit-unpacking (avoid repeated allocation) self.shifts = torch.arange(64, dtype=torch.int64, device=self.device).view(1, 1, 64) os.makedirs(CONFIG["checkpoint_dir"], exist_ok=True) self.log_file = os.path.join(CONFIG["checkpoint_dir"], "training_log.csv") self.elo_log_file = os.path.join(CONFIG["checkpoint_dir"], "elo_log.csv") if not os.path.exists(self.log_file): with open(self.log_file, "w", newline="") as f: csv.writer(f).writerow([ "iteration", "p_loss", "v_loss", "v_mean", "fps", "win_rate", "draw_rate", "check_rate", "capture_rate", "avg_game_len", ]) if not os.path.exists(self.elo_log_file): with open(self.elo_log_file, "w", newline="") as f: csv.writer(f).writerow( ["iteration", "elo", "eval_wins", "eval_draws", "eval_losses"]) self._init_checkpointing() # ── Checkpointing ────────────────────────────────────────────────────────── def _init_checkpointing(self) -> None: latest = os.path.join(CONFIG["checkpoint_dir"], "latest.pt") if not os.path.exists(latest): return try: ckpt = torch.load(latest, map_location=self.device, weights_only=False) sd = ckpt['model_state_dict'] # Handle compiled (_orig_mod. prefix) vs uncompiled state dicts loaded = False for attempt in [ sd, {k.replace('_orig_mod.', ''): v for k, v in sd.items()}, {'_orig_mod.' + k: v for k, v in sd.items()}, ]: try: self.model.load_state_dict(attempt); loaded = True; break except RuntimeError: continue if not loaded: raise RuntimeError("All state dict key variants failed.") self.optimizer.load_state_dict(ckpt['optimizer_state_dict']) self.scaler.load_state_dict(ckpt['scaler_state_dict']) self.start_iter = ckpt.get('iteration', 0) + 1 self.elo_tracker.elo = ckpt.get('elo', 1200.0) self.best_win_rate = ckpt.get('best_win_rate', 0.0) print(f"Resumed from iter {self.start_iter} | " f"ELO {self.elo_tracker.elo:.0f} | best_win {self.best_win_rate:.3f}") except Exception as e: print(f"Checkpoint load failed ({e}). Starting fresh.") def save_checkpoint(self, iteration: int, is_best: bool = False) -> None: ckpt = { 'iteration': iteration, 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'scaler_state_dict': self.scaler.state_dict(), 'elo': self.elo_tracker.elo, 'best_win_rate': self.best_win_rate, 'config': CONFIG, } cdir = CONFIG["checkpoint_dir"] path = os.path.join(cdir, f"iter_{iteration:04d}.pt") # Atomic write: write to .tmp then os.replace (single syscall, crash-safe) torch.save(ckpt, path + ".tmp"); os.replace(path + ".tmp", path) latest = os.path.join(cdir, "latest.pt") shutil.copy2(path, latest + ".tmp"); os.replace(latest + ".tmp", latest) if is_best: best = os.path.join(cdir, "best.pt") shutil.copy2(path, best + ".tmp"); os.replace(best + ".tmp", best) # ── ELO Evaluation (batched, greedy) ────────────────────────────────────── def _elo_game_done(self, board: chess.Board, idx: int, agent_color, scores: np.ndarray, active: np.ndarray) -> None: if board.is_game_over(): res = board.result() if (res == "1-0" and agent_color == chess.WHITE) or \ (res == "0-1" and agent_color == chess.BLACK): scores[idx] = 1.0 elif res == "1/2-1/2": scores[idx] = 0.5 else: scores[idx] = 0.0 active[idx] = False def evaluate_elo(self, n_games: int = 32, max_ply: int = 200) -> tuple: """ Play n_games vs random opponent (batched GPU for agent moves). Half games as White, half as Black. Returns (wins, draws, losses) from agent's perspective. """ self.model.eval() boards = [chess.Board() for _ in range(n_games)] agent_colors = [chess.WHITE if i % 2 == 0 else chess.BLACK for i in range(n_games)] scores = np.full(n_games, 0.5, dtype=np.float32) # default: draw active = np.ones(n_games, dtype=bool) bbs_sub = np.zeros((n_games, 12), dtype=np.int64) meta_sub= np.zeros((n_games, 3), dtype=np.float32) for _ in range(max_ply): if not active.any(): break # Random moves (opponent turns) — CPU for i in [i for i in range(n_games) if active[i] and boards[i].turn != agent_colors[i]]: legal = list(boards[i].legal_moves) if legal: boards[i].push(random.choice(legal)) self._elo_game_done(boards[i], i, agent_colors[i], scores, active) # Agent moves (batched GPU) ag_idx = [i for i in range(n_games) if active[i] and boards[i].turn == agent_colors[i]] if not ag_idx: continue n = len(ag_idx) sub = [boards[i] for i in ag_idx] act_sub = np.ones(n, dtype=bool) populate_states_fast(sub, act_sub, bbs_sub[:n], meta_sub[:n]) bbs_t = torch.tensor(bbs_sub[:n], dtype=torch.int64, device=self.device) unpacked = ((bbs_t.unsqueeze(-1) >> self.shifts) & 1).float().view(n, 12, 8, 8) state = torch.zeros(n, 14, 8, 8, device=self.device, dtype=torch.float32) state[:, :12] = unpacked state[:, 12] = torch.tensor(meta_sub[:n, 0], device=self.device).view(n, 1, 1).expand(n, 8, 8) state[:, 13] = torch.tensor(meta_sub[:n, 1], device=self.device).view(n, 1, 1).expand(n, 8, 8) for lj in range(n): if meta_sub[lj, 2]: state[lj, 13, 0, 1] = float(meta_sub[lj, 2]) with torch.no_grad(), torch.amp.autocast('cuda'): logits, _ = self.model(state.to(memory_format=torch.channels_last)) logits = logits.float() masks_np, legal_lists = get_legal_masks(sub, act_sub) masks_t = torch.tensor(masks_np, dtype=torch.bool, device=self.device) logits = torch.where(masks_t, logits, torch.tensor(-60000.0, device=self.device)) best_acts = logits.argmax(dim=-1).cpu().numpy() # greedy for evaluation for lj, gi in enumerate(ag_idx): if not active[gi]: continue move_uci = ACTION_MAPPER.idx_to_move[best_acts[lj]] move = chess.Move.from_uci(move_uci) legal = legal_lists[lj] or list(boards[gi].legal_moves) if not legal: active[gi] = False; continue if move not in legal: move = random.choice(legal) boards[gi].push(move) self._elo_game_done(boards[gi], gi, agent_colors[gi], scores, active) wins = int((scores == 1.0).sum()) draws = int((scores == 0.5).sum()) losses = int((scores == 0.0).sum()) for s in scores: self.elo_tracker.update(float(s), RANDOM_BASELINE_ELO) return wins, draws, losses # ── Main Training Loop ───────────────────────────────────────────────────── def train(self, num_iterations: int) -> None: B = CONFIG["num_envs"] max_steps = CONFIG["max_steps"] G = CONFIG["grpo_group_size"] num_groups= B // G gamma = CONFIG["gamma"] t_start = time.time() max_rt = CONFIG["max_runtime_hours"] * 3600.0 # ── Preallocate GPU buffers (int8/bool minimizes VRAM footprint) ────── states_buf = torch.zeros((max_steps, B, 14, 8, 8), dtype=torch.int8, device=self.device) actions_buf = torch.zeros((max_steps, B), dtype=torch.int16, device=self.device) logprobs_buf= torch.zeros((max_steps, B), dtype=torch.float32, device=self.device) values_buf = torch.zeros((max_steps, B), dtype=torch.float32, device=self.device) rewards_buf = torch.zeros((max_steps, B), dtype=torch.float32, device=self.device) dones_buf = torch.zeros((max_steps, B), dtype=torch.bool, device=self.device) active_buf = torch.zeros((max_steps, B), dtype=torch.bool, device=self.device) bbs_np = np.zeros((B, 12), dtype=np.int64) # int64: no astype copy needed meta_np = np.zeros((B, 3), dtype=np.float32) vram_gb = (torch.cuda.get_device_properties(0).total_memory / 1e9 if torch.cuda.is_available() else 0.0) print(f"\n🚀 Aggressive GRPO Chess Agent") print(f" Envs:{B} | Groups:{num_groups}×G:{G} | Device:{self.device.upper()} | " f"VRAM:{vram_gb:.1f}GB") print(f" Reward: capture(0-0.3)+check(0.3)+checkmate_speed(1.0-1.5)" f"+draw_penalty(-0.5)+time(-0.003/step)") print(f" gamma:{gamma} | entropy:{CONFIG['entropy_coef']} | " f"lr:{CONFIG['learning_rate']}") for iteration in range(self.start_iter, num_iterations): # ── Runtime guard ────────────────────────────────────────────── elapsed = time.time() - t_start if elapsed > max_rt: print(f"\n⏱ {elapsed/3600:.2f}h reached. Saving & downloading…") self.save_checkpoint(iteration) self.plot_metrics() auto_download(CONFIG["checkpoint_dir"]) break iter_start = time.time() # Zero buffers in-place (no reallocation) states_buf.zero_(); actions_buf.zero_(); logprobs_buf.zero_() values_buf.zero_(); rewards_buf.zero_() dones_buf.fill_(False); active_buf.fill_(False) # ── GRPO: each group of G envs shares an opening position ────── fens = [get_opening_position(CONFIG["opening_max_moves"]).fen() for _ in range(num_groups)] envs: list[chess.Board] = [] for gi in range(num_groups): for _ in range(G): envs.append(chess.Board(fens[gi])) active = np.ones(B, dtype=bool) game_lengths = np.zeros(B, dtype=np.int32) # Per-iteration attack metrics white_wins = black_wins = draws_count = 0 total_checks = total_captures = 0 # ── PHASE 1: ROLLOUT ─────────────────────────────────────────── for t in range(max_steps): if not active.any(): break populate_states_fast(envs, active, bbs_np, meta_np) # Bit-unpack bitboards → int8 state tensor (no float copy) bbs_t = torch.as_tensor(bbs_np, dtype=torch.int64, device=self.device) unpacked = ((bbs_t.unsqueeze(-1) >> self.shifts) & 1).to(torch.int8) meta_t = torch.as_tensor(meta_np, dtype=torch.float32, device=self.device) # Pack into int8 buffer (scale float meta to [-127,127]) states_buf[t, :, :12, :, :] = unpacked.view(B, 12, 8, 8) states_buf[t, :, 12, :, :] = (meta_t[:, 0] * 127).clamp(-127, 127) \ .to(torch.int8).view(B, 1, 1).expand(B, 8, 8) states_buf[t, :, 13, :, :] = (meta_t[:, 1] * 127).clamp(0, 127) \ .to(torch.int8).view(B, 1, 1).expand(B, 8, 8) states_buf[t, :, 13, 0, 1]= (meta_t[:, 2] * 127).clamp(0, 127).to(torch.int8) active_buf[t] = torch.as_tensor(active, dtype=torch.bool, device=self.device) # Normalize int8→float32 for forward pass model_input = states_buf[t].to( dtype=torch.float32, memory_format=torch.channels_last) / 127.0 self.model.eval() with torch.no_grad(), torch.amp.autocast('cuda'): logits, values = self.model(model_input) masks_np, legal_moves_list = get_legal_masks(envs, active) masks_t = torch.as_tensor(masks_np, dtype=torch.bool, device=self.device) logits = logits.float() logits = torch.where(masks_t, logits, torch.tensor(-60000.0, device=self.device)) no_legal = ~masks_t.any(dim=-1, keepdim=True) logits.masked_fill_(no_legal, 0.0) probs = F.softmax(logits, dim=-1) dist = torch.distributions.Categorical(probs) actions = dist.sample() actions_buf[t] = actions.to(torch.int16) logprobs_buf[t] = dist.log_prob(actions) values_buf[t] = values.squeeze(-1) actions_cpu = actions.cpu().numpy() for b in range(B): if not active[b]: continue move_uci = ACTION_MAPPER.idx_to_move[actions_cpu[b]] move = chess.Move.from_uci(move_uci) if move not in legal_moves_list[b]: move = random.choice(legal_moves_list[b]) board = envs[b] mover_is_white = (board.turn == chess.WHITE) sign = 1.0 if mover_is_white else -1.0 # ── Reward: pre-push components ───────────────────── r = -0.003 * sign # time penalty (per-mover, white-perspective) if board.is_capture(move): if board.is_en_passant(move): cap_val = 1.0 else: cp = board.piece_at(move.to_square) cap_val = PIECE_VAL.get(cp.piece_type, 0.0) if cp else 0.0 r += sign * (cap_val / 9.0) * 0.3 # [0, 0.3] total_captures += 1 if move.promotion in (chess.QUEEN, chess.ROOK): r += sign * 0.15 # aggressive promotion board.push(move) game_lengths[b] += 1 # ── Reward: post-push components ──────────────────── if board.is_check(): r += sign * 0.3 # gave check total_checks += 1 if board.is_game_over(): if board.is_checkmate(): # Mover delivered checkmate speed_bonus = 0.5 * math.exp(-game_lengths[b] / 20.0) r += sign * (1.0 + speed_bonus) # ~1.0-1.5 if mover_is_white: white_wins += 1 else: black_wins += 1 else: # Draw (stalemate / 50-move / repetition / insufficient material) r -= 0.5 # flat penalty from white's perspective — attack to WIN draws_count += 1 dones_buf[t, b] = True active[b] = False rewards_buf[t, b] = r # end per-env loop # end rollout # ── PHASE 2: VECTORIZED RETURNS ──────────────────────────────── returns = torch.zeros(B, dtype=torch.float32, device=self.device) returns_buf = torch.zeros((max_steps, B), dtype=torch.float32, device=self.device) not_done_f = (~dones_buf).float() for step in reversed(range(max_steps)): returns = rewards_buf[step] + gamma * returns * not_done_f[step] returns_buf[step]= returns # ── PHASE 3: GRPO GROUP-WISE ADVANTAGE NORMALIZATION ─────────── # advantages shape [max_steps, B] adv_raw = returns_buf - values_buf active_f = active_buf.float() # Reshape to [max_steps, num_groups, G] and normalize within each group adv_3d = adv_raw.view(max_steps, num_groups, G) act_3d = active_f.view(max_steps, num_groups, G) g_count = act_3d.sum(dim=[0, 2]).clamp(min=1.0) # [num_groups] g_mean = (adv_3d * act_3d).sum(dim=[0, 2]) / g_count # [num_groups] g_sq_diff = ((adv_3d - g_mean.view(1, num_groups, 1)) ** 2 * act_3d).sum(dim=[0, 2]) g_std = (g_sq_diff / g_count).sqrt().clamp(min=1e-8) # [num_groups] adv_3d = (adv_3d - g_mean.view(1, num_groups, 1)) / \ g_std.view(1, num_groups, 1) adv_norm = adv_3d.view(max_steps, B) # Flatten, filter to active steps only valid_mask = active_buf.view(-1) flat_states = (states_buf.view(-1, 14, 8, 8)[valid_mask] .to(torch.float32, memory_format=torch.channels_last) .div_(127.0)) flat_actions = actions_buf.view(-1)[valid_mask].to(torch.int64) flat_old_lp = logprobs_buf.view(-1)[valid_mask] flat_returns = returns_buf.view(-1)[valid_mask] flat_advantages = adv_norm.view(-1)[valid_mask] dataset_size = flat_states.size(0) if dataset_size < 100: continue # skip degenerate rollout (all games ended instantly) # ── PHASE 4: PPO OPTIMIZATION ────────────────────────────────── self.model.train() total_p_loss = total_v_loss = 0.0 num_updates = 0 mb_size = CONFIG["mini_batch_size"] for _ in range(CONFIG["ppo_epochs"]): perm = torch.randperm(dataset_size, device=self.device) for start in range(0, dataset_size, mb_size): mb = perm[start: start + mb_size] with torch.amp.autocast('cuda'): new_logits, new_vals = self.model(flat_states[mb]) new_dist = torch.distributions.Categorical(logits=new_logits) new_lp = new_dist.log_prob(flat_actions[mb]) ratio = torch.exp(new_lp - flat_old_lp[mb]) adv = flat_advantages[mb] surr1 = ratio * adv surr2 = torch.clamp( ratio, 1.0 - CONFIG["clip_epsilon"], 1.0 + CONFIG["clip_epsilon"], ) * adv p_loss = -torch.min(surr1, surr2).mean() v_loss = F.mse_loss(new_vals.squeeze(-1), flat_returns[mb]) entropy = new_dist.entropy().mean() loss = (p_loss + CONFIG["value_coef"] * v_loss - CONFIG["entropy_coef"] * entropy) self.optimizer.zero_grad(set_to_none=True) self.scaler.scale(loss).backward() self.scaler.unscale_(self.optimizer) nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) self.scaler.step(self.optimizer) self.scaler.update() total_p_loss += p_loss.item() total_v_loss += v_loss.item() num_updates += 1 # ── PHASE 5: METRICS & LOGGING ──────────────────────────────── done_count = white_wins + black_wins + draws_count win_rate = white_wins / max(done_count, 1) draw_rate = draws_count / max(done_count, 1) active_steps = int(active_buf.sum().item()) check_rate = total_checks / max(active_steps, 1) capture_rate = total_captures / max(active_steps, 1) avg_game_len = float(game_lengths.mean()) fps = dataset_size / max(time.time() - iter_start, 1e-3) if (iteration + 1) % CONFIG["log_interval"] == 0: vram_alloc = (torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0.0) vram_res = (torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0.0) print( f"[{iteration+1:05d}] " f"P:{total_p_loss/max(1,num_updates):.4f} " f"V:{total_v_loss/max(1,num_updates):.4f} | " f"W:{win_rate:.3f} D:{draw_rate:.3f} " f"Chk:{check_rate:.4f} Cap:{capture_rate:.4f} " f"Len:{avg_game_len:.1f} | " f"ELO:{self.elo_tracker.elo:.0f} | " f"FPS:{fps:.0f} | " f"VRAM:{vram_alloc:.2f}/{vram_res:.2f}GB" ) with open(self.log_file, "a", newline="") as f: csv.writer(f).writerow([ iteration + 1, total_p_loss / max(1, num_updates), total_v_loss / max(1, num_updates), flat_returns.mean().item(), fps, win_rate, draw_rate, check_rate, capture_rate, avg_game_len, ]) # Save best checkpoint when win_rate improves if win_rate > self.best_win_rate: self.best_win_rate = win_rate self.save_checkpoint(iteration + 1, is_best=True) if (iteration + 1) % CONFIG["save_interval"] == 0: self.save_checkpoint(iteration + 1) self.plot_metrics() # ELO evaluation if (iteration + 1) % CONFIG["elo_eval_interval"] == 0: elo_before = self.elo_tracker.elo ew, ed, el = self.evaluate_elo(CONFIG["elo_eval_games"]) print( f" [ELO eval] {elo_before:.0f} → {self.elo_tracker.elo:.0f} | " f"W:{ew} D:{ed} L:{el} vs random({RANDOM_BASELINE_ELO})" ) with open(self.elo_log_file, "a", newline="") as f: csv.writer(f).writerow( [iteration + 1, self.elo_tracker.elo, ew, ed, el]) self.plot_metrics() # Aggressive cache reclaim (free fragmented blocks, not pinned allocs) torch.cuda.empty_cache() # ── Plotting ─────────────────────────────────────────────────────────────── def plot_metrics(self) -> None: if not os.path.exists(self.log_file): return df = pd.read_csv(self.log_file) if len(df) < 2: return elo_df = None if os.path.exists(self.elo_log_file): elo_df = pd.read_csv(self.elo_log_file) fig, axs = plt.subplots(3, 2, figsize=(14, 12)) fig.suptitle("Aggressive GRPO Chess Agent — Training Dashboard", fontsize=14) # Row 0: Losses axs[0, 0].plot(df['iteration'], df['p_loss'], color='steelblue', linewidth=1.2) axs[0, 0].set_title('Policy Loss'); axs[0, 0].set_xlabel('Iteration') axs[0, 1].plot(df['iteration'], df['v_loss'], color='tomato', linewidth=1.2) axs[0, 1].set_title('Value Loss'); axs[0, 1].set_xlabel('Iteration') # Row 1: Outcomes axs[1, 0].plot(df['iteration'], df['win_rate'], label='Win', color='green') axs[1, 0].plot(df['iteration'], df['draw_rate'], label='Draw', color='orange') axs[1, 0].set_title('Outcomes (White perspective)') axs[1, 0].legend(); axs[1, 0].set_xlabel('Iteration') # Row 1: Attack metrics axs[1, 1].plot(df['iteration'], df['check_rate'], label='Check/step', color='purple') axs[1, 1].plot(df['iteration'], df['capture_rate'], label='Capture/step', color='darkorange') axs[1, 1].set_title('Attack Metrics (↑ = more aggressive)') axs[1, 1].legend(); axs[1, 1].set_xlabel('Iteration') # Row 2: ELO Rating if elo_df is not None and len(elo_df) > 0: axs[2, 0].plot(elo_df['iteration'], elo_df['elo'], color='gold', linewidth=2.0, label='Agent ELO') axs[2, 0].axhline(RANDOM_BASELINE_ELO, linestyle='--', color='gray', alpha=0.8, label=f'Random ({RANDOM_BASELINE_ELO})') axs[2, 0].axhline(1200, linestyle=':', color='lightblue', alpha=0.6, label='Start (1200)') axs[2, 0].fill_between(elo_df['iteration'], RANDOM_BASELINE_ELO, elo_df['elo'], alpha=0.15, color='gold') axs[2, 0].set_title('ELO Rating vs Random Baseline') axs[2, 0].legend(); axs[2, 0].set_xlabel('Iteration') else: axs[2, 0].text(0.5, 0.5, f'ELO eval every {CONFIG["elo_eval_interval"]} iters', ha='center', va='center', transform=axs[2, 0].transAxes, color='gray', fontsize=11) axs[2, 0].set_title('ELO Rating (pending)') # Row 2: Average game length axs[2, 1].plot(df['iteration'], df['avg_game_len'], color='teal', linewidth=1.2) axs[2, 1].set_title('Avg Game Length (↓ = faster checkmates)') axs[2, 1].set_xlabel('Iteration') for ax in axs.flat: ax.grid(True, alpha=0.25) plt.tight_layout() out = os.path.join(CONFIG["checkpoint_dir"], "training_performance.png") plt.savefig(out, dpi=100, bbox_inches='tight') plt.close(fig) print(f" [Plot] saved → {out}") # ── Entry Point ──────────────────────────────────────────────────────────────── if __name__ == "__main__": parser = argparse.ArgumentParser( description="Aggressive GRPO Chess Agent (T4/Colab)") parser.add_argument("--iterations", type=int, default=10000, help="Total training iterations") parser.add_argument("--test-batch", action="store_true", help="Run 2 iterations for smoke-test") args, _ = parser.parse_known_args() torch.manual_seed(CONFIG["seed"]) np.random.seed(CONFIG["seed"]) random.seed(CONFIG["seed"]) # Print VRAM summary at startup if torch.cuda.is_available(): props = torch.cuda.get_device_properties(0) print(f"GPU: {props.name} | VRAM: {props.total_memory/1e9:.1f}GB | " f"SM: {props.multi_processor_count} | " f"Compute: {props.major}.{props.minor}") trainer = GRPOTrainer() trainer.train(2 if args.test_batch else args.iterations)