GRPO_chessagent / model_aggressive.py
algorembrant's picture
Upload model_aggressive.py
73d8b96 verified
# NOTE FOR COLAB USERS: Run in a separate cell first:
# !pip -q install chess numpy torch matplotlib pandas
"""
Aggressive GRPO Chess Agent β€” T4/Colab Optimized
"""
import os, sys, csv, time, math, shutil, argparse, random
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
try:
import chess
except ImportError:
os.system("pip install -q chess")
import chess
import torch
import torch.nn as nn
import torch.nn.functional as F
# ── Hardware flags ─────────────────────────────────────────────────────────────
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
if hasattr(torch, 'set_float32_matmul_precision'):
torch.set_float32_matmul_precision('high')
# ── Constants ──────────────────────────────────────────────────────────────────
PIECE_VAL = {
chess.PAWN: 1.0, chess.KNIGHT: 3.0, chess.BISHOP: 3.2,
chess.ROOK: 5.0, chess.QUEEN: 9.0, chess.KING: 0.0,
}
RANDOM_BASELINE_ELO = 800 # estimated ELO of uniform-random player
CONFIG = {
"num_envs": 256,
"grpo_group_size": 8, # G envs per group, all start from same opening position
"ppo_epochs": 3,
"mini_batch_size": 4096,
"learning_rate": 2e-4,
"weight_decay": 1e-4,
"gamma": 0.98, # lower β†’ discount future more β†’ prefer fast wins
"clip_epsilon": 0.15,
"entropy_coef": 0.02, # low β†’ exploit aggressive lines
"value_coef": 0.5,
"max_steps": 100,
"opening_max_moves": 10, # randomize opening for GRPO diversity
"checkpoint_dir": "./checkpoints",
"save_interval": 50,
"log_interval": 1,
"elo_eval_interval": 100, # evaluate ELO every N iterations
"elo_eval_games": 32,
"max_runtime_hours": 4.5, # auto-save + download before Colab kills session
"device": "cuda" if torch.cuda.is_available() else "cpu",
"seed": 42,
}
# ── Action Space ───────────────────────────────────────────────────────────────
class ActionMapper:
__slots__ = ['move_to_idx', 'idx_to_move', 'num_actions']
def __init__(self):
self.move_to_idx: dict[str, int] = {}
self.idx_to_move: list[str] = []
idx = 0
for f in range(64):
for t in range(64):
if f == t: continue
uci = chess.SQUARE_NAMES[f] + chess.SQUARE_NAMES[t]
self.move_to_idx[uci] = idx
self.idx_to_move.append(uci)
idx += 1
if chess.square_rank(f) in (1, 6) and \
abs(chess.square_file(f) - chess.square_file(t)) <= 1:
for promo in "nbrq":
puci = uci + promo
self.move_to_idx[puci] = idx
self.idx_to_move.append(puci)
idx += 1
self.num_actions = idx
ACTION_MAPPER = ActionMapper()
# ── Board Encoding ─────────────────────────────────────────────────────────────
def populate_states_fast(envs: list, active_mask: np.ndarray,
bbs_np: np.ndarray, meta_np: np.ndarray) -> None:
"""Fill bbs_np [B,12] int64 and meta_np [B,3] float32 for active envs."""
for b in range(len(envs)):
if not active_mask[b]: continue
env = envs[b]
w = env.occupied_co[chess.WHITE]
bc = env.occupied_co[chess.BLACK]
bbs_np[b, 0] = env.pawns & w; bbs_np[b, 1] = env.knights & w
bbs_np[b, 2] = env.bishops & w; bbs_np[b, 3] = env.rooks & w
bbs_np[b, 4] = env.queens & w; bbs_np[b, 5] = env.kings & w
bbs_np[b, 6] = env.pawns & bc; bbs_np[b, 7] = env.knights & bc
bbs_np[b, 8] = env.bishops & bc; bbs_np[b, 9] = env.rooks & bc
bbs_np[b, 10] = env.queens & bc; bbs_np[b, 11] = env.kings & bc
meta_np[b, 0] = 1.0 if env.turn else -1.0
meta_np[b, 1] = float(env.castling_rights) / 15.0 # [0,1]
meta_np[b, 2] = 1.0 if env.ep_square is not None else 0.0
def get_legal_masks(envs: list, active_mask: np.ndarray):
masks = np.zeros((len(envs), ACTION_MAPPER.num_actions), dtype=np.bool_)
moves_list = [None] * len(envs)
for b in range(len(envs)):
if not active_mask[b]: continue
legal = list(envs[b].legal_moves)
moves_list[b] = legal
for m in legal:
masks[b, ACTION_MAPPER.move_to_idx[m.uci()]] = True
return masks, moves_list
# ── Neural Network ─────────────────────────────────────────────────────────────
class ChessNet(nn.Module):
def __init__(self, res_blocks: int = 8, channels: int = 128):
super().__init__()
self.conv_in = nn.Conv2d(14, channels, 3, padding=1, bias=False)
self.bn_in = nn.BatchNorm2d(channels)
self.res_blocks = nn.ModuleList([
nn.Sequential(
nn.Conv2d(channels, channels, 3, padding=1, bias=False),
nn.BatchNorm2d(channels), nn.ReLU(inplace=True),
nn.Conv2d(channels, channels, 3, padding=1, bias=False),
nn.BatchNorm2d(channels),
) for _ in range(res_blocks)
])
self.policy_head = nn.Sequential(
nn.Conv2d(channels, 32, 1, bias=False), nn.BatchNorm2d(32),
nn.ReLU(inplace=True), nn.Flatten(),
nn.Linear(32 * 64, ACTION_MAPPER.num_actions),
)
# No Tanh β€” shaped rewards exceed [-1,1]; unbounded linear output
self.value_head = nn.Sequential(
nn.Conv2d(channels, 32, 1, bias=False), nn.BatchNorm2d(32),
nn.ReLU(inplace=True), nn.Flatten(),
nn.Linear(32 * 64, 256), nn.ReLU(inplace=True),
nn.Linear(256, 1),
)
def forward(self, x):
x = F.relu(self.bn_in(self.conv_in(x)), inplace=True)
for blk in self.res_blocks:
x = F.relu(x + blk(x), inplace=True)
return self.policy_head(x), self.value_head(x)
# ── ELO Tracker ───────────────────────────────────────────────────────────────
class ELOTracker:
def __init__(self, initial_elo: float = 1200.0, K: float = 32.0):
self.elo = initial_elo
self.K = K
def expected(self, opp_elo: float) -> float:
return 1.0 / (1.0 + 10.0 ** ((opp_elo - self.elo) / 400.0))
def update(self, score: float, opp_elo: float) -> None:
self.elo += self.K * (score - self.expected(opp_elo))
# ── Opening Position Generator ─────────────────────────────────────────────────
def get_opening_position(max_moves: int = 10) -> chess.Board:
"""Play 0..max_moves random half-moves from start for GRPO diversity."""
board = chess.Board()
for _ in range(random.randint(0, max_moves)):
if board.is_game_over(): break
board.push(random.choice(list(board.legal_moves)))
return chess.Board(board.fen()) # detached copy
# ── Auto-download ──────────────────────────────────────────────────────────────
def auto_download(checkpoint_dir: str) -> None:
"""Sync to Google Drive if mounted, else trigger browser downloads."""
try:
from google.colab import files as _cf
drive_dst = '/content/drive/MyDrive/chess_agent'
if os.path.exists('/content/drive/MyDrive'):
os.makedirs(drive_dst, exist_ok=True)
shutil.copytree(checkpoint_dir, drive_dst, dirs_exist_ok=True)
print(f"[AutoSave] Synced β†’ {drive_dst}")
else:
for fname in ['best.pt', 'latest.pt', 'training_log.csv',
'elo_log.csv', 'training_performance.png']:
fpath = os.path.join(checkpoint_dir, fname)
if os.path.exists(fpath):
_cf.download(fpath)
print(f"[AutoSave] Downloaded {fname}")
except Exception as e:
print(f"[AutoSave] {e}")
# ── GRPO Trainer ───────────────────────────────────────────────────────────────
class GRPOTrainer:
def __init__(self):
self.device = CONFIG["device"]
_model = ChessNet(res_blocks=8, channels=128)
_model = _model.to(self.device).to(memory_format=torch.channels_last)
try:
print("Compiling model (reduce-overhead)…")
self.model = torch.compile(_model, mode="reduce-overhead")
except Exception:
self.model = _model
self.optimizer = torch.optim.AdamW(
self.model.parameters(),
lr=CONFIG["learning_rate"],
weight_decay=CONFIG["weight_decay"],
fused=torch.cuda.is_available(),
)
self.scaler = torch.amp.GradScaler('cuda')
self.start_iter = 0
self.best_win_rate = 0.0
self.elo_tracker = ELOTracker()
# Shared shift tensor for bit-unpacking (avoid repeated allocation)
self.shifts = torch.arange(64, dtype=torch.int64,
device=self.device).view(1, 1, 64)
os.makedirs(CONFIG["checkpoint_dir"], exist_ok=True)
self.log_file = os.path.join(CONFIG["checkpoint_dir"], "training_log.csv")
self.elo_log_file = os.path.join(CONFIG["checkpoint_dir"], "elo_log.csv")
if not os.path.exists(self.log_file):
with open(self.log_file, "w", newline="") as f:
csv.writer(f).writerow([
"iteration", "p_loss", "v_loss", "v_mean", "fps",
"win_rate", "draw_rate", "check_rate", "capture_rate", "avg_game_len",
])
if not os.path.exists(self.elo_log_file):
with open(self.elo_log_file, "w", newline="") as f:
csv.writer(f).writerow(
["iteration", "elo", "eval_wins", "eval_draws", "eval_losses"])
self._init_checkpointing()
# ── Checkpointing ──────────────────────────────────────────────────────────
def _init_checkpointing(self) -> None:
latest = os.path.join(CONFIG["checkpoint_dir"], "latest.pt")
if not os.path.exists(latest):
return
try:
ckpt = torch.load(latest, map_location=self.device, weights_only=False)
sd = ckpt['model_state_dict']
# Handle compiled (_orig_mod. prefix) vs uncompiled state dicts
loaded = False
for attempt in [
sd,
{k.replace('_orig_mod.', ''): v for k, v in sd.items()},
{'_orig_mod.' + k: v for k, v in sd.items()},
]:
try:
self.model.load_state_dict(attempt); loaded = True; break
except RuntimeError:
continue
if not loaded:
raise RuntimeError("All state dict key variants failed.")
self.optimizer.load_state_dict(ckpt['optimizer_state_dict'])
self.scaler.load_state_dict(ckpt['scaler_state_dict'])
self.start_iter = ckpt.get('iteration', 0) + 1
self.elo_tracker.elo = ckpt.get('elo', 1200.0)
self.best_win_rate = ckpt.get('best_win_rate', 0.0)
print(f"Resumed from iter {self.start_iter} | "
f"ELO {self.elo_tracker.elo:.0f} | best_win {self.best_win_rate:.3f}")
except Exception as e:
print(f"Checkpoint load failed ({e}). Starting fresh.")
def save_checkpoint(self, iteration: int, is_best: bool = False) -> None:
ckpt = {
'iteration': iteration,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'scaler_state_dict': self.scaler.state_dict(),
'elo': self.elo_tracker.elo,
'best_win_rate': self.best_win_rate,
'config': CONFIG,
}
cdir = CONFIG["checkpoint_dir"]
path = os.path.join(cdir, f"iter_{iteration:04d}.pt")
# Atomic write: write to .tmp then os.replace (single syscall, crash-safe)
torch.save(ckpt, path + ".tmp"); os.replace(path + ".tmp", path)
latest = os.path.join(cdir, "latest.pt")
shutil.copy2(path, latest + ".tmp"); os.replace(latest + ".tmp", latest)
if is_best:
best = os.path.join(cdir, "best.pt")
shutil.copy2(path, best + ".tmp"); os.replace(best + ".tmp", best)
# ── ELO Evaluation (batched, greedy) ──────────────────────────────────────
def _elo_game_done(self, board: chess.Board, idx: int, agent_color,
scores: np.ndarray, active: np.ndarray) -> None:
if board.is_game_over():
res = board.result()
if (res == "1-0" and agent_color == chess.WHITE) or \
(res == "0-1" and agent_color == chess.BLACK):
scores[idx] = 1.0
elif res == "1/2-1/2":
scores[idx] = 0.5
else:
scores[idx] = 0.0
active[idx] = False
def evaluate_elo(self, n_games: int = 32, max_ply: int = 200) -> tuple:
"""
Play n_games vs random opponent (batched GPU for agent moves).
Half games as White, half as Black.
Returns (wins, draws, losses) from agent's perspective.
"""
self.model.eval()
boards = [chess.Board() for _ in range(n_games)]
agent_colors = [chess.WHITE if i % 2 == 0 else chess.BLACK
for i in range(n_games)]
scores = np.full(n_games, 0.5, dtype=np.float32) # default: draw
active = np.ones(n_games, dtype=bool)
bbs_sub = np.zeros((n_games, 12), dtype=np.int64)
meta_sub= np.zeros((n_games, 3), dtype=np.float32)
for _ in range(max_ply):
if not active.any(): break
# Random moves (opponent turns) β€” CPU
for i in [i for i in range(n_games)
if active[i] and boards[i].turn != agent_colors[i]]:
legal = list(boards[i].legal_moves)
if legal: boards[i].push(random.choice(legal))
self._elo_game_done(boards[i], i, agent_colors[i], scores, active)
# Agent moves (batched GPU)
ag_idx = [i for i in range(n_games)
if active[i] and boards[i].turn == agent_colors[i]]
if not ag_idx:
continue
n = len(ag_idx)
sub = [boards[i] for i in ag_idx]
act_sub = np.ones(n, dtype=bool)
populate_states_fast(sub, act_sub, bbs_sub[:n], meta_sub[:n])
bbs_t = torch.tensor(bbs_sub[:n], dtype=torch.int64, device=self.device)
unpacked = ((bbs_t.unsqueeze(-1) >> self.shifts) & 1).float().view(n, 12, 8, 8)
state = torch.zeros(n, 14, 8, 8, device=self.device, dtype=torch.float32)
state[:, :12] = unpacked
state[:, 12] = torch.tensor(meta_sub[:n, 0], device=self.device).view(n, 1, 1).expand(n, 8, 8)
state[:, 13] = torch.tensor(meta_sub[:n, 1], device=self.device).view(n, 1, 1).expand(n, 8, 8)
for lj in range(n):
if meta_sub[lj, 2]:
state[lj, 13, 0, 1] = float(meta_sub[lj, 2])
with torch.no_grad(), torch.amp.autocast('cuda'):
logits, _ = self.model(state.to(memory_format=torch.channels_last))
logits = logits.float()
masks_np, legal_lists = get_legal_masks(sub, act_sub)
masks_t = torch.tensor(masks_np, dtype=torch.bool, device=self.device)
logits = torch.where(masks_t, logits,
torch.tensor(-60000.0, device=self.device))
best_acts = logits.argmax(dim=-1).cpu().numpy() # greedy for evaluation
for lj, gi in enumerate(ag_idx):
if not active[gi]: continue
move_uci = ACTION_MAPPER.idx_to_move[best_acts[lj]]
move = chess.Move.from_uci(move_uci)
legal = legal_lists[lj] or list(boards[gi].legal_moves)
if not legal:
active[gi] = False; continue
if move not in legal:
move = random.choice(legal)
boards[gi].push(move)
self._elo_game_done(boards[gi], gi, agent_colors[gi], scores, active)
wins = int((scores == 1.0).sum())
draws = int((scores == 0.5).sum())
losses = int((scores == 0.0).sum())
for s in scores:
self.elo_tracker.update(float(s), RANDOM_BASELINE_ELO)
return wins, draws, losses
# ── Main Training Loop ─────────────────────────────────────────────────────
def train(self, num_iterations: int) -> None:
B = CONFIG["num_envs"]
max_steps = CONFIG["max_steps"]
G = CONFIG["grpo_group_size"]
num_groups= B // G
gamma = CONFIG["gamma"]
t_start = time.time()
max_rt = CONFIG["max_runtime_hours"] * 3600.0
# ── Preallocate GPU buffers (int8/bool minimizes VRAM footprint) ──────
states_buf = torch.zeros((max_steps, B, 14, 8, 8), dtype=torch.int8, device=self.device)
actions_buf = torch.zeros((max_steps, B), dtype=torch.int16, device=self.device)
logprobs_buf= torch.zeros((max_steps, B), dtype=torch.float32, device=self.device)
values_buf = torch.zeros((max_steps, B), dtype=torch.float32, device=self.device)
rewards_buf = torch.zeros((max_steps, B), dtype=torch.float32, device=self.device)
dones_buf = torch.zeros((max_steps, B), dtype=torch.bool, device=self.device)
active_buf = torch.zeros((max_steps, B), dtype=torch.bool, device=self.device)
bbs_np = np.zeros((B, 12), dtype=np.int64) # int64: no astype copy needed
meta_np = np.zeros((B, 3), dtype=np.float32)
vram_gb = (torch.cuda.get_device_properties(0).total_memory / 1e9
if torch.cuda.is_available() else 0.0)
print(f"\nπŸš€ Aggressive GRPO Chess Agent")
print(f" Envs:{B} | Groups:{num_groups}Γ—G:{G} | Device:{self.device.upper()} | "
f"VRAM:{vram_gb:.1f}GB")
print(f" Reward: capture(0-0.3)+check(0.3)+checkmate_speed(1.0-1.5)"
f"+draw_penalty(-0.5)+time(-0.003/step)")
print(f" gamma:{gamma} | entropy:{CONFIG['entropy_coef']} | "
f"lr:{CONFIG['learning_rate']}")
for iteration in range(self.start_iter, num_iterations):
# ── Runtime guard ──────────────────────────────────────────────
elapsed = time.time() - t_start
if elapsed > max_rt:
print(f"\n⏱ {elapsed/3600:.2f}h reached. Saving & downloading…")
self.save_checkpoint(iteration)
self.plot_metrics()
auto_download(CONFIG["checkpoint_dir"])
break
iter_start = time.time()
# Zero buffers in-place (no reallocation)
states_buf.zero_(); actions_buf.zero_(); logprobs_buf.zero_()
values_buf.zero_(); rewards_buf.zero_()
dones_buf.fill_(False); active_buf.fill_(False)
# ── GRPO: each group of G envs shares an opening position ──────
fens = [get_opening_position(CONFIG["opening_max_moves"]).fen()
for _ in range(num_groups)]
envs: list[chess.Board] = []
for gi in range(num_groups):
for _ in range(G):
envs.append(chess.Board(fens[gi]))
active = np.ones(B, dtype=bool)
game_lengths = np.zeros(B, dtype=np.int32)
# Per-iteration attack metrics
white_wins = black_wins = draws_count = 0
total_checks = total_captures = 0
# ── PHASE 1: ROLLOUT ───────────────────────────────────────────
for t in range(max_steps):
if not active.any(): break
populate_states_fast(envs, active, bbs_np, meta_np)
# Bit-unpack bitboards β†’ int8 state tensor (no float copy)
bbs_t = torch.as_tensor(bbs_np, dtype=torch.int64, device=self.device)
unpacked = ((bbs_t.unsqueeze(-1) >> self.shifts) & 1).to(torch.int8)
meta_t = torch.as_tensor(meta_np, dtype=torch.float32, device=self.device)
# Pack into int8 buffer (scale float meta to [-127,127])
states_buf[t, :, :12, :, :] = unpacked.view(B, 12, 8, 8)
states_buf[t, :, 12, :, :] = (meta_t[:, 0] * 127).clamp(-127, 127) \
.to(torch.int8).view(B, 1, 1).expand(B, 8, 8)
states_buf[t, :, 13, :, :] = (meta_t[:, 1] * 127).clamp(0, 127) \
.to(torch.int8).view(B, 1, 1).expand(B, 8, 8)
states_buf[t, :, 13, 0, 1]= (meta_t[:, 2] * 127).clamp(0, 127).to(torch.int8)
active_buf[t] = torch.as_tensor(active, dtype=torch.bool, device=self.device)
# Normalize int8β†’float32 for forward pass
model_input = states_buf[t].to(
dtype=torch.float32, memory_format=torch.channels_last) / 127.0
self.model.eval()
with torch.no_grad(), torch.amp.autocast('cuda'):
logits, values = self.model(model_input)
masks_np, legal_moves_list = get_legal_masks(envs, active)
masks_t = torch.as_tensor(masks_np, dtype=torch.bool, device=self.device)
logits = logits.float()
logits = torch.where(masks_t, logits,
torch.tensor(-60000.0, device=self.device))
no_legal = ~masks_t.any(dim=-1, keepdim=True)
logits.masked_fill_(no_legal, 0.0)
probs = F.softmax(logits, dim=-1)
dist = torch.distributions.Categorical(probs)
actions = dist.sample()
actions_buf[t] = actions.to(torch.int16)
logprobs_buf[t] = dist.log_prob(actions)
values_buf[t] = values.squeeze(-1)
actions_cpu = actions.cpu().numpy()
for b in range(B):
if not active[b]: continue
move_uci = ACTION_MAPPER.idx_to_move[actions_cpu[b]]
move = chess.Move.from_uci(move_uci)
if move not in legal_moves_list[b]:
move = random.choice(legal_moves_list[b])
board = envs[b]
mover_is_white = (board.turn == chess.WHITE)
sign = 1.0 if mover_is_white else -1.0
# ── Reward: pre-push components ─────────────────────
r = -0.003 * sign # time penalty (per-mover, white-perspective)
if board.is_capture(move):
if board.is_en_passant(move):
cap_val = 1.0
else:
cp = board.piece_at(move.to_square)
cap_val = PIECE_VAL.get(cp.piece_type, 0.0) if cp else 0.0
r += sign * (cap_val / 9.0) * 0.3 # [0, 0.3]
total_captures += 1
if move.promotion in (chess.QUEEN, chess.ROOK):
r += sign * 0.15 # aggressive promotion
board.push(move)
game_lengths[b] += 1
# ── Reward: post-push components ────────────────────
if board.is_check():
r += sign * 0.3 # gave check
total_checks += 1
if board.is_game_over():
if board.is_checkmate():
# Mover delivered checkmate
speed_bonus = 0.5 * math.exp(-game_lengths[b] / 20.0)
r += sign * (1.0 + speed_bonus) # ~1.0-1.5
if mover_is_white: white_wins += 1
else: black_wins += 1
else:
# Draw (stalemate / 50-move / repetition / insufficient material)
r -= 0.5 # flat penalty from white's perspective β€” attack to WIN
draws_count += 1
dones_buf[t, b] = True
active[b] = False
rewards_buf[t, b] = r
# end per-env loop
# end rollout
# ── PHASE 2: VECTORIZED RETURNS ────────────────────────────────
returns = torch.zeros(B, dtype=torch.float32, device=self.device)
returns_buf = torch.zeros((max_steps, B), dtype=torch.float32, device=self.device)
not_done_f = (~dones_buf).float()
for step in reversed(range(max_steps)):
returns = rewards_buf[step] + gamma * returns * not_done_f[step]
returns_buf[step]= returns
# ── PHASE 3: GRPO GROUP-WISE ADVANTAGE NORMALIZATION ───────────
# advantages shape [max_steps, B]
adv_raw = returns_buf - values_buf
active_f = active_buf.float()
# Reshape to [max_steps, num_groups, G] and normalize within each group
adv_3d = adv_raw.view(max_steps, num_groups, G)
act_3d = active_f.view(max_steps, num_groups, G)
g_count = act_3d.sum(dim=[0, 2]).clamp(min=1.0) # [num_groups]
g_mean = (adv_3d * act_3d).sum(dim=[0, 2]) / g_count # [num_groups]
g_sq_diff = ((adv_3d - g_mean.view(1, num_groups, 1)) ** 2
* act_3d).sum(dim=[0, 2])
g_std = (g_sq_diff / g_count).sqrt().clamp(min=1e-8) # [num_groups]
adv_3d = (adv_3d - g_mean.view(1, num_groups, 1)) / \
g_std.view(1, num_groups, 1)
adv_norm = adv_3d.view(max_steps, B)
# Flatten, filter to active steps only
valid_mask = active_buf.view(-1)
flat_states = (states_buf.view(-1, 14, 8, 8)[valid_mask]
.to(torch.float32, memory_format=torch.channels_last)
.div_(127.0))
flat_actions = actions_buf.view(-1)[valid_mask].to(torch.int64)
flat_old_lp = logprobs_buf.view(-1)[valid_mask]
flat_returns = returns_buf.view(-1)[valid_mask]
flat_advantages = adv_norm.view(-1)[valid_mask]
dataset_size = flat_states.size(0)
if dataset_size < 100:
continue # skip degenerate rollout (all games ended instantly)
# ── PHASE 4: PPO OPTIMIZATION ──────────────────────────────────
self.model.train()
total_p_loss = total_v_loss = 0.0
num_updates = 0
mb_size = CONFIG["mini_batch_size"]
for _ in range(CONFIG["ppo_epochs"]):
perm = torch.randperm(dataset_size, device=self.device)
for start in range(0, dataset_size, mb_size):
mb = perm[start: start + mb_size]
with torch.amp.autocast('cuda'):
new_logits, new_vals = self.model(flat_states[mb])
new_dist = torch.distributions.Categorical(logits=new_logits)
new_lp = new_dist.log_prob(flat_actions[mb])
ratio = torch.exp(new_lp - flat_old_lp[mb])
adv = flat_advantages[mb]
surr1 = ratio * adv
surr2 = torch.clamp(
ratio,
1.0 - CONFIG["clip_epsilon"],
1.0 + CONFIG["clip_epsilon"],
) * adv
p_loss = -torch.min(surr1, surr2).mean()
v_loss = F.mse_loss(new_vals.squeeze(-1), flat_returns[mb])
entropy = new_dist.entropy().mean()
loss = (p_loss
+ CONFIG["value_coef"] * v_loss
- CONFIG["entropy_coef"] * entropy)
self.optimizer.zero_grad(set_to_none=True)
self.scaler.scale(loss).backward()
self.scaler.unscale_(self.optimizer)
nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
self.scaler.step(self.optimizer)
self.scaler.update()
total_p_loss += p_loss.item()
total_v_loss += v_loss.item()
num_updates += 1
# ── PHASE 5: METRICS & LOGGING ────────────────────────────────
done_count = white_wins + black_wins + draws_count
win_rate = white_wins / max(done_count, 1)
draw_rate = draws_count / max(done_count, 1)
active_steps = int(active_buf.sum().item())
check_rate = total_checks / max(active_steps, 1)
capture_rate = total_captures / max(active_steps, 1)
avg_game_len = float(game_lengths.mean())
fps = dataset_size / max(time.time() - iter_start, 1e-3)
if (iteration + 1) % CONFIG["log_interval"] == 0:
vram_alloc = (torch.cuda.memory_allocated() / 1e9
if torch.cuda.is_available() else 0.0)
vram_res = (torch.cuda.memory_reserved() / 1e9
if torch.cuda.is_available() else 0.0)
print(
f"[{iteration+1:05d}] "
f"P:{total_p_loss/max(1,num_updates):.4f} "
f"V:{total_v_loss/max(1,num_updates):.4f} | "
f"W:{win_rate:.3f} D:{draw_rate:.3f} "
f"Chk:{check_rate:.4f} Cap:{capture_rate:.4f} "
f"Len:{avg_game_len:.1f} | "
f"ELO:{self.elo_tracker.elo:.0f} | "
f"FPS:{fps:.0f} | "
f"VRAM:{vram_alloc:.2f}/{vram_res:.2f}GB"
)
with open(self.log_file, "a", newline="") as f:
csv.writer(f).writerow([
iteration + 1,
total_p_loss / max(1, num_updates),
total_v_loss / max(1, num_updates),
flat_returns.mean().item(),
fps, win_rate, draw_rate,
check_rate, capture_rate, avg_game_len,
])
# Save best checkpoint when win_rate improves
if win_rate > self.best_win_rate:
self.best_win_rate = win_rate
self.save_checkpoint(iteration + 1, is_best=True)
if (iteration + 1) % CONFIG["save_interval"] == 0:
self.save_checkpoint(iteration + 1)
self.plot_metrics()
# ELO evaluation
if (iteration + 1) % CONFIG["elo_eval_interval"] == 0:
elo_before = self.elo_tracker.elo
ew, ed, el = self.evaluate_elo(CONFIG["elo_eval_games"])
print(
f" [ELO eval] {elo_before:.0f} β†’ {self.elo_tracker.elo:.0f} | "
f"W:{ew} D:{ed} L:{el} vs random({RANDOM_BASELINE_ELO})"
)
with open(self.elo_log_file, "a", newline="") as f:
csv.writer(f).writerow(
[iteration + 1, self.elo_tracker.elo, ew, ed, el])
self.plot_metrics()
# Aggressive cache reclaim (free fragmented blocks, not pinned allocs)
torch.cuda.empty_cache()
# ── Plotting ───────────────────────────────────────────────────────────────
def plot_metrics(self) -> None:
if not os.path.exists(self.log_file): return
df = pd.read_csv(self.log_file)
if len(df) < 2: return
elo_df = None
if os.path.exists(self.elo_log_file):
elo_df = pd.read_csv(self.elo_log_file)
fig, axs = plt.subplots(3, 2, figsize=(14, 12))
fig.suptitle("Aggressive GRPO Chess Agent β€” Training Dashboard", fontsize=14)
# Row 0: Losses
axs[0, 0].plot(df['iteration'], df['p_loss'], color='steelblue', linewidth=1.2)
axs[0, 0].set_title('Policy Loss'); axs[0, 0].set_xlabel('Iteration')
axs[0, 1].plot(df['iteration'], df['v_loss'], color='tomato', linewidth=1.2)
axs[0, 1].set_title('Value Loss'); axs[0, 1].set_xlabel('Iteration')
# Row 1: Outcomes
axs[1, 0].plot(df['iteration'], df['win_rate'], label='Win', color='green')
axs[1, 0].plot(df['iteration'], df['draw_rate'], label='Draw', color='orange')
axs[1, 0].set_title('Outcomes (White perspective)')
axs[1, 0].legend(); axs[1, 0].set_xlabel('Iteration')
# Row 1: Attack metrics
axs[1, 1].plot(df['iteration'], df['check_rate'], label='Check/step', color='purple')
axs[1, 1].plot(df['iteration'], df['capture_rate'], label='Capture/step', color='darkorange')
axs[1, 1].set_title('Attack Metrics (↑ = more aggressive)')
axs[1, 1].legend(); axs[1, 1].set_xlabel('Iteration')
# Row 2: ELO Rating
if elo_df is not None and len(elo_df) > 0:
axs[2, 0].plot(elo_df['iteration'], elo_df['elo'],
color='gold', linewidth=2.0, label='Agent ELO')
axs[2, 0].axhline(RANDOM_BASELINE_ELO, linestyle='--',
color='gray', alpha=0.8, label=f'Random ({RANDOM_BASELINE_ELO})')
axs[2, 0].axhline(1200, linestyle=':', color='lightblue',
alpha=0.6, label='Start (1200)')
axs[2, 0].fill_between(elo_df['iteration'], RANDOM_BASELINE_ELO,
elo_df['elo'], alpha=0.15, color='gold')
axs[2, 0].set_title('ELO Rating vs Random Baseline')
axs[2, 0].legend(); axs[2, 0].set_xlabel('Iteration')
else:
axs[2, 0].text(0.5, 0.5, f'ELO eval every {CONFIG["elo_eval_interval"]} iters',
ha='center', va='center', transform=axs[2, 0].transAxes,
color='gray', fontsize=11)
axs[2, 0].set_title('ELO Rating (pending)')
# Row 2: Average game length
axs[2, 1].plot(df['iteration'], df['avg_game_len'], color='teal', linewidth=1.2)
axs[2, 1].set_title('Avg Game Length (↓ = faster checkmates)')
axs[2, 1].set_xlabel('Iteration')
for ax in axs.flat:
ax.grid(True, alpha=0.25)
plt.tight_layout()
out = os.path.join(CONFIG["checkpoint_dir"], "training_performance.png")
plt.savefig(out, dpi=100, bbox_inches='tight')
plt.close(fig)
print(f" [Plot] saved β†’ {out}")
# ── Entry Point ────────────────────────────────────────────────────────────────
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Aggressive GRPO Chess Agent (T4/Colab)")
parser.add_argument("--iterations", type=int, default=10000,
help="Total training iterations")
parser.add_argument("--test-batch", action="store_true",
help="Run 2 iterations for smoke-test")
args, _ = parser.parse_known_args()
torch.manual_seed(CONFIG["seed"])
np.random.seed(CONFIG["seed"])
random.seed(CONFIG["seed"])
# Print VRAM summary at startup
if torch.cuda.is_available():
props = torch.cuda.get_device_properties(0)
print(f"GPU: {props.name} | VRAM: {props.total_memory/1e9:.1f}GB | "
f"SM: {props.multi_processor_count} | "
f"Compute: {props.major}.{props.minor}")
trainer = GRPOTrainer()
trainer.train(2 if args.test_batch else args.iterations)