Spaces:
Sleeping
Sleeping
serving model
Browse files- .gitignore +17 -0
- Dockerfile +36 -0
- app.py +43 -0
- model/policy_model.pt +3 -0
- model/tokenizer.pt +3 -0
- requirements.txt +28 -0
- README.md → src/README.md +0 -0
- src/__init__.py +0 -0
- src/benchmark.py +204 -0
- src/build_datasets.py +697 -0
- src/minimax.py +116 -0
- src/model.py +327 -0
- src/tokenizer.py +104 -0
- src/train.py +1319 -0
.gitignore
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Virtual environments
|
| 2 |
+
.venv/
|
| 3 |
+
venv/
|
| 4 |
+
env/
|
| 5 |
+
|
| 6 |
+
# Python
|
| 7 |
+
__pycache__/
|
| 8 |
+
*.py[cod]
|
| 9 |
+
*.egg-info/
|
| 10 |
+
.pytest_cache/
|
| 11 |
+
|
| 12 |
+
# IDE
|
| 13 |
+
.vscode/
|
| 14 |
+
.idea/
|
| 15 |
+
|
| 16 |
+
# OS
|
| 17 |
+
.DS_Store
|
Dockerfile
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.12.2-slim
|
| 2 |
+
|
| 3 |
+
ENV PYTHONDONTWRITEBYTECODE=1 \
|
| 4 |
+
PYTHONUNBUFFERED=1 \
|
| 5 |
+
PIP_NO_CACHE_DIR=1
|
| 6 |
+
|
| 7 |
+
# build-essential covers the few deps with C extensions (numpy/pandas wheels
|
| 8 |
+
# are usually prebuilt for 3.12, but this is cheap insurance).
|
| 9 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 10 |
+
build-essential \
|
| 11 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 12 |
+
|
| 13 |
+
WORKDIR /app
|
| 14 |
+
|
| 15 |
+
# Deps first so editing app.py doesn't reinstall torch on every rebuild.
|
| 16 |
+
COPY requirements.txt .
|
| 17 |
+
RUN pip install --extra-index-url https://download.pytorch.org/whl/cpu -r requirements.txt
|
| 18 |
+
|
| 19 |
+
# Source next.
|
| 20 |
+
COPY src/ ./src/
|
| 21 |
+
COPY app.py ./
|
| 22 |
+
|
| 23 |
+
# Weights last — biggest layer, changes least often, so cache stays warm
|
| 24 |
+
# when you iterate on app.py.
|
| 25 |
+
COPY model/ ./model/
|
| 26 |
+
|
| 27 |
+
# torch.load on tokenizer.pt unpickles a `Tokenizer` instance that was
|
| 28 |
+
# saved from the module path `tokenizer` (src/tokenizer.py). Making /app/src
|
| 29 |
+
# importable lets the unpickler find that class.
|
| 30 |
+
ENV PYTHONPATH=/app/src
|
| 31 |
+
|
| 32 |
+
# HF Spaces routes external traffic to $PORT; 7860 is the convention.
|
| 33 |
+
# Single worker — the model lives in process memory and multiple workers
|
| 34 |
+
# would multiply the ~1.5 GB RSS.
|
| 35 |
+
EXPOSE 7860
|
| 36 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860", "--workers", "1"]
|
app.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI, HTTPException
|
| 2 |
+
import chess
|
| 3 |
+
from contextlib import asynccontextmanager
|
| 4 |
+
import sys
|
| 5 |
+
from src.model import ChessPolicyModel, PolicyModelInference
|
| 6 |
+
from src.tokenizer import tokenizer
|
| 7 |
+
import torch
|
| 8 |
+
from pydantic import BaseModel
|
| 9 |
+
|
| 10 |
+
ml = {}
|
| 11 |
+
@asynccontextmanager
|
| 12 |
+
async def lifespan(app: FastAPI):
|
| 13 |
+
tokenizer = torch.load("./model/tokenizer.pt", weights_only=False, map_location=torch.device('cpu'))
|
| 14 |
+
|
| 15 |
+
model = ChessPolicyModel(vocab_size=tokenizer.language_size)
|
| 16 |
+
model.load_state_dict(
|
| 17 |
+
torch.load("./model/policy_model.pt", weights_only=False, map_location=torch.device('cpu'))
|
| 18 |
+
)
|
| 19 |
+
ml["inference"] = PolicyModelInference(model, tokenizer, device="cpu")
|
| 20 |
+
yield
|
| 21 |
+
ml.clear()
|
| 22 |
+
|
| 23 |
+
app = FastAPI(lifespan=lifespan)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class InferenceRequest(BaseModel):
|
| 28 |
+
moves: list[str]
|
| 29 |
+
|
| 30 |
+
@app.post("/inference")
|
| 31 |
+
def model_inference(req: InferenceRequest):
|
| 32 |
+
board = chess.Board()
|
| 33 |
+
for move in req.moves:
|
| 34 |
+
try:
|
| 35 |
+
board.push_uci(move)
|
| 36 |
+
except ValueError as e:
|
| 37 |
+
raise HTTPException(status_code=400, detail=f"Incorrect move {move}: {e}")
|
| 38 |
+
try:
|
| 39 |
+
return {"move" : ml["inference"](board)}
|
| 40 |
+
except ValueError as e:
|
| 41 |
+
raise HTTPException(status_code=500, detail=f"Model Failed to evaluate: {e}")
|
| 42 |
+
|
| 43 |
+
|
model/policy_model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b7c3556e2842ff3f6b20a11b9cc0759a89ca39b2d7cde30b54d8d3ba1bee573a
|
| 3 |
+
size 322812083
|
model/tokenizer.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:45887b381da27b8cd119274704fb0b4766125a9218f87643be8260017869471b
|
| 3 |
+
size 57977
|
requirements.txt
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
annotated-doc==0.0.4
|
| 2 |
+
annotated-types==0.7.0
|
| 3 |
+
anyio==4.13.0
|
| 4 |
+
chess==1.11.2
|
| 5 |
+
click==8.4.0
|
| 6 |
+
fastapi==0.136.1
|
| 7 |
+
filelock==3.29.0
|
| 8 |
+
fsspec==2026.4.0
|
| 9 |
+
h11==0.16.0
|
| 10 |
+
idna==3.15
|
| 11 |
+
Jinja2==3.1.6
|
| 12 |
+
MarkupSafe==3.0.3
|
| 13 |
+
mpmath==1.3.0
|
| 14 |
+
networkx==3.6.1
|
| 15 |
+
numpy==2.4.5
|
| 16 |
+
pandas==3.0.3
|
| 17 |
+
pydantic==2.13.4
|
| 18 |
+
pydantic_core==2.46.4
|
| 19 |
+
python-chess==1.999
|
| 20 |
+
python-dateutil==2.9.0.post0
|
| 21 |
+
setuptools==81.0.0
|
| 22 |
+
six==1.17.0
|
| 23 |
+
starlette==1.0.0
|
| 24 |
+
sympy==1.14.0
|
| 25 |
+
torch==2.12.0
|
| 26 |
+
typing-inspection==0.4.2
|
| 27 |
+
typing_extensions==4.15.0
|
| 28 |
+
uvicorn==0.47.0
|
README.md → src/README.md
RENAMED
|
File without changes
|
src/__init__.py
ADDED
|
File without changes
|
src/benchmark.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Evaluate trained reward and policy models on held-out test sets.
|
| 2 |
+
|
| 3 |
+
Test sets are built by src/build_datasets.py (run once before benchmarking):
|
| 4 |
+
stockfish_test_*.bin — 50K Stockfish-labeled positions (reward model)
|
| 5 |
+
policy_test_*.bin — 50K game sequences (policy model)
|
| 6 |
+
puzzle_test_*.bin — 100K puzzle sequences (puzzle solve rate)
|
| 7 |
+
|
| 8 |
+
Metrics:
|
| 9 |
+
Reward: MSE, MAE, Pearson r
|
| 10 |
+
Policy: loss, perplexity, top-1 move accuracy, top-5 move accuracy
|
| 11 |
+
Puzzle: first-move solve rate (top-1), all-moves solve rate
|
| 12 |
+
|
| 13 |
+
Usage:
|
| 14 |
+
arch -arm64 poetry run python src/benchmark.py
|
| 15 |
+
arch -arm64 poetry run python src/benchmark.py \\
|
| 16 |
+
--reward-model reward_model.pt \\
|
| 17 |
+
--policy-model policy_model.pt \\
|
| 18 |
+
--data-dir data/ \\
|
| 19 |
+
--batch-size 512
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
import argparse
|
| 23 |
+
import time
|
| 24 |
+
from pathlib import Path
|
| 25 |
+
|
| 26 |
+
import torch
|
| 27 |
+
from torch.utils.data import DataLoader
|
| 28 |
+
|
| 29 |
+
from model import ChessRewardModel, ChessPolicyModel, PAD_TOKEN
|
| 30 |
+
from train import (
|
| 31 |
+
ChessPositionDataset,
|
| 32 |
+
ChessPolicyDataset,
|
| 33 |
+
collate_fn_memmap,
|
| 34 |
+
collate_fn_policy,
|
| 35 |
+
eval_reward,
|
| 36 |
+
eval_policy,
|
| 37 |
+
eval_puzzle_solve_rate,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _fmt(v: float, pct: bool = False) -> str:
|
| 42 |
+
return f"{v * 100:.2f}%" if pct else f"{v:.4f}"
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def run_benchmark(
|
| 46 |
+
data_dir: Path,
|
| 47 |
+
reward_model_path: str | None,
|
| 48 |
+
policy_model_path: str | None,
|
| 49 |
+
batch_size: int = 512,
|
| 50 |
+
num_workers: int = 4,
|
| 51 |
+
device: str | None = None,
|
| 52 |
+
) -> dict:
|
| 53 |
+
"""Run all available benchmarks and return a results dict.
|
| 54 |
+
|
| 55 |
+
Returns a dict with keys 'reward', 'policy', 'puzzle' (each a sub-dict of
|
| 56 |
+
metrics), or omits a key if the test set / model is unavailable.
|
| 57 |
+
"""
|
| 58 |
+
if device is None:
|
| 59 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 60 |
+
print(f"Benchmark device: {device}")
|
| 61 |
+
|
| 62 |
+
tokenizer_path = data_dir / "tokenizer.pt"
|
| 63 |
+
if not tokenizer_path.exists():
|
| 64 |
+
raise FileNotFoundError(f"Tokenizer not found at {tokenizer_path}")
|
| 65 |
+
tokenizer = torch.load(tokenizer_path, weights_only=False)
|
| 66 |
+
vocab_size = tokenizer.language_size
|
| 67 |
+
pad_id = tokenizer.symbol_to_token[PAD_TOKEN]
|
| 68 |
+
|
| 69 |
+
results = {}
|
| 70 |
+
|
| 71 |
+
# ── Reward model ──────────────────────────────────────────────────────────
|
| 72 |
+
reward_test_meta = data_dir / "stockfish_test_meta.pt"
|
| 73 |
+
if reward_model_path and Path(reward_model_path).exists() and reward_test_meta.exists():
|
| 74 |
+
print(f"\nLoading reward model from {reward_model_path}...")
|
| 75 |
+
reward_model = ChessRewardModel(vocab_size=vocab_size).to(device)
|
| 76 |
+
reward_model.load_state_dict(torch.load(reward_model_path, map_location=device, weights_only=True))
|
| 77 |
+
|
| 78 |
+
print("Loading reward test set...")
|
| 79 |
+
reward_test_ds = ChessPositionDataset.from_memmap(data_dir, "stockfish_test", tokenizer)
|
| 80 |
+
reward_test_loader = DataLoader(
|
| 81 |
+
reward_test_ds, batch_size=batch_size, shuffle=False,
|
| 82 |
+
collate_fn=collate_fn_memmap, num_workers=num_workers, pin_memory=True,
|
| 83 |
+
)
|
| 84 |
+
print(f" {len(reward_test_ds):,} test positions")
|
| 85 |
+
|
| 86 |
+
t0 = time.time()
|
| 87 |
+
m = eval_reward(reward_model, reward_test_loader, device)
|
| 88 |
+
print(
|
| 89 |
+
f" Reward | MSE={_fmt(m['mse'])} MAE={_fmt(m['mae'])}"
|
| 90 |
+
f" Pearson r={_fmt(m['pearson_r'])} ({time.time()-t0:.1f}s)"
|
| 91 |
+
)
|
| 92 |
+
results["reward"] = m
|
| 93 |
+
elif not reward_test_meta.exists():
|
| 94 |
+
print("\nReward test set not found (run build_datasets.py to create it). Skipping.")
|
| 95 |
+
elif not reward_model_path or not Path(reward_model_path).exists():
|
| 96 |
+
print(f"\nReward model not found at {reward_model_path}. Skipping.")
|
| 97 |
+
|
| 98 |
+
# ── Policy model ──────────────────────────────────────────────────────────
|
| 99 |
+
policy_test_meta = data_dir / "policy_test_meta.pt"
|
| 100 |
+
puzzle_test_meta = data_dir / "puzzle_test_meta.pt"
|
| 101 |
+
policy_model = None
|
| 102 |
+
|
| 103 |
+
if policy_model_path and Path(policy_model_path).exists():
|
| 104 |
+
print(f"\nLoading policy model from {policy_model_path}...")
|
| 105 |
+
policy_model = ChessPolicyModel(vocab_size=vocab_size).to(device)
|
| 106 |
+
policy_model.load_state_dict(torch.load(policy_model_path, map_location=device, weights_only=True))
|
| 107 |
+
|
| 108 |
+
if policy_model is not None and policy_test_meta.exists():
|
| 109 |
+
print("Loading policy test set...")
|
| 110 |
+
policy_test_ds = ChessPolicyDataset.from_memmap(data_dir, tokenizer, name="policy_test")
|
| 111 |
+
policy_test_loader = DataLoader(
|
| 112 |
+
policy_test_ds, batch_size=batch_size, shuffle=False,
|
| 113 |
+
collate_fn=collate_fn_policy, num_workers=num_workers, pin_memory=True,
|
| 114 |
+
)
|
| 115 |
+
print(f" {len(policy_test_ds):,} test sequences")
|
| 116 |
+
|
| 117 |
+
t0 = time.time()
|
| 118 |
+
m = eval_policy(policy_model, policy_test_loader, device, pad_id)
|
| 119 |
+
print(
|
| 120 |
+
f" Policy | loss={_fmt(m['loss'])} ppl={m['perplexity']:.2f}"
|
| 121 |
+
f" top1={_fmt(m['top1_acc'], pct=True)} top5={_fmt(m['top5_acc'], pct=True)}"
|
| 122 |
+
f" ({time.time()-t0:.1f}s)"
|
| 123 |
+
)
|
| 124 |
+
results["policy"] = m
|
| 125 |
+
elif policy_model is None:
|
| 126 |
+
print(f"\nPolicy model not found at {policy_model_path}. Skipping policy + puzzle eval.")
|
| 127 |
+
elif not policy_test_meta.exists():
|
| 128 |
+
print("\nPolicy test set not found (run build_datasets.py to create it). Skipping.")
|
| 129 |
+
|
| 130 |
+
if policy_model is not None and puzzle_test_meta.exists():
|
| 131 |
+
print("Loading puzzle test set...")
|
| 132 |
+
puzzle_test_ds = ChessPolicyDataset.from_memmap(data_dir, tokenizer, name="puzzle_test")
|
| 133 |
+
puzzle_test_loader = DataLoader(
|
| 134 |
+
puzzle_test_ds, batch_size=batch_size, shuffle=False,
|
| 135 |
+
collate_fn=collate_fn_policy, num_workers=num_workers, pin_memory=True,
|
| 136 |
+
)
|
| 137 |
+
print(f" {len(puzzle_test_ds):,} test puzzles")
|
| 138 |
+
|
| 139 |
+
t0 = time.time()
|
| 140 |
+
m = eval_puzzle_solve_rate(policy_model, puzzle_test_loader, device, pad_id)
|
| 141 |
+
print(
|
| 142 |
+
f" Puzzle | first_move={_fmt(m['first_move_solve_rate'], pct=True)}"
|
| 143 |
+
f" all_moves={_fmt(m['all_moves_solve_rate'], pct=True)}"
|
| 144 |
+
f" ({time.time()-t0:.1f}s)"
|
| 145 |
+
)
|
| 146 |
+
results["puzzle"] = m
|
| 147 |
+
elif policy_model is not None and not puzzle_test_meta.exists():
|
| 148 |
+
print("\nPuzzle test set not found (run build_datasets.py to create it). Skipping.")
|
| 149 |
+
|
| 150 |
+
return results
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def _print_summary(results: dict) -> None:
|
| 154 |
+
print("\n" + "=" * 60)
|
| 155 |
+
print("BENCHMARK SUMMARY")
|
| 156 |
+
print("=" * 60)
|
| 157 |
+
if "reward" in results:
|
| 158 |
+
m = results["reward"]
|
| 159 |
+
print(f" Reward model MSE={m['mse']:.4f} MAE={m['mae']:.4f} Pearson r={m['pearson_r']:.4f}")
|
| 160 |
+
if "policy" in results:
|
| 161 |
+
m = results["policy"]
|
| 162 |
+
print(
|
| 163 |
+
f" Policy model loss={m['loss']:.4f} perplexity={m['perplexity']:.2f}"
|
| 164 |
+
f" top-1={m['top1_acc']*100:.2f}% top-5={m['top5_acc']*100:.2f}%"
|
| 165 |
+
)
|
| 166 |
+
if "puzzle" in results:
|
| 167 |
+
m = results["puzzle"]
|
| 168 |
+
print(
|
| 169 |
+
f" Puzzle eval first-move solve={m['first_move_solve_rate']*100:.2f}%"
|
| 170 |
+
f" all-moves solve={m['all_moves_solve_rate']*100:.2f}%"
|
| 171 |
+
)
|
| 172 |
+
if not results:
|
| 173 |
+
print(" No results — check that models and test sets exist.")
|
| 174 |
+
print("=" * 60)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def _build_argparser() -> argparse.ArgumentParser:
|
| 178 |
+
p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
|
| 179 |
+
p.add_argument("--reward-model", default="reward_model.pt",
|
| 180 |
+
help="Path to reward_model.pt (default: reward_model.pt)")
|
| 181 |
+
p.add_argument("--policy-model", default="policy_model.pt",
|
| 182 |
+
help="Path to policy_model.pt (default: policy_model.pt)")
|
| 183 |
+
p.add_argument("--data-dir", type=Path, default=Path("data"),
|
| 184 |
+
help="Directory containing test set .bin files and tokenizer.pt (default: data/)")
|
| 185 |
+
p.add_argument("--batch-size", type=int, default=512,
|
| 186 |
+
help="Batch size for evaluation (default: 512)")
|
| 187 |
+
p.add_argument("--num-workers", type=int, default=4,
|
| 188 |
+
help="DataLoader worker count (default: 4)")
|
| 189 |
+
p.add_argument("--device", default=None,
|
| 190 |
+
help="Device override, e.g. 'cpu' or 'cuda:0' (default: auto-detect)")
|
| 191 |
+
return p
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
if __name__ == "__main__":
|
| 195 |
+
args = _build_argparser().parse_args()
|
| 196 |
+
results = run_benchmark(
|
| 197 |
+
data_dir=args.data_dir,
|
| 198 |
+
reward_model_path=args.reward_model,
|
| 199 |
+
policy_model_path=args.policy_model,
|
| 200 |
+
batch_size=args.batch_size,
|
| 201 |
+
num_workers=args.num_workers,
|
| 202 |
+
device=args.device,
|
| 203 |
+
)
|
| 204 |
+
_print_summary(results)
|
src/build_datasets.py
ADDED
|
@@ -0,0 +1,697 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Build the two training datasets for hybrid two-phase training.
|
| 2 |
+
|
| 3 |
+
Runs three resumable stages:
|
| 4 |
+
1. Stream the Lichess HF dataset, filter by Elo + Termination, save two
|
| 5 |
+
disjoint raw-game subsets (movetext + Result only).
|
| 6 |
+
2. Build the shared tokenizer from the outcome subset and generate
|
| 7 |
+
outcome-labeled samples ({+1, 0, -1} from game Result).
|
| 8 |
+
3. Run parallel Stockfish on the disjoint subset to produce precisely
|
| 9 |
+
labeled samples (tanh(cp/400)).
|
| 10 |
+
|
| 11 |
+
Each stage skips if its output files already exist. Use --force to re-run.
|
| 12 |
+
|
| 13 |
+
Outputs (under --out-dir):
|
| 14 |
+
games_outcome.pt raw outcome-subset games
|
| 15 |
+
games_stockfish.pt raw stockfish-subset games
|
| 16 |
+
tokenizer.pt shared Tokenizer (built from outcome games)
|
| 17 |
+
outcome_samples.pt list[(token_ids, outcome_label)]
|
| 18 |
+
stockfish_samples.pt list[(token_ids, stockfish_label)]
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import argparse
|
| 22 |
+
import random
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
|
| 25 |
+
import chess
|
| 26 |
+
import numpy as np
|
| 27 |
+
import torch
|
| 28 |
+
from datasets import load_dataset
|
| 29 |
+
from tqdm import tqdm
|
| 30 |
+
|
| 31 |
+
from model import CLS_TOKEN
|
| 32 |
+
from train import (
|
| 33 |
+
build_tokenizer_from_games,
|
| 34 |
+
generate_samples_stockfish_parallel,
|
| 35 |
+
parse_movetext,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
# Lichess Result → outcome label from white's perspective.
|
| 39 |
+
RESULT_TO_LABEL = {"1-0": 1.0, "0-1": -1.0, "1/2-1/2": 0.0}
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _save_as_memmap(
|
| 43 |
+
samples: list[tuple[list[int], float]], out_dir: Path, name: str, max_seq_len: int = 128
|
| 44 |
+
) -> None:
|
| 45 |
+
"""Save samples as memory-mapped arrays for fast DataLoader access.
|
| 46 |
+
|
| 47 |
+
Sequences longer than max_seq_len are truncated (keeps the most recent tokens,
|
| 48 |
+
since the CLS token is at position 0 we keep ids[:max_seq_len]).
|
| 49 |
+
|
| 50 |
+
Produces three files:
|
| 51 |
+
{name}_tokens.bin — (N, max_seq_len) int32, zero-padded
|
| 52 |
+
{name}_labels.bin — (N,) float32
|
| 53 |
+
{name}_lengths.bin — (N,) int32, actual sequence length per sample (capped at max_seq_len)
|
| 54 |
+
{name}_meta.pt — dict with 'n' and 'max_len'
|
| 55 |
+
"""
|
| 56 |
+
n = len(samples)
|
| 57 |
+
max_len = min(max(len(ids) for ids, _ in samples), max_seq_len)
|
| 58 |
+
print(f" memmap {name}: {n:,} samples, max_seq_len={max_len}")
|
| 59 |
+
|
| 60 |
+
tokens = np.memmap(out_dir / f"{name}_tokens.bin", dtype=np.int32, mode="w+", shape=(n, max_len))
|
| 61 |
+
labels = np.memmap(out_dir / f"{name}_labels.bin", dtype=np.float32, mode="w+", shape=(n,))
|
| 62 |
+
lengths = np.memmap(out_dir / f"{name}_lengths.bin", dtype=np.int32, mode="w+", shape=(n,))
|
| 63 |
+
|
| 64 |
+
for i, (ids, label) in enumerate(tqdm(samples, desc=f" writing {name}", unit="sample")):
|
| 65 |
+
ids = ids[:max_len]
|
| 66 |
+
l = len(ids)
|
| 67 |
+
tokens[i, :l] = ids
|
| 68 |
+
labels[i] = label
|
| 69 |
+
lengths[i] = l
|
| 70 |
+
|
| 71 |
+
tokens.flush()
|
| 72 |
+
labels.flush()
|
| 73 |
+
lengths.flush()
|
| 74 |
+
torch.save({"n": n, "max_len": max_len}, out_dir / f"{name}_meta.pt")
|
| 75 |
+
size_gb = (tokens.nbytes + labels.nbytes + lengths.nbytes) / 1024 ** 3
|
| 76 |
+
print(f" memmap {name} saved ({size_gb:.2f} GB)")
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def stage1_collect_games(args: argparse.Namespace) -> None:
|
| 80 |
+
policy_games_path = args.out_dir / "games_outcome.pt"
|
| 81 |
+
reward_games_path = args.out_dir / "games_stockfish.pt"
|
| 82 |
+
policy_only = getattr(args, "policy_only", False)
|
| 83 |
+
|
| 84 |
+
if policy_only:
|
| 85 |
+
# Reward subset is irrelevant when we're only training the policy model.
|
| 86 |
+
# Skip-condition checks only the policy artifact.
|
| 87 |
+
if policy_games_path.exists() and not args.force:
|
| 88 |
+
print(f"Stage 1: skipping — {policy_games_path.name} exists (--policy-only).")
|
| 89 |
+
return
|
| 90 |
+
else:
|
| 91 |
+
if policy_games_path.exists() and reward_games_path.exists() and not args.force:
|
| 92 |
+
print(f"Stage 1: skipping — {policy_games_path.name} and {reward_games_path.name} exist.")
|
| 93 |
+
return
|
| 94 |
+
|
| 95 |
+
if policy_only:
|
| 96 |
+
lower_elo = args.policy_min_elo
|
| 97 |
+
print(
|
| 98 |
+
f"Stage 1: streaming Lichess/standard-chess-games (Termination == 'Normal'), "
|
| 99 |
+
f"policy Elo >= {args.policy_min_elo} (target {args.policy_games:,}). "
|
| 100 |
+
f"Reward subset skipped (--policy-only)."
|
| 101 |
+
)
|
| 102 |
+
else:
|
| 103 |
+
lower_elo = min(args.reward_min_elo, args.policy_min_elo)
|
| 104 |
+
print(
|
| 105 |
+
f"Stage 1: streaming Lichess/standard-chess-games (Termination == 'Normal'), "
|
| 106 |
+
f"reward Elo >= {args.reward_min_elo} (target {args.reward_games:,}), "
|
| 107 |
+
f"policy Elo >= {args.policy_min_elo} (target {args.policy_games:,})..."
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
ds = load_dataset("Lichess/standard-chess-games", split="train", streaming=True)
|
| 111 |
+
# Pre-filter by the lower of the two thresholds to skip clearly ineligible games.
|
| 112 |
+
ds = ds.filter(
|
| 113 |
+
lambda r: (
|
| 114 |
+
r.get("WhiteElo") is not None
|
| 115 |
+
and r.get("BlackElo") is not None
|
| 116 |
+
and r["WhiteElo"] >= lower_elo
|
| 117 |
+
and r["BlackElo"] >= lower_elo
|
| 118 |
+
and r.get("Termination") == "Normal"
|
| 119 |
+
)
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
policy_games: list[dict] = []
|
| 123 |
+
reward_games: list[dict] = []
|
| 124 |
+
keep_keys = ("movetext", "Result")
|
| 125 |
+
|
| 126 |
+
for row in tqdm(ds, desc="Stage 1: streaming", unit="game"):
|
| 127 |
+
white_elo = row.get("WhiteElo", 0)
|
| 128 |
+
black_elo = row.get("BlackElo", 0)
|
| 129 |
+
minimal = {k: row.get(k) for k in keep_keys}
|
| 130 |
+
|
| 131 |
+
if (
|
| 132 |
+
not policy_only
|
| 133 |
+
and len(reward_games) < args.reward_games
|
| 134 |
+
and white_elo >= args.reward_min_elo
|
| 135 |
+
and black_elo >= args.reward_min_elo
|
| 136 |
+
):
|
| 137 |
+
reward_games.append(minimal)
|
| 138 |
+
|
| 139 |
+
if len(policy_games) < args.policy_games and white_elo >= args.policy_min_elo and black_elo >= args.policy_min_elo:
|
| 140 |
+
policy_games.append(minimal)
|
| 141 |
+
|
| 142 |
+
if policy_only:
|
| 143 |
+
if len(policy_games) >= args.policy_games:
|
| 144 |
+
break
|
| 145 |
+
elif len(reward_games) >= args.reward_games and len(policy_games) >= args.policy_games:
|
| 146 |
+
break
|
| 147 |
+
|
| 148 |
+
if policy_only:
|
| 149 |
+
if len(policy_games) < args.policy_games:
|
| 150 |
+
print(
|
| 151 |
+
f" WARNING: dataset exhausted before target — "
|
| 152 |
+
f"got {len(policy_games):,} policy games."
|
| 153 |
+
)
|
| 154 |
+
elif len(reward_games) < args.reward_games or len(policy_games) < args.policy_games:
|
| 155 |
+
print(
|
| 156 |
+
f" WARNING: dataset exhausted before target — "
|
| 157 |
+
f"got {len(reward_games):,} reward + {len(policy_games):,} policy games."
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
if not policy_only:
|
| 161 |
+
print(f"Stage 1: saving {reward_games_path} ({len(reward_games):,} games)...")
|
| 162 |
+
torch.save(reward_games, reward_games_path)
|
| 163 |
+
print(f"Stage 1: saving {policy_games_path} ({len(policy_games):,} games)...")
|
| 164 |
+
torch.save(policy_games, policy_games_path)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def _generate_outcome_samples(games, tokenizer, max_positions_per_game, skip_ply):
|
| 168 |
+
"""Build (token_ids, outcome_label) samples for the phase-1 dataset."""
|
| 169 |
+
cls_id = tokenizer.symbol_to_token[CLS_TOKEN]
|
| 170 |
+
samples: list[tuple[list[int], float]] = []
|
| 171 |
+
with tqdm(games, desc="Stage 2: outcome samples", unit="game") as pbar:
|
| 172 |
+
for idx, game in enumerate(pbar):
|
| 173 |
+
result = game.get("Result")
|
| 174 |
+
if result not in RESULT_TO_LABEL:
|
| 175 |
+
continue
|
| 176 |
+
label = RESULT_TO_LABEL[result]
|
| 177 |
+
|
| 178 |
+
movetext = game.get("movetext", "")
|
| 179 |
+
if not movetext:
|
| 180 |
+
continue
|
| 181 |
+
move_sans = parse_movetext(movetext)
|
| 182 |
+
if len(move_sans) < max(2, skip_ply + 1):
|
| 183 |
+
continue
|
| 184 |
+
|
| 185 |
+
eligible = list(range(skip_ply, len(move_sans)))
|
| 186 |
+
num_positions = min(max_positions_per_game, len(eligible))
|
| 187 |
+
rng = random.Random(idx)
|
| 188 |
+
sample_indices = set(rng.sample(eligible, num_positions))
|
| 189 |
+
|
| 190 |
+
board = chess.Board()
|
| 191 |
+
valid_moves: list[str] = []
|
| 192 |
+
for i, san in enumerate(move_sans):
|
| 193 |
+
try:
|
| 194 |
+
move = board.parse_san(san)
|
| 195 |
+
board.push(move)
|
| 196 |
+
valid_moves.append(move.uci())
|
| 197 |
+
except (chess.InvalidMoveError, chess.AmbiguousMoveError):
|
| 198 |
+
break
|
| 199 |
+
if i in sample_indices:
|
| 200 |
+
token_ids = [cls_id] + tokenizer.encode_moves(valid_moves)
|
| 201 |
+
samples.append((token_ids, label))
|
| 202 |
+
|
| 203 |
+
if (idx + 1) % 50_000 == 0:
|
| 204 |
+
pbar.set_postfix(samples=f"{len(samples):,}")
|
| 205 |
+
|
| 206 |
+
return samples
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def stage2_outcome_samples(args: argparse.Namespace) -> None:
|
| 210 |
+
tokenizer_path = args.out_dir / "tokenizer.pt"
|
| 211 |
+
meta_path = args.out_dir / "outcome_meta.pt"
|
| 212 |
+
if tokenizer_path.exists() and meta_path.exists() and not args.force:
|
| 213 |
+
print(f"Stage 2: skipping — {tokenizer_path.name} and {meta_path.name} exist.")
|
| 214 |
+
return
|
| 215 |
+
|
| 216 |
+
raw_games_path = args.out_dir / "games_outcome.pt"
|
| 217 |
+
print(f"Stage 2: loading outcome games from {raw_games_path}...")
|
| 218 |
+
games = torch.load(raw_games_path, weights_only=False)
|
| 219 |
+
|
| 220 |
+
print("Stage 2: building tokenizer from all UCI moves...")
|
| 221 |
+
tokenizer = build_tokenizer_from_games()
|
| 222 |
+
print(f"Stage 2: tokenizer vocab size = {tokenizer.language_size}")
|
| 223 |
+
torch.save(tokenizer, tokenizer_path)
|
| 224 |
+
|
| 225 |
+
print("Stage 2: generating outcome samples (up to 20 per game)...")
|
| 226 |
+
samples = _generate_outcome_samples(
|
| 227 |
+
games,
|
| 228 |
+
tokenizer,
|
| 229 |
+
max_positions_per_game=20,
|
| 230 |
+
skip_ply=0,
|
| 231 |
+
)
|
| 232 |
+
print(f"Stage 2: saving {len(samples):,} outcome samples as memmap...")
|
| 233 |
+
_save_as_memmap(samples, args.out_dir, "outcome", max_seq_len=args.max_seq_len)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def _generate_policy_sequences(games, tokenizer, max_seq_len: int = 128) -> list[list[int]]:
|
| 237 |
+
"""Tokenize full game sequences for policy training.
|
| 238 |
+
|
| 239 |
+
Each output sequence is [CLS, m1, m2, ..., mN], truncated to max_seq_len.
|
| 240 |
+
Games with fewer than 2 valid UCI moves are skipped.
|
| 241 |
+
"""
|
| 242 |
+
cls_id = tokenizer.symbol_to_token[CLS_TOKEN]
|
| 243 |
+
sequences: list[list[int]] = []
|
| 244 |
+
with tqdm(games, desc="Stage 4: policy sequences", unit="game") as pbar:
|
| 245 |
+
for game in pbar:
|
| 246 |
+
movetext = game.get("movetext", "")
|
| 247 |
+
if not movetext:
|
| 248 |
+
continue
|
| 249 |
+
move_sans = parse_movetext(movetext)
|
| 250 |
+
if len(move_sans) < 2:
|
| 251 |
+
continue
|
| 252 |
+
board = chess.Board()
|
| 253 |
+
move_ucis: list[str] = []
|
| 254 |
+
for san in move_sans:
|
| 255 |
+
try:
|
| 256 |
+
move = board.parse_san(san)
|
| 257 |
+
board.push(move)
|
| 258 |
+
move_ucis.append(move.uci())
|
| 259 |
+
except (chess.InvalidMoveError, chess.AmbiguousMoveError):
|
| 260 |
+
break
|
| 261 |
+
if len(move_ucis) < 2:
|
| 262 |
+
continue
|
| 263 |
+
move_ucis = move_ucis[:max_seq_len - 1]
|
| 264 |
+
sequences.append([cls_id] + tokenizer.encode_moves(move_ucis))
|
| 265 |
+
return sequences
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def _save_policy_memmap(
|
| 269 |
+
sequences: list[list[int]], out_dir: Path, name: str, max_seq_len: int = 128,
|
| 270 |
+
fens: list[str] | None = None, fen_len: int = 100,
|
| 271 |
+
) -> None:
|
| 272 |
+
"""Save policy sequences as memory-mapped arrays (no labels).
|
| 273 |
+
|
| 274 |
+
Produces:
|
| 275 |
+
{name}_tokens.bin — (N, max_len) int32, zero-padded
|
| 276 |
+
{name}_lengths.bin — (N,) int32, actual sequence length per sample
|
| 277 |
+
{name}_meta.pt — dict with 'n', 'max_len', and (if fens given) 'fen_len'
|
| 278 |
+
|
| 279 |
+
If `fens` is provided, also writes {name}_fens.bin — (N, fen_len) uint8
|
| 280 |
+
holding zero-padded ASCII FEN strings, one per sample. Used by the
|
| 281 |
+
CNN-conditioned policy training to reconstruct each sample's starting board.
|
| 282 |
+
"""
|
| 283 |
+
n = len(sequences)
|
| 284 |
+
max_len = min(max(len(s) for s in sequences), max_seq_len)
|
| 285 |
+
print(f" memmap {name}: {n:,} sequences, max_seq_len={max_len}")
|
| 286 |
+
|
| 287 |
+
tokens = np.memmap(out_dir / f"{name}_tokens.bin", dtype=np.int32, mode="w+", shape=(n, max_len))
|
| 288 |
+
lengths = np.memmap(out_dir / f"{name}_lengths.bin", dtype=np.int32, mode="w+", shape=(n,))
|
| 289 |
+
|
| 290 |
+
for i, seq in enumerate(tqdm(sequences, desc=f" writing {name}", unit="seq")):
|
| 291 |
+
seq = seq[:max_len]
|
| 292 |
+
l = len(seq)
|
| 293 |
+
tokens[i, :l] = seq
|
| 294 |
+
lengths[i] = l
|
| 295 |
+
|
| 296 |
+
tokens.flush()
|
| 297 |
+
lengths.flush()
|
| 298 |
+
|
| 299 |
+
meta = {"n": n, "max_len": max_len}
|
| 300 |
+
extra_bytes = 0
|
| 301 |
+
if fens is not None:
|
| 302 |
+
assert len(fens) == n, f"fen count {len(fens)} mismatch with sequence count {n}"
|
| 303 |
+
fens_mm = np.memmap(out_dir / f"{name}_fens.bin", dtype=np.uint8, mode="w+", shape=(n, fen_len))
|
| 304 |
+
for i, fen in enumerate(fens):
|
| 305 |
+
b = fen.encode("ascii")[:fen_len]
|
| 306 |
+
fens_mm[i, :len(b)] = list(b)
|
| 307 |
+
fens_mm.flush()
|
| 308 |
+
meta["fen_len"] = fen_len
|
| 309 |
+
extra_bytes = fens_mm.nbytes
|
| 310 |
+
|
| 311 |
+
torch.save(meta, out_dir / f"{name}_meta.pt")
|
| 312 |
+
size_gb = (tokens.nbytes + lengths.nbytes + extra_bytes) / 1024 ** 3
|
| 313 |
+
print(f" memmap {name} saved ({size_gb:.3f} GB)")
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def _process_puzzle(
|
| 317 |
+
row: dict,
|
| 318 |
+
tokenizer_symbol_map: dict,
|
| 319 |
+
cls_id: int,
|
| 320 |
+
) -> tuple[list[int], str] | None:
|
| 321 |
+
"""Parse one Lichess puzzle row into a (token_sequence, FEN) pair.
|
| 322 |
+
|
| 323 |
+
Sequence layout: [CLS, setup_move, solver_move1, opp_response, solver_move2, ...]
|
| 324 |
+
|
| 325 |
+
The setup move (Moves[0]) is included as context so the model conditions on it
|
| 326 |
+
when predicting the solution. During training the loss on the setup move position
|
| 327 |
+
is masked out — we model P[m_n | S, m_{<n}] where S is the setup move.
|
| 328 |
+
|
| 329 |
+
The FEN is the puzzle's starting board position. It is persisted alongside the
|
| 330 |
+
token sequence so the CNN-conditioned policy training can reconstruct the
|
| 331 |
+
starting board planes (CNN's input) at __getitem__ time.
|
| 332 |
+
|
| 333 |
+
Returns None if any move is illegal, unknown to the tokenizer, or the sequence
|
| 334 |
+
has fewer than 3 tokens (CLS + setup + at least one solver move).
|
| 335 |
+
"""
|
| 336 |
+
fen = row.get("FEN", "")
|
| 337 |
+
moves_str = row.get("Moves", "")
|
| 338 |
+
if not fen or not moves_str:
|
| 339 |
+
return None
|
| 340 |
+
uci_moves = moves_str.strip().split()
|
| 341 |
+
if len(uci_moves) < 2: # need setup + at least one solver move
|
| 342 |
+
return None
|
| 343 |
+
try:
|
| 344 |
+
board = chess.Board(fen)
|
| 345 |
+
except ValueError:
|
| 346 |
+
return None
|
| 347 |
+
|
| 348 |
+
# Tokenize all moves: setup first (as context), then the full solution.
|
| 349 |
+
token_ids: list[int] = [cls_id]
|
| 350 |
+
for uci in uci_moves:
|
| 351 |
+
try:
|
| 352 |
+
move = chess.Move.from_uci(uci)
|
| 353 |
+
except ValueError:
|
| 354 |
+
return None
|
| 355 |
+
if move not in board.legal_moves:
|
| 356 |
+
return None
|
| 357 |
+
if uci not in tokenizer_symbol_map:
|
| 358 |
+
return None
|
| 359 |
+
token_ids.append(tokenizer_symbol_map[uci])
|
| 360 |
+
board.push(move)
|
| 361 |
+
|
| 362 |
+
if len(token_ids) < 3: # CLS + setup + at least one solver move
|
| 363 |
+
return None
|
| 364 |
+
return token_ids, fen
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
def stage3_stockfish_samples(args: argparse.Namespace) -> None:
|
| 368 |
+
meta_path = args.out_dir / "stockfish_meta.pt"
|
| 369 |
+
if meta_path.exists() and not args.force:
|
| 370 |
+
print(f"Stage 3: skipping — {meta_path.name} exists.")
|
| 371 |
+
return
|
| 372 |
+
|
| 373 |
+
games_path = args.out_dir / "games_stockfish.pt"
|
| 374 |
+
tokenizer_path = args.out_dir / "tokenizer.pt"
|
| 375 |
+
print(f"Stage 3: loading {games_path} and {tokenizer_path}...")
|
| 376 |
+
games = torch.load(games_path, weights_only=False)
|
| 377 |
+
tokenizer = torch.load(tokenizer_path, weights_only=False)
|
| 378 |
+
|
| 379 |
+
print(
|
| 380 |
+
f"Stage 3: running parallel Stockfish ({args.workers} workers, "
|
| 381 |
+
f"depth {args.stockfish_depth}) on {len(games):,} games..."
|
| 382 |
+
)
|
| 383 |
+
samples = generate_samples_stockfish_parallel(
|
| 384 |
+
games,
|
| 385 |
+
tokenizer,
|
| 386 |
+
num_workers=args.workers,
|
| 387 |
+
stockfish_depth=args.stockfish_depth,
|
| 388 |
+
sample_rate=args.sample_rate,
|
| 389 |
+
skew_exponent=args.position_skew,
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
print(f"Stage 3: saving {len(samples):,} stockfish samples as memmap...")
|
| 393 |
+
_save_as_memmap(samples, args.out_dir, "stockfish", max_seq_len=args.max_seq_len)
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
def stage4_policy_sequences(args: argparse.Namespace) -> None:
|
| 397 |
+
meta_path = args.out_dir / "policy_meta.pt"
|
| 398 |
+
if meta_path.exists() and not args.force:
|
| 399 |
+
print(f"Stage 4: skipping — {meta_path.name} exists.")
|
| 400 |
+
return
|
| 401 |
+
|
| 402 |
+
games_path = args.out_dir / "games_outcome.pt"
|
| 403 |
+
tokenizer_path = args.out_dir / "tokenizer.pt"
|
| 404 |
+
print(f"Stage 4: loading {games_path} and {tokenizer_path}...")
|
| 405 |
+
games = torch.load(games_path, weights_only=False)
|
| 406 |
+
tokenizer = torch.load(tokenizer_path, weights_only=False)
|
| 407 |
+
|
| 408 |
+
print(f"Stage 4: tokenizing {len(games):,} games into policy sequences...")
|
| 409 |
+
sequences = _generate_policy_sequences(games, tokenizer, max_seq_len=args.max_seq_len)
|
| 410 |
+
print(f"Stage 4: saving {len(sequences):,} policy sequences as memmap...")
|
| 411 |
+
_save_policy_memmap(sequences, args.out_dir, "policy", max_seq_len=args.max_seq_len)
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
def _write_test_subset_reward(out_dir: Path, src_name: str, dst_name: str, indices: np.ndarray) -> None:
|
| 415 |
+
"""Write a subset of a reward memmap (tokens+labels+lengths) to new files."""
|
| 416 |
+
meta = torch.load(out_dir / f"{src_name}_meta.pt", weights_only=True)
|
| 417 |
+
n_src, max_len = meta["n"], meta["max_len"]
|
| 418 |
+
src_tokens = np.memmap(out_dir / f"{src_name}_tokens.bin", dtype=np.int32, mode="r", shape=(n_src, max_len))
|
| 419 |
+
src_labels = np.memmap(out_dir / f"{src_name}_labels.bin", dtype=np.float32, mode="r", shape=(n_src,))
|
| 420 |
+
src_lengths = np.memmap(out_dir / f"{src_name}_lengths.bin", dtype=np.int32, mode="r", shape=(n_src,))
|
| 421 |
+
n_test = len(indices)
|
| 422 |
+
dst_tokens = np.memmap(out_dir / f"{dst_name}_tokens.bin", dtype=np.int32, mode="w+", shape=(n_test, max_len))
|
| 423 |
+
dst_labels = np.memmap(out_dir / f"{dst_name}_labels.bin", dtype=np.float32, mode="w+", shape=(n_test,))
|
| 424 |
+
dst_lengths = np.memmap(out_dir / f"{dst_name}_lengths.bin", dtype=np.int32, mode="w+", shape=(n_test,))
|
| 425 |
+
for i, idx in enumerate(tqdm(indices, desc=f" writing {dst_name}", unit="sample")):
|
| 426 |
+
dst_tokens[i] = src_tokens[idx]
|
| 427 |
+
dst_labels[i] = src_labels[idx]
|
| 428 |
+
dst_lengths[i] = src_lengths[idx]
|
| 429 |
+
dst_tokens.flush()
|
| 430 |
+
dst_labels.flush()
|
| 431 |
+
dst_lengths.flush()
|
| 432 |
+
torch.save({"n": n_test, "max_len": max_len}, out_dir / f"{dst_name}_meta.pt")
|
| 433 |
+
print(f" {dst_name}: {n_test:,} samples written")
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
def _write_test_subset_policy(out_dir: Path, src_name: str, dst_name: str, indices: np.ndarray) -> None:
|
| 437 |
+
"""Write a subset of a policy memmap (tokens+lengths, no labels) to new files."""
|
| 438 |
+
meta = torch.load(out_dir / f"{src_name}_meta.pt", weights_only=True)
|
| 439 |
+
n_src, max_len = meta["n"], meta["max_len"]
|
| 440 |
+
src_tokens = np.memmap(out_dir / f"{src_name}_tokens.bin", dtype=np.int32, mode="r", shape=(n_src, max_len))
|
| 441 |
+
src_lengths = np.memmap(out_dir / f"{src_name}_lengths.bin", dtype=np.int32, mode="r", shape=(n_src,))
|
| 442 |
+
n_test = len(indices)
|
| 443 |
+
dst_tokens = np.memmap(out_dir / f"{dst_name}_tokens.bin", dtype=np.int32, mode="w+", shape=(n_test, max_len))
|
| 444 |
+
dst_lengths = np.memmap(out_dir / f"{dst_name}_lengths.bin", dtype=np.int32, mode="w+", shape=(n_test,))
|
| 445 |
+
for i, idx in enumerate(tqdm(indices, desc=f" writing {dst_name}", unit="seq")):
|
| 446 |
+
dst_tokens[i] = src_tokens[idx]
|
| 447 |
+
dst_lengths[i] = src_lengths[idx]
|
| 448 |
+
dst_tokens.flush()
|
| 449 |
+
dst_lengths.flush()
|
| 450 |
+
torch.save({"n": n_test, "max_len": max_len}, out_dir / f"{dst_name}_meta.pt")
|
| 451 |
+
print(f" {dst_name}: {n_test:,} sequences written")
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
def stage_build_test_splits(args: argparse.Namespace, out_dir: Path) -> None:
|
| 455 |
+
"""Build held-out test sets for reward and policy models from existing memmaps.
|
| 456 |
+
|
| 457 |
+
Uses a fixed random seed (42) so the same indices are always selected.
|
| 458 |
+
Saves the chosen indices to {name}_test_indices.npy so the corresponding
|
| 459 |
+
training memmap loader can exclude them — making train and test disjoint
|
| 460 |
+
even though they share the underlying .bin file.
|
| 461 |
+
|
| 462 |
+
Produces:
|
| 463 |
+
stockfish_test_*.bin / stockfish_test_meta.pt / stockfish_test_indices.npy
|
| 464 |
+
policy_test_*.bin / policy_test_meta.pt / policy_test_indices.npy
|
| 465 |
+
"""
|
| 466 |
+
rng = np.random.default_rng(42)
|
| 467 |
+
policy_only = getattr(args, "policy_only", False)
|
| 468 |
+
|
| 469 |
+
# Reward test set — skipped when --policy-only since no Stockfish data exists.
|
| 470 |
+
reward_test_meta = out_dir / "stockfish_test_meta.pt"
|
| 471 |
+
sf_meta_path = out_dir / "stockfish_meta.pt"
|
| 472 |
+
if policy_only:
|
| 473 |
+
print("Test splits: stockfish_test skipped (--policy-only).")
|
| 474 |
+
elif (not reward_test_meta.exists() or args.force) and sf_meta_path.exists():
|
| 475 |
+
print(f"Test splits: building stockfish_test ({args.reward_test_size:,} samples)...")
|
| 476 |
+
sf_meta = torch.load(sf_meta_path, weights_only=True)
|
| 477 |
+
n = sf_meta["n"]
|
| 478 |
+
test_n = min(args.reward_test_size, n)
|
| 479 |
+
idx = rng.choice(n, size=test_n, replace=False)
|
| 480 |
+
idx.sort()
|
| 481 |
+
_write_test_subset_reward(out_dir, "stockfish", "stockfish_test", idx)
|
| 482 |
+
np.save(out_dir / "stockfish_test_indices.npy", idx)
|
| 483 |
+
print(f" saved stockfish_test_indices.npy ({test_n:,} indices excluded from training)")
|
| 484 |
+
elif reward_test_meta.exists():
|
| 485 |
+
print("Test splits: stockfish_test already exists, skipping.")
|
| 486 |
+
|
| 487 |
+
# Policy test set
|
| 488 |
+
policy_test_meta = out_dir / "policy_test_meta.pt"
|
| 489 |
+
pol_meta_path = out_dir / "policy_meta.pt"
|
| 490 |
+
if (not policy_test_meta.exists() or args.force) and pol_meta_path.exists():
|
| 491 |
+
print(f"Test splits: building policy_test ({args.policy_test_size:,} sequences)...")
|
| 492 |
+
pol_meta = torch.load(pol_meta_path, weights_only=True)
|
| 493 |
+
n = pol_meta["n"]
|
| 494 |
+
test_n = min(args.policy_test_size, n)
|
| 495 |
+
idx = rng.choice(n, size=test_n, replace=False)
|
| 496 |
+
idx.sort()
|
| 497 |
+
_write_test_subset_policy(out_dir, "policy", "policy_test", idx)
|
| 498 |
+
np.save(out_dir / "policy_test_indices.npy", idx)
|
| 499 |
+
print(f" saved policy_test_indices.npy ({test_n:,} indices excluded from training)")
|
| 500 |
+
elif policy_test_meta.exists():
|
| 501 |
+
print("Test splits: policy_test already exists, skipping.")
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
def stage5_puzzle_samples(args: argparse.Namespace, tokenizer, out_dir: Path) -> None:
|
| 505 |
+
train_done = (out_dir / "puzzle_meta.pt").exists()
|
| 506 |
+
test_done = (out_dir / "puzzle_test_meta.pt").exists()
|
| 507 |
+
if train_done and test_done and not args.force:
|
| 508 |
+
print("Stage 5: skipping — puzzle_meta.pt and puzzle_test_meta.pt exist.")
|
| 509 |
+
return
|
| 510 |
+
|
| 511 |
+
print("Stage 5: loading Lichess/chess-puzzles from HuggingFace...")
|
| 512 |
+
ds = load_dataset("Lichess/chess-puzzles", split="train", streaming=True)
|
| 513 |
+
|
| 514 |
+
min_pop = args.min_puzzle_popularity
|
| 515 |
+
min_plays = args.min_puzzle_plays
|
| 516 |
+
cls_id = tokenizer.symbol_to_token[CLS_TOKEN]
|
| 517 |
+
sym_map = tokenizer.symbol_to_token
|
| 518 |
+
test_seqs: list[list[int]] = []
|
| 519 |
+
test_fens: list[str] = []
|
| 520 |
+
train_seqs: list[list[int]] = []
|
| 521 |
+
train_fens: list[str] = []
|
| 522 |
+
skipped = 0
|
| 523 |
+
test_target = args.puzzle_test_size
|
| 524 |
+
train_target = args.puzzle_count
|
| 525 |
+
|
| 526 |
+
with tqdm(ds, desc="Stage 5: puzzles", unit="puzzle") as pbar:
|
| 527 |
+
for row in pbar:
|
| 528 |
+
if min_pop is not None and row.get("Popularity", 0) < min_pop:
|
| 529 |
+
continue
|
| 530 |
+
if min_plays is not None and row.get("NbPlays", 0) < min_plays:
|
| 531 |
+
continue
|
| 532 |
+
result = _process_puzzle(row, sym_map, cls_id)
|
| 533 |
+
if result is None:
|
| 534 |
+
skipped += 1
|
| 535 |
+
continue
|
| 536 |
+
seq, fen = result
|
| 537 |
+
if len(test_seqs) < test_target:
|
| 538 |
+
test_seqs.append(seq)
|
| 539 |
+
test_fens.append(fen)
|
| 540 |
+
else:
|
| 541 |
+
train_seqs.append(seq)
|
| 542 |
+
train_fens.append(fen)
|
| 543 |
+
pbar.set_postfix(test=len(test_seqs), train=len(train_seqs), skipped=skipped)
|
| 544 |
+
if train_target is not None and len(train_seqs) >= train_target:
|
| 545 |
+
break
|
| 546 |
+
|
| 547 |
+
print(
|
| 548 |
+
f"Stage 5: test={len(test_seqs):,} puzzles, "
|
| 549 |
+
f"train={len(train_seqs):,} puzzles, "
|
| 550 |
+
f"skipped={skipped:,} invalid."
|
| 551 |
+
)
|
| 552 |
+
if test_seqs and not test_done:
|
| 553 |
+
_save_policy_memmap(
|
| 554 |
+
test_seqs, out_dir, "puzzle_test", max_seq_len=args.max_seq_len, fens=test_fens,
|
| 555 |
+
)
|
| 556 |
+
if train_seqs and not train_done:
|
| 557 |
+
_save_policy_memmap(
|
| 558 |
+
train_seqs, out_dir, "puzzle", max_seq_len=args.max_seq_len, fens=train_fens,
|
| 559 |
+
)
|
| 560 |
+
elif not train_seqs:
|
| 561 |
+
print("Stage 5: WARNING — no training puzzles collected.")
|
| 562 |
+
|
| 563 |
+
|
| 564 |
+
def main():
|
| 565 |
+
parser = argparse.ArgumentParser(description=__doc__)
|
| 566 |
+
parser.add_argument("--out-dir", type=Path, default=Path("data"))
|
| 567 |
+
parser.add_argument("--policy-games", type=int, default=1_000_000,
|
| 568 |
+
help="Number of games to collect for policy model training")
|
| 569 |
+
parser.add_argument("--reward-games", type=int, default=1_000_000,
|
| 570 |
+
help="Number of games to collect for reward model (Stockfish eval)")
|
| 571 |
+
parser.add_argument("--policy-min-elo", type=int, default=1800,
|
| 572 |
+
help="Min Elo for both players in policy training games")
|
| 573 |
+
parser.add_argument("--reward-min-elo", type=int, default=1500,
|
| 574 |
+
help="Min Elo for both players in reward model training games")
|
| 575 |
+
parser.add_argument("--sample-rate", type=float, default=0.25,
|
| 576 |
+
help="Fraction of positions to sample per game (scales with game length)")
|
| 577 |
+
parser.add_argument("--position-skew", type=float, default=1.5,
|
| 578 |
+
help="Power-law exponent weighting later positions; 1.0=linear, higher=more mid/late")
|
| 579 |
+
parser.add_argument("--workers", type=int, default=16)
|
| 580 |
+
parser.add_argument("--stockfish-depth", type=int, default=12)
|
| 581 |
+
parser.add_argument("--max-seq-len", type=int, default=128,
|
| 582 |
+
help="Truncate token sequences to this length when writing .bin files")
|
| 583 |
+
parser.add_argument(
|
| 584 |
+
"--force",
|
| 585 |
+
action="store_true",
|
| 586 |
+
help="Re-run all stages even if their outputs already exist",
|
| 587 |
+
)
|
| 588 |
+
parser.add_argument("--puzzle-count", type=int, default=None, dest="puzzle_count",
|
| 589 |
+
help="Max puzzles to include (default: all ~4.99M)")
|
| 590 |
+
parser.add_argument("--min-puzzle-popularity", type=int, default=None, dest="min_puzzle_popularity",
|
| 591 |
+
help="Min Lichess Popularity score (0-100 scale)")
|
| 592 |
+
parser.add_argument("--min-puzzle-plays", type=int, default=None, dest="min_puzzle_plays",
|
| 593 |
+
help="Min NbPlays for a puzzle to be included")
|
| 594 |
+
parser.add_argument("--skip-puzzles", action="store_true",
|
| 595 |
+
help="Skip Stage 5 puzzle processing")
|
| 596 |
+
parser.add_argument("--puzzles-only", action="store_true",
|
| 597 |
+
help="Only run Stage 5 (puzzle processing). Skips game collection, "
|
| 598 |
+
"outcome/Stockfish/policy memmaps, and test splits. Requires "
|
| 599 |
+
"tokenizer.pt to exist (or it will be built from the UCI vocab).")
|
| 600 |
+
parser.add_argument("--policy-only", action="store_true",
|
| 601 |
+
help="Skip everything Stockfish/reward-related: Stage 1 collects only "
|
| 602 |
+
"policy games, Stage 3 (Stockfish labeling) is skipped, and the "
|
| 603 |
+
"stockfish_test split is not built. Stages 1/2/4/5 + policy_test "
|
| 604 |
+
"still run, producing tokenizer.pt, policy_* / puzzle_* memmaps, "
|
| 605 |
+
"and the policy_test split.")
|
| 606 |
+
parser.add_argument("--puzzle-test-size", type=int, default=100_000, dest="puzzle_test_size",
|
| 607 |
+
help="Number of puzzle sequences held out for the test set (default: 100000)")
|
| 608 |
+
parser.add_argument("--reward-test-size", type=int, default=50_000, dest="reward_test_size",
|
| 609 |
+
help="Number of reward positions held out for the test set (default: 50000)")
|
| 610 |
+
parser.add_argument("--policy-test-size", type=int, default=50_000, dest="policy_test_size",
|
| 611 |
+
help="Number of policy sequences held out for the test set (default: 50000)")
|
| 612 |
+
args = parser.parse_args()
|
| 613 |
+
|
| 614 |
+
args.out_dir.mkdir(parents=True, exist_ok=True)
|
| 615 |
+
|
| 616 |
+
if args.puzzles_only and args.policy_only:
|
| 617 |
+
parser.error("--puzzles-only and --policy-only are mutually exclusive.")
|
| 618 |
+
|
| 619 |
+
if args.puzzles_only:
|
| 620 |
+
print("--puzzles-only: skipping Stages 1-4 and test-split builder.")
|
| 621 |
+
tokenizer_path = args.out_dir / "tokenizer.pt"
|
| 622 |
+
if tokenizer_path.exists():
|
| 623 |
+
tokenizer = torch.load(tokenizer_path, weights_only=False)
|
| 624 |
+
else:
|
| 625 |
+
# Tokenizer is just the enumerated UCI vocab — no games needed.
|
| 626 |
+
print(" tokenizer.pt missing; building from UCI vocab...")
|
| 627 |
+
tokenizer = build_tokenizer_from_games()
|
| 628 |
+
torch.save(tokenizer, tokenizer_path)
|
| 629 |
+
stage5_puzzle_samples(args, tokenizer, args.out_dir)
|
| 630 |
+
elif args.policy_only:
|
| 631 |
+
print("--policy-only: skipping Stage 3 (Stockfish labeling) and stockfish_test split.")
|
| 632 |
+
stage1_collect_games(args)
|
| 633 |
+
stage2_outcome_samples(args)
|
| 634 |
+
stage4_policy_sequences(args)
|
| 635 |
+
|
| 636 |
+
tokenizer_path = args.out_dir / "tokenizer.pt"
|
| 637 |
+
if not args.skip_puzzles and tokenizer_path.exists():
|
| 638 |
+
tokenizer = torch.load(tokenizer_path, weights_only=False)
|
| 639 |
+
stage5_puzzle_samples(args, tokenizer, args.out_dir)
|
| 640 |
+
elif not args.skip_puzzles:
|
| 641 |
+
print("Stage 5: skipping — tokenizer.pt not found (run stages 1-2 first).")
|
| 642 |
+
|
| 643 |
+
stage_build_test_splits(args, args.out_dir)
|
| 644 |
+
else:
|
| 645 |
+
stage1_collect_games(args)
|
| 646 |
+
stage2_outcome_samples(args)
|
| 647 |
+
stage3_stockfish_samples(args)
|
| 648 |
+
stage4_policy_sequences(args)
|
| 649 |
+
|
| 650 |
+
tokenizer_path = args.out_dir / "tokenizer.pt"
|
| 651 |
+
if not args.skip_puzzles and tokenizer_path.exists():
|
| 652 |
+
tokenizer = torch.load(tokenizer_path, weights_only=False)
|
| 653 |
+
stage5_puzzle_samples(args, tokenizer, args.out_dir)
|
| 654 |
+
elif not args.skip_puzzles:
|
| 655 |
+
print("Stage 5: skipping — tokenizer.pt not found (run stages 1-2 first).")
|
| 656 |
+
|
| 657 |
+
stage_build_test_splits(args, args.out_dir)
|
| 658 |
+
|
| 659 |
+
print("\nAll stages complete. Artifacts:")
|
| 660 |
+
for name in (
|
| 661 |
+
"games_outcome.pt",
|
| 662 |
+
"games_stockfish.pt",
|
| 663 |
+
"tokenizer.pt",
|
| 664 |
+
"outcome_tokens.bin",
|
| 665 |
+
"outcome_labels.bin",
|
| 666 |
+
"outcome_lengths.bin",
|
| 667 |
+
"outcome_meta.pt",
|
| 668 |
+
"stockfish_tokens.bin",
|
| 669 |
+
"stockfish_labels.bin",
|
| 670 |
+
"stockfish_lengths.bin",
|
| 671 |
+
"stockfish_meta.pt",
|
| 672 |
+
"policy_tokens.bin",
|
| 673 |
+
"policy_lengths.bin",
|
| 674 |
+
"policy_meta.pt",
|
| 675 |
+
"puzzle_tokens.bin",
|
| 676 |
+
"puzzle_lengths.bin",
|
| 677 |
+
"puzzle_fens.bin",
|
| 678 |
+
"puzzle_meta.pt",
|
| 679 |
+
"puzzle_test_tokens.bin",
|
| 680 |
+
"puzzle_test_lengths.bin",
|
| 681 |
+
"puzzle_test_fens.bin",
|
| 682 |
+
"puzzle_test_meta.pt",
|
| 683 |
+
"stockfish_test_tokens.bin",
|
| 684 |
+
"stockfish_test_labels.bin",
|
| 685 |
+
"stockfish_test_lengths.bin",
|
| 686 |
+
"stockfish_test_meta.pt",
|
| 687 |
+
"policy_test_tokens.bin",
|
| 688 |
+
"policy_test_lengths.bin",
|
| 689 |
+
"policy_test_meta.pt",
|
| 690 |
+
):
|
| 691 |
+
path = args.out_dir / name
|
| 692 |
+
size_mb = path.stat().st_size / 1024 / 1024 if path.exists() else 0
|
| 693 |
+
print(f" {path} ({size_mb:.1f} MB)")
|
| 694 |
+
|
| 695 |
+
|
| 696 |
+
if __name__ == "__main__":
|
| 697 |
+
main()
|
src/minimax.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import chess
|
| 2 |
+
import math
|
| 3 |
+
from typing import Callable
|
| 4 |
+
|
| 5 |
+
from model import PIECE_VALUES
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def dummy_reward_fn(board: chess.Board) -> float:
|
| 9 |
+
"""Material-count heuristic: positive favors white."""
|
| 10 |
+
score = 0.0
|
| 11 |
+
for piece_type in PIECE_VALUES:
|
| 12 |
+
score += len(board.pieces(piece_type, chess.WHITE)) * PIECE_VALUES[piece_type]
|
| 13 |
+
score -= len(board.pieces(piece_type, chess.BLACK)) * PIECE_VALUES[piece_type]
|
| 14 |
+
return math.tanh(score / 10.0)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class MinimaxSearch:
|
| 18 |
+
"""Minimax search with top-N move pruning.
|
| 19 |
+
|
| 20 |
+
At each node, evaluates all legal moves with the reward function,
|
| 21 |
+
keeps the top N candidates, and recurses to the given depth.
|
| 22 |
+
Alternates between maximizing (white) and minimizing (black).
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
reward_fn: Callable[[chess.Board], float],
|
| 28 |
+
depth: int = 3,
|
| 29 |
+
top_n: int = 5,
|
| 30 |
+
):
|
| 31 |
+
self.reward_fn = reward_fn
|
| 32 |
+
self.depth = depth
|
| 33 |
+
self.top_n = top_n
|
| 34 |
+
|
| 35 |
+
def search(self, board: chess.Board) -> chess.Move:
|
| 36 |
+
"""Return the best move for the current side to play."""
|
| 37 |
+
legal_moves = list(board.legal_moves)
|
| 38 |
+
if not legal_moves:
|
| 39 |
+
raise ValueError("No legal moves available")
|
| 40 |
+
if len(legal_moves) == 1:
|
| 41 |
+
return legal_moves[0]
|
| 42 |
+
|
| 43 |
+
maximizing = board.turn == chess.WHITE
|
| 44 |
+
|
| 45 |
+
# Score every legal move with a shallow reward evaluation
|
| 46 |
+
scored_moves = []
|
| 47 |
+
for move in legal_moves:
|
| 48 |
+
board.push(move)
|
| 49 |
+
score = self.reward_fn(board)
|
| 50 |
+
board.pop()
|
| 51 |
+
scored_moves.append((score, move))
|
| 52 |
+
|
| 53 |
+
# Keep top N candidates (best for current side)
|
| 54 |
+
scored_moves.sort(key=lambda x: x[0], reverse=maximizing)
|
| 55 |
+
candidates = scored_moves[: self.top_n]
|
| 56 |
+
|
| 57 |
+
# Recurse on each candidate to find the best
|
| 58 |
+
best_move = candidates[0][1]
|
| 59 |
+
best_value = float("-inf") if maximizing else float("inf")
|
| 60 |
+
|
| 61 |
+
for _, move in candidates:
|
| 62 |
+
board.push(move)
|
| 63 |
+
value = self._minimax(board, self.depth - 1, not maximizing)
|
| 64 |
+
board.pop()
|
| 65 |
+
|
| 66 |
+
if maximizing and value > best_value:
|
| 67 |
+
best_value = value
|
| 68 |
+
best_move = move
|
| 69 |
+
elif not maximizing and value < best_value:
|
| 70 |
+
best_value = value
|
| 71 |
+
best_move = move
|
| 72 |
+
|
| 73 |
+
return best_move
|
| 74 |
+
|
| 75 |
+
def _minimax(self, board: chess.Board, depth: int, maximizing: bool) -> float:
|
| 76 |
+
if depth <= 0 or board.is_game_over():
|
| 77 |
+
return self._terminal_eval(board)
|
| 78 |
+
|
| 79 |
+
legal_moves = list(board.legal_moves)
|
| 80 |
+
if not legal_moves:
|
| 81 |
+
return self._terminal_eval(board)
|
| 82 |
+
|
| 83 |
+
# Score all moves, keep top N for the current side
|
| 84 |
+
scored_moves = []
|
| 85 |
+
for move in legal_moves:
|
| 86 |
+
board.push(move)
|
| 87 |
+
score = self.reward_fn(board)
|
| 88 |
+
board.pop()
|
| 89 |
+
scored_moves.append((score, move))
|
| 90 |
+
|
| 91 |
+
scored_moves.sort(key=lambda x: x[0], reverse=maximizing)
|
| 92 |
+
candidates = scored_moves[: self.top_n]
|
| 93 |
+
|
| 94 |
+
if maximizing:
|
| 95 |
+
best = float("-inf")
|
| 96 |
+
for _, move in candidates:
|
| 97 |
+
board.push(move)
|
| 98 |
+
best = max(best, self._minimax(board, depth - 1, False))
|
| 99 |
+
board.pop()
|
| 100 |
+
return best
|
| 101 |
+
else:
|
| 102 |
+
best = float("inf")
|
| 103 |
+
for _, move in candidates:
|
| 104 |
+
board.push(move)
|
| 105 |
+
best = min(best, self._minimax(board, depth - 1, True))
|
| 106 |
+
board.pop()
|
| 107 |
+
return best
|
| 108 |
+
|
| 109 |
+
def _terminal_eval(self, board: chess.Board) -> float:
|
| 110 |
+
"""Evaluate a terminal or leaf node."""
|
| 111 |
+
if board.is_checkmate():
|
| 112 |
+
# The side to move is checkmated
|
| 113 |
+
return -1.0 if board.turn == chess.WHITE else 1.0
|
| 114 |
+
if board.is_game_over():
|
| 115 |
+
return 0.0
|
| 116 |
+
return self.reward_fn(board)
|
src/model.py
ADDED
|
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import chess
|
| 5 |
+
|
| 6 |
+
from src.tokenizer import Tokenizer
|
| 7 |
+
|
| 8 |
+
CLS_TOKEN = "[CLS]"
|
| 9 |
+
PAD_TOKEN = "[PAD]"
|
| 10 |
+
|
| 11 |
+
PIECE_VALUES = {
|
| 12 |
+
chess.PAWN: 1,
|
| 13 |
+
chess.KNIGHT: 3,
|
| 14 |
+
chess.BISHOP: 3,
|
| 15 |
+
chess.ROOK: 5,
|
| 16 |
+
chess.QUEEN: 9,
|
| 17 |
+
chess.KING: 0,
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
BOARD_PLANES = 19
|
| 21 |
+
|
| 22 |
+
def board_to_planes(board: chess.Board) -> torch.Tensor:
|
| 23 |
+
"""chess.Board -> (19, 8, 8) float tensor."""
|
| 24 |
+
planes = torch.zeros(BOARD_PLANES, 8, 8, dtype=torch.float32)
|
| 25 |
+
pieces = [chess.PAWN, chess.KNIGHT, chess.BISHOP, chess.ROOK, chess.QUEEN, chess.KING]
|
| 26 |
+
colors = [chess.WHITE, chess.BLACK]
|
| 27 |
+
piece_to_plane = {(piece, color) : 6 * color_num + piece_num for piece_num, piece in enumerate(pieces) for color_num, color in enumerate(colors)}
|
| 28 |
+
|
| 29 |
+
for sq, piece in board.piece_map().items():
|
| 30 |
+
r, c = chess.square_rank(sq), chess.square_file(sq)
|
| 31 |
+
|
| 32 |
+
planes[piece_to_plane[(piece.piece_type, piece.color)], r, c] = 1.0
|
| 33 |
+
|
| 34 |
+
if board.turn == chess.WHITE:
|
| 35 |
+
planes[12].fill_(1.0)
|
| 36 |
+
|
| 37 |
+
if board.has_kingside_castling_rights(chess.WHITE): planes[13].fill_(1.0)
|
| 38 |
+
if board.has_queenside_castling_rights(chess.WHITE): planes[14].fill_(1.0)
|
| 39 |
+
if board.has_kingside_castling_rights(chess.BLACK): planes[15].fill_(1.0)
|
| 40 |
+
if board.has_queenside_castling_rights(chess.BLACK): planes[16].fill_(1.0)
|
| 41 |
+
if board.ep_square is not None:
|
| 42 |
+
r, c = chess.square_rank(board.ep_square), chess.square_file(board.ep_square)
|
| 43 |
+
planes[17, r, c] = 1.0
|
| 44 |
+
planes[18].fill_(min(board.halfmove_clock, 100) / 100.0)
|
| 45 |
+
|
| 46 |
+
return planes
|
| 47 |
+
|
| 48 |
+
def _group_norm(channels: int, groups: int = 32) -> nn.GroupNorm:
|
| 49 |
+
return nn.GroupNorm(num_groups=min(groups, channels), num_channels=channels)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class ResidualBlock(nn.Module):
|
| 53 |
+
def __init__(self, channels: int):
|
| 54 |
+
super().__init__()
|
| 55 |
+
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1, bias=False)
|
| 56 |
+
self.norm1 = _group_norm(channels)
|
| 57 |
+
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1, bias=False)
|
| 58 |
+
self.norm2 = _group_norm(channels)
|
| 59 |
+
|
| 60 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 61 |
+
h = torch.relu(self.norm1(self.conv1(x)))
|
| 62 |
+
h = self.norm2(self.conv2(h))
|
| 63 |
+
return torch.relu(h + x)
|
| 64 |
+
class BoardCNN(nn.Module):
|
| 65 |
+
def __init__(self, d_model, channels=128, num_blocks=6):
|
| 66 |
+
super().__init__()
|
| 67 |
+
self.stem = nn.Sequential(
|
| 68 |
+
nn.Conv2d(BOARD_PLANES, channels, 3, padding=1, bias=False),
|
| 69 |
+
_group_norm(channels),
|
| 70 |
+
nn.ReLU(inplace=True),
|
| 71 |
+
)
|
| 72 |
+
self.blocks = nn.Sequential(*[ResidualBlock(channels) for _ in range(num_blocks)])
|
| 73 |
+
self.proj = nn.Linear(channels, d_model)
|
| 74 |
+
self.square_pos = nn.Embedding(64, d_model)
|
| 75 |
+
|
| 76 |
+
def forward(self, planes : torch.Tensor) -> torch.Tensor:
|
| 77 |
+
x = self.stem(planes)
|
| 78 |
+
x = self.blocks(x) # (N, C, 8, 8)
|
| 79 |
+
x = x.permute(0, 2, 3, 1).reshape(x.size(0), 64, -1) # (n, 64, C)
|
| 80 |
+
x = self.proj(x) + self.square_pos.weight # (n, 64, d_model)
|
| 81 |
+
return x
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class CrossAttnBlock(nn.Module):
|
| 85 |
+
def __init__(self, d_model, n_head, dim_ff, dropout):
|
| 86 |
+
super().__init__()
|
| 87 |
+
self.self_attn = nn.MultiheadAttention(d_model, n_head, dropout = dropout, batch_first=True)
|
| 88 |
+
|
| 89 |
+
self.cross_attn = nn.MultiheadAttention(d_model, n_head, dropout = dropout, batch_first = True)
|
| 90 |
+
|
| 91 |
+
self.ff = nn.Sequential(
|
| 92 |
+
nn.Linear(d_model, dim_ff), nn.GELU(), nn.Linear(dim_ff, d_model)
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 96 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 97 |
+
self.norm3 = nn.LayerNorm(d_model)
|
| 98 |
+
self.drop = nn.Dropout(dropout)
|
| 99 |
+
#Adding this gate which is init to 0 so cross-attn starts disabled
|
| 100 |
+
self.cross_gate = nn.Parameter(torch.zeros(1))
|
| 101 |
+
|
| 102 |
+
def forward(self, moves, board, key_padding_mask, attn_mask):
|
| 103 |
+
"""
|
| 104 |
+
moves: (B, T, d)
|
| 105 |
+
board: (B, T, 64, d) -- per-position K/V banks
|
| 106 |
+
key_padding_mask: (B, T) -- True = padded move position
|
| 107 |
+
attn_mask: (T, T) -- causal mask for self-attn
|
| 108 |
+
"""
|
| 109 |
+
m = self.norm1(moves)
|
| 110 |
+
sa, _ = self.self_attn(m, m, m, attn_mask = attn_mask, key_padding_mask=key_padding_mask, need_weights=False)
|
| 111 |
+
|
| 112 |
+
moves = moves + self.drop(sa)
|
| 113 |
+
|
| 114 |
+
B, T, d = moves.shape
|
| 115 |
+
q = self.norm2(moves).reshape(B * T, 1, d)
|
| 116 |
+
kv = board.reshape(B * T, 64, d)
|
| 117 |
+
ca, _ = self.cross_attn(q, kv, kv, need_weights = False)
|
| 118 |
+
ca = ca.reshape(B, T, d)
|
| 119 |
+
moves = moves + self.drop(self.cross_gate.tanh() * ca)
|
| 120 |
+
|
| 121 |
+
# FFN
|
| 122 |
+
|
| 123 |
+
moves = moves + self.drop(self.ff(self.norm3(moves)))
|
| 124 |
+
return moves
|
| 125 |
+
|
| 126 |
+
class PositionalEncoding(nn.Module):
|
| 127 |
+
def __init__(self, d_model: int, max_len: int = 512, dropout: float = 0.1):
|
| 128 |
+
super().__init__()
|
| 129 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 130 |
+
pe = torch.zeros(max_len, d_model)
|
| 131 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
| 132 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
| 133 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 134 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 135 |
+
pe = pe.unsqueeze(0) # (1, max_len, d_model)
|
| 136 |
+
self.register_buffer("pe", pe)
|
| 137 |
+
|
| 138 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 139 |
+
x = x + self.pe[:, :x.size(1)]
|
| 140 |
+
return self.dropout(x)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
class ChessRewardModel(nn.Module):
|
| 144 |
+
def __init__(
|
| 145 |
+
self,
|
| 146 |
+
vocab_size: int,
|
| 147 |
+
d_model: int = 768,
|
| 148 |
+
nhead: int = 12,
|
| 149 |
+
num_layers: int = 8,
|
| 150 |
+
dim_feedforward: int = 3072,
|
| 151 |
+
max_seq_len: int = 128,
|
| 152 |
+
dropout: float = 0.1,
|
| 153 |
+
):
|
| 154 |
+
super().__init__()
|
| 155 |
+
self.token_embedding = nn.Embedding(vocab_size, d_model)
|
| 156 |
+
self.pos_encoding = PositionalEncoding(d_model, max_seq_len, dropout)
|
| 157 |
+
encoder_layer = nn.TransformerEncoderLayer(
|
| 158 |
+
d_model, nhead, dim_feedforward, dropout, batch_first=True
|
| 159 |
+
)
|
| 160 |
+
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
|
| 161 |
+
self.reward_head = nn.Linear(d_model, 1)
|
| 162 |
+
|
| 163 |
+
def forward(
|
| 164 |
+
self,
|
| 165 |
+
token_ids: torch.Tensor,
|
| 166 |
+
attention_mask: torch.Tensor | None = None,
|
| 167 |
+
) -> torch.Tensor:
|
| 168 |
+
"""
|
| 169 |
+
Args:
|
| 170 |
+
token_ids: (batch, seq_len) int tensor with CLS prepended
|
| 171 |
+
attention_mask: (batch, seq_len) bool tensor, True where padded
|
| 172 |
+
Returns:
|
| 173 |
+
(batch,) float tensor bounded to [-1, 1]
|
| 174 |
+
"""
|
| 175 |
+
x = self.token_embedding(token_ids)
|
| 176 |
+
x = self.pos_encoding(x)
|
| 177 |
+
x = self.encoder(x, src_key_padding_mask=attention_mask)
|
| 178 |
+
cls_hidden = x[:, 0, :] # CLS token at position 0
|
| 179 |
+
reward = self.reward_head(cls_hidden).squeeze(-1)
|
| 180 |
+
return torch.tanh(reward)
|
| 181 |
+
|
| 182 |
+
class ChessPolicyModel(nn.Module):
|
| 183 |
+
"""Causal next-move predictor with per-position live-board cross-attention.
|
| 184 |
+
|
| 185 |
+
Two streams flow through every block:
|
| 186 |
+
- Move stream: token embeddings + sinusoidal positional encoding, doing
|
| 187 |
+
causal self-attention over the move history.
|
| 188 |
+
- Board stream: a (B, T, 64, d_model) bank of CNN-encoded board features
|
| 189 |
+
where bank `t` is the state after token_ids[1..t] have been played.
|
| 190 |
+
At each block, the move query at position t cross-attends only to its
|
| 191 |
+
own 64 board-square keys — implicit causality via data layout, no
|
| 192 |
+
masking needed.
|
| 193 |
+
|
| 194 |
+
The board representation never depends on a token the model is being
|
| 195 |
+
asked to predict, so multi-position LM-style training is leak-safe.
|
| 196 |
+
"""
|
| 197 |
+
def __init__(
|
| 198 |
+
self,
|
| 199 |
+
vocab_size: int,
|
| 200 |
+
d_model: int = 768,
|
| 201 |
+
nhead: int = 12,
|
| 202 |
+
num_layers: int = 8,
|
| 203 |
+
dim_feedforward: int = 3072,
|
| 204 |
+
max_seq_len: int = 128,
|
| 205 |
+
dropout: float = 0.1,
|
| 206 |
+
cnn_channels: int = 128,
|
| 207 |
+
cnn_blocks: int = 6,
|
| 208 |
+
):
|
| 209 |
+
super().__init__()
|
| 210 |
+
self.vocab_size = vocab_size
|
| 211 |
+
self.token_embedding = nn.Embedding(vocab_size, d_model)
|
| 212 |
+
self.board_cnn = BoardCNN(d_model, cnn_channels, cnn_blocks)
|
| 213 |
+
self.pos_encoding = PositionalEncoding(d_model, max_seq_len, dropout)
|
| 214 |
+
self.blocks = nn.ModuleList([
|
| 215 |
+
CrossAttnBlock(d_model, nhead, dim_feedforward, dropout)
|
| 216 |
+
for _ in range(num_layers)
|
| 217 |
+
])
|
| 218 |
+
self.norm_out = nn.LayerNorm(d_model)
|
| 219 |
+
self.prob_head = nn.Linear(d_model, vocab_size)
|
| 220 |
+
|
| 221 |
+
def forward(
|
| 222 |
+
self,
|
| 223 |
+
token_ids: torch.Tensor,
|
| 224 |
+
board_planes: torch.Tensor,
|
| 225 |
+
attention_mask: torch.Tensor | None = None,
|
| 226 |
+
) -> torch.Tensor:
|
| 227 |
+
"""
|
| 228 |
+
Args:
|
| 229 |
+
token_ids: (B, T) int — CLS at position 0
|
| 230 |
+
board_planes: (B, T, 19, 8, 8) float — per-position live planes;
|
| 231 |
+
planes[:, t] is the board state after token_ids[1..t]
|
| 232 |
+
attention_mask: (B, T) bool — True where padded
|
| 233 |
+
Returns:
|
| 234 |
+
(B, T, vocab_size) raw logits at every position
|
| 235 |
+
"""
|
| 236 |
+
B, T = token_ids.shape
|
| 237 |
+
|
| 238 |
+
moves = self.token_embedding(token_ids)
|
| 239 |
+
moves = self.pos_encoding(moves) # (B, T, d)
|
| 240 |
+
|
| 241 |
+
# Vectorize the CNN over (B*T) boards — one big conv batch, not a loop.
|
| 242 |
+
planes_flat = board_planes.reshape(B * T, BOARD_PLANES, 8, 8)
|
| 243 |
+
board_feats = self.board_cnn(planes_flat) # (B*T, 64, d)
|
| 244 |
+
board_feats = board_feats.reshape(B, T, 64, -1) # (B, T, 64, d)
|
| 245 |
+
|
| 246 |
+
# Bool causal mask (True = masked future position) to match the bool
|
| 247 |
+
# key_padding_mask. PyTorch deprecates mixing float and bool masks.
|
| 248 |
+
causal = torch.triu(
|
| 249 |
+
torch.ones(T, T, dtype=torch.bool, device=token_ids.device), diagonal=1
|
| 250 |
+
)
|
| 251 |
+
for blk in self.blocks:
|
| 252 |
+
moves = blk(moves, board_feats, attention_mask, causal)
|
| 253 |
+
|
| 254 |
+
moves = self.norm_out(moves)
|
| 255 |
+
return self.prob_head(moves) # (B, T, vocab)
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
class DummyRewardModel:
|
| 259 |
+
"""Material-count heuristic for MCTS testing."""
|
| 260 |
+
def __call__(self, board: chess.Board) -> float:
|
| 261 |
+
score = 0.0
|
| 262 |
+
for piece_type in PIECE_VALUES:
|
| 263 |
+
score += len(board.pieces(piece_type, chess.WHITE)) * PIECE_VALUES[piece_type]
|
| 264 |
+
score -= len(board.pieces(piece_type, chess.BLACK)) * PIECE_VALUES[piece_type]
|
| 265 |
+
return math.tanh(score / 10.0)
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
class RewardModelInference:
|
| 269 |
+
"""Wraps ChessRewardModel + Tokenizer for use in minimax"""
|
| 270 |
+
def __init__(self, model: ChessRewardModel, tokenizer: Tokenizer, device: str = "cpu"):
|
| 271 |
+
self.model = model
|
| 272 |
+
self.tokenizer = tokenizer
|
| 273 |
+
self.device = device
|
| 274 |
+
self.cls_id = tokenizer.symbol_to_token[CLS_TOKEN]
|
| 275 |
+
self.pad_id = tokenizer.symbol_to_token[PAD_TOKEN]
|
| 276 |
+
self.model.eval()
|
| 277 |
+
|
| 278 |
+
@torch.no_grad()
|
| 279 |
+
def __call__(self, board: chess.Board, max_seq_len: int = 128) -> float:
|
| 280 |
+
moves_uci = [move.uci() for move in board.move_stack]
|
| 281 |
+
|
| 282 |
+
# Keep the most recent moves to stay within the training sequence length.
|
| 283 |
+
# CLS occupies position 0, so cap move history at max_seq_len - 1.
|
| 284 |
+
moves_uci = moves_uci[-(max_seq_len - 1):]
|
| 285 |
+
token_ids = [self.cls_id] + self.tokenizer.encode_moves(moves_uci)
|
| 286 |
+
token_tensor = torch.tensor([token_ids], dtype=torch.long, device=self.device)
|
| 287 |
+
reward = self.model(token_tensor)
|
| 288 |
+
return reward.item()
|
| 289 |
+
|
| 290 |
+
class PolicyModelInference:
|
| 291 |
+
"""Wraps ChessPolicyModel + Tokenizer"""
|
| 292 |
+
|
| 293 |
+
def __init__(self, model: ChessPolicyModel, tokenizer: Tokenizer, device: str = "cpu"):
|
| 294 |
+
self.model = model
|
| 295 |
+
self.tokenizer = tokenizer
|
| 296 |
+
self.device = device
|
| 297 |
+
self.cls_id = tokenizer.symbol_to_token[CLS_TOKEN]
|
| 298 |
+
self.pad_id = tokenizer.symbol_to_token[PAD_TOKEN]
|
| 299 |
+
self.model.eval()
|
| 300 |
+
|
| 301 |
+
@torch.no_grad()
|
| 302 |
+
def __call__(self, board: chess.Board) -> str:
|
| 303 |
+
moves_uci = [move.uci() for move in board.move_stack]
|
| 304 |
+
token_ids = [self.cls_id] + self.tokenizer.encode_moves(moves_uci)
|
| 305 |
+
token_tensor = torch.tensor([token_ids], dtype=torch.long, device=self.device)
|
| 306 |
+
|
| 307 |
+
# Replay the full move history on a fresh board, snapshotting planes
|
| 308 |
+
# at every position. planes[0] = standard starting board (model has
|
| 309 |
+
# only seen [CLS]); planes[t] = state after the first t moves played.
|
| 310 |
+
# This matches the training pipeline (ChessPolicyDataset._replay_planes
|
| 311 |
+
# with start_board=chess.Board()) exactly.
|
| 312 |
+
replay_board = chess.Board()
|
| 313 |
+
plane_list = [board_to_planes(replay_board)]
|
| 314 |
+
for uci in moves_uci:
|
| 315 |
+
replay_board.push(chess.Move.from_uci(uci))
|
| 316 |
+
plane_list.append(board_to_planes(replay_board))
|
| 317 |
+
planes = torch.stack(plane_list).unsqueeze(0).to(self.device) # (1, T, 19, 8, 8)
|
| 318 |
+
|
| 319 |
+
logits = self.model(token_tensor, planes) # (1, T, vocab_size)
|
| 320 |
+
last_logits = logits[0, -1] # last position predicts the next move
|
| 321 |
+
|
| 322 |
+
legal_move_ids = [self.tokenizer.symbol_to_token[move.uci()] for move in board.legal_moves]
|
| 323 |
+
mask = torch.full((self.tokenizer.language_size,), float('-inf'), device=self.device)
|
| 324 |
+
mask[legal_move_ids] = 0.0
|
| 325 |
+
best_move_idx = (last_logits + mask).argmax().item()
|
| 326 |
+
|
| 327 |
+
return self.tokenizer.token_to_symbol[best_move_idx]
|
src/tokenizer.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
import typing
|
| 4 |
+
from collections import deque, defaultdict
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class Tokenizer():
|
| 8 |
+
def __init__(self):
|
| 9 |
+
|
| 10 |
+
self.symbol_set : set = None
|
| 11 |
+
self.symbol_to_token = {}
|
| 12 |
+
self.token_to_symbol = {}
|
| 13 |
+
self.language_size = 0
|
| 14 |
+
self.corpus = None
|
| 15 |
+
def train_tokenizer(self, input, max_language_size: int) -> None:
|
| 16 |
+
if type(input) == str:
|
| 17 |
+
self.corpus = input.split(",")
|
| 18 |
+
else:
|
| 19 |
+
self.corpus = input
|
| 20 |
+
self.symbol_set = set(self.corpus)
|
| 21 |
+
for sym in self.symbol_set:
|
| 22 |
+
self.symbol_to_token[sym] = self.language_size
|
| 23 |
+
self.token_to_symbol[self.language_size] = sym
|
| 24 |
+
self.language_size += 1
|
| 25 |
+
# Converted everythign to tokens from symbolic form
|
| 26 |
+
self.corpus = np.array([self.symbol_to_token[sym] for sym in self.corpus], dtype=int)
|
| 27 |
+
|
| 28 |
+
while self.language_size < max_language_size:
|
| 29 |
+
temp_corpus = self.corpus
|
| 30 |
+
common_pair = None
|
| 31 |
+
highest_pair_count = 0
|
| 32 |
+
pair_counts = defaultdict(int)
|
| 33 |
+
for i in range(len(temp_corpus)-1):
|
| 34 |
+
pair = (temp_corpus[i], temp_corpus[i+1])
|
| 35 |
+
pair_counts[pair] += 1
|
| 36 |
+
if (pair_counts[pair] > highest_pair_count):
|
| 37 |
+
highest_pair_count = pair_counts[pair]
|
| 38 |
+
common_pair = pair
|
| 39 |
+
synthetic_symbol = self.token_to_symbol[common_pair[0]] + self.token_to_symbol[common_pair[1]]
|
| 40 |
+
|
| 41 |
+
self.symbol_to_token[synthetic_symbol] = self.language_size
|
| 42 |
+
self.token_to_symbol[self.language_size] = synthetic_symbol
|
| 43 |
+
|
| 44 |
+
self.language_size += 1
|
| 45 |
+
combine_tokens = deque(temp_corpus)
|
| 46 |
+
self.corpus = []
|
| 47 |
+
|
| 48 |
+
while (len(combine_tokens) > 1):
|
| 49 |
+
first_elem = combine_tokens.popleft()
|
| 50 |
+
second_elem = combine_tokens.popleft()
|
| 51 |
+
|
| 52 |
+
if ((first_elem, second_elem) == common_pair):
|
| 53 |
+
combine_tokens.appendleft(self.language_size - 1)
|
| 54 |
+
|
| 55 |
+
else:
|
| 56 |
+
self.corpus.append(first_elem)
|
| 57 |
+
self.corpus.append(second_elem)
|
| 58 |
+
if (len(combine_tokens) > 0):
|
| 59 |
+
self.corpus.append(combine_tokens.popleft())
|
| 60 |
+
|
| 61 |
+
self.corpus = None
|
| 62 |
+
|
| 63 |
+
def decode(self, tokens: list[int]) -> str:
|
| 64 |
+
return "".join([self.token_to_symbol[t] for t in tokens])
|
| 65 |
+
|
| 66 |
+
def encode(self, message: str):
|
| 67 |
+
char_list = list(message)
|
| 68 |
+
char_inputs = deque(char_list)
|
| 69 |
+
|
| 70 |
+
result_tokens = []
|
| 71 |
+
curr_symbol = ""
|
| 72 |
+
while (len(char_inputs) > 0):
|
| 73 |
+
f_char = char_inputs.popleft()
|
| 74 |
+
curr_symbol += f_char
|
| 75 |
+
|
| 76 |
+
if (curr_symbol not in self.symbol_to_token.keys()):
|
| 77 |
+
curr_symbol = curr_symbol[:-1]
|
| 78 |
+
result_tokens.append(self.symbol_to_token[curr_symbol])
|
| 79 |
+
char_inputs.appendleft(f_char)
|
| 80 |
+
curr_symbol = ""
|
| 81 |
+
if (len(curr_symbol) > 0):
|
| 82 |
+
result_tokens.append(self.symbol_to_token[curr_symbol])
|
| 83 |
+
|
| 84 |
+
return result_tokens
|
| 85 |
+
|
| 86 |
+
def encode_moves(self, moves: list[str]) -> list[int]:
|
| 87 |
+
return [self.symbol_to_token[move] for move in moves]
|
| 88 |
+
|
| 89 |
+
def add_special_tokens(self, tokens: list[str]) -> dict[str, int]:
|
| 90 |
+
mapping = {}
|
| 91 |
+
for tok in tokens:
|
| 92 |
+
self.symbol_to_token[tok] = self.language_size
|
| 93 |
+
self.token_to_symbol[self.language_size] = tok
|
| 94 |
+
mapping[tok] = self.language_size
|
| 95 |
+
self.language_size += 1
|
| 96 |
+
return mapping
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class DataLoader():
|
| 100 |
+
corpus = None
|
| 101 |
+
def __init__(self, file_name: str):
|
| 102 |
+
with open(file_name, "r") as f:
|
| 103 |
+
self.corpus = f.read()
|
| 104 |
+
|
src/train.py
ADDED
|
@@ -0,0 +1,1319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import atexit
|
| 3 |
+
import contextlib
|
| 4 |
+
import math # noqa: F401 (used in eval_policy)
|
| 5 |
+
import multiprocessing as mp
|
| 6 |
+
import re
|
| 7 |
+
import shutil
|
| 8 |
+
import time
|
| 9 |
+
import numpy as np
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
import chess
|
| 12 |
+
import chess.engine
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
from torch.utils.data import Dataset, DataLoader
|
| 16 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 17 |
+
from datasets import load_dataset
|
| 18 |
+
|
| 19 |
+
from tokenizer import Tokenizer
|
| 20 |
+
from model import (
|
| 21 |
+
ChessRewardModel,
|
| 22 |
+
ChessPolicyModel,
|
| 23 |
+
DummyRewardModel,
|
| 24 |
+
CLS_TOKEN,
|
| 25 |
+
PAD_TOKEN,
|
| 26 |
+
board_to_planes,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
STOCKFISH_PATH = shutil.which("stockfish") or "/usr/local/bin/stockfish"
|
| 30 |
+
|
| 31 |
+
RESULT_TOKENS = {"1-0", "0-1", "1/2-1/2", "*"}
|
| 32 |
+
MOVE_NUMBER_RE = re.compile(r"^\d+\.(\.\.)?$")
|
| 33 |
+
# Brace-delimited PGN comments like {[%eval 0.37]} and {[%clk 0:05:00]}.
|
| 34 |
+
# Non-greedy to handle multiple comments in one movetext.
|
| 35 |
+
BRACE_COMMENT_RE = re.compile(r"\{[^}]*\}")
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def normalize_cp(centipawns: int) -> float:
|
| 39 |
+
"""Map centipawn score to [-1, 1] using tanh scaling."""
|
| 40 |
+
return math.tanh(centipawns / 400.0)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def material_eval(board: chess.Board) -> float:
|
| 44 |
+
"""Material-count evaluation as a fallback for Stockfish."""
|
| 45 |
+
return DummyRewardModel()(board)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class StockfishEvaluator:
|
| 49 |
+
"""Wraps a persistent Stockfish engine for batch evaluation."""
|
| 50 |
+
def __init__(self, engine_path: str = STOCKFISH_PATH, depth: int = 15):
|
| 51 |
+
self.engine = chess.engine.SimpleEngine.popen_uci(engine_path)
|
| 52 |
+
self.depth = depth
|
| 53 |
+
|
| 54 |
+
def __call__(self, board: chess.Board) -> float:
|
| 55 |
+
info = self.engine.analyse(board, chess.engine.Limit(depth=self.depth))
|
| 56 |
+
score = info["score"].white()
|
| 57 |
+
if score.is_mate():
|
| 58 |
+
return 1.0 if score.mate() > 0 else -1.0
|
| 59 |
+
return normalize_cp(score.score())
|
| 60 |
+
|
| 61 |
+
def close(self):
|
| 62 |
+
self.engine.quit()
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def parse_movetext(movetext: str) -> list[str]:
|
| 66 |
+
"""Parse PGN movetext into a list of SAN moves.
|
| 67 |
+
|
| 68 |
+
Handles Lichess-style annotations like '{[%eval 0.37]}' and '{[%clk 0:05:00]}'
|
| 69 |
+
by stripping them before tokenization. Without this, annotated games get
|
| 70 |
+
truncated mid-replay when parse_san chokes on comment fragments.
|
| 71 |
+
|
| 72 |
+
Input format: '1. d4 {[%eval 0.13]} d5 2. Nf3 ... 1-0'
|
| 73 |
+
Returns: ['d4', 'd5', 'Nf3', ...]
|
| 74 |
+
"""
|
| 75 |
+
cleaned = BRACE_COMMENT_RE.sub(" ", movetext)
|
| 76 |
+
tokens = cleaned.split()
|
| 77 |
+
moves = []
|
| 78 |
+
for tok in tokens:
|
| 79 |
+
if tok in RESULT_TOKENS:
|
| 80 |
+
continue
|
| 81 |
+
if MOVE_NUMBER_RE.match(tok):
|
| 82 |
+
continue
|
| 83 |
+
moves.append(tok)
|
| 84 |
+
return moves
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def load_filtered_dataset(min_elo: int = 1500, min_rows: int = 100_000):
|
| 88 |
+
"""Load and filter the Lichess HuggingFace dataset.
|
| 89 |
+
|
| 90 |
+
Filters:
|
| 91 |
+
- WhiteElo >= min_elo AND BlackElo >= min_elo
|
| 92 |
+
- Termination == 'Normal'
|
| 93 |
+
|
| 94 |
+
Returns the filtered dataset and raises ValueError if too few rows.
|
| 95 |
+
"""
|
| 96 |
+
print("Loading Lichess dataset from HuggingFace...")
|
| 97 |
+
ds = load_dataset("Lichess/standard-chess-games", split="train", streaming=True)
|
| 98 |
+
|
| 99 |
+
print(f"Filtering for Elo >= {min_elo} and Termination == 'Normal'...")
|
| 100 |
+
ds_filtered = ds.filter(
|
| 101 |
+
lambda row: (
|
| 102 |
+
row["WhiteElo"] is not None
|
| 103 |
+
and row["BlackElo"] is not None
|
| 104 |
+
and row["WhiteElo"] >= min_elo
|
| 105 |
+
and row["BlackElo"] >= min_elo
|
| 106 |
+
and row.get("Termination") == "Normal"
|
| 107 |
+
)
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
# Materialize enough rows to validate the threshold
|
| 111 |
+
print(f"Collecting at least {min_rows:,} filtered games...")
|
| 112 |
+
rows = []
|
| 113 |
+
for row in ds_filtered:
|
| 114 |
+
rows.append(row)
|
| 115 |
+
if len(rows) % 50_000 == 0:
|
| 116 |
+
print(f" collected {len(rows):,} games so far...")
|
| 117 |
+
if len(rows) >= min_rows:
|
| 118 |
+
break
|
| 119 |
+
|
| 120 |
+
if len(rows) < min_rows:
|
| 121 |
+
raise ValueError(
|
| 122 |
+
f"Only found {len(rows):,} games matching filters, "
|
| 123 |
+
f"need at least {min_rows:,}."
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
print(f"Collected {len(rows):,} games (target met).")
|
| 127 |
+
return rows
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def _enumerate_all_uci_moves() -> list[str]:
|
| 131 |
+
"""Enumerate every UCI move string that can legally appear in a chess game.
|
| 132 |
+
|
| 133 |
+
Uses direct geometric enumeration rather than board simulation to avoid
|
| 134 |
+
edge cases where king placement blocks valid destination squares.
|
| 135 |
+
Covers all piece movement patterns: lines (rook/queen), diagonals
|
| 136 |
+
(bishop/queen), L-shapes (knight), and pawn promotions.
|
| 137 |
+
"""
|
| 138 |
+
seen: set[str] = set()
|
| 139 |
+
for from_sq in chess.SQUARES:
|
| 140 |
+
fr = chess.square_rank(from_sq)
|
| 141 |
+
ff = chess.square_file(from_sq)
|
| 142 |
+
for to_sq in chess.SQUARES:
|
| 143 |
+
if from_sq == to_sq:
|
| 144 |
+
continue
|
| 145 |
+
tr = chess.square_rank(to_sq)
|
| 146 |
+
tf = chess.square_file(to_sq)
|
| 147 |
+
dr = abs(tr - fr)
|
| 148 |
+
df = abs(tf - ff)
|
| 149 |
+
|
| 150 |
+
is_line = (dr == 0 or df == 0) # rook / queen
|
| 151 |
+
is_diag = (dr == df) # bishop / queen
|
| 152 |
+
is_knight = (dr == 2 and df == 1) or (dr == 1 and df == 2)
|
| 153 |
+
|
| 154 |
+
if not (is_line or is_diag or is_knight):
|
| 155 |
+
continue
|
| 156 |
+
|
| 157 |
+
seen.add(chess.Move(from_sq, to_sq).uci())
|
| 158 |
+
|
| 159 |
+
# Promotion variants: pawn on 7th rank advancing to 8th (or 2nd→1st)
|
| 160 |
+
if ((fr == 6 and tr == 7) or (fr == 1 and tr == 0)) and df <= 1:
|
| 161 |
+
for promo in (chess.QUEEN, chess.ROOK, chess.BISHOP, chess.KNIGHT):
|
| 162 |
+
seen.add(chess.Move(from_sq, to_sq, promotion=promo).uci())
|
| 163 |
+
|
| 164 |
+
return list(seen)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def _weighted_sample(eligible: list[int], k: int, skew_exponent: float, seed: int) -> set[int]:
|
| 168 |
+
"""Sample k positions from eligible without replacement, skewed toward later positions.
|
| 169 |
+
|
| 170 |
+
Weights grow as (position_rank + 1)^skew_exponent so later positions in a game
|
| 171 |
+
are proportionally more likely to be selected. skew_exponent=1.0 gives linear
|
| 172 |
+
weighting; higher values concentrate more mass at the end of the game.
|
| 173 |
+
"""
|
| 174 |
+
n = len(eligible)
|
| 175 |
+
k = min(k, n)
|
| 176 |
+
if k == n:
|
| 177 |
+
return set(eligible)
|
| 178 |
+
weights = np.array([(i + 1) ** skew_exponent for i in range(n)], dtype=np.float64)
|
| 179 |
+
weights /= weights.sum()
|
| 180 |
+
rng = np.random.default_rng(seed)
|
| 181 |
+
chosen = rng.choice(n, size=k, replace=False, p=weights)
|
| 182 |
+
return {eligible[i] for i in chosen}
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def build_tokenizer_from_games(games: list[dict] | None = None) -> Tokenizer:
|
| 186 |
+
"""Build a move-level tokenizer covering all 1968 UCI moves."""
|
| 187 |
+
uci_moves = _enumerate_all_uci_moves()
|
| 188 |
+
print(f" building tokenizer from {len(set(uci_moves)):,} UCI moves (no BPE)")
|
| 189 |
+
tokenizer = Tokenizer()
|
| 190 |
+
tokenizer.train_tokenizer(uci_moves, max_language_size=len(set(uci_moves)))
|
| 191 |
+
tokenizer.add_special_tokens([CLS_TOKEN, PAD_TOKEN])
|
| 192 |
+
return tokenizer
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def _load_train_idx(out_dir: Path, name: str, n: int) -> np.ndarray | None:
|
| 196 |
+
"""If `{name}_test_indices.npy` exists, return the complement (training-only
|
| 197 |
+
indices into the full memmap). Returns None when no test split is recorded —
|
| 198 |
+
in that case the caller indexes into the memmap directly.
|
| 199 |
+
|
| 200 |
+
Names ending in '_test' always return None (the test memmap should not
|
| 201 |
+
exclude itself).
|
| 202 |
+
"""
|
| 203 |
+
if name.endswith("_test"):
|
| 204 |
+
return None
|
| 205 |
+
test_idx_file = out_dir / f"{name}_test_indices.npy"
|
| 206 |
+
if not test_idx_file.exists():
|
| 207 |
+
return None
|
| 208 |
+
test_idx = np.load(test_idx_file)
|
| 209 |
+
mask = np.ones(n, dtype=bool)
|
| 210 |
+
mask[test_idx] = False
|
| 211 |
+
return np.where(mask)[0]
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
class ChessPositionDataset(Dataset):
|
| 215 |
+
def __init__(
|
| 216 |
+
self,
|
| 217 |
+
games: list[dict],
|
| 218 |
+
tokenizer: Tokenizer,
|
| 219 |
+
eval_fn=material_eval,
|
| 220 |
+
sample_rate: float = 0.25,
|
| 221 |
+
skew_exponent: float = 1.5,
|
| 222 |
+
):
|
| 223 |
+
self.tokenizer = tokenizer
|
| 224 |
+
self.cls_id = tokenizer.symbol_to_token[CLS_TOKEN]
|
| 225 |
+
self.samples: list[tuple[list[int], float]] = []
|
| 226 |
+
self._memmap = False
|
| 227 |
+
self._train_idx: np.ndarray | None = None
|
| 228 |
+
self._generate_samples(games, eval_fn, sample_rate, skew_exponent)
|
| 229 |
+
|
| 230 |
+
def _generate_samples(self, games, eval_fn, sample_rate, skew_exponent):
|
| 231 |
+
for idx, game in enumerate(games):
|
| 232 |
+
movetext = game.get("movetext", "")
|
| 233 |
+
if not movetext:
|
| 234 |
+
continue
|
| 235 |
+
move_sans = parse_movetext(movetext)
|
| 236 |
+
if len(move_sans) < 2:
|
| 237 |
+
continue
|
| 238 |
+
|
| 239 |
+
board = chess.Board()
|
| 240 |
+
eligible = list(range(len(move_sans)))
|
| 241 |
+
# Scale sample count with game length — longer games have more
|
| 242 |
+
# evaluation swings and contribute proportionally more samples.
|
| 243 |
+
num_positions = max(1, int(len(move_sans) * sample_rate))
|
| 244 |
+
# Deterministic weighted sampling seeded by game index so serial
|
| 245 |
+
# and parallel paths produce identical sample sets for the same input.
|
| 246 |
+
sample_indices = _weighted_sample(eligible, num_positions, skew_exponent, seed=idx)
|
| 247 |
+
|
| 248 |
+
valid_moves = []
|
| 249 |
+
for i, san in enumerate(move_sans):
|
| 250 |
+
try:
|
| 251 |
+
move = board.parse_san(san)
|
| 252 |
+
board.push(move)
|
| 253 |
+
valid_moves.append(move.uci())
|
| 254 |
+
except (chess.InvalidMoveError, chess.AmbiguousMoveError):
|
| 255 |
+
break
|
| 256 |
+
|
| 257 |
+
if i in sample_indices:
|
| 258 |
+
token_ids = [self.cls_id] + self.tokenizer.encode_moves(valid_moves)
|
| 259 |
+
score = eval_fn(board)
|
| 260 |
+
self.samples.append((token_ids, score))
|
| 261 |
+
|
| 262 |
+
if (idx + 1) % 10_000 == 0:
|
| 263 |
+
print(f" processed {idx + 1:,} games, {len(self.samples):,} positions...")
|
| 264 |
+
|
| 265 |
+
def __len__(self) -> int:
|
| 266 |
+
if self._memmap:
|
| 267 |
+
if self._train_idx is not None:
|
| 268 |
+
return len(self._train_idx)
|
| 269 |
+
return len(self._mm_labels)
|
| 270 |
+
return len(self.samples)
|
| 271 |
+
|
| 272 |
+
def __getitem__(self, idx: int):
|
| 273 |
+
if self._memmap:
|
| 274 |
+
if self._train_idx is not None:
|
| 275 |
+
idx = int(self._train_idx[idx])
|
| 276 |
+
tokens = torch.from_numpy(np.array(self._mm_tokens[idx], dtype=np.int32)).long()
|
| 277 |
+
length = int(self._mm_lengths[idx])
|
| 278 |
+
mask = torch.arange(tokens.shape[0]) >= length # True = padded
|
| 279 |
+
return tokens, mask, float(self._mm_labels[idx])
|
| 280 |
+
token_ids, score = self.samples[idx]
|
| 281 |
+
return torch.tensor(token_ids, dtype=torch.long), score
|
| 282 |
+
|
| 283 |
+
@classmethod
|
| 284 |
+
def from_samples(cls, samples, tokenizer: Tokenizer):
|
| 285 |
+
"""Build a dataset from pre-generated (token_ids, score) samples."""
|
| 286 |
+
inst = cls.__new__(cls)
|
| 287 |
+
inst.tokenizer = tokenizer
|
| 288 |
+
inst.cls_id = tokenizer.symbol_to_token[CLS_TOKEN]
|
| 289 |
+
inst.samples = list(samples)
|
| 290 |
+
inst._memmap = False
|
| 291 |
+
return inst
|
| 292 |
+
|
| 293 |
+
@classmethod
|
| 294 |
+
def from_file(cls, samples_path: str, tokenizer: Tokenizer):
|
| 295 |
+
"""Load (token_ids, score) samples from a torch.save file."""
|
| 296 |
+
samples = torch.load(samples_path, weights_only=False)
|
| 297 |
+
return cls.from_samples(samples, tokenizer)
|
| 298 |
+
|
| 299 |
+
@classmethod
|
| 300 |
+
def from_memmap(cls, out_dir: Path, name: str, tokenizer: Tokenizer):
|
| 301 |
+
"""Load pre-padded samples from memory-mapped arrays (fast DataLoader path).
|
| 302 |
+
|
| 303 |
+
If a sibling file `{name}_test_indices.npy` exists, those indices are
|
| 304 |
+
excluded from this dataset — used to make training disjoint from the
|
| 305 |
+
held-out test split that shares the same underlying .bin file.
|
| 306 |
+
"""
|
| 307 |
+
meta = torch.load(out_dir / f"{name}_meta.pt", weights_only=True)
|
| 308 |
+
n, max_len = meta["n"], meta["max_len"]
|
| 309 |
+
inst = cls.__new__(cls)
|
| 310 |
+
inst.tokenizer = tokenizer
|
| 311 |
+
inst.cls_id = tokenizer.symbol_to_token[CLS_TOKEN]
|
| 312 |
+
inst._memmap = True
|
| 313 |
+
inst._mm_tokens = np.memmap(out_dir / f"{name}_tokens.bin", dtype=np.int32, mode="r", shape=(n, max_len))
|
| 314 |
+
inst._mm_labels = np.memmap(out_dir / f"{name}_labels.bin", dtype=np.float32, mode="r", shape=(n,))
|
| 315 |
+
inst._mm_lengths = np.memmap(out_dir / f"{name}_lengths.bin", dtype=np.int32, mode="r", shape=(n,))
|
| 316 |
+
inst._train_idx = _load_train_idx(out_dir, name, n)
|
| 317 |
+
return inst
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
# ---------------------------------------------------------------------------
|
| 321 |
+
# Parallel Stockfish-backed sample generation.
|
| 322 |
+
#
|
| 323 |
+
# One Stockfish subprocess per worker process. Each worker:
|
| 324 |
+
# 1. receives a game + its index (used to seed a local random.Random)
|
| 325 |
+
# 2. replays the game, samples positions, tokenizes move prefixes
|
| 326 |
+
# 3. evaluates each sampled position with its own Stockfish engine
|
| 327 |
+
# 4. returns a list of (token_ids, score) tuples
|
| 328 |
+
#
|
| 329 |
+
# The main process collects results via imap_unordered and flattens them.
|
| 330 |
+
# ---------------------------------------------------------------------------
|
| 331 |
+
|
| 332 |
+
# Module-level state populated by _init_worker in each spawned process.
|
| 333 |
+
_worker_engine = None
|
| 334 |
+
_worker_tokenizer = None
|
| 335 |
+
_worker_cls_id = None
|
| 336 |
+
_worker_sample_rate = None
|
| 337 |
+
_worker_skew = None
|
| 338 |
+
_worker_depth = None
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
def _shutdown_worker():
|
| 342 |
+
"""Called at worker exit to cleanly quit the Stockfish engine."""
|
| 343 |
+
global _worker_engine
|
| 344 |
+
if _worker_engine is not None:
|
| 345 |
+
try:
|
| 346 |
+
_worker_engine.quit()
|
| 347 |
+
except Exception:
|
| 348 |
+
pass
|
| 349 |
+
_worker_engine = None
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
def _init_worker(engine_path, depth, tokenizer, cls_id, sample_rate, skew_exponent):
|
| 353 |
+
"""Pool initializer: create one Stockfish engine per worker.
|
| 354 |
+
|
| 355 |
+
If engine_path is None, workers fall back to material_eval. This lets
|
| 356 |
+
tests exercise the parallel machinery without requiring Stockfish.
|
| 357 |
+
"""
|
| 358 |
+
global _worker_engine, _worker_tokenizer, _worker_cls_id
|
| 359 |
+
global _worker_sample_rate, _worker_skew, _worker_depth
|
| 360 |
+
_worker_tokenizer = tokenizer
|
| 361 |
+
_worker_cls_id = cls_id
|
| 362 |
+
_worker_sample_rate = sample_rate
|
| 363 |
+
_worker_skew = skew_exponent
|
| 364 |
+
_worker_depth = depth
|
| 365 |
+
if engine_path is not None:
|
| 366 |
+
_worker_engine = chess.engine.SimpleEngine.popen_uci(engine_path)
|
| 367 |
+
atexit.register(_shutdown_worker)
|
| 368 |
+
else:
|
| 369 |
+
_worker_engine = None
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
def _worker_eval(board: chess.Board) -> float:
|
| 373 |
+
if _worker_engine is None:
|
| 374 |
+
return material_eval(board)
|
| 375 |
+
info = _worker_engine.analyse(board, chess.engine.Limit(depth=_worker_depth))
|
| 376 |
+
score = info["score"].white()
|
| 377 |
+
if score.is_mate():
|
| 378 |
+
return 1.0 if score.mate() > 0 else -1.0
|
| 379 |
+
return normalize_cp(score.score())
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
def _process_game(game_with_seed):
|
| 383 |
+
"""Worker task: parse, replay, sample, tokenize, and evaluate one game."""
|
| 384 |
+
game, seed = game_with_seed
|
| 385 |
+
movetext = game.get("movetext", "")
|
| 386 |
+
if not movetext:
|
| 387 |
+
return []
|
| 388 |
+
move_sans = parse_movetext(movetext)
|
| 389 |
+
if len(move_sans) < 2:
|
| 390 |
+
return []
|
| 391 |
+
|
| 392 |
+
eligible = list(range(len(move_sans)))
|
| 393 |
+
num_positions = max(1, int(len(move_sans) * _worker_sample_rate))
|
| 394 |
+
sample_indices = _weighted_sample(eligible, num_positions, _worker_skew, seed=seed)
|
| 395 |
+
|
| 396 |
+
samples = []
|
| 397 |
+
board = chess.Board()
|
| 398 |
+
valid_moves = []
|
| 399 |
+
for i, san in enumerate(move_sans):
|
| 400 |
+
try:
|
| 401 |
+
move = board.parse_san(san)
|
| 402 |
+
board.push(move)
|
| 403 |
+
valid_moves.append(move.uci())
|
| 404 |
+
except (chess.InvalidMoveError, chess.AmbiguousMoveError):
|
| 405 |
+
break
|
| 406 |
+
|
| 407 |
+
if i in sample_indices:
|
| 408 |
+
token_ids = [_worker_cls_id] + _worker_tokenizer.encode_moves(valid_moves)
|
| 409 |
+
score = _worker_eval(board)
|
| 410 |
+
samples.append((token_ids, score))
|
| 411 |
+
|
| 412 |
+
return samples
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
def generate_samples_stockfish_parallel(
|
| 416 |
+
games: list[dict],
|
| 417 |
+
tokenizer: Tokenizer,
|
| 418 |
+
num_workers: int = 8,
|
| 419 |
+
stockfish_depth: int = 12,
|
| 420 |
+
sample_rate: float = 0.25,
|
| 421 |
+
skew_exponent: float = 1.5,
|
| 422 |
+
engine_path: str | None = STOCKFISH_PATH,
|
| 423 |
+
chunksize: int = 8,
|
| 424 |
+
progress_every: int = 1000,
|
| 425 |
+
) -> list[tuple[list[int], float]]:
|
| 426 |
+
"""Parallel Stockfish-backed sample generation.
|
| 427 |
+
|
| 428 |
+
Spawns `num_workers` processes, each owning one Stockfish subprocess.
|
| 429 |
+
If `engine_path` is None, workers use material_eval instead of Stockfish
|
| 430 |
+
(used by tests to verify the parallel machinery without the binary).
|
| 431 |
+
|
| 432 |
+
Each game contributes `max(1, game_length * sample_rate)` positions,
|
| 433 |
+
weighted toward mid/late game by `skew_exponent`. Sampling is seeded
|
| 434 |
+
per-game-index for determinism across runs and worker counts.
|
| 435 |
+
"""
|
| 436 |
+
cls_id = tokenizer.symbol_to_token[CLS_TOKEN]
|
| 437 |
+
tasks = [(game, idx) for idx, game in enumerate(games)]
|
| 438 |
+
|
| 439 |
+
# spawn context: safest across macOS/Linux and avoids fork-safety issues
|
| 440 |
+
# with chess.engine's subprocess.
|
| 441 |
+
ctx = mp.get_context("spawn")
|
| 442 |
+
samples: list[tuple[list[int], float]] = []
|
| 443 |
+
with ctx.Pool(
|
| 444 |
+
processes=num_workers,
|
| 445 |
+
initializer=_init_worker,
|
| 446 |
+
initargs=(engine_path, stockfish_depth, tokenizer, cls_id, sample_rate, skew_exponent),
|
| 447 |
+
) as pool:
|
| 448 |
+
for i, game_samples in enumerate(
|
| 449 |
+
pool.imap_unordered(_process_game, tasks, chunksize=chunksize)
|
| 450 |
+
):
|
| 451 |
+
samples.extend(game_samples)
|
| 452 |
+
if progress_every and (i + 1) % progress_every == 0:
|
| 453 |
+
print(
|
| 454 |
+
f" processed {i + 1:,}/{len(games):,} games, "
|
| 455 |
+
f"{len(samples):,} positions..."
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
return samples
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
def collate_fn(batch):
|
| 462 |
+
"""Pad token sequences and create attention mask."""
|
| 463 |
+
tokens, labels = zip(*batch)
|
| 464 |
+
max_len = max(len(t) for t in tokens)
|
| 465 |
+
padded = torch.zeros(len(tokens), max_len, dtype=torch.long)
|
| 466 |
+
attention_mask = torch.ones(len(tokens), max_len, dtype=torch.bool) # True = masked
|
| 467 |
+
|
| 468 |
+
for i, t in enumerate(tokens):
|
| 469 |
+
padded[i, :len(t)] = t
|
| 470 |
+
attention_mask[i, :len(t)] = False
|
| 471 |
+
|
| 472 |
+
labels_tensor = torch.tensor(labels, dtype=torch.float)
|
| 473 |
+
return padded, attention_mask, labels_tensor
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
def collate_fn_memmap(batch):
|
| 477 |
+
"""Collate pre-padded memmap samples — just stack, no per-batch padding needed."""
|
| 478 |
+
tokens, masks, labels = zip(*batch)
|
| 479 |
+
return torch.stack(tokens), torch.stack(masks), torch.tensor(labels, dtype=torch.float)
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
def collate_fn_policy(batch):
|
| 483 |
+
"""Pad token sequences and per-position board planes for policy training.
|
| 484 |
+
|
| 485 |
+
Each batch element is (tokens, planes, weight, source_tag) where
|
| 486 |
+
`tokens` is shape (L,) long and `planes` is shape (L, 19, 8, 8) float —
|
| 487 |
+
one set of board planes per position in the sequence. We pad both
|
| 488 |
+
along the sequence dimension to the batch's max length. Padded
|
| 489 |
+
positions get zero token, zero planes, and mask=True; downstream
|
| 490 |
+
loss masking ignores them.
|
| 491 |
+
|
| 492 |
+
Returns (padded_tokens, attention_mask, planes, weights, sources).
|
| 493 |
+
"""
|
| 494 |
+
tokens_list, planes_list, weights_list, sources_list = zip(*batch)
|
| 495 |
+
B = len(tokens_list)
|
| 496 |
+
max_len = max(len(t) for t in tokens_list)
|
| 497 |
+
padded = torch.zeros(B, max_len, dtype=torch.long)
|
| 498 |
+
mask = torch.ones(B, max_len, dtype=torch.bool) # True = padded
|
| 499 |
+
planes = torch.zeros(B, max_len, 19, 8, 8)
|
| 500 |
+
for i, (t, p) in enumerate(zip(tokens_list, planes_list)):
|
| 501 |
+
L = len(t)
|
| 502 |
+
padded[i, :L] = t
|
| 503 |
+
mask[i, :L] = False
|
| 504 |
+
planes[i, :L] = p
|
| 505 |
+
weights = torch.tensor(weights_list, dtype=torch.float)
|
| 506 |
+
sources = torch.tensor(sources_list, dtype=torch.long)
|
| 507 |
+
return padded, mask, planes, weights, sources
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
class MixedBatchSampler(torch.utils.data.Sampler):
|
| 511 |
+
"""Hard-balanced sampler over a ConcatDataset([games, puzzles]).
|
| 512 |
+
|
| 513 |
+
Each batch contains exactly `n_game_per_batch` game indices (drawn from
|
| 514 |
+
[0, n_game)) and `n_puzzle_per_batch` puzzle indices (drawn from
|
| 515 |
+
[n_game, n_game + n_puzzle)). Both pools are shuffled and consumed in
|
| 516 |
+
parallel; when the smaller (puzzle) pool runs out it gets re-shuffled,
|
| 517 |
+
so puzzles are effectively oversampled to match the game stream.
|
| 518 |
+
|
| 519 |
+
This guarantees a consistent gradient signal per batch and prevents the
|
| 520 |
+
puzzle samples from being statistical outliers under BatchNorm (already
|
| 521 |
+
moot now that the CNN uses GroupNorm, but still matters for loss-level
|
| 522 |
+
balance).
|
| 523 |
+
"""
|
| 524 |
+
def __init__(
|
| 525 |
+
self,
|
| 526 |
+
n_game: int,
|
| 527 |
+
n_puzzle: int,
|
| 528 |
+
batch_size: int,
|
| 529 |
+
game_ratio: float = 0.8,
|
| 530 |
+
drop_last: bool = True,
|
| 531 |
+
):
|
| 532 |
+
self.n_game = n_game
|
| 533 |
+
self.n_puzzle = n_puzzle
|
| 534 |
+
self.batch_size = batch_size
|
| 535 |
+
self.n_game_per_batch = max(1, int(round(batch_size * game_ratio)))
|
| 536 |
+
self.n_puzzle_per_batch = batch_size - self.n_game_per_batch
|
| 537 |
+
self.drop_last = drop_last
|
| 538 |
+
|
| 539 |
+
def __iter__(self):
|
| 540 |
+
game_perm = torch.randperm(self.n_game).tolist()
|
| 541 |
+
puzzle_perm = torch.randperm(self.n_puzzle).tolist() if self.n_puzzle > 0 else []
|
| 542 |
+
gi, pi = 0, 0
|
| 543 |
+
for _ in range(len(self)):
|
| 544 |
+
if gi + self.n_game_per_batch > self.n_game:
|
| 545 |
+
game_perm = torch.randperm(self.n_game).tolist()
|
| 546 |
+
gi = 0
|
| 547 |
+
if self.n_puzzle_per_batch > 0 and pi + self.n_puzzle_per_batch > self.n_puzzle:
|
| 548 |
+
puzzle_perm = torch.randperm(self.n_puzzle).tolist()
|
| 549 |
+
pi = 0
|
| 550 |
+
batch = []
|
| 551 |
+
for _ in range(self.n_game_per_batch):
|
| 552 |
+
batch.append(game_perm[gi]); gi += 1
|
| 553 |
+
for _ in range(self.n_puzzle_per_batch):
|
| 554 |
+
batch.append(self.n_game + puzzle_perm[pi]); pi += 1
|
| 555 |
+
yield batch
|
| 556 |
+
|
| 557 |
+
def __len__(self):
|
| 558 |
+
# One pass over the (more numerous) game pool defines an epoch.
|
| 559 |
+
return self.n_game // self.n_game_per_batch
|
| 560 |
+
|
| 561 |
+
|
| 562 |
+
class ChessPolicyDataset(Dataset):
|
| 563 |
+
"""Full game sequences for next-move prediction training.
|
| 564 |
+
|
| 565 |
+
Each sample yields (token_ids, board_planes, weight, source_tag):
|
| 566 |
+
|
| 567 |
+
- token_ids: full tokenized sequence [CLS, m1, m2, ..., mN]
|
| 568 |
+
- board_planes: (L, 19, 8, 8) tensor of per-position planes built by
|
| 569 |
+
replaying the move sequence. planes[0] is the starting
|
| 570 |
+
board (the standard chess start for games, the puzzle
|
| 571 |
+
FEN for puzzles); planes[t] is the board state after
|
| 572 |
+
token_ids[1..t] have been played. This is the
|
| 573 |
+
information-leak-safe per-position anchor that lets the
|
| 574 |
+
model cross-attend to the live board at every step.
|
| 575 |
+
- weight: per-sample loss weight (1.0 for games, default 5.0 for
|
| 576 |
+
puzzles) so puzzle samples have outsized gradient pull.
|
| 577 |
+
- source_tag: 0 = game, 1 = puzzle. Used by the mixed training loop to
|
| 578 |
+
mask the setup-move target on puzzle samples.
|
| 579 |
+
"""
|
| 580 |
+
def __init__(self, games: list[dict], tokenizer: Tokenizer, max_seq_len: int = 128):
|
| 581 |
+
cls_id = tokenizer.symbol_to_token[CLS_TOKEN]
|
| 582 |
+
self.tokenizer = tokenizer
|
| 583 |
+
self._memmap = False
|
| 584 |
+
self._train_idx: np.ndarray | None = None
|
| 585 |
+
self._mm_fens = None
|
| 586 |
+
self._fen_len = None
|
| 587 |
+
self.source_tag: int = 0
|
| 588 |
+
self.loss_weight: float = 1.0
|
| 589 |
+
self.samples: list[list[int]] = []
|
| 590 |
+
for game in games:
|
| 591 |
+
movetext = game.get("movetext", "")
|
| 592 |
+
if not movetext:
|
| 593 |
+
continue
|
| 594 |
+
move_sans = parse_movetext(movetext)
|
| 595 |
+
if len(move_sans) < 2:
|
| 596 |
+
continue
|
| 597 |
+
board = chess.Board()
|
| 598 |
+
move_ucis: list[str] = []
|
| 599 |
+
for san in move_sans:
|
| 600 |
+
try:
|
| 601 |
+
move = board.parse_san(san)
|
| 602 |
+
board.push(move)
|
| 603 |
+
move_ucis.append(move.uci())
|
| 604 |
+
except (chess.InvalidMoveError, chess.AmbiguousMoveError):
|
| 605 |
+
break
|
| 606 |
+
if len(move_ucis) < 2:
|
| 607 |
+
continue
|
| 608 |
+
move_ucis = move_ucis[:max_seq_len - 1] # reserve slot for CLS
|
| 609 |
+
self.samples.append([cls_id] + tokenizer.encode_moves(move_ucis))
|
| 610 |
+
|
| 611 |
+
def _get_start_board(self, idx: int) -> chess.Board:
|
| 612 |
+
"""Resolve the starting board for the per-position replay.
|
| 613 |
+
|
| 614 |
+
Puzzles with a `{name}_fens.bin` sidecar use the puzzle's FEN.
|
| 615 |
+
Everything else (games, puzzles without FENs) starts from the
|
| 616 |
+
standard chess starting position. A corrupt FEN silently falls
|
| 617 |
+
back to the starting position so the loader doesn't crash.
|
| 618 |
+
"""
|
| 619 |
+
if self._memmap and self._mm_fens is not None:
|
| 620 |
+
fen_bytes = bytes(self._mm_fens[idx])
|
| 621 |
+
fen_str = fen_bytes.rstrip(b"\x00").decode("ascii")
|
| 622 |
+
try:
|
| 623 |
+
return chess.Board(fen_str)
|
| 624 |
+
except ValueError:
|
| 625 |
+
return chess.Board()
|
| 626 |
+
return chess.Board()
|
| 627 |
+
|
| 628 |
+
def __len__(self) -> int:
|
| 629 |
+
if self._memmap:
|
| 630 |
+
if self._train_idx is not None:
|
| 631 |
+
return len(self._train_idx)
|
| 632 |
+
return len(self._mm_lengths)
|
| 633 |
+
return len(self.samples)
|
| 634 |
+
|
| 635 |
+
def __getitem__(self, idx: int):
|
| 636 |
+
if self._memmap:
|
| 637 |
+
if self._train_idx is not None:
|
| 638 |
+
idx = int(self._train_idx[idx])
|
| 639 |
+
length = int(self._mm_lengths[idx])
|
| 640 |
+
tokens = torch.from_numpy(np.array(self._mm_tokens[idx, :length], dtype=np.int32)).long()
|
| 641 |
+
else:
|
| 642 |
+
tokens = torch.tensor(self.samples[idx], dtype=torch.long)
|
| 643 |
+
start_board = self._get_start_board(idx)
|
| 644 |
+
planes = self._replay_planes(tokens.tolist(), start_board)
|
| 645 |
+
return tokens, planes, self.loss_weight, self.source_tag
|
| 646 |
+
|
| 647 |
+
@classmethod
|
| 648 |
+
def from_memmap(
|
| 649 |
+
cls,
|
| 650 |
+
out_dir: Path,
|
| 651 |
+
tokenizer: Tokenizer,
|
| 652 |
+
name: str = "policy",
|
| 653 |
+
source_tag: int = 0,
|
| 654 |
+
loss_weight: float = 1.0,
|
| 655 |
+
):
|
| 656 |
+
"""Load pre-tokenized policy sequences from memory-mapped arrays.
|
| 657 |
+
|
| 658 |
+
Args:
|
| 659 |
+
name: filename prefix; use 'puzzle' to load puzzle_*.bin files.
|
| 660 |
+
source_tag: 0 for game data, 1 for puzzle data (drives setup-move
|
| 661 |
+
masking in the mixed training loop).
|
| 662 |
+
loss_weight: per-sample weight applied to this dataset's samples in
|
| 663 |
+
the weighted cross-entropy loss.
|
| 664 |
+
|
| 665 |
+
If a sibling file `{name}_fens.bin` exists, FENs are loaded and used
|
| 666 |
+
to reconstruct each sample's starting-board planes. Otherwise the
|
| 667 |
+
standard chess starting position is used.
|
| 668 |
+
|
| 669 |
+
If `{name}_test_indices.npy` exists, those indices are excluded from
|
| 670 |
+
this dataset — used to make training disjoint from the held-out test
|
| 671 |
+
split that shares the same underlying .bin file.
|
| 672 |
+
"""
|
| 673 |
+
meta = torch.load(out_dir / f"{name}_meta.pt", weights_only=True)
|
| 674 |
+
n, max_len = meta["n"], meta["max_len"]
|
| 675 |
+
inst = cls.__new__(cls)
|
| 676 |
+
inst._memmap = True
|
| 677 |
+
inst._mm_tokens = np.memmap(out_dir / f"{name}_tokens.bin", dtype=np.int32, mode="r", shape=(n, max_len))
|
| 678 |
+
inst._mm_lengths = np.memmap(out_dir / f"{name}_lengths.bin", dtype=np.int32, mode="r", shape=(n,))
|
| 679 |
+
inst._train_idx = _load_train_idx(out_dir, name, n)
|
| 680 |
+
|
| 681 |
+
fen_path = out_dir / f"{name}_fens.bin"
|
| 682 |
+
if fen_path.exists() and "fen_len" in meta:
|
| 683 |
+
inst._mm_fens = np.memmap(fen_path, dtype=np.uint8, mode="r", shape=(n, meta["fen_len"]))
|
| 684 |
+
inst._fen_len = meta["fen_len"]
|
| 685 |
+
else:
|
| 686 |
+
inst._mm_fens = None
|
| 687 |
+
inst._fen_len = None
|
| 688 |
+
if source_tag == 1:
|
| 689 |
+
# Puzzle data without FENs: CNN will see the standard starting
|
| 690 |
+
# position for every puzzle, which is wrong. Loud warning.
|
| 691 |
+
print(
|
| 692 |
+
f"WARNING: {name}_fens.bin not found — puzzle samples will "
|
| 693 |
+
f"feed the starting-position planes to the CNN, defeating "
|
| 694 |
+
f"the point of puzzle conditioning. Rebuild with the "
|
| 695 |
+
f"updated build_datasets.py to fix."
|
| 696 |
+
)
|
| 697 |
+
|
| 698 |
+
inst.tokenizer = tokenizer
|
| 699 |
+
inst.source_tag = source_tag
|
| 700 |
+
inst.loss_weight = loss_weight
|
| 701 |
+
return inst
|
| 702 |
+
def _replay_planes(self, token_ids: list[int], start_board: chess.Board) -> torch.Tensor:
|
| 703 |
+
"""Returns (L, 19, 8, 8) tensor of board planes per position.
|
| 704 |
+
|
| 705 |
+
plane_t = state of the board after token_ids[1..t] have been played.
|
| 706 |
+
plane_0 = start_board (the model has only seen [CLS] at that point).
|
| 707 |
+
|
| 708 |
+
If a token in the sequence isn't a parseable UCI move (corrupt
|
| 709 |
+
data, non-move special token mid-stream), we freeze planes at the
|
| 710 |
+
last valid state and return. The loss already masks padded targets,
|
| 711 |
+
so the worst case is a few positions with stale board input rather
|
| 712 |
+
than a crashed worker.
|
| 713 |
+
"""
|
| 714 |
+
L = len(token_ids)
|
| 715 |
+
planes = torch.zeros(L, 19, 8, 8)
|
| 716 |
+
board = start_board.copy()
|
| 717 |
+
planes[0] = board_to_planes(board)
|
| 718 |
+
for t in range(1, L):
|
| 719 |
+
uci = self.tokenizer.token_to_symbol[int(token_ids[t])]
|
| 720 |
+
try:
|
| 721 |
+
board.push(chess.Move.from_uci(uci))
|
| 722 |
+
except (chess.InvalidMoveError, ValueError):
|
| 723 |
+
planes[t:] = planes[t - 1]
|
| 724 |
+
return planes
|
| 725 |
+
planes[t] = board_to_planes(board)
|
| 726 |
+
return planes
|
| 727 |
+
|
| 728 |
+
def _fmt_duration(seconds: float) -> str:
|
| 729 |
+
h, m = divmod(int(seconds), 3600)
|
| 730 |
+
m, s = divmod(m, 60)
|
| 731 |
+
return f"{h}h {m:02d}m {s:02d}s" if h else f"{m}m {s:02d}s"
|
| 732 |
+
|
| 733 |
+
|
| 734 |
+
def _amp_ctx(device):
|
| 735 |
+
"""BF16 autocast on CUDA, no-op elsewhere.
|
| 736 |
+
|
| 737 |
+
BF16 is preferred over FP16 here: same dynamic range as FP32 (no GradScaler
|
| 738 |
+
needed) and full tensor-core acceleration on Ampere+ / Ada / Blackwell.
|
| 739 |
+
On Blackwell (RTX PRO 6000 / B200) this typically gives 2-3x training speedup
|
| 740 |
+
on transformer matmuls.
|
| 741 |
+
"""
|
| 742 |
+
dev = device if isinstance(device, str) else getattr(device, "type", "cpu")
|
| 743 |
+
if dev == "cuda":
|
| 744 |
+
return torch.autocast(device_type="cuda", dtype=torch.bfloat16)
|
| 745 |
+
return contextlib.nullcontext()
|
| 746 |
+
|
| 747 |
+
|
| 748 |
+
def _run_epoch_reward(model, loader, optimizer, device, writer, global_step, epoch_idx):
|
| 749 |
+
"""Single training epoch: MSE against Stockfish labels."""
|
| 750 |
+
model.train()
|
| 751 |
+
total_loss = 0.0
|
| 752 |
+
n_batches = len(loader)
|
| 753 |
+
log_every = max(1, n_batches // 20)
|
| 754 |
+
epoch_start = time.time()
|
| 755 |
+
|
| 756 |
+
for i, (batch_tokens, batch_mask, batch_labels) in enumerate(loader):
|
| 757 |
+
batch_tokens = batch_tokens.to(device)
|
| 758 |
+
batch_mask = batch_mask.to(device)
|
| 759 |
+
batch_labels = batch_labels.to(device)
|
| 760 |
+
|
| 761 |
+
with _amp_ctx(device):
|
| 762 |
+
predictions = model(batch_tokens, attention_mask=batch_mask)
|
| 763 |
+
loss = F.mse_loss(predictions, batch_labels)
|
| 764 |
+
|
| 765 |
+
optimizer.zero_grad()
|
| 766 |
+
loss.backward()
|
| 767 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
| 768 |
+
optimizer.step()
|
| 769 |
+
total_loss += loss.item()
|
| 770 |
+
|
| 771 |
+
writer.add_scalar("train/reward_batch_loss", loss.item(), global_step)
|
| 772 |
+
global_step += 1
|
| 773 |
+
|
| 774 |
+
if (i + 1) % log_every == 0 or (i + 1) == n_batches:
|
| 775 |
+
elapsed = time.time() - epoch_start
|
| 776 |
+
batches_done = i + 1
|
| 777 |
+
eta = elapsed / batches_done * (n_batches - batches_done)
|
| 778 |
+
samples_per_sec = batches_done * batch_tokens.size(0) / elapsed
|
| 779 |
+
avg_so_far = total_loss / batches_done
|
| 780 |
+
print(
|
| 781 |
+
f" batch {batches_done:,}/{n_batches:,} "
|
| 782 |
+
f"loss={avg_so_far:.4f} "
|
| 783 |
+
f"{samples_per_sec:,.0f} samples/s "
|
| 784 |
+
f"eta {_fmt_duration(eta)}"
|
| 785 |
+
)
|
| 786 |
+
|
| 787 |
+
epoch_elapsed = time.time() - epoch_start
|
| 788 |
+
avg = total_loss / n_batches
|
| 789 |
+
writer.add_scalar("train/reward_epoch_loss", avg, epoch_idx)
|
| 790 |
+
return avg, global_step, epoch_elapsed
|
| 791 |
+
|
| 792 |
+
def _run_epoch_policy_mixed(
|
| 793 |
+
model, loader, optimizer, device, writer, global_step, epoch_idx, pad_id,
|
| 794 |
+
):
|
| 795 |
+
"""Single training epoch over mixed game + puzzle batches.
|
| 796 |
+
|
| 797 |
+
Loader yields (tokens, mask, planes, weights, sources). For each batch:
|
| 798 |
+
|
| 799 |
+
1. CNN-conditioned forward pass: position-0 embedding is replaced by the
|
| 800 |
+
CNN's encoding of `planes` (starting board of the sequence).
|
| 801 |
+
2. Per-position cross-entropy at every non-padded target position.
|
| 802 |
+
3. Setup-move target is masked out for puzzle rows (source==1): the setup
|
| 803 |
+
move is given as context, not a prediction target.
|
| 804 |
+
4. Per-sample loss weight upweights puzzle samples (default 5x via the
|
| 805 |
+
dataset's loss_weight field) — implemented as a position-weighted mean.
|
| 806 |
+
"""
|
| 807 |
+
model.train()
|
| 808 |
+
total_loss = 0.0
|
| 809 |
+
n_batches = len(loader)
|
| 810 |
+
log_every = max(1, n_batches // 20)
|
| 811 |
+
epoch_start = time.time()
|
| 812 |
+
|
| 813 |
+
for i, (batch_tokens, batch_mask, batch_planes, batch_weights, batch_sources) in enumerate(loader):
|
| 814 |
+
batch_tokens = batch_tokens.to(device, non_blocking=True)
|
| 815 |
+
batch_mask = batch_mask.to(device, non_blocking=True)
|
| 816 |
+
batch_planes = batch_planes.to(device, non_blocking=True)
|
| 817 |
+
batch_weights = batch_weights.to(device, non_blocking=True)
|
| 818 |
+
batch_sources = batch_sources.to(device, non_blocking=True)
|
| 819 |
+
|
| 820 |
+
input_tokens = batch_tokens[:, :-1]
|
| 821 |
+
input_mask = batch_mask[:, :-1]
|
| 822 |
+
input_planes = batch_planes[:, :-1] # planes are per-position; slice with tokens
|
| 823 |
+
targets = batch_tokens[:, 1:].contiguous()
|
| 824 |
+
|
| 825 |
+
# Mask the setup-move target (position 0 of the shifted target) for
|
| 826 |
+
# puzzle rows — it's the opponent's forcing move given as context.
|
| 827 |
+
is_puzzle = (batch_sources == 1)
|
| 828 |
+
if is_puzzle.any():
|
| 829 |
+
targets = targets.clone()
|
| 830 |
+
targets[is_puzzle, 0] = pad_id
|
| 831 |
+
|
| 832 |
+
with _amp_ctx(device):
|
| 833 |
+
logits = model(input_tokens, input_planes, attention_mask=input_mask)
|
| 834 |
+
B, T, V = logits.shape
|
| 835 |
+
ce = F.cross_entropy(
|
| 836 |
+
logits.reshape(-1, V),
|
| 837 |
+
targets.reshape(-1),
|
| 838 |
+
ignore_index=pad_id,
|
| 839 |
+
reduction="none",
|
| 840 |
+
).reshape(B, T)
|
| 841 |
+
position_mask = (targets != pad_id).float()
|
| 842 |
+
sample_weights = batch_weights.unsqueeze(1)
|
| 843 |
+
weighted = ce * position_mask * sample_weights
|
| 844 |
+
denom = (position_mask * sample_weights).sum().clamp(min=1.0)
|
| 845 |
+
loss = weighted.sum() / denom
|
| 846 |
+
|
| 847 |
+
optimizer.zero_grad()
|
| 848 |
+
loss.backward()
|
| 849 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
| 850 |
+
optimizer.step()
|
| 851 |
+
total_loss += loss.item()
|
| 852 |
+
|
| 853 |
+
writer.add_scalar("train_policy/batch_loss", loss.item(), global_step)
|
| 854 |
+
global_step += 1
|
| 855 |
+
|
| 856 |
+
if (i + 1) % log_every == 0 or (i + 1) == n_batches:
|
| 857 |
+
elapsed = time.time() - epoch_start
|
| 858 |
+
batches_done = i + 1
|
| 859 |
+
eta = elapsed / batches_done * (n_batches - batches_done)
|
| 860 |
+
samples_per_sec = batches_done * batch_tokens.size(0) / elapsed
|
| 861 |
+
avg_so_far = total_loss / batches_done
|
| 862 |
+
print(
|
| 863 |
+
f" batch {batches_done:,}/{n_batches:,} "
|
| 864 |
+
f"loss={avg_so_far:.4f} "
|
| 865 |
+
f"{samples_per_sec:,.0f} samples/s "
|
| 866 |
+
f"eta {_fmt_duration(eta)}"
|
| 867 |
+
)
|
| 868 |
+
|
| 869 |
+
epoch_elapsed = time.time() - epoch_start
|
| 870 |
+
avg = total_loss / max(n_batches, 1)
|
| 871 |
+
writer.add_scalar("train_policy/epoch_loss", avg, epoch_idx)
|
| 872 |
+
return avg, global_step, epoch_elapsed
|
| 873 |
+
|
| 874 |
+
|
| 875 |
+
def eval_reward(model, loader, device) -> dict:
|
| 876 |
+
"""Evaluate reward model on a test loader. Returns MSE, MAE, and Pearson r."""
|
| 877 |
+
model.eval()
|
| 878 |
+
all_preds, all_labels = [], []
|
| 879 |
+
with torch.no_grad(), _amp_ctx(device):
|
| 880 |
+
for batch_tokens, batch_mask, batch_labels in loader:
|
| 881 |
+
preds = model(batch_tokens.to(device), attention_mask=batch_mask.to(device))
|
| 882 |
+
all_preds.append(preds.float().cpu())
|
| 883 |
+
all_labels.append(batch_labels)
|
| 884 |
+
preds = torch.cat(all_preds)
|
| 885 |
+
labels = torch.cat(all_labels)
|
| 886 |
+
mse = F.mse_loss(preds, labels).item()
|
| 887 |
+
mae = (preds - labels).abs().mean().item()
|
| 888 |
+
# Pearson r
|
| 889 |
+
p_centered = preds - preds.mean()
|
| 890 |
+
l_centered = labels - labels.mean()
|
| 891 |
+
denom = (p_centered.norm() * l_centered.norm()).clamp(min=1e-8)
|
| 892 |
+
pearson_r = (p_centered * l_centered).sum() / denom
|
| 893 |
+
return {"mse": mse, "mae": mae, "pearson_r": pearson_r.item()}
|
| 894 |
+
|
| 895 |
+
|
| 896 |
+
def eval_policy(model, loader, device, pad_id: int) -> dict:
|
| 897 |
+
"""Evaluate policy model on a test loader. Returns loss, perplexity, top-1/top-5 acc.
|
| 898 |
+
|
| 899 |
+
Loader yields (tokens, mask, planes, weights, sources). Weights and sources
|
| 900 |
+
are ignored here — eval is uniform across samples.
|
| 901 |
+
"""
|
| 902 |
+
model.eval()
|
| 903 |
+
total_loss = 0.0
|
| 904 |
+
total_correct1 = 0
|
| 905 |
+
total_correct5 = 0
|
| 906 |
+
total_positions = 0
|
| 907 |
+
with torch.no_grad(), _amp_ctx(device):
|
| 908 |
+
for batch_tokens, batch_mask, batch_planes, _, _ in loader:
|
| 909 |
+
batch_tokens = batch_tokens.to(device)
|
| 910 |
+
batch_mask = batch_mask.to(device)
|
| 911 |
+
batch_planes = batch_planes.to(device)
|
| 912 |
+
input_tokens = batch_tokens[:, :-1]
|
| 913 |
+
input_mask = batch_mask[:, :-1]
|
| 914 |
+
input_planes = batch_planes[:, :-1]
|
| 915 |
+
targets = batch_tokens[:, 1:].contiguous()
|
| 916 |
+
logits = model(input_tokens, input_planes, attention_mask=input_mask)
|
| 917 |
+
flat_logits = logits.reshape(-1, logits.size(-1))
|
| 918 |
+
flat_targets = targets.reshape(-1)
|
| 919 |
+
valid = flat_targets != pad_id
|
| 920 |
+
total_loss += F.cross_entropy(flat_logits, flat_targets, ignore_index=pad_id, reduction="sum").item()
|
| 921 |
+
total_positions += valid.sum().item()
|
| 922 |
+
top5 = flat_logits[valid].topk(5, dim=-1).indices
|
| 923 |
+
valid_targets = flat_targets[valid]
|
| 924 |
+
total_correct1 += (top5[:, 0] == valid_targets).sum().item()
|
| 925 |
+
total_correct5 += (top5 == valid_targets.unsqueeze(1)).any(dim=1).sum().item()
|
| 926 |
+
avg_loss = total_loss / max(total_positions, 1)
|
| 927 |
+
return {
|
| 928 |
+
"loss": avg_loss,
|
| 929 |
+
"perplexity": math.exp(min(avg_loss, 20)),
|
| 930 |
+
"top1_acc": total_correct1 / max(total_positions, 1),
|
| 931 |
+
"top5_acc": total_correct5 / max(total_positions, 1),
|
| 932 |
+
}
|
| 933 |
+
|
| 934 |
+
|
| 935 |
+
def eval_puzzle_solve_rate(model, loader, device, pad_id: int) -> dict:
|
| 936 |
+
"""Evaluate puzzle solve rate: % of solver positions where model's top-1 matches
|
| 937 |
+
ground truth. Sequence layout: [CLS, setup, solver1, opp1, solver2, ...]
|
| 938 |
+
|
| 939 |
+
Solver moves are at token positions 2, 4, 6, ... (logit positions 1, 3, 5, ...).
|
| 940 |
+
The setup move at token position 1 (logit 0) is excluded — it's context, not a
|
| 941 |
+
prediction target. Also reports first-move solve rate (logit position 1 only).
|
| 942 |
+
"""
|
| 943 |
+
model.eval()
|
| 944 |
+
first_correct = 0
|
| 945 |
+
first_total = 0
|
| 946 |
+
all_correct = 0
|
| 947 |
+
all_total = 0
|
| 948 |
+
with torch.no_grad(), _amp_ctx(device):
|
| 949 |
+
for batch_tokens, batch_mask, batch_planes, _, _ in loader:
|
| 950 |
+
batch_tokens = batch_tokens.to(device)
|
| 951 |
+
batch_mask = batch_mask.to(device)
|
| 952 |
+
batch_planes = batch_planes.to(device)
|
| 953 |
+
input_tokens = batch_tokens[:, :-1]
|
| 954 |
+
input_mask = batch_mask[:, :-1]
|
| 955 |
+
input_planes = batch_planes[:, :-1]
|
| 956 |
+
logits = model(input_tokens, input_planes, attention_mask=input_mask)
|
| 957 |
+
seq_len = batch_tokens.size(1)
|
| 958 |
+
# Solver logit positions: 1, 3, 5, ... → target positions: 2, 4, 6, ...
|
| 959 |
+
for solver_logit_pos in range(1, seq_len - 1, 2):
|
| 960 |
+
solver_token_pos = solver_logit_pos + 1
|
| 961 |
+
if solver_token_pos >= seq_len:
|
| 962 |
+
break
|
| 963 |
+
targets = batch_tokens[:, solver_token_pos]
|
| 964 |
+
valid = targets != pad_id
|
| 965 |
+
if not valid.any():
|
| 966 |
+
continue
|
| 967 |
+
preds = logits[:, solver_logit_pos].argmax(dim=-1)
|
| 968 |
+
correct = (preds[valid] == targets[valid]).sum().item()
|
| 969 |
+
n_valid = valid.sum().item()
|
| 970 |
+
all_correct += correct
|
| 971 |
+
all_total += n_valid
|
| 972 |
+
if solver_logit_pos == 1:
|
| 973 |
+
first_correct += correct
|
| 974 |
+
first_total += n_valid
|
| 975 |
+
return {
|
| 976 |
+
"first_move_solve_rate": first_correct / max(first_total, 1),
|
| 977 |
+
"all_moves_solve_rate": all_correct / max(all_total, 1),
|
| 978 |
+
}
|
| 979 |
+
|
| 980 |
+
|
| 981 |
+
def train(
|
| 982 |
+
tokenizer_path,
|
| 983 |
+
stockfish_samples_path,
|
| 984 |
+
outcome_games_path,
|
| 985 |
+
epochs,
|
| 986 |
+
policy_epochs,
|
| 987 |
+
batch_size,
|
| 988 |
+
learning_rate,
|
| 989 |
+
max_seq_len,
|
| 990 |
+
log_dir,
|
| 991 |
+
num_workers,
|
| 992 |
+
puzzle_data_dir=None,
|
| 993 |
+
puzzle_epochs=5, # kept for CLI compat; no longer used (mixed training merges phases)
|
| 994 |
+
puzzle_loss_weight=5.0,
|
| 995 |
+
puzzle_ratio=0.2,
|
| 996 |
+
skip_reward=False,
|
| 997 |
+
keep_last_n_checkpoints=3,
|
| 998 |
+
):
|
| 999 |
+
"""Train the reward model then the policy model.
|
| 1000 |
+
|
| 1001 |
+
Phase 1: MSE on Stockfish-labeled positions (reward model).
|
| 1002 |
+
Phase 2: Mixed game + puzzle policy training. Each batch is hard-balanced
|
| 1003 |
+
at `puzzle_ratio` (default 20% puzzle) and puzzle samples carry a
|
| 1004 |
+
`puzzle_loss_weight` (default 5x) in the weighted cross-entropy loss.
|
| 1005 |
+
Games feed the CNN the standard chess starting board (constant signal,
|
| 1006 |
+
effectively a no-op); puzzles feed the FEN-derived board.
|
| 1007 |
+
|
| 1008 |
+
If `skip_reward` is True, Phase 1 is skipped entirely — the reward dataset
|
| 1009 |
+
is not loaded, no reward model is created, and `reward_model.pt` on disk is
|
| 1010 |
+
untouched. Use this for iterating on Phase 2 without burning hours on a
|
| 1011 |
+
Phase 1 that hasn't changed.
|
| 1012 |
+
|
| 1013 |
+
Requires stockfish memmap files, outcome games, and (for mixed training)
|
| 1014 |
+
puzzle memmaps with FENs built by src/build_datasets.py.
|
| 1015 |
+
"""
|
| 1016 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 1017 |
+
amp_dtype = "bfloat16 autocast" if device == "cuda" else "fp32 (CPU)"
|
| 1018 |
+
print(f"Using device: {device} ({amp_dtype})")
|
| 1019 |
+
|
| 1020 |
+
print(f"Loading tokenizer from {tokenizer_path}...")
|
| 1021 |
+
tokenizer = torch.load(tokenizer_path, weights_only=False)
|
| 1022 |
+
vocab_size = tokenizer.language_size
|
| 1023 |
+
pad_id = tokenizer.symbol_to_token[PAD_TOKEN]
|
| 1024 |
+
|
| 1025 |
+
writer = SummaryWriter(log_dir=log_dir)
|
| 1026 |
+
|
| 1027 |
+
# ── Test loaders (optional, skip silently if test sets not built yet) ───────
|
| 1028 |
+
out_dir = Path(stockfish_samples_path).parent
|
| 1029 |
+
reward_test_loader = None
|
| 1030 |
+
if not skip_reward and (out_dir / "stockfish_test_meta.pt").exists():
|
| 1031 |
+
reward_test_ds = ChessPositionDataset.from_memmap(out_dir, "stockfish_test", tokenizer)
|
| 1032 |
+
reward_test_loader = DataLoader(
|
| 1033 |
+
reward_test_ds, batch_size=batch_size, shuffle=False, collate_fn=collate_fn_memmap,
|
| 1034 |
+
num_workers=num_workers, pin_memory=True,
|
| 1035 |
+
)
|
| 1036 |
+
print(f"Reward test set: {len(reward_test_ds):,} positions loaded")
|
| 1037 |
+
|
| 1038 |
+
policy_data_dir_early = Path(outcome_games_path).parent
|
| 1039 |
+
policy_test_loader = None
|
| 1040 |
+
if (policy_data_dir_early / "policy_test_meta.pt").exists():
|
| 1041 |
+
policy_test_ds = ChessPolicyDataset.from_memmap(policy_data_dir_early, tokenizer, name="policy_test")
|
| 1042 |
+
policy_test_loader = DataLoader(
|
| 1043 |
+
policy_test_ds, batch_size=batch_size, shuffle=False, collate_fn=collate_fn_policy,
|
| 1044 |
+
num_workers=num_workers, pin_memory=True,
|
| 1045 |
+
)
|
| 1046 |
+
print(f"Policy test set: {len(policy_test_ds):,} sequences loaded")
|
| 1047 |
+
|
| 1048 |
+
puzzle_test_loader = None
|
| 1049 |
+
_puzzle_dir = Path(puzzle_data_dir) if puzzle_data_dir is not None else policy_data_dir_early
|
| 1050 |
+
if (_puzzle_dir / "puzzle_test_meta.pt").exists():
|
| 1051 |
+
puzzle_test_ds = ChessPolicyDataset.from_memmap(
|
| 1052 |
+
_puzzle_dir, tokenizer, name="puzzle_test",
|
| 1053 |
+
source_tag=1, loss_weight=1.0, # eval uses uniform weighting
|
| 1054 |
+
)
|
| 1055 |
+
puzzle_test_loader = DataLoader(
|
| 1056 |
+
puzzle_test_ds, batch_size=batch_size, shuffle=False, collate_fn=collate_fn_policy,
|
| 1057 |
+
num_workers=num_workers, pin_memory=True,
|
| 1058 |
+
)
|
| 1059 |
+
print(f"Puzzle test set: {len(puzzle_test_ds):,} sequences loaded")
|
| 1060 |
+
|
| 1061 |
+
# ── Phase 1: reward model ────────────────────────────────────────────────
|
| 1062 |
+
reward_model = None
|
| 1063 |
+
global_step = 0
|
| 1064 |
+
if skip_reward:
|
| 1065 |
+
print("\n── Phase 1: SKIPPED (--skip-reward) — existing reward_model.pt untouched.")
|
| 1066 |
+
else:
|
| 1067 |
+
sf_meta = out_dir / "stockfish_meta.pt"
|
| 1068 |
+
if sf_meta.exists():
|
| 1069 |
+
print(f"Loading Stockfish samples from memmap ({out_dir}/stockfish_*)...")
|
| 1070 |
+
sf_ds = ChessPositionDataset.from_memmap(out_dir, "stockfish", tokenizer)
|
| 1071 |
+
sf_collate = collate_fn_memmap
|
| 1072 |
+
else:
|
| 1073 |
+
print(f"Loading Stockfish samples from {stockfish_samples_path}...")
|
| 1074 |
+
sf_ds = ChessPositionDataset.from_file(stockfish_samples_path, tokenizer)
|
| 1075 |
+
sf_collate = collate_fn
|
| 1076 |
+
print(f"Reward dataset: {len(sf_ds):,} positions")
|
| 1077 |
+
|
| 1078 |
+
sf_loader = DataLoader(
|
| 1079 |
+
sf_ds, batch_size=batch_size, shuffle=True, collate_fn=sf_collate,
|
| 1080 |
+
num_workers=num_workers, pin_memory=True,
|
| 1081 |
+
)
|
| 1082 |
+
|
| 1083 |
+
reward_model = ChessRewardModel(vocab_size=vocab_size, max_seq_len=max_seq_len).to(device)
|
| 1084 |
+
reward_optimizer = torch.optim.AdamW(reward_model.parameters(), lr=learning_rate)
|
| 1085 |
+
|
| 1086 |
+
print(f"\n── Phase 1: reward model — {epochs} epochs, lr={learning_rate}")
|
| 1087 |
+
phase1_start = time.time()
|
| 1088 |
+
for epoch in range(epochs):
|
| 1089 |
+
epoch_num = epoch + 1
|
| 1090 |
+
print(f" [epoch {epoch_num}/{epochs}] starting...")
|
| 1091 |
+
avg_loss, global_step, epoch_secs = _run_epoch_reward(
|
| 1092 |
+
reward_model, sf_loader, reward_optimizer, device, writer, global_step, epoch
|
| 1093 |
+
)
|
| 1094 |
+
epochs_left = epochs - epoch_num
|
| 1095 |
+
print(
|
| 1096 |
+
f" [epoch {epoch_num}/{epochs}] "
|
| 1097 |
+
f"loss={avg_loss:.4f} "
|
| 1098 |
+
f"epoch_time={_fmt_duration(epoch_secs)} "
|
| 1099 |
+
f"eta={_fmt_duration(epoch_secs * epochs_left)}"
|
| 1100 |
+
)
|
| 1101 |
+
if reward_test_loader is not None:
|
| 1102 |
+
m = eval_reward(reward_model, reward_test_loader, device)
|
| 1103 |
+
writer.add_scalar("test/reward_mse", m["mse"], epoch_num)
|
| 1104 |
+
writer.add_scalar("test/reward_mae", m["mae"], epoch_num)
|
| 1105 |
+
writer.add_scalar("test/reward_pearson_r", m["pearson_r"], epoch_num)
|
| 1106 |
+
print(
|
| 1107 |
+
f" [test] mse={m['mse']:.4f} mae={m['mae']:.4f} r={m['pearson_r']:.4f}"
|
| 1108 |
+
)
|
| 1109 |
+
|
| 1110 |
+
print(f"Phase 1 complete in {_fmt_duration(time.time() - phase1_start)}")
|
| 1111 |
+
torch.save(reward_model.state_dict(), "reward_model.pt")
|
| 1112 |
+
print("Reward model saved to reward_model.pt")
|
| 1113 |
+
|
| 1114 |
+
# ── Phase 2: mixed game + puzzle policy training ─────────────────────────
|
| 1115 |
+
policy_data_dir = policy_data_dir_early # already computed above
|
| 1116 |
+
policy_meta = policy_data_dir / "policy_meta.pt"
|
| 1117 |
+
if policy_meta.exists():
|
| 1118 |
+
print(f"Loading policy sequences from memmap ({policy_data_dir}/policy_*)...")
|
| 1119 |
+
game_ds = ChessPolicyDataset.from_memmap(
|
| 1120 |
+
policy_data_dir, tokenizer, name="policy",
|
| 1121 |
+
source_tag=0, loss_weight=1.0,
|
| 1122 |
+
)
|
| 1123 |
+
else:
|
| 1124 |
+
print(f"Loading outcome games from {outcome_games_path} (tokenizing on-the-fly)...")
|
| 1125 |
+
outcome_games = torch.load(outcome_games_path, weights_only=False)
|
| 1126 |
+
game_ds = ChessPolicyDataset(outcome_games, tokenizer, max_seq_len=max_seq_len)
|
| 1127 |
+
print(f"Game dataset: {len(game_ds):,} sequences")
|
| 1128 |
+
|
| 1129 |
+
# Puzzle dataset (optional — falls back to game-only training if absent).
|
| 1130 |
+
# If --puzzle-data isn't passed, look for puzzle_*.bin alongside policy_*.bin
|
| 1131 |
+
# so a `build_datasets.py --policy-only` layout (everything in data/) is
|
| 1132 |
+
# picked up automatically without an extra CLI flag.
|
| 1133 |
+
puzzle_ds = None
|
| 1134 |
+
pdir = Path(puzzle_data_dir) if puzzle_data_dir is not None else policy_data_dir_early
|
| 1135 |
+
if (pdir / "puzzle_meta.pt").exists():
|
| 1136 |
+
puzzle_ds = ChessPolicyDataset.from_memmap(
|
| 1137 |
+
pdir, tokenizer, name="puzzle",
|
| 1138 |
+
source_tag=1, loss_weight=puzzle_loss_weight,
|
| 1139 |
+
)
|
| 1140 |
+
print(f"Puzzle dataset ({pdir}): {len(puzzle_ds):,} sequences (loss_weight={puzzle_loss_weight}x)")
|
| 1141 |
+
elif puzzle_data_dir is not None:
|
| 1142 |
+
print(f"WARNING: --puzzle-data given but {pdir}/puzzle_meta.pt not found.")
|
| 1143 |
+
|
| 1144 |
+
if puzzle_ds is not None:
|
| 1145 |
+
mixed_ds = torch.utils.data.ConcatDataset([game_ds, puzzle_ds])
|
| 1146 |
+
sampler = MixedBatchSampler(
|
| 1147 |
+
n_game=len(game_ds),
|
| 1148 |
+
n_puzzle=len(puzzle_ds),
|
| 1149 |
+
batch_size=batch_size,
|
| 1150 |
+
game_ratio=1.0 - puzzle_ratio,
|
| 1151 |
+
)
|
| 1152 |
+
print(
|
| 1153 |
+
f"Mixed batch composition: {sampler.n_game_per_batch} game + "
|
| 1154 |
+
f"{sampler.n_puzzle_per_batch} puzzle per batch (puzzle_ratio={puzzle_ratio})"
|
| 1155 |
+
)
|
| 1156 |
+
policy_loader = DataLoader(
|
| 1157 |
+
mixed_ds, batch_sampler=sampler, collate_fn=collate_fn_policy,
|
| 1158 |
+
num_workers=num_workers, pin_memory=True,
|
| 1159 |
+
)
|
| 1160 |
+
else:
|
| 1161 |
+
policy_loader = DataLoader(
|
| 1162 |
+
game_ds, batch_size=batch_size, shuffle=True, collate_fn=collate_fn_policy,
|
| 1163 |
+
num_workers=num_workers, pin_memory=True,
|
| 1164 |
+
)
|
| 1165 |
+
|
| 1166 |
+
policy_model = ChessPolicyModel(vocab_size=vocab_size, max_seq_len=max_seq_len).to(device)
|
| 1167 |
+
policy_optimizer = torch.optim.AdamW(policy_model.parameters(), lr=learning_rate)
|
| 1168 |
+
global_step = 0
|
| 1169 |
+
|
| 1170 |
+
def _run_policy_test(epoch_num: int, tb_prefix: str) -> None:
|
| 1171 |
+
if policy_test_loader is not None:
|
| 1172 |
+
m = eval_policy(policy_model, policy_test_loader, device, pad_id)
|
| 1173 |
+
writer.add_scalar(f"{tb_prefix}/policy_loss", m["loss"], epoch_num)
|
| 1174 |
+
writer.add_scalar(f"{tb_prefix}/policy_perplexity", m["perplexity"], epoch_num)
|
| 1175 |
+
writer.add_scalar(f"{tb_prefix}/policy_top1_acc", m["top1_acc"], epoch_num)
|
| 1176 |
+
writer.add_scalar(f"{tb_prefix}/policy_top5_acc", m["top5_acc"], epoch_num)
|
| 1177 |
+
print(
|
| 1178 |
+
f" [policy test] loss={m['loss']:.4f} ppl={m['perplexity']:.2f}"
|
| 1179 |
+
f" top1={m['top1_acc']:.3f} top5={m['top5_acc']:.3f}"
|
| 1180 |
+
)
|
| 1181 |
+
if puzzle_test_loader is not None:
|
| 1182 |
+
m = eval_puzzle_solve_rate(policy_model, puzzle_test_loader, device, pad_id)
|
| 1183 |
+
writer.add_scalar(f"{tb_prefix}/puzzle_first_move", m["first_move_solve_rate"], epoch_num)
|
| 1184 |
+
writer.add_scalar(f"{tb_prefix}/puzzle_all_moves", m["all_moves_solve_rate"], epoch_num)
|
| 1185 |
+
print(
|
| 1186 |
+
f" [puzzle test] first_move={m['first_move_solve_rate']:.3f}"
|
| 1187 |
+
f" all_moves={m['all_moves_solve_rate']:.3f}"
|
| 1188 |
+
)
|
| 1189 |
+
|
| 1190 |
+
def _save_epoch_checkpoint(epoch_num: int) -> None:
|
| 1191 |
+
"""Save the policy model after a completed epoch.
|
| 1192 |
+
|
| 1193 |
+
Each checkpoint is `policy_model_epoch_{NN}.pt` next to the final
|
| 1194 |
+
`policy_model.pt`. If `keep_last_n_checkpoints > 0`, older
|
| 1195 |
+
checkpoints are pruned to cap disk usage. The final end-of-Phase-2
|
| 1196 |
+
save (`policy_model.pt`) is kept regardless of this setting and
|
| 1197 |
+
always reflects the last completed epoch.
|
| 1198 |
+
"""
|
| 1199 |
+
ckpt_path = Path(f"policy_model_epoch_{epoch_num:02d}.pt")
|
| 1200 |
+
torch.save(policy_model.state_dict(), ckpt_path)
|
| 1201 |
+
print(f" [checkpoint] saved {ckpt_path.name}")
|
| 1202 |
+
if keep_last_n_checkpoints and keep_last_n_checkpoints > 0:
|
| 1203 |
+
existing = sorted(Path(".").glob("policy_model_epoch_*.pt"))
|
| 1204 |
+
stale = existing[:-keep_last_n_checkpoints]
|
| 1205 |
+
for p in stale:
|
| 1206 |
+
try:
|
| 1207 |
+
p.unlink()
|
| 1208 |
+
except OSError:
|
| 1209 |
+
pass
|
| 1210 |
+
|
| 1211 |
+
def _log_cross_gates(epoch_num: int) -> None:
|
| 1212 |
+
"""Log per-block cross-attention gate values to TensorBoard.
|
| 1213 |
+
|
| 1214 |
+
Each CrossAttnBlock has a single learned scalar `cross_gate` whose
|
| 1215 |
+
tanh controls how much board cross-attention contributes through
|
| 1216 |
+
its residual (init=0 means cross-attn starts disabled). Tracking
|
| 1217 |
+
these over epochs shows which layers opened the board pathway and
|
| 1218 |
+
how fast — flat-at-zero across all layers means the model decided
|
| 1219 |
+
cross-attention wasn't worth it.
|
| 1220 |
+
|
| 1221 |
+
TB tags:
|
| 1222 |
+
cross_gate/block_{i} effective gate tanh(α) ∈ (-1, 1)
|
| 1223 |
+
cross_gate_raw/block_{i} raw parameter α (unbounded)
|
| 1224 |
+
"""
|
| 1225 |
+
blocks = getattr(policy_model, "blocks", None)
|
| 1226 |
+
if blocks is None:
|
| 1227 |
+
return # Older model variants without CrossAttnBlock stack.
|
| 1228 |
+
gates_tanh = {}
|
| 1229 |
+
for i, blk in enumerate(blocks):
|
| 1230 |
+
raw = blk.cross_gate.detach()
|
| 1231 |
+
tanh_val = raw.tanh().item()
|
| 1232 |
+
raw_val = raw.item()
|
| 1233 |
+
writer.add_scalar(f"cross_gate/block_{i}", tanh_val, epoch_num)
|
| 1234 |
+
writer.add_scalar(f"cross_gate_raw/block_{i}", raw_val, epoch_num)
|
| 1235 |
+
gates_tanh[f"block_{i}"] = tanh_val
|
| 1236 |
+
# Overlay all blocks on a single chart for easy at-a-glance comparison.
|
| 1237 |
+
writer.add_scalars("cross_gate_all", gates_tanh, epoch_num)
|
| 1238 |
+
gate_summary = " ".join(f"L{i}={v:+.3f}" for i, v in enumerate(gates_tanh.values()))
|
| 1239 |
+
print(f" [cross_gate] {gate_summary}")
|
| 1240 |
+
|
| 1241 |
+
print(f"\n── Phase 2: mixed policy training — {policy_epochs} epochs, lr={learning_rate}")
|
| 1242 |
+
phase2_start = time.time()
|
| 1243 |
+
# Log initial gate values (all zeros at init) so TB charts start at epoch 0.
|
| 1244 |
+
_log_cross_gates(0)
|
| 1245 |
+
for epoch in range(policy_epochs):
|
| 1246 |
+
epoch_num = epoch + 1
|
| 1247 |
+
print(f" [epoch {epoch_num}/{policy_epochs}] starting...")
|
| 1248 |
+
avg_loss, global_step, epoch_secs = _run_epoch_policy_mixed(
|
| 1249 |
+
policy_model, policy_loader, policy_optimizer, device, writer, global_step, epoch, pad_id,
|
| 1250 |
+
)
|
| 1251 |
+
epochs_left = policy_epochs - epoch_num
|
| 1252 |
+
print(
|
| 1253 |
+
f" [epoch {epoch_num}/{policy_epochs}] "
|
| 1254 |
+
f"loss={avg_loss:.4f} "
|
| 1255 |
+
f"epoch_time={_fmt_duration(epoch_secs)} "
|
| 1256 |
+
f"eta={_fmt_duration(epoch_secs * epochs_left)}"
|
| 1257 |
+
)
|
| 1258 |
+
_run_policy_test(epoch_num, "test_mixed")
|
| 1259 |
+
_log_cross_gates(epoch_num)
|
| 1260 |
+
_save_epoch_checkpoint(epoch_num)
|
| 1261 |
+
|
| 1262 |
+
print(f"Phase 2 complete in {_fmt_duration(time.time() - phase2_start)}")
|
| 1263 |
+
torch.save(policy_model.state_dict(), "policy_model.pt")
|
| 1264 |
+
print("Policy model saved to policy_model.pt")
|
| 1265 |
+
|
| 1266 |
+
return reward_model, policy_model, tokenizer
|
| 1267 |
+
|
| 1268 |
+
|
| 1269 |
+
def _build_argparser():
|
| 1270 |
+
p = argparse.ArgumentParser(description=train.__doc__)
|
| 1271 |
+
p.add_argument("--tokenizer-path", default="data/tokenizer.pt")
|
| 1272 |
+
p.add_argument("--stockfish-samples-path", default="data/stockfish_samples.pt")
|
| 1273 |
+
p.add_argument("--outcome-games-path", default="data/games_outcome.pt")
|
| 1274 |
+
p.add_argument("--epochs", type=int, default=15)
|
| 1275 |
+
p.add_argument("--policy-epochs", type=int, default=15)
|
| 1276 |
+
p.add_argument("--batch-size", type=int, default=1024)
|
| 1277 |
+
p.add_argument("--learning-rate", type=float, default=3e-5)
|
| 1278 |
+
p.add_argument("--max-seq-len", type=int, default=128)
|
| 1279 |
+
p.add_argument("--log-dir", default="runs/chess_models")
|
| 1280 |
+
p.add_argument("--num-workers", type=int, default=8)
|
| 1281 |
+
p.add_argument("--puzzle-data", default=None, dest="puzzle_data_dir",
|
| 1282 |
+
help="Directory containing puzzle_tokens.bin / puzzle_lengths.bin / puzzle_fens.bin / puzzle_meta.pt")
|
| 1283 |
+
p.add_argument("--puzzle-epochs", type=int, default=5, dest="puzzle_epochs",
|
| 1284 |
+
help="(Deprecated, retained for CLI compat) — mixed training merges game/puzzle into Phase 2.")
|
| 1285 |
+
p.add_argument("--puzzle-loss-weight", type=float, default=5.0, dest="puzzle_loss_weight",
|
| 1286 |
+
help="Per-sample loss weight applied to puzzle samples in the mixed-training "
|
| 1287 |
+
"weighted cross-entropy (default 5.0).")
|
| 1288 |
+
p.add_argument("--puzzle-ratio", type=float, default=0.2, dest="puzzle_ratio",
|
| 1289 |
+
help="Fraction of each mixed batch drawn from the puzzle dataset (default 0.2).")
|
| 1290 |
+
p.add_argument("--skip-reward", action="store_true", dest="skip_reward",
|
| 1291 |
+
help="Skip Phase 1 (reward model training). Existing reward_model.pt is "
|
| 1292 |
+
"left untouched. Use when iterating on Phase 2 only.")
|
| 1293 |
+
p.add_argument("--keep-last-n-checkpoints", type=int, default=3, dest="keep_last_n_checkpoints",
|
| 1294 |
+
help="Number of per-epoch policy_model_epoch_NN.pt files to keep on disk "
|
| 1295 |
+
"(default 3). Use 0 to keep all epochs. Final policy_model.pt is kept "
|
| 1296 |
+
"regardless.")
|
| 1297 |
+
return p
|
| 1298 |
+
|
| 1299 |
+
|
| 1300 |
+
if __name__ == "__main__":
|
| 1301 |
+
args = _build_argparser().parse_args()
|
| 1302 |
+
train(
|
| 1303 |
+
tokenizer_path=args.tokenizer_path,
|
| 1304 |
+
stockfish_samples_path=args.stockfish_samples_path,
|
| 1305 |
+
outcome_games_path=args.outcome_games_path,
|
| 1306 |
+
epochs=args.epochs,
|
| 1307 |
+
policy_epochs=args.policy_epochs,
|
| 1308 |
+
batch_size=args.batch_size,
|
| 1309 |
+
learning_rate=args.learning_rate,
|
| 1310 |
+
max_seq_len=args.max_seq_len,
|
| 1311 |
+
log_dir=args.log_dir,
|
| 1312 |
+
num_workers=args.num_workers,
|
| 1313 |
+
puzzle_data_dir=args.puzzle_data_dir,
|
| 1314 |
+
puzzle_epochs=args.puzzle_epochs,
|
| 1315 |
+
puzzle_loss_weight=args.puzzle_loss_weight,
|
| 1316 |
+
puzzle_ratio=args.puzzle_ratio,
|
| 1317 |
+
skip_reward=args.skip_reward,
|
| 1318 |
+
keep_last_n_checkpoints=args.keep_last_n_checkpoints,
|
| 1319 |
+
)
|