PAWN / pawn /data.py
thomas-schweich's picture
Safetensors migration, checkpoint integrity, and multi-model training. (#1)
230508d unverified
"""PAWN data pipeline: on-the-fly generation via Rust engine."""
import os
import threading
import time
from collections.abc import Iterator
import numpy as np
import torch
import torch.utils.data
import chess_engine as engine
from pawn.config import (
WHITE_CHECKMATES,
BLACK_CHECKMATES,
STALEMATE,
DRAW_BY_RULE,
PLY_LIMIT,
)
_positions_cache: dict[tuple[str, int], torch.Tensor] = {}
def _map_termination_to_outcome(
term_codes: np.ndarray, game_lengths: np.ndarray
) -> torch.Tensor:
"""Map engine termination codes to outcome token IDs.
Engine codes: 0=Checkmate, 1=Stalemate, 2=SeventyFiveMoveRule,
3=FivefoldRepetition, 4=InsufficientMaterial, 5=PlyLimit
For checkmate, who checkmated is determined by game length:
- Odd game_length (last ply index even) -> white delivered checkmate
- Even game_length (last ply index odd) -> black delivered checkmate
"""
term = torch.from_numpy(term_codes).long()
gl = torch.from_numpy(game_lengths).long()
outcomes = torch.full((len(term_codes),), PLY_LIMIT, dtype=torch.long)
is_checkmate = term == 0
outcomes[is_checkmate & (gl % 2 == 1)] = WHITE_CHECKMATES
outcomes[is_checkmate & (gl % 2 == 0)] = BLACK_CHECKMATES
outcomes[term == 1] = STALEMATE
outcomes[(term == 2) | (term == 3) | (term == 4)] = DRAW_BY_RULE
# PlyLimit (code 5) is the default
return outcomes
def pack_clm_sequences(
move_ids: np.ndarray,
game_lengths: np.ndarray,
outcome_tokens: torch.Tensor,
seq_len: int,
) -> dict[str, torch.Tensor]:
"""Pack move arrays into CLM training tensors.
Constructs input_ids = [outcome, move_1, ..., move_N, PAD, ...]
and targets shifted left by 1.
Args:
move_ids: (B, max_ply) raw move token IDs
game_lengths: (B,) actual game lengths
outcome_tokens: (B,) pre-computed outcome token IDs (4273-4277)
seq_len: total CLM sequence length (256)
"""
B = len(game_lengths)
n_move_slots = seq_len - 1 # 255 slots for moves (position 0 = outcome)
max_ply = move_ids.shape[1]
game_lengths_t = torch.from_numpy(game_lengths).long()
move_ids_t = torch.from_numpy(move_ids).long() # (B, max_ply)
# Build input_ids: [outcome, move_0, ..., move_{N-1}, PAD, ...]
input_ids = torch.zeros(B, seq_len, dtype=torch.long)
input_ids[:, 0] = outcome_tokens
# Mask out any non-move tokens from engine output
cache_key = ("engine", max_ply)
engine_positions = _positions_cache.get(cache_key)
if engine_positions is None:
engine_positions = torch.arange(max_ply).unsqueeze(0)
_positions_cache[cache_key] = engine_positions
move_mask = engine_positions < game_lengths_t.unsqueeze(1)
clean_moves = move_ids_t * move_mask
# Place moves at positions 1..n_move_slots
n_to_copy = min(max_ply, n_move_slots)
input_ids[:, 1 : n_to_copy + 1] = clean_moves[:, :n_to_copy]
# Cap game_lengths to n_move_slots (handles edge case where engine
# produces more moves than we have slots)
capped_lengths = game_lengths_t.clamp(max=n_move_slots)
# Targets: input shifted left by 1
targets = torch.zeros(B, seq_len, dtype=torch.long)
targets[:, :-1] = input_ids[:, 1:]
# Loss mask: True for positions 0 through capped_lengths[b]
cache_key_seq = ("seq", seq_len)
seq_positions = _positions_cache.get(cache_key_seq)
if seq_positions is None:
seq_positions = torch.arange(seq_len).unsqueeze(0)
_positions_cache[cache_key_seq] = seq_positions
loss_mask = seq_positions <= capped_lengths.unsqueeze(1)
return {
"input_ids": input_ids,
"targets": targets,
"loss_mask": loss_mask,
}
def _to_clm_batch(
move_ids: np.ndarray,
game_lengths: np.ndarray,
term_codes: np.ndarray,
seq_len: int,
) -> dict[str, torch.Tensor]:
"""Convert Rust engine output to CLM training tensors.
Convenience wrapper: computes outcome tokens from termination codes,
then delegates to pack_clm_sequences.
"""
outcome_tokens = _map_termination_to_outcome(term_codes, game_lengths)
return pack_clm_sequences(move_ids, game_lengths, outcome_tokens, seq_len)
class CLMDataset(torch.utils.data.IterableDataset):
"""Generates CLM training data on-the-fly via the Rust engine.
Each iteration yields a complete batch. Seeds are deterministic:
base_seed + step * num_workers + worker_id.
"""
def __init__(self, batch_size: int, max_ply: int, base_seed: int,
discard_ply_limit: bool = False):
super().__init__()
self.batch_size = batch_size
self.max_ply = max_ply
self.base_seed = base_seed
self.discard_ply_limit = discard_ply_limit
self._start_step = 0
self._main_pid = os.getpid()
def set_start_step(self, step: int) -> None:
self._start_step = step
def __iter__(self) -> Iterator[dict[str, torch.Tensor]]:
worker_info = torch.utils.data.get_worker_info()
worker_id = worker_info.id if worker_info else 0
num_workers = worker_info.num_workers if worker_info else 1
if worker_info is not None:
main_pid = self._main_pid
def _watchdog():
while True:
time.sleep(2)
try:
os.kill(main_pid, 0)
except OSError:
os._exit(1)
t = threading.Thread(target=_watchdog, daemon=True)
t.start()
step = self._start_step
while True:
seed = self.base_seed + step * num_workers + worker_id
input_ids, targets, loss_mask, _move_ids, _gl, _tc = \
engine.generate_clm_batch(
self.batch_size, self.max_ply, seed,
discard_ply_limit=self.discard_ply_limit,
)
yield {
"input_ids": torch.from_numpy(input_ids).long(),
"targets": torch.from_numpy(targets).long(),
"loss_mask": torch.from_numpy(loss_mask),
}
step += 1
def create_validation_set(
n_games: int, max_ply: int, seed: int,
discard_ply_limit: bool = False,
) -> dict[str, torch.Tensor]:
"""Generate a fixed validation set.
Also computes legal move masks for legal move rate evaluation.
Args:
max_ply: total CLM sequence length (256).
"""
input_ids, targets, loss_mask, move_ids, game_lengths, _tc = \
engine.generate_clm_batch(
n_games, max_ply, seed, discard_ply_limit=discard_ply_limit,
)
batch = {
"input_ids": torch.from_numpy(input_ids).long(),
"targets": torch.from_numpy(targets).long(),
"loss_mask": torch.from_numpy(loss_mask),
}
# Compute legal move masks for evaluating legal move rate
legal_grid, _legal_promo = engine.compute_legal_move_masks(move_ids, game_lengths)
batch["legal_grid"] = torch.from_numpy(legal_grid).long()
batch["game_lengths"] = torch.from_numpy(game_lengths).long()
return batch