"""Lichess data preparation for FiLM behavioral cloning. Parses a PGN file, tokenizes via the Rust engine, and produces PyTorch tensors ready for training. Legal move grids are computed per-batch during training (not precomputed) to keep memory independent of dataset size. """ from __future__ import annotations from pathlib import Path import numpy as np import torch import torch.utils.data import chess_engine as engine from pawn.config import ( WHITE_CHECKMATES, BLACK_CHECKMATES, DRAW_BY_RULE, PLY_LIMIT, ) # --------------------------------------------------------------------------- # PGN result → outcome token # --------------------------------------------------------------------------- _RESULT_MAP = { "1-0": "white", "0-1": "black", "1/2-1/2": "draw", } def _result_to_outcome(results: list[str]) -> torch.Tensor: """Map PGN result strings to outcome token IDs. For decisive games we use the checkmate token even though the actual termination was likely resignation/time — the prefix of moves is still valid strategic play and the outcome token approximation is acceptable per the spec (§3.4). """ outcomes = torch.full((len(results),), PLY_LIMIT, dtype=torch.long) for i, result in enumerate(results): mapped = _RESULT_MAP.get(result) if mapped == "white": outcomes[i] = WHITE_CHECKMATES elif mapped == "black": outcomes[i] = BLACK_CHECKMATES elif mapped == "draw": outcomes[i] = DRAW_BY_RULE return outcomes # --------------------------------------------------------------------------- # Legal token mask via fused Rust computation # --------------------------------------------------------------------------- def compute_legal_indices( move_ids: np.ndarray, game_lengths: np.ndarray, seq_len: int, vocab_size: int = 4278, ) -> np.ndarray: """Compute flat sparse indices for legal token masks (CPU only). Calls the Rust engine to replay games and returns flat i64 indices suitable for scattering into a (B, seq_len, vocab_size) bool mask. """ move_ids = np.ascontiguousarray(move_ids, dtype=np.int16) game_lengths = np.asarray(game_lengths, dtype=np.int16) return engine.compute_legal_token_masks_sparse( move_ids, game_lengths, seq_len, vocab_size, ) class LegalMaskBuilder: """Legal token mask via sparse Rust computation + GPU scatter. Calls engine.compute_legal_token_masks_sparse which replays games and returns flat i64 indices (~2 MB) instead of a dense bool mask (~70 MB). Indices are transferred to GPU and scattered into a pre-allocated buffer. Two usage modes: 1. ``scatter(indices, B)`` — fast GPU-only path for pre-computed indices (from ``LegalMaskCollate`` or precomputation). 2. ``__call__(batch)`` — legacy path that computes indices inline. """ def __init__(self, batch_size: int, max_ply: int, vocab_size: int = 4278, device: str = "cpu", max_index_buf: int = 4_000_000): self.vocab_size = vocab_size self.max_ply = max_ply self.T = max_ply + 1 # seq_len = outcome token + max_ply move slots self.device = device # Pre-allocated GPU output buffer self._mask_gpu = torch.zeros(batch_size, self.T, vocab_size, dtype=torch.bool, device=device) # Pre-allocated GPU index buffer to avoid per-batch allocation self._idx_buf = torch.empty(max_index_buf, dtype=torch.long, device=device) def scatter(self, legal_indices: torch.Tensor, B: int) -> torch.Tensor: """Scatter pre-computed CPU indices into the GPU mask buffer. Uses a pre-allocated index buffer to avoid per-batch GPU allocation. Falls back to a fresh allocation if the buffer is too small. """ if B > self._mask_gpu.shape[0]: raise ValueError( f"B={B} exceeds pre-allocated batch_size={self._mask_gpu.shape[0]}" ) mask_view = self._mask_gpu[:B] mask_view.zero_() n = legal_indices.shape[0] if n > 0: if n <= self._idx_buf.shape[0]: self._idx_buf[:n].copy_(legal_indices) mask_view.view(-1).index_fill_(0, self._idx_buf[:n], True) else: idx_gpu = legal_indices.to(self.device) mask_view.view(-1).index_fill_(0, idx_gpu, True) return mask_view def __call__(self, batch: dict) -> torch.Tensor: """Build (B, T, V) legal mask from batch move_ids + game_lengths. Computes sparse indices via Rust and scatters to the GPU buffer. For better performance, use ``LegalMaskCollate`` with DataLoader workers to compute indices off the critical path, then call ``scatter()`` directly. """ move_ids = batch["move_ids"] game_lengths_raw = batch["game_length"] B = move_ids.shape[0] if hasattr(move_ids, 'shape') else len(move_ids) if isinstance(move_ids, torch.Tensor): move_ids = move_ids.numpy() move_ids = np.ascontiguousarray(move_ids, dtype=np.int16) game_lengths = np.asarray(game_lengths_raw, dtype=np.int16) indices = engine.compute_legal_token_masks_sparse( move_ids, game_lengths, self.T, self.vocab_size, ) return self.scatter(torch.from_numpy(indices), B) class LegalMaskCollate: """Collate that computes legal mask indices in DataLoader workers. Wraps default collation and appends a ``legal_indices`` CPU tensor to each batch so the Rust replay runs in worker processes, off the GPU training critical path. """ def __init__(self, seq_len: int, vocab_size: int = 4278): self.seq_len = seq_len self.vocab_size = vocab_size def __call__(self, items: list[dict]) -> dict: batch = torch.utils.data.default_collate(items) move_ids = batch["move_ids"].numpy() game_lengths = np.asarray(batch["game_length"], dtype=np.int16) indices = compute_legal_indices( move_ids, game_lengths, self.seq_len, self.vocab_size, ) batch["legal_indices"] = torch.from_numpy(indices) return batch # --------------------------------------------------------------------------- # PGN → tokenized dataset with legal move masks # --------------------------------------------------------------------------- def prepare_lichess_dataset( pgn_path: str | Path, max_ply: int = 255, max_games: int = 50_000, min_ply: int = 10, ) -> dict: """Parse a PGN or Parquet file and produce training-ready tensors. If pgn_path ends with .parquet, delegates to prepare_lichess_parquet(). If pgn_path looks like a HuggingFace repo (contains '/'), loads from HF. Returns dict with: move_ids: (N, max_ply) int16 — tokenized moves game_lengths: (N,) int16 input_ids: (N, seq_len) long — [outcome, move_0, ..., PAD] targets: (N, seq_len) long — shifted left loss_mask: (N, seq_len) bool n_games: int """ pgn_path_str = str(pgn_path) if pgn_path_str.endswith(".parquet"): return prepare_lichess_parquet( parquet_path=pgn_path_str, max_ply=max_ply, max_games=max_games, min_ply=min_ply, ) # Check if it looks like a HF repo ID (e.g. "user/dataset") if "/" in pgn_path_str and not Path(pgn_path_str).exists(): return prepare_lichess_parquet( hf_repo=pgn_path_str, max_ply=max_ply, max_games=max_games, min_ply=min_ply, ) pgn_path = Path(pgn_path) # Parse with min_ply=1 so every parseable game appears in the output, # keeping result extraction aligned. We apply min_ply in Python below. print(f"Parsing PGN: {pgn_path}") move_ids, game_lengths, n_parsed = engine.parse_pgn_file( str(pgn_path), max_ply=max_ply, max_games=max_games, min_ply=1, ) N = move_ids.shape[0] print(f" Parsed {n_parsed} PGN games, {N} tokenized") move_ids = move_ids[:N] game_lengths = game_lengths[:N] # Extract results — aligned with engine output since min_ply=1 results = _extract_results(pgn_path, n_parsed)[:N] # Apply min_ply filter in Python on aligned arrays if min_ply > 1: keep = game_lengths >= min_ply move_ids = move_ids[keep] game_lengths = game_lengths[keep] results = [r for r, k in zip(results, keep) if k] N = len(results) print(f" After min_ply={min_ply} filter: {N} games") outcome_tokens = _result_to_outcome(results) seq_len = max_ply + 1 # outcome token + max_ply move slots from pawn.data import pack_clm_sequences batch = pack_clm_sequences(move_ids, game_lengths, outcome_tokens, seq_len) return { "move_ids": move_ids, "game_lengths": game_lengths, "input_ids": batch["input_ids"], "targets": batch["targets"], "loss_mask": batch["loss_mask"], "outcome_tokens": outcome_tokens, "n_games": N, } def prepare_lichess_parquet( parquet_path: str | Path = None, hf_repo: str = None, max_ply: int = 255, max_games: int = 50_000, min_ply: int = 10, ) -> dict: """Load a Lichess Parquet dataset and produce training-ready tensors. Reads from a local Parquet file or a HuggingFace dataset repo. Expects columns: pgn (SAN move text), result (1-0/0-1/1/2-1/2). Returns the same dict format as prepare_lichess_dataset(). """ import polars as pl if hf_repo is not None: from huggingface_hub import hf_hub_download, HfApi api = HfApi() files = api.list_repo_files(hf_repo, repo_type="dataset") parquet_files = [f for f in files if f.endswith(".parquet")] local_files = [hf_hub_download(hf_repo, pf, repo_type="dataset") for pf in parquet_files] lf = pl.scan_parquet(local_files) elif parquet_path is not None: lf = pl.scan_parquet(str(parquet_path)) else: raise ValueError("Either parquet_path or hf_repo must be provided") # Lazy: select only needed columns, limit rows, then collect df = ( lf.select(["pgn", "result"]) .head(max_games) .collect() ) n_to_use = len(df) print(f"Loaded {n_to_use} games from Parquet") pgn_strings = df["pgn"].to_list() results = df["result"].to_list() # Split PGN text into move lists, stripping comments, move numbers, results import re games: list[list[str]] = [] for pgn_text in pgn_strings: # Strip { ... } comments (clock annotations, etc.) cleaned = re.sub(r'\{[^}]*\}', '', pgn_text) tokens = cleaned.split() moves = [] for tok in tokens: if tok in ("1-0", "0-1", "1/2-1/2", "*"): break # Skip move numbers (1. 2. 12... etc.) stripped = tok.rstrip(".") if stripped and stripped.replace(".", "").isdigit(): continue if not tok: continue moves.append(tok) games.append(moves) # Tokenize via Rust engine (batch) print(f" Tokenizing {len(games)} games...") move_ids, game_lengths = engine.pgn_to_tokens(games, max_ply=max_ply) N = move_ids.shape[0] # Apply min_ply filter if min_ply > 1: keep = game_lengths >= min_ply move_ids = move_ids[keep] game_lengths = game_lengths[keep] results = [r for r, k in zip(results, keep) if k] N = len(results) print(f" After min_ply={min_ply} filter: {N} games") outcome_tokens = _result_to_outcome(results) seq_len = max_ply + 1 from pawn.data import pack_clm_sequences batch = pack_clm_sequences(move_ids, game_lengths, outcome_tokens, seq_len) return { "move_ids": move_ids, "game_lengths": game_lengths, "input_ids": batch["input_ids"], "targets": batch["targets"], "loss_mask": batch["loss_mask"], "outcome_tokens": outcome_tokens, "n_games": N, } def _extract_results(pgn_path: Path, max_games: int) -> list[str]: """Extract game results from PGN headers. Uses [Event header to delimit games, matching the Rust parser's game-boundary detection. The previous approach (one result per [Result] header) could miscount when headers were malformed. """ import re results: list[str] = [] current_result = "*" in_game = False with open(pgn_path) as f: for line in f: line = line.strip() if line.startswith("[Event "): if in_game: results.append(current_result) if len(results) >= max_games: break current_result = "*" in_game = True elif line.startswith('[Result "'): m = re.search(r'"([^"]+)"', line) if m: current_result = m.group(1) # Flush last game if in_game and len(results) < max_games: results.append(current_result) return results # --------------------------------------------------------------------------- # Dataset class # --------------------------------------------------------------------------- class LichessDataset(torch.utils.data.Dataset): """Map-style dataset for Lichess behavioral cloning.""" def __init__(self, data: dict, start: int = 0, end: int | None = None): end = end or data["n_games"] self.input_ids = data["input_ids"][start:end] self.targets = data["targets"][start:end] self.loss_mask = data["loss_mask"][start:end] self.move_ids = data["move_ids"][start:end] self.game_lengths = data["game_lengths"][start:end] def share_memory(self): """Move tensors to shared memory so spawn workers avoid copies.""" self.input_ids = self.input_ids.share_memory_() self.targets = self.targets.share_memory_() self.loss_mask = self.loss_mask.share_memory_() self.move_ids = torch.from_numpy(np.array(self.move_ids)).share_memory_() self.game_lengths = torch.from_numpy(np.array(self.game_lengths)).share_memory_() return self def __len__(self) -> int: return len(self.input_ids) def __getitem__(self, idx: int) -> dict[str, torch.Tensor | int]: return { "input_ids": self.input_ids[idx], "targets": self.targets[idx], "loss_mask": self.loss_mask[idx], "move_ids": self.move_ids[idx], "game_length": int(self.game_lengths[idx]), }