File size: 30,487 Bytes
6d75857
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
"""Build the two training datasets for hybrid two-phase training.

Runs three resumable stages:
  1. Stream the Lichess HF dataset, filter by Elo + Termination, save two
     disjoint raw-game subsets (movetext + Result only).
  2. Build the shared tokenizer from the outcome subset and generate
     outcome-labeled samples ({+1, 0, -1} from game Result).
  3. Run parallel Stockfish on the disjoint subset to produce precisely
     labeled samples (tanh(cp/400)).

Each stage skips if its output files already exist. Use --force to re-run.

Outputs (under --out-dir):
  games_outcome.pt       raw outcome-subset games
  games_stockfish.pt     raw stockfish-subset games
  tokenizer.pt           shared Tokenizer (built from outcome games)
  outcome_samples.pt     list[(token_ids, outcome_label)]
  stockfish_samples.pt   list[(token_ids, stockfish_label)]
"""

import argparse
import random
from pathlib import Path

import chess
import numpy as np
import torch
from datasets import load_dataset
from tqdm import tqdm

from model import CLS_TOKEN
from train import (
    build_tokenizer_from_games,
    generate_samples_stockfish_parallel,
    parse_movetext,
)

# Lichess Result → outcome label from white's perspective.
RESULT_TO_LABEL = {"1-0": 1.0, "0-1": -1.0, "1/2-1/2": 0.0}


def _save_as_memmap(
    samples: list[tuple[list[int], float]], out_dir: Path, name: str, max_seq_len: int = 128
) -> None:
    """Save samples as memory-mapped arrays for fast DataLoader access.

    Sequences longer than max_seq_len are truncated (keeps the most recent tokens,
    since the CLS token is at position 0 we keep ids[:max_seq_len]).

    Produces three files:
      {name}_tokens.bin   — (N, max_seq_len) int32, zero-padded
      {name}_labels.bin   — (N,) float32
      {name}_lengths.bin  — (N,) int32, actual sequence length per sample (capped at max_seq_len)
      {name}_meta.pt      — dict with 'n' and 'max_len'
    """
    n = len(samples)
    max_len = min(max(len(ids) for ids, _ in samples), max_seq_len)
    print(f"  memmap {name}: {n:,} samples, max_seq_len={max_len}")

    tokens = np.memmap(out_dir / f"{name}_tokens.bin", dtype=np.int32, mode="w+", shape=(n, max_len))
    labels = np.memmap(out_dir / f"{name}_labels.bin", dtype=np.float32, mode="w+", shape=(n,))
    lengths = np.memmap(out_dir / f"{name}_lengths.bin", dtype=np.int32, mode="w+", shape=(n,))

    for i, (ids, label) in enumerate(tqdm(samples, desc=f"  writing {name}", unit="sample")):
        ids = ids[:max_len]
        l = len(ids)
        tokens[i, :l] = ids
        labels[i] = label
        lengths[i] = l

    tokens.flush()
    labels.flush()
    lengths.flush()
    torch.save({"n": n, "max_len": max_len}, out_dir / f"{name}_meta.pt")
    size_gb = (tokens.nbytes + labels.nbytes + lengths.nbytes) / 1024 ** 3
    print(f"  memmap {name} saved ({size_gb:.2f} GB)")


def stage1_collect_games(args: argparse.Namespace) -> None:
    policy_games_path = args.out_dir / "games_outcome.pt"
    reward_games_path = args.out_dir / "games_stockfish.pt"
    policy_only = getattr(args, "policy_only", False)

    if policy_only:
        # Reward subset is irrelevant when we're only training the policy model.
        # Skip-condition checks only the policy artifact.
        if policy_games_path.exists() and not args.force:
            print(f"Stage 1: skipping — {policy_games_path.name} exists (--policy-only).")
            return
    else:
        if policy_games_path.exists() and reward_games_path.exists() and not args.force:
            print(f"Stage 1: skipping — {policy_games_path.name} and {reward_games_path.name} exist.")
            return

    if policy_only:
        lower_elo = args.policy_min_elo
        print(
            f"Stage 1: streaming Lichess/standard-chess-games (Termination == 'Normal'), "
            f"policy Elo >= {args.policy_min_elo} (target {args.policy_games:,}). "
            f"Reward subset skipped (--policy-only)."
        )
    else:
        lower_elo = min(args.reward_min_elo, args.policy_min_elo)
        print(
            f"Stage 1: streaming Lichess/standard-chess-games (Termination == 'Normal'), "
            f"reward Elo >= {args.reward_min_elo} (target {args.reward_games:,}), "
            f"policy Elo >= {args.policy_min_elo} (target {args.policy_games:,})..."
        )

    ds = load_dataset("Lichess/standard-chess-games", split="train", streaming=True)
    # Pre-filter by the lower of the two thresholds to skip clearly ineligible games.
    ds = ds.filter(
        lambda r: (
            r.get("WhiteElo") is not None
            and r.get("BlackElo") is not None
            and r["WhiteElo"] >= lower_elo
            and r["BlackElo"] >= lower_elo
            and r.get("Termination") == "Normal"
        )
    )

    policy_games: list[dict] = []
    reward_games: list[dict] = []
    keep_keys = ("movetext", "Result")

    for row in tqdm(ds, desc="Stage 1: streaming", unit="game"):
        white_elo = row.get("WhiteElo", 0)
        black_elo = row.get("BlackElo", 0)
        minimal = {k: row.get(k) for k in keep_keys}

        if (
            not policy_only
            and len(reward_games) < args.reward_games
            and white_elo >= args.reward_min_elo
            and black_elo >= args.reward_min_elo
        ):
            reward_games.append(minimal)

        if len(policy_games) < args.policy_games and white_elo >= args.policy_min_elo and black_elo >= args.policy_min_elo:
            policy_games.append(minimal)

        if policy_only:
            if len(policy_games) >= args.policy_games:
                break
        elif len(reward_games) >= args.reward_games and len(policy_games) >= args.policy_games:
            break

    if policy_only:
        if len(policy_games) < args.policy_games:
            print(
                f"  WARNING: dataset exhausted before target — "
                f"got {len(policy_games):,} policy games."
            )
    elif len(reward_games) < args.reward_games or len(policy_games) < args.policy_games:
        print(
            f"  WARNING: dataset exhausted before target — "
            f"got {len(reward_games):,} reward + {len(policy_games):,} policy games."
        )

    if not policy_only:
        print(f"Stage 1: saving {reward_games_path} ({len(reward_games):,} games)...")
        torch.save(reward_games, reward_games_path)
    print(f"Stage 1: saving {policy_games_path} ({len(policy_games):,} games)...")
    torch.save(policy_games, policy_games_path)


def _generate_outcome_samples(games, tokenizer, max_positions_per_game, skip_ply):
    """Build (token_ids, outcome_label) samples for the phase-1 dataset."""
    cls_id = tokenizer.symbol_to_token[CLS_TOKEN]
    samples: list[tuple[list[int], float]] = []
    with tqdm(games, desc="Stage 2: outcome samples", unit="game") as pbar:
        for idx, game in enumerate(pbar):
            result = game.get("Result")
            if result not in RESULT_TO_LABEL:
                continue
            label = RESULT_TO_LABEL[result]

            movetext = game.get("movetext", "")
            if not movetext:
                continue
            move_sans = parse_movetext(movetext)
            if len(move_sans) < max(2, skip_ply + 1):
                continue

            eligible = list(range(skip_ply, len(move_sans)))
            num_positions = min(max_positions_per_game, len(eligible))
            rng = random.Random(idx)
            sample_indices = set(rng.sample(eligible, num_positions))

            board = chess.Board()
            valid_moves: list[str] = []
            for i, san in enumerate(move_sans):
                try:
                    move = board.parse_san(san)
                    board.push(move)
                    valid_moves.append(move.uci())
                except (chess.InvalidMoveError, chess.AmbiguousMoveError):
                    break
                if i in sample_indices:
                    token_ids = [cls_id] + tokenizer.encode_moves(valid_moves)
                    samples.append((token_ids, label))

            if (idx + 1) % 50_000 == 0:
                pbar.set_postfix(samples=f"{len(samples):,}")

    return samples


def stage2_outcome_samples(args: argparse.Namespace) -> None:
    tokenizer_path = args.out_dir / "tokenizer.pt"
    meta_path = args.out_dir / "outcome_meta.pt"
    if tokenizer_path.exists() and meta_path.exists() and not args.force:
        print(f"Stage 2: skipping — {tokenizer_path.name} and {meta_path.name} exist.")
        return

    raw_games_path = args.out_dir / "games_outcome.pt"
    print(f"Stage 2: loading outcome games from {raw_games_path}...")
    games = torch.load(raw_games_path, weights_only=False)

    print("Stage 2: building tokenizer from all UCI moves...")
    tokenizer = build_tokenizer_from_games()
    print(f"Stage 2: tokenizer vocab size = {tokenizer.language_size}")
    torch.save(tokenizer, tokenizer_path)

    print("Stage 2: generating outcome samples (up to 20 per game)...")
    samples = _generate_outcome_samples(
        games,
        tokenizer,
        max_positions_per_game=20,
        skip_ply=0,
    )
    print(f"Stage 2: saving {len(samples):,} outcome samples as memmap...")
    _save_as_memmap(samples, args.out_dir, "outcome", max_seq_len=args.max_seq_len)


def _generate_policy_sequences(games, tokenizer, max_seq_len: int = 128) -> list[list[int]]:
    """Tokenize full game sequences for policy training.

    Each output sequence is [CLS, m1, m2, ..., mN], truncated to max_seq_len.
    Games with fewer than 2 valid UCI moves are skipped.
    """
    cls_id = tokenizer.symbol_to_token[CLS_TOKEN]
    sequences: list[list[int]] = []
    with tqdm(games, desc="Stage 4: policy sequences", unit="game") as pbar:
        for game in pbar:
            movetext = game.get("movetext", "")
            if not movetext:
                continue
            move_sans = parse_movetext(movetext)
            if len(move_sans) < 2:
                continue
            board = chess.Board()
            move_ucis: list[str] = []
            for san in move_sans:
                try:
                    move = board.parse_san(san)
                    board.push(move)
                    move_ucis.append(move.uci())
                except (chess.InvalidMoveError, chess.AmbiguousMoveError):
                    break
            if len(move_ucis) < 2:
                continue
            move_ucis = move_ucis[:max_seq_len - 1]
            sequences.append([cls_id] + tokenizer.encode_moves(move_ucis))
    return sequences


def _save_policy_memmap(
    sequences: list[list[int]], out_dir: Path, name: str, max_seq_len: int = 128,
    fens: list[str] | None = None, fen_len: int = 100,
) -> None:
    """Save policy sequences as memory-mapped arrays (no labels).

    Produces:
      {name}_tokens.bin   — (N, max_len) int32, zero-padded
      {name}_lengths.bin  — (N,) int32, actual sequence length per sample
      {name}_meta.pt      — dict with 'n', 'max_len', and (if fens given) 'fen_len'

    If `fens` is provided, also writes {name}_fens.bin — (N, fen_len) uint8
    holding zero-padded ASCII FEN strings, one per sample. Used by the
    CNN-conditioned policy training to reconstruct each sample's starting board.
    """
    n = len(sequences)
    max_len = min(max(len(s) for s in sequences), max_seq_len)
    print(f"  memmap {name}: {n:,} sequences, max_seq_len={max_len}")

    tokens = np.memmap(out_dir / f"{name}_tokens.bin", dtype=np.int32, mode="w+", shape=(n, max_len))
    lengths = np.memmap(out_dir / f"{name}_lengths.bin", dtype=np.int32, mode="w+", shape=(n,))

    for i, seq in enumerate(tqdm(sequences, desc=f"  writing {name}", unit="seq")):
        seq = seq[:max_len]
        l = len(seq)
        tokens[i, :l] = seq
        lengths[i] = l

    tokens.flush()
    lengths.flush()

    meta = {"n": n, "max_len": max_len}
    extra_bytes = 0
    if fens is not None:
        assert len(fens) == n, f"fen count {len(fens)} mismatch with sequence count {n}"
        fens_mm = np.memmap(out_dir / f"{name}_fens.bin", dtype=np.uint8, mode="w+", shape=(n, fen_len))
        for i, fen in enumerate(fens):
            b = fen.encode("ascii")[:fen_len]
            fens_mm[i, :len(b)] = list(b)
        fens_mm.flush()
        meta["fen_len"] = fen_len
        extra_bytes = fens_mm.nbytes

    torch.save(meta, out_dir / f"{name}_meta.pt")
    size_gb = (tokens.nbytes + lengths.nbytes + extra_bytes) / 1024 ** 3
    print(f"  memmap {name} saved ({size_gb:.3f} GB)")


def _process_puzzle(
    row: dict,
    tokenizer_symbol_map: dict,
    cls_id: int,
) -> tuple[list[int], str] | None:
    """Parse one Lichess puzzle row into a (token_sequence, FEN) pair.

    Sequence layout: [CLS, setup_move, solver_move1, opp_response, solver_move2, ...]

    The setup move (Moves[0]) is included as context so the model conditions on it
    when predicting the solution. During training the loss on the setup move position
    is masked out — we model P[m_n | S, m_{<n}] where S is the setup move.

    The FEN is the puzzle's starting board position. It is persisted alongside the
    token sequence so the CNN-conditioned policy training can reconstruct the
    starting board planes (CNN's input) at __getitem__ time.

    Returns None if any move is illegal, unknown to the tokenizer, or the sequence
    has fewer than 3 tokens (CLS + setup + at least one solver move).
    """
    fen = row.get("FEN", "")
    moves_str = row.get("Moves", "")
    if not fen or not moves_str:
        return None
    uci_moves = moves_str.strip().split()
    if len(uci_moves) < 2:  # need setup + at least one solver move
        return None
    try:
        board = chess.Board(fen)
    except ValueError:
        return None

    # Tokenize all moves: setup first (as context), then the full solution.
    token_ids: list[int] = [cls_id]
    for uci in uci_moves:
        try:
            move = chess.Move.from_uci(uci)
        except ValueError:
            return None
        if move not in board.legal_moves:
            return None
        if uci not in tokenizer_symbol_map:
            return None
        token_ids.append(tokenizer_symbol_map[uci])
        board.push(move)

    if len(token_ids) < 3:  # CLS + setup + at least one solver move
        return None
    return token_ids, fen


def stage3_stockfish_samples(args: argparse.Namespace) -> None:
    meta_path = args.out_dir / "stockfish_meta.pt"
    if meta_path.exists() and not args.force:
        print(f"Stage 3: skipping — {meta_path.name} exists.")
        return

    games_path = args.out_dir / "games_stockfish.pt"
    tokenizer_path = args.out_dir / "tokenizer.pt"
    print(f"Stage 3: loading {games_path} and {tokenizer_path}...")
    games = torch.load(games_path, weights_only=False)
    tokenizer = torch.load(tokenizer_path, weights_only=False)

    print(
        f"Stage 3: running parallel Stockfish ({args.workers} workers, "
        f"depth {args.stockfish_depth}) on {len(games):,} games..."
    )
    samples = generate_samples_stockfish_parallel(
        games,
        tokenizer,
        num_workers=args.workers,
        stockfish_depth=args.stockfish_depth,
        sample_rate=args.sample_rate,
        skew_exponent=args.position_skew,
    )

    print(f"Stage 3: saving {len(samples):,} stockfish samples as memmap...")
    _save_as_memmap(samples, args.out_dir, "stockfish", max_seq_len=args.max_seq_len)


def stage4_policy_sequences(args: argparse.Namespace) -> None:
    meta_path = args.out_dir / "policy_meta.pt"
    if meta_path.exists() and not args.force:
        print(f"Stage 4: skipping — {meta_path.name} exists.")
        return

    games_path = args.out_dir / "games_outcome.pt"
    tokenizer_path = args.out_dir / "tokenizer.pt"
    print(f"Stage 4: loading {games_path} and {tokenizer_path}...")
    games = torch.load(games_path, weights_only=False)
    tokenizer = torch.load(tokenizer_path, weights_only=False)

    print(f"Stage 4: tokenizing {len(games):,} games into policy sequences...")
    sequences = _generate_policy_sequences(games, tokenizer, max_seq_len=args.max_seq_len)
    print(f"Stage 4: saving {len(sequences):,} policy sequences as memmap...")
    _save_policy_memmap(sequences, args.out_dir, "policy", max_seq_len=args.max_seq_len)


def _write_test_subset_reward(out_dir: Path, src_name: str, dst_name: str, indices: np.ndarray) -> None:
    """Write a subset of a reward memmap (tokens+labels+lengths) to new files."""
    meta = torch.load(out_dir / f"{src_name}_meta.pt", weights_only=True)
    n_src, max_len = meta["n"], meta["max_len"]
    src_tokens = np.memmap(out_dir / f"{src_name}_tokens.bin", dtype=np.int32, mode="r", shape=(n_src, max_len))
    src_labels = np.memmap(out_dir / f"{src_name}_labels.bin", dtype=np.float32, mode="r", shape=(n_src,))
    src_lengths = np.memmap(out_dir / f"{src_name}_lengths.bin", dtype=np.int32, mode="r", shape=(n_src,))
    n_test = len(indices)
    dst_tokens = np.memmap(out_dir / f"{dst_name}_tokens.bin", dtype=np.int32, mode="w+", shape=(n_test, max_len))
    dst_labels = np.memmap(out_dir / f"{dst_name}_labels.bin", dtype=np.float32, mode="w+", shape=(n_test,))
    dst_lengths = np.memmap(out_dir / f"{dst_name}_lengths.bin", dtype=np.int32, mode="w+", shape=(n_test,))
    for i, idx in enumerate(tqdm(indices, desc=f"  writing {dst_name}", unit="sample")):
        dst_tokens[i] = src_tokens[idx]
        dst_labels[i] = src_labels[idx]
        dst_lengths[i] = src_lengths[idx]
    dst_tokens.flush()
    dst_labels.flush()
    dst_lengths.flush()
    torch.save({"n": n_test, "max_len": max_len}, out_dir / f"{dst_name}_meta.pt")
    print(f"  {dst_name}: {n_test:,} samples written")


def _write_test_subset_policy(out_dir: Path, src_name: str, dst_name: str, indices: np.ndarray) -> None:
    """Write a subset of a policy memmap (tokens+lengths, no labels) to new files."""
    meta = torch.load(out_dir / f"{src_name}_meta.pt", weights_only=True)
    n_src, max_len = meta["n"], meta["max_len"]
    src_tokens = np.memmap(out_dir / f"{src_name}_tokens.bin", dtype=np.int32, mode="r", shape=(n_src, max_len))
    src_lengths = np.memmap(out_dir / f"{src_name}_lengths.bin", dtype=np.int32, mode="r", shape=(n_src,))
    n_test = len(indices)
    dst_tokens = np.memmap(out_dir / f"{dst_name}_tokens.bin", dtype=np.int32, mode="w+", shape=(n_test, max_len))
    dst_lengths = np.memmap(out_dir / f"{dst_name}_lengths.bin", dtype=np.int32, mode="w+", shape=(n_test,))
    for i, idx in enumerate(tqdm(indices, desc=f"  writing {dst_name}", unit="seq")):
        dst_tokens[i] = src_tokens[idx]
        dst_lengths[i] = src_lengths[idx]
    dst_tokens.flush()
    dst_lengths.flush()
    torch.save({"n": n_test, "max_len": max_len}, out_dir / f"{dst_name}_meta.pt")
    print(f"  {dst_name}: {n_test:,} sequences written")


def stage_build_test_splits(args: argparse.Namespace, out_dir: Path) -> None:
    """Build held-out test sets for reward and policy models from existing memmaps.

    Uses a fixed random seed (42) so the same indices are always selected.
    Saves the chosen indices to {name}_test_indices.npy so the corresponding
    training memmap loader can exclude them — making train and test disjoint
    even though they share the underlying .bin file.

    Produces:
      stockfish_test_*.bin / stockfish_test_meta.pt / stockfish_test_indices.npy
      policy_test_*.bin    / policy_test_meta.pt    / policy_test_indices.npy
    """
    rng = np.random.default_rng(42)
    policy_only = getattr(args, "policy_only", False)

    # Reward test set — skipped when --policy-only since no Stockfish data exists.
    reward_test_meta = out_dir / "stockfish_test_meta.pt"
    sf_meta_path = out_dir / "stockfish_meta.pt"
    if policy_only:
        print("Test splits: stockfish_test skipped (--policy-only).")
    elif (not reward_test_meta.exists() or args.force) and sf_meta_path.exists():
        print(f"Test splits: building stockfish_test ({args.reward_test_size:,} samples)...")
        sf_meta = torch.load(sf_meta_path, weights_only=True)
        n = sf_meta["n"]
        test_n = min(args.reward_test_size, n)
        idx = rng.choice(n, size=test_n, replace=False)
        idx.sort()
        _write_test_subset_reward(out_dir, "stockfish", "stockfish_test", idx)
        np.save(out_dir / "stockfish_test_indices.npy", idx)
        print(f"  saved stockfish_test_indices.npy ({test_n:,} indices excluded from training)")
    elif reward_test_meta.exists():
        print("Test splits: stockfish_test already exists, skipping.")

    # Policy test set
    policy_test_meta = out_dir / "policy_test_meta.pt"
    pol_meta_path = out_dir / "policy_meta.pt"
    if (not policy_test_meta.exists() or args.force) and pol_meta_path.exists():
        print(f"Test splits: building policy_test ({args.policy_test_size:,} sequences)...")
        pol_meta = torch.load(pol_meta_path, weights_only=True)
        n = pol_meta["n"]
        test_n = min(args.policy_test_size, n)
        idx = rng.choice(n, size=test_n, replace=False)
        idx.sort()
        _write_test_subset_policy(out_dir, "policy", "policy_test", idx)
        np.save(out_dir / "policy_test_indices.npy", idx)
        print(f"  saved policy_test_indices.npy ({test_n:,} indices excluded from training)")
    elif policy_test_meta.exists():
        print("Test splits: policy_test already exists, skipping.")


def stage5_puzzle_samples(args: argparse.Namespace, tokenizer, out_dir: Path) -> None:
    train_done = (out_dir / "puzzle_meta.pt").exists()
    test_done = (out_dir / "puzzle_test_meta.pt").exists()
    if train_done and test_done and not args.force:
        print("Stage 5: skipping — puzzle_meta.pt and puzzle_test_meta.pt exist.")
        return

    print("Stage 5: loading Lichess/chess-puzzles from HuggingFace...")
    ds = load_dataset("Lichess/chess-puzzles", split="train", streaming=True)

    min_pop = args.min_puzzle_popularity
    min_plays = args.min_puzzle_plays
    cls_id = tokenizer.symbol_to_token[CLS_TOKEN]
    sym_map = tokenizer.symbol_to_token
    test_seqs: list[list[int]] = []
    test_fens: list[str] = []
    train_seqs: list[list[int]] = []
    train_fens: list[str] = []
    skipped = 0
    test_target = args.puzzle_test_size
    train_target = args.puzzle_count

    with tqdm(ds, desc="Stage 5: puzzles", unit="puzzle") as pbar:
        for row in pbar:
            if min_pop is not None and row.get("Popularity", 0) < min_pop:
                continue
            if min_plays is not None and row.get("NbPlays", 0) < min_plays:
                continue
            result = _process_puzzle(row, sym_map, cls_id)
            if result is None:
                skipped += 1
                continue
            seq, fen = result
            if len(test_seqs) < test_target:
                test_seqs.append(seq)
                test_fens.append(fen)
            else:
                train_seqs.append(seq)
                train_fens.append(fen)
            pbar.set_postfix(test=len(test_seqs), train=len(train_seqs), skipped=skipped)
            if train_target is not None and len(train_seqs) >= train_target:
                break

    print(
        f"Stage 5: test={len(test_seqs):,} puzzles, "
        f"train={len(train_seqs):,} puzzles, "
        f"skipped={skipped:,} invalid."
    )
    if test_seqs and not test_done:
        _save_policy_memmap(
            test_seqs, out_dir, "puzzle_test", max_seq_len=args.max_seq_len, fens=test_fens,
        )
    if train_seqs and not train_done:
        _save_policy_memmap(
            train_seqs, out_dir, "puzzle", max_seq_len=args.max_seq_len, fens=train_fens,
        )
    elif not train_seqs:
        print("Stage 5: WARNING — no training puzzles collected.")


def main():
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--out-dir", type=Path, default=Path("data"))
    parser.add_argument("--policy-games", type=int, default=1_000_000,
                        help="Number of games to collect for policy model training")
    parser.add_argument("--reward-games", type=int, default=1_000_000,
                        help="Number of games to collect for reward model (Stockfish eval)")
    parser.add_argument("--policy-min-elo", type=int, default=1800,
                        help="Min Elo for both players in policy training games")
    parser.add_argument("--reward-min-elo", type=int, default=1500,
                        help="Min Elo for both players in reward model training games")
    parser.add_argument("--sample-rate", type=float, default=0.25,
                        help="Fraction of positions to sample per game (scales with game length)")
    parser.add_argument("--position-skew", type=float, default=1.5,
                        help="Power-law exponent weighting later positions; 1.0=linear, higher=more mid/late")
    parser.add_argument("--workers", type=int, default=16)
    parser.add_argument("--stockfish-depth", type=int, default=12)
    parser.add_argument("--max-seq-len", type=int, default=128,
                        help="Truncate token sequences to this length when writing .bin files")
    parser.add_argument(
        "--force",
        action="store_true",
        help="Re-run all stages even if their outputs already exist",
    )
    parser.add_argument("--puzzle-count", type=int, default=None, dest="puzzle_count",
        help="Max puzzles to include (default: all ~4.99M)")
    parser.add_argument("--min-puzzle-popularity", type=int, default=None, dest="min_puzzle_popularity",
        help="Min Lichess Popularity score (0-100 scale)")
    parser.add_argument("--min-puzzle-plays", type=int, default=None, dest="min_puzzle_plays",
        help="Min NbPlays for a puzzle to be included")
    parser.add_argument("--skip-puzzles", action="store_true",
        help="Skip Stage 5 puzzle processing")
    parser.add_argument("--puzzles-only", action="store_true",
        help="Only run Stage 5 (puzzle processing). Skips game collection, "
             "outcome/Stockfish/policy memmaps, and test splits. Requires "
             "tokenizer.pt to exist (or it will be built from the UCI vocab).")
    parser.add_argument("--policy-only", action="store_true",
        help="Skip everything Stockfish/reward-related: Stage 1 collects only "
             "policy games, Stage 3 (Stockfish labeling) is skipped, and the "
             "stockfish_test split is not built. Stages 1/2/4/5 + policy_test "
             "still run, producing tokenizer.pt, policy_* / puzzle_* memmaps, "
             "and the policy_test split.")
    parser.add_argument("--puzzle-test-size", type=int, default=100_000, dest="puzzle_test_size",
        help="Number of puzzle sequences held out for the test set (default: 100000)")
    parser.add_argument("--reward-test-size", type=int, default=50_000, dest="reward_test_size",
        help="Number of reward positions held out for the test set (default: 50000)")
    parser.add_argument("--policy-test-size", type=int, default=50_000, dest="policy_test_size",
        help="Number of policy sequences held out for the test set (default: 50000)")
    args = parser.parse_args()

    args.out_dir.mkdir(parents=True, exist_ok=True)

    if args.puzzles_only and args.policy_only:
        parser.error("--puzzles-only and --policy-only are mutually exclusive.")

    if args.puzzles_only:
        print("--puzzles-only: skipping Stages 1-4 and test-split builder.")
        tokenizer_path = args.out_dir / "tokenizer.pt"
        if tokenizer_path.exists():
            tokenizer = torch.load(tokenizer_path, weights_only=False)
        else:
            # Tokenizer is just the enumerated UCI vocab — no games needed.
            print("  tokenizer.pt missing; building from UCI vocab...")
            tokenizer = build_tokenizer_from_games()
            torch.save(tokenizer, tokenizer_path)
        stage5_puzzle_samples(args, tokenizer, args.out_dir)
    elif args.policy_only:
        print("--policy-only: skipping Stage 3 (Stockfish labeling) and stockfish_test split.")
        stage1_collect_games(args)
        stage2_outcome_samples(args)
        stage4_policy_sequences(args)

        tokenizer_path = args.out_dir / "tokenizer.pt"
        if not args.skip_puzzles and tokenizer_path.exists():
            tokenizer = torch.load(tokenizer_path, weights_only=False)
            stage5_puzzle_samples(args, tokenizer, args.out_dir)
        elif not args.skip_puzzles:
            print("Stage 5: skipping — tokenizer.pt not found (run stages 1-2 first).")

        stage_build_test_splits(args, args.out_dir)
    else:
        stage1_collect_games(args)
        stage2_outcome_samples(args)
        stage3_stockfish_samples(args)
        stage4_policy_sequences(args)

        tokenizer_path = args.out_dir / "tokenizer.pt"
        if not args.skip_puzzles and tokenizer_path.exists():
            tokenizer = torch.load(tokenizer_path, weights_only=False)
            stage5_puzzle_samples(args, tokenizer, args.out_dir)
        elif not args.skip_puzzles:
            print("Stage 5: skipping — tokenizer.pt not found (run stages 1-2 first).")

        stage_build_test_splits(args, args.out_dir)

    print("\nAll stages complete. Artifacts:")
    for name in (
        "games_outcome.pt",
        "games_stockfish.pt",
        "tokenizer.pt",
        "outcome_tokens.bin",
        "outcome_labels.bin",
        "outcome_lengths.bin",
        "outcome_meta.pt",
        "stockfish_tokens.bin",
        "stockfish_labels.bin",
        "stockfish_lengths.bin",
        "stockfish_meta.pt",
        "policy_tokens.bin",
        "policy_lengths.bin",
        "policy_meta.pt",
        "puzzle_tokens.bin",
        "puzzle_lengths.bin",
        "puzzle_fens.bin",
        "puzzle_meta.pt",
        "puzzle_test_tokens.bin",
        "puzzle_test_lengths.bin",
        "puzzle_test_fens.bin",
        "puzzle_test_meta.pt",
        "stockfish_test_tokens.bin",
        "stockfish_test_labels.bin",
        "stockfish_test_lengths.bin",
        "stockfish_test_meta.pt",
        "policy_test_tokens.bin",
        "policy_test_lengths.bin",
        "policy_test_meta.pt",
    ):
        path = args.out_dir / name
        size_mb = path.stat().st_size / 1024 / 1024 if path.exists() else 0
        print(f"  {path}  ({size_mb:.1f} MB)")


if __name__ == "__main__":
    main()