| """Lichess data preparation for FiLM behavioral cloning. |
| |
| Parses a PGN file, tokenizes via the Rust engine, and produces PyTorch |
| tensors ready for training. Legal move grids are computed per-batch during |
| training (not precomputed) to keep memory independent of dataset size. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| import torch.utils.data |
|
|
| import chess_engine as engine |
|
|
| from pawn.config import ( |
| WHITE_CHECKMATES, |
| BLACK_CHECKMATES, |
| DRAW_BY_RULE, |
| PLY_LIMIT, |
| ) |
|
|
|
|
| |
| |
| |
|
|
| _RESULT_MAP = { |
| "1-0": "white", |
| "0-1": "black", |
| "1/2-1/2": "draw", |
| } |
|
|
|
|
| def _result_to_outcome(results: list[str]) -> torch.Tensor: |
| """Map PGN result strings to outcome token IDs. |
| |
| For decisive games we use the checkmate token even though the actual |
| termination was likely resignation/time — the prefix of moves is still |
| valid strategic play and the outcome token approximation is acceptable |
| per the spec (§3.4). |
| """ |
| outcomes = torch.full((len(results),), PLY_LIMIT, dtype=torch.long) |
| for i, result in enumerate(results): |
| mapped = _RESULT_MAP.get(result) |
| if mapped == "white": |
| outcomes[i] = WHITE_CHECKMATES |
| elif mapped == "black": |
| outcomes[i] = BLACK_CHECKMATES |
| elif mapped == "draw": |
| outcomes[i] = DRAW_BY_RULE |
| return outcomes |
|
|
|
|
| |
| |
| |
|
|
|
|
| def compute_legal_indices( |
| move_ids: np.ndarray, |
| game_lengths: np.ndarray, |
| seq_len: int, |
| vocab_size: int = 4278, |
| ) -> np.ndarray: |
| """Compute flat sparse indices for legal token masks (CPU only). |
| |
| Calls the Rust engine to replay games and returns flat i64 indices |
| suitable for scattering into a (B, seq_len, vocab_size) bool mask. |
| """ |
| move_ids = np.ascontiguousarray(move_ids, dtype=np.int16) |
| game_lengths = np.asarray(game_lengths, dtype=np.int16) |
| return engine.compute_legal_token_masks_sparse( |
| move_ids, game_lengths, seq_len, vocab_size, |
| ) |
|
|
|
|
| class LegalMaskBuilder: |
| """Legal token mask via sparse Rust computation + GPU scatter. |
| |
| Calls engine.compute_legal_token_masks_sparse which replays games and |
| returns flat i64 indices (~2 MB) instead of a dense bool mask (~70 MB). |
| Indices are transferred to GPU and scattered into a pre-allocated buffer. |
| |
| Two usage modes: |
| 1. ``scatter(indices, B)`` — fast GPU-only path for pre-computed indices |
| (from ``LegalMaskCollate`` or precomputation). |
| 2. ``__call__(batch)`` — legacy path that computes indices inline. |
| """ |
|
|
| def __init__(self, batch_size: int, max_ply: int, vocab_size: int = 4278, |
| device: str = "cpu", max_index_buf: int = 4_000_000): |
| self.vocab_size = vocab_size |
| self.max_ply = max_ply |
| self.T = max_ply + 1 |
| self.device = device |
|
|
| |
| self._mask_gpu = torch.zeros(batch_size, self.T, vocab_size, |
| dtype=torch.bool, device=device) |
| |
| self._idx_buf = torch.empty(max_index_buf, dtype=torch.long, device=device) |
|
|
| def scatter(self, legal_indices: torch.Tensor, B: int) -> torch.Tensor: |
| """Scatter pre-computed CPU indices into the GPU mask buffer. |
| |
| Uses a pre-allocated index buffer to avoid per-batch GPU allocation. |
| Falls back to a fresh allocation if the buffer is too small. |
| """ |
| if B > self._mask_gpu.shape[0]: |
| raise ValueError( |
| f"B={B} exceeds pre-allocated batch_size={self._mask_gpu.shape[0]}" |
| ) |
| mask_view = self._mask_gpu[:B] |
| mask_view.zero_() |
| n = legal_indices.shape[0] |
| if n > 0: |
| if n <= self._idx_buf.shape[0]: |
| self._idx_buf[:n].copy_(legal_indices) |
| mask_view.view(-1).index_fill_(0, self._idx_buf[:n], True) |
| else: |
| idx_gpu = legal_indices.to(self.device) |
| mask_view.view(-1).index_fill_(0, idx_gpu, True) |
| return mask_view |
|
|
| def __call__(self, batch: dict) -> torch.Tensor: |
| """Build (B, T, V) legal mask from batch move_ids + game_lengths. |
| |
| Computes sparse indices via Rust and scatters to the GPU buffer. |
| For better performance, use ``LegalMaskCollate`` with DataLoader |
| workers to compute indices off the critical path, then call |
| ``scatter()`` directly. |
| """ |
| move_ids = batch["move_ids"] |
| game_lengths_raw = batch["game_length"] |
| B = move_ids.shape[0] if hasattr(move_ids, 'shape') else len(move_ids) |
|
|
| if isinstance(move_ids, torch.Tensor): |
| move_ids = move_ids.numpy() |
| move_ids = np.ascontiguousarray(move_ids, dtype=np.int16) |
| game_lengths = np.asarray(game_lengths_raw, dtype=np.int16) |
|
|
| indices = engine.compute_legal_token_masks_sparse( |
| move_ids, game_lengths, self.T, self.vocab_size, |
| ) |
|
|
| return self.scatter(torch.from_numpy(indices), B) |
|
|
|
|
| class LegalMaskCollate: |
| """Collate that computes legal mask indices in DataLoader workers. |
| |
| Wraps default collation and appends a ``legal_indices`` CPU tensor |
| to each batch so the Rust replay runs in worker processes, off the |
| GPU training critical path. |
| """ |
|
|
| def __init__(self, seq_len: int, vocab_size: int = 4278): |
| self.seq_len = seq_len |
| self.vocab_size = vocab_size |
|
|
| def __call__(self, items: list[dict]) -> dict: |
| batch = torch.utils.data.default_collate(items) |
| move_ids = batch["move_ids"].numpy() |
| game_lengths = np.asarray(batch["game_length"], dtype=np.int16) |
| indices = compute_legal_indices( |
| move_ids, game_lengths, self.seq_len, self.vocab_size, |
| ) |
| batch["legal_indices"] = torch.from_numpy(indices) |
| return batch |
|
|
|
|
| |
| |
| |
|
|
|
|
| def prepare_lichess_dataset( |
| pgn_path: str | Path, |
| max_ply: int = 255, |
| max_games: int = 50_000, |
| min_ply: int = 10, |
| ) -> dict: |
| """Parse a PGN or Parquet file and produce training-ready tensors. |
| |
| If pgn_path ends with .parquet, delegates to prepare_lichess_parquet(). |
| If pgn_path looks like a HuggingFace repo (contains '/'), loads from HF. |
| |
| Returns dict with: |
| move_ids: (N, max_ply) int16 — tokenized moves |
| game_lengths: (N,) int16 |
| input_ids: (N, seq_len) long — [outcome, move_0, ..., PAD] |
| targets: (N, seq_len) long — shifted left |
| loss_mask: (N, seq_len) bool |
| n_games: int |
| """ |
| pgn_path_str = str(pgn_path) |
| if pgn_path_str.endswith(".parquet"): |
| return prepare_lichess_parquet( |
| parquet_path=pgn_path_str, max_ply=max_ply, |
| max_games=max_games, min_ply=min_ply, |
| ) |
| |
| if "/" in pgn_path_str and not Path(pgn_path_str).exists(): |
| return prepare_lichess_parquet( |
| hf_repo=pgn_path_str, max_ply=max_ply, |
| max_games=max_games, min_ply=min_ply, |
| ) |
| pgn_path = Path(pgn_path) |
|
|
| |
| |
| print(f"Parsing PGN: {pgn_path}") |
| move_ids, game_lengths, n_parsed = engine.parse_pgn_file( |
| str(pgn_path), max_ply=max_ply, max_games=max_games, min_ply=1, |
| ) |
| N = move_ids.shape[0] |
| print(f" Parsed {n_parsed} PGN games, {N} tokenized") |
|
|
| move_ids = move_ids[:N] |
| game_lengths = game_lengths[:N] |
|
|
| |
| results = _extract_results(pgn_path, n_parsed)[:N] |
|
|
| |
| if min_ply > 1: |
| keep = game_lengths >= min_ply |
| move_ids = move_ids[keep] |
| game_lengths = game_lengths[keep] |
| results = [r for r, k in zip(results, keep) if k] |
| N = len(results) |
| print(f" After min_ply={min_ply} filter: {N} games") |
|
|
| outcome_tokens = _result_to_outcome(results) |
|
|
| seq_len = max_ply + 1 |
|
|
| from pawn.data import pack_clm_sequences |
| batch = pack_clm_sequences(move_ids, game_lengths, outcome_tokens, seq_len) |
|
|
| return { |
| "move_ids": move_ids, |
| "game_lengths": game_lengths, |
| "input_ids": batch["input_ids"], |
| "targets": batch["targets"], |
| "loss_mask": batch["loss_mask"], |
| "outcome_tokens": outcome_tokens, |
| "n_games": N, |
| } |
|
|
|
|
| def prepare_lichess_parquet( |
| parquet_path: str | Path = None, |
| hf_repo: str = None, |
| max_ply: int = 255, |
| max_games: int = 50_000, |
| min_ply: int = 10, |
| ) -> dict: |
| """Load a Lichess Parquet dataset and produce training-ready tensors. |
| |
| Reads from a local Parquet file or a HuggingFace dataset repo. |
| Expects columns: pgn (SAN move text), result (1-0/0-1/1/2-1/2). |
| |
| Returns the same dict format as prepare_lichess_dataset(). |
| """ |
|
|
| import polars as pl |
|
|
| if hf_repo is not None: |
| from huggingface_hub import hf_hub_download, HfApi |
| api = HfApi() |
| files = api.list_repo_files(hf_repo, repo_type="dataset") |
| parquet_files = [f for f in files if f.endswith(".parquet")] |
| local_files = [hf_hub_download(hf_repo, pf, repo_type="dataset") |
| for pf in parquet_files] |
| lf = pl.scan_parquet(local_files) |
| elif parquet_path is not None: |
| lf = pl.scan_parquet(str(parquet_path)) |
| else: |
| raise ValueError("Either parquet_path or hf_repo must be provided") |
|
|
| |
| df = ( |
| lf.select(["pgn", "result"]) |
| .head(max_games) |
| .collect() |
| ) |
| n_to_use = len(df) |
| print(f"Loaded {n_to_use} games from Parquet") |
|
|
| pgn_strings = df["pgn"].to_list() |
| results = df["result"].to_list() |
|
|
| |
| import re |
| games: list[list[str]] = [] |
| for pgn_text in pgn_strings: |
| |
| cleaned = re.sub(r'\{[^}]*\}', '', pgn_text) |
| tokens = cleaned.split() |
| moves = [] |
| for tok in tokens: |
| if tok in ("1-0", "0-1", "1/2-1/2", "*"): |
| break |
| |
| stripped = tok.rstrip(".") |
| if stripped and stripped.replace(".", "").isdigit(): |
| continue |
| if not tok: |
| continue |
| moves.append(tok) |
| games.append(moves) |
|
|
| |
| print(f" Tokenizing {len(games)} games...") |
| move_ids, game_lengths = engine.pgn_to_tokens(games, max_ply=max_ply) |
| N = move_ids.shape[0] |
|
|
| |
| if min_ply > 1: |
| keep = game_lengths >= min_ply |
| move_ids = move_ids[keep] |
| game_lengths = game_lengths[keep] |
| results = [r for r, k in zip(results, keep) if k] |
| N = len(results) |
| print(f" After min_ply={min_ply} filter: {N} games") |
|
|
| outcome_tokens = _result_to_outcome(results) |
|
|
| seq_len = max_ply + 1 |
| from pawn.data import pack_clm_sequences |
| batch = pack_clm_sequences(move_ids, game_lengths, outcome_tokens, seq_len) |
|
|
| return { |
| "move_ids": move_ids, |
| "game_lengths": game_lengths, |
| "input_ids": batch["input_ids"], |
| "targets": batch["targets"], |
| "loss_mask": batch["loss_mask"], |
| "outcome_tokens": outcome_tokens, |
| "n_games": N, |
| } |
|
|
|
|
| def _extract_results(pgn_path: Path, max_games: int) -> list[str]: |
| """Extract game results from PGN headers. |
| |
| Uses [Event header to delimit games, matching the Rust parser's |
| game-boundary detection. The previous approach (one result per |
| [Result] header) could miscount when headers were malformed. |
| """ |
| import re |
| results: list[str] = [] |
| current_result = "*" |
| in_game = False |
|
|
| with open(pgn_path) as f: |
| for line in f: |
| line = line.strip() |
| if line.startswith("[Event "): |
| if in_game: |
| results.append(current_result) |
| if len(results) >= max_games: |
| break |
| current_result = "*" |
| in_game = True |
| elif line.startswith('[Result "'): |
| m = re.search(r'"([^"]+)"', line) |
| if m: |
| current_result = m.group(1) |
| |
| if in_game and len(results) < max_games: |
| results.append(current_result) |
|
|
| return results |
|
|
|
|
| |
| |
| |
|
|
|
|
| class LichessDataset(torch.utils.data.Dataset): |
| """Map-style dataset for Lichess behavioral cloning.""" |
|
|
| def __init__(self, data: dict, start: int = 0, end: int | None = None): |
| end = end or data["n_games"] |
| self.input_ids = data["input_ids"][start:end] |
| self.targets = data["targets"][start:end] |
| self.loss_mask = data["loss_mask"][start:end] |
| self.move_ids = data["move_ids"][start:end] |
| self.game_lengths = data["game_lengths"][start:end] |
|
|
| def share_memory(self): |
| """Move tensors to shared memory so spawn workers avoid copies.""" |
| self.input_ids = self.input_ids.share_memory_() |
| self.targets = self.targets.share_memory_() |
| self.loss_mask = self.loss_mask.share_memory_() |
| self.move_ids = torch.from_numpy(np.array(self.move_ids)).share_memory_() |
| self.game_lengths = torch.from_numpy(np.array(self.game_lengths)).share_memory_() |
| return self |
|
|
| def __len__(self) -> int: |
| return len(self.input_ids) |
|
|
| def __getitem__(self, idx: int) -> dict[str, torch.Tensor | int]: |
| return { |
| "input_ids": self.input_ids[idx], |
| "targets": self.targets[idx], |
| "loss_mask": self.loss_mask[idx], |
| "move_ids": self.move_ids[idx], |
| "game_length": int(self.game_lengths[idx]), |
| } |
|
|