""" Custom Chess Tokenizer for the Chess Challenge. This tokenizer treats each move as a single token using the extended UCI notation from the Lichess dataset (e.g., WPe2e4, BNg8f6). The dataset format uses: - W/B prefix for White/Black - Piece letter: P=Pawn, N=Knight, B=Bishop, R=Rook, Q=Queen, K=King - Source and destination squares (e.g., e2e4) - Special suffixes: (x)=capture, (+)=check, (+*)=checkmate, (o)/(O)=castling """ from __future__ import annotations import json import os from pathlib import Path from typing import Dict, List, Optional from transformers import PreTrainedTokenizer class ChessTokenizer(PreTrainedTokenizer): """ A custom tokenizer for chess moves using extended UCI notation. This tokenizer maps each possible chess move to a unique token ID. The vocabulary is built from the training dataset to ensure all moves encountered during training have a corresponding token. Example: >>> tokenizer = ChessTokenizer() >>> tokenizer.encode("WPe2e4 BPe7e5") [1, 42, 87, 2] # [BOS, e2e4, e7e5, EOS] """ model_input_names = ["input_ids", "attention_mask"] vocab_files_names = {"vocab_file": "vocab.json"} # Special tokens PAD_TOKEN = "[PAD]" BOS_TOKEN = "[BOS]" EOS_TOKEN = "[EOS]" UNK_TOKEN = "[UNK]" EOM_TOKEN = "[EOM]" # End of Move - marks boundary between moves def __init__( self, vocab_file: Optional[str] = None, vocab: Optional[Dict[str, int]] = None, component_mode: bool = False, **kwargs, ): """ Initialize the chess tokenizer. Args: vocab_file: Path to a JSON file containing the vocabulary mapping. vocab: Dictionary mapping tokens to IDs (alternative to vocab_file). component_mode: If True, tokenize moves into components (WP, e2, e4). **kwargs: Additional arguments passed to PreTrainedTokenizer. """ # Initialize special tokens self._pad_token = self.PAD_TOKEN self._bos_token = self.BOS_TOKEN self._eos_token = self.EOS_TOKEN self._unk_token = self.UNK_TOKEN self._eom_token = self.EOM_TOKEN # Component mode flag (for splitting moves into parts) self._component_mode = component_mode # Remove any duplicate special-token entries passed through kwargs # to avoid "multiple values for keyword" errors when loading from disk. kwargs.pop("pad_token", None) kwargs.pop("bos_token", None) kwargs.pop("eos_token", None) kwargs.pop("unk_token", None) kwargs.pop("eom_token", None) kwargs.pop("component_mode", None) # Load or create vocabulary if vocab is not None: self._vocab = vocab elif vocab_file is not None and os.path.exists(vocab_file): with open(vocab_file, "r", encoding="utf-8") as f: self._vocab = json.load(f) else: # Create a minimal vocabulary with just special tokens # The full vocabulary should be built from the dataset self._vocab = self._create_default_vocab() # Create reverse mapping self._ids_to_tokens = {v: k for k, v in self._vocab.items()} # Call parent init AFTER setting up vocab super().__init__( pad_token=self._pad_token, bos_token=self._bos_token, eos_token=self._eos_token, unk_token=self._unk_token, component_mode=component_mode, # This gets saved to tokenizer_config.json **kwargs, ) # Store EOM token ID for easy access self.eom_token_id = self._vocab.get(self.EOM_TOKEN, -1) def _create_default_vocab(self) -> Dict[str, int]: """ Create a minimal default vocabulary with just special tokens. For the full vocabulary, use `build_vocab_from_dataset()`. This minimal vocab is just a placeholder - you should build from data. """ special_tokens = [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN] vocab = {token: idx for idx, token in enumerate(special_tokens)} return vocab @classmethod def build_vocab_from_iterator( cls, iterator, min_frequency: int = 1, ) -> "ChessTokenizer": """ Build a tokenizer vocabulary from an iterator of game strings. Args: iterator: An iterator yielding game strings (space-separated moves). min_frequency: Minimum frequency for a token to be included. Returns: A ChessTokenizer with the built vocabulary. """ from collections import Counter token_counts = Counter() for game in iterator: moves = game.strip().split() token_counts.update(moves) # Filter by frequency tokens = [ token for token, count in token_counts.items() if count >= min_frequency ] # Sort for reproducibility tokens = sorted(tokens) # Build vocabulary special_tokens = [cls.PAD_TOKEN, cls.BOS_TOKEN, cls.EOS_TOKEN, cls.UNK_TOKEN] vocab = {token: idx for idx, token in enumerate(special_tokens + tokens)} return cls(vocab=vocab) @classmethod def build_vocab_from_dataset( cls, dataset_name: str = "dlouapre/lichess_2025-01_1M", split: str = "train", column: str = "text", min_frequency: int = 500, max_samples: Optional[int] = 100000, ) -> "ChessTokenizer": """ Build a tokenizer vocabulary from a Hugging Face dataset. Args: dataset_name: Name of the dataset on Hugging Face Hub. split: Dataset split to use. column: Column containing the game strings. min_frequency: Minimum frequency for a token to be included (default: 500). max_samples: Maximum number of samples to process (default: 100k). Returns: A ChessTokenizer with the built vocabulary. """ from datasets import load_dataset dataset = load_dataset(dataset_name, split=split) if max_samples is not None: dataset = dataset.select(range(min(max_samples, len(dataset)))) def game_iterator(): for example in dataset: yield example[column] return cls.build_vocab_from_iterator(game_iterator(), min_frequency=min_frequency) @classmethod def build_vocab_more_detailed( cls, ) -> "ChessTokenizer": """ Build a component-based tokenizer for chess moves. Instead of one token per move (WPe2e4), splits into components: WPe2e4 -> [WP, e2, e4] BNg8f6(x) -> [BN, g8, f6, (x)] This gives ~90 tokens instead of ~1200, with better generalization. Returns: A ChessTokenizer with component vocabulary. """ # Combined color+piece tokens (avoids B collision between Black and Bishop) tokens_pieces = [ "WP", "WN", "WB", "WR", "WQ", "WK", # White pieces "BP", "BN", "BB", "BR", "BQ", "BK", # Black pieces ] # the positions: files = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'] ranks = ['1', '2', '3', '4', '5', '6', '7', '8'] tokens_positions = [f + r for f in files for r in ranks] # the special suffixes: tokens_suffixes = [ "(x)", # capture "(+)", # check "(x+)", # capture + check "(+*)", # checkmate "(x+*)", # capture + checkmate "(o)", # kingside castling "(O)", # queenside castling "(xE)", # en passant "=Q", # promotion to queen "=R", # promotion to rook "=B", # promotion to bishop "=N", # promotion to knight ] # Combine all tokens tokens = tokens_pieces + tokens_positions + tokens_suffixes # Build vocabulary with [EOM] for move boundaries # [EOM] helps the model understand when a move ends special_tokens = [cls.PAD_TOKEN, cls.BOS_TOKEN, cls.EOS_TOKEN, cls.UNK_TOKEN, cls.EOM_TOKEN] vocab = {token: idx for idx, token in enumerate(special_tokens + tokens)} for ind, token in enumerate(special_tokens+tokens): print(f"Token {ind}: {token}") # Pass component_mode=True so it gets saved to tokenizer_config.json return cls(vocab=vocab, component_mode=True) @property def vocab_size(self) -> int: """Return the size of the vocabulary.""" return len(self._vocab) def get_vocab(self) -> Dict[str, int]: """Return the vocabulary as a dictionary.""" return dict(self._vocab) def _tokenize(self, text: str) -> List[str]: """ Tokenize a string of moves into a list of tokens. If component_mode is enabled, splits each move into parts: WPe2e4 -> [W, P, e2, e4, " "] BNg8f6(x) -> [B, N, g8, f6, (x), " "] Args: text: A string of space-separated moves. Returns: List of tokens. """ if getattr(self, '_component_mode', False): return self._tokenize_components(text) return text.strip().split() def _tokenize_components(self, text: str) -> List[str]: """ Tokenize moves into component parts with [EOM] boundaries. Move format: [Color][Piece][from_square][to_square][suffix] [EOM] Example: WPe2e4 -> [WP, e2, e4, EOM] BNg8f6(x) -> [BN, g8, f6, (x), EOM] """ import re tokens = [] moves = text.strip().split() for i, move in enumerate(moves): # Skip special tokens if move in [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN, self.EOM_TOKEN]: tokens.append(move) continue # Parse move: ColorPiece + from_square + to_square + optional suffix # Pattern: (W|B)(P|N|B|R|Q|K)([a-h][1-8])([a-h][1-8])(suffix)? pattern = r'^([WB])([PNBRQK])([a-h][1-8])([a-h][1-8])(.*)$' match = re.match(pattern, move) if match: color, piece, from_sq, to_sq, suffix = match.groups() # Combined color+piece token (e.g., "WP", "BN", "BB") tokens.append(color + piece) tokens.extend([from_sq, to_sq]) # Handle suffix (could be combination like "(x+)" or "=Q") if suffix: # Try to match known suffixes suffix_pattern = r'(\(x\+\*\)|\(x\+\)|\(\+\*\)|\(xE\)|\(x\)|\(\+\)|\(o\)|\(O\)|=Q|=R|=B|=N)' suffix_matches = re.findall(suffix_pattern, suffix) tokens.extend(suffix_matches) # Add [EOM] to mark end of this move tokens.append(self.EOM_TOKEN) else: # Fallback: add as unknown + EOM tokens.append(self.UNK_TOKEN) tokens.append(self.EOM_TOKEN) return tokens def _convert_token_to_id(self, token: str) -> int: """Convert a token to its ID.""" return self._vocab.get(token, self._vocab.get(self.UNK_TOKEN, 0)) def _convert_id_to_token(self, index: int) -> str: """Convert an ID to its token.""" token = self._ids_to_tokens.get(index, self.UNK_TOKEN) # Convert [EOM] to whitespace for evaluator compatibility # This makes _generate_until_whitespace stop after one move if token == self.EOM_TOKEN: return " " return token # Color+piece tokens that mark the start of a new move _MOVE_START_TOKENS = {"WP", "WN", "WB", "WR", "WQ", "WK", "BP", "BN", "BB", "BR", "BQ", "BK"} def convert_tokens_to_string(self, tokens: List[str]) -> str: """Convert a list of tokens back to a string. In component mode, reconstructs moves by replacing [EOM] with spaces. CRITICAL: [EOM] must decode to a non-empty whitespace string so that the evaluator's _generate_until_whitespace stops after one move. """ # Filter out special tokens except EOM for cleaner output special = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN} if getattr(self, '_component_mode', False): # Reconstruct with [EOM] as space delimiter result = [] for token in tokens: if token == self.EOM_TOKEN: # MUST be non-empty whitespace for evaluator result.append(" ") elif token not in special: result.append(token) # Don't strip! We need the trailing space from [EOM] return "".join(result) # Non-component mode: just join with spaces filtered = [t for t in tokens if t not in special] return " ".join(filtered) # ========================================================================= # Structured Generation Support Methods # ========================================================================= def get_token_category(self, token: str) -> str: """Categorize a token into: piece, square, suffix, eom, or special. Args: token: Token string to categorize. Returns: Category name: 'piece', 'square', 'suffix', 'eom', or 'special'. """ if token in [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]: return 'special' if token == self.EOM_TOKEN: return 'eom' if self.is_piece_token(token): return 'piece' if self.is_square_token(token): return 'square' if self.is_suffix_token(token): return 'suffix' return 'unknown' def is_piece_token(self, token: str) -> bool: """Check if token is a piece token (WP, BN, etc.).""" return token in ['WP', 'WN', 'WB', 'WR', 'WQ', 'WK', 'BP', 'BN', 'BB', 'BR', 'BQ', 'BK'] def is_square_token(self, token: str) -> bool: """Check if token is a square token (e2, g8, etc.).""" if len(token) != 2: return False return token[0] in 'abcdefgh' and token[1] in '12345678' def is_suffix_token(self, token: str) -> bool: """Check if token is a suffix token ((x), (+), =Q, etc.).""" return token in ['(x)', '(+)', '(x+)', '(+*)', '(x+*)', '(o)', '(O)', '(xE)', '=Q', '=R', '=B', '=N'] def is_eom_token(self, token: str) -> bool: """Check if token is the [EOM] token.""" return token == self.EOM_TOKEN def get_token_color(self, token: str) -> Optional[str]: """Get the color ('W' or 'B') from a piece token, None otherwise.""" if self.is_piece_token(token) and len(token) >= 2: return token[0] # 'W' or 'B' return None def build_vocabulary_masks(self) -> dict: """Build boolean masks for each token category. Returns: Dictionary with keys: 'piece', 'square', 'suffix', 'eom', 'white_piece', 'black_piece'. Each value is a boolean list/tensor of length vocab_size. """ import torch vocab_size = len(self._vocab) masks = { 'piece': [False] * vocab_size, 'square': [False] * vocab_size, 'suffix': [False] * vocab_size, 'eom': [False] * vocab_size, 'white_piece': [False] * vocab_size, 'black_piece': [False] * vocab_size, } for token, token_id in self._vocab.items(): if self.is_piece_token(token): masks['piece'][token_id] = True color = self.get_token_color(token) if color == 'W': masks['white_piece'][token_id] = True elif color == 'B': masks['black_piece'][token_id] = True elif self.is_square_token(token): masks['square'][token_id] = True elif self.is_suffix_token(token): masks['suffix'][token_id] = True elif self.is_eom_token(token): masks['eom'][token_id] = True # Convert to tensors return {k: torch.tensor(v, dtype=torch.bool) for k, v in masks.items()} def analyze_generation_state(self, input_ids: torch.Tensor) -> dict: """Analyze the current generation state to determine next expected token. Args: input_ids: Tensor of shape (batch_size, seq_len) with token IDs. Returns: Dictionary with: - 'position': 0 (piece), 1 (from_square), 2 (to_square), 3 (suffix/eom) - 'expected_color': 'W' or 'B' - 'last_eom_idx': Index of last [EOM] token in sequence """ batch_size = input_ids.shape[0] results = [] for b in range(batch_size): seq = input_ids[b].tolist() # Find last [EOM] or [BOS] last_eom_idx = -1 for i in range(len(seq) - 1, -1, -1): token = self._ids_to_tokens.get(seq[i], self.UNK_TOKEN) if token in [self.EOM_TOKEN, self.BOS_TOKEN]: last_eom_idx = i break # Count tokens since last [EOM]/[BOS] (excluding padding) tokens_since_boundary = [] for i in range(last_eom_idx + 1, len(seq)): token = self._ids_to_tokens.get(seq[i], self.UNK_TOKEN) if token != self.PAD_TOKEN: tokens_since_boundary.append(token) # Determine position in move structure: [Piece][Square][Square][Suffix?][EOM] num_tokens = len(tokens_since_boundary) if num_tokens == 0: position = 0 # Expect piece elif num_tokens == 1: position = 1 # Expect from_square elif num_tokens == 2: position = 2 # Expect to_square else: position = 3 # Expect suffix or [EOM] # Determine expected color by counting complete moves # Count [EOM] tokens to get move number eom_count = sum(1 for i in seq if self._ids_to_tokens.get(i, '') == self.EOM_TOKEN) expected_color = 'W' if eom_count % 2 == 0 else 'B' results.append({ 'position': position, 'expected_color': expected_color, 'last_eom_idx': last_eom_idx, }) # For single batch, return dict directly; for multi-batch, return list return results[0] if batch_size == 1 else results def save_vocabulary( self, save_directory: str, filename_prefix: Optional[str] = None, ) -> tuple: """ Save the vocabulary to a JSON file. Args: save_directory: Directory to save the vocabulary. filename_prefix: Optional prefix for the filename. Returns: Tuple containing the path to the saved vocabulary file. """ if not os.path.isdir(save_directory): os.makedirs(save_directory, exist_ok=True) vocab_file = os.path.join( save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.json", ) with open(vocab_file, "w", encoding="utf-8") as f: json.dump(self._vocab, f, ensure_ascii=False, indent=2) return (vocab_file,) def count_vocab_from_dataset( dataset_name: str = "dlouapre/lichess_2025-01_1M", split: str = "train", column: str = "text", max_samples: Optional[int] = 10000, ) -> Dict[str, int]: """ Count token frequencies in a dataset (useful for vocabulary analysis). Args: dataset_name: Name of the dataset on Hugging Face Hub. split: Dataset split to use. column: Column containing the game strings. max_samples: Maximum number of samples to process. Returns: Dictionary mapping tokens to their frequencies. """ from collections import Counter from datasets import load_dataset dataset = load_dataset(dataset_name, split=split) if max_samples is not None: dataset = dataset.select(range(min(max_samples, len(dataset)))) token_counts = Counter() for example in dataset: moves = example[column].strip().split() token_counts.update(moves) return dict(token_counts)