Spaces:
Sleeping
Sleeping
| """Evaluate trained reward and policy models on held-out test sets. | |
| Test sets are built by src/build_datasets.py (run once before benchmarking): | |
| stockfish_test_*.bin — 50K Stockfish-labeled positions (reward model) | |
| policy_test_*.bin — 50K game sequences (policy model) | |
| puzzle_test_*.bin — 100K puzzle sequences (puzzle solve rate) | |
| Metrics: | |
| Reward: MSE, MAE, Pearson r | |
| Policy: loss, perplexity, top-1 move accuracy, top-5 move accuracy | |
| Puzzle: first-move solve rate (top-1), all-moves solve rate | |
| Usage: | |
| arch -arm64 poetry run python src/benchmark.py | |
| arch -arm64 poetry run python src/benchmark.py \\ | |
| --reward-model reward_model.pt \\ | |
| --policy-model policy_model.pt \\ | |
| --data-dir data/ \\ | |
| --batch-size 512 | |
| """ | |
| import argparse | |
| import time | |
| from pathlib import Path | |
| import torch | |
| from torch.utils.data import DataLoader | |
| from model import ChessRewardModel, ChessPolicyModel, PAD_TOKEN | |
| from train import ( | |
| ChessPositionDataset, | |
| ChessPolicyDataset, | |
| collate_fn_memmap, | |
| collate_fn_policy, | |
| eval_reward, | |
| eval_policy, | |
| eval_puzzle_solve_rate, | |
| ) | |
| def _fmt(v: float, pct: bool = False) -> str: | |
| return f"{v * 100:.2f}%" if pct else f"{v:.4f}" | |
| def run_benchmark( | |
| data_dir: Path, | |
| reward_model_path: str | None, | |
| policy_model_path: str | None, | |
| batch_size: int = 512, | |
| num_workers: int = 4, | |
| device: str | None = None, | |
| ) -> dict: | |
| """Run all available benchmarks and return a results dict. | |
| Returns a dict with keys 'reward', 'policy', 'puzzle' (each a sub-dict of | |
| metrics), or omits a key if the test set / model is unavailable. | |
| """ | |
| if device is None: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Benchmark device: {device}") | |
| tokenizer_path = data_dir / "tokenizer.pt" | |
| if not tokenizer_path.exists(): | |
| raise FileNotFoundError(f"Tokenizer not found at {tokenizer_path}") | |
| tokenizer = torch.load(tokenizer_path, weights_only=False) | |
| vocab_size = tokenizer.language_size | |
| pad_id = tokenizer.symbol_to_token[PAD_TOKEN] | |
| results = {} | |
| # ── Reward model ────────────────────────────────────────────────────────── | |
| reward_test_meta = data_dir / "stockfish_test_meta.pt" | |
| if reward_model_path and Path(reward_model_path).exists() and reward_test_meta.exists(): | |
| print(f"\nLoading reward model from {reward_model_path}...") | |
| reward_model = ChessRewardModel(vocab_size=vocab_size).to(device) | |
| reward_model.load_state_dict(torch.load(reward_model_path, map_location=device, weights_only=True)) | |
| print("Loading reward test set...") | |
| reward_test_ds = ChessPositionDataset.from_memmap(data_dir, "stockfish_test", tokenizer) | |
| reward_test_loader = DataLoader( | |
| reward_test_ds, batch_size=batch_size, shuffle=False, | |
| collate_fn=collate_fn_memmap, num_workers=num_workers, pin_memory=True, | |
| ) | |
| print(f" {len(reward_test_ds):,} test positions") | |
| t0 = time.time() | |
| m = eval_reward(reward_model, reward_test_loader, device) | |
| print( | |
| f" Reward | MSE={_fmt(m['mse'])} MAE={_fmt(m['mae'])}" | |
| f" Pearson r={_fmt(m['pearson_r'])} ({time.time()-t0:.1f}s)" | |
| ) | |
| results["reward"] = m | |
| elif not reward_test_meta.exists(): | |
| print("\nReward test set not found (run build_datasets.py to create it). Skipping.") | |
| elif not reward_model_path or not Path(reward_model_path).exists(): | |
| print(f"\nReward model not found at {reward_model_path}. Skipping.") | |
| # ── Policy model ────────────────────────────────────────────────────────── | |
| policy_test_meta = data_dir / "policy_test_meta.pt" | |
| puzzle_test_meta = data_dir / "puzzle_test_meta.pt" | |
| policy_model = None | |
| if policy_model_path and Path(policy_model_path).exists(): | |
| print(f"\nLoading policy model from {policy_model_path}...") | |
| policy_model = ChessPolicyModel(vocab_size=vocab_size).to(device) | |
| policy_model.load_state_dict(torch.load(policy_model_path, map_location=device, weights_only=True)) | |
| if policy_model is not None and policy_test_meta.exists(): | |
| print("Loading policy test set...") | |
| policy_test_ds = ChessPolicyDataset.from_memmap(data_dir, tokenizer, name="policy_test") | |
| policy_test_loader = DataLoader( | |
| policy_test_ds, batch_size=batch_size, shuffle=False, | |
| collate_fn=collate_fn_policy, num_workers=num_workers, pin_memory=True, | |
| ) | |
| print(f" {len(policy_test_ds):,} test sequences") | |
| t0 = time.time() | |
| m = eval_policy(policy_model, policy_test_loader, device, pad_id) | |
| print( | |
| f" Policy | loss={_fmt(m['loss'])} ppl={m['perplexity']:.2f}" | |
| f" top1={_fmt(m['top1_acc'], pct=True)} top5={_fmt(m['top5_acc'], pct=True)}" | |
| f" ({time.time()-t0:.1f}s)" | |
| ) | |
| results["policy"] = m | |
| elif policy_model is None: | |
| print(f"\nPolicy model not found at {policy_model_path}. Skipping policy + puzzle eval.") | |
| elif not policy_test_meta.exists(): | |
| print("\nPolicy test set not found (run build_datasets.py to create it). Skipping.") | |
| if policy_model is not None and puzzle_test_meta.exists(): | |
| print("Loading puzzle test set...") | |
| puzzle_test_ds = ChessPolicyDataset.from_memmap(data_dir, tokenizer, name="puzzle_test") | |
| puzzle_test_loader = DataLoader( | |
| puzzle_test_ds, batch_size=batch_size, shuffle=False, | |
| collate_fn=collate_fn_policy, num_workers=num_workers, pin_memory=True, | |
| ) | |
| print(f" {len(puzzle_test_ds):,} test puzzles") | |
| t0 = time.time() | |
| m = eval_puzzle_solve_rate(policy_model, puzzle_test_loader, device, pad_id) | |
| print( | |
| f" Puzzle | first_move={_fmt(m['first_move_solve_rate'], pct=True)}" | |
| f" all_moves={_fmt(m['all_moves_solve_rate'], pct=True)}" | |
| f" ({time.time()-t0:.1f}s)" | |
| ) | |
| results["puzzle"] = m | |
| elif policy_model is not None and not puzzle_test_meta.exists(): | |
| print("\nPuzzle test set not found (run build_datasets.py to create it). Skipping.") | |
| return results | |
| def _print_summary(results: dict) -> None: | |
| print("\n" + "=" * 60) | |
| print("BENCHMARK SUMMARY") | |
| print("=" * 60) | |
| if "reward" in results: | |
| m = results["reward"] | |
| print(f" Reward model MSE={m['mse']:.4f} MAE={m['mae']:.4f} Pearson r={m['pearson_r']:.4f}") | |
| if "policy" in results: | |
| m = results["policy"] | |
| print( | |
| f" Policy model loss={m['loss']:.4f} perplexity={m['perplexity']:.2f}" | |
| f" top-1={m['top1_acc']*100:.2f}% top-5={m['top5_acc']*100:.2f}%" | |
| ) | |
| if "puzzle" in results: | |
| m = results["puzzle"] | |
| print( | |
| f" Puzzle eval first-move solve={m['first_move_solve_rate']*100:.2f}%" | |
| f" all-moves solve={m['all_moves_solve_rate']*100:.2f}%" | |
| ) | |
| if not results: | |
| print(" No results — check that models and test sets exist.") | |
| print("=" * 60) | |
| def _build_argparser() -> argparse.ArgumentParser: | |
| p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) | |
| p.add_argument("--reward-model", default="reward_model.pt", | |
| help="Path to reward_model.pt (default: reward_model.pt)") | |
| p.add_argument("--policy-model", default="policy_model.pt", | |
| help="Path to policy_model.pt (default: policy_model.pt)") | |
| p.add_argument("--data-dir", type=Path, default=Path("data"), | |
| help="Directory containing test set .bin files and tokenizer.pt (default: data/)") | |
| p.add_argument("--batch-size", type=int, default=512, | |
| help="Batch size for evaluation (default: 512)") | |
| p.add_argument("--num-workers", type=int, default=4, | |
| help="DataLoader worker count (default: 4)") | |
| p.add_argument("--device", default=None, | |
| help="Device override, e.g. 'cpu' or 'cuda:0' (default: auto-detect)") | |
| return p | |
| if __name__ == "__main__": | |
| args = _build_argparser().parse_args() | |
| results = run_benchmark( | |
| data_dir=args.data_dir, | |
| reward_model_path=args.reward_model, | |
| policy_model_path=args.policy_model, | |
| batch_size=args.batch_size, | |
| num_workers=args.num_workers, | |
| device=args.device, | |
| ) | |
| _print_summary(results) | |