Buckets:
| #!/usr/bin/env python3 | |
| """Ensemble evaluation: load multiple adapter checkpoints, evaluate each | |
| individually, then report ensemble (softmax averaging with legal mask). | |
| Usage: | |
| ensemble_eval.py --model name:strategy:ckpt_path [--model ...] [...] | |
| """ | |
| import argparse | |
| import json | |
| import sys | |
| from pathlib import Path | |
| sys.path.insert(0, "/opt/pawn") | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.utils.data import DataLoader | |
| from safetensors.torch import load_file | |
| from pawn.config import CLMConfig | |
| from pawn.model import PAWNCLM | |
| from pawn import model as model_module | |
| from pawn.checkpoint import load_backbone_weights | |
| from pawn.lichess_data import LichessDataset, LegalMaskBuilder, LegalMaskCollate | |
| from pawn.shard_loader import load_val_shards | |
| from pawn.gpu import configure_gpu, apply_gpu_config | |
| def make_unfreeze_wrapper(backbone): | |
| """Create a wrapper with forward_hidden/project_head for unfreeze model.""" | |
| class W(nn.Module): | |
| def __init__(self, bb): | |
| super().__init__() | |
| self.bb = bb | |
| self.cfg = bb.cfg | |
| def forward_hidden(self, input_ids, attention_mask=None): | |
| bb = self.bb | |
| x = bb.embed(input_ids) | |
| T = input_ids.shape[1] | |
| if attention_mask is not None: | |
| causal = bb.causal_mask[:T, :T] | |
| padding = attention_mask.unsqueeze(1).unsqueeze(2) | |
| mask = causal.unsqueeze(0) & padding | |
| else: | |
| mask = None | |
| rope_cos = bb.rope_cos[:, :, :T, :] | |
| rope_sin = bb.rope_sin[:, :, :T, :] | |
| for layer in bb.layers: | |
| x = layer(x, rope_cos, rope_sin, mask) | |
| return bb.final_norm(x) | |
| def project_head(self, x): | |
| return self.bb.lm_head(x) | |
| return W(backbone) | |
| def load_model(strategy, ckpt_path, device): | |
| """Build model, load adapter weights, return ready-to-eval wrapper.""" | |
| ckpt_path = Path(ckpt_path) | |
| adapter_cfg = json.loads((ckpt_path / "config.json").read_text()) | |
| state_dict, model_config = load_backbone_weights( | |
| "thomas-schweich/pawn-base", device | |
| ) | |
| cfg = CLMConfig(**model_config) if model_config else CLMConfig() | |
| backbone = PAWNCLM(cfg).to(device) | |
| backbone.load_state_dict(state_dict) | |
| adapter_state = load_file(str(ckpt_path / "adapter.safetensors")) | |
| if strategy == "bottleneck": | |
| from pawn.adapters.bottleneck import BottleneckCLM | |
| layers = tuple(adapter_cfg["adapter_layers"]) | |
| model = BottleneckCLM( | |
| backbone, | |
| bottleneck_dim=adapter_cfg["bottleneck_dim"], | |
| layers=layers, | |
| adapt_attn=adapter_cfg.get("adapt_attn", True), | |
| adapt_ffn=adapter_cfg.get("adapt_ffn", True), | |
| ).to(device) | |
| model.load_state_dict(adapter_state, strict=False) | |
| elif strategy == "unfreeze": | |
| backbone.load_state_dict(adapter_state, strict=False) | |
| model = make_unfreeze_wrapper(backbone).to(device) | |
| elif strategy == "rosa": | |
| # retro-bottleneck: backbone has merged sparse + wrap bottleneck | |
| from pawn.adapters.bottleneck import BottleneckCLM | |
| layers = tuple(adapter_cfg.get("adapter_layers") or (4, 5, 6, 7)) | |
| model = BottleneckCLM( | |
| backbone, | |
| bottleneck_dim=adapter_cfg["bottleneck_dim"], | |
| layers=layers, | |
| adapt_attn=True, adapt_ffn=True, | |
| ).to(device) | |
| model.load_state_dict(adapter_state, strict=False) | |
| else: | |
| raise ValueError(f"Unknown strategy: {strategy}") | |
| model.eval() | |
| return model | |
| def eval_models(models, val_loader, mask_builder, device, amp_dtype): | |
| """Return per-model metrics + ensemble metrics.""" | |
| n_models = len(models) | |
| per_model = [ | |
| {"loss": 0.0, "top1": 0.0, "top5": 0.0} for _ in range(n_models) | |
| ] | |
| ens = {"loss": 0.0, "top1": 0.0, "top5": 0.0} | |
| total = 0 | |
| for batch in val_loader: | |
| ids = batch["input_ids"].to(device, non_blocking=True) | |
| tgt = batch["targets"].to(device, non_blocking=True) | |
| msk = batch["loss_mask"].to(device, non_blocking=True) | |
| if "legal_indices" in batch: | |
| legal_mask = mask_builder.scatter( | |
| batch["legal_indices"], ids.shape[0] | |
| ) | |
| else: | |
| legal_mask = mask_builder(batch) | |
| valid_legal = legal_mask[msk] | |
| valid_targets = tgt[msk] | |
| n_pos = valid_targets.shape[0] | |
| if n_pos == 0: | |
| continue | |
| all_logits = [] | |
| for (name, model) in models: | |
| with torch.amp.autocast('cuda', dtype=amp_dtype, enabled=True): | |
| hidden = model.forward_hidden(ids) | |
| valid_hidden = hidden[msk] | |
| lg = model.project_head(valid_hidden).float() | |
| lg = lg.masked_fill(~valid_legal, float("-inf")) | |
| all_logits.append(lg) | |
| # Per-model | |
| for i, lg in enumerate(all_logits): | |
| loss = F.cross_entropy(lg, valid_targets).item() | |
| preds = lg.argmax(dim=-1) | |
| top1 = (preds == valid_targets).float().mean().item() | |
| top5_idx = lg.topk(5, dim=-1).indices | |
| top5 = ( | |
| (top5_idx == valid_targets.unsqueeze(-1)) | |
| .any(dim=-1).float().mean().item() | |
| ) | |
| per_model[i]["loss"] += loss * n_pos | |
| per_model[i]["top1"] += top1 * n_pos | |
| per_model[i]["top5"] += top5 * n_pos | |
| # Ensemble: average softmax | |
| probs = torch.stack( | |
| [F.softmax(lg, dim=-1) for lg in all_logits], dim=0 | |
| ).mean(dim=0) | |
| log_probs = probs.clamp(min=1e-30).log() | |
| loss = F.nll_loss(log_probs, valid_targets).item() | |
| preds = probs.argmax(dim=-1) | |
| top1 = (preds == valid_targets).float().mean().item() | |
| top5_idx = probs.topk(5, dim=-1).indices | |
| top5 = ( | |
| (top5_idx == valid_targets.unsqueeze(-1)) | |
| .any(dim=-1).float().mean().item() | |
| ) | |
| ens["loss"] += loss * n_pos | |
| ens["top1"] += top1 * n_pos | |
| ens["top5"] += top5 * n_pos | |
| total += n_pos | |
| for m in per_model: | |
| for k in m: | |
| m[k] /= total | |
| for k in ens: | |
| ens[k] /= total | |
| return per_model, ens, total | |
| def main(): | |
| p = argparse.ArgumentParser() | |
| p.add_argument("--model", action="append", required=True, | |
| help="name:strategy:checkpoint_path") | |
| p.add_argument("--pgn", default="thomas-schweich/pawn-lichess-full") | |
| p.add_argument("--elo-min", type=int, default=1800) | |
| p.add_argument("--elo-max", type=int, default=1900) | |
| p.add_argument("--val-games", type=int, default=50000) | |
| p.add_argument("--min-ply", type=int, default=10) | |
| p.add_argument("--batch-size", type=int, default=256) | |
| p.add_argument("--cache-dir", default="/dev/shm/pawn_cache") | |
| args = p.parse_args() | |
| device = "cuda" | |
| gpu_cfg = configure_gpu(device=device, sdpa_math=True, no_compile=True) | |
| apply_gpu_config(gpu_cfg, model_module, lambda x: x) | |
| amp_dtype = torch.bfloat16 | |
| # Parse model specs | |
| specs = [] | |
| for spec in args.model: | |
| name, strategy, ckpt_path = spec.split(":", 2) | |
| specs.append((name, strategy, ckpt_path)) | |
| # Load models | |
| models = [] | |
| for name, strategy, ckpt_path in specs: | |
| print(f"Loading {name} ({strategy}) from {ckpt_path}...") | |
| m = load_model(strategy, ckpt_path, device) | |
| models.append((name, m)) | |
| # Build val loader | |
| print(f"Loading val shards...") | |
| val_data = load_val_shards( | |
| args.pgn, elo_min=args.elo_min, elo_max=args.elo_max, | |
| min_ply=args.min_ply, max_games=args.val_games, | |
| cache_dir=args.cache_dir, | |
| ) | |
| val_ds = LichessDataset(val_data, start=0, end=val_data["n_games"]) | |
| collate = LegalMaskCollate(seq_len=256, vocab_size=4278) | |
| val_loader = DataLoader( | |
| val_ds, batch_size=args.batch_size, shuffle=False, | |
| num_workers=0, pin_memory=True, collate_fn=collate, | |
| ) | |
| mask_builder = LegalMaskBuilder( | |
| batch_size=args.batch_size, max_ply=255, | |
| vocab_size=4278, device=device, | |
| ) | |
| print(f"Evaluating ({len(val_ds):,} games)...") | |
| per_model, ensemble, total = eval_models( | |
| models, val_loader, mask_builder, device, amp_dtype | |
| ) | |
| print("\n=== Individual ===") | |
| for (name, _), m in zip(specs, per_model): | |
| print( | |
| f" {name[0]:20s} " | |
| f"loss={m['loss']:.4f} " | |
| f"top1={m['top1']*100:.2f}% " | |
| f"top5={m['top5']*100:.2f}%" | |
| ) | |
| print("\n=== Ensemble ===") | |
| print( | |
| f" softmax_avg " | |
| f"loss={ensemble['loss']:.4f} " | |
| f"top1={ensemble['top1']*100:.2f}% " | |
| f"top5={ensemble['top5']*100:.2f}%" | |
| ) | |
| print(f"\n({total:,} positions evaluated)") | |
| if __name__ == "__main__": | |
| main() | |
Xet Storage Details
- Size:
- 8.91 kB
- Xet hash:
- f94679f77e3bf13493c726f1f4b2a801b005f60cec54c1a3c68c2ca8f3e52370
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.