robell05 commited on
Commit
6d75857
·
1 Parent(s): 4c1be1b

serving model

Browse files
.gitignore ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Virtual environments
2
+ .venv/
3
+ venv/
4
+ env/
5
+
6
+ # Python
7
+ __pycache__/
8
+ *.py[cod]
9
+ *.egg-info/
10
+ .pytest_cache/
11
+
12
+ # IDE
13
+ .vscode/
14
+ .idea/
15
+
16
+ # OS
17
+ .DS_Store
Dockerfile ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.12.2-slim
2
+
3
+ ENV PYTHONDONTWRITEBYTECODE=1 \
4
+ PYTHONUNBUFFERED=1 \
5
+ PIP_NO_CACHE_DIR=1
6
+
7
+ # build-essential covers the few deps with C extensions (numpy/pandas wheels
8
+ # are usually prebuilt for 3.12, but this is cheap insurance).
9
+ RUN apt-get update && apt-get install -y --no-install-recommends \
10
+ build-essential \
11
+ && rm -rf /var/lib/apt/lists/*
12
+
13
+ WORKDIR /app
14
+
15
+ # Deps first so editing app.py doesn't reinstall torch on every rebuild.
16
+ COPY requirements.txt .
17
+ RUN pip install --extra-index-url https://download.pytorch.org/whl/cpu -r requirements.txt
18
+
19
+ # Source next.
20
+ COPY src/ ./src/
21
+ COPY app.py ./
22
+
23
+ # Weights last — biggest layer, changes least often, so cache stays warm
24
+ # when you iterate on app.py.
25
+ COPY model/ ./model/
26
+
27
+ # torch.load on tokenizer.pt unpickles a `Tokenizer` instance that was
28
+ # saved from the module path `tokenizer` (src/tokenizer.py). Making /app/src
29
+ # importable lets the unpickler find that class.
30
+ ENV PYTHONPATH=/app/src
31
+
32
+ # HF Spaces routes external traffic to $PORT; 7860 is the convention.
33
+ # Single worker — the model lives in process memory and multiple workers
34
+ # would multiply the ~1.5 GB RSS.
35
+ EXPOSE 7860
36
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860", "--workers", "1"]
app.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ import chess
3
+ from contextlib import asynccontextmanager
4
+ import sys
5
+ from src.model import ChessPolicyModel, PolicyModelInference
6
+ from src.tokenizer import tokenizer
7
+ import torch
8
+ from pydantic import BaseModel
9
+
10
+ ml = {}
11
+ @asynccontextmanager
12
+ async def lifespan(app: FastAPI):
13
+ tokenizer = torch.load("./model/tokenizer.pt", weights_only=False, map_location=torch.device('cpu'))
14
+
15
+ model = ChessPolicyModel(vocab_size=tokenizer.language_size)
16
+ model.load_state_dict(
17
+ torch.load("./model/policy_model.pt", weights_only=False, map_location=torch.device('cpu'))
18
+ )
19
+ ml["inference"] = PolicyModelInference(model, tokenizer, device="cpu")
20
+ yield
21
+ ml.clear()
22
+
23
+ app = FastAPI(lifespan=lifespan)
24
+
25
+
26
+
27
+ class InferenceRequest(BaseModel):
28
+ moves: list[str]
29
+
30
+ @app.post("/inference")
31
+ def model_inference(req: InferenceRequest):
32
+ board = chess.Board()
33
+ for move in req.moves:
34
+ try:
35
+ board.push_uci(move)
36
+ except ValueError as e:
37
+ raise HTTPException(status_code=400, detail=f"Incorrect move {move}: {e}")
38
+ try:
39
+ return {"move" : ml["inference"](board)}
40
+ except ValueError as e:
41
+ raise HTTPException(status_code=500, detail=f"Model Failed to evaluate: {e}")
42
+
43
+
model/policy_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b7c3556e2842ff3f6b20a11b9cc0759a89ca39b2d7cde30b54d8d3ba1bee573a
3
+ size 322812083
model/tokenizer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:45887b381da27b8cd119274704fb0b4766125a9218f87643be8260017869471b
3
+ size 57977
requirements.txt ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ annotated-doc==0.0.4
2
+ annotated-types==0.7.0
3
+ anyio==4.13.0
4
+ chess==1.11.2
5
+ click==8.4.0
6
+ fastapi==0.136.1
7
+ filelock==3.29.0
8
+ fsspec==2026.4.0
9
+ h11==0.16.0
10
+ idna==3.15
11
+ Jinja2==3.1.6
12
+ MarkupSafe==3.0.3
13
+ mpmath==1.3.0
14
+ networkx==3.6.1
15
+ numpy==2.4.5
16
+ pandas==3.0.3
17
+ pydantic==2.13.4
18
+ pydantic_core==2.46.4
19
+ python-chess==1.999
20
+ python-dateutil==2.9.0.post0
21
+ setuptools==81.0.0
22
+ six==1.17.0
23
+ starlette==1.0.0
24
+ sympy==1.14.0
25
+ torch==2.12.0
26
+ typing-inspection==0.4.2
27
+ typing_extensions==4.15.0
28
+ uvicorn==0.47.0
README.md → src/README.md RENAMED
File without changes
src/__init__.py ADDED
File without changes
src/benchmark.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Evaluate trained reward and policy models on held-out test sets.
2
+
3
+ Test sets are built by src/build_datasets.py (run once before benchmarking):
4
+ stockfish_test_*.bin — 50K Stockfish-labeled positions (reward model)
5
+ policy_test_*.bin — 50K game sequences (policy model)
6
+ puzzle_test_*.bin — 100K puzzle sequences (puzzle solve rate)
7
+
8
+ Metrics:
9
+ Reward: MSE, MAE, Pearson r
10
+ Policy: loss, perplexity, top-1 move accuracy, top-5 move accuracy
11
+ Puzzle: first-move solve rate (top-1), all-moves solve rate
12
+
13
+ Usage:
14
+ arch -arm64 poetry run python src/benchmark.py
15
+ arch -arm64 poetry run python src/benchmark.py \\
16
+ --reward-model reward_model.pt \\
17
+ --policy-model policy_model.pt \\
18
+ --data-dir data/ \\
19
+ --batch-size 512
20
+ """
21
+
22
+ import argparse
23
+ import time
24
+ from pathlib import Path
25
+
26
+ import torch
27
+ from torch.utils.data import DataLoader
28
+
29
+ from model import ChessRewardModel, ChessPolicyModel, PAD_TOKEN
30
+ from train import (
31
+ ChessPositionDataset,
32
+ ChessPolicyDataset,
33
+ collate_fn_memmap,
34
+ collate_fn_policy,
35
+ eval_reward,
36
+ eval_policy,
37
+ eval_puzzle_solve_rate,
38
+ )
39
+
40
+
41
+ def _fmt(v: float, pct: bool = False) -> str:
42
+ return f"{v * 100:.2f}%" if pct else f"{v:.4f}"
43
+
44
+
45
+ def run_benchmark(
46
+ data_dir: Path,
47
+ reward_model_path: str | None,
48
+ policy_model_path: str | None,
49
+ batch_size: int = 512,
50
+ num_workers: int = 4,
51
+ device: str | None = None,
52
+ ) -> dict:
53
+ """Run all available benchmarks and return a results dict.
54
+
55
+ Returns a dict with keys 'reward', 'policy', 'puzzle' (each a sub-dict of
56
+ metrics), or omits a key if the test set / model is unavailable.
57
+ """
58
+ if device is None:
59
+ device = "cuda" if torch.cuda.is_available() else "cpu"
60
+ print(f"Benchmark device: {device}")
61
+
62
+ tokenizer_path = data_dir / "tokenizer.pt"
63
+ if not tokenizer_path.exists():
64
+ raise FileNotFoundError(f"Tokenizer not found at {tokenizer_path}")
65
+ tokenizer = torch.load(tokenizer_path, weights_only=False)
66
+ vocab_size = tokenizer.language_size
67
+ pad_id = tokenizer.symbol_to_token[PAD_TOKEN]
68
+
69
+ results = {}
70
+
71
+ # ── Reward model ──────────────────────────────────────────────────────────
72
+ reward_test_meta = data_dir / "stockfish_test_meta.pt"
73
+ if reward_model_path and Path(reward_model_path).exists() and reward_test_meta.exists():
74
+ print(f"\nLoading reward model from {reward_model_path}...")
75
+ reward_model = ChessRewardModel(vocab_size=vocab_size).to(device)
76
+ reward_model.load_state_dict(torch.load(reward_model_path, map_location=device, weights_only=True))
77
+
78
+ print("Loading reward test set...")
79
+ reward_test_ds = ChessPositionDataset.from_memmap(data_dir, "stockfish_test", tokenizer)
80
+ reward_test_loader = DataLoader(
81
+ reward_test_ds, batch_size=batch_size, shuffle=False,
82
+ collate_fn=collate_fn_memmap, num_workers=num_workers, pin_memory=True,
83
+ )
84
+ print(f" {len(reward_test_ds):,} test positions")
85
+
86
+ t0 = time.time()
87
+ m = eval_reward(reward_model, reward_test_loader, device)
88
+ print(
89
+ f" Reward | MSE={_fmt(m['mse'])} MAE={_fmt(m['mae'])}"
90
+ f" Pearson r={_fmt(m['pearson_r'])} ({time.time()-t0:.1f}s)"
91
+ )
92
+ results["reward"] = m
93
+ elif not reward_test_meta.exists():
94
+ print("\nReward test set not found (run build_datasets.py to create it). Skipping.")
95
+ elif not reward_model_path or not Path(reward_model_path).exists():
96
+ print(f"\nReward model not found at {reward_model_path}. Skipping.")
97
+
98
+ # ── Policy model ──────────────────────────────────────────────────────────
99
+ policy_test_meta = data_dir / "policy_test_meta.pt"
100
+ puzzle_test_meta = data_dir / "puzzle_test_meta.pt"
101
+ policy_model = None
102
+
103
+ if policy_model_path and Path(policy_model_path).exists():
104
+ print(f"\nLoading policy model from {policy_model_path}...")
105
+ policy_model = ChessPolicyModel(vocab_size=vocab_size).to(device)
106
+ policy_model.load_state_dict(torch.load(policy_model_path, map_location=device, weights_only=True))
107
+
108
+ if policy_model is not None and policy_test_meta.exists():
109
+ print("Loading policy test set...")
110
+ policy_test_ds = ChessPolicyDataset.from_memmap(data_dir, tokenizer, name="policy_test")
111
+ policy_test_loader = DataLoader(
112
+ policy_test_ds, batch_size=batch_size, shuffle=False,
113
+ collate_fn=collate_fn_policy, num_workers=num_workers, pin_memory=True,
114
+ )
115
+ print(f" {len(policy_test_ds):,} test sequences")
116
+
117
+ t0 = time.time()
118
+ m = eval_policy(policy_model, policy_test_loader, device, pad_id)
119
+ print(
120
+ f" Policy | loss={_fmt(m['loss'])} ppl={m['perplexity']:.2f}"
121
+ f" top1={_fmt(m['top1_acc'], pct=True)} top5={_fmt(m['top5_acc'], pct=True)}"
122
+ f" ({time.time()-t0:.1f}s)"
123
+ )
124
+ results["policy"] = m
125
+ elif policy_model is None:
126
+ print(f"\nPolicy model not found at {policy_model_path}. Skipping policy + puzzle eval.")
127
+ elif not policy_test_meta.exists():
128
+ print("\nPolicy test set not found (run build_datasets.py to create it). Skipping.")
129
+
130
+ if policy_model is not None and puzzle_test_meta.exists():
131
+ print("Loading puzzle test set...")
132
+ puzzle_test_ds = ChessPolicyDataset.from_memmap(data_dir, tokenizer, name="puzzle_test")
133
+ puzzle_test_loader = DataLoader(
134
+ puzzle_test_ds, batch_size=batch_size, shuffle=False,
135
+ collate_fn=collate_fn_policy, num_workers=num_workers, pin_memory=True,
136
+ )
137
+ print(f" {len(puzzle_test_ds):,} test puzzles")
138
+
139
+ t0 = time.time()
140
+ m = eval_puzzle_solve_rate(policy_model, puzzle_test_loader, device, pad_id)
141
+ print(
142
+ f" Puzzle | first_move={_fmt(m['first_move_solve_rate'], pct=True)}"
143
+ f" all_moves={_fmt(m['all_moves_solve_rate'], pct=True)}"
144
+ f" ({time.time()-t0:.1f}s)"
145
+ )
146
+ results["puzzle"] = m
147
+ elif policy_model is not None and not puzzle_test_meta.exists():
148
+ print("\nPuzzle test set not found (run build_datasets.py to create it). Skipping.")
149
+
150
+ return results
151
+
152
+
153
+ def _print_summary(results: dict) -> None:
154
+ print("\n" + "=" * 60)
155
+ print("BENCHMARK SUMMARY")
156
+ print("=" * 60)
157
+ if "reward" in results:
158
+ m = results["reward"]
159
+ print(f" Reward model MSE={m['mse']:.4f} MAE={m['mae']:.4f} Pearson r={m['pearson_r']:.4f}")
160
+ if "policy" in results:
161
+ m = results["policy"]
162
+ print(
163
+ f" Policy model loss={m['loss']:.4f} perplexity={m['perplexity']:.2f}"
164
+ f" top-1={m['top1_acc']*100:.2f}% top-5={m['top5_acc']*100:.2f}%"
165
+ )
166
+ if "puzzle" in results:
167
+ m = results["puzzle"]
168
+ print(
169
+ f" Puzzle eval first-move solve={m['first_move_solve_rate']*100:.2f}%"
170
+ f" all-moves solve={m['all_moves_solve_rate']*100:.2f}%"
171
+ )
172
+ if not results:
173
+ print(" No results — check that models and test sets exist.")
174
+ print("=" * 60)
175
+
176
+
177
+ def _build_argparser() -> argparse.ArgumentParser:
178
+ p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
179
+ p.add_argument("--reward-model", default="reward_model.pt",
180
+ help="Path to reward_model.pt (default: reward_model.pt)")
181
+ p.add_argument("--policy-model", default="policy_model.pt",
182
+ help="Path to policy_model.pt (default: policy_model.pt)")
183
+ p.add_argument("--data-dir", type=Path, default=Path("data"),
184
+ help="Directory containing test set .bin files and tokenizer.pt (default: data/)")
185
+ p.add_argument("--batch-size", type=int, default=512,
186
+ help="Batch size for evaluation (default: 512)")
187
+ p.add_argument("--num-workers", type=int, default=4,
188
+ help="DataLoader worker count (default: 4)")
189
+ p.add_argument("--device", default=None,
190
+ help="Device override, e.g. 'cpu' or 'cuda:0' (default: auto-detect)")
191
+ return p
192
+
193
+
194
+ if __name__ == "__main__":
195
+ args = _build_argparser().parse_args()
196
+ results = run_benchmark(
197
+ data_dir=args.data_dir,
198
+ reward_model_path=args.reward_model,
199
+ policy_model_path=args.policy_model,
200
+ batch_size=args.batch_size,
201
+ num_workers=args.num_workers,
202
+ device=args.device,
203
+ )
204
+ _print_summary(results)
src/build_datasets.py ADDED
@@ -0,0 +1,697 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Build the two training datasets for hybrid two-phase training.
2
+
3
+ Runs three resumable stages:
4
+ 1. Stream the Lichess HF dataset, filter by Elo + Termination, save two
5
+ disjoint raw-game subsets (movetext + Result only).
6
+ 2. Build the shared tokenizer from the outcome subset and generate
7
+ outcome-labeled samples ({+1, 0, -1} from game Result).
8
+ 3. Run parallel Stockfish on the disjoint subset to produce precisely
9
+ labeled samples (tanh(cp/400)).
10
+
11
+ Each stage skips if its output files already exist. Use --force to re-run.
12
+
13
+ Outputs (under --out-dir):
14
+ games_outcome.pt raw outcome-subset games
15
+ games_stockfish.pt raw stockfish-subset games
16
+ tokenizer.pt shared Tokenizer (built from outcome games)
17
+ outcome_samples.pt list[(token_ids, outcome_label)]
18
+ stockfish_samples.pt list[(token_ids, stockfish_label)]
19
+ """
20
+
21
+ import argparse
22
+ import random
23
+ from pathlib import Path
24
+
25
+ import chess
26
+ import numpy as np
27
+ import torch
28
+ from datasets import load_dataset
29
+ from tqdm import tqdm
30
+
31
+ from model import CLS_TOKEN
32
+ from train import (
33
+ build_tokenizer_from_games,
34
+ generate_samples_stockfish_parallel,
35
+ parse_movetext,
36
+ )
37
+
38
+ # Lichess Result → outcome label from white's perspective.
39
+ RESULT_TO_LABEL = {"1-0": 1.0, "0-1": -1.0, "1/2-1/2": 0.0}
40
+
41
+
42
+ def _save_as_memmap(
43
+ samples: list[tuple[list[int], float]], out_dir: Path, name: str, max_seq_len: int = 128
44
+ ) -> None:
45
+ """Save samples as memory-mapped arrays for fast DataLoader access.
46
+
47
+ Sequences longer than max_seq_len are truncated (keeps the most recent tokens,
48
+ since the CLS token is at position 0 we keep ids[:max_seq_len]).
49
+
50
+ Produces three files:
51
+ {name}_tokens.bin — (N, max_seq_len) int32, zero-padded
52
+ {name}_labels.bin — (N,) float32
53
+ {name}_lengths.bin — (N,) int32, actual sequence length per sample (capped at max_seq_len)
54
+ {name}_meta.pt — dict with 'n' and 'max_len'
55
+ """
56
+ n = len(samples)
57
+ max_len = min(max(len(ids) for ids, _ in samples), max_seq_len)
58
+ print(f" memmap {name}: {n:,} samples, max_seq_len={max_len}")
59
+
60
+ tokens = np.memmap(out_dir / f"{name}_tokens.bin", dtype=np.int32, mode="w+", shape=(n, max_len))
61
+ labels = np.memmap(out_dir / f"{name}_labels.bin", dtype=np.float32, mode="w+", shape=(n,))
62
+ lengths = np.memmap(out_dir / f"{name}_lengths.bin", dtype=np.int32, mode="w+", shape=(n,))
63
+
64
+ for i, (ids, label) in enumerate(tqdm(samples, desc=f" writing {name}", unit="sample")):
65
+ ids = ids[:max_len]
66
+ l = len(ids)
67
+ tokens[i, :l] = ids
68
+ labels[i] = label
69
+ lengths[i] = l
70
+
71
+ tokens.flush()
72
+ labels.flush()
73
+ lengths.flush()
74
+ torch.save({"n": n, "max_len": max_len}, out_dir / f"{name}_meta.pt")
75
+ size_gb = (tokens.nbytes + labels.nbytes + lengths.nbytes) / 1024 ** 3
76
+ print(f" memmap {name} saved ({size_gb:.2f} GB)")
77
+
78
+
79
+ def stage1_collect_games(args: argparse.Namespace) -> None:
80
+ policy_games_path = args.out_dir / "games_outcome.pt"
81
+ reward_games_path = args.out_dir / "games_stockfish.pt"
82
+ policy_only = getattr(args, "policy_only", False)
83
+
84
+ if policy_only:
85
+ # Reward subset is irrelevant when we're only training the policy model.
86
+ # Skip-condition checks only the policy artifact.
87
+ if policy_games_path.exists() and not args.force:
88
+ print(f"Stage 1: skipping — {policy_games_path.name} exists (--policy-only).")
89
+ return
90
+ else:
91
+ if policy_games_path.exists() and reward_games_path.exists() and not args.force:
92
+ print(f"Stage 1: skipping — {policy_games_path.name} and {reward_games_path.name} exist.")
93
+ return
94
+
95
+ if policy_only:
96
+ lower_elo = args.policy_min_elo
97
+ print(
98
+ f"Stage 1: streaming Lichess/standard-chess-games (Termination == 'Normal'), "
99
+ f"policy Elo >= {args.policy_min_elo} (target {args.policy_games:,}). "
100
+ f"Reward subset skipped (--policy-only)."
101
+ )
102
+ else:
103
+ lower_elo = min(args.reward_min_elo, args.policy_min_elo)
104
+ print(
105
+ f"Stage 1: streaming Lichess/standard-chess-games (Termination == 'Normal'), "
106
+ f"reward Elo >= {args.reward_min_elo} (target {args.reward_games:,}), "
107
+ f"policy Elo >= {args.policy_min_elo} (target {args.policy_games:,})..."
108
+ )
109
+
110
+ ds = load_dataset("Lichess/standard-chess-games", split="train", streaming=True)
111
+ # Pre-filter by the lower of the two thresholds to skip clearly ineligible games.
112
+ ds = ds.filter(
113
+ lambda r: (
114
+ r.get("WhiteElo") is not None
115
+ and r.get("BlackElo") is not None
116
+ and r["WhiteElo"] >= lower_elo
117
+ and r["BlackElo"] >= lower_elo
118
+ and r.get("Termination") == "Normal"
119
+ )
120
+ )
121
+
122
+ policy_games: list[dict] = []
123
+ reward_games: list[dict] = []
124
+ keep_keys = ("movetext", "Result")
125
+
126
+ for row in tqdm(ds, desc="Stage 1: streaming", unit="game"):
127
+ white_elo = row.get("WhiteElo", 0)
128
+ black_elo = row.get("BlackElo", 0)
129
+ minimal = {k: row.get(k) for k in keep_keys}
130
+
131
+ if (
132
+ not policy_only
133
+ and len(reward_games) < args.reward_games
134
+ and white_elo >= args.reward_min_elo
135
+ and black_elo >= args.reward_min_elo
136
+ ):
137
+ reward_games.append(minimal)
138
+
139
+ if len(policy_games) < args.policy_games and white_elo >= args.policy_min_elo and black_elo >= args.policy_min_elo:
140
+ policy_games.append(minimal)
141
+
142
+ if policy_only:
143
+ if len(policy_games) >= args.policy_games:
144
+ break
145
+ elif len(reward_games) >= args.reward_games and len(policy_games) >= args.policy_games:
146
+ break
147
+
148
+ if policy_only:
149
+ if len(policy_games) < args.policy_games:
150
+ print(
151
+ f" WARNING: dataset exhausted before target — "
152
+ f"got {len(policy_games):,} policy games."
153
+ )
154
+ elif len(reward_games) < args.reward_games or len(policy_games) < args.policy_games:
155
+ print(
156
+ f" WARNING: dataset exhausted before target — "
157
+ f"got {len(reward_games):,} reward + {len(policy_games):,} policy games."
158
+ )
159
+
160
+ if not policy_only:
161
+ print(f"Stage 1: saving {reward_games_path} ({len(reward_games):,} games)...")
162
+ torch.save(reward_games, reward_games_path)
163
+ print(f"Stage 1: saving {policy_games_path} ({len(policy_games):,} games)...")
164
+ torch.save(policy_games, policy_games_path)
165
+
166
+
167
+ def _generate_outcome_samples(games, tokenizer, max_positions_per_game, skip_ply):
168
+ """Build (token_ids, outcome_label) samples for the phase-1 dataset."""
169
+ cls_id = tokenizer.symbol_to_token[CLS_TOKEN]
170
+ samples: list[tuple[list[int], float]] = []
171
+ with tqdm(games, desc="Stage 2: outcome samples", unit="game") as pbar:
172
+ for idx, game in enumerate(pbar):
173
+ result = game.get("Result")
174
+ if result not in RESULT_TO_LABEL:
175
+ continue
176
+ label = RESULT_TO_LABEL[result]
177
+
178
+ movetext = game.get("movetext", "")
179
+ if not movetext:
180
+ continue
181
+ move_sans = parse_movetext(movetext)
182
+ if len(move_sans) < max(2, skip_ply + 1):
183
+ continue
184
+
185
+ eligible = list(range(skip_ply, len(move_sans)))
186
+ num_positions = min(max_positions_per_game, len(eligible))
187
+ rng = random.Random(idx)
188
+ sample_indices = set(rng.sample(eligible, num_positions))
189
+
190
+ board = chess.Board()
191
+ valid_moves: list[str] = []
192
+ for i, san in enumerate(move_sans):
193
+ try:
194
+ move = board.parse_san(san)
195
+ board.push(move)
196
+ valid_moves.append(move.uci())
197
+ except (chess.InvalidMoveError, chess.AmbiguousMoveError):
198
+ break
199
+ if i in sample_indices:
200
+ token_ids = [cls_id] + tokenizer.encode_moves(valid_moves)
201
+ samples.append((token_ids, label))
202
+
203
+ if (idx + 1) % 50_000 == 0:
204
+ pbar.set_postfix(samples=f"{len(samples):,}")
205
+
206
+ return samples
207
+
208
+
209
+ def stage2_outcome_samples(args: argparse.Namespace) -> None:
210
+ tokenizer_path = args.out_dir / "tokenizer.pt"
211
+ meta_path = args.out_dir / "outcome_meta.pt"
212
+ if tokenizer_path.exists() and meta_path.exists() and not args.force:
213
+ print(f"Stage 2: skipping — {tokenizer_path.name} and {meta_path.name} exist.")
214
+ return
215
+
216
+ raw_games_path = args.out_dir / "games_outcome.pt"
217
+ print(f"Stage 2: loading outcome games from {raw_games_path}...")
218
+ games = torch.load(raw_games_path, weights_only=False)
219
+
220
+ print("Stage 2: building tokenizer from all UCI moves...")
221
+ tokenizer = build_tokenizer_from_games()
222
+ print(f"Stage 2: tokenizer vocab size = {tokenizer.language_size}")
223
+ torch.save(tokenizer, tokenizer_path)
224
+
225
+ print("Stage 2: generating outcome samples (up to 20 per game)...")
226
+ samples = _generate_outcome_samples(
227
+ games,
228
+ tokenizer,
229
+ max_positions_per_game=20,
230
+ skip_ply=0,
231
+ )
232
+ print(f"Stage 2: saving {len(samples):,} outcome samples as memmap...")
233
+ _save_as_memmap(samples, args.out_dir, "outcome", max_seq_len=args.max_seq_len)
234
+
235
+
236
+ def _generate_policy_sequences(games, tokenizer, max_seq_len: int = 128) -> list[list[int]]:
237
+ """Tokenize full game sequences for policy training.
238
+
239
+ Each output sequence is [CLS, m1, m2, ..., mN], truncated to max_seq_len.
240
+ Games with fewer than 2 valid UCI moves are skipped.
241
+ """
242
+ cls_id = tokenizer.symbol_to_token[CLS_TOKEN]
243
+ sequences: list[list[int]] = []
244
+ with tqdm(games, desc="Stage 4: policy sequences", unit="game") as pbar:
245
+ for game in pbar:
246
+ movetext = game.get("movetext", "")
247
+ if not movetext:
248
+ continue
249
+ move_sans = parse_movetext(movetext)
250
+ if len(move_sans) < 2:
251
+ continue
252
+ board = chess.Board()
253
+ move_ucis: list[str] = []
254
+ for san in move_sans:
255
+ try:
256
+ move = board.parse_san(san)
257
+ board.push(move)
258
+ move_ucis.append(move.uci())
259
+ except (chess.InvalidMoveError, chess.AmbiguousMoveError):
260
+ break
261
+ if len(move_ucis) < 2:
262
+ continue
263
+ move_ucis = move_ucis[:max_seq_len - 1]
264
+ sequences.append([cls_id] + tokenizer.encode_moves(move_ucis))
265
+ return sequences
266
+
267
+
268
+ def _save_policy_memmap(
269
+ sequences: list[list[int]], out_dir: Path, name: str, max_seq_len: int = 128,
270
+ fens: list[str] | None = None, fen_len: int = 100,
271
+ ) -> None:
272
+ """Save policy sequences as memory-mapped arrays (no labels).
273
+
274
+ Produces:
275
+ {name}_tokens.bin — (N, max_len) int32, zero-padded
276
+ {name}_lengths.bin — (N,) int32, actual sequence length per sample
277
+ {name}_meta.pt — dict with 'n', 'max_len', and (if fens given) 'fen_len'
278
+
279
+ If `fens` is provided, also writes {name}_fens.bin — (N, fen_len) uint8
280
+ holding zero-padded ASCII FEN strings, one per sample. Used by the
281
+ CNN-conditioned policy training to reconstruct each sample's starting board.
282
+ """
283
+ n = len(sequences)
284
+ max_len = min(max(len(s) for s in sequences), max_seq_len)
285
+ print(f" memmap {name}: {n:,} sequences, max_seq_len={max_len}")
286
+
287
+ tokens = np.memmap(out_dir / f"{name}_tokens.bin", dtype=np.int32, mode="w+", shape=(n, max_len))
288
+ lengths = np.memmap(out_dir / f"{name}_lengths.bin", dtype=np.int32, mode="w+", shape=(n,))
289
+
290
+ for i, seq in enumerate(tqdm(sequences, desc=f" writing {name}", unit="seq")):
291
+ seq = seq[:max_len]
292
+ l = len(seq)
293
+ tokens[i, :l] = seq
294
+ lengths[i] = l
295
+
296
+ tokens.flush()
297
+ lengths.flush()
298
+
299
+ meta = {"n": n, "max_len": max_len}
300
+ extra_bytes = 0
301
+ if fens is not None:
302
+ assert len(fens) == n, f"fen count {len(fens)} mismatch with sequence count {n}"
303
+ fens_mm = np.memmap(out_dir / f"{name}_fens.bin", dtype=np.uint8, mode="w+", shape=(n, fen_len))
304
+ for i, fen in enumerate(fens):
305
+ b = fen.encode("ascii")[:fen_len]
306
+ fens_mm[i, :len(b)] = list(b)
307
+ fens_mm.flush()
308
+ meta["fen_len"] = fen_len
309
+ extra_bytes = fens_mm.nbytes
310
+
311
+ torch.save(meta, out_dir / f"{name}_meta.pt")
312
+ size_gb = (tokens.nbytes + lengths.nbytes + extra_bytes) / 1024 ** 3
313
+ print(f" memmap {name} saved ({size_gb:.3f} GB)")
314
+
315
+
316
+ def _process_puzzle(
317
+ row: dict,
318
+ tokenizer_symbol_map: dict,
319
+ cls_id: int,
320
+ ) -> tuple[list[int], str] | None:
321
+ """Parse one Lichess puzzle row into a (token_sequence, FEN) pair.
322
+
323
+ Sequence layout: [CLS, setup_move, solver_move1, opp_response, solver_move2, ...]
324
+
325
+ The setup move (Moves[0]) is included as context so the model conditions on it
326
+ when predicting the solution. During training the loss on the setup move position
327
+ is masked out — we model P[m_n | S, m_{<n}] where S is the setup move.
328
+
329
+ The FEN is the puzzle's starting board position. It is persisted alongside the
330
+ token sequence so the CNN-conditioned policy training can reconstruct the
331
+ starting board planes (CNN's input) at __getitem__ time.
332
+
333
+ Returns None if any move is illegal, unknown to the tokenizer, or the sequence
334
+ has fewer than 3 tokens (CLS + setup + at least one solver move).
335
+ """
336
+ fen = row.get("FEN", "")
337
+ moves_str = row.get("Moves", "")
338
+ if not fen or not moves_str:
339
+ return None
340
+ uci_moves = moves_str.strip().split()
341
+ if len(uci_moves) < 2: # need setup + at least one solver move
342
+ return None
343
+ try:
344
+ board = chess.Board(fen)
345
+ except ValueError:
346
+ return None
347
+
348
+ # Tokenize all moves: setup first (as context), then the full solution.
349
+ token_ids: list[int] = [cls_id]
350
+ for uci in uci_moves:
351
+ try:
352
+ move = chess.Move.from_uci(uci)
353
+ except ValueError:
354
+ return None
355
+ if move not in board.legal_moves:
356
+ return None
357
+ if uci not in tokenizer_symbol_map:
358
+ return None
359
+ token_ids.append(tokenizer_symbol_map[uci])
360
+ board.push(move)
361
+
362
+ if len(token_ids) < 3: # CLS + setup + at least one solver move
363
+ return None
364
+ return token_ids, fen
365
+
366
+
367
+ def stage3_stockfish_samples(args: argparse.Namespace) -> None:
368
+ meta_path = args.out_dir / "stockfish_meta.pt"
369
+ if meta_path.exists() and not args.force:
370
+ print(f"Stage 3: skipping — {meta_path.name} exists.")
371
+ return
372
+
373
+ games_path = args.out_dir / "games_stockfish.pt"
374
+ tokenizer_path = args.out_dir / "tokenizer.pt"
375
+ print(f"Stage 3: loading {games_path} and {tokenizer_path}...")
376
+ games = torch.load(games_path, weights_only=False)
377
+ tokenizer = torch.load(tokenizer_path, weights_only=False)
378
+
379
+ print(
380
+ f"Stage 3: running parallel Stockfish ({args.workers} workers, "
381
+ f"depth {args.stockfish_depth}) on {len(games):,} games..."
382
+ )
383
+ samples = generate_samples_stockfish_parallel(
384
+ games,
385
+ tokenizer,
386
+ num_workers=args.workers,
387
+ stockfish_depth=args.stockfish_depth,
388
+ sample_rate=args.sample_rate,
389
+ skew_exponent=args.position_skew,
390
+ )
391
+
392
+ print(f"Stage 3: saving {len(samples):,} stockfish samples as memmap...")
393
+ _save_as_memmap(samples, args.out_dir, "stockfish", max_seq_len=args.max_seq_len)
394
+
395
+
396
+ def stage4_policy_sequences(args: argparse.Namespace) -> None:
397
+ meta_path = args.out_dir / "policy_meta.pt"
398
+ if meta_path.exists() and not args.force:
399
+ print(f"Stage 4: skipping — {meta_path.name} exists.")
400
+ return
401
+
402
+ games_path = args.out_dir / "games_outcome.pt"
403
+ tokenizer_path = args.out_dir / "tokenizer.pt"
404
+ print(f"Stage 4: loading {games_path} and {tokenizer_path}...")
405
+ games = torch.load(games_path, weights_only=False)
406
+ tokenizer = torch.load(tokenizer_path, weights_only=False)
407
+
408
+ print(f"Stage 4: tokenizing {len(games):,} games into policy sequences...")
409
+ sequences = _generate_policy_sequences(games, tokenizer, max_seq_len=args.max_seq_len)
410
+ print(f"Stage 4: saving {len(sequences):,} policy sequences as memmap...")
411
+ _save_policy_memmap(sequences, args.out_dir, "policy", max_seq_len=args.max_seq_len)
412
+
413
+
414
+ def _write_test_subset_reward(out_dir: Path, src_name: str, dst_name: str, indices: np.ndarray) -> None:
415
+ """Write a subset of a reward memmap (tokens+labels+lengths) to new files."""
416
+ meta = torch.load(out_dir / f"{src_name}_meta.pt", weights_only=True)
417
+ n_src, max_len = meta["n"], meta["max_len"]
418
+ src_tokens = np.memmap(out_dir / f"{src_name}_tokens.bin", dtype=np.int32, mode="r", shape=(n_src, max_len))
419
+ src_labels = np.memmap(out_dir / f"{src_name}_labels.bin", dtype=np.float32, mode="r", shape=(n_src,))
420
+ src_lengths = np.memmap(out_dir / f"{src_name}_lengths.bin", dtype=np.int32, mode="r", shape=(n_src,))
421
+ n_test = len(indices)
422
+ dst_tokens = np.memmap(out_dir / f"{dst_name}_tokens.bin", dtype=np.int32, mode="w+", shape=(n_test, max_len))
423
+ dst_labels = np.memmap(out_dir / f"{dst_name}_labels.bin", dtype=np.float32, mode="w+", shape=(n_test,))
424
+ dst_lengths = np.memmap(out_dir / f"{dst_name}_lengths.bin", dtype=np.int32, mode="w+", shape=(n_test,))
425
+ for i, idx in enumerate(tqdm(indices, desc=f" writing {dst_name}", unit="sample")):
426
+ dst_tokens[i] = src_tokens[idx]
427
+ dst_labels[i] = src_labels[idx]
428
+ dst_lengths[i] = src_lengths[idx]
429
+ dst_tokens.flush()
430
+ dst_labels.flush()
431
+ dst_lengths.flush()
432
+ torch.save({"n": n_test, "max_len": max_len}, out_dir / f"{dst_name}_meta.pt")
433
+ print(f" {dst_name}: {n_test:,} samples written")
434
+
435
+
436
+ def _write_test_subset_policy(out_dir: Path, src_name: str, dst_name: str, indices: np.ndarray) -> None:
437
+ """Write a subset of a policy memmap (tokens+lengths, no labels) to new files."""
438
+ meta = torch.load(out_dir / f"{src_name}_meta.pt", weights_only=True)
439
+ n_src, max_len = meta["n"], meta["max_len"]
440
+ src_tokens = np.memmap(out_dir / f"{src_name}_tokens.bin", dtype=np.int32, mode="r", shape=(n_src, max_len))
441
+ src_lengths = np.memmap(out_dir / f"{src_name}_lengths.bin", dtype=np.int32, mode="r", shape=(n_src,))
442
+ n_test = len(indices)
443
+ dst_tokens = np.memmap(out_dir / f"{dst_name}_tokens.bin", dtype=np.int32, mode="w+", shape=(n_test, max_len))
444
+ dst_lengths = np.memmap(out_dir / f"{dst_name}_lengths.bin", dtype=np.int32, mode="w+", shape=(n_test,))
445
+ for i, idx in enumerate(tqdm(indices, desc=f" writing {dst_name}", unit="seq")):
446
+ dst_tokens[i] = src_tokens[idx]
447
+ dst_lengths[i] = src_lengths[idx]
448
+ dst_tokens.flush()
449
+ dst_lengths.flush()
450
+ torch.save({"n": n_test, "max_len": max_len}, out_dir / f"{dst_name}_meta.pt")
451
+ print(f" {dst_name}: {n_test:,} sequences written")
452
+
453
+
454
+ def stage_build_test_splits(args: argparse.Namespace, out_dir: Path) -> None:
455
+ """Build held-out test sets for reward and policy models from existing memmaps.
456
+
457
+ Uses a fixed random seed (42) so the same indices are always selected.
458
+ Saves the chosen indices to {name}_test_indices.npy so the corresponding
459
+ training memmap loader can exclude them — making train and test disjoint
460
+ even though they share the underlying .bin file.
461
+
462
+ Produces:
463
+ stockfish_test_*.bin / stockfish_test_meta.pt / stockfish_test_indices.npy
464
+ policy_test_*.bin / policy_test_meta.pt / policy_test_indices.npy
465
+ """
466
+ rng = np.random.default_rng(42)
467
+ policy_only = getattr(args, "policy_only", False)
468
+
469
+ # Reward test set — skipped when --policy-only since no Stockfish data exists.
470
+ reward_test_meta = out_dir / "stockfish_test_meta.pt"
471
+ sf_meta_path = out_dir / "stockfish_meta.pt"
472
+ if policy_only:
473
+ print("Test splits: stockfish_test skipped (--policy-only).")
474
+ elif (not reward_test_meta.exists() or args.force) and sf_meta_path.exists():
475
+ print(f"Test splits: building stockfish_test ({args.reward_test_size:,} samples)...")
476
+ sf_meta = torch.load(sf_meta_path, weights_only=True)
477
+ n = sf_meta["n"]
478
+ test_n = min(args.reward_test_size, n)
479
+ idx = rng.choice(n, size=test_n, replace=False)
480
+ idx.sort()
481
+ _write_test_subset_reward(out_dir, "stockfish", "stockfish_test", idx)
482
+ np.save(out_dir / "stockfish_test_indices.npy", idx)
483
+ print(f" saved stockfish_test_indices.npy ({test_n:,} indices excluded from training)")
484
+ elif reward_test_meta.exists():
485
+ print("Test splits: stockfish_test already exists, skipping.")
486
+
487
+ # Policy test set
488
+ policy_test_meta = out_dir / "policy_test_meta.pt"
489
+ pol_meta_path = out_dir / "policy_meta.pt"
490
+ if (not policy_test_meta.exists() or args.force) and pol_meta_path.exists():
491
+ print(f"Test splits: building policy_test ({args.policy_test_size:,} sequences)...")
492
+ pol_meta = torch.load(pol_meta_path, weights_only=True)
493
+ n = pol_meta["n"]
494
+ test_n = min(args.policy_test_size, n)
495
+ idx = rng.choice(n, size=test_n, replace=False)
496
+ idx.sort()
497
+ _write_test_subset_policy(out_dir, "policy", "policy_test", idx)
498
+ np.save(out_dir / "policy_test_indices.npy", idx)
499
+ print(f" saved policy_test_indices.npy ({test_n:,} indices excluded from training)")
500
+ elif policy_test_meta.exists():
501
+ print("Test splits: policy_test already exists, skipping.")
502
+
503
+
504
+ def stage5_puzzle_samples(args: argparse.Namespace, tokenizer, out_dir: Path) -> None:
505
+ train_done = (out_dir / "puzzle_meta.pt").exists()
506
+ test_done = (out_dir / "puzzle_test_meta.pt").exists()
507
+ if train_done and test_done and not args.force:
508
+ print("Stage 5: skipping — puzzle_meta.pt and puzzle_test_meta.pt exist.")
509
+ return
510
+
511
+ print("Stage 5: loading Lichess/chess-puzzles from HuggingFace...")
512
+ ds = load_dataset("Lichess/chess-puzzles", split="train", streaming=True)
513
+
514
+ min_pop = args.min_puzzle_popularity
515
+ min_plays = args.min_puzzle_plays
516
+ cls_id = tokenizer.symbol_to_token[CLS_TOKEN]
517
+ sym_map = tokenizer.symbol_to_token
518
+ test_seqs: list[list[int]] = []
519
+ test_fens: list[str] = []
520
+ train_seqs: list[list[int]] = []
521
+ train_fens: list[str] = []
522
+ skipped = 0
523
+ test_target = args.puzzle_test_size
524
+ train_target = args.puzzle_count
525
+
526
+ with tqdm(ds, desc="Stage 5: puzzles", unit="puzzle") as pbar:
527
+ for row in pbar:
528
+ if min_pop is not None and row.get("Popularity", 0) < min_pop:
529
+ continue
530
+ if min_plays is not None and row.get("NbPlays", 0) < min_plays:
531
+ continue
532
+ result = _process_puzzle(row, sym_map, cls_id)
533
+ if result is None:
534
+ skipped += 1
535
+ continue
536
+ seq, fen = result
537
+ if len(test_seqs) < test_target:
538
+ test_seqs.append(seq)
539
+ test_fens.append(fen)
540
+ else:
541
+ train_seqs.append(seq)
542
+ train_fens.append(fen)
543
+ pbar.set_postfix(test=len(test_seqs), train=len(train_seqs), skipped=skipped)
544
+ if train_target is not None and len(train_seqs) >= train_target:
545
+ break
546
+
547
+ print(
548
+ f"Stage 5: test={len(test_seqs):,} puzzles, "
549
+ f"train={len(train_seqs):,} puzzles, "
550
+ f"skipped={skipped:,} invalid."
551
+ )
552
+ if test_seqs and not test_done:
553
+ _save_policy_memmap(
554
+ test_seqs, out_dir, "puzzle_test", max_seq_len=args.max_seq_len, fens=test_fens,
555
+ )
556
+ if train_seqs and not train_done:
557
+ _save_policy_memmap(
558
+ train_seqs, out_dir, "puzzle", max_seq_len=args.max_seq_len, fens=train_fens,
559
+ )
560
+ elif not train_seqs:
561
+ print("Stage 5: WARNING — no training puzzles collected.")
562
+
563
+
564
+ def main():
565
+ parser = argparse.ArgumentParser(description=__doc__)
566
+ parser.add_argument("--out-dir", type=Path, default=Path("data"))
567
+ parser.add_argument("--policy-games", type=int, default=1_000_000,
568
+ help="Number of games to collect for policy model training")
569
+ parser.add_argument("--reward-games", type=int, default=1_000_000,
570
+ help="Number of games to collect for reward model (Stockfish eval)")
571
+ parser.add_argument("--policy-min-elo", type=int, default=1800,
572
+ help="Min Elo for both players in policy training games")
573
+ parser.add_argument("--reward-min-elo", type=int, default=1500,
574
+ help="Min Elo for both players in reward model training games")
575
+ parser.add_argument("--sample-rate", type=float, default=0.25,
576
+ help="Fraction of positions to sample per game (scales with game length)")
577
+ parser.add_argument("--position-skew", type=float, default=1.5,
578
+ help="Power-law exponent weighting later positions; 1.0=linear, higher=more mid/late")
579
+ parser.add_argument("--workers", type=int, default=16)
580
+ parser.add_argument("--stockfish-depth", type=int, default=12)
581
+ parser.add_argument("--max-seq-len", type=int, default=128,
582
+ help="Truncate token sequences to this length when writing .bin files")
583
+ parser.add_argument(
584
+ "--force",
585
+ action="store_true",
586
+ help="Re-run all stages even if their outputs already exist",
587
+ )
588
+ parser.add_argument("--puzzle-count", type=int, default=None, dest="puzzle_count",
589
+ help="Max puzzles to include (default: all ~4.99M)")
590
+ parser.add_argument("--min-puzzle-popularity", type=int, default=None, dest="min_puzzle_popularity",
591
+ help="Min Lichess Popularity score (0-100 scale)")
592
+ parser.add_argument("--min-puzzle-plays", type=int, default=None, dest="min_puzzle_plays",
593
+ help="Min NbPlays for a puzzle to be included")
594
+ parser.add_argument("--skip-puzzles", action="store_true",
595
+ help="Skip Stage 5 puzzle processing")
596
+ parser.add_argument("--puzzles-only", action="store_true",
597
+ help="Only run Stage 5 (puzzle processing). Skips game collection, "
598
+ "outcome/Stockfish/policy memmaps, and test splits. Requires "
599
+ "tokenizer.pt to exist (or it will be built from the UCI vocab).")
600
+ parser.add_argument("--policy-only", action="store_true",
601
+ help="Skip everything Stockfish/reward-related: Stage 1 collects only "
602
+ "policy games, Stage 3 (Stockfish labeling) is skipped, and the "
603
+ "stockfish_test split is not built. Stages 1/2/4/5 + policy_test "
604
+ "still run, producing tokenizer.pt, policy_* / puzzle_* memmaps, "
605
+ "and the policy_test split.")
606
+ parser.add_argument("--puzzle-test-size", type=int, default=100_000, dest="puzzle_test_size",
607
+ help="Number of puzzle sequences held out for the test set (default: 100000)")
608
+ parser.add_argument("--reward-test-size", type=int, default=50_000, dest="reward_test_size",
609
+ help="Number of reward positions held out for the test set (default: 50000)")
610
+ parser.add_argument("--policy-test-size", type=int, default=50_000, dest="policy_test_size",
611
+ help="Number of policy sequences held out for the test set (default: 50000)")
612
+ args = parser.parse_args()
613
+
614
+ args.out_dir.mkdir(parents=True, exist_ok=True)
615
+
616
+ if args.puzzles_only and args.policy_only:
617
+ parser.error("--puzzles-only and --policy-only are mutually exclusive.")
618
+
619
+ if args.puzzles_only:
620
+ print("--puzzles-only: skipping Stages 1-4 and test-split builder.")
621
+ tokenizer_path = args.out_dir / "tokenizer.pt"
622
+ if tokenizer_path.exists():
623
+ tokenizer = torch.load(tokenizer_path, weights_only=False)
624
+ else:
625
+ # Tokenizer is just the enumerated UCI vocab — no games needed.
626
+ print(" tokenizer.pt missing; building from UCI vocab...")
627
+ tokenizer = build_tokenizer_from_games()
628
+ torch.save(tokenizer, tokenizer_path)
629
+ stage5_puzzle_samples(args, tokenizer, args.out_dir)
630
+ elif args.policy_only:
631
+ print("--policy-only: skipping Stage 3 (Stockfish labeling) and stockfish_test split.")
632
+ stage1_collect_games(args)
633
+ stage2_outcome_samples(args)
634
+ stage4_policy_sequences(args)
635
+
636
+ tokenizer_path = args.out_dir / "tokenizer.pt"
637
+ if not args.skip_puzzles and tokenizer_path.exists():
638
+ tokenizer = torch.load(tokenizer_path, weights_only=False)
639
+ stage5_puzzle_samples(args, tokenizer, args.out_dir)
640
+ elif not args.skip_puzzles:
641
+ print("Stage 5: skipping — tokenizer.pt not found (run stages 1-2 first).")
642
+
643
+ stage_build_test_splits(args, args.out_dir)
644
+ else:
645
+ stage1_collect_games(args)
646
+ stage2_outcome_samples(args)
647
+ stage3_stockfish_samples(args)
648
+ stage4_policy_sequences(args)
649
+
650
+ tokenizer_path = args.out_dir / "tokenizer.pt"
651
+ if not args.skip_puzzles and tokenizer_path.exists():
652
+ tokenizer = torch.load(tokenizer_path, weights_only=False)
653
+ stage5_puzzle_samples(args, tokenizer, args.out_dir)
654
+ elif not args.skip_puzzles:
655
+ print("Stage 5: skipping — tokenizer.pt not found (run stages 1-2 first).")
656
+
657
+ stage_build_test_splits(args, args.out_dir)
658
+
659
+ print("\nAll stages complete. Artifacts:")
660
+ for name in (
661
+ "games_outcome.pt",
662
+ "games_stockfish.pt",
663
+ "tokenizer.pt",
664
+ "outcome_tokens.bin",
665
+ "outcome_labels.bin",
666
+ "outcome_lengths.bin",
667
+ "outcome_meta.pt",
668
+ "stockfish_tokens.bin",
669
+ "stockfish_labels.bin",
670
+ "stockfish_lengths.bin",
671
+ "stockfish_meta.pt",
672
+ "policy_tokens.bin",
673
+ "policy_lengths.bin",
674
+ "policy_meta.pt",
675
+ "puzzle_tokens.bin",
676
+ "puzzle_lengths.bin",
677
+ "puzzle_fens.bin",
678
+ "puzzle_meta.pt",
679
+ "puzzle_test_tokens.bin",
680
+ "puzzle_test_lengths.bin",
681
+ "puzzle_test_fens.bin",
682
+ "puzzle_test_meta.pt",
683
+ "stockfish_test_tokens.bin",
684
+ "stockfish_test_labels.bin",
685
+ "stockfish_test_lengths.bin",
686
+ "stockfish_test_meta.pt",
687
+ "policy_test_tokens.bin",
688
+ "policy_test_lengths.bin",
689
+ "policy_test_meta.pt",
690
+ ):
691
+ path = args.out_dir / name
692
+ size_mb = path.stat().st_size / 1024 / 1024 if path.exists() else 0
693
+ print(f" {path} ({size_mb:.1f} MB)")
694
+
695
+
696
+ if __name__ == "__main__":
697
+ main()
src/minimax.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import chess
2
+ import math
3
+ from typing import Callable
4
+
5
+ from model import PIECE_VALUES
6
+
7
+
8
+ def dummy_reward_fn(board: chess.Board) -> float:
9
+ """Material-count heuristic: positive favors white."""
10
+ score = 0.0
11
+ for piece_type in PIECE_VALUES:
12
+ score += len(board.pieces(piece_type, chess.WHITE)) * PIECE_VALUES[piece_type]
13
+ score -= len(board.pieces(piece_type, chess.BLACK)) * PIECE_VALUES[piece_type]
14
+ return math.tanh(score / 10.0)
15
+
16
+
17
+ class MinimaxSearch:
18
+ """Minimax search with top-N move pruning.
19
+
20
+ At each node, evaluates all legal moves with the reward function,
21
+ keeps the top N candidates, and recurses to the given depth.
22
+ Alternates between maximizing (white) and minimizing (black).
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ reward_fn: Callable[[chess.Board], float],
28
+ depth: int = 3,
29
+ top_n: int = 5,
30
+ ):
31
+ self.reward_fn = reward_fn
32
+ self.depth = depth
33
+ self.top_n = top_n
34
+
35
+ def search(self, board: chess.Board) -> chess.Move:
36
+ """Return the best move for the current side to play."""
37
+ legal_moves = list(board.legal_moves)
38
+ if not legal_moves:
39
+ raise ValueError("No legal moves available")
40
+ if len(legal_moves) == 1:
41
+ return legal_moves[0]
42
+
43
+ maximizing = board.turn == chess.WHITE
44
+
45
+ # Score every legal move with a shallow reward evaluation
46
+ scored_moves = []
47
+ for move in legal_moves:
48
+ board.push(move)
49
+ score = self.reward_fn(board)
50
+ board.pop()
51
+ scored_moves.append((score, move))
52
+
53
+ # Keep top N candidates (best for current side)
54
+ scored_moves.sort(key=lambda x: x[0], reverse=maximizing)
55
+ candidates = scored_moves[: self.top_n]
56
+
57
+ # Recurse on each candidate to find the best
58
+ best_move = candidates[0][1]
59
+ best_value = float("-inf") if maximizing else float("inf")
60
+
61
+ for _, move in candidates:
62
+ board.push(move)
63
+ value = self._minimax(board, self.depth - 1, not maximizing)
64
+ board.pop()
65
+
66
+ if maximizing and value > best_value:
67
+ best_value = value
68
+ best_move = move
69
+ elif not maximizing and value < best_value:
70
+ best_value = value
71
+ best_move = move
72
+
73
+ return best_move
74
+
75
+ def _minimax(self, board: chess.Board, depth: int, maximizing: bool) -> float:
76
+ if depth <= 0 or board.is_game_over():
77
+ return self._terminal_eval(board)
78
+
79
+ legal_moves = list(board.legal_moves)
80
+ if not legal_moves:
81
+ return self._terminal_eval(board)
82
+
83
+ # Score all moves, keep top N for the current side
84
+ scored_moves = []
85
+ for move in legal_moves:
86
+ board.push(move)
87
+ score = self.reward_fn(board)
88
+ board.pop()
89
+ scored_moves.append((score, move))
90
+
91
+ scored_moves.sort(key=lambda x: x[0], reverse=maximizing)
92
+ candidates = scored_moves[: self.top_n]
93
+
94
+ if maximizing:
95
+ best = float("-inf")
96
+ for _, move in candidates:
97
+ board.push(move)
98
+ best = max(best, self._minimax(board, depth - 1, False))
99
+ board.pop()
100
+ return best
101
+ else:
102
+ best = float("inf")
103
+ for _, move in candidates:
104
+ board.push(move)
105
+ best = min(best, self._minimax(board, depth - 1, True))
106
+ board.pop()
107
+ return best
108
+
109
+ def _terminal_eval(self, board: chess.Board) -> float:
110
+ """Evaluate a terminal or leaf node."""
111
+ if board.is_checkmate():
112
+ # The side to move is checkmated
113
+ return -1.0 if board.turn == chess.WHITE else 1.0
114
+ if board.is_game_over():
115
+ return 0.0
116
+ return self.reward_fn(board)
src/model.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import chess
5
+
6
+ from src.tokenizer import Tokenizer
7
+
8
+ CLS_TOKEN = "[CLS]"
9
+ PAD_TOKEN = "[PAD]"
10
+
11
+ PIECE_VALUES = {
12
+ chess.PAWN: 1,
13
+ chess.KNIGHT: 3,
14
+ chess.BISHOP: 3,
15
+ chess.ROOK: 5,
16
+ chess.QUEEN: 9,
17
+ chess.KING: 0,
18
+ }
19
+
20
+ BOARD_PLANES = 19
21
+
22
+ def board_to_planes(board: chess.Board) -> torch.Tensor:
23
+ """chess.Board -> (19, 8, 8) float tensor."""
24
+ planes = torch.zeros(BOARD_PLANES, 8, 8, dtype=torch.float32)
25
+ pieces = [chess.PAWN, chess.KNIGHT, chess.BISHOP, chess.ROOK, chess.QUEEN, chess.KING]
26
+ colors = [chess.WHITE, chess.BLACK]
27
+ piece_to_plane = {(piece, color) : 6 * color_num + piece_num for piece_num, piece in enumerate(pieces) for color_num, color in enumerate(colors)}
28
+
29
+ for sq, piece in board.piece_map().items():
30
+ r, c = chess.square_rank(sq), chess.square_file(sq)
31
+
32
+ planes[piece_to_plane[(piece.piece_type, piece.color)], r, c] = 1.0
33
+
34
+ if board.turn == chess.WHITE:
35
+ planes[12].fill_(1.0)
36
+
37
+ if board.has_kingside_castling_rights(chess.WHITE): planes[13].fill_(1.0)
38
+ if board.has_queenside_castling_rights(chess.WHITE): planes[14].fill_(1.0)
39
+ if board.has_kingside_castling_rights(chess.BLACK): planes[15].fill_(1.0)
40
+ if board.has_queenside_castling_rights(chess.BLACK): planes[16].fill_(1.0)
41
+ if board.ep_square is not None:
42
+ r, c = chess.square_rank(board.ep_square), chess.square_file(board.ep_square)
43
+ planes[17, r, c] = 1.0
44
+ planes[18].fill_(min(board.halfmove_clock, 100) / 100.0)
45
+
46
+ return planes
47
+
48
+ def _group_norm(channels: int, groups: int = 32) -> nn.GroupNorm:
49
+ return nn.GroupNorm(num_groups=min(groups, channels), num_channels=channels)
50
+
51
+
52
+ class ResidualBlock(nn.Module):
53
+ def __init__(self, channels: int):
54
+ super().__init__()
55
+ self.conv1 = nn.Conv2d(channels, channels, 3, padding=1, bias=False)
56
+ self.norm1 = _group_norm(channels)
57
+ self.conv2 = nn.Conv2d(channels, channels, 3, padding=1, bias=False)
58
+ self.norm2 = _group_norm(channels)
59
+
60
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
61
+ h = torch.relu(self.norm1(self.conv1(x)))
62
+ h = self.norm2(self.conv2(h))
63
+ return torch.relu(h + x)
64
+ class BoardCNN(nn.Module):
65
+ def __init__(self, d_model, channels=128, num_blocks=6):
66
+ super().__init__()
67
+ self.stem = nn.Sequential(
68
+ nn.Conv2d(BOARD_PLANES, channels, 3, padding=1, bias=False),
69
+ _group_norm(channels),
70
+ nn.ReLU(inplace=True),
71
+ )
72
+ self.blocks = nn.Sequential(*[ResidualBlock(channels) for _ in range(num_blocks)])
73
+ self.proj = nn.Linear(channels, d_model)
74
+ self.square_pos = nn.Embedding(64, d_model)
75
+
76
+ def forward(self, planes : torch.Tensor) -> torch.Tensor:
77
+ x = self.stem(planes)
78
+ x = self.blocks(x) # (N, C, 8, 8)
79
+ x = x.permute(0, 2, 3, 1).reshape(x.size(0), 64, -1) # (n, 64, C)
80
+ x = self.proj(x) + self.square_pos.weight # (n, 64, d_model)
81
+ return x
82
+
83
+
84
+ class CrossAttnBlock(nn.Module):
85
+ def __init__(self, d_model, n_head, dim_ff, dropout):
86
+ super().__init__()
87
+ self.self_attn = nn.MultiheadAttention(d_model, n_head, dropout = dropout, batch_first=True)
88
+
89
+ self.cross_attn = nn.MultiheadAttention(d_model, n_head, dropout = dropout, batch_first = True)
90
+
91
+ self.ff = nn.Sequential(
92
+ nn.Linear(d_model, dim_ff), nn.GELU(), nn.Linear(dim_ff, d_model)
93
+ )
94
+
95
+ self.norm1 = nn.LayerNorm(d_model)
96
+ self.norm2 = nn.LayerNorm(d_model)
97
+ self.norm3 = nn.LayerNorm(d_model)
98
+ self.drop = nn.Dropout(dropout)
99
+ #Adding this gate which is init to 0 so cross-attn starts disabled
100
+ self.cross_gate = nn.Parameter(torch.zeros(1))
101
+
102
+ def forward(self, moves, board, key_padding_mask, attn_mask):
103
+ """
104
+ moves: (B, T, d)
105
+ board: (B, T, 64, d) -- per-position K/V banks
106
+ key_padding_mask: (B, T) -- True = padded move position
107
+ attn_mask: (T, T) -- causal mask for self-attn
108
+ """
109
+ m = self.norm1(moves)
110
+ sa, _ = self.self_attn(m, m, m, attn_mask = attn_mask, key_padding_mask=key_padding_mask, need_weights=False)
111
+
112
+ moves = moves + self.drop(sa)
113
+
114
+ B, T, d = moves.shape
115
+ q = self.norm2(moves).reshape(B * T, 1, d)
116
+ kv = board.reshape(B * T, 64, d)
117
+ ca, _ = self.cross_attn(q, kv, kv, need_weights = False)
118
+ ca = ca.reshape(B, T, d)
119
+ moves = moves + self.drop(self.cross_gate.tanh() * ca)
120
+
121
+ # FFN
122
+
123
+ moves = moves + self.drop(self.ff(self.norm3(moves)))
124
+ return moves
125
+
126
+ class PositionalEncoding(nn.Module):
127
+ def __init__(self, d_model: int, max_len: int = 512, dropout: float = 0.1):
128
+ super().__init__()
129
+ self.dropout = nn.Dropout(p=dropout)
130
+ pe = torch.zeros(max_len, d_model)
131
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
132
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
133
+ pe[:, 0::2] = torch.sin(position * div_term)
134
+ pe[:, 1::2] = torch.cos(position * div_term)
135
+ pe = pe.unsqueeze(0) # (1, max_len, d_model)
136
+ self.register_buffer("pe", pe)
137
+
138
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
139
+ x = x + self.pe[:, :x.size(1)]
140
+ return self.dropout(x)
141
+
142
+
143
+ class ChessRewardModel(nn.Module):
144
+ def __init__(
145
+ self,
146
+ vocab_size: int,
147
+ d_model: int = 768,
148
+ nhead: int = 12,
149
+ num_layers: int = 8,
150
+ dim_feedforward: int = 3072,
151
+ max_seq_len: int = 128,
152
+ dropout: float = 0.1,
153
+ ):
154
+ super().__init__()
155
+ self.token_embedding = nn.Embedding(vocab_size, d_model)
156
+ self.pos_encoding = PositionalEncoding(d_model, max_seq_len, dropout)
157
+ encoder_layer = nn.TransformerEncoderLayer(
158
+ d_model, nhead, dim_feedforward, dropout, batch_first=True
159
+ )
160
+ self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
161
+ self.reward_head = nn.Linear(d_model, 1)
162
+
163
+ def forward(
164
+ self,
165
+ token_ids: torch.Tensor,
166
+ attention_mask: torch.Tensor | None = None,
167
+ ) -> torch.Tensor:
168
+ """
169
+ Args:
170
+ token_ids: (batch, seq_len) int tensor with CLS prepended
171
+ attention_mask: (batch, seq_len) bool tensor, True where padded
172
+ Returns:
173
+ (batch,) float tensor bounded to [-1, 1]
174
+ """
175
+ x = self.token_embedding(token_ids)
176
+ x = self.pos_encoding(x)
177
+ x = self.encoder(x, src_key_padding_mask=attention_mask)
178
+ cls_hidden = x[:, 0, :] # CLS token at position 0
179
+ reward = self.reward_head(cls_hidden).squeeze(-1)
180
+ return torch.tanh(reward)
181
+
182
+ class ChessPolicyModel(nn.Module):
183
+ """Causal next-move predictor with per-position live-board cross-attention.
184
+
185
+ Two streams flow through every block:
186
+ - Move stream: token embeddings + sinusoidal positional encoding, doing
187
+ causal self-attention over the move history.
188
+ - Board stream: a (B, T, 64, d_model) bank of CNN-encoded board features
189
+ where bank `t` is the state after token_ids[1..t] have been played.
190
+ At each block, the move query at position t cross-attends only to its
191
+ own 64 board-square keys — implicit causality via data layout, no
192
+ masking needed.
193
+
194
+ The board representation never depends on a token the model is being
195
+ asked to predict, so multi-position LM-style training is leak-safe.
196
+ """
197
+ def __init__(
198
+ self,
199
+ vocab_size: int,
200
+ d_model: int = 768,
201
+ nhead: int = 12,
202
+ num_layers: int = 8,
203
+ dim_feedforward: int = 3072,
204
+ max_seq_len: int = 128,
205
+ dropout: float = 0.1,
206
+ cnn_channels: int = 128,
207
+ cnn_blocks: int = 6,
208
+ ):
209
+ super().__init__()
210
+ self.vocab_size = vocab_size
211
+ self.token_embedding = nn.Embedding(vocab_size, d_model)
212
+ self.board_cnn = BoardCNN(d_model, cnn_channels, cnn_blocks)
213
+ self.pos_encoding = PositionalEncoding(d_model, max_seq_len, dropout)
214
+ self.blocks = nn.ModuleList([
215
+ CrossAttnBlock(d_model, nhead, dim_feedforward, dropout)
216
+ for _ in range(num_layers)
217
+ ])
218
+ self.norm_out = nn.LayerNorm(d_model)
219
+ self.prob_head = nn.Linear(d_model, vocab_size)
220
+
221
+ def forward(
222
+ self,
223
+ token_ids: torch.Tensor,
224
+ board_planes: torch.Tensor,
225
+ attention_mask: torch.Tensor | None = None,
226
+ ) -> torch.Tensor:
227
+ """
228
+ Args:
229
+ token_ids: (B, T) int — CLS at position 0
230
+ board_planes: (B, T, 19, 8, 8) float — per-position live planes;
231
+ planes[:, t] is the board state after token_ids[1..t]
232
+ attention_mask: (B, T) bool — True where padded
233
+ Returns:
234
+ (B, T, vocab_size) raw logits at every position
235
+ """
236
+ B, T = token_ids.shape
237
+
238
+ moves = self.token_embedding(token_ids)
239
+ moves = self.pos_encoding(moves) # (B, T, d)
240
+
241
+ # Vectorize the CNN over (B*T) boards — one big conv batch, not a loop.
242
+ planes_flat = board_planes.reshape(B * T, BOARD_PLANES, 8, 8)
243
+ board_feats = self.board_cnn(planes_flat) # (B*T, 64, d)
244
+ board_feats = board_feats.reshape(B, T, 64, -1) # (B, T, 64, d)
245
+
246
+ # Bool causal mask (True = masked future position) to match the bool
247
+ # key_padding_mask. PyTorch deprecates mixing float and bool masks.
248
+ causal = torch.triu(
249
+ torch.ones(T, T, dtype=torch.bool, device=token_ids.device), diagonal=1
250
+ )
251
+ for blk in self.blocks:
252
+ moves = blk(moves, board_feats, attention_mask, causal)
253
+
254
+ moves = self.norm_out(moves)
255
+ return self.prob_head(moves) # (B, T, vocab)
256
+
257
+
258
+ class DummyRewardModel:
259
+ """Material-count heuristic for MCTS testing."""
260
+ def __call__(self, board: chess.Board) -> float:
261
+ score = 0.0
262
+ for piece_type in PIECE_VALUES:
263
+ score += len(board.pieces(piece_type, chess.WHITE)) * PIECE_VALUES[piece_type]
264
+ score -= len(board.pieces(piece_type, chess.BLACK)) * PIECE_VALUES[piece_type]
265
+ return math.tanh(score / 10.0)
266
+
267
+
268
+ class RewardModelInference:
269
+ """Wraps ChessRewardModel + Tokenizer for use in minimax"""
270
+ def __init__(self, model: ChessRewardModel, tokenizer: Tokenizer, device: str = "cpu"):
271
+ self.model = model
272
+ self.tokenizer = tokenizer
273
+ self.device = device
274
+ self.cls_id = tokenizer.symbol_to_token[CLS_TOKEN]
275
+ self.pad_id = tokenizer.symbol_to_token[PAD_TOKEN]
276
+ self.model.eval()
277
+
278
+ @torch.no_grad()
279
+ def __call__(self, board: chess.Board, max_seq_len: int = 128) -> float:
280
+ moves_uci = [move.uci() for move in board.move_stack]
281
+
282
+ # Keep the most recent moves to stay within the training sequence length.
283
+ # CLS occupies position 0, so cap move history at max_seq_len - 1.
284
+ moves_uci = moves_uci[-(max_seq_len - 1):]
285
+ token_ids = [self.cls_id] + self.tokenizer.encode_moves(moves_uci)
286
+ token_tensor = torch.tensor([token_ids], dtype=torch.long, device=self.device)
287
+ reward = self.model(token_tensor)
288
+ return reward.item()
289
+
290
+ class PolicyModelInference:
291
+ """Wraps ChessPolicyModel + Tokenizer"""
292
+
293
+ def __init__(self, model: ChessPolicyModel, tokenizer: Tokenizer, device: str = "cpu"):
294
+ self.model = model
295
+ self.tokenizer = tokenizer
296
+ self.device = device
297
+ self.cls_id = tokenizer.symbol_to_token[CLS_TOKEN]
298
+ self.pad_id = tokenizer.symbol_to_token[PAD_TOKEN]
299
+ self.model.eval()
300
+
301
+ @torch.no_grad()
302
+ def __call__(self, board: chess.Board) -> str:
303
+ moves_uci = [move.uci() for move in board.move_stack]
304
+ token_ids = [self.cls_id] + self.tokenizer.encode_moves(moves_uci)
305
+ token_tensor = torch.tensor([token_ids], dtype=torch.long, device=self.device)
306
+
307
+ # Replay the full move history on a fresh board, snapshotting planes
308
+ # at every position. planes[0] = standard starting board (model has
309
+ # only seen [CLS]); planes[t] = state after the first t moves played.
310
+ # This matches the training pipeline (ChessPolicyDataset._replay_planes
311
+ # with start_board=chess.Board()) exactly.
312
+ replay_board = chess.Board()
313
+ plane_list = [board_to_planes(replay_board)]
314
+ for uci in moves_uci:
315
+ replay_board.push(chess.Move.from_uci(uci))
316
+ plane_list.append(board_to_planes(replay_board))
317
+ planes = torch.stack(plane_list).unsqueeze(0).to(self.device) # (1, T, 19, 8, 8)
318
+
319
+ logits = self.model(token_tensor, planes) # (1, T, vocab_size)
320
+ last_logits = logits[0, -1] # last position predicts the next move
321
+
322
+ legal_move_ids = [self.tokenizer.symbol_to_token[move.uci()] for move in board.legal_moves]
323
+ mask = torch.full((self.tokenizer.language_size,), float('-inf'), device=self.device)
324
+ mask[legal_move_ids] = 0.0
325
+ best_move_idx = (last_logits + mask).argmax().item()
326
+
327
+ return self.tokenizer.token_to_symbol[best_move_idx]
src/tokenizer.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import typing
4
+ from collections import deque, defaultdict
5
+
6
+
7
+ class Tokenizer():
8
+ def __init__(self):
9
+
10
+ self.symbol_set : set = None
11
+ self.symbol_to_token = {}
12
+ self.token_to_symbol = {}
13
+ self.language_size = 0
14
+ self.corpus = None
15
+ def train_tokenizer(self, input, max_language_size: int) -> None:
16
+ if type(input) == str:
17
+ self.corpus = input.split(",")
18
+ else:
19
+ self.corpus = input
20
+ self.symbol_set = set(self.corpus)
21
+ for sym in self.symbol_set:
22
+ self.symbol_to_token[sym] = self.language_size
23
+ self.token_to_symbol[self.language_size] = sym
24
+ self.language_size += 1
25
+ # Converted everythign to tokens from symbolic form
26
+ self.corpus = np.array([self.symbol_to_token[sym] for sym in self.corpus], dtype=int)
27
+
28
+ while self.language_size < max_language_size:
29
+ temp_corpus = self.corpus
30
+ common_pair = None
31
+ highest_pair_count = 0
32
+ pair_counts = defaultdict(int)
33
+ for i in range(len(temp_corpus)-1):
34
+ pair = (temp_corpus[i], temp_corpus[i+1])
35
+ pair_counts[pair] += 1
36
+ if (pair_counts[pair] > highest_pair_count):
37
+ highest_pair_count = pair_counts[pair]
38
+ common_pair = pair
39
+ synthetic_symbol = self.token_to_symbol[common_pair[0]] + self.token_to_symbol[common_pair[1]]
40
+
41
+ self.symbol_to_token[synthetic_symbol] = self.language_size
42
+ self.token_to_symbol[self.language_size] = synthetic_symbol
43
+
44
+ self.language_size += 1
45
+ combine_tokens = deque(temp_corpus)
46
+ self.corpus = []
47
+
48
+ while (len(combine_tokens) > 1):
49
+ first_elem = combine_tokens.popleft()
50
+ second_elem = combine_tokens.popleft()
51
+
52
+ if ((first_elem, second_elem) == common_pair):
53
+ combine_tokens.appendleft(self.language_size - 1)
54
+
55
+ else:
56
+ self.corpus.append(first_elem)
57
+ self.corpus.append(second_elem)
58
+ if (len(combine_tokens) > 0):
59
+ self.corpus.append(combine_tokens.popleft())
60
+
61
+ self.corpus = None
62
+
63
+ def decode(self, tokens: list[int]) -> str:
64
+ return "".join([self.token_to_symbol[t] for t in tokens])
65
+
66
+ def encode(self, message: str):
67
+ char_list = list(message)
68
+ char_inputs = deque(char_list)
69
+
70
+ result_tokens = []
71
+ curr_symbol = ""
72
+ while (len(char_inputs) > 0):
73
+ f_char = char_inputs.popleft()
74
+ curr_symbol += f_char
75
+
76
+ if (curr_symbol not in self.symbol_to_token.keys()):
77
+ curr_symbol = curr_symbol[:-1]
78
+ result_tokens.append(self.symbol_to_token[curr_symbol])
79
+ char_inputs.appendleft(f_char)
80
+ curr_symbol = ""
81
+ if (len(curr_symbol) > 0):
82
+ result_tokens.append(self.symbol_to_token[curr_symbol])
83
+
84
+ return result_tokens
85
+
86
+ def encode_moves(self, moves: list[str]) -> list[int]:
87
+ return [self.symbol_to_token[move] for move in moves]
88
+
89
+ def add_special_tokens(self, tokens: list[str]) -> dict[str, int]:
90
+ mapping = {}
91
+ for tok in tokens:
92
+ self.symbol_to_token[tok] = self.language_size
93
+ self.token_to_symbol[self.language_size] = tok
94
+ mapping[tok] = self.language_size
95
+ self.language_size += 1
96
+ return mapping
97
+
98
+
99
+ class DataLoader():
100
+ corpus = None
101
+ def __init__(self, file_name: str):
102
+ with open(file_name, "r") as f:
103
+ self.corpus = f.read()
104
+
src/train.py ADDED
@@ -0,0 +1,1319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import atexit
3
+ import contextlib
4
+ import math # noqa: F401 (used in eval_policy)
5
+ import multiprocessing as mp
6
+ import re
7
+ import shutil
8
+ import time
9
+ import numpy as np
10
+ from pathlib import Path
11
+ import chess
12
+ import chess.engine
13
+ import torch
14
+ import torch.nn.functional as F
15
+ from torch.utils.data import Dataset, DataLoader
16
+ from torch.utils.tensorboard import SummaryWriter
17
+ from datasets import load_dataset
18
+
19
+ from tokenizer import Tokenizer
20
+ from model import (
21
+ ChessRewardModel,
22
+ ChessPolicyModel,
23
+ DummyRewardModel,
24
+ CLS_TOKEN,
25
+ PAD_TOKEN,
26
+ board_to_planes,
27
+ )
28
+
29
+ STOCKFISH_PATH = shutil.which("stockfish") or "/usr/local/bin/stockfish"
30
+
31
+ RESULT_TOKENS = {"1-0", "0-1", "1/2-1/2", "*"}
32
+ MOVE_NUMBER_RE = re.compile(r"^\d+\.(\.\.)?$")
33
+ # Brace-delimited PGN comments like {[%eval 0.37]} and {[%clk 0:05:00]}.
34
+ # Non-greedy to handle multiple comments in one movetext.
35
+ BRACE_COMMENT_RE = re.compile(r"\{[^}]*\}")
36
+
37
+
38
+ def normalize_cp(centipawns: int) -> float:
39
+ """Map centipawn score to [-1, 1] using tanh scaling."""
40
+ return math.tanh(centipawns / 400.0)
41
+
42
+
43
+ def material_eval(board: chess.Board) -> float:
44
+ """Material-count evaluation as a fallback for Stockfish."""
45
+ return DummyRewardModel()(board)
46
+
47
+
48
+ class StockfishEvaluator:
49
+ """Wraps a persistent Stockfish engine for batch evaluation."""
50
+ def __init__(self, engine_path: str = STOCKFISH_PATH, depth: int = 15):
51
+ self.engine = chess.engine.SimpleEngine.popen_uci(engine_path)
52
+ self.depth = depth
53
+
54
+ def __call__(self, board: chess.Board) -> float:
55
+ info = self.engine.analyse(board, chess.engine.Limit(depth=self.depth))
56
+ score = info["score"].white()
57
+ if score.is_mate():
58
+ return 1.0 if score.mate() > 0 else -1.0
59
+ return normalize_cp(score.score())
60
+
61
+ def close(self):
62
+ self.engine.quit()
63
+
64
+
65
+ def parse_movetext(movetext: str) -> list[str]:
66
+ """Parse PGN movetext into a list of SAN moves.
67
+
68
+ Handles Lichess-style annotations like '{[%eval 0.37]}' and '{[%clk 0:05:00]}'
69
+ by stripping them before tokenization. Without this, annotated games get
70
+ truncated mid-replay when parse_san chokes on comment fragments.
71
+
72
+ Input format: '1. d4 {[%eval 0.13]} d5 2. Nf3 ... 1-0'
73
+ Returns: ['d4', 'd5', 'Nf3', ...]
74
+ """
75
+ cleaned = BRACE_COMMENT_RE.sub(" ", movetext)
76
+ tokens = cleaned.split()
77
+ moves = []
78
+ for tok in tokens:
79
+ if tok in RESULT_TOKENS:
80
+ continue
81
+ if MOVE_NUMBER_RE.match(tok):
82
+ continue
83
+ moves.append(tok)
84
+ return moves
85
+
86
+
87
+ def load_filtered_dataset(min_elo: int = 1500, min_rows: int = 100_000):
88
+ """Load and filter the Lichess HuggingFace dataset.
89
+
90
+ Filters:
91
+ - WhiteElo >= min_elo AND BlackElo >= min_elo
92
+ - Termination == 'Normal'
93
+
94
+ Returns the filtered dataset and raises ValueError if too few rows.
95
+ """
96
+ print("Loading Lichess dataset from HuggingFace...")
97
+ ds = load_dataset("Lichess/standard-chess-games", split="train", streaming=True)
98
+
99
+ print(f"Filtering for Elo >= {min_elo} and Termination == 'Normal'...")
100
+ ds_filtered = ds.filter(
101
+ lambda row: (
102
+ row["WhiteElo"] is not None
103
+ and row["BlackElo"] is not None
104
+ and row["WhiteElo"] >= min_elo
105
+ and row["BlackElo"] >= min_elo
106
+ and row.get("Termination") == "Normal"
107
+ )
108
+ )
109
+
110
+ # Materialize enough rows to validate the threshold
111
+ print(f"Collecting at least {min_rows:,} filtered games...")
112
+ rows = []
113
+ for row in ds_filtered:
114
+ rows.append(row)
115
+ if len(rows) % 50_000 == 0:
116
+ print(f" collected {len(rows):,} games so far...")
117
+ if len(rows) >= min_rows:
118
+ break
119
+
120
+ if len(rows) < min_rows:
121
+ raise ValueError(
122
+ f"Only found {len(rows):,} games matching filters, "
123
+ f"need at least {min_rows:,}."
124
+ )
125
+
126
+ print(f"Collected {len(rows):,} games (target met).")
127
+ return rows
128
+
129
+
130
+ def _enumerate_all_uci_moves() -> list[str]:
131
+ """Enumerate every UCI move string that can legally appear in a chess game.
132
+
133
+ Uses direct geometric enumeration rather than board simulation to avoid
134
+ edge cases where king placement blocks valid destination squares.
135
+ Covers all piece movement patterns: lines (rook/queen), diagonals
136
+ (bishop/queen), L-shapes (knight), and pawn promotions.
137
+ """
138
+ seen: set[str] = set()
139
+ for from_sq in chess.SQUARES:
140
+ fr = chess.square_rank(from_sq)
141
+ ff = chess.square_file(from_sq)
142
+ for to_sq in chess.SQUARES:
143
+ if from_sq == to_sq:
144
+ continue
145
+ tr = chess.square_rank(to_sq)
146
+ tf = chess.square_file(to_sq)
147
+ dr = abs(tr - fr)
148
+ df = abs(tf - ff)
149
+
150
+ is_line = (dr == 0 or df == 0) # rook / queen
151
+ is_diag = (dr == df) # bishop / queen
152
+ is_knight = (dr == 2 and df == 1) or (dr == 1 and df == 2)
153
+
154
+ if not (is_line or is_diag or is_knight):
155
+ continue
156
+
157
+ seen.add(chess.Move(from_sq, to_sq).uci())
158
+
159
+ # Promotion variants: pawn on 7th rank advancing to 8th (or 2nd→1st)
160
+ if ((fr == 6 and tr == 7) or (fr == 1 and tr == 0)) and df <= 1:
161
+ for promo in (chess.QUEEN, chess.ROOK, chess.BISHOP, chess.KNIGHT):
162
+ seen.add(chess.Move(from_sq, to_sq, promotion=promo).uci())
163
+
164
+ return list(seen)
165
+
166
+
167
+ def _weighted_sample(eligible: list[int], k: int, skew_exponent: float, seed: int) -> set[int]:
168
+ """Sample k positions from eligible without replacement, skewed toward later positions.
169
+
170
+ Weights grow as (position_rank + 1)^skew_exponent so later positions in a game
171
+ are proportionally more likely to be selected. skew_exponent=1.0 gives linear
172
+ weighting; higher values concentrate more mass at the end of the game.
173
+ """
174
+ n = len(eligible)
175
+ k = min(k, n)
176
+ if k == n:
177
+ return set(eligible)
178
+ weights = np.array([(i + 1) ** skew_exponent for i in range(n)], dtype=np.float64)
179
+ weights /= weights.sum()
180
+ rng = np.random.default_rng(seed)
181
+ chosen = rng.choice(n, size=k, replace=False, p=weights)
182
+ return {eligible[i] for i in chosen}
183
+
184
+
185
+ def build_tokenizer_from_games(games: list[dict] | None = None) -> Tokenizer:
186
+ """Build a move-level tokenizer covering all 1968 UCI moves."""
187
+ uci_moves = _enumerate_all_uci_moves()
188
+ print(f" building tokenizer from {len(set(uci_moves)):,} UCI moves (no BPE)")
189
+ tokenizer = Tokenizer()
190
+ tokenizer.train_tokenizer(uci_moves, max_language_size=len(set(uci_moves)))
191
+ tokenizer.add_special_tokens([CLS_TOKEN, PAD_TOKEN])
192
+ return tokenizer
193
+
194
+
195
+ def _load_train_idx(out_dir: Path, name: str, n: int) -> np.ndarray | None:
196
+ """If `{name}_test_indices.npy` exists, return the complement (training-only
197
+ indices into the full memmap). Returns None when no test split is recorded —
198
+ in that case the caller indexes into the memmap directly.
199
+
200
+ Names ending in '_test' always return None (the test memmap should not
201
+ exclude itself).
202
+ """
203
+ if name.endswith("_test"):
204
+ return None
205
+ test_idx_file = out_dir / f"{name}_test_indices.npy"
206
+ if not test_idx_file.exists():
207
+ return None
208
+ test_idx = np.load(test_idx_file)
209
+ mask = np.ones(n, dtype=bool)
210
+ mask[test_idx] = False
211
+ return np.where(mask)[0]
212
+
213
+
214
+ class ChessPositionDataset(Dataset):
215
+ def __init__(
216
+ self,
217
+ games: list[dict],
218
+ tokenizer: Tokenizer,
219
+ eval_fn=material_eval,
220
+ sample_rate: float = 0.25,
221
+ skew_exponent: float = 1.5,
222
+ ):
223
+ self.tokenizer = tokenizer
224
+ self.cls_id = tokenizer.symbol_to_token[CLS_TOKEN]
225
+ self.samples: list[tuple[list[int], float]] = []
226
+ self._memmap = False
227
+ self._train_idx: np.ndarray | None = None
228
+ self._generate_samples(games, eval_fn, sample_rate, skew_exponent)
229
+
230
+ def _generate_samples(self, games, eval_fn, sample_rate, skew_exponent):
231
+ for idx, game in enumerate(games):
232
+ movetext = game.get("movetext", "")
233
+ if not movetext:
234
+ continue
235
+ move_sans = parse_movetext(movetext)
236
+ if len(move_sans) < 2:
237
+ continue
238
+
239
+ board = chess.Board()
240
+ eligible = list(range(len(move_sans)))
241
+ # Scale sample count with game length — longer games have more
242
+ # evaluation swings and contribute proportionally more samples.
243
+ num_positions = max(1, int(len(move_sans) * sample_rate))
244
+ # Deterministic weighted sampling seeded by game index so serial
245
+ # and parallel paths produce identical sample sets for the same input.
246
+ sample_indices = _weighted_sample(eligible, num_positions, skew_exponent, seed=idx)
247
+
248
+ valid_moves = []
249
+ for i, san in enumerate(move_sans):
250
+ try:
251
+ move = board.parse_san(san)
252
+ board.push(move)
253
+ valid_moves.append(move.uci())
254
+ except (chess.InvalidMoveError, chess.AmbiguousMoveError):
255
+ break
256
+
257
+ if i in sample_indices:
258
+ token_ids = [self.cls_id] + self.tokenizer.encode_moves(valid_moves)
259
+ score = eval_fn(board)
260
+ self.samples.append((token_ids, score))
261
+
262
+ if (idx + 1) % 10_000 == 0:
263
+ print(f" processed {idx + 1:,} games, {len(self.samples):,} positions...")
264
+
265
+ def __len__(self) -> int:
266
+ if self._memmap:
267
+ if self._train_idx is not None:
268
+ return len(self._train_idx)
269
+ return len(self._mm_labels)
270
+ return len(self.samples)
271
+
272
+ def __getitem__(self, idx: int):
273
+ if self._memmap:
274
+ if self._train_idx is not None:
275
+ idx = int(self._train_idx[idx])
276
+ tokens = torch.from_numpy(np.array(self._mm_tokens[idx], dtype=np.int32)).long()
277
+ length = int(self._mm_lengths[idx])
278
+ mask = torch.arange(tokens.shape[0]) >= length # True = padded
279
+ return tokens, mask, float(self._mm_labels[idx])
280
+ token_ids, score = self.samples[idx]
281
+ return torch.tensor(token_ids, dtype=torch.long), score
282
+
283
+ @classmethod
284
+ def from_samples(cls, samples, tokenizer: Tokenizer):
285
+ """Build a dataset from pre-generated (token_ids, score) samples."""
286
+ inst = cls.__new__(cls)
287
+ inst.tokenizer = tokenizer
288
+ inst.cls_id = tokenizer.symbol_to_token[CLS_TOKEN]
289
+ inst.samples = list(samples)
290
+ inst._memmap = False
291
+ return inst
292
+
293
+ @classmethod
294
+ def from_file(cls, samples_path: str, tokenizer: Tokenizer):
295
+ """Load (token_ids, score) samples from a torch.save file."""
296
+ samples = torch.load(samples_path, weights_only=False)
297
+ return cls.from_samples(samples, tokenizer)
298
+
299
+ @classmethod
300
+ def from_memmap(cls, out_dir: Path, name: str, tokenizer: Tokenizer):
301
+ """Load pre-padded samples from memory-mapped arrays (fast DataLoader path).
302
+
303
+ If a sibling file `{name}_test_indices.npy` exists, those indices are
304
+ excluded from this dataset — used to make training disjoint from the
305
+ held-out test split that shares the same underlying .bin file.
306
+ """
307
+ meta = torch.load(out_dir / f"{name}_meta.pt", weights_only=True)
308
+ n, max_len = meta["n"], meta["max_len"]
309
+ inst = cls.__new__(cls)
310
+ inst.tokenizer = tokenizer
311
+ inst.cls_id = tokenizer.symbol_to_token[CLS_TOKEN]
312
+ inst._memmap = True
313
+ inst._mm_tokens = np.memmap(out_dir / f"{name}_tokens.bin", dtype=np.int32, mode="r", shape=(n, max_len))
314
+ inst._mm_labels = np.memmap(out_dir / f"{name}_labels.bin", dtype=np.float32, mode="r", shape=(n,))
315
+ inst._mm_lengths = np.memmap(out_dir / f"{name}_lengths.bin", dtype=np.int32, mode="r", shape=(n,))
316
+ inst._train_idx = _load_train_idx(out_dir, name, n)
317
+ return inst
318
+
319
+
320
+ # ---------------------------------------------------------------------------
321
+ # Parallel Stockfish-backed sample generation.
322
+ #
323
+ # One Stockfish subprocess per worker process. Each worker:
324
+ # 1. receives a game + its index (used to seed a local random.Random)
325
+ # 2. replays the game, samples positions, tokenizes move prefixes
326
+ # 3. evaluates each sampled position with its own Stockfish engine
327
+ # 4. returns a list of (token_ids, score) tuples
328
+ #
329
+ # The main process collects results via imap_unordered and flattens them.
330
+ # ---------------------------------------------------------------------------
331
+
332
+ # Module-level state populated by _init_worker in each spawned process.
333
+ _worker_engine = None
334
+ _worker_tokenizer = None
335
+ _worker_cls_id = None
336
+ _worker_sample_rate = None
337
+ _worker_skew = None
338
+ _worker_depth = None
339
+
340
+
341
+ def _shutdown_worker():
342
+ """Called at worker exit to cleanly quit the Stockfish engine."""
343
+ global _worker_engine
344
+ if _worker_engine is not None:
345
+ try:
346
+ _worker_engine.quit()
347
+ except Exception:
348
+ pass
349
+ _worker_engine = None
350
+
351
+
352
+ def _init_worker(engine_path, depth, tokenizer, cls_id, sample_rate, skew_exponent):
353
+ """Pool initializer: create one Stockfish engine per worker.
354
+
355
+ If engine_path is None, workers fall back to material_eval. This lets
356
+ tests exercise the parallel machinery without requiring Stockfish.
357
+ """
358
+ global _worker_engine, _worker_tokenizer, _worker_cls_id
359
+ global _worker_sample_rate, _worker_skew, _worker_depth
360
+ _worker_tokenizer = tokenizer
361
+ _worker_cls_id = cls_id
362
+ _worker_sample_rate = sample_rate
363
+ _worker_skew = skew_exponent
364
+ _worker_depth = depth
365
+ if engine_path is not None:
366
+ _worker_engine = chess.engine.SimpleEngine.popen_uci(engine_path)
367
+ atexit.register(_shutdown_worker)
368
+ else:
369
+ _worker_engine = None
370
+
371
+
372
+ def _worker_eval(board: chess.Board) -> float:
373
+ if _worker_engine is None:
374
+ return material_eval(board)
375
+ info = _worker_engine.analyse(board, chess.engine.Limit(depth=_worker_depth))
376
+ score = info["score"].white()
377
+ if score.is_mate():
378
+ return 1.0 if score.mate() > 0 else -1.0
379
+ return normalize_cp(score.score())
380
+
381
+
382
+ def _process_game(game_with_seed):
383
+ """Worker task: parse, replay, sample, tokenize, and evaluate one game."""
384
+ game, seed = game_with_seed
385
+ movetext = game.get("movetext", "")
386
+ if not movetext:
387
+ return []
388
+ move_sans = parse_movetext(movetext)
389
+ if len(move_sans) < 2:
390
+ return []
391
+
392
+ eligible = list(range(len(move_sans)))
393
+ num_positions = max(1, int(len(move_sans) * _worker_sample_rate))
394
+ sample_indices = _weighted_sample(eligible, num_positions, _worker_skew, seed=seed)
395
+
396
+ samples = []
397
+ board = chess.Board()
398
+ valid_moves = []
399
+ for i, san in enumerate(move_sans):
400
+ try:
401
+ move = board.parse_san(san)
402
+ board.push(move)
403
+ valid_moves.append(move.uci())
404
+ except (chess.InvalidMoveError, chess.AmbiguousMoveError):
405
+ break
406
+
407
+ if i in sample_indices:
408
+ token_ids = [_worker_cls_id] + _worker_tokenizer.encode_moves(valid_moves)
409
+ score = _worker_eval(board)
410
+ samples.append((token_ids, score))
411
+
412
+ return samples
413
+
414
+
415
+ def generate_samples_stockfish_parallel(
416
+ games: list[dict],
417
+ tokenizer: Tokenizer,
418
+ num_workers: int = 8,
419
+ stockfish_depth: int = 12,
420
+ sample_rate: float = 0.25,
421
+ skew_exponent: float = 1.5,
422
+ engine_path: str | None = STOCKFISH_PATH,
423
+ chunksize: int = 8,
424
+ progress_every: int = 1000,
425
+ ) -> list[tuple[list[int], float]]:
426
+ """Parallel Stockfish-backed sample generation.
427
+
428
+ Spawns `num_workers` processes, each owning one Stockfish subprocess.
429
+ If `engine_path` is None, workers use material_eval instead of Stockfish
430
+ (used by tests to verify the parallel machinery without the binary).
431
+
432
+ Each game contributes `max(1, game_length * sample_rate)` positions,
433
+ weighted toward mid/late game by `skew_exponent`. Sampling is seeded
434
+ per-game-index for determinism across runs and worker counts.
435
+ """
436
+ cls_id = tokenizer.symbol_to_token[CLS_TOKEN]
437
+ tasks = [(game, idx) for idx, game in enumerate(games)]
438
+
439
+ # spawn context: safest across macOS/Linux and avoids fork-safety issues
440
+ # with chess.engine's subprocess.
441
+ ctx = mp.get_context("spawn")
442
+ samples: list[tuple[list[int], float]] = []
443
+ with ctx.Pool(
444
+ processes=num_workers,
445
+ initializer=_init_worker,
446
+ initargs=(engine_path, stockfish_depth, tokenizer, cls_id, sample_rate, skew_exponent),
447
+ ) as pool:
448
+ for i, game_samples in enumerate(
449
+ pool.imap_unordered(_process_game, tasks, chunksize=chunksize)
450
+ ):
451
+ samples.extend(game_samples)
452
+ if progress_every and (i + 1) % progress_every == 0:
453
+ print(
454
+ f" processed {i + 1:,}/{len(games):,} games, "
455
+ f"{len(samples):,} positions..."
456
+ )
457
+
458
+ return samples
459
+
460
+
461
+ def collate_fn(batch):
462
+ """Pad token sequences and create attention mask."""
463
+ tokens, labels = zip(*batch)
464
+ max_len = max(len(t) for t in tokens)
465
+ padded = torch.zeros(len(tokens), max_len, dtype=torch.long)
466
+ attention_mask = torch.ones(len(tokens), max_len, dtype=torch.bool) # True = masked
467
+
468
+ for i, t in enumerate(tokens):
469
+ padded[i, :len(t)] = t
470
+ attention_mask[i, :len(t)] = False
471
+
472
+ labels_tensor = torch.tensor(labels, dtype=torch.float)
473
+ return padded, attention_mask, labels_tensor
474
+
475
+
476
+ def collate_fn_memmap(batch):
477
+ """Collate pre-padded memmap samples — just stack, no per-batch padding needed."""
478
+ tokens, masks, labels = zip(*batch)
479
+ return torch.stack(tokens), torch.stack(masks), torch.tensor(labels, dtype=torch.float)
480
+
481
+
482
+ def collate_fn_policy(batch):
483
+ """Pad token sequences and per-position board planes for policy training.
484
+
485
+ Each batch element is (tokens, planes, weight, source_tag) where
486
+ `tokens` is shape (L,) long and `planes` is shape (L, 19, 8, 8) float —
487
+ one set of board planes per position in the sequence. We pad both
488
+ along the sequence dimension to the batch's max length. Padded
489
+ positions get zero token, zero planes, and mask=True; downstream
490
+ loss masking ignores them.
491
+
492
+ Returns (padded_tokens, attention_mask, planes, weights, sources).
493
+ """
494
+ tokens_list, planes_list, weights_list, sources_list = zip(*batch)
495
+ B = len(tokens_list)
496
+ max_len = max(len(t) for t in tokens_list)
497
+ padded = torch.zeros(B, max_len, dtype=torch.long)
498
+ mask = torch.ones(B, max_len, dtype=torch.bool) # True = padded
499
+ planes = torch.zeros(B, max_len, 19, 8, 8)
500
+ for i, (t, p) in enumerate(zip(tokens_list, planes_list)):
501
+ L = len(t)
502
+ padded[i, :L] = t
503
+ mask[i, :L] = False
504
+ planes[i, :L] = p
505
+ weights = torch.tensor(weights_list, dtype=torch.float)
506
+ sources = torch.tensor(sources_list, dtype=torch.long)
507
+ return padded, mask, planes, weights, sources
508
+
509
+
510
+ class MixedBatchSampler(torch.utils.data.Sampler):
511
+ """Hard-balanced sampler over a ConcatDataset([games, puzzles]).
512
+
513
+ Each batch contains exactly `n_game_per_batch` game indices (drawn from
514
+ [0, n_game)) and `n_puzzle_per_batch` puzzle indices (drawn from
515
+ [n_game, n_game + n_puzzle)). Both pools are shuffled and consumed in
516
+ parallel; when the smaller (puzzle) pool runs out it gets re-shuffled,
517
+ so puzzles are effectively oversampled to match the game stream.
518
+
519
+ This guarantees a consistent gradient signal per batch and prevents the
520
+ puzzle samples from being statistical outliers under BatchNorm (already
521
+ moot now that the CNN uses GroupNorm, but still matters for loss-level
522
+ balance).
523
+ """
524
+ def __init__(
525
+ self,
526
+ n_game: int,
527
+ n_puzzle: int,
528
+ batch_size: int,
529
+ game_ratio: float = 0.8,
530
+ drop_last: bool = True,
531
+ ):
532
+ self.n_game = n_game
533
+ self.n_puzzle = n_puzzle
534
+ self.batch_size = batch_size
535
+ self.n_game_per_batch = max(1, int(round(batch_size * game_ratio)))
536
+ self.n_puzzle_per_batch = batch_size - self.n_game_per_batch
537
+ self.drop_last = drop_last
538
+
539
+ def __iter__(self):
540
+ game_perm = torch.randperm(self.n_game).tolist()
541
+ puzzle_perm = torch.randperm(self.n_puzzle).tolist() if self.n_puzzle > 0 else []
542
+ gi, pi = 0, 0
543
+ for _ in range(len(self)):
544
+ if gi + self.n_game_per_batch > self.n_game:
545
+ game_perm = torch.randperm(self.n_game).tolist()
546
+ gi = 0
547
+ if self.n_puzzle_per_batch > 0 and pi + self.n_puzzle_per_batch > self.n_puzzle:
548
+ puzzle_perm = torch.randperm(self.n_puzzle).tolist()
549
+ pi = 0
550
+ batch = []
551
+ for _ in range(self.n_game_per_batch):
552
+ batch.append(game_perm[gi]); gi += 1
553
+ for _ in range(self.n_puzzle_per_batch):
554
+ batch.append(self.n_game + puzzle_perm[pi]); pi += 1
555
+ yield batch
556
+
557
+ def __len__(self):
558
+ # One pass over the (more numerous) game pool defines an epoch.
559
+ return self.n_game // self.n_game_per_batch
560
+
561
+
562
+ class ChessPolicyDataset(Dataset):
563
+ """Full game sequences for next-move prediction training.
564
+
565
+ Each sample yields (token_ids, board_planes, weight, source_tag):
566
+
567
+ - token_ids: full tokenized sequence [CLS, m1, m2, ..., mN]
568
+ - board_planes: (L, 19, 8, 8) tensor of per-position planes built by
569
+ replaying the move sequence. planes[0] is the starting
570
+ board (the standard chess start for games, the puzzle
571
+ FEN for puzzles); planes[t] is the board state after
572
+ token_ids[1..t] have been played. This is the
573
+ information-leak-safe per-position anchor that lets the
574
+ model cross-attend to the live board at every step.
575
+ - weight: per-sample loss weight (1.0 for games, default 5.0 for
576
+ puzzles) so puzzle samples have outsized gradient pull.
577
+ - source_tag: 0 = game, 1 = puzzle. Used by the mixed training loop to
578
+ mask the setup-move target on puzzle samples.
579
+ """
580
+ def __init__(self, games: list[dict], tokenizer: Tokenizer, max_seq_len: int = 128):
581
+ cls_id = tokenizer.symbol_to_token[CLS_TOKEN]
582
+ self.tokenizer = tokenizer
583
+ self._memmap = False
584
+ self._train_idx: np.ndarray | None = None
585
+ self._mm_fens = None
586
+ self._fen_len = None
587
+ self.source_tag: int = 0
588
+ self.loss_weight: float = 1.0
589
+ self.samples: list[list[int]] = []
590
+ for game in games:
591
+ movetext = game.get("movetext", "")
592
+ if not movetext:
593
+ continue
594
+ move_sans = parse_movetext(movetext)
595
+ if len(move_sans) < 2:
596
+ continue
597
+ board = chess.Board()
598
+ move_ucis: list[str] = []
599
+ for san in move_sans:
600
+ try:
601
+ move = board.parse_san(san)
602
+ board.push(move)
603
+ move_ucis.append(move.uci())
604
+ except (chess.InvalidMoveError, chess.AmbiguousMoveError):
605
+ break
606
+ if len(move_ucis) < 2:
607
+ continue
608
+ move_ucis = move_ucis[:max_seq_len - 1] # reserve slot for CLS
609
+ self.samples.append([cls_id] + tokenizer.encode_moves(move_ucis))
610
+
611
+ def _get_start_board(self, idx: int) -> chess.Board:
612
+ """Resolve the starting board for the per-position replay.
613
+
614
+ Puzzles with a `{name}_fens.bin` sidecar use the puzzle's FEN.
615
+ Everything else (games, puzzles without FENs) starts from the
616
+ standard chess starting position. A corrupt FEN silently falls
617
+ back to the starting position so the loader doesn't crash.
618
+ """
619
+ if self._memmap and self._mm_fens is not None:
620
+ fen_bytes = bytes(self._mm_fens[idx])
621
+ fen_str = fen_bytes.rstrip(b"\x00").decode("ascii")
622
+ try:
623
+ return chess.Board(fen_str)
624
+ except ValueError:
625
+ return chess.Board()
626
+ return chess.Board()
627
+
628
+ def __len__(self) -> int:
629
+ if self._memmap:
630
+ if self._train_idx is not None:
631
+ return len(self._train_idx)
632
+ return len(self._mm_lengths)
633
+ return len(self.samples)
634
+
635
+ def __getitem__(self, idx: int):
636
+ if self._memmap:
637
+ if self._train_idx is not None:
638
+ idx = int(self._train_idx[idx])
639
+ length = int(self._mm_lengths[idx])
640
+ tokens = torch.from_numpy(np.array(self._mm_tokens[idx, :length], dtype=np.int32)).long()
641
+ else:
642
+ tokens = torch.tensor(self.samples[idx], dtype=torch.long)
643
+ start_board = self._get_start_board(idx)
644
+ planes = self._replay_planes(tokens.tolist(), start_board)
645
+ return tokens, planes, self.loss_weight, self.source_tag
646
+
647
+ @classmethod
648
+ def from_memmap(
649
+ cls,
650
+ out_dir: Path,
651
+ tokenizer: Tokenizer,
652
+ name: str = "policy",
653
+ source_tag: int = 0,
654
+ loss_weight: float = 1.0,
655
+ ):
656
+ """Load pre-tokenized policy sequences from memory-mapped arrays.
657
+
658
+ Args:
659
+ name: filename prefix; use 'puzzle' to load puzzle_*.bin files.
660
+ source_tag: 0 for game data, 1 for puzzle data (drives setup-move
661
+ masking in the mixed training loop).
662
+ loss_weight: per-sample weight applied to this dataset's samples in
663
+ the weighted cross-entropy loss.
664
+
665
+ If a sibling file `{name}_fens.bin` exists, FENs are loaded and used
666
+ to reconstruct each sample's starting-board planes. Otherwise the
667
+ standard chess starting position is used.
668
+
669
+ If `{name}_test_indices.npy` exists, those indices are excluded from
670
+ this dataset — used to make training disjoint from the held-out test
671
+ split that shares the same underlying .bin file.
672
+ """
673
+ meta = torch.load(out_dir / f"{name}_meta.pt", weights_only=True)
674
+ n, max_len = meta["n"], meta["max_len"]
675
+ inst = cls.__new__(cls)
676
+ inst._memmap = True
677
+ inst._mm_tokens = np.memmap(out_dir / f"{name}_tokens.bin", dtype=np.int32, mode="r", shape=(n, max_len))
678
+ inst._mm_lengths = np.memmap(out_dir / f"{name}_lengths.bin", dtype=np.int32, mode="r", shape=(n,))
679
+ inst._train_idx = _load_train_idx(out_dir, name, n)
680
+
681
+ fen_path = out_dir / f"{name}_fens.bin"
682
+ if fen_path.exists() and "fen_len" in meta:
683
+ inst._mm_fens = np.memmap(fen_path, dtype=np.uint8, mode="r", shape=(n, meta["fen_len"]))
684
+ inst._fen_len = meta["fen_len"]
685
+ else:
686
+ inst._mm_fens = None
687
+ inst._fen_len = None
688
+ if source_tag == 1:
689
+ # Puzzle data without FENs: CNN will see the standard starting
690
+ # position for every puzzle, which is wrong. Loud warning.
691
+ print(
692
+ f"WARNING: {name}_fens.bin not found — puzzle samples will "
693
+ f"feed the starting-position planes to the CNN, defeating "
694
+ f"the point of puzzle conditioning. Rebuild with the "
695
+ f"updated build_datasets.py to fix."
696
+ )
697
+
698
+ inst.tokenizer = tokenizer
699
+ inst.source_tag = source_tag
700
+ inst.loss_weight = loss_weight
701
+ return inst
702
+ def _replay_planes(self, token_ids: list[int], start_board: chess.Board) -> torch.Tensor:
703
+ """Returns (L, 19, 8, 8) tensor of board planes per position.
704
+
705
+ plane_t = state of the board after token_ids[1..t] have been played.
706
+ plane_0 = start_board (the model has only seen [CLS] at that point).
707
+
708
+ If a token in the sequence isn't a parseable UCI move (corrupt
709
+ data, non-move special token mid-stream), we freeze planes at the
710
+ last valid state and return. The loss already masks padded targets,
711
+ so the worst case is a few positions with stale board input rather
712
+ than a crashed worker.
713
+ """
714
+ L = len(token_ids)
715
+ planes = torch.zeros(L, 19, 8, 8)
716
+ board = start_board.copy()
717
+ planes[0] = board_to_planes(board)
718
+ for t in range(1, L):
719
+ uci = self.tokenizer.token_to_symbol[int(token_ids[t])]
720
+ try:
721
+ board.push(chess.Move.from_uci(uci))
722
+ except (chess.InvalidMoveError, ValueError):
723
+ planes[t:] = planes[t - 1]
724
+ return planes
725
+ planes[t] = board_to_planes(board)
726
+ return planes
727
+
728
+ def _fmt_duration(seconds: float) -> str:
729
+ h, m = divmod(int(seconds), 3600)
730
+ m, s = divmod(m, 60)
731
+ return f"{h}h {m:02d}m {s:02d}s" if h else f"{m}m {s:02d}s"
732
+
733
+
734
+ def _amp_ctx(device):
735
+ """BF16 autocast on CUDA, no-op elsewhere.
736
+
737
+ BF16 is preferred over FP16 here: same dynamic range as FP32 (no GradScaler
738
+ needed) and full tensor-core acceleration on Ampere+ / Ada / Blackwell.
739
+ On Blackwell (RTX PRO 6000 / B200) this typically gives 2-3x training speedup
740
+ on transformer matmuls.
741
+ """
742
+ dev = device if isinstance(device, str) else getattr(device, "type", "cpu")
743
+ if dev == "cuda":
744
+ return torch.autocast(device_type="cuda", dtype=torch.bfloat16)
745
+ return contextlib.nullcontext()
746
+
747
+
748
+ def _run_epoch_reward(model, loader, optimizer, device, writer, global_step, epoch_idx):
749
+ """Single training epoch: MSE against Stockfish labels."""
750
+ model.train()
751
+ total_loss = 0.0
752
+ n_batches = len(loader)
753
+ log_every = max(1, n_batches // 20)
754
+ epoch_start = time.time()
755
+
756
+ for i, (batch_tokens, batch_mask, batch_labels) in enumerate(loader):
757
+ batch_tokens = batch_tokens.to(device)
758
+ batch_mask = batch_mask.to(device)
759
+ batch_labels = batch_labels.to(device)
760
+
761
+ with _amp_ctx(device):
762
+ predictions = model(batch_tokens, attention_mask=batch_mask)
763
+ loss = F.mse_loss(predictions, batch_labels)
764
+
765
+ optimizer.zero_grad()
766
+ loss.backward()
767
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
768
+ optimizer.step()
769
+ total_loss += loss.item()
770
+
771
+ writer.add_scalar("train/reward_batch_loss", loss.item(), global_step)
772
+ global_step += 1
773
+
774
+ if (i + 1) % log_every == 0 or (i + 1) == n_batches:
775
+ elapsed = time.time() - epoch_start
776
+ batches_done = i + 1
777
+ eta = elapsed / batches_done * (n_batches - batches_done)
778
+ samples_per_sec = batches_done * batch_tokens.size(0) / elapsed
779
+ avg_so_far = total_loss / batches_done
780
+ print(
781
+ f" batch {batches_done:,}/{n_batches:,} "
782
+ f"loss={avg_so_far:.4f} "
783
+ f"{samples_per_sec:,.0f} samples/s "
784
+ f"eta {_fmt_duration(eta)}"
785
+ )
786
+
787
+ epoch_elapsed = time.time() - epoch_start
788
+ avg = total_loss / n_batches
789
+ writer.add_scalar("train/reward_epoch_loss", avg, epoch_idx)
790
+ return avg, global_step, epoch_elapsed
791
+
792
+ def _run_epoch_policy_mixed(
793
+ model, loader, optimizer, device, writer, global_step, epoch_idx, pad_id,
794
+ ):
795
+ """Single training epoch over mixed game + puzzle batches.
796
+
797
+ Loader yields (tokens, mask, planes, weights, sources). For each batch:
798
+
799
+ 1. CNN-conditioned forward pass: position-0 embedding is replaced by the
800
+ CNN's encoding of `planes` (starting board of the sequence).
801
+ 2. Per-position cross-entropy at every non-padded target position.
802
+ 3. Setup-move target is masked out for puzzle rows (source==1): the setup
803
+ move is given as context, not a prediction target.
804
+ 4. Per-sample loss weight upweights puzzle samples (default 5x via the
805
+ dataset's loss_weight field) — implemented as a position-weighted mean.
806
+ """
807
+ model.train()
808
+ total_loss = 0.0
809
+ n_batches = len(loader)
810
+ log_every = max(1, n_batches // 20)
811
+ epoch_start = time.time()
812
+
813
+ for i, (batch_tokens, batch_mask, batch_planes, batch_weights, batch_sources) in enumerate(loader):
814
+ batch_tokens = batch_tokens.to(device, non_blocking=True)
815
+ batch_mask = batch_mask.to(device, non_blocking=True)
816
+ batch_planes = batch_planes.to(device, non_blocking=True)
817
+ batch_weights = batch_weights.to(device, non_blocking=True)
818
+ batch_sources = batch_sources.to(device, non_blocking=True)
819
+
820
+ input_tokens = batch_tokens[:, :-1]
821
+ input_mask = batch_mask[:, :-1]
822
+ input_planes = batch_planes[:, :-1] # planes are per-position; slice with tokens
823
+ targets = batch_tokens[:, 1:].contiguous()
824
+
825
+ # Mask the setup-move target (position 0 of the shifted target) for
826
+ # puzzle rows — it's the opponent's forcing move given as context.
827
+ is_puzzle = (batch_sources == 1)
828
+ if is_puzzle.any():
829
+ targets = targets.clone()
830
+ targets[is_puzzle, 0] = pad_id
831
+
832
+ with _amp_ctx(device):
833
+ logits = model(input_tokens, input_planes, attention_mask=input_mask)
834
+ B, T, V = logits.shape
835
+ ce = F.cross_entropy(
836
+ logits.reshape(-1, V),
837
+ targets.reshape(-1),
838
+ ignore_index=pad_id,
839
+ reduction="none",
840
+ ).reshape(B, T)
841
+ position_mask = (targets != pad_id).float()
842
+ sample_weights = batch_weights.unsqueeze(1)
843
+ weighted = ce * position_mask * sample_weights
844
+ denom = (position_mask * sample_weights).sum().clamp(min=1.0)
845
+ loss = weighted.sum() / denom
846
+
847
+ optimizer.zero_grad()
848
+ loss.backward()
849
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
850
+ optimizer.step()
851
+ total_loss += loss.item()
852
+
853
+ writer.add_scalar("train_policy/batch_loss", loss.item(), global_step)
854
+ global_step += 1
855
+
856
+ if (i + 1) % log_every == 0 or (i + 1) == n_batches:
857
+ elapsed = time.time() - epoch_start
858
+ batches_done = i + 1
859
+ eta = elapsed / batches_done * (n_batches - batches_done)
860
+ samples_per_sec = batches_done * batch_tokens.size(0) / elapsed
861
+ avg_so_far = total_loss / batches_done
862
+ print(
863
+ f" batch {batches_done:,}/{n_batches:,} "
864
+ f"loss={avg_so_far:.4f} "
865
+ f"{samples_per_sec:,.0f} samples/s "
866
+ f"eta {_fmt_duration(eta)}"
867
+ )
868
+
869
+ epoch_elapsed = time.time() - epoch_start
870
+ avg = total_loss / max(n_batches, 1)
871
+ writer.add_scalar("train_policy/epoch_loss", avg, epoch_idx)
872
+ return avg, global_step, epoch_elapsed
873
+
874
+
875
+ def eval_reward(model, loader, device) -> dict:
876
+ """Evaluate reward model on a test loader. Returns MSE, MAE, and Pearson r."""
877
+ model.eval()
878
+ all_preds, all_labels = [], []
879
+ with torch.no_grad(), _amp_ctx(device):
880
+ for batch_tokens, batch_mask, batch_labels in loader:
881
+ preds = model(batch_tokens.to(device), attention_mask=batch_mask.to(device))
882
+ all_preds.append(preds.float().cpu())
883
+ all_labels.append(batch_labels)
884
+ preds = torch.cat(all_preds)
885
+ labels = torch.cat(all_labels)
886
+ mse = F.mse_loss(preds, labels).item()
887
+ mae = (preds - labels).abs().mean().item()
888
+ # Pearson r
889
+ p_centered = preds - preds.mean()
890
+ l_centered = labels - labels.mean()
891
+ denom = (p_centered.norm() * l_centered.norm()).clamp(min=1e-8)
892
+ pearson_r = (p_centered * l_centered).sum() / denom
893
+ return {"mse": mse, "mae": mae, "pearson_r": pearson_r.item()}
894
+
895
+
896
+ def eval_policy(model, loader, device, pad_id: int) -> dict:
897
+ """Evaluate policy model on a test loader. Returns loss, perplexity, top-1/top-5 acc.
898
+
899
+ Loader yields (tokens, mask, planes, weights, sources). Weights and sources
900
+ are ignored here — eval is uniform across samples.
901
+ """
902
+ model.eval()
903
+ total_loss = 0.0
904
+ total_correct1 = 0
905
+ total_correct5 = 0
906
+ total_positions = 0
907
+ with torch.no_grad(), _amp_ctx(device):
908
+ for batch_tokens, batch_mask, batch_planes, _, _ in loader:
909
+ batch_tokens = batch_tokens.to(device)
910
+ batch_mask = batch_mask.to(device)
911
+ batch_planes = batch_planes.to(device)
912
+ input_tokens = batch_tokens[:, :-1]
913
+ input_mask = batch_mask[:, :-1]
914
+ input_planes = batch_planes[:, :-1]
915
+ targets = batch_tokens[:, 1:].contiguous()
916
+ logits = model(input_tokens, input_planes, attention_mask=input_mask)
917
+ flat_logits = logits.reshape(-1, logits.size(-1))
918
+ flat_targets = targets.reshape(-1)
919
+ valid = flat_targets != pad_id
920
+ total_loss += F.cross_entropy(flat_logits, flat_targets, ignore_index=pad_id, reduction="sum").item()
921
+ total_positions += valid.sum().item()
922
+ top5 = flat_logits[valid].topk(5, dim=-1).indices
923
+ valid_targets = flat_targets[valid]
924
+ total_correct1 += (top5[:, 0] == valid_targets).sum().item()
925
+ total_correct5 += (top5 == valid_targets.unsqueeze(1)).any(dim=1).sum().item()
926
+ avg_loss = total_loss / max(total_positions, 1)
927
+ return {
928
+ "loss": avg_loss,
929
+ "perplexity": math.exp(min(avg_loss, 20)),
930
+ "top1_acc": total_correct1 / max(total_positions, 1),
931
+ "top5_acc": total_correct5 / max(total_positions, 1),
932
+ }
933
+
934
+
935
+ def eval_puzzle_solve_rate(model, loader, device, pad_id: int) -> dict:
936
+ """Evaluate puzzle solve rate: % of solver positions where model's top-1 matches
937
+ ground truth. Sequence layout: [CLS, setup, solver1, opp1, solver2, ...]
938
+
939
+ Solver moves are at token positions 2, 4, 6, ... (logit positions 1, 3, 5, ...).
940
+ The setup move at token position 1 (logit 0) is excluded — it's context, not a
941
+ prediction target. Also reports first-move solve rate (logit position 1 only).
942
+ """
943
+ model.eval()
944
+ first_correct = 0
945
+ first_total = 0
946
+ all_correct = 0
947
+ all_total = 0
948
+ with torch.no_grad(), _amp_ctx(device):
949
+ for batch_tokens, batch_mask, batch_planes, _, _ in loader:
950
+ batch_tokens = batch_tokens.to(device)
951
+ batch_mask = batch_mask.to(device)
952
+ batch_planes = batch_planes.to(device)
953
+ input_tokens = batch_tokens[:, :-1]
954
+ input_mask = batch_mask[:, :-1]
955
+ input_planes = batch_planes[:, :-1]
956
+ logits = model(input_tokens, input_planes, attention_mask=input_mask)
957
+ seq_len = batch_tokens.size(1)
958
+ # Solver logit positions: 1, 3, 5, ... → target positions: 2, 4, 6, ...
959
+ for solver_logit_pos in range(1, seq_len - 1, 2):
960
+ solver_token_pos = solver_logit_pos + 1
961
+ if solver_token_pos >= seq_len:
962
+ break
963
+ targets = batch_tokens[:, solver_token_pos]
964
+ valid = targets != pad_id
965
+ if not valid.any():
966
+ continue
967
+ preds = logits[:, solver_logit_pos].argmax(dim=-1)
968
+ correct = (preds[valid] == targets[valid]).sum().item()
969
+ n_valid = valid.sum().item()
970
+ all_correct += correct
971
+ all_total += n_valid
972
+ if solver_logit_pos == 1:
973
+ first_correct += correct
974
+ first_total += n_valid
975
+ return {
976
+ "first_move_solve_rate": first_correct / max(first_total, 1),
977
+ "all_moves_solve_rate": all_correct / max(all_total, 1),
978
+ }
979
+
980
+
981
+ def train(
982
+ tokenizer_path,
983
+ stockfish_samples_path,
984
+ outcome_games_path,
985
+ epochs,
986
+ policy_epochs,
987
+ batch_size,
988
+ learning_rate,
989
+ max_seq_len,
990
+ log_dir,
991
+ num_workers,
992
+ puzzle_data_dir=None,
993
+ puzzle_epochs=5, # kept for CLI compat; no longer used (mixed training merges phases)
994
+ puzzle_loss_weight=5.0,
995
+ puzzle_ratio=0.2,
996
+ skip_reward=False,
997
+ keep_last_n_checkpoints=3,
998
+ ):
999
+ """Train the reward model then the policy model.
1000
+
1001
+ Phase 1: MSE on Stockfish-labeled positions (reward model).
1002
+ Phase 2: Mixed game + puzzle policy training. Each batch is hard-balanced
1003
+ at `puzzle_ratio` (default 20% puzzle) and puzzle samples carry a
1004
+ `puzzle_loss_weight` (default 5x) in the weighted cross-entropy loss.
1005
+ Games feed the CNN the standard chess starting board (constant signal,
1006
+ effectively a no-op); puzzles feed the FEN-derived board.
1007
+
1008
+ If `skip_reward` is True, Phase 1 is skipped entirely — the reward dataset
1009
+ is not loaded, no reward model is created, and `reward_model.pt` on disk is
1010
+ untouched. Use this for iterating on Phase 2 without burning hours on a
1011
+ Phase 1 that hasn't changed.
1012
+
1013
+ Requires stockfish memmap files, outcome games, and (for mixed training)
1014
+ puzzle memmaps with FENs built by src/build_datasets.py.
1015
+ """
1016
+ device = "cuda" if torch.cuda.is_available() else "cpu"
1017
+ amp_dtype = "bfloat16 autocast" if device == "cuda" else "fp32 (CPU)"
1018
+ print(f"Using device: {device} ({amp_dtype})")
1019
+
1020
+ print(f"Loading tokenizer from {tokenizer_path}...")
1021
+ tokenizer = torch.load(tokenizer_path, weights_only=False)
1022
+ vocab_size = tokenizer.language_size
1023
+ pad_id = tokenizer.symbol_to_token[PAD_TOKEN]
1024
+
1025
+ writer = SummaryWriter(log_dir=log_dir)
1026
+
1027
+ # ── Test loaders (optional, skip silently if test sets not built yet) ───────
1028
+ out_dir = Path(stockfish_samples_path).parent
1029
+ reward_test_loader = None
1030
+ if not skip_reward and (out_dir / "stockfish_test_meta.pt").exists():
1031
+ reward_test_ds = ChessPositionDataset.from_memmap(out_dir, "stockfish_test", tokenizer)
1032
+ reward_test_loader = DataLoader(
1033
+ reward_test_ds, batch_size=batch_size, shuffle=False, collate_fn=collate_fn_memmap,
1034
+ num_workers=num_workers, pin_memory=True,
1035
+ )
1036
+ print(f"Reward test set: {len(reward_test_ds):,} positions loaded")
1037
+
1038
+ policy_data_dir_early = Path(outcome_games_path).parent
1039
+ policy_test_loader = None
1040
+ if (policy_data_dir_early / "policy_test_meta.pt").exists():
1041
+ policy_test_ds = ChessPolicyDataset.from_memmap(policy_data_dir_early, tokenizer, name="policy_test")
1042
+ policy_test_loader = DataLoader(
1043
+ policy_test_ds, batch_size=batch_size, shuffle=False, collate_fn=collate_fn_policy,
1044
+ num_workers=num_workers, pin_memory=True,
1045
+ )
1046
+ print(f"Policy test set: {len(policy_test_ds):,} sequences loaded")
1047
+
1048
+ puzzle_test_loader = None
1049
+ _puzzle_dir = Path(puzzle_data_dir) if puzzle_data_dir is not None else policy_data_dir_early
1050
+ if (_puzzle_dir / "puzzle_test_meta.pt").exists():
1051
+ puzzle_test_ds = ChessPolicyDataset.from_memmap(
1052
+ _puzzle_dir, tokenizer, name="puzzle_test",
1053
+ source_tag=1, loss_weight=1.0, # eval uses uniform weighting
1054
+ )
1055
+ puzzle_test_loader = DataLoader(
1056
+ puzzle_test_ds, batch_size=batch_size, shuffle=False, collate_fn=collate_fn_policy,
1057
+ num_workers=num_workers, pin_memory=True,
1058
+ )
1059
+ print(f"Puzzle test set: {len(puzzle_test_ds):,} sequences loaded")
1060
+
1061
+ # ── Phase 1: reward model ────────────────────────────────────────────────
1062
+ reward_model = None
1063
+ global_step = 0
1064
+ if skip_reward:
1065
+ print("\n── Phase 1: SKIPPED (--skip-reward) — existing reward_model.pt untouched.")
1066
+ else:
1067
+ sf_meta = out_dir / "stockfish_meta.pt"
1068
+ if sf_meta.exists():
1069
+ print(f"Loading Stockfish samples from memmap ({out_dir}/stockfish_*)...")
1070
+ sf_ds = ChessPositionDataset.from_memmap(out_dir, "stockfish", tokenizer)
1071
+ sf_collate = collate_fn_memmap
1072
+ else:
1073
+ print(f"Loading Stockfish samples from {stockfish_samples_path}...")
1074
+ sf_ds = ChessPositionDataset.from_file(stockfish_samples_path, tokenizer)
1075
+ sf_collate = collate_fn
1076
+ print(f"Reward dataset: {len(sf_ds):,} positions")
1077
+
1078
+ sf_loader = DataLoader(
1079
+ sf_ds, batch_size=batch_size, shuffle=True, collate_fn=sf_collate,
1080
+ num_workers=num_workers, pin_memory=True,
1081
+ )
1082
+
1083
+ reward_model = ChessRewardModel(vocab_size=vocab_size, max_seq_len=max_seq_len).to(device)
1084
+ reward_optimizer = torch.optim.AdamW(reward_model.parameters(), lr=learning_rate)
1085
+
1086
+ print(f"\n── Phase 1: reward model — {epochs} epochs, lr={learning_rate}")
1087
+ phase1_start = time.time()
1088
+ for epoch in range(epochs):
1089
+ epoch_num = epoch + 1
1090
+ print(f" [epoch {epoch_num}/{epochs}] starting...")
1091
+ avg_loss, global_step, epoch_secs = _run_epoch_reward(
1092
+ reward_model, sf_loader, reward_optimizer, device, writer, global_step, epoch
1093
+ )
1094
+ epochs_left = epochs - epoch_num
1095
+ print(
1096
+ f" [epoch {epoch_num}/{epochs}] "
1097
+ f"loss={avg_loss:.4f} "
1098
+ f"epoch_time={_fmt_duration(epoch_secs)} "
1099
+ f"eta={_fmt_duration(epoch_secs * epochs_left)}"
1100
+ )
1101
+ if reward_test_loader is not None:
1102
+ m = eval_reward(reward_model, reward_test_loader, device)
1103
+ writer.add_scalar("test/reward_mse", m["mse"], epoch_num)
1104
+ writer.add_scalar("test/reward_mae", m["mae"], epoch_num)
1105
+ writer.add_scalar("test/reward_pearson_r", m["pearson_r"], epoch_num)
1106
+ print(
1107
+ f" [test] mse={m['mse']:.4f} mae={m['mae']:.4f} r={m['pearson_r']:.4f}"
1108
+ )
1109
+
1110
+ print(f"Phase 1 complete in {_fmt_duration(time.time() - phase1_start)}")
1111
+ torch.save(reward_model.state_dict(), "reward_model.pt")
1112
+ print("Reward model saved to reward_model.pt")
1113
+
1114
+ # ── Phase 2: mixed game + puzzle policy training ─────────────────────────
1115
+ policy_data_dir = policy_data_dir_early # already computed above
1116
+ policy_meta = policy_data_dir / "policy_meta.pt"
1117
+ if policy_meta.exists():
1118
+ print(f"Loading policy sequences from memmap ({policy_data_dir}/policy_*)...")
1119
+ game_ds = ChessPolicyDataset.from_memmap(
1120
+ policy_data_dir, tokenizer, name="policy",
1121
+ source_tag=0, loss_weight=1.0,
1122
+ )
1123
+ else:
1124
+ print(f"Loading outcome games from {outcome_games_path} (tokenizing on-the-fly)...")
1125
+ outcome_games = torch.load(outcome_games_path, weights_only=False)
1126
+ game_ds = ChessPolicyDataset(outcome_games, tokenizer, max_seq_len=max_seq_len)
1127
+ print(f"Game dataset: {len(game_ds):,} sequences")
1128
+
1129
+ # Puzzle dataset (optional — falls back to game-only training if absent).
1130
+ # If --puzzle-data isn't passed, look for puzzle_*.bin alongside policy_*.bin
1131
+ # so a `build_datasets.py --policy-only` layout (everything in data/) is
1132
+ # picked up automatically without an extra CLI flag.
1133
+ puzzle_ds = None
1134
+ pdir = Path(puzzle_data_dir) if puzzle_data_dir is not None else policy_data_dir_early
1135
+ if (pdir / "puzzle_meta.pt").exists():
1136
+ puzzle_ds = ChessPolicyDataset.from_memmap(
1137
+ pdir, tokenizer, name="puzzle",
1138
+ source_tag=1, loss_weight=puzzle_loss_weight,
1139
+ )
1140
+ print(f"Puzzle dataset ({pdir}): {len(puzzle_ds):,} sequences (loss_weight={puzzle_loss_weight}x)")
1141
+ elif puzzle_data_dir is not None:
1142
+ print(f"WARNING: --puzzle-data given but {pdir}/puzzle_meta.pt not found.")
1143
+
1144
+ if puzzle_ds is not None:
1145
+ mixed_ds = torch.utils.data.ConcatDataset([game_ds, puzzle_ds])
1146
+ sampler = MixedBatchSampler(
1147
+ n_game=len(game_ds),
1148
+ n_puzzle=len(puzzle_ds),
1149
+ batch_size=batch_size,
1150
+ game_ratio=1.0 - puzzle_ratio,
1151
+ )
1152
+ print(
1153
+ f"Mixed batch composition: {sampler.n_game_per_batch} game + "
1154
+ f"{sampler.n_puzzle_per_batch} puzzle per batch (puzzle_ratio={puzzle_ratio})"
1155
+ )
1156
+ policy_loader = DataLoader(
1157
+ mixed_ds, batch_sampler=sampler, collate_fn=collate_fn_policy,
1158
+ num_workers=num_workers, pin_memory=True,
1159
+ )
1160
+ else:
1161
+ policy_loader = DataLoader(
1162
+ game_ds, batch_size=batch_size, shuffle=True, collate_fn=collate_fn_policy,
1163
+ num_workers=num_workers, pin_memory=True,
1164
+ )
1165
+
1166
+ policy_model = ChessPolicyModel(vocab_size=vocab_size, max_seq_len=max_seq_len).to(device)
1167
+ policy_optimizer = torch.optim.AdamW(policy_model.parameters(), lr=learning_rate)
1168
+ global_step = 0
1169
+
1170
+ def _run_policy_test(epoch_num: int, tb_prefix: str) -> None:
1171
+ if policy_test_loader is not None:
1172
+ m = eval_policy(policy_model, policy_test_loader, device, pad_id)
1173
+ writer.add_scalar(f"{tb_prefix}/policy_loss", m["loss"], epoch_num)
1174
+ writer.add_scalar(f"{tb_prefix}/policy_perplexity", m["perplexity"], epoch_num)
1175
+ writer.add_scalar(f"{tb_prefix}/policy_top1_acc", m["top1_acc"], epoch_num)
1176
+ writer.add_scalar(f"{tb_prefix}/policy_top5_acc", m["top5_acc"], epoch_num)
1177
+ print(
1178
+ f" [policy test] loss={m['loss']:.4f} ppl={m['perplexity']:.2f}"
1179
+ f" top1={m['top1_acc']:.3f} top5={m['top5_acc']:.3f}"
1180
+ )
1181
+ if puzzle_test_loader is not None:
1182
+ m = eval_puzzle_solve_rate(policy_model, puzzle_test_loader, device, pad_id)
1183
+ writer.add_scalar(f"{tb_prefix}/puzzle_first_move", m["first_move_solve_rate"], epoch_num)
1184
+ writer.add_scalar(f"{tb_prefix}/puzzle_all_moves", m["all_moves_solve_rate"], epoch_num)
1185
+ print(
1186
+ f" [puzzle test] first_move={m['first_move_solve_rate']:.3f}"
1187
+ f" all_moves={m['all_moves_solve_rate']:.3f}"
1188
+ )
1189
+
1190
+ def _save_epoch_checkpoint(epoch_num: int) -> None:
1191
+ """Save the policy model after a completed epoch.
1192
+
1193
+ Each checkpoint is `policy_model_epoch_{NN}.pt` next to the final
1194
+ `policy_model.pt`. If `keep_last_n_checkpoints > 0`, older
1195
+ checkpoints are pruned to cap disk usage. The final end-of-Phase-2
1196
+ save (`policy_model.pt`) is kept regardless of this setting and
1197
+ always reflects the last completed epoch.
1198
+ """
1199
+ ckpt_path = Path(f"policy_model_epoch_{epoch_num:02d}.pt")
1200
+ torch.save(policy_model.state_dict(), ckpt_path)
1201
+ print(f" [checkpoint] saved {ckpt_path.name}")
1202
+ if keep_last_n_checkpoints and keep_last_n_checkpoints > 0:
1203
+ existing = sorted(Path(".").glob("policy_model_epoch_*.pt"))
1204
+ stale = existing[:-keep_last_n_checkpoints]
1205
+ for p in stale:
1206
+ try:
1207
+ p.unlink()
1208
+ except OSError:
1209
+ pass
1210
+
1211
+ def _log_cross_gates(epoch_num: int) -> None:
1212
+ """Log per-block cross-attention gate values to TensorBoard.
1213
+
1214
+ Each CrossAttnBlock has a single learned scalar `cross_gate` whose
1215
+ tanh controls how much board cross-attention contributes through
1216
+ its residual (init=0 means cross-attn starts disabled). Tracking
1217
+ these over epochs shows which layers opened the board pathway and
1218
+ how fast — flat-at-zero across all layers means the model decided
1219
+ cross-attention wasn't worth it.
1220
+
1221
+ TB tags:
1222
+ cross_gate/block_{i} effective gate tanh(α) ∈ (-1, 1)
1223
+ cross_gate_raw/block_{i} raw parameter α (unbounded)
1224
+ """
1225
+ blocks = getattr(policy_model, "blocks", None)
1226
+ if blocks is None:
1227
+ return # Older model variants without CrossAttnBlock stack.
1228
+ gates_tanh = {}
1229
+ for i, blk in enumerate(blocks):
1230
+ raw = blk.cross_gate.detach()
1231
+ tanh_val = raw.tanh().item()
1232
+ raw_val = raw.item()
1233
+ writer.add_scalar(f"cross_gate/block_{i}", tanh_val, epoch_num)
1234
+ writer.add_scalar(f"cross_gate_raw/block_{i}", raw_val, epoch_num)
1235
+ gates_tanh[f"block_{i}"] = tanh_val
1236
+ # Overlay all blocks on a single chart for easy at-a-glance comparison.
1237
+ writer.add_scalars("cross_gate_all", gates_tanh, epoch_num)
1238
+ gate_summary = " ".join(f"L{i}={v:+.3f}" for i, v in enumerate(gates_tanh.values()))
1239
+ print(f" [cross_gate] {gate_summary}")
1240
+
1241
+ print(f"\n── Phase 2: mixed policy training — {policy_epochs} epochs, lr={learning_rate}")
1242
+ phase2_start = time.time()
1243
+ # Log initial gate values (all zeros at init) so TB charts start at epoch 0.
1244
+ _log_cross_gates(0)
1245
+ for epoch in range(policy_epochs):
1246
+ epoch_num = epoch + 1
1247
+ print(f" [epoch {epoch_num}/{policy_epochs}] starting...")
1248
+ avg_loss, global_step, epoch_secs = _run_epoch_policy_mixed(
1249
+ policy_model, policy_loader, policy_optimizer, device, writer, global_step, epoch, pad_id,
1250
+ )
1251
+ epochs_left = policy_epochs - epoch_num
1252
+ print(
1253
+ f" [epoch {epoch_num}/{policy_epochs}] "
1254
+ f"loss={avg_loss:.4f} "
1255
+ f"epoch_time={_fmt_duration(epoch_secs)} "
1256
+ f"eta={_fmt_duration(epoch_secs * epochs_left)}"
1257
+ )
1258
+ _run_policy_test(epoch_num, "test_mixed")
1259
+ _log_cross_gates(epoch_num)
1260
+ _save_epoch_checkpoint(epoch_num)
1261
+
1262
+ print(f"Phase 2 complete in {_fmt_duration(time.time() - phase2_start)}")
1263
+ torch.save(policy_model.state_dict(), "policy_model.pt")
1264
+ print("Policy model saved to policy_model.pt")
1265
+
1266
+ return reward_model, policy_model, tokenizer
1267
+
1268
+
1269
+ def _build_argparser():
1270
+ p = argparse.ArgumentParser(description=train.__doc__)
1271
+ p.add_argument("--tokenizer-path", default="data/tokenizer.pt")
1272
+ p.add_argument("--stockfish-samples-path", default="data/stockfish_samples.pt")
1273
+ p.add_argument("--outcome-games-path", default="data/games_outcome.pt")
1274
+ p.add_argument("--epochs", type=int, default=15)
1275
+ p.add_argument("--policy-epochs", type=int, default=15)
1276
+ p.add_argument("--batch-size", type=int, default=1024)
1277
+ p.add_argument("--learning-rate", type=float, default=3e-5)
1278
+ p.add_argument("--max-seq-len", type=int, default=128)
1279
+ p.add_argument("--log-dir", default="runs/chess_models")
1280
+ p.add_argument("--num-workers", type=int, default=8)
1281
+ p.add_argument("--puzzle-data", default=None, dest="puzzle_data_dir",
1282
+ help="Directory containing puzzle_tokens.bin / puzzle_lengths.bin / puzzle_fens.bin / puzzle_meta.pt")
1283
+ p.add_argument("--puzzle-epochs", type=int, default=5, dest="puzzle_epochs",
1284
+ help="(Deprecated, retained for CLI compat) — mixed training merges game/puzzle into Phase 2.")
1285
+ p.add_argument("--puzzle-loss-weight", type=float, default=5.0, dest="puzzle_loss_weight",
1286
+ help="Per-sample loss weight applied to puzzle samples in the mixed-training "
1287
+ "weighted cross-entropy (default 5.0).")
1288
+ p.add_argument("--puzzle-ratio", type=float, default=0.2, dest="puzzle_ratio",
1289
+ help="Fraction of each mixed batch drawn from the puzzle dataset (default 0.2).")
1290
+ p.add_argument("--skip-reward", action="store_true", dest="skip_reward",
1291
+ help="Skip Phase 1 (reward model training). Existing reward_model.pt is "
1292
+ "left untouched. Use when iterating on Phase 2 only.")
1293
+ p.add_argument("--keep-last-n-checkpoints", type=int, default=3, dest="keep_last_n_checkpoints",
1294
+ help="Number of per-epoch policy_model_epoch_NN.pt files to keep on disk "
1295
+ "(default 3). Use 0 to keep all epochs. Final policy_model.pt is kept "
1296
+ "regardless.")
1297
+ return p
1298
+
1299
+
1300
+ if __name__ == "__main__":
1301
+ args = _build_argparser().parse_args()
1302
+ train(
1303
+ tokenizer_path=args.tokenizer_path,
1304
+ stockfish_samples_path=args.stockfish_samples_path,
1305
+ outcome_games_path=args.outcome_games_path,
1306
+ epochs=args.epochs,
1307
+ policy_epochs=args.policy_epochs,
1308
+ batch_size=args.batch_size,
1309
+ learning_rate=args.learning_rate,
1310
+ max_seq_len=args.max_seq_len,
1311
+ log_dir=args.log_dir,
1312
+ num_workers=args.num_workers,
1313
+ puzzle_data_dir=args.puzzle_data_dir,
1314
+ puzzle_epochs=args.puzzle_epochs,
1315
+ puzzle_loss_weight=args.puzzle_loss_weight,
1316
+ puzzle_ratio=args.puzzle_ratio,
1317
+ skip_reward=args.skip_reward,
1318
+ keep_last_n_checkpoints=args.keep_last_n_checkpoints,
1319
+ )