GptForChess / src /benchmark.py
robell05's picture
serving model
6d75857
"""Evaluate trained reward and policy models on held-out test sets.
Test sets are built by src/build_datasets.py (run once before benchmarking):
stockfish_test_*.bin — 50K Stockfish-labeled positions (reward model)
policy_test_*.bin — 50K game sequences (policy model)
puzzle_test_*.bin — 100K puzzle sequences (puzzle solve rate)
Metrics:
Reward: MSE, MAE, Pearson r
Policy: loss, perplexity, top-1 move accuracy, top-5 move accuracy
Puzzle: first-move solve rate (top-1), all-moves solve rate
Usage:
arch -arm64 poetry run python src/benchmark.py
arch -arm64 poetry run python src/benchmark.py \\
--reward-model reward_model.pt \\
--policy-model policy_model.pt \\
--data-dir data/ \\
--batch-size 512
"""
import argparse
import time
from pathlib import Path
import torch
from torch.utils.data import DataLoader
from model import ChessRewardModel, ChessPolicyModel, PAD_TOKEN
from train import (
ChessPositionDataset,
ChessPolicyDataset,
collate_fn_memmap,
collate_fn_policy,
eval_reward,
eval_policy,
eval_puzzle_solve_rate,
)
def _fmt(v: float, pct: bool = False) -> str:
return f"{v * 100:.2f}%" if pct else f"{v:.4f}"
def run_benchmark(
data_dir: Path,
reward_model_path: str | None,
policy_model_path: str | None,
batch_size: int = 512,
num_workers: int = 4,
device: str | None = None,
) -> dict:
"""Run all available benchmarks and return a results dict.
Returns a dict with keys 'reward', 'policy', 'puzzle' (each a sub-dict of
metrics), or omits a key if the test set / model is unavailable.
"""
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Benchmark device: {device}")
tokenizer_path = data_dir / "tokenizer.pt"
if not tokenizer_path.exists():
raise FileNotFoundError(f"Tokenizer not found at {tokenizer_path}")
tokenizer = torch.load(tokenizer_path, weights_only=False)
vocab_size = tokenizer.language_size
pad_id = tokenizer.symbol_to_token[PAD_TOKEN]
results = {}
# ── Reward model ──────────────────────────────────────────────────────────
reward_test_meta = data_dir / "stockfish_test_meta.pt"
if reward_model_path and Path(reward_model_path).exists() and reward_test_meta.exists():
print(f"\nLoading reward model from {reward_model_path}...")
reward_model = ChessRewardModel(vocab_size=vocab_size).to(device)
reward_model.load_state_dict(torch.load(reward_model_path, map_location=device, weights_only=True))
print("Loading reward test set...")
reward_test_ds = ChessPositionDataset.from_memmap(data_dir, "stockfish_test", tokenizer)
reward_test_loader = DataLoader(
reward_test_ds, batch_size=batch_size, shuffle=False,
collate_fn=collate_fn_memmap, num_workers=num_workers, pin_memory=True,
)
print(f" {len(reward_test_ds):,} test positions")
t0 = time.time()
m = eval_reward(reward_model, reward_test_loader, device)
print(
f" Reward | MSE={_fmt(m['mse'])} MAE={_fmt(m['mae'])}"
f" Pearson r={_fmt(m['pearson_r'])} ({time.time()-t0:.1f}s)"
)
results["reward"] = m
elif not reward_test_meta.exists():
print("\nReward test set not found (run build_datasets.py to create it). Skipping.")
elif not reward_model_path or not Path(reward_model_path).exists():
print(f"\nReward model not found at {reward_model_path}. Skipping.")
# ── Policy model ──────────────────────────────────────────────────────────
policy_test_meta = data_dir / "policy_test_meta.pt"
puzzle_test_meta = data_dir / "puzzle_test_meta.pt"
policy_model = None
if policy_model_path and Path(policy_model_path).exists():
print(f"\nLoading policy model from {policy_model_path}...")
policy_model = ChessPolicyModel(vocab_size=vocab_size).to(device)
policy_model.load_state_dict(torch.load(policy_model_path, map_location=device, weights_only=True))
if policy_model is not None and policy_test_meta.exists():
print("Loading policy test set...")
policy_test_ds = ChessPolicyDataset.from_memmap(data_dir, tokenizer, name="policy_test")
policy_test_loader = DataLoader(
policy_test_ds, batch_size=batch_size, shuffle=False,
collate_fn=collate_fn_policy, num_workers=num_workers, pin_memory=True,
)
print(f" {len(policy_test_ds):,} test sequences")
t0 = time.time()
m = eval_policy(policy_model, policy_test_loader, device, pad_id)
print(
f" Policy | loss={_fmt(m['loss'])} ppl={m['perplexity']:.2f}"
f" top1={_fmt(m['top1_acc'], pct=True)} top5={_fmt(m['top5_acc'], pct=True)}"
f" ({time.time()-t0:.1f}s)"
)
results["policy"] = m
elif policy_model is None:
print(f"\nPolicy model not found at {policy_model_path}. Skipping policy + puzzle eval.")
elif not policy_test_meta.exists():
print("\nPolicy test set not found (run build_datasets.py to create it). Skipping.")
if policy_model is not None and puzzle_test_meta.exists():
print("Loading puzzle test set...")
puzzle_test_ds = ChessPolicyDataset.from_memmap(data_dir, tokenizer, name="puzzle_test")
puzzle_test_loader = DataLoader(
puzzle_test_ds, batch_size=batch_size, shuffle=False,
collate_fn=collate_fn_policy, num_workers=num_workers, pin_memory=True,
)
print(f" {len(puzzle_test_ds):,} test puzzles")
t0 = time.time()
m = eval_puzzle_solve_rate(policy_model, puzzle_test_loader, device, pad_id)
print(
f" Puzzle | first_move={_fmt(m['first_move_solve_rate'], pct=True)}"
f" all_moves={_fmt(m['all_moves_solve_rate'], pct=True)}"
f" ({time.time()-t0:.1f}s)"
)
results["puzzle"] = m
elif policy_model is not None and not puzzle_test_meta.exists():
print("\nPuzzle test set not found (run build_datasets.py to create it). Skipping.")
return results
def _print_summary(results: dict) -> None:
print("\n" + "=" * 60)
print("BENCHMARK SUMMARY")
print("=" * 60)
if "reward" in results:
m = results["reward"]
print(f" Reward model MSE={m['mse']:.4f} MAE={m['mae']:.4f} Pearson r={m['pearson_r']:.4f}")
if "policy" in results:
m = results["policy"]
print(
f" Policy model loss={m['loss']:.4f} perplexity={m['perplexity']:.2f}"
f" top-1={m['top1_acc']*100:.2f}% top-5={m['top5_acc']*100:.2f}%"
)
if "puzzle" in results:
m = results["puzzle"]
print(
f" Puzzle eval first-move solve={m['first_move_solve_rate']*100:.2f}%"
f" all-moves solve={m['all_moves_solve_rate']*100:.2f}%"
)
if not results:
print(" No results — check that models and test sets exist.")
print("=" * 60)
def _build_argparser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
p.add_argument("--reward-model", default="reward_model.pt",
help="Path to reward_model.pt (default: reward_model.pt)")
p.add_argument("--policy-model", default="policy_model.pt",
help="Path to policy_model.pt (default: policy_model.pt)")
p.add_argument("--data-dir", type=Path, default=Path("data"),
help="Directory containing test set .bin files and tokenizer.pt (default: data/)")
p.add_argument("--batch-size", type=int, default=512,
help="Batch size for evaluation (default: 512)")
p.add_argument("--num-workers", type=int, default=4,
help="DataLoader worker count (default: 4)")
p.add_argument("--device", default=None,
help="Device override, e.g. 'cpu' or 'cuda:0' (default: auto-detect)")
return p
if __name__ == "__main__":
args = _build_argparser().parse_args()
results = run_benchmark(
data_dir=args.data_dir,
reward_model_path=args.reward_model,
policy_model_path=args.policy_model,
batch_size=args.batch_size,
num_workers=args.num_workers,
device=args.device,
)
_print_summary(results)