File size: 7,167 Bytes
d7ecc62
 
 
 
 
36caadb
d7ecc62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36caadb
d7ecc62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230508d
d7ecc62
 
230508d
d7ecc62
 
230508d
d7ecc62
 
 
 
 
230508d
 
 
d7ecc62
 
 
 
230508d
d7ecc62
 
230508d
d7ecc62
 
 
 
 
 
230508d
d7ecc62
 
230508d
d7ecc62
 
 
 
 
230508d
d7ecc62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230508d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d7ecc62
 
 
 
 
 
 
a188746
 
d7ecc62
 
 
 
a188746
d7ecc62
 
 
36caadb
d7ecc62
 
36caadb
d7ecc62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230508d
 
 
 
 
 
 
 
 
 
d7ecc62
 
 
 
a188746
 
d7ecc62
 
 
 
 
 
230508d
d7ecc62
230508d
 
 
 
 
 
 
 
 
d7ecc62
 
 
230508d
d7ecc62
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
"""PAWN data pipeline: on-the-fly generation via Rust engine."""

import os
import threading
import time
from collections.abc import Iterator

import numpy as np
import torch
import torch.utils.data

import chess_engine as engine

from pawn.config import (
    WHITE_CHECKMATES,
    BLACK_CHECKMATES,
    STALEMATE,
    DRAW_BY_RULE,
    PLY_LIMIT,
)


_positions_cache: dict[tuple[str, int], torch.Tensor] = {}


def _map_termination_to_outcome(
    term_codes: np.ndarray, game_lengths: np.ndarray
) -> torch.Tensor:
    """Map engine termination codes to outcome token IDs.

    Engine codes: 0=Checkmate, 1=Stalemate, 2=SeventyFiveMoveRule,
                  3=FivefoldRepetition, 4=InsufficientMaterial, 5=PlyLimit

    For checkmate, who checkmated is determined by game length:
    - Odd game_length (last ply index even) -> white delivered checkmate
    - Even game_length (last ply index odd) -> black delivered checkmate
    """
    term = torch.from_numpy(term_codes).long()
    gl = torch.from_numpy(game_lengths).long()

    outcomes = torch.full((len(term_codes),), PLY_LIMIT, dtype=torch.long)

    is_checkmate = term == 0
    outcomes[is_checkmate & (gl % 2 == 1)] = WHITE_CHECKMATES
    outcomes[is_checkmate & (gl % 2 == 0)] = BLACK_CHECKMATES
    outcomes[term == 1] = STALEMATE
    outcomes[(term == 2) | (term == 3) | (term == 4)] = DRAW_BY_RULE
    # PlyLimit (code 5) is the default

    return outcomes


def pack_clm_sequences(
    move_ids: np.ndarray,
    game_lengths: np.ndarray,
    outcome_tokens: torch.Tensor,
    seq_len: int,
) -> dict[str, torch.Tensor]:
    """Pack move arrays into CLM training tensors.

    Constructs input_ids = [outcome, move_1, ..., move_N, PAD, ...]
    and targets shifted left by 1.

    Args:
        move_ids: (B, max_ply) raw move token IDs
        game_lengths: (B,) actual game lengths
        outcome_tokens: (B,) pre-computed outcome token IDs (4273-4277)
        seq_len: total CLM sequence length (256)
    """
    B = len(game_lengths)
    n_move_slots = seq_len - 1  # 255 slots for moves (position 0 = outcome)
    max_ply = move_ids.shape[1]

    game_lengths_t = torch.from_numpy(game_lengths).long()
    move_ids_t = torch.from_numpy(move_ids).long()  # (B, max_ply)

    # Build input_ids: [outcome, move_0, ..., move_{N-1}, PAD, ...]
    input_ids = torch.zeros(B, seq_len, dtype=torch.long)
    input_ids[:, 0] = outcome_tokens

    # Mask out any non-move tokens from engine output
    cache_key = ("engine", max_ply)
    engine_positions = _positions_cache.get(cache_key)
    if engine_positions is None:
        engine_positions = torch.arange(max_ply).unsqueeze(0)
        _positions_cache[cache_key] = engine_positions
    move_mask = engine_positions < game_lengths_t.unsqueeze(1)
    clean_moves = move_ids_t * move_mask

    # Place moves at positions 1..n_move_slots
    n_to_copy = min(max_ply, n_move_slots)
    input_ids[:, 1 : n_to_copy + 1] = clean_moves[:, :n_to_copy]

    # Cap game_lengths to n_move_slots (handles edge case where engine
    # produces more moves than we have slots)
    capped_lengths = game_lengths_t.clamp(max=n_move_slots)

    # Targets: input shifted left by 1
    targets = torch.zeros(B, seq_len, dtype=torch.long)
    targets[:, :-1] = input_ids[:, 1:]

    # Loss mask: True for positions 0 through capped_lengths[b]
    cache_key_seq = ("seq", seq_len)
    seq_positions = _positions_cache.get(cache_key_seq)
    if seq_positions is None:
        seq_positions = torch.arange(seq_len).unsqueeze(0)
        _positions_cache[cache_key_seq] = seq_positions
    loss_mask = seq_positions <= capped_lengths.unsqueeze(1)

    return {
        "input_ids": input_ids,
        "targets": targets,
        "loss_mask": loss_mask,
    }


def _to_clm_batch(
    move_ids: np.ndarray,
    game_lengths: np.ndarray,
    term_codes: np.ndarray,
    seq_len: int,
) -> dict[str, torch.Tensor]:
    """Convert Rust engine output to CLM training tensors.

    Convenience wrapper: computes outcome tokens from termination codes,
    then delegates to pack_clm_sequences.
    """
    outcome_tokens = _map_termination_to_outcome(term_codes, game_lengths)
    return pack_clm_sequences(move_ids, game_lengths, outcome_tokens, seq_len)


class CLMDataset(torch.utils.data.IterableDataset):
    """Generates CLM training data on-the-fly via the Rust engine.

    Each iteration yields a complete batch. Seeds are deterministic:
    base_seed + step * num_workers + worker_id.
    """

    def __init__(self, batch_size: int, max_ply: int, base_seed: int,
                 discard_ply_limit: bool = False):
        super().__init__()
        self.batch_size = batch_size
        self.max_ply = max_ply
        self.base_seed = base_seed
        self.discard_ply_limit = discard_ply_limit
        self._start_step = 0
        self._main_pid = os.getpid()

    def set_start_step(self, step: int) -> None:
        self._start_step = step

    def __iter__(self) -> Iterator[dict[str, torch.Tensor]]:
        worker_info = torch.utils.data.get_worker_info()
        worker_id = worker_info.id if worker_info else 0
        num_workers = worker_info.num_workers if worker_info else 1

        if worker_info is not None:
            main_pid = self._main_pid

            def _watchdog():
                while True:
                    time.sleep(2)
                    try:
                        os.kill(main_pid, 0)
                    except OSError:
                        os._exit(1)

            t = threading.Thread(target=_watchdog, daemon=True)
            t.start()

        step = self._start_step
        while True:
            seed = self.base_seed + step * num_workers + worker_id
            input_ids, targets, loss_mask, _move_ids, _gl, _tc = \
                engine.generate_clm_batch(
                    self.batch_size, self.max_ply, seed,
                    discard_ply_limit=self.discard_ply_limit,
                )
            yield {
                "input_ids": torch.from_numpy(input_ids).long(),
                "targets": torch.from_numpy(targets).long(),
                "loss_mask": torch.from_numpy(loss_mask),
            }
            step += 1


def create_validation_set(
    n_games: int, max_ply: int, seed: int,
    discard_ply_limit: bool = False,
) -> dict[str, torch.Tensor]:
    """Generate a fixed validation set.

    Also computes legal move masks for legal move rate evaluation.

    Args:
        max_ply: total CLM sequence length (256).
    """
    input_ids, targets, loss_mask, move_ids, game_lengths, _tc = \
        engine.generate_clm_batch(
            n_games, max_ply, seed, discard_ply_limit=discard_ply_limit,
        )
    batch = {
        "input_ids": torch.from_numpy(input_ids).long(),
        "targets": torch.from_numpy(targets).long(),
        "loss_mask": torch.from_numpy(loss_mask),
    }

    # Compute legal move masks for evaluating legal move rate
    legal_grid, _legal_promo = engine.compute_legal_move_masks(move_ids, game_lengths)
    batch["legal_grid"] = torch.from_numpy(legal_grid).long()
    batch["game_lengths"] = torch.from_numpy(game_lengths).long()

    return batch