"""Build the two training datasets for hybrid two-phase training. Runs three resumable stages: 1. Stream the Lichess HF dataset, filter by Elo + Termination, save two disjoint raw-game subsets (movetext + Result only). 2. Build the shared tokenizer from the outcome subset and generate outcome-labeled samples ({+1, 0, -1} from game Result). 3. Run parallel Stockfish on the disjoint subset to produce precisely labeled samples (tanh(cp/400)). Each stage skips if its output files already exist. Use --force to re-run. Outputs (under --out-dir): games_outcome.pt raw outcome-subset games games_stockfish.pt raw stockfish-subset games tokenizer.pt shared Tokenizer (built from outcome games) outcome_samples.pt list[(token_ids, outcome_label)] stockfish_samples.pt list[(token_ids, stockfish_label)] """ import argparse import random from pathlib import Path import chess import numpy as np import torch from datasets import load_dataset from tqdm import tqdm from model import CLS_TOKEN from train import ( build_tokenizer_from_games, generate_samples_stockfish_parallel, parse_movetext, ) # Lichess Result → outcome label from white's perspective. RESULT_TO_LABEL = {"1-0": 1.0, "0-1": -1.0, "1/2-1/2": 0.0} def _save_as_memmap( samples: list[tuple[list[int], float]], out_dir: Path, name: str, max_seq_len: int = 128 ) -> None: """Save samples as memory-mapped arrays for fast DataLoader access. Sequences longer than max_seq_len are truncated (keeps the most recent tokens, since the CLS token is at position 0 we keep ids[:max_seq_len]). Produces three files: {name}_tokens.bin — (N, max_seq_len) int32, zero-padded {name}_labels.bin — (N,) float32 {name}_lengths.bin — (N,) int32, actual sequence length per sample (capped at max_seq_len) {name}_meta.pt — dict with 'n' and 'max_len' """ n = len(samples) max_len = min(max(len(ids) for ids, _ in samples), max_seq_len) print(f" memmap {name}: {n:,} samples, max_seq_len={max_len}") tokens = np.memmap(out_dir / f"{name}_tokens.bin", dtype=np.int32, mode="w+", shape=(n, max_len)) labels = np.memmap(out_dir / f"{name}_labels.bin", dtype=np.float32, mode="w+", shape=(n,)) lengths = np.memmap(out_dir / f"{name}_lengths.bin", dtype=np.int32, mode="w+", shape=(n,)) for i, (ids, label) in enumerate(tqdm(samples, desc=f" writing {name}", unit="sample")): ids = ids[:max_len] l = len(ids) tokens[i, :l] = ids labels[i] = label lengths[i] = l tokens.flush() labels.flush() lengths.flush() torch.save({"n": n, "max_len": max_len}, out_dir / f"{name}_meta.pt") size_gb = (tokens.nbytes + labels.nbytes + lengths.nbytes) / 1024 ** 3 print(f" memmap {name} saved ({size_gb:.2f} GB)") def stage1_collect_games(args: argparse.Namespace) -> None: policy_games_path = args.out_dir / "games_outcome.pt" reward_games_path = args.out_dir / "games_stockfish.pt" policy_only = getattr(args, "policy_only", False) if policy_only: # Reward subset is irrelevant when we're only training the policy model. # Skip-condition checks only the policy artifact. if policy_games_path.exists() and not args.force: print(f"Stage 1: skipping — {policy_games_path.name} exists (--policy-only).") return else: if policy_games_path.exists() and reward_games_path.exists() and not args.force: print(f"Stage 1: skipping — {policy_games_path.name} and {reward_games_path.name} exist.") return if policy_only: lower_elo = args.policy_min_elo print( f"Stage 1: streaming Lichess/standard-chess-games (Termination == 'Normal'), " f"policy Elo >= {args.policy_min_elo} (target {args.policy_games:,}). " f"Reward subset skipped (--policy-only)." ) else: lower_elo = min(args.reward_min_elo, args.policy_min_elo) print( f"Stage 1: streaming Lichess/standard-chess-games (Termination == 'Normal'), " f"reward Elo >= {args.reward_min_elo} (target {args.reward_games:,}), " f"policy Elo >= {args.policy_min_elo} (target {args.policy_games:,})..." ) ds = load_dataset("Lichess/standard-chess-games", split="train", streaming=True) # Pre-filter by the lower of the two thresholds to skip clearly ineligible games. ds = ds.filter( lambda r: ( r.get("WhiteElo") is not None and r.get("BlackElo") is not None and r["WhiteElo"] >= lower_elo and r["BlackElo"] >= lower_elo and r.get("Termination") == "Normal" ) ) policy_games: list[dict] = [] reward_games: list[dict] = [] keep_keys = ("movetext", "Result") for row in tqdm(ds, desc="Stage 1: streaming", unit="game"): white_elo = row.get("WhiteElo", 0) black_elo = row.get("BlackElo", 0) minimal = {k: row.get(k) for k in keep_keys} if ( not policy_only and len(reward_games) < args.reward_games and white_elo >= args.reward_min_elo and black_elo >= args.reward_min_elo ): reward_games.append(minimal) if len(policy_games) < args.policy_games and white_elo >= args.policy_min_elo and black_elo >= args.policy_min_elo: policy_games.append(minimal) if policy_only: if len(policy_games) >= args.policy_games: break elif len(reward_games) >= args.reward_games and len(policy_games) >= args.policy_games: break if policy_only: if len(policy_games) < args.policy_games: print( f" WARNING: dataset exhausted before target — " f"got {len(policy_games):,} policy games." ) elif len(reward_games) < args.reward_games or len(policy_games) < args.policy_games: print( f" WARNING: dataset exhausted before target — " f"got {len(reward_games):,} reward + {len(policy_games):,} policy games." ) if not policy_only: print(f"Stage 1: saving {reward_games_path} ({len(reward_games):,} games)...") torch.save(reward_games, reward_games_path) print(f"Stage 1: saving {policy_games_path} ({len(policy_games):,} games)...") torch.save(policy_games, policy_games_path) def _generate_outcome_samples(games, tokenizer, max_positions_per_game, skip_ply): """Build (token_ids, outcome_label) samples for the phase-1 dataset.""" cls_id = tokenizer.symbol_to_token[CLS_TOKEN] samples: list[tuple[list[int], float]] = [] with tqdm(games, desc="Stage 2: outcome samples", unit="game") as pbar: for idx, game in enumerate(pbar): result = game.get("Result") if result not in RESULT_TO_LABEL: continue label = RESULT_TO_LABEL[result] movetext = game.get("movetext", "") if not movetext: continue move_sans = parse_movetext(movetext) if len(move_sans) < max(2, skip_ply + 1): continue eligible = list(range(skip_ply, len(move_sans))) num_positions = min(max_positions_per_game, len(eligible)) rng = random.Random(idx) sample_indices = set(rng.sample(eligible, num_positions)) board = chess.Board() valid_moves: list[str] = [] for i, san in enumerate(move_sans): try: move = board.parse_san(san) board.push(move) valid_moves.append(move.uci()) except (chess.InvalidMoveError, chess.AmbiguousMoveError): break if i in sample_indices: token_ids = [cls_id] + tokenizer.encode_moves(valid_moves) samples.append((token_ids, label)) if (idx + 1) % 50_000 == 0: pbar.set_postfix(samples=f"{len(samples):,}") return samples def stage2_outcome_samples(args: argparse.Namespace) -> None: tokenizer_path = args.out_dir / "tokenizer.pt" meta_path = args.out_dir / "outcome_meta.pt" if tokenizer_path.exists() and meta_path.exists() and not args.force: print(f"Stage 2: skipping — {tokenizer_path.name} and {meta_path.name} exist.") return raw_games_path = args.out_dir / "games_outcome.pt" print(f"Stage 2: loading outcome games from {raw_games_path}...") games = torch.load(raw_games_path, weights_only=False) print("Stage 2: building tokenizer from all UCI moves...") tokenizer = build_tokenizer_from_games() print(f"Stage 2: tokenizer vocab size = {tokenizer.language_size}") torch.save(tokenizer, tokenizer_path) print("Stage 2: generating outcome samples (up to 20 per game)...") samples = _generate_outcome_samples( games, tokenizer, max_positions_per_game=20, skip_ply=0, ) print(f"Stage 2: saving {len(samples):,} outcome samples as memmap...") _save_as_memmap(samples, args.out_dir, "outcome", max_seq_len=args.max_seq_len) def _generate_policy_sequences(games, tokenizer, max_seq_len: int = 128) -> list[list[int]]: """Tokenize full game sequences for policy training. Each output sequence is [CLS, m1, m2, ..., mN], truncated to max_seq_len. Games with fewer than 2 valid UCI moves are skipped. """ cls_id = tokenizer.symbol_to_token[CLS_TOKEN] sequences: list[list[int]] = [] with tqdm(games, desc="Stage 4: policy sequences", unit="game") as pbar: for game in pbar: movetext = game.get("movetext", "") if not movetext: continue move_sans = parse_movetext(movetext) if len(move_sans) < 2: continue board = chess.Board() move_ucis: list[str] = [] for san in move_sans: try: move = board.parse_san(san) board.push(move) move_ucis.append(move.uci()) except (chess.InvalidMoveError, chess.AmbiguousMoveError): break if len(move_ucis) < 2: continue move_ucis = move_ucis[:max_seq_len - 1] sequences.append([cls_id] + tokenizer.encode_moves(move_ucis)) return sequences def _save_policy_memmap( sequences: list[list[int]], out_dir: Path, name: str, max_seq_len: int = 128, fens: list[str] | None = None, fen_len: int = 100, ) -> None: """Save policy sequences as memory-mapped arrays (no labels). Produces: {name}_tokens.bin — (N, max_len) int32, zero-padded {name}_lengths.bin — (N,) int32, actual sequence length per sample {name}_meta.pt — dict with 'n', 'max_len', and (if fens given) 'fen_len' If `fens` is provided, also writes {name}_fens.bin — (N, fen_len) uint8 holding zero-padded ASCII FEN strings, one per sample. Used by the CNN-conditioned policy training to reconstruct each sample's starting board. """ n = len(sequences) max_len = min(max(len(s) for s in sequences), max_seq_len) print(f" memmap {name}: {n:,} sequences, max_seq_len={max_len}") tokens = np.memmap(out_dir / f"{name}_tokens.bin", dtype=np.int32, mode="w+", shape=(n, max_len)) lengths = np.memmap(out_dir / f"{name}_lengths.bin", dtype=np.int32, mode="w+", shape=(n,)) for i, seq in enumerate(tqdm(sequences, desc=f" writing {name}", unit="seq")): seq = seq[:max_len] l = len(seq) tokens[i, :l] = seq lengths[i] = l tokens.flush() lengths.flush() meta = {"n": n, "max_len": max_len} extra_bytes = 0 if fens is not None: assert len(fens) == n, f"fen count {len(fens)} mismatch with sequence count {n}" fens_mm = np.memmap(out_dir / f"{name}_fens.bin", dtype=np.uint8, mode="w+", shape=(n, fen_len)) for i, fen in enumerate(fens): b = fen.encode("ascii")[:fen_len] fens_mm[i, :len(b)] = list(b) fens_mm.flush() meta["fen_len"] = fen_len extra_bytes = fens_mm.nbytes torch.save(meta, out_dir / f"{name}_meta.pt") size_gb = (tokens.nbytes + lengths.nbytes + extra_bytes) / 1024 ** 3 print(f" memmap {name} saved ({size_gb:.3f} GB)") def _process_puzzle( row: dict, tokenizer_symbol_map: dict, cls_id: int, ) -> tuple[list[int], str] | None: """Parse one Lichess puzzle row into a (token_sequence, FEN) pair. Sequence layout: [CLS, setup_move, solver_move1, opp_response, solver_move2, ...] The setup move (Moves[0]) is included as context so the model conditions on it when predicting the solution. During training the loss on the setup move position is masked out — we model P[m_n | S, m_{ None: meta_path = args.out_dir / "stockfish_meta.pt" if meta_path.exists() and not args.force: print(f"Stage 3: skipping — {meta_path.name} exists.") return games_path = args.out_dir / "games_stockfish.pt" tokenizer_path = args.out_dir / "tokenizer.pt" print(f"Stage 3: loading {games_path} and {tokenizer_path}...") games = torch.load(games_path, weights_only=False) tokenizer = torch.load(tokenizer_path, weights_only=False) print( f"Stage 3: running parallel Stockfish ({args.workers} workers, " f"depth {args.stockfish_depth}) on {len(games):,} games..." ) samples = generate_samples_stockfish_parallel( games, tokenizer, num_workers=args.workers, stockfish_depth=args.stockfish_depth, sample_rate=args.sample_rate, skew_exponent=args.position_skew, ) print(f"Stage 3: saving {len(samples):,} stockfish samples as memmap...") _save_as_memmap(samples, args.out_dir, "stockfish", max_seq_len=args.max_seq_len) def stage4_policy_sequences(args: argparse.Namespace) -> None: meta_path = args.out_dir / "policy_meta.pt" if meta_path.exists() and not args.force: print(f"Stage 4: skipping — {meta_path.name} exists.") return games_path = args.out_dir / "games_outcome.pt" tokenizer_path = args.out_dir / "tokenizer.pt" print(f"Stage 4: loading {games_path} and {tokenizer_path}...") games = torch.load(games_path, weights_only=False) tokenizer = torch.load(tokenizer_path, weights_only=False) print(f"Stage 4: tokenizing {len(games):,} games into policy sequences...") sequences = _generate_policy_sequences(games, tokenizer, max_seq_len=args.max_seq_len) print(f"Stage 4: saving {len(sequences):,} policy sequences as memmap...") _save_policy_memmap(sequences, args.out_dir, "policy", max_seq_len=args.max_seq_len) def _write_test_subset_reward(out_dir: Path, src_name: str, dst_name: str, indices: np.ndarray) -> None: """Write a subset of a reward memmap (tokens+labels+lengths) to new files.""" meta = torch.load(out_dir / f"{src_name}_meta.pt", weights_only=True) n_src, max_len = meta["n"], meta["max_len"] src_tokens = np.memmap(out_dir / f"{src_name}_tokens.bin", dtype=np.int32, mode="r", shape=(n_src, max_len)) src_labels = np.memmap(out_dir / f"{src_name}_labels.bin", dtype=np.float32, mode="r", shape=(n_src,)) src_lengths = np.memmap(out_dir / f"{src_name}_lengths.bin", dtype=np.int32, mode="r", shape=(n_src,)) n_test = len(indices) dst_tokens = np.memmap(out_dir / f"{dst_name}_tokens.bin", dtype=np.int32, mode="w+", shape=(n_test, max_len)) dst_labels = np.memmap(out_dir / f"{dst_name}_labels.bin", dtype=np.float32, mode="w+", shape=(n_test,)) dst_lengths = np.memmap(out_dir / f"{dst_name}_lengths.bin", dtype=np.int32, mode="w+", shape=(n_test,)) for i, idx in enumerate(tqdm(indices, desc=f" writing {dst_name}", unit="sample")): dst_tokens[i] = src_tokens[idx] dst_labels[i] = src_labels[idx] dst_lengths[i] = src_lengths[idx] dst_tokens.flush() dst_labels.flush() dst_lengths.flush() torch.save({"n": n_test, "max_len": max_len}, out_dir / f"{dst_name}_meta.pt") print(f" {dst_name}: {n_test:,} samples written") def _write_test_subset_policy(out_dir: Path, src_name: str, dst_name: str, indices: np.ndarray) -> None: """Write a subset of a policy memmap (tokens+lengths, no labels) to new files.""" meta = torch.load(out_dir / f"{src_name}_meta.pt", weights_only=True) n_src, max_len = meta["n"], meta["max_len"] src_tokens = np.memmap(out_dir / f"{src_name}_tokens.bin", dtype=np.int32, mode="r", shape=(n_src, max_len)) src_lengths = np.memmap(out_dir / f"{src_name}_lengths.bin", dtype=np.int32, mode="r", shape=(n_src,)) n_test = len(indices) dst_tokens = np.memmap(out_dir / f"{dst_name}_tokens.bin", dtype=np.int32, mode="w+", shape=(n_test, max_len)) dst_lengths = np.memmap(out_dir / f"{dst_name}_lengths.bin", dtype=np.int32, mode="w+", shape=(n_test,)) for i, idx in enumerate(tqdm(indices, desc=f" writing {dst_name}", unit="seq")): dst_tokens[i] = src_tokens[idx] dst_lengths[i] = src_lengths[idx] dst_tokens.flush() dst_lengths.flush() torch.save({"n": n_test, "max_len": max_len}, out_dir / f"{dst_name}_meta.pt") print(f" {dst_name}: {n_test:,} sequences written") def stage_build_test_splits(args: argparse.Namespace, out_dir: Path) -> None: """Build held-out test sets for reward and policy models from existing memmaps. Uses a fixed random seed (42) so the same indices are always selected. Saves the chosen indices to {name}_test_indices.npy so the corresponding training memmap loader can exclude them — making train and test disjoint even though they share the underlying .bin file. Produces: stockfish_test_*.bin / stockfish_test_meta.pt / stockfish_test_indices.npy policy_test_*.bin / policy_test_meta.pt / policy_test_indices.npy """ rng = np.random.default_rng(42) policy_only = getattr(args, "policy_only", False) # Reward test set — skipped when --policy-only since no Stockfish data exists. reward_test_meta = out_dir / "stockfish_test_meta.pt" sf_meta_path = out_dir / "stockfish_meta.pt" if policy_only: print("Test splits: stockfish_test skipped (--policy-only).") elif (not reward_test_meta.exists() or args.force) and sf_meta_path.exists(): print(f"Test splits: building stockfish_test ({args.reward_test_size:,} samples)...") sf_meta = torch.load(sf_meta_path, weights_only=True) n = sf_meta["n"] test_n = min(args.reward_test_size, n) idx = rng.choice(n, size=test_n, replace=False) idx.sort() _write_test_subset_reward(out_dir, "stockfish", "stockfish_test", idx) np.save(out_dir / "stockfish_test_indices.npy", idx) print(f" saved stockfish_test_indices.npy ({test_n:,} indices excluded from training)") elif reward_test_meta.exists(): print("Test splits: stockfish_test already exists, skipping.") # Policy test set policy_test_meta = out_dir / "policy_test_meta.pt" pol_meta_path = out_dir / "policy_meta.pt" if (not policy_test_meta.exists() or args.force) and pol_meta_path.exists(): print(f"Test splits: building policy_test ({args.policy_test_size:,} sequences)...") pol_meta = torch.load(pol_meta_path, weights_only=True) n = pol_meta["n"] test_n = min(args.policy_test_size, n) idx = rng.choice(n, size=test_n, replace=False) idx.sort() _write_test_subset_policy(out_dir, "policy", "policy_test", idx) np.save(out_dir / "policy_test_indices.npy", idx) print(f" saved policy_test_indices.npy ({test_n:,} indices excluded from training)") elif policy_test_meta.exists(): print("Test splits: policy_test already exists, skipping.") def stage5_puzzle_samples(args: argparse.Namespace, tokenizer, out_dir: Path) -> None: train_done = (out_dir / "puzzle_meta.pt").exists() test_done = (out_dir / "puzzle_test_meta.pt").exists() if train_done and test_done and not args.force: print("Stage 5: skipping — puzzle_meta.pt and puzzle_test_meta.pt exist.") return print("Stage 5: loading Lichess/chess-puzzles from HuggingFace...") ds = load_dataset("Lichess/chess-puzzles", split="train", streaming=True) min_pop = args.min_puzzle_popularity min_plays = args.min_puzzle_plays cls_id = tokenizer.symbol_to_token[CLS_TOKEN] sym_map = tokenizer.symbol_to_token test_seqs: list[list[int]] = [] test_fens: list[str] = [] train_seqs: list[list[int]] = [] train_fens: list[str] = [] skipped = 0 test_target = args.puzzle_test_size train_target = args.puzzle_count with tqdm(ds, desc="Stage 5: puzzles", unit="puzzle") as pbar: for row in pbar: if min_pop is not None and row.get("Popularity", 0) < min_pop: continue if min_plays is not None and row.get("NbPlays", 0) < min_plays: continue result = _process_puzzle(row, sym_map, cls_id) if result is None: skipped += 1 continue seq, fen = result if len(test_seqs) < test_target: test_seqs.append(seq) test_fens.append(fen) else: train_seqs.append(seq) train_fens.append(fen) pbar.set_postfix(test=len(test_seqs), train=len(train_seqs), skipped=skipped) if train_target is not None and len(train_seqs) >= train_target: break print( f"Stage 5: test={len(test_seqs):,} puzzles, " f"train={len(train_seqs):,} puzzles, " f"skipped={skipped:,} invalid." ) if test_seqs and not test_done: _save_policy_memmap( test_seqs, out_dir, "puzzle_test", max_seq_len=args.max_seq_len, fens=test_fens, ) if train_seqs and not train_done: _save_policy_memmap( train_seqs, out_dir, "puzzle", max_seq_len=args.max_seq_len, fens=train_fens, ) elif not train_seqs: print("Stage 5: WARNING — no training puzzles collected.") def main(): parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("--out-dir", type=Path, default=Path("data")) parser.add_argument("--policy-games", type=int, default=1_000_000, help="Number of games to collect for policy model training") parser.add_argument("--reward-games", type=int, default=1_000_000, help="Number of games to collect for reward model (Stockfish eval)") parser.add_argument("--policy-min-elo", type=int, default=1800, help="Min Elo for both players in policy training games") parser.add_argument("--reward-min-elo", type=int, default=1500, help="Min Elo for both players in reward model training games") parser.add_argument("--sample-rate", type=float, default=0.25, help="Fraction of positions to sample per game (scales with game length)") parser.add_argument("--position-skew", type=float, default=1.5, help="Power-law exponent weighting later positions; 1.0=linear, higher=more mid/late") parser.add_argument("--workers", type=int, default=16) parser.add_argument("--stockfish-depth", type=int, default=12) parser.add_argument("--max-seq-len", type=int, default=128, help="Truncate token sequences to this length when writing .bin files") parser.add_argument( "--force", action="store_true", help="Re-run all stages even if their outputs already exist", ) parser.add_argument("--puzzle-count", type=int, default=None, dest="puzzle_count", help="Max puzzles to include (default: all ~4.99M)") parser.add_argument("--min-puzzle-popularity", type=int, default=None, dest="min_puzzle_popularity", help="Min Lichess Popularity score (0-100 scale)") parser.add_argument("--min-puzzle-plays", type=int, default=None, dest="min_puzzle_plays", help="Min NbPlays for a puzzle to be included") parser.add_argument("--skip-puzzles", action="store_true", help="Skip Stage 5 puzzle processing") parser.add_argument("--puzzles-only", action="store_true", help="Only run Stage 5 (puzzle processing). Skips game collection, " "outcome/Stockfish/policy memmaps, and test splits. Requires " "tokenizer.pt to exist (or it will be built from the UCI vocab).") parser.add_argument("--policy-only", action="store_true", help="Skip everything Stockfish/reward-related: Stage 1 collects only " "policy games, Stage 3 (Stockfish labeling) is skipped, and the " "stockfish_test split is not built. Stages 1/2/4/5 + policy_test " "still run, producing tokenizer.pt, policy_* / puzzle_* memmaps, " "and the policy_test split.") parser.add_argument("--puzzle-test-size", type=int, default=100_000, dest="puzzle_test_size", help="Number of puzzle sequences held out for the test set (default: 100000)") parser.add_argument("--reward-test-size", type=int, default=50_000, dest="reward_test_size", help="Number of reward positions held out for the test set (default: 50000)") parser.add_argument("--policy-test-size", type=int, default=50_000, dest="policy_test_size", help="Number of policy sequences held out for the test set (default: 50000)") args = parser.parse_args() args.out_dir.mkdir(parents=True, exist_ok=True) if args.puzzles_only and args.policy_only: parser.error("--puzzles-only and --policy-only are mutually exclusive.") if args.puzzles_only: print("--puzzles-only: skipping Stages 1-4 and test-split builder.") tokenizer_path = args.out_dir / "tokenizer.pt" if tokenizer_path.exists(): tokenizer = torch.load(tokenizer_path, weights_only=False) else: # Tokenizer is just the enumerated UCI vocab — no games needed. print(" tokenizer.pt missing; building from UCI vocab...") tokenizer = build_tokenizer_from_games() torch.save(tokenizer, tokenizer_path) stage5_puzzle_samples(args, tokenizer, args.out_dir) elif args.policy_only: print("--policy-only: skipping Stage 3 (Stockfish labeling) and stockfish_test split.") stage1_collect_games(args) stage2_outcome_samples(args) stage4_policy_sequences(args) tokenizer_path = args.out_dir / "tokenizer.pt" if not args.skip_puzzles and tokenizer_path.exists(): tokenizer = torch.load(tokenizer_path, weights_only=False) stage5_puzzle_samples(args, tokenizer, args.out_dir) elif not args.skip_puzzles: print("Stage 5: skipping — tokenizer.pt not found (run stages 1-2 first).") stage_build_test_splits(args, args.out_dir) else: stage1_collect_games(args) stage2_outcome_samples(args) stage3_stockfish_samples(args) stage4_policy_sequences(args) tokenizer_path = args.out_dir / "tokenizer.pt" if not args.skip_puzzles and tokenizer_path.exists(): tokenizer = torch.load(tokenizer_path, weights_only=False) stage5_puzzle_samples(args, tokenizer, args.out_dir) elif not args.skip_puzzles: print("Stage 5: skipping — tokenizer.pt not found (run stages 1-2 first).") stage_build_test_splits(args, args.out_dir) print("\nAll stages complete. Artifacts:") for name in ( "games_outcome.pt", "games_stockfish.pt", "tokenizer.pt", "outcome_tokens.bin", "outcome_labels.bin", "outcome_lengths.bin", "outcome_meta.pt", "stockfish_tokens.bin", "stockfish_labels.bin", "stockfish_lengths.bin", "stockfish_meta.pt", "policy_tokens.bin", "policy_lengths.bin", "policy_meta.pt", "puzzle_tokens.bin", "puzzle_lengths.bin", "puzzle_fens.bin", "puzzle_meta.pt", "puzzle_test_tokens.bin", "puzzle_test_lengths.bin", "puzzle_test_fens.bin", "puzzle_test_meta.pt", "stockfish_test_tokens.bin", "stockfish_test_labels.bin", "stockfish_test_lengths.bin", "stockfish_test_meta.pt", "policy_test_tokens.bin", "policy_test_lengths.bin", "policy_test_meta.pt", ): path = args.out_dir / name size_mb = path.stat().st_size / 1024 / 1024 if path.exists() else 0 print(f" {path} ({size_mb:.1f} MB)") if __name__ == "__main__": main()