thomas-schweich's picture
download
raw
21.5 kB
"""PAWN training loop with checkpointing and monitoring.
Uses `AdamW <https://arxiv.org/abs/1711.05101>`_ (Loshchilov & Hutter,
2017) with cosine LR decay (`Loshchilov & Hutter, 2016
<https://arxiv.org/abs/1608.03983>`_) and mixed-precision training
(`Micikevicius et al., 2017 <https://arxiv.org/abs/1710.03740>`_).
"""
import json
import math
import os
import signal
import sys
import time
from datetime import datetime, timezone
import psutil
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from pawn.config import CLMConfig, TrainingConfig
from pawn.model import PAWNCLM, clm_loss
from pawn.data import CLMDataset, create_validation_set
from pawn.logging import MetricsLogger
from pawn.data_utils import unpack_grid
class CosineWithWarmup:
"""Cosine LR schedule with linear warmup.
Based on SGDR (`Loshchilov & Hutter, 2016
<https://arxiv.org/abs/1608.03983>`_).
"""
def __init__(
self,
optimizer: torch.optim.Optimizer,
warmup_steps: int,
total_steps: int,
min_lr_ratio: float = 0.1,
):
self.optimizer = optimizer
self.warmup_steps = warmup_steps
self.total_steps = total_steps
self.min_lr_ratio = min_lr_ratio
self.base_lrs = [pg["lr"] for pg in optimizer.param_groups]
self._step = 0
self._apply_lr(0)
def _apply_lr(self, step: int) -> None:
lr_scale = self._compute_lr_scale(step)
for pg, base_lr in zip(self.optimizer.param_groups, self.base_lrs, strict=True):
pg["lr"] = base_lr * lr_scale
def step(self) -> None:
self._step += 1
self._apply_lr(self._step)
def _compute_lr_scale(self, step: int) -> float:
if step < self.warmup_steps:
return step / max(1, self.warmup_steps)
progress = (step - self.warmup_steps) / max(1, self.total_steps - self.warmup_steps)
progress = min(progress, 1.0)
return self.min_lr_ratio + 0.5 * (1.0 - self.min_lr_ratio) * (
1.0 + math.cos(math.pi * progress)
)
def get_lr(self) -> float:
return self.optimizer.param_groups[0]["lr"]
def state_dict(self) -> dict[str, int]:
return {"step": self._step}
def load_state_dict(self, state: dict[str, int]) -> None:
self._step = state["step"]
self._apply_lr(self._step)
def _build_promo_grid_index() -> list[int]:
"""Build a mapping from promotion token to grid index (src*64 + dst).
Promotion tokens 4097..4272 encode 44 (src, dst) pairs x 4 piece types.
Token layout: PROMO_START + pair_idx * 4 + promo_type.
"""
import chess_engine
vocab = chess_engine.export_move_vocabulary()
promo_pairs = vocab["promo_pairs"] # list of (src, dst) tuples, len=44
# For each of the 176 promo tokens, compute src*64+dst
grid_indices = []
for src, dst in promo_pairs:
grid_idx = src * 64 + dst
grid_indices.extend([grid_idx] * 4) # 4 piece types per pair
return grid_indices
# Lazily initialized on first use
_PROMO_GRID_INDEX: list[int] | None = None
def _get_promo_grid_index(device: str | torch.device) -> torch.Tensor:
"""Get the promo-token-to-grid-index mapping as a tensor on the given device."""
global _PROMO_GRID_INDEX
if _PROMO_GRID_INDEX is None:
_PROMO_GRID_INDEX = _build_promo_grid_index()
return torch.tensor(_PROMO_GRID_INDEX, dtype=torch.long, device=device)
def compute_legal_move_rate(
logits: torch.Tensor,
legal_grid: torch.Tensor,
loss_mask: torch.Tensor,
game_lengths: torch.Tensor,
) -> float:
"""Compute fraction of argmax predictions that are legal moves.
Evaluated at positions 0 through game_lengths (inclusive), matching the
loss_mask semantics (which includes the end-of-game PAD prediction).
Args:
logits: (B, T, vocab_size)
legal_grid: (B, max_ply, 64) bit-packed legal moves from engine
loss_mask: (B, T) bool
game_lengths: (B,) int
"""
B, T, V = logits.shape
max_ply = legal_grid.shape[1]
with torch.no_grad():
# Positions where target is an actual move or end-of-game PAD:
# 0..game_lengths in CLM indexing (matches loss_mask <= semantics)
move_mask = torch.arange(T, device=logits.device).unsqueeze(0) <= game_lengths.unsqueeze(1)
move_mask = move_mask & loss_mask
if not move_mask.any():
return 0.0
preds = logits.argmax(dim=-1) # (B, T)
# Unpack legal grid to dense: (B, max_ply, 64, 64) -> flatten to (B, max_ply, 4096)
legal_dense = unpack_grid(legal_grid) # (B, max_ply, 64, 64)
legal_flat = legal_dense.reshape(B, max_ply, 4096) # (B, max_ply, 4096)
n_plies = min(T, max_ply)
valid_count = 0
legal_acc = torch.tensor(0, dtype=torch.long, device=logits.device)
# Promo token -> grid index lookup (lazily built, cached)
promo_grid_idx = _get_promo_grid_index(logits.device) # (176,)
for p in range(n_plies):
pos_mask = move_mask[:, p] # (B,)
if not pos_mask.any():
continue
batch_preds = preds[pos_mask, p] # (N,)
batch_legal = legal_flat[pos_mask, p] # (N, 4096)
n = len(batch_preds)
arange_n = torch.arange(n, device=logits.device)
# Base grid tokens (1-4096): grid index = token - 1
is_base = (batch_preds >= 1) & (batch_preds <= 4096)
base_idx = (batch_preds - 1).clamp(0, 4095)
base_legal = batch_legal[arange_n, base_idx] > 0.5
legal_base = is_base & base_legal
# Promotion tokens (4097-4272): look up the (src, dst) grid index
is_promo = (batch_preds >= 4097) & (batch_preds <= 4272)
promo_offset = (batch_preds - 4097).clamp(0, 175)
promo_grid = promo_grid_idx[promo_offset] # (N,) grid index per pred
promo_legal = batch_legal[arange_n, promo_grid] > 0.5
legal_promo = is_promo & promo_legal
valid_count += n
legal_acc += legal_base.sum() + legal_promo.sum()
if valid_count == 0:
return 0.0
return legal_acc.item() / valid_count
def _get_grad_norm(model: nn.Module) -> float:
grads = [p.grad.data for p in model.parameters() if p.grad is not None]
if not grads:
return 0.0
total = torch.stack([g.float().norm() for g in grads]).square().sum()
return total.sqrt().item()
class CLMTrainer:
def __init__(
self,
train_cfg: TrainingConfig,
model_cfg: CLMConfig,
hf_repo: str | None = None,
):
self.cfg = train_cfg
self.model_cfg = model_cfg
self.device = train_cfg.device
self.global_step = 0
self.hf_repo = hf_repo
self.hf_branch: str | None = None
self.logger = MetricsLogger(
train_cfg.log_dir, run_prefix="run", device=self.device,
)
self.run_dir = str(self.logger.run_dir)
self.cfg.checkpoint_dir = os.path.join(self.run_dir, "checkpoints")
self._jsonl_path = str(self.logger.metrics_path)
if self.hf_repo:
self.hf_branch = f"run/{os.path.basename(self.run_dir)}"
self._model = PAWNCLM(model_cfg).to(self.device)
self.model = self._model
param_count = sum(p.numel() for p in self._model.parameters())
print(f"Model parameters: {param_count:,}")
print(f"Run directory: {self.run_dir}")
self.optimizer = torch.optim.AdamW(
self.model.parameters(),
lr=train_cfg.lr,
weight_decay=train_cfg.weight_decay,
betas=(0.9, 0.95),
)
self.scheduler = CosineWithWarmup(
self.optimizer,
warmup_steps=train_cfg.warmup_steps,
total_steps=train_cfg.total_steps,
)
self.scaler = torch.amp.GradScaler(self.device, enabled=train_cfg.use_amp)
self.dataset = CLMDataset(
train_cfg.batch_size, train_cfg.max_ply, train_cfg.base_seed,
discard_ply_limit=train_cfg.discard_ply_limit,
no_outcome=train_cfg.no_outcome_token,
mate_boost=train_cfg.mate_boost,
)
print("Generating validation set...")
self.val_data = create_validation_set(
train_cfg.val_games, train_cfg.max_ply, train_cfg.val_seed,
discard_ply_limit=train_cfg.discard_ply_limit,
no_outcome=train_cfg.no_outcome_token,
mate_boost=train_cfg.mate_boost,
)
# W&B
self.wandb_run = None
if train_cfg.use_wandb:
try:
import wandb
self.wandb_run = wandb.init(
project=train_cfg.wandb_project,
config={
"model": model_cfg.__dict__,
"training": train_cfg.__dict__,
},
)
except (ImportError, Exception) as e:
print(f"W&B init failed: {e}. Continuing without W&B.")
# torch.compile
self._compiled = False
if self.device != "cpu":
try:
self.model = torch.compile(self.model, mode="default")
self._compiled = True
print("torch.compile enabled")
except Exception:
print("torch.compile not available, using eager mode")
else:
print("Skipping torch.compile on CPU")
self.logger.log_config(
model=model_cfg.__dict__,
training=train_cfg.__dict__,
param_count=param_count,
compiled=self._compiled,
formulation="clm",
)
self.logger.write_config_json(
model=model_cfg.__dict__,
training=train_cfg.__dict__,
param_count=param_count,
compiled=self._compiled,
formulation="clm",
)
def seed_logs(self, run_dirs: list[str], max_step: int):
"""Splice prior run logs into this run's JSONL."""
from pathlib import Path
all_records: list[dict] = []
for rd in run_dirs:
p = Path(rd) / "metrics.jsonl"
if not p.exists():
print(f" WARNING: {p} not found, skipping")
continue
with open(p, "rb") as f:
data = f.read()
text = data.rstrip(b"\x00").decode(errors="replace")
for line in text.splitlines():
line = line.strip()
if not line:
continue
try:
all_records.append(json.loads(line))
except json.JSONDecodeError:
continue
all_records = [r for r in all_records if r.get("step", 0) <= max_step]
all_records.sort(key=lambda r: (r.get("step", 0), r.get("type", "")))
seen: set[tuple[str, int]] = set()
deduped: list[dict] = []
for r in all_records:
key = (r.get("type", ""), r.get("step", 0))
if key not in seen:
seen.add(key)
deduped.append(r)
if not deduped:
print(" No prior log lines to seed.")
return
with open(self._jsonl_path, "w") as f:
for r in deduped:
f.write(json.dumps(r, default=str) + "\n")
first_step = deduped[0].get("step", "?")
last_step = deduped[-1].get("step", "?")
print(f"Seeded {len(deduped)} log lines from prior runs "
f"(steps {first_step}-{last_step})")
def _log_jsonl(self, record: dict):
"""Low-level JSONL write for seed_logs compatibility."""
self.logger._write(record)
def train_step(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
self.model.train()
input_ids = batch["input_ids"].to(self.device)
targets = batch["targets"].to(self.device)
loss_mask = batch["loss_mask"].to(self.device)
model = self._eager_model()
with torch.amp.autocast(self.device, enabled=self.cfg.use_amp):
loss, metrics = model.forward_train(input_ids, loss_mask, targets)
scaled_loss = loss / self.cfg.accumulation_steps
self.scaler.scale(scaled_loss).backward()
return metrics
def optimizer_step(self) -> float:
self.scaler.unscale_(self.optimizer)
grad_norm = _get_grad_norm(self._model)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.max_grad_norm)
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad(set_to_none=True)
self.scheduler.step()
return grad_norm
def _eager_model(self) -> PAWNCLM:
return self._model
@torch.no_grad()
def evaluate(self) -> dict[str, float]:
model = self._eager_model()
model.eval()
n = self.val_data["input_ids"].shape[0]
batch_size = self.cfg.batch_size
total_metrics: dict[str, float] = {}
n_batches = 0
has_legal = "legal_grid" in self.val_data
total_legal_count = 0
total_move_count = 0
for start in range(0, n, batch_size):
end = min(start + batch_size, n)
input_ids = self.val_data["input_ids"][start:end].to(self.device)
targets = self.val_data["targets"][start:end].to(self.device)
loss_mask = self.val_data["loss_mask"][start:end].to(self.device)
with torch.amp.autocast(self.device, enabled=self.cfg.use_amp):
logits, _layer_outputs = model(input_ids, loss_mask)
del _layer_outputs
_, metrics = clm_loss(logits, targets, loss_mask)
# Top-5 accuracy
valid_logits = logits[loss_mask]
valid_targets = targets[loss_mask]
top5 = valid_logits.topk(5, dim=-1).indices
top5_acc = (top5 == valid_targets.unsqueeze(-1)).any(dim=-1).float().mean().item()
metrics["top5_accuracy"] = top5_acc
# Legal move rate (if legal grid available)
if has_legal:
legal_grid = self.val_data["legal_grid"][start:end].to(self.device)
game_lengths = self.val_data["game_lengths"][start:end].to(self.device)
legal_rate = compute_legal_move_rate(logits, legal_grid, loss_mask, game_lengths)
metrics["legal_move_rate"] = legal_rate
for k, v in metrics.items():
total_metrics[k] = total_metrics.get(k, 0.0) + v
n_batches += 1
if self.device != "cpu" and torch.cuda.is_available():
torch.cuda.empty_cache()
avg = {f"val/{k}": v / n_batches for k, v in total_metrics.items()}
avg["val/perplexity"] = math.exp(min(avg["val/loss"], 20.0))
return avg
def train(self):
self.dataset.set_start_step(self.global_step)
num_workers = self.cfg.num_workers
loader = DataLoader(
self.dataset,
batch_size=None,
num_workers=num_workers,
pin_memory=(self.device != "cpu"),
persistent_workers=(num_workers > 0),
prefetch_factor=1 if num_workers > 0 else None,
)
_shutdown_requested = False
_shutdown_signal = None
def _graceful_exit(signum, frame):
nonlocal _shutdown_requested, _shutdown_signal
_shutdown_requested = True
_shutdown_signal = signum
old_term = signal.signal(signal.SIGTERM, _graceful_exit)
old_int = signal.signal(signal.SIGINT, _graceful_exit)
os.makedirs(self.cfg.checkpoint_dir, exist_ok=True)
accum_count = 0
step_start = time.time()
games_per_step = self.cfg.batch_size * self.cfg.accumulation_steps
print(f"Starting training from step {self.global_step}", flush=True)
print(f"JSONL log: {self._jsonl_path}", flush=True)
for batch in loader:
metrics = self.train_step(batch)
accum_count += 1
if accum_count >= self.cfg.accumulation_steps:
grad_norm = self.optimizer_step()
accum_count = 0
self.global_step += 1
step_time = time.time() - step_start
games_per_sec = games_per_step / step_time
if self.global_step % self.cfg.log_interval == 0:
# .item() sync only at log intervals (metrics are tensors here)
loss_val = metrics['loss'].item()
acc_val = metrics['accuracy'].item()
lr = self.scheduler.get_lr()
print(
f"step {self.global_step:>7d} | "
f"loss {loss_val:.4f} | "
f"acc {acc_val:.3f} | "
f"lr {lr:.2e} | "
f"gn {grad_norm:.2f} | "
f"{games_per_sec:.0f} g/s | "
f"{step_time:.2f}s",
flush=True,
)
self.logger.log_train(
step=self.global_step,
lr=lr, grad_norm=grad_norm,
step_time=step_time, games_per_sec=games_per_sec,
**{"train/loss": loss_val, "train/accuracy": acc_val}, # type: ignore[arg-type]
)
if self.wandb_run:
self.wandb_run.log({
"train/loss": loss_val, "train/accuracy": acc_val,
"train/lr": lr, "train/grad_norm": grad_norm,
"train/step_time": step_time, "train/games_per_sec": games_per_sec,
}, step=self.global_step)
if self.global_step % self.cfg.eval_interval == 0:
val_metrics = self.evaluate()
val_msg = (
f" val: loss {val_metrics['val/loss']:.4f} | "
f"acc {val_metrics['val/accuracy']:.3f} | "
f"top5 {val_metrics.get('val/top5_accuracy', 0):.3f} | "
f"ppl {val_metrics.get('val/perplexity', 0):.1f}"
)
if "val/legal_move_rate" in val_metrics:
val_msg += f" | legal {val_metrics['val/legal_move_rate']:.3f}"
print(val_msg, flush=True)
self.logger.log_val(step=self.global_step, **val_metrics) # type: ignore[arg-type]
if self.wandb_run:
self.wandb_run.log(val_metrics, step=self.global_step)
if self.global_step % self.cfg.checkpoint_interval == 0:
self.save_checkpoint()
if self.global_step >= self.cfg.total_steps:
print(f"Training complete at step {self.global_step}")
self.save_checkpoint()
break
if (self.cfg.pause_after_steps
and self.global_step >= self.cfg.pause_after_steps):
print(f"\n Paused at step {self.global_step} "
f"(pause_after_steps={self.cfg.pause_after_steps})")
self.save_checkpoint()
break
if _shutdown_requested:
print(f"\nShutdown requested (signal {_shutdown_signal}), "
f"saving checkpoint at step {self.global_step}...")
self.save_checkpoint()
break
step_start = time.time()
signal.signal(signal.SIGTERM, old_term)
signal.signal(signal.SIGINT, old_int)
self.logger.close()
def save_checkpoint(self, path: str | None = None):
from pawn.checkpoint import save_pretrain_checkpoint
if path is None:
path = os.path.join(
self.cfg.checkpoint_dir, f"step_{self.global_step:08d}"
)
model: PAWNCLM = self._eager_model()
save_pretrain_checkpoint(
path,
model,
self.optimizer,
self.scheduler,
self.scaler,
self.global_step,
self.model_cfg.__dict__,
self.cfg.__dict__,
)
print(f"Checkpoint saved: {path}")
if self.hf_repo and self.hf_branch:
from pawn.checkpoint import push_checkpoint_to_hf
try:
push_checkpoint_to_hf(
path, self.hf_repo, self.hf_branch,
metrics_path=self._jsonl_path,
step=self.global_step,
)
print(f"Pushed to HF: {self.hf_repo}@{self.hf_branch}")
except Exception as e:
print(f"WARNING: HF push failed: {e}")
def load_checkpoint(self, path: str):
from pawn.checkpoint import load_pretrain_checkpoint
model: PAWNCLM = self._eager_model()
meta = load_pretrain_checkpoint(
path, model, self.optimizer, self.scheduler, self.scaler,
device=self.device,
)
self.global_step = meta["global_step"]
print(f"Resumed from step {self.global_step}")

Xet Storage Details

Size:
21.5 kB
·
Xet hash:
634d63104c69cfdeccc8d8aa8e643ac4d881888c419eea8abc487d3f0e16493a

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.