thomas-schweich's picture
download
raw
8.91 kB
#!/usr/bin/env python3
"""Ensemble evaluation: load multiple adapter checkpoints, evaluate each
individually, then report ensemble (softmax averaging with legal mask).
Usage:
ensemble_eval.py --model name:strategy:ckpt_path [--model ...] [...]
"""
import argparse
import json
import sys
from pathlib import Path
sys.path.insert(0, "/opt/pawn")
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from safetensors.torch import load_file
from pawn.config import CLMConfig
from pawn.model import PAWNCLM
from pawn import model as model_module
from pawn.checkpoint import load_backbone_weights
from pawn.lichess_data import LichessDataset, LegalMaskBuilder, LegalMaskCollate
from pawn.shard_loader import load_val_shards
from pawn.gpu import configure_gpu, apply_gpu_config
def make_unfreeze_wrapper(backbone):
"""Create a wrapper with forward_hidden/project_head for unfreeze model."""
class W(nn.Module):
def __init__(self, bb):
super().__init__()
self.bb = bb
self.cfg = bb.cfg
def forward_hidden(self, input_ids, attention_mask=None):
bb = self.bb
x = bb.embed(input_ids)
T = input_ids.shape[1]
if attention_mask is not None:
causal = bb.causal_mask[:T, :T]
padding = attention_mask.unsqueeze(1).unsqueeze(2)
mask = causal.unsqueeze(0) & padding
else:
mask = None
rope_cos = bb.rope_cos[:, :, :T, :]
rope_sin = bb.rope_sin[:, :, :T, :]
for layer in bb.layers:
x = layer(x, rope_cos, rope_sin, mask)
return bb.final_norm(x)
def project_head(self, x):
return self.bb.lm_head(x)
return W(backbone)
def load_model(strategy, ckpt_path, device):
"""Build model, load adapter weights, return ready-to-eval wrapper."""
ckpt_path = Path(ckpt_path)
adapter_cfg = json.loads((ckpt_path / "config.json").read_text())
state_dict, model_config = load_backbone_weights(
"thomas-schweich/pawn-base", device
)
cfg = CLMConfig(**model_config) if model_config else CLMConfig()
backbone = PAWNCLM(cfg).to(device)
backbone.load_state_dict(state_dict)
adapter_state = load_file(str(ckpt_path / "adapter.safetensors"))
if strategy == "bottleneck":
from pawn.adapters.bottleneck import BottleneckCLM
layers = tuple(adapter_cfg["adapter_layers"])
model = BottleneckCLM(
backbone,
bottleneck_dim=adapter_cfg["bottleneck_dim"],
layers=layers,
adapt_attn=adapter_cfg.get("adapt_attn", True),
adapt_ffn=adapter_cfg.get("adapt_ffn", True),
).to(device)
model.load_state_dict(adapter_state, strict=False)
elif strategy == "unfreeze":
backbone.load_state_dict(adapter_state, strict=False)
model = make_unfreeze_wrapper(backbone).to(device)
elif strategy == "rosa":
# retro-bottleneck: backbone has merged sparse + wrap bottleneck
from pawn.adapters.bottleneck import BottleneckCLM
layers = tuple(adapter_cfg.get("adapter_layers") or (4, 5, 6, 7))
model = BottleneckCLM(
backbone,
bottleneck_dim=adapter_cfg["bottleneck_dim"],
layers=layers,
adapt_attn=True, adapt_ffn=True,
).to(device)
model.load_state_dict(adapter_state, strict=False)
else:
raise ValueError(f"Unknown strategy: {strategy}")
model.eval()
return model
@torch.no_grad()
def eval_models(models, val_loader, mask_builder, device, amp_dtype):
"""Return per-model metrics + ensemble metrics."""
n_models = len(models)
per_model = [
{"loss": 0.0, "top1": 0.0, "top5": 0.0} for _ in range(n_models)
]
ens = {"loss": 0.0, "top1": 0.0, "top5": 0.0}
total = 0
for batch in val_loader:
ids = batch["input_ids"].to(device, non_blocking=True)
tgt = batch["targets"].to(device, non_blocking=True)
msk = batch["loss_mask"].to(device, non_blocking=True)
if "legal_indices" in batch:
legal_mask = mask_builder.scatter(
batch["legal_indices"], ids.shape[0]
)
else:
legal_mask = mask_builder(batch)
valid_legal = legal_mask[msk]
valid_targets = tgt[msk]
n_pos = valid_targets.shape[0]
if n_pos == 0:
continue
all_logits = []
for (name, model) in models:
with torch.amp.autocast('cuda', dtype=amp_dtype, enabled=True):
hidden = model.forward_hidden(ids)
valid_hidden = hidden[msk]
lg = model.project_head(valid_hidden).float()
lg = lg.masked_fill(~valid_legal, float("-inf"))
all_logits.append(lg)
# Per-model
for i, lg in enumerate(all_logits):
loss = F.cross_entropy(lg, valid_targets).item()
preds = lg.argmax(dim=-1)
top1 = (preds == valid_targets).float().mean().item()
top5_idx = lg.topk(5, dim=-1).indices
top5 = (
(top5_idx == valid_targets.unsqueeze(-1))
.any(dim=-1).float().mean().item()
)
per_model[i]["loss"] += loss * n_pos
per_model[i]["top1"] += top1 * n_pos
per_model[i]["top5"] += top5 * n_pos
# Ensemble: average softmax
probs = torch.stack(
[F.softmax(lg, dim=-1) for lg in all_logits], dim=0
).mean(dim=0)
log_probs = probs.clamp(min=1e-30).log()
loss = F.nll_loss(log_probs, valid_targets).item()
preds = probs.argmax(dim=-1)
top1 = (preds == valid_targets).float().mean().item()
top5_idx = probs.topk(5, dim=-1).indices
top5 = (
(top5_idx == valid_targets.unsqueeze(-1))
.any(dim=-1).float().mean().item()
)
ens["loss"] += loss * n_pos
ens["top1"] += top1 * n_pos
ens["top5"] += top5 * n_pos
total += n_pos
for m in per_model:
for k in m:
m[k] /= total
for k in ens:
ens[k] /= total
return per_model, ens, total
def main():
p = argparse.ArgumentParser()
p.add_argument("--model", action="append", required=True,
help="name:strategy:checkpoint_path")
p.add_argument("--pgn", default="thomas-schweich/pawn-lichess-full")
p.add_argument("--elo-min", type=int, default=1800)
p.add_argument("--elo-max", type=int, default=1900)
p.add_argument("--val-games", type=int, default=50000)
p.add_argument("--min-ply", type=int, default=10)
p.add_argument("--batch-size", type=int, default=256)
p.add_argument("--cache-dir", default="/dev/shm/pawn_cache")
args = p.parse_args()
device = "cuda"
gpu_cfg = configure_gpu(device=device, sdpa_math=True, no_compile=True)
apply_gpu_config(gpu_cfg, model_module, lambda x: x)
amp_dtype = torch.bfloat16
# Parse model specs
specs = []
for spec in args.model:
name, strategy, ckpt_path = spec.split(":", 2)
specs.append((name, strategy, ckpt_path))
# Load models
models = []
for name, strategy, ckpt_path in specs:
print(f"Loading {name} ({strategy}) from {ckpt_path}...")
m = load_model(strategy, ckpt_path, device)
models.append((name, m))
# Build val loader
print(f"Loading val shards...")
val_data = load_val_shards(
args.pgn, elo_min=args.elo_min, elo_max=args.elo_max,
min_ply=args.min_ply, max_games=args.val_games,
cache_dir=args.cache_dir,
)
val_ds = LichessDataset(val_data, start=0, end=val_data["n_games"])
collate = LegalMaskCollate(seq_len=256, vocab_size=4278)
val_loader = DataLoader(
val_ds, batch_size=args.batch_size, shuffle=False,
num_workers=0, pin_memory=True, collate_fn=collate,
)
mask_builder = LegalMaskBuilder(
batch_size=args.batch_size, max_ply=255,
vocab_size=4278, device=device,
)
print(f"Evaluating ({len(val_ds):,} games)...")
per_model, ensemble, total = eval_models(
models, val_loader, mask_builder, device, amp_dtype
)
print("\n=== Individual ===")
for (name, _), m in zip(specs, per_model):
print(
f" {name[0]:20s} "
f"loss={m['loss']:.4f} "
f"top1={m['top1']*100:.2f}% "
f"top5={m['top5']*100:.2f}%"
)
print("\n=== Ensemble ===")
print(
f" softmax_avg "
f"loss={ensemble['loss']:.4f} "
f"top1={ensemble['top1']*100:.2f}% "
f"top5={ensemble['top5']*100:.2f}%"
)
print(f"\n({total:,} positions evaluated)")
if __name__ == "__main__":
main()

Xet Storage Details

Size:
8.91 kB
·
Xet hash:
f94679f77e3bf13493c726f1f4b2a801b005f60cec54c1a3c68c2ca8f3e52370

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.