Buckets:
| """PAWN training loop with checkpointing and monitoring. | |
| Uses `AdamW <https://arxiv.org/abs/1711.05101>`_ (Loshchilov & Hutter, | |
| 2017) with cosine LR decay (`Loshchilov & Hutter, 2016 | |
| <https://arxiv.org/abs/1608.03983>`_) and mixed-precision training | |
| (`Micikevicius et al., 2017 <https://arxiv.org/abs/1710.03740>`_). | |
| """ | |
| import json | |
| import math | |
| import os | |
| import signal | |
| import sys | |
| import time | |
| from datetime import datetime, timezone | |
| import psutil | |
| import torch | |
| import torch.nn as nn | |
| from torch.utils.data import DataLoader | |
| from pawn.config import CLMConfig, TrainingConfig | |
| from pawn.model import PAWNCLM, clm_loss | |
| from pawn.data import CLMDataset, create_validation_set | |
| from pawn.logging import MetricsLogger | |
| from pawn.data_utils import unpack_grid | |
| class CosineWithWarmup: | |
| """Cosine LR schedule with linear warmup. | |
| Based on SGDR (`Loshchilov & Hutter, 2016 | |
| <https://arxiv.org/abs/1608.03983>`_). | |
| """ | |
| def __init__( | |
| self, | |
| optimizer: torch.optim.Optimizer, | |
| warmup_steps: int, | |
| total_steps: int, | |
| min_lr_ratio: float = 0.1, | |
| ): | |
| self.optimizer = optimizer | |
| self.warmup_steps = warmup_steps | |
| self.total_steps = total_steps | |
| self.min_lr_ratio = min_lr_ratio | |
| self.base_lrs = [pg["lr"] for pg in optimizer.param_groups] | |
| self._step = 0 | |
| self._apply_lr(0) | |
| def _apply_lr(self, step: int) -> None: | |
| lr_scale = self._compute_lr_scale(step) | |
| for pg, base_lr in zip(self.optimizer.param_groups, self.base_lrs, strict=True): | |
| pg["lr"] = base_lr * lr_scale | |
| def step(self) -> None: | |
| self._step += 1 | |
| self._apply_lr(self._step) | |
| def _compute_lr_scale(self, step: int) -> float: | |
| if step < self.warmup_steps: | |
| return step / max(1, self.warmup_steps) | |
| progress = (step - self.warmup_steps) / max(1, self.total_steps - self.warmup_steps) | |
| progress = min(progress, 1.0) | |
| return self.min_lr_ratio + 0.5 * (1.0 - self.min_lr_ratio) * ( | |
| 1.0 + math.cos(math.pi * progress) | |
| ) | |
| def get_lr(self) -> float: | |
| return self.optimizer.param_groups[0]["lr"] | |
| def state_dict(self) -> dict[str, int]: | |
| return {"step": self._step} | |
| def load_state_dict(self, state: dict[str, int]) -> None: | |
| self._step = state["step"] | |
| self._apply_lr(self._step) | |
| def _build_promo_grid_index() -> list[int]: | |
| """Build a mapping from promotion token to grid index (src*64 + dst). | |
| Promotion tokens 4097..4272 encode 44 (src, dst) pairs x 4 piece types. | |
| Token layout: PROMO_START + pair_idx * 4 + promo_type. | |
| """ | |
| import chess_engine | |
| vocab = chess_engine.export_move_vocabulary() | |
| promo_pairs = vocab["promo_pairs"] # list of (src, dst) tuples, len=44 | |
| # For each of the 176 promo tokens, compute src*64+dst | |
| grid_indices = [] | |
| for src, dst in promo_pairs: | |
| grid_idx = src * 64 + dst | |
| grid_indices.extend([grid_idx] * 4) # 4 piece types per pair | |
| return grid_indices | |
| # Lazily initialized on first use | |
| _PROMO_GRID_INDEX: list[int] | None = None | |
| def _get_promo_grid_index(device: str | torch.device) -> torch.Tensor: | |
| """Get the promo-token-to-grid-index mapping as a tensor on the given device.""" | |
| global _PROMO_GRID_INDEX | |
| if _PROMO_GRID_INDEX is None: | |
| _PROMO_GRID_INDEX = _build_promo_grid_index() | |
| return torch.tensor(_PROMO_GRID_INDEX, dtype=torch.long, device=device) | |
| def compute_legal_move_rate( | |
| logits: torch.Tensor, | |
| legal_grid: torch.Tensor, | |
| loss_mask: torch.Tensor, | |
| game_lengths: torch.Tensor, | |
| ) -> float: | |
| """Compute fraction of argmax predictions that are legal moves. | |
| Evaluated at positions 0 through game_lengths (inclusive), matching the | |
| loss_mask semantics (which includes the end-of-game PAD prediction). | |
| Args: | |
| logits: (B, T, vocab_size) | |
| legal_grid: (B, max_ply, 64) bit-packed legal moves from engine | |
| loss_mask: (B, T) bool | |
| game_lengths: (B,) int | |
| """ | |
| B, T, V = logits.shape | |
| max_ply = legal_grid.shape[1] | |
| with torch.no_grad(): | |
| # Positions where target is an actual move or end-of-game PAD: | |
| # 0..game_lengths in CLM indexing (matches loss_mask <= semantics) | |
| move_mask = torch.arange(T, device=logits.device).unsqueeze(0) <= game_lengths.unsqueeze(1) | |
| move_mask = move_mask & loss_mask | |
| if not move_mask.any(): | |
| return 0.0 | |
| preds = logits.argmax(dim=-1) # (B, T) | |
| # Unpack legal grid to dense: (B, max_ply, 64, 64) -> flatten to (B, max_ply, 4096) | |
| legal_dense = unpack_grid(legal_grid) # (B, max_ply, 64, 64) | |
| legal_flat = legal_dense.reshape(B, max_ply, 4096) # (B, max_ply, 4096) | |
| n_plies = min(T, max_ply) | |
| valid_count = 0 | |
| legal_acc = torch.tensor(0, dtype=torch.long, device=logits.device) | |
| # Promo token -> grid index lookup (lazily built, cached) | |
| promo_grid_idx = _get_promo_grid_index(logits.device) # (176,) | |
| for p in range(n_plies): | |
| pos_mask = move_mask[:, p] # (B,) | |
| if not pos_mask.any(): | |
| continue | |
| batch_preds = preds[pos_mask, p] # (N,) | |
| batch_legal = legal_flat[pos_mask, p] # (N, 4096) | |
| n = len(batch_preds) | |
| arange_n = torch.arange(n, device=logits.device) | |
| # Base grid tokens (1-4096): grid index = token - 1 | |
| is_base = (batch_preds >= 1) & (batch_preds <= 4096) | |
| base_idx = (batch_preds - 1).clamp(0, 4095) | |
| base_legal = batch_legal[arange_n, base_idx] > 0.5 | |
| legal_base = is_base & base_legal | |
| # Promotion tokens (4097-4272): look up the (src, dst) grid index | |
| is_promo = (batch_preds >= 4097) & (batch_preds <= 4272) | |
| promo_offset = (batch_preds - 4097).clamp(0, 175) | |
| promo_grid = promo_grid_idx[promo_offset] # (N,) grid index per pred | |
| promo_legal = batch_legal[arange_n, promo_grid] > 0.5 | |
| legal_promo = is_promo & promo_legal | |
| valid_count += n | |
| legal_acc += legal_base.sum() + legal_promo.sum() | |
| if valid_count == 0: | |
| return 0.0 | |
| return legal_acc.item() / valid_count | |
| def _get_grad_norm(model: nn.Module) -> float: | |
| grads = [p.grad.data for p in model.parameters() if p.grad is not None] | |
| if not grads: | |
| return 0.0 | |
| total = torch.stack([g.float().norm() for g in grads]).square().sum() | |
| return total.sqrt().item() | |
| class CLMTrainer: | |
| def __init__( | |
| self, | |
| train_cfg: TrainingConfig, | |
| model_cfg: CLMConfig, | |
| hf_repo: str | None = None, | |
| ): | |
| self.cfg = train_cfg | |
| self.model_cfg = model_cfg | |
| self.device = train_cfg.device | |
| self.global_step = 0 | |
| self.hf_repo = hf_repo | |
| self.hf_branch: str | None = None | |
| self.logger = MetricsLogger( | |
| train_cfg.log_dir, run_prefix="run", device=self.device, | |
| ) | |
| self.run_dir = str(self.logger.run_dir) | |
| self.cfg.checkpoint_dir = os.path.join(self.run_dir, "checkpoints") | |
| self._jsonl_path = str(self.logger.metrics_path) | |
| if self.hf_repo: | |
| self.hf_branch = f"run/{os.path.basename(self.run_dir)}" | |
| self._model = PAWNCLM(model_cfg).to(self.device) | |
| self.model = self._model | |
| param_count = sum(p.numel() for p in self._model.parameters()) | |
| print(f"Model parameters: {param_count:,}") | |
| print(f"Run directory: {self.run_dir}") | |
| self.optimizer = torch.optim.AdamW( | |
| self.model.parameters(), | |
| lr=train_cfg.lr, | |
| weight_decay=train_cfg.weight_decay, | |
| betas=(0.9, 0.95), | |
| ) | |
| self.scheduler = CosineWithWarmup( | |
| self.optimizer, | |
| warmup_steps=train_cfg.warmup_steps, | |
| total_steps=train_cfg.total_steps, | |
| ) | |
| self.scaler = torch.amp.GradScaler(self.device, enabled=train_cfg.use_amp) | |
| self.dataset = CLMDataset( | |
| train_cfg.batch_size, train_cfg.max_ply, train_cfg.base_seed, | |
| discard_ply_limit=train_cfg.discard_ply_limit, | |
| no_outcome=train_cfg.no_outcome_token, | |
| mate_boost=train_cfg.mate_boost, | |
| ) | |
| print("Generating validation set...") | |
| self.val_data = create_validation_set( | |
| train_cfg.val_games, train_cfg.max_ply, train_cfg.val_seed, | |
| discard_ply_limit=train_cfg.discard_ply_limit, | |
| no_outcome=train_cfg.no_outcome_token, | |
| mate_boost=train_cfg.mate_boost, | |
| ) | |
| # W&B | |
| self.wandb_run = None | |
| if train_cfg.use_wandb: | |
| try: | |
| import wandb | |
| self.wandb_run = wandb.init( | |
| project=train_cfg.wandb_project, | |
| config={ | |
| "model": model_cfg.__dict__, | |
| "training": train_cfg.__dict__, | |
| }, | |
| ) | |
| except (ImportError, Exception) as e: | |
| print(f"W&B init failed: {e}. Continuing without W&B.") | |
| # torch.compile | |
| self._compiled = False | |
| if self.device != "cpu": | |
| try: | |
| self.model = torch.compile(self.model, mode="default") | |
| self._compiled = True | |
| print("torch.compile enabled") | |
| except Exception: | |
| print("torch.compile not available, using eager mode") | |
| else: | |
| print("Skipping torch.compile on CPU") | |
| self.logger.log_config( | |
| model=model_cfg.__dict__, | |
| training=train_cfg.__dict__, | |
| param_count=param_count, | |
| compiled=self._compiled, | |
| formulation="clm", | |
| ) | |
| self.logger.write_config_json( | |
| model=model_cfg.__dict__, | |
| training=train_cfg.__dict__, | |
| param_count=param_count, | |
| compiled=self._compiled, | |
| formulation="clm", | |
| ) | |
| def seed_logs(self, run_dirs: list[str], max_step: int): | |
| """Splice prior run logs into this run's JSONL.""" | |
| from pathlib import Path | |
| all_records: list[dict] = [] | |
| for rd in run_dirs: | |
| p = Path(rd) / "metrics.jsonl" | |
| if not p.exists(): | |
| print(f" WARNING: {p} not found, skipping") | |
| continue | |
| with open(p, "rb") as f: | |
| data = f.read() | |
| text = data.rstrip(b"\x00").decode(errors="replace") | |
| for line in text.splitlines(): | |
| line = line.strip() | |
| if not line: | |
| continue | |
| try: | |
| all_records.append(json.loads(line)) | |
| except json.JSONDecodeError: | |
| continue | |
| all_records = [r for r in all_records if r.get("step", 0) <= max_step] | |
| all_records.sort(key=lambda r: (r.get("step", 0), r.get("type", ""))) | |
| seen: set[tuple[str, int]] = set() | |
| deduped: list[dict] = [] | |
| for r in all_records: | |
| key = (r.get("type", ""), r.get("step", 0)) | |
| if key not in seen: | |
| seen.add(key) | |
| deduped.append(r) | |
| if not deduped: | |
| print(" No prior log lines to seed.") | |
| return | |
| with open(self._jsonl_path, "w") as f: | |
| for r in deduped: | |
| f.write(json.dumps(r, default=str) + "\n") | |
| first_step = deduped[0].get("step", "?") | |
| last_step = deduped[-1].get("step", "?") | |
| print(f"Seeded {len(deduped)} log lines from prior runs " | |
| f"(steps {first_step}-{last_step})") | |
| def _log_jsonl(self, record: dict): | |
| """Low-level JSONL write for seed_logs compatibility.""" | |
| self.logger._write(record) | |
| def train_step(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: | |
| self.model.train() | |
| input_ids = batch["input_ids"].to(self.device) | |
| targets = batch["targets"].to(self.device) | |
| loss_mask = batch["loss_mask"].to(self.device) | |
| model = self._eager_model() | |
| with torch.amp.autocast(self.device, enabled=self.cfg.use_amp): | |
| loss, metrics = model.forward_train(input_ids, loss_mask, targets) | |
| scaled_loss = loss / self.cfg.accumulation_steps | |
| self.scaler.scale(scaled_loss).backward() | |
| return metrics | |
| def optimizer_step(self) -> float: | |
| self.scaler.unscale_(self.optimizer) | |
| grad_norm = _get_grad_norm(self._model) | |
| torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.max_grad_norm) | |
| self.scaler.step(self.optimizer) | |
| self.scaler.update() | |
| self.optimizer.zero_grad(set_to_none=True) | |
| self.scheduler.step() | |
| return grad_norm | |
| def _eager_model(self) -> PAWNCLM: | |
| return self._model | |
| def evaluate(self) -> dict[str, float]: | |
| model = self._eager_model() | |
| model.eval() | |
| n = self.val_data["input_ids"].shape[0] | |
| batch_size = self.cfg.batch_size | |
| total_metrics: dict[str, float] = {} | |
| n_batches = 0 | |
| has_legal = "legal_grid" in self.val_data | |
| total_legal_count = 0 | |
| total_move_count = 0 | |
| for start in range(0, n, batch_size): | |
| end = min(start + batch_size, n) | |
| input_ids = self.val_data["input_ids"][start:end].to(self.device) | |
| targets = self.val_data["targets"][start:end].to(self.device) | |
| loss_mask = self.val_data["loss_mask"][start:end].to(self.device) | |
| with torch.amp.autocast(self.device, enabled=self.cfg.use_amp): | |
| logits, _layer_outputs = model(input_ids, loss_mask) | |
| del _layer_outputs | |
| _, metrics = clm_loss(logits, targets, loss_mask) | |
| # Top-5 accuracy | |
| valid_logits = logits[loss_mask] | |
| valid_targets = targets[loss_mask] | |
| top5 = valid_logits.topk(5, dim=-1).indices | |
| top5_acc = (top5 == valid_targets.unsqueeze(-1)).any(dim=-1).float().mean().item() | |
| metrics["top5_accuracy"] = top5_acc | |
| # Legal move rate (if legal grid available) | |
| if has_legal: | |
| legal_grid = self.val_data["legal_grid"][start:end].to(self.device) | |
| game_lengths = self.val_data["game_lengths"][start:end].to(self.device) | |
| legal_rate = compute_legal_move_rate(logits, legal_grid, loss_mask, game_lengths) | |
| metrics["legal_move_rate"] = legal_rate | |
| for k, v in metrics.items(): | |
| total_metrics[k] = total_metrics.get(k, 0.0) + v | |
| n_batches += 1 | |
| if self.device != "cpu" and torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| avg = {f"val/{k}": v / n_batches for k, v in total_metrics.items()} | |
| avg["val/perplexity"] = math.exp(min(avg["val/loss"], 20.0)) | |
| return avg | |
| def train(self): | |
| self.dataset.set_start_step(self.global_step) | |
| num_workers = self.cfg.num_workers | |
| loader = DataLoader( | |
| self.dataset, | |
| batch_size=None, | |
| num_workers=num_workers, | |
| pin_memory=(self.device != "cpu"), | |
| persistent_workers=(num_workers > 0), | |
| prefetch_factor=1 if num_workers > 0 else None, | |
| ) | |
| _shutdown_requested = False | |
| _shutdown_signal = None | |
| def _graceful_exit(signum, frame): | |
| nonlocal _shutdown_requested, _shutdown_signal | |
| _shutdown_requested = True | |
| _shutdown_signal = signum | |
| old_term = signal.signal(signal.SIGTERM, _graceful_exit) | |
| old_int = signal.signal(signal.SIGINT, _graceful_exit) | |
| os.makedirs(self.cfg.checkpoint_dir, exist_ok=True) | |
| accum_count = 0 | |
| step_start = time.time() | |
| games_per_step = self.cfg.batch_size * self.cfg.accumulation_steps | |
| print(f"Starting training from step {self.global_step}", flush=True) | |
| print(f"JSONL log: {self._jsonl_path}", flush=True) | |
| for batch in loader: | |
| metrics = self.train_step(batch) | |
| accum_count += 1 | |
| if accum_count >= self.cfg.accumulation_steps: | |
| grad_norm = self.optimizer_step() | |
| accum_count = 0 | |
| self.global_step += 1 | |
| step_time = time.time() - step_start | |
| games_per_sec = games_per_step / step_time | |
| if self.global_step % self.cfg.log_interval == 0: | |
| # .item() sync only at log intervals (metrics are tensors here) | |
| loss_val = metrics['loss'].item() | |
| acc_val = metrics['accuracy'].item() | |
| lr = self.scheduler.get_lr() | |
| print( | |
| f"step {self.global_step:>7d} | " | |
| f"loss {loss_val:.4f} | " | |
| f"acc {acc_val:.3f} | " | |
| f"lr {lr:.2e} | " | |
| f"gn {grad_norm:.2f} | " | |
| f"{games_per_sec:.0f} g/s | " | |
| f"{step_time:.2f}s", | |
| flush=True, | |
| ) | |
| self.logger.log_train( | |
| step=self.global_step, | |
| lr=lr, grad_norm=grad_norm, | |
| step_time=step_time, games_per_sec=games_per_sec, | |
| **{"train/loss": loss_val, "train/accuracy": acc_val}, # type: ignore[arg-type] | |
| ) | |
| if self.wandb_run: | |
| self.wandb_run.log({ | |
| "train/loss": loss_val, "train/accuracy": acc_val, | |
| "train/lr": lr, "train/grad_norm": grad_norm, | |
| "train/step_time": step_time, "train/games_per_sec": games_per_sec, | |
| }, step=self.global_step) | |
| if self.global_step % self.cfg.eval_interval == 0: | |
| val_metrics = self.evaluate() | |
| val_msg = ( | |
| f" val: loss {val_metrics['val/loss']:.4f} | " | |
| f"acc {val_metrics['val/accuracy']:.3f} | " | |
| f"top5 {val_metrics.get('val/top5_accuracy', 0):.3f} | " | |
| f"ppl {val_metrics.get('val/perplexity', 0):.1f}" | |
| ) | |
| if "val/legal_move_rate" in val_metrics: | |
| val_msg += f" | legal {val_metrics['val/legal_move_rate']:.3f}" | |
| print(val_msg, flush=True) | |
| self.logger.log_val(step=self.global_step, **val_metrics) # type: ignore[arg-type] | |
| if self.wandb_run: | |
| self.wandb_run.log(val_metrics, step=self.global_step) | |
| if self.global_step % self.cfg.checkpoint_interval == 0: | |
| self.save_checkpoint() | |
| if self.global_step >= self.cfg.total_steps: | |
| print(f"Training complete at step {self.global_step}") | |
| self.save_checkpoint() | |
| break | |
| if (self.cfg.pause_after_steps | |
| and self.global_step >= self.cfg.pause_after_steps): | |
| print(f"\n Paused at step {self.global_step} " | |
| f"(pause_after_steps={self.cfg.pause_after_steps})") | |
| self.save_checkpoint() | |
| break | |
| if _shutdown_requested: | |
| print(f"\nShutdown requested (signal {_shutdown_signal}), " | |
| f"saving checkpoint at step {self.global_step}...") | |
| self.save_checkpoint() | |
| break | |
| step_start = time.time() | |
| signal.signal(signal.SIGTERM, old_term) | |
| signal.signal(signal.SIGINT, old_int) | |
| self.logger.close() | |
| def save_checkpoint(self, path: str | None = None): | |
| from pawn.checkpoint import save_pretrain_checkpoint | |
| if path is None: | |
| path = os.path.join( | |
| self.cfg.checkpoint_dir, f"step_{self.global_step:08d}" | |
| ) | |
| model: PAWNCLM = self._eager_model() | |
| save_pretrain_checkpoint( | |
| path, | |
| model, | |
| self.optimizer, | |
| self.scheduler, | |
| self.scaler, | |
| self.global_step, | |
| self.model_cfg.__dict__, | |
| self.cfg.__dict__, | |
| ) | |
| print(f"Checkpoint saved: {path}") | |
| if self.hf_repo and self.hf_branch: | |
| from pawn.checkpoint import push_checkpoint_to_hf | |
| try: | |
| push_checkpoint_to_hf( | |
| path, self.hf_repo, self.hf_branch, | |
| metrics_path=self._jsonl_path, | |
| step=self.global_step, | |
| ) | |
| print(f"Pushed to HF: {self.hf_repo}@{self.hf_branch}") | |
| except Exception as e: | |
| print(f"WARNING: HF push failed: {e}") | |
| def load_checkpoint(self, path: str): | |
| from pawn.checkpoint import load_pretrain_checkpoint | |
| model: PAWNCLM = self._eager_model() | |
| meta = load_pretrain_checkpoint( | |
| path, model, self.optimizer, self.scheduler, self.scaler, | |
| device=self.device, | |
| ) | |
| self.global_step = meta["global_step"] | |
| print(f"Resumed from step {self.global_step}") | |
Xet Storage Details
- Size:
- 21.5 kB
- Xet hash:
- 634d63104c69cfdeccc8d8aa8e643ac4d881888c419eea8abc487d3f0e16493a
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.