File size: 8,647 Bytes
6d75857
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
"""Evaluate trained reward and policy models on held-out test sets.

Test sets are built by src/build_datasets.py (run once before benchmarking):
  stockfish_test_*.bin  — 50K Stockfish-labeled positions  (reward model)
  policy_test_*.bin     — 50K game sequences               (policy model)
  puzzle_test_*.bin     — 100K puzzle sequences            (puzzle solve rate)

Metrics:
  Reward:  MSE, MAE, Pearson r
  Policy:  loss, perplexity, top-1 move accuracy, top-5 move accuracy
  Puzzle:  first-move solve rate (top-1), all-moves solve rate

Usage:
  arch -arm64 poetry run python src/benchmark.py
  arch -arm64 poetry run python src/benchmark.py \\
    --reward-model reward_model.pt \\
    --policy-model policy_model.pt \\
    --data-dir data/ \\
    --batch-size 512
"""

import argparse
import time
from pathlib import Path

import torch
from torch.utils.data import DataLoader

from model import ChessRewardModel, ChessPolicyModel, PAD_TOKEN
from train import (
    ChessPositionDataset,
    ChessPolicyDataset,
    collate_fn_memmap,
    collate_fn_policy,
    eval_reward,
    eval_policy,
    eval_puzzle_solve_rate,
)


def _fmt(v: float, pct: bool = False) -> str:
    return f"{v * 100:.2f}%" if pct else f"{v:.4f}"


def run_benchmark(
    data_dir: Path,
    reward_model_path: str | None,
    policy_model_path: str | None,
    batch_size: int = 512,
    num_workers: int = 4,
    device: str | None = None,
) -> dict:
    """Run all available benchmarks and return a results dict.

    Returns a dict with keys 'reward', 'policy', 'puzzle' (each a sub-dict of
    metrics), or omits a key if the test set / model is unavailable.
    """
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Benchmark device: {device}")

    tokenizer_path = data_dir / "tokenizer.pt"
    if not tokenizer_path.exists():
        raise FileNotFoundError(f"Tokenizer not found at {tokenizer_path}")
    tokenizer = torch.load(tokenizer_path, weights_only=False)
    vocab_size = tokenizer.language_size
    pad_id = tokenizer.symbol_to_token[PAD_TOKEN]

    results = {}

    # ── Reward model ──────────────────────────────────────────────────────────
    reward_test_meta = data_dir / "stockfish_test_meta.pt"
    if reward_model_path and Path(reward_model_path).exists() and reward_test_meta.exists():
        print(f"\nLoading reward model from {reward_model_path}...")
        reward_model = ChessRewardModel(vocab_size=vocab_size).to(device)
        reward_model.load_state_dict(torch.load(reward_model_path, map_location=device, weights_only=True))

        print("Loading reward test set...")
        reward_test_ds = ChessPositionDataset.from_memmap(data_dir, "stockfish_test", tokenizer)
        reward_test_loader = DataLoader(
            reward_test_ds, batch_size=batch_size, shuffle=False,
            collate_fn=collate_fn_memmap, num_workers=num_workers, pin_memory=True,
        )
        print(f"  {len(reward_test_ds):,} test positions")

        t0 = time.time()
        m = eval_reward(reward_model, reward_test_loader, device)
        print(
            f"  Reward  |  MSE={_fmt(m['mse'])}  MAE={_fmt(m['mae'])}"
            f"  Pearson r={_fmt(m['pearson_r'])}  ({time.time()-t0:.1f}s)"
        )
        results["reward"] = m
    elif not reward_test_meta.exists():
        print("\nReward test set not found (run build_datasets.py to create it). Skipping.")
    elif not reward_model_path or not Path(reward_model_path).exists():
        print(f"\nReward model not found at {reward_model_path}. Skipping.")

    # ── Policy model ──────────────────────────────────────────────────────────
    policy_test_meta = data_dir / "policy_test_meta.pt"
    puzzle_test_meta = data_dir / "puzzle_test_meta.pt"
    policy_model = None

    if policy_model_path and Path(policy_model_path).exists():
        print(f"\nLoading policy model from {policy_model_path}...")
        policy_model = ChessPolicyModel(vocab_size=vocab_size).to(device)
        policy_model.load_state_dict(torch.load(policy_model_path, map_location=device, weights_only=True))

    if policy_model is not None and policy_test_meta.exists():
        print("Loading policy test set...")
        policy_test_ds = ChessPolicyDataset.from_memmap(data_dir, tokenizer, name="policy_test")
        policy_test_loader = DataLoader(
            policy_test_ds, batch_size=batch_size, shuffle=False,
            collate_fn=collate_fn_policy, num_workers=num_workers, pin_memory=True,
        )
        print(f"  {len(policy_test_ds):,} test sequences")

        t0 = time.time()
        m = eval_policy(policy_model, policy_test_loader, device, pad_id)
        print(
            f"  Policy  |  loss={_fmt(m['loss'])}  ppl={m['perplexity']:.2f}"
            f"  top1={_fmt(m['top1_acc'], pct=True)}  top5={_fmt(m['top5_acc'], pct=True)}"
            f"  ({time.time()-t0:.1f}s)"
        )
        results["policy"] = m
    elif policy_model is None:
        print(f"\nPolicy model not found at {policy_model_path}. Skipping policy + puzzle eval.")
    elif not policy_test_meta.exists():
        print("\nPolicy test set not found (run build_datasets.py to create it). Skipping.")

    if policy_model is not None and puzzle_test_meta.exists():
        print("Loading puzzle test set...")
        puzzle_test_ds = ChessPolicyDataset.from_memmap(data_dir, tokenizer, name="puzzle_test")
        puzzle_test_loader = DataLoader(
            puzzle_test_ds, batch_size=batch_size, shuffle=False,
            collate_fn=collate_fn_policy, num_workers=num_workers, pin_memory=True,
        )
        print(f"  {len(puzzle_test_ds):,} test puzzles")

        t0 = time.time()
        m = eval_puzzle_solve_rate(policy_model, puzzle_test_loader, device, pad_id)
        print(
            f"  Puzzle  |  first_move={_fmt(m['first_move_solve_rate'], pct=True)}"
            f"  all_moves={_fmt(m['all_moves_solve_rate'], pct=True)}"
            f"  ({time.time()-t0:.1f}s)"
        )
        results["puzzle"] = m
    elif policy_model is not None and not puzzle_test_meta.exists():
        print("\nPuzzle test set not found (run build_datasets.py to create it). Skipping.")

    return results


def _print_summary(results: dict) -> None:
    print("\n" + "=" * 60)
    print("BENCHMARK SUMMARY")
    print("=" * 60)
    if "reward" in results:
        m = results["reward"]
        print(f"  Reward model  MSE={m['mse']:.4f}  MAE={m['mae']:.4f}  Pearson r={m['pearson_r']:.4f}")
    if "policy" in results:
        m = results["policy"]
        print(
            f"  Policy model  loss={m['loss']:.4f}  perplexity={m['perplexity']:.2f}"
            f"  top-1={m['top1_acc']*100:.2f}%  top-5={m['top5_acc']*100:.2f}%"
        )
    if "puzzle" in results:
        m = results["puzzle"]
        print(
            f"  Puzzle eval   first-move solve={m['first_move_solve_rate']*100:.2f}%"
            f"  all-moves solve={m['all_moves_solve_rate']*100:.2f}%"
        )
    if not results:
        print("  No results — check that models and test sets exist.")
    print("=" * 60)


def _build_argparser() -> argparse.ArgumentParser:
    p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
    p.add_argument("--reward-model", default="reward_model.pt",
        help="Path to reward_model.pt (default: reward_model.pt)")
    p.add_argument("--policy-model", default="policy_model.pt",
        help="Path to policy_model.pt (default: policy_model.pt)")
    p.add_argument("--data-dir", type=Path, default=Path("data"),
        help="Directory containing test set .bin files and tokenizer.pt (default: data/)")
    p.add_argument("--batch-size", type=int, default=512,
        help="Batch size for evaluation (default: 512)")
    p.add_argument("--num-workers", type=int, default=4,
        help="DataLoader worker count (default: 4)")
    p.add_argument("--device", default=None,
        help="Device override, e.g. 'cpu' or 'cuda:0' (default: auto-detect)")
    return p


if __name__ == "__main__":
    args = _build_argparser().parse_args()
    results = run_benchmark(
        data_dir=args.data_dir,
        reward_model_path=args.reward_model,
        policy_model_path=args.policy_model,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        device=args.device,
    )
    _print_summary(results)