| """PAWN data pipeline: on-the-fly generation via Rust engine.""" |
|
|
| import os |
| import threading |
| import time |
| from collections.abc import Iterator |
|
|
| import numpy as np |
| import torch |
| import torch.utils.data |
|
|
| import chess_engine as engine |
|
|
| from pawn.config import ( |
| WHITE_CHECKMATES, |
| BLACK_CHECKMATES, |
| STALEMATE, |
| DRAW_BY_RULE, |
| PLY_LIMIT, |
| ) |
|
|
|
|
| _positions_cache: dict[tuple[str, int], torch.Tensor] = {} |
|
|
|
|
| def _map_termination_to_outcome( |
| term_codes: np.ndarray, game_lengths: np.ndarray |
| ) -> torch.Tensor: |
| """Map engine termination codes to outcome token IDs. |
| |
| Engine codes: 0=Checkmate, 1=Stalemate, 2=SeventyFiveMoveRule, |
| 3=FivefoldRepetition, 4=InsufficientMaterial, 5=PlyLimit |
| |
| For checkmate, who checkmated is determined by game length: |
| - Odd game_length (last ply index even) -> white delivered checkmate |
| - Even game_length (last ply index odd) -> black delivered checkmate |
| """ |
| term = torch.from_numpy(term_codes).long() |
| gl = torch.from_numpy(game_lengths).long() |
|
|
| outcomes = torch.full((len(term_codes),), PLY_LIMIT, dtype=torch.long) |
|
|
| is_checkmate = term == 0 |
| outcomes[is_checkmate & (gl % 2 == 1)] = WHITE_CHECKMATES |
| outcomes[is_checkmate & (gl % 2 == 0)] = BLACK_CHECKMATES |
| outcomes[term == 1] = STALEMATE |
| outcomes[(term == 2) | (term == 3) | (term == 4)] = DRAW_BY_RULE |
| |
|
|
| return outcomes |
|
|
|
|
| def pack_clm_sequences( |
| move_ids: np.ndarray, |
| game_lengths: np.ndarray, |
| outcome_tokens: torch.Tensor, |
| seq_len: int, |
| ) -> dict[str, torch.Tensor]: |
| """Pack move arrays into CLM training tensors. |
| |
| Constructs input_ids = [outcome, move_1, ..., move_N, PAD, ...] |
| and targets shifted left by 1. |
| |
| Args: |
| move_ids: (B, max_ply) raw move token IDs |
| game_lengths: (B,) actual game lengths |
| outcome_tokens: (B,) pre-computed outcome token IDs (4273-4277) |
| seq_len: total CLM sequence length (256) |
| """ |
| B = len(game_lengths) |
| n_move_slots = seq_len - 1 |
| max_ply = move_ids.shape[1] |
|
|
| game_lengths_t = torch.from_numpy(game_lengths).long() |
| move_ids_t = torch.from_numpy(move_ids).long() |
|
|
| |
| input_ids = torch.zeros(B, seq_len, dtype=torch.long) |
| input_ids[:, 0] = outcome_tokens |
|
|
| |
| cache_key = ("engine", max_ply) |
| engine_positions = _positions_cache.get(cache_key) |
| if engine_positions is None: |
| engine_positions = torch.arange(max_ply).unsqueeze(0) |
| _positions_cache[cache_key] = engine_positions |
| move_mask = engine_positions < game_lengths_t.unsqueeze(1) |
| clean_moves = move_ids_t * move_mask |
|
|
| |
| n_to_copy = min(max_ply, n_move_slots) |
| input_ids[:, 1 : n_to_copy + 1] = clean_moves[:, :n_to_copy] |
|
|
| |
| |
| capped_lengths = game_lengths_t.clamp(max=n_move_slots) |
|
|
| |
| targets = torch.zeros(B, seq_len, dtype=torch.long) |
| targets[:, :-1] = input_ids[:, 1:] |
|
|
| |
| cache_key_seq = ("seq", seq_len) |
| seq_positions = _positions_cache.get(cache_key_seq) |
| if seq_positions is None: |
| seq_positions = torch.arange(seq_len).unsqueeze(0) |
| _positions_cache[cache_key_seq] = seq_positions |
| loss_mask = seq_positions <= capped_lengths.unsqueeze(1) |
|
|
| return { |
| "input_ids": input_ids, |
| "targets": targets, |
| "loss_mask": loss_mask, |
| } |
|
|
|
|
| def _to_clm_batch( |
| move_ids: np.ndarray, |
| game_lengths: np.ndarray, |
| term_codes: np.ndarray, |
| seq_len: int, |
| ) -> dict[str, torch.Tensor]: |
| """Convert Rust engine output to CLM training tensors. |
| |
| Convenience wrapper: computes outcome tokens from termination codes, |
| then delegates to pack_clm_sequences. |
| """ |
| outcome_tokens = _map_termination_to_outcome(term_codes, game_lengths) |
| return pack_clm_sequences(move_ids, game_lengths, outcome_tokens, seq_len) |
|
|
|
|
| class CLMDataset(torch.utils.data.IterableDataset): |
| """Generates CLM training data on-the-fly via the Rust engine. |
| |
| Each iteration yields a complete batch. Seeds are deterministic: |
| base_seed + step * num_workers + worker_id. |
| """ |
|
|
| def __init__(self, batch_size: int, max_ply: int, base_seed: int, |
| discard_ply_limit: bool = False): |
| super().__init__() |
| self.batch_size = batch_size |
| self.max_ply = max_ply |
| self.base_seed = base_seed |
| self.discard_ply_limit = discard_ply_limit |
| self._start_step = 0 |
| self._main_pid = os.getpid() |
|
|
| def set_start_step(self, step: int) -> None: |
| self._start_step = step |
|
|
| def __iter__(self) -> Iterator[dict[str, torch.Tensor]]: |
| worker_info = torch.utils.data.get_worker_info() |
| worker_id = worker_info.id if worker_info else 0 |
| num_workers = worker_info.num_workers if worker_info else 1 |
|
|
| if worker_info is not None: |
| main_pid = self._main_pid |
|
|
| def _watchdog(): |
| while True: |
| time.sleep(2) |
| try: |
| os.kill(main_pid, 0) |
| except OSError: |
| os._exit(1) |
|
|
| t = threading.Thread(target=_watchdog, daemon=True) |
| t.start() |
|
|
| step = self._start_step |
| while True: |
| seed = self.base_seed + step * num_workers + worker_id |
| input_ids, targets, loss_mask, _move_ids, _gl, _tc = \ |
| engine.generate_clm_batch( |
| self.batch_size, self.max_ply, seed, |
| discard_ply_limit=self.discard_ply_limit, |
| ) |
| yield { |
| "input_ids": torch.from_numpy(input_ids).long(), |
| "targets": torch.from_numpy(targets).long(), |
| "loss_mask": torch.from_numpy(loss_mask), |
| } |
| step += 1 |
|
|
|
|
| def create_validation_set( |
| n_games: int, max_ply: int, seed: int, |
| discard_ply_limit: bool = False, |
| ) -> dict[str, torch.Tensor]: |
| """Generate a fixed validation set. |
| |
| Also computes legal move masks for legal move rate evaluation. |
| |
| Args: |
| max_ply: total CLM sequence length (256). |
| """ |
| input_ids, targets, loss_mask, move_ids, game_lengths, _tc = \ |
| engine.generate_clm_batch( |
| n_games, max_ply, seed, discard_ply_limit=discard_ply_limit, |
| ) |
| batch = { |
| "input_ids": torch.from_numpy(input_ids).long(), |
| "targets": torch.from_numpy(targets).long(), |
| "loss_mask": torch.from_numpy(loss_mask), |
| } |
|
|
| |
| legal_grid, _legal_promo = engine.compute_legal_move_masks(move_ids, game_lengths) |
| batch["legal_grid"] = torch.from_numpy(legal_grid).long() |
| batch["game_lengths"] = torch.from_numpy(game_lengths).long() |
|
|
| return batch |
|
|