File size: 7,167 Bytes
d7ecc62 36caadb d7ecc62 36caadb d7ecc62 230508d d7ecc62 230508d d7ecc62 230508d d7ecc62 230508d d7ecc62 230508d d7ecc62 230508d d7ecc62 230508d d7ecc62 230508d d7ecc62 230508d d7ecc62 230508d d7ecc62 a188746 d7ecc62 a188746 d7ecc62 36caadb d7ecc62 36caadb d7ecc62 230508d d7ecc62 a188746 d7ecc62 230508d d7ecc62 230508d d7ecc62 230508d d7ecc62 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 | """PAWN data pipeline: on-the-fly generation via Rust engine."""
import os
import threading
import time
from collections.abc import Iterator
import numpy as np
import torch
import torch.utils.data
import chess_engine as engine
from pawn.config import (
WHITE_CHECKMATES,
BLACK_CHECKMATES,
STALEMATE,
DRAW_BY_RULE,
PLY_LIMIT,
)
_positions_cache: dict[tuple[str, int], torch.Tensor] = {}
def _map_termination_to_outcome(
term_codes: np.ndarray, game_lengths: np.ndarray
) -> torch.Tensor:
"""Map engine termination codes to outcome token IDs.
Engine codes: 0=Checkmate, 1=Stalemate, 2=SeventyFiveMoveRule,
3=FivefoldRepetition, 4=InsufficientMaterial, 5=PlyLimit
For checkmate, who checkmated is determined by game length:
- Odd game_length (last ply index even) -> white delivered checkmate
- Even game_length (last ply index odd) -> black delivered checkmate
"""
term = torch.from_numpy(term_codes).long()
gl = torch.from_numpy(game_lengths).long()
outcomes = torch.full((len(term_codes),), PLY_LIMIT, dtype=torch.long)
is_checkmate = term == 0
outcomes[is_checkmate & (gl % 2 == 1)] = WHITE_CHECKMATES
outcomes[is_checkmate & (gl % 2 == 0)] = BLACK_CHECKMATES
outcomes[term == 1] = STALEMATE
outcomes[(term == 2) | (term == 3) | (term == 4)] = DRAW_BY_RULE
# PlyLimit (code 5) is the default
return outcomes
def pack_clm_sequences(
move_ids: np.ndarray,
game_lengths: np.ndarray,
outcome_tokens: torch.Tensor,
seq_len: int,
) -> dict[str, torch.Tensor]:
"""Pack move arrays into CLM training tensors.
Constructs input_ids = [outcome, move_1, ..., move_N, PAD, ...]
and targets shifted left by 1.
Args:
move_ids: (B, max_ply) raw move token IDs
game_lengths: (B,) actual game lengths
outcome_tokens: (B,) pre-computed outcome token IDs (4273-4277)
seq_len: total CLM sequence length (256)
"""
B = len(game_lengths)
n_move_slots = seq_len - 1 # 255 slots for moves (position 0 = outcome)
max_ply = move_ids.shape[1]
game_lengths_t = torch.from_numpy(game_lengths).long()
move_ids_t = torch.from_numpy(move_ids).long() # (B, max_ply)
# Build input_ids: [outcome, move_0, ..., move_{N-1}, PAD, ...]
input_ids = torch.zeros(B, seq_len, dtype=torch.long)
input_ids[:, 0] = outcome_tokens
# Mask out any non-move tokens from engine output
cache_key = ("engine", max_ply)
engine_positions = _positions_cache.get(cache_key)
if engine_positions is None:
engine_positions = torch.arange(max_ply).unsqueeze(0)
_positions_cache[cache_key] = engine_positions
move_mask = engine_positions < game_lengths_t.unsqueeze(1)
clean_moves = move_ids_t * move_mask
# Place moves at positions 1..n_move_slots
n_to_copy = min(max_ply, n_move_slots)
input_ids[:, 1 : n_to_copy + 1] = clean_moves[:, :n_to_copy]
# Cap game_lengths to n_move_slots (handles edge case where engine
# produces more moves than we have slots)
capped_lengths = game_lengths_t.clamp(max=n_move_slots)
# Targets: input shifted left by 1
targets = torch.zeros(B, seq_len, dtype=torch.long)
targets[:, :-1] = input_ids[:, 1:]
# Loss mask: True for positions 0 through capped_lengths[b]
cache_key_seq = ("seq", seq_len)
seq_positions = _positions_cache.get(cache_key_seq)
if seq_positions is None:
seq_positions = torch.arange(seq_len).unsqueeze(0)
_positions_cache[cache_key_seq] = seq_positions
loss_mask = seq_positions <= capped_lengths.unsqueeze(1)
return {
"input_ids": input_ids,
"targets": targets,
"loss_mask": loss_mask,
}
def _to_clm_batch(
move_ids: np.ndarray,
game_lengths: np.ndarray,
term_codes: np.ndarray,
seq_len: int,
) -> dict[str, torch.Tensor]:
"""Convert Rust engine output to CLM training tensors.
Convenience wrapper: computes outcome tokens from termination codes,
then delegates to pack_clm_sequences.
"""
outcome_tokens = _map_termination_to_outcome(term_codes, game_lengths)
return pack_clm_sequences(move_ids, game_lengths, outcome_tokens, seq_len)
class CLMDataset(torch.utils.data.IterableDataset):
"""Generates CLM training data on-the-fly via the Rust engine.
Each iteration yields a complete batch. Seeds are deterministic:
base_seed + step * num_workers + worker_id.
"""
def __init__(self, batch_size: int, max_ply: int, base_seed: int,
discard_ply_limit: bool = False):
super().__init__()
self.batch_size = batch_size
self.max_ply = max_ply
self.base_seed = base_seed
self.discard_ply_limit = discard_ply_limit
self._start_step = 0
self._main_pid = os.getpid()
def set_start_step(self, step: int) -> None:
self._start_step = step
def __iter__(self) -> Iterator[dict[str, torch.Tensor]]:
worker_info = torch.utils.data.get_worker_info()
worker_id = worker_info.id if worker_info else 0
num_workers = worker_info.num_workers if worker_info else 1
if worker_info is not None:
main_pid = self._main_pid
def _watchdog():
while True:
time.sleep(2)
try:
os.kill(main_pid, 0)
except OSError:
os._exit(1)
t = threading.Thread(target=_watchdog, daemon=True)
t.start()
step = self._start_step
while True:
seed = self.base_seed + step * num_workers + worker_id
input_ids, targets, loss_mask, _move_ids, _gl, _tc = \
engine.generate_clm_batch(
self.batch_size, self.max_ply, seed,
discard_ply_limit=self.discard_ply_limit,
)
yield {
"input_ids": torch.from_numpy(input_ids).long(),
"targets": torch.from_numpy(targets).long(),
"loss_mask": torch.from_numpy(loss_mask),
}
step += 1
def create_validation_set(
n_games: int, max_ply: int, seed: int,
discard_ply_limit: bool = False,
) -> dict[str, torch.Tensor]:
"""Generate a fixed validation set.
Also computes legal move masks for legal move rate evaluation.
Args:
max_ply: total CLM sequence length (256).
"""
input_ids, targets, loss_mask, move_ids, game_lengths, _tc = \
engine.generate_clm_batch(
n_games, max_ply, seed, discard_ply_limit=discard_ply_limit,
)
batch = {
"input_ids": torch.from_numpy(input_ids).long(),
"targets": torch.from_numpy(targets).long(),
"loss_mask": torch.from_numpy(loss_mask),
}
# Compute legal move masks for evaluating legal move rate
legal_grid, _legal_promo = engine.compute_legal_move_masks(move_ids, game_lengths)
batch["legal_grid"] = torch.from_numpy(legal_grid).long()
batch["game_lengths"] = torch.from_numpy(game_lengths).long()
return batch
|