""" Component-based Chess Tokenizer - Optimized for Parameter Efficiency. This tokenizer decomposes chess moves into reusable components: - Piece type (P, N, B, R, Q, K) - Source square (a1-h8) - Destination square (a1-h8) - Modifiers (capture, check, castling, etc.) Example: "WPe2e4" → ["P", "e2", "e4"] "BNg8f6(x)" → ["N", "g8", "f6", "(x)"] This reduces vocabulary from ~1682 to ~80 tokens, saving 205K parameters. """ from __future__ import annotations import json import os from typing import Dict, List, Optional from transformers import PreTrainedTokenizer class ComponentChessTokenizer(PreTrainedTokenizer): """ Component-based tokenizer for chess moves. Decomposes moves into: [piece, from_square, to_square, modifiers...] Key advantages: - 95% smaller vocabulary (1682 → 80 tokens) - Saves 205K embedding parameters - Better generalization to rare move combinations - Compositional understanding of chess structure """ model_input_names = ["input_ids", "attention_mask"] vocab_files_names = {"vocab_file": "vocab.json"} # Special tokens PAD_TOKEN = "[PAD]" BOS_TOKEN = "[BOS]" EOS_TOKEN = "[EOS]" UNK_TOKEN = "[UNK]" SEP_TOKEN = "[SEP]" # Separates components within a move # Chess piece types (6 tokens) PIECES = ["P", "N", "B", "R", "Q", "K"] # All squares on the board (64 tokens) FILES = "abcdefgh" RANKS = "12345678" # Move modifiers (10 tokens) MODIFIERS = [ "(x)", # capture "(+)", # check "(+*)", # checkmate "(o)", # kingside castling "(O)", # queenside castling "=Q", # promotion to queen "=R", # promotion to rook "=B", # promotion to bishop "=N", # promotion to knight "(e.p.)", # en passant ] def __init__( self, vocab_file: Optional[str] = None, vocab: Optional[Dict[str, int]] = None, **kwargs, ): """Initialize the component chess tokenizer.""" # Initialize special tokens self._pad_token = self.PAD_TOKEN self._bos_token = self.BOS_TOKEN self._eos_token = self.EOS_TOKEN self._unk_token = self.UNK_TOKEN # Remove duplicate special-token entries kwargs.pop("pad_token", None) kwargs.pop("bos_token", None) kwargs.pop("eos_token", None) kwargs.pop("unk_token", None) # Load or create vocabulary if vocab is not None: self._vocab = vocab elif vocab_file is not None and os.path.exists(vocab_file): with open(vocab_file, "r", encoding="utf-8") as f: self._vocab = json.load(f) else: self._vocab = self._create_component_vocab() # Create reverse mapping self._ids_to_tokens = {v: k for k, v in self._vocab.items()} # Call parent init super().__init__( pad_token=self._pad_token, bos_token=self._bos_token, eos_token=self._eos_token, unk_token=self._unk_token, **kwargs, ) def _create_component_vocab(self) -> Dict[str, int]: """ Create the component vocabulary. Vocabulary structure: - Special tokens (5): [PAD], [BOS], [EOS], [UNK], [SEP] - Pieces (6): P, N, B, R, Q, K - Squares (64): a1, a2, ..., h8 - Modifiers (10): (x), (+), (+*), (o), (O), =Q, =R, =B, =N, (e.p.) Total: 85 tokens (vs 1682 in original tokenizer) """ tokens = [ self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN, self.SEP_TOKEN, ] # Add pieces tokens.extend(self.PIECES) # Add all squares squares = [f + r for f in self.FILES for r in self.RANKS] tokens.extend(squares) # Add modifiers tokens.extend(self.MODIFIERS) # Create vocabulary vocab = {token: idx for idx, token in enumerate(tokens)} return vocab @classmethod def build_vocab(cls) -> "ComponentChessTokenizer": """ Build tokenizer with component vocabulary. No dataset needed - vocabulary is deterministic based on chess rules. """ return cls() @property def vocab_size(self) -> int: """Return the size of the vocabulary.""" return len(self._vocab) def get_vocab(self) -> Dict[str, int]: """Return the vocabulary as a dictionary.""" return dict(self._vocab) def _decompose_move(self, move: str) -> List[str]: """ Decompose a move string into components. Examples: "WPe2e4" → ["P", "e2", "e4"] "BNg8f6(x)" → ["N", "g8", "f6", "(x)"] "WKe1g1(o)" → ["K", "e1", "g1", "(o)"] "BPe7e8=Q(+)" → ["P", "e7", "e8", "=Q", "(+)"] Args: move: Extended UCI move string (e.g., "WPe2e4") Returns: List of component tokens """ if not move or move in [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]: return [move] components = [] # Remove color prefix (W/B) if move.startswith(('W', 'B')): move = move[1:] if not move: return [self.UNK_TOKEN] # Extract piece type piece = move[0] if piece in self.PIECES: components.append(piece) move = move[1:] else: # Invalid piece return [self.UNK_TOKEN] # Extract squares (from and to) # Format: [modifiers] # E.g., "Pe2e4", "Ng1f3(x)", "Ke1g1(o)" if len(move) < 4: # Not enough characters for two squares return [self.UNK_TOKEN] # Generate valid squares for checking valid_squares = [f + r for f in self.FILES for r in self.RANKS] # Extract from_square (2 chars) from_square = move[0:2] if from_square in valid_squares: components.append(from_square) else: return [self.UNK_TOKEN] # Extract to_square (2 chars) to_square = move[2:4] if to_square in valid_squares: components.append(to_square) else: return [self.UNK_TOKEN] # Extract modifiers (remaining characters) remaining = move[4:] if remaining: # Parse modifiers: (x), (+), (+*), (o), (O), =Q, =R, =B, =N, (e.p.) i = 0 while i < len(remaining): # Check for known modifiers found = False for modifier in self.MODIFIERS: if remaining[i:].startswith(modifier): components.append(modifier) i += len(modifier) found = True break if not found: # Unknown character, skip it i += 1 return components def _tokenize(self, text: str) -> List[str]: """ Tokenize a string of moves into component tokens. Args: text: Space-separated moves (e.g., "WPe2e4 BPe7e5 WNg1f3") Returns: List of component tokens """ moves = text.strip().split() tokens = [] for move in moves: # Skip special tokens if move in [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]: tokens.append(move) else: # Decompose move into components components = self._decompose_move(move) tokens.extend(components) return tokens def _convert_token_to_id(self, token: str) -> int: """Convert a token to its ID.""" return self._vocab.get(token, self._vocab.get(self.UNK_TOKEN, 0)) def _convert_id_to_token(self, index: int) -> str: """Convert an ID to its token.""" return self._ids_to_tokens.get(index, self.UNK_TOKEN) def convert_tokens_to_string(self, tokens: List[str]) -> str: """ Convert component tokens back to move strings. This reconstructs moves from components. Note: We lose the W/B color prefix, but it's redundant (can be inferred from move position). """ # Filter out special tokens special = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN, self.SEP_TOKEN} tokens = [t for t in tokens if t not in special] # Generate valid squares for checking valid_squares = [f + r for f in self.FILES for r in self.RANKS] # Reconstruct moves from components moves = [] i = 0 while i < len(tokens): # Expect: piece, from_square, to_square, [modifiers...] if i + 2 >= len(tokens): break piece = tokens[i] from_sq = tokens[i + 1] to_sq = tokens[i + 2] if piece in self.PIECES and from_sq in valid_squares and to_sq in valid_squares: move = f"{piece}{from_sq}{to_sq}" i += 3 # Collect modifiers while i < len(tokens) and tokens[i] in self.MODIFIERS: move += tokens[i] i += 1 moves.append(move) else: # Skip invalid tokens i += 1 return " ".join(moves) def save_vocabulary( self, save_directory: str, filename_prefix: Optional[str] = None, ) -> tuple: """Save the vocabulary to a JSON file.""" if not os.path.isdir(save_directory): os.makedirs(save_directory, exist_ok=True) vocab_file = os.path.join( save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.json", ) with open(vocab_file, "w", encoding="utf-8") as f: json.dump(self._vocab, f, ensure_ascii=False, indent=2) return (vocab_file,)