GptForChess / src /build_datasets.py
robell05's picture
serving model
6d75857
"""Build the two training datasets for hybrid two-phase training.
Runs three resumable stages:
1. Stream the Lichess HF dataset, filter by Elo + Termination, save two
disjoint raw-game subsets (movetext + Result only).
2. Build the shared tokenizer from the outcome subset and generate
outcome-labeled samples ({+1, 0, -1} from game Result).
3. Run parallel Stockfish on the disjoint subset to produce precisely
labeled samples (tanh(cp/400)).
Each stage skips if its output files already exist. Use --force to re-run.
Outputs (under --out-dir):
games_outcome.pt raw outcome-subset games
games_stockfish.pt raw stockfish-subset games
tokenizer.pt shared Tokenizer (built from outcome games)
outcome_samples.pt list[(token_ids, outcome_label)]
stockfish_samples.pt list[(token_ids, stockfish_label)]
"""
import argparse
import random
from pathlib import Path
import chess
import numpy as np
import torch
from datasets import load_dataset
from tqdm import tqdm
from model import CLS_TOKEN
from train import (
build_tokenizer_from_games,
generate_samples_stockfish_parallel,
parse_movetext,
)
# Lichess Result → outcome label from white's perspective.
RESULT_TO_LABEL = {"1-0": 1.0, "0-1": -1.0, "1/2-1/2": 0.0}
def _save_as_memmap(
samples: list[tuple[list[int], float]], out_dir: Path, name: str, max_seq_len: int = 128
) -> None:
"""Save samples as memory-mapped arrays for fast DataLoader access.
Sequences longer than max_seq_len are truncated (keeps the most recent tokens,
since the CLS token is at position 0 we keep ids[:max_seq_len]).
Produces three files:
{name}_tokens.bin — (N, max_seq_len) int32, zero-padded
{name}_labels.bin — (N,) float32
{name}_lengths.bin — (N,) int32, actual sequence length per sample (capped at max_seq_len)
{name}_meta.pt — dict with 'n' and 'max_len'
"""
n = len(samples)
max_len = min(max(len(ids) for ids, _ in samples), max_seq_len)
print(f" memmap {name}: {n:,} samples, max_seq_len={max_len}")
tokens = np.memmap(out_dir / f"{name}_tokens.bin", dtype=np.int32, mode="w+", shape=(n, max_len))
labels = np.memmap(out_dir / f"{name}_labels.bin", dtype=np.float32, mode="w+", shape=(n,))
lengths = np.memmap(out_dir / f"{name}_lengths.bin", dtype=np.int32, mode="w+", shape=(n,))
for i, (ids, label) in enumerate(tqdm(samples, desc=f" writing {name}", unit="sample")):
ids = ids[:max_len]
l = len(ids)
tokens[i, :l] = ids
labels[i] = label
lengths[i] = l
tokens.flush()
labels.flush()
lengths.flush()
torch.save({"n": n, "max_len": max_len}, out_dir / f"{name}_meta.pt")
size_gb = (tokens.nbytes + labels.nbytes + lengths.nbytes) / 1024 ** 3
print(f" memmap {name} saved ({size_gb:.2f} GB)")
def stage1_collect_games(args: argparse.Namespace) -> None:
policy_games_path = args.out_dir / "games_outcome.pt"
reward_games_path = args.out_dir / "games_stockfish.pt"
policy_only = getattr(args, "policy_only", False)
if policy_only:
# Reward subset is irrelevant when we're only training the policy model.
# Skip-condition checks only the policy artifact.
if policy_games_path.exists() and not args.force:
print(f"Stage 1: skipping — {policy_games_path.name} exists (--policy-only).")
return
else:
if policy_games_path.exists() and reward_games_path.exists() and not args.force:
print(f"Stage 1: skipping — {policy_games_path.name} and {reward_games_path.name} exist.")
return
if policy_only:
lower_elo = args.policy_min_elo
print(
f"Stage 1: streaming Lichess/standard-chess-games (Termination == 'Normal'), "
f"policy Elo >= {args.policy_min_elo} (target {args.policy_games:,}). "
f"Reward subset skipped (--policy-only)."
)
else:
lower_elo = min(args.reward_min_elo, args.policy_min_elo)
print(
f"Stage 1: streaming Lichess/standard-chess-games (Termination == 'Normal'), "
f"reward Elo >= {args.reward_min_elo} (target {args.reward_games:,}), "
f"policy Elo >= {args.policy_min_elo} (target {args.policy_games:,})..."
)
ds = load_dataset("Lichess/standard-chess-games", split="train", streaming=True)
# Pre-filter by the lower of the two thresholds to skip clearly ineligible games.
ds = ds.filter(
lambda r: (
r.get("WhiteElo") is not None
and r.get("BlackElo") is not None
and r["WhiteElo"] >= lower_elo
and r["BlackElo"] >= lower_elo
and r.get("Termination") == "Normal"
)
)
policy_games: list[dict] = []
reward_games: list[dict] = []
keep_keys = ("movetext", "Result")
for row in tqdm(ds, desc="Stage 1: streaming", unit="game"):
white_elo = row.get("WhiteElo", 0)
black_elo = row.get("BlackElo", 0)
minimal = {k: row.get(k) for k in keep_keys}
if (
not policy_only
and len(reward_games) < args.reward_games
and white_elo >= args.reward_min_elo
and black_elo >= args.reward_min_elo
):
reward_games.append(minimal)
if len(policy_games) < args.policy_games and white_elo >= args.policy_min_elo and black_elo >= args.policy_min_elo:
policy_games.append(minimal)
if policy_only:
if len(policy_games) >= args.policy_games:
break
elif len(reward_games) >= args.reward_games and len(policy_games) >= args.policy_games:
break
if policy_only:
if len(policy_games) < args.policy_games:
print(
f" WARNING: dataset exhausted before target — "
f"got {len(policy_games):,} policy games."
)
elif len(reward_games) < args.reward_games or len(policy_games) < args.policy_games:
print(
f" WARNING: dataset exhausted before target — "
f"got {len(reward_games):,} reward + {len(policy_games):,} policy games."
)
if not policy_only:
print(f"Stage 1: saving {reward_games_path} ({len(reward_games):,} games)...")
torch.save(reward_games, reward_games_path)
print(f"Stage 1: saving {policy_games_path} ({len(policy_games):,} games)...")
torch.save(policy_games, policy_games_path)
def _generate_outcome_samples(games, tokenizer, max_positions_per_game, skip_ply):
"""Build (token_ids, outcome_label) samples for the phase-1 dataset."""
cls_id = tokenizer.symbol_to_token[CLS_TOKEN]
samples: list[tuple[list[int], float]] = []
with tqdm(games, desc="Stage 2: outcome samples", unit="game") as pbar:
for idx, game in enumerate(pbar):
result = game.get("Result")
if result not in RESULT_TO_LABEL:
continue
label = RESULT_TO_LABEL[result]
movetext = game.get("movetext", "")
if not movetext:
continue
move_sans = parse_movetext(movetext)
if len(move_sans) < max(2, skip_ply + 1):
continue
eligible = list(range(skip_ply, len(move_sans)))
num_positions = min(max_positions_per_game, len(eligible))
rng = random.Random(idx)
sample_indices = set(rng.sample(eligible, num_positions))
board = chess.Board()
valid_moves: list[str] = []
for i, san in enumerate(move_sans):
try:
move = board.parse_san(san)
board.push(move)
valid_moves.append(move.uci())
except (chess.InvalidMoveError, chess.AmbiguousMoveError):
break
if i in sample_indices:
token_ids = [cls_id] + tokenizer.encode_moves(valid_moves)
samples.append((token_ids, label))
if (idx + 1) % 50_000 == 0:
pbar.set_postfix(samples=f"{len(samples):,}")
return samples
def stage2_outcome_samples(args: argparse.Namespace) -> None:
tokenizer_path = args.out_dir / "tokenizer.pt"
meta_path = args.out_dir / "outcome_meta.pt"
if tokenizer_path.exists() and meta_path.exists() and not args.force:
print(f"Stage 2: skipping — {tokenizer_path.name} and {meta_path.name} exist.")
return
raw_games_path = args.out_dir / "games_outcome.pt"
print(f"Stage 2: loading outcome games from {raw_games_path}...")
games = torch.load(raw_games_path, weights_only=False)
print("Stage 2: building tokenizer from all UCI moves...")
tokenizer = build_tokenizer_from_games()
print(f"Stage 2: tokenizer vocab size = {tokenizer.language_size}")
torch.save(tokenizer, tokenizer_path)
print("Stage 2: generating outcome samples (up to 20 per game)...")
samples = _generate_outcome_samples(
games,
tokenizer,
max_positions_per_game=20,
skip_ply=0,
)
print(f"Stage 2: saving {len(samples):,} outcome samples as memmap...")
_save_as_memmap(samples, args.out_dir, "outcome", max_seq_len=args.max_seq_len)
def _generate_policy_sequences(games, tokenizer, max_seq_len: int = 128) -> list[list[int]]:
"""Tokenize full game sequences for policy training.
Each output sequence is [CLS, m1, m2, ..., mN], truncated to max_seq_len.
Games with fewer than 2 valid UCI moves are skipped.
"""
cls_id = tokenizer.symbol_to_token[CLS_TOKEN]
sequences: list[list[int]] = []
with tqdm(games, desc="Stage 4: policy sequences", unit="game") as pbar:
for game in pbar:
movetext = game.get("movetext", "")
if not movetext:
continue
move_sans = parse_movetext(movetext)
if len(move_sans) < 2:
continue
board = chess.Board()
move_ucis: list[str] = []
for san in move_sans:
try:
move = board.parse_san(san)
board.push(move)
move_ucis.append(move.uci())
except (chess.InvalidMoveError, chess.AmbiguousMoveError):
break
if len(move_ucis) < 2:
continue
move_ucis = move_ucis[:max_seq_len - 1]
sequences.append([cls_id] + tokenizer.encode_moves(move_ucis))
return sequences
def _save_policy_memmap(
sequences: list[list[int]], out_dir: Path, name: str, max_seq_len: int = 128,
fens: list[str] | None = None, fen_len: int = 100,
) -> None:
"""Save policy sequences as memory-mapped arrays (no labels).
Produces:
{name}_tokens.bin — (N, max_len) int32, zero-padded
{name}_lengths.bin — (N,) int32, actual sequence length per sample
{name}_meta.pt — dict with 'n', 'max_len', and (if fens given) 'fen_len'
If `fens` is provided, also writes {name}_fens.bin — (N, fen_len) uint8
holding zero-padded ASCII FEN strings, one per sample. Used by the
CNN-conditioned policy training to reconstruct each sample's starting board.
"""
n = len(sequences)
max_len = min(max(len(s) for s in sequences), max_seq_len)
print(f" memmap {name}: {n:,} sequences, max_seq_len={max_len}")
tokens = np.memmap(out_dir / f"{name}_tokens.bin", dtype=np.int32, mode="w+", shape=(n, max_len))
lengths = np.memmap(out_dir / f"{name}_lengths.bin", dtype=np.int32, mode="w+", shape=(n,))
for i, seq in enumerate(tqdm(sequences, desc=f" writing {name}", unit="seq")):
seq = seq[:max_len]
l = len(seq)
tokens[i, :l] = seq
lengths[i] = l
tokens.flush()
lengths.flush()
meta = {"n": n, "max_len": max_len}
extra_bytes = 0
if fens is not None:
assert len(fens) == n, f"fen count {len(fens)} mismatch with sequence count {n}"
fens_mm = np.memmap(out_dir / f"{name}_fens.bin", dtype=np.uint8, mode="w+", shape=(n, fen_len))
for i, fen in enumerate(fens):
b = fen.encode("ascii")[:fen_len]
fens_mm[i, :len(b)] = list(b)
fens_mm.flush()
meta["fen_len"] = fen_len
extra_bytes = fens_mm.nbytes
torch.save(meta, out_dir / f"{name}_meta.pt")
size_gb = (tokens.nbytes + lengths.nbytes + extra_bytes) / 1024 ** 3
print(f" memmap {name} saved ({size_gb:.3f} GB)")
def _process_puzzle(
row: dict,
tokenizer_symbol_map: dict,
cls_id: int,
) -> tuple[list[int], str] | None:
"""Parse one Lichess puzzle row into a (token_sequence, FEN) pair.
Sequence layout: [CLS, setup_move, solver_move1, opp_response, solver_move2, ...]
The setup move (Moves[0]) is included as context so the model conditions on it
when predicting the solution. During training the loss on the setup move position
is masked out — we model P[m_n | S, m_{<n}] where S is the setup move.
The FEN is the puzzle's starting board position. It is persisted alongside the
token sequence so the CNN-conditioned policy training can reconstruct the
starting board planes (CNN's input) at __getitem__ time.
Returns None if any move is illegal, unknown to the tokenizer, or the sequence
has fewer than 3 tokens (CLS + setup + at least one solver move).
"""
fen = row.get("FEN", "")
moves_str = row.get("Moves", "")
if not fen or not moves_str:
return None
uci_moves = moves_str.strip().split()
if len(uci_moves) < 2: # need setup + at least one solver move
return None
try:
board = chess.Board(fen)
except ValueError:
return None
# Tokenize all moves: setup first (as context), then the full solution.
token_ids: list[int] = [cls_id]
for uci in uci_moves:
try:
move = chess.Move.from_uci(uci)
except ValueError:
return None
if move not in board.legal_moves:
return None
if uci not in tokenizer_symbol_map:
return None
token_ids.append(tokenizer_symbol_map[uci])
board.push(move)
if len(token_ids) < 3: # CLS + setup + at least one solver move
return None
return token_ids, fen
def stage3_stockfish_samples(args: argparse.Namespace) -> None:
meta_path = args.out_dir / "stockfish_meta.pt"
if meta_path.exists() and not args.force:
print(f"Stage 3: skipping — {meta_path.name} exists.")
return
games_path = args.out_dir / "games_stockfish.pt"
tokenizer_path = args.out_dir / "tokenizer.pt"
print(f"Stage 3: loading {games_path} and {tokenizer_path}...")
games = torch.load(games_path, weights_only=False)
tokenizer = torch.load(tokenizer_path, weights_only=False)
print(
f"Stage 3: running parallel Stockfish ({args.workers} workers, "
f"depth {args.stockfish_depth}) on {len(games):,} games..."
)
samples = generate_samples_stockfish_parallel(
games,
tokenizer,
num_workers=args.workers,
stockfish_depth=args.stockfish_depth,
sample_rate=args.sample_rate,
skew_exponent=args.position_skew,
)
print(f"Stage 3: saving {len(samples):,} stockfish samples as memmap...")
_save_as_memmap(samples, args.out_dir, "stockfish", max_seq_len=args.max_seq_len)
def stage4_policy_sequences(args: argparse.Namespace) -> None:
meta_path = args.out_dir / "policy_meta.pt"
if meta_path.exists() and not args.force:
print(f"Stage 4: skipping — {meta_path.name} exists.")
return
games_path = args.out_dir / "games_outcome.pt"
tokenizer_path = args.out_dir / "tokenizer.pt"
print(f"Stage 4: loading {games_path} and {tokenizer_path}...")
games = torch.load(games_path, weights_only=False)
tokenizer = torch.load(tokenizer_path, weights_only=False)
print(f"Stage 4: tokenizing {len(games):,} games into policy sequences...")
sequences = _generate_policy_sequences(games, tokenizer, max_seq_len=args.max_seq_len)
print(f"Stage 4: saving {len(sequences):,} policy sequences as memmap...")
_save_policy_memmap(sequences, args.out_dir, "policy", max_seq_len=args.max_seq_len)
def _write_test_subset_reward(out_dir: Path, src_name: str, dst_name: str, indices: np.ndarray) -> None:
"""Write a subset of a reward memmap (tokens+labels+lengths) to new files."""
meta = torch.load(out_dir / f"{src_name}_meta.pt", weights_only=True)
n_src, max_len = meta["n"], meta["max_len"]
src_tokens = np.memmap(out_dir / f"{src_name}_tokens.bin", dtype=np.int32, mode="r", shape=(n_src, max_len))
src_labels = np.memmap(out_dir / f"{src_name}_labels.bin", dtype=np.float32, mode="r", shape=(n_src,))
src_lengths = np.memmap(out_dir / f"{src_name}_lengths.bin", dtype=np.int32, mode="r", shape=(n_src,))
n_test = len(indices)
dst_tokens = np.memmap(out_dir / f"{dst_name}_tokens.bin", dtype=np.int32, mode="w+", shape=(n_test, max_len))
dst_labels = np.memmap(out_dir / f"{dst_name}_labels.bin", dtype=np.float32, mode="w+", shape=(n_test,))
dst_lengths = np.memmap(out_dir / f"{dst_name}_lengths.bin", dtype=np.int32, mode="w+", shape=(n_test,))
for i, idx in enumerate(tqdm(indices, desc=f" writing {dst_name}", unit="sample")):
dst_tokens[i] = src_tokens[idx]
dst_labels[i] = src_labels[idx]
dst_lengths[i] = src_lengths[idx]
dst_tokens.flush()
dst_labels.flush()
dst_lengths.flush()
torch.save({"n": n_test, "max_len": max_len}, out_dir / f"{dst_name}_meta.pt")
print(f" {dst_name}: {n_test:,} samples written")
def _write_test_subset_policy(out_dir: Path, src_name: str, dst_name: str, indices: np.ndarray) -> None:
"""Write a subset of a policy memmap (tokens+lengths, no labels) to new files."""
meta = torch.load(out_dir / f"{src_name}_meta.pt", weights_only=True)
n_src, max_len = meta["n"], meta["max_len"]
src_tokens = np.memmap(out_dir / f"{src_name}_tokens.bin", dtype=np.int32, mode="r", shape=(n_src, max_len))
src_lengths = np.memmap(out_dir / f"{src_name}_lengths.bin", dtype=np.int32, mode="r", shape=(n_src,))
n_test = len(indices)
dst_tokens = np.memmap(out_dir / f"{dst_name}_tokens.bin", dtype=np.int32, mode="w+", shape=(n_test, max_len))
dst_lengths = np.memmap(out_dir / f"{dst_name}_lengths.bin", dtype=np.int32, mode="w+", shape=(n_test,))
for i, idx in enumerate(tqdm(indices, desc=f" writing {dst_name}", unit="seq")):
dst_tokens[i] = src_tokens[idx]
dst_lengths[i] = src_lengths[idx]
dst_tokens.flush()
dst_lengths.flush()
torch.save({"n": n_test, "max_len": max_len}, out_dir / f"{dst_name}_meta.pt")
print(f" {dst_name}: {n_test:,} sequences written")
def stage_build_test_splits(args: argparse.Namespace, out_dir: Path) -> None:
"""Build held-out test sets for reward and policy models from existing memmaps.
Uses a fixed random seed (42) so the same indices are always selected.
Saves the chosen indices to {name}_test_indices.npy so the corresponding
training memmap loader can exclude them — making train and test disjoint
even though they share the underlying .bin file.
Produces:
stockfish_test_*.bin / stockfish_test_meta.pt / stockfish_test_indices.npy
policy_test_*.bin / policy_test_meta.pt / policy_test_indices.npy
"""
rng = np.random.default_rng(42)
policy_only = getattr(args, "policy_only", False)
# Reward test set — skipped when --policy-only since no Stockfish data exists.
reward_test_meta = out_dir / "stockfish_test_meta.pt"
sf_meta_path = out_dir / "stockfish_meta.pt"
if policy_only:
print("Test splits: stockfish_test skipped (--policy-only).")
elif (not reward_test_meta.exists() or args.force) and sf_meta_path.exists():
print(f"Test splits: building stockfish_test ({args.reward_test_size:,} samples)...")
sf_meta = torch.load(sf_meta_path, weights_only=True)
n = sf_meta["n"]
test_n = min(args.reward_test_size, n)
idx = rng.choice(n, size=test_n, replace=False)
idx.sort()
_write_test_subset_reward(out_dir, "stockfish", "stockfish_test", idx)
np.save(out_dir / "stockfish_test_indices.npy", idx)
print(f" saved stockfish_test_indices.npy ({test_n:,} indices excluded from training)")
elif reward_test_meta.exists():
print("Test splits: stockfish_test already exists, skipping.")
# Policy test set
policy_test_meta = out_dir / "policy_test_meta.pt"
pol_meta_path = out_dir / "policy_meta.pt"
if (not policy_test_meta.exists() or args.force) and pol_meta_path.exists():
print(f"Test splits: building policy_test ({args.policy_test_size:,} sequences)...")
pol_meta = torch.load(pol_meta_path, weights_only=True)
n = pol_meta["n"]
test_n = min(args.policy_test_size, n)
idx = rng.choice(n, size=test_n, replace=False)
idx.sort()
_write_test_subset_policy(out_dir, "policy", "policy_test", idx)
np.save(out_dir / "policy_test_indices.npy", idx)
print(f" saved policy_test_indices.npy ({test_n:,} indices excluded from training)")
elif policy_test_meta.exists():
print("Test splits: policy_test already exists, skipping.")
def stage5_puzzle_samples(args: argparse.Namespace, tokenizer, out_dir: Path) -> None:
train_done = (out_dir / "puzzle_meta.pt").exists()
test_done = (out_dir / "puzzle_test_meta.pt").exists()
if train_done and test_done and not args.force:
print("Stage 5: skipping — puzzle_meta.pt and puzzle_test_meta.pt exist.")
return
print("Stage 5: loading Lichess/chess-puzzles from HuggingFace...")
ds = load_dataset("Lichess/chess-puzzles", split="train", streaming=True)
min_pop = args.min_puzzle_popularity
min_plays = args.min_puzzle_plays
cls_id = tokenizer.symbol_to_token[CLS_TOKEN]
sym_map = tokenizer.symbol_to_token
test_seqs: list[list[int]] = []
test_fens: list[str] = []
train_seqs: list[list[int]] = []
train_fens: list[str] = []
skipped = 0
test_target = args.puzzle_test_size
train_target = args.puzzle_count
with tqdm(ds, desc="Stage 5: puzzles", unit="puzzle") as pbar:
for row in pbar:
if min_pop is not None and row.get("Popularity", 0) < min_pop:
continue
if min_plays is not None and row.get("NbPlays", 0) < min_plays:
continue
result = _process_puzzle(row, sym_map, cls_id)
if result is None:
skipped += 1
continue
seq, fen = result
if len(test_seqs) < test_target:
test_seqs.append(seq)
test_fens.append(fen)
else:
train_seqs.append(seq)
train_fens.append(fen)
pbar.set_postfix(test=len(test_seqs), train=len(train_seqs), skipped=skipped)
if train_target is not None and len(train_seqs) >= train_target:
break
print(
f"Stage 5: test={len(test_seqs):,} puzzles, "
f"train={len(train_seqs):,} puzzles, "
f"skipped={skipped:,} invalid."
)
if test_seqs and not test_done:
_save_policy_memmap(
test_seqs, out_dir, "puzzle_test", max_seq_len=args.max_seq_len, fens=test_fens,
)
if train_seqs and not train_done:
_save_policy_memmap(
train_seqs, out_dir, "puzzle", max_seq_len=args.max_seq_len, fens=train_fens,
)
elif not train_seqs:
print("Stage 5: WARNING — no training puzzles collected.")
def main():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--out-dir", type=Path, default=Path("data"))
parser.add_argument("--policy-games", type=int, default=1_000_000,
help="Number of games to collect for policy model training")
parser.add_argument("--reward-games", type=int, default=1_000_000,
help="Number of games to collect for reward model (Stockfish eval)")
parser.add_argument("--policy-min-elo", type=int, default=1800,
help="Min Elo for both players in policy training games")
parser.add_argument("--reward-min-elo", type=int, default=1500,
help="Min Elo for both players in reward model training games")
parser.add_argument("--sample-rate", type=float, default=0.25,
help="Fraction of positions to sample per game (scales with game length)")
parser.add_argument("--position-skew", type=float, default=1.5,
help="Power-law exponent weighting later positions; 1.0=linear, higher=more mid/late")
parser.add_argument("--workers", type=int, default=16)
parser.add_argument("--stockfish-depth", type=int, default=12)
parser.add_argument("--max-seq-len", type=int, default=128,
help="Truncate token sequences to this length when writing .bin files")
parser.add_argument(
"--force",
action="store_true",
help="Re-run all stages even if their outputs already exist",
)
parser.add_argument("--puzzle-count", type=int, default=None, dest="puzzle_count",
help="Max puzzles to include (default: all ~4.99M)")
parser.add_argument("--min-puzzle-popularity", type=int, default=None, dest="min_puzzle_popularity",
help="Min Lichess Popularity score (0-100 scale)")
parser.add_argument("--min-puzzle-plays", type=int, default=None, dest="min_puzzle_plays",
help="Min NbPlays for a puzzle to be included")
parser.add_argument("--skip-puzzles", action="store_true",
help="Skip Stage 5 puzzle processing")
parser.add_argument("--puzzles-only", action="store_true",
help="Only run Stage 5 (puzzle processing). Skips game collection, "
"outcome/Stockfish/policy memmaps, and test splits. Requires "
"tokenizer.pt to exist (or it will be built from the UCI vocab).")
parser.add_argument("--policy-only", action="store_true",
help="Skip everything Stockfish/reward-related: Stage 1 collects only "
"policy games, Stage 3 (Stockfish labeling) is skipped, and the "
"stockfish_test split is not built. Stages 1/2/4/5 + policy_test "
"still run, producing tokenizer.pt, policy_* / puzzle_* memmaps, "
"and the policy_test split.")
parser.add_argument("--puzzle-test-size", type=int, default=100_000, dest="puzzle_test_size",
help="Number of puzzle sequences held out for the test set (default: 100000)")
parser.add_argument("--reward-test-size", type=int, default=50_000, dest="reward_test_size",
help="Number of reward positions held out for the test set (default: 50000)")
parser.add_argument("--policy-test-size", type=int, default=50_000, dest="policy_test_size",
help="Number of policy sequences held out for the test set (default: 50000)")
args = parser.parse_args()
args.out_dir.mkdir(parents=True, exist_ok=True)
if args.puzzles_only and args.policy_only:
parser.error("--puzzles-only and --policy-only are mutually exclusive.")
if args.puzzles_only:
print("--puzzles-only: skipping Stages 1-4 and test-split builder.")
tokenizer_path = args.out_dir / "tokenizer.pt"
if tokenizer_path.exists():
tokenizer = torch.load(tokenizer_path, weights_only=False)
else:
# Tokenizer is just the enumerated UCI vocab — no games needed.
print(" tokenizer.pt missing; building from UCI vocab...")
tokenizer = build_tokenizer_from_games()
torch.save(tokenizer, tokenizer_path)
stage5_puzzle_samples(args, tokenizer, args.out_dir)
elif args.policy_only:
print("--policy-only: skipping Stage 3 (Stockfish labeling) and stockfish_test split.")
stage1_collect_games(args)
stage2_outcome_samples(args)
stage4_policy_sequences(args)
tokenizer_path = args.out_dir / "tokenizer.pt"
if not args.skip_puzzles and tokenizer_path.exists():
tokenizer = torch.load(tokenizer_path, weights_only=False)
stage5_puzzle_samples(args, tokenizer, args.out_dir)
elif not args.skip_puzzles:
print("Stage 5: skipping — tokenizer.pt not found (run stages 1-2 first).")
stage_build_test_splits(args, args.out_dir)
else:
stage1_collect_games(args)
stage2_outcome_samples(args)
stage3_stockfish_samples(args)
stage4_policy_sequences(args)
tokenizer_path = args.out_dir / "tokenizer.pt"
if not args.skip_puzzles and tokenizer_path.exists():
tokenizer = torch.load(tokenizer_path, weights_only=False)
stage5_puzzle_samples(args, tokenizer, args.out_dir)
elif not args.skip_puzzles:
print("Stage 5: skipping — tokenizer.pt not found (run stages 1-2 first).")
stage_build_test_splits(args, args.out_dir)
print("\nAll stages complete. Artifacts:")
for name in (
"games_outcome.pt",
"games_stockfish.pt",
"tokenizer.pt",
"outcome_tokens.bin",
"outcome_labels.bin",
"outcome_lengths.bin",
"outcome_meta.pt",
"stockfish_tokens.bin",
"stockfish_labels.bin",
"stockfish_lengths.bin",
"stockfish_meta.pt",
"policy_tokens.bin",
"policy_lengths.bin",
"policy_meta.pt",
"puzzle_tokens.bin",
"puzzle_lengths.bin",
"puzzle_fens.bin",
"puzzle_meta.pt",
"puzzle_test_tokens.bin",
"puzzle_test_lengths.bin",
"puzzle_test_fens.bin",
"puzzle_test_meta.pt",
"stockfish_test_tokens.bin",
"stockfish_test_labels.bin",
"stockfish_test_lengths.bin",
"stockfish_test_meta.pt",
"policy_test_tokens.bin",
"policy_test_lengths.bin",
"policy_test_meta.pt",
):
path = args.out_dir / name
size_mb = path.stat().st_size / 1024 / 1024 if path.exists() else 0
print(f" {path} ({size_mb:.1f} MB)")
if __name__ == "__main__":
main()