""" Compositional Chess Tokenizer for the Chess Challenge. This tokenizer decomposes chess moves into meaningful components: - Color (W/B), Piece (P/N/B/R/Q/K), Squares, Actions, Modifiers Reduces vocabulary from 3803 to 86 tokens while enabling better generalization. Example: WPe2e4 -> [W, P, e2, ->, e4] BNg8f6(x) -> [B, N, g8, x, f6] """ from __future__ import annotations import json import os import re from typing import Dict, List, Optional from transformers import PreTrainedTokenizer class ChessTokenizer(PreTrainedTokenizer): """Compositional tokenizer for chess moves (86 tokens vs 3803 baseline).""" model_input_names = ["input_ids", "attention_mask"] vocab_files_names = {"vocab_file": "vocab.json"} PAD_TOKEN = "[PAD]" BOS_TOKEN = "[BOS]" EOS_TOKEN = "[EOS]" UNK_TOKEN = "[UNK]" # Use ASCII-safe tokens to avoid encoding issues and mismatches MOVE_ARROW = "->" CAPTURE_CROSS = "x" def __init__(self, vocab_file: Optional[str] = None, vocab: Optional[Dict[str, int]] = None, **kwargs): """Initialize compositional chess tokenizer.""" self._pad_token = self.PAD_TOKEN self._bos_token = self.BOS_TOKEN self._eos_token = self.EOS_TOKEN self._unk_token = self.UNK_TOKEN kwargs.pop("pad_token", None) kwargs.pop("bos_token", None) kwargs.pop("eos_token", None) kwargs.pop("unk_token", None) if vocab is not None: self._vocab = vocab elif vocab_file is not None and os.path.exists(vocab_file): with open(vocab_file, "r", encoding="utf-8") as f: self._vocab = json.load(f) else: print("Building compositional vocabulary (86 tokens)...") self._vocab = self._build_compositional_vocab() self._ids_to_tokens = {v: k for k, v in self._vocab.items()} super().__init__( pad_token=self._pad_token, bos_token=self._bos_token, eos_token=self._eos_token, unk_token=self._unk_token, **kwargs, ) def _build_compositional_vocab(self) -> Dict[str, int]: """Build 86-token compositional vocabulary.""" vocab = {} idx = 0 # Special tokens (4 tokens) for token in [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]: vocab[token] = idx idx += 1 # Colors (2 tokens: W, B) for color in ["W", "B"]: vocab[color] = idx idx += 1 # Pieces (6 tokens: P, N, B, R, Q, K) for piece in ["P", "N", "B", "R", "Q", "K"]: vocab[piece] = idx idx += 1 # Squares (64 tokens: a1-h8) for f in ["a", "b", "c", "d", "e", "f", "g", "h"]: for r in ["1", "2", "3", "4", "5", "6", "7", "8"]: vocab[f + r] = idx idx += 1 # Actions (2 tokens: →move, ×capture) vocab[self.MOVE_ARROW] = idx idx += 1 vocab[self.CAPTURE_CROSS] = idx idx += 1 # Modifiers (6 tokens: +check, +*checkmate, =Q/R/B/N promotions) for mod in ["+", "+*", "=Q", "=R", "=B", "=N"]: vocab[mod] = idx idx += 1 # Special moves (2 tokens: O-O, O-O-O) for move in ["O-O", "O-O-O"]: vocab[move] = idx idx += 1 return vocab @property def vocab_size(self) -> int: """Return vocabulary size.""" return len(self._vocab) def get_vocab(self) -> Dict[str, int]: """Return vocabulary dictionary.""" return dict(self._vocab) def _decompose_move(self, move: str) -> List[str]: """Decompose chess move into component tokens. Args: move: Chess move in format WPe2e4 or BNg8f6(x) Returns: List of component tokens [color, piece, from_sq, action, to_sq, modifiers...] """ move = move.strip() # Handle castling if "O-O-O" in move or "o-o-o" in move.lower(): return [move[0], "O-O-O"] # Color + castling elif "O-O" in move or "o-o" in move.lower(): return [move[0], "O-O"] # Color + castling # Basic validation if len(move) < 6: return [self.UNK_TOKEN] # Extract basic components tokens = [ move[0], # Color (W/B) move[1], # Piece (P/N/B/R/Q/K) move[2:4] # From square (e.g., e2) ] # Determine action (capture vs move) is_capture = "(x)" in move or "(x+" in move or "(x+*)" in move tokens.append(self.CAPTURE_CROSS if is_capture else self.MOVE_ARROW) # To square tokens.append(move[4:6]) # Add modifiers (check, checkmate) if "(+*)" in move or "(x+*)" in move: tokens.append("+*") elif "(+)" in move or "(x+)" in move: tokens.append("+") # Handle promotions (e.g., (Q), (+Q), (xQ)) promotion_match = re.search(r'\((?:x\s*)?(?:\+\s*)?([QRBN])\)', move) if promotion_match: tokens.append(f"={promotion_match.group(1)}") return tokens def _tokenize(self, text: str) -> List[str]: """Tokenize string of moves into component tokens. Args: text: String of space-separated chess moves Returns: List of all component tokens """ all_tokens = [] for move in text.strip().split(): all_tokens.extend(self._decompose_move(move)) return all_tokens def _convert_token_to_id(self, token: str) -> int: """Convert token to ID.""" return self._vocab.get(token, self._vocab.get(self.UNK_TOKEN, 3)) def _convert_id_to_token(self, index: int) -> str: """Convert ID to token.""" return self._ids_to_tokens.get(index, self.UNK_TOKEN) def convert_tokens_to_string(self, tokens: List[str]) -> str: """Convert component tokens back to move strings. Args: tokens: List of component tokens Returns: String of reconstructed chess moves """ # Filter out special tokens special = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN} tokens = [t for t in tokens if t not in special] moves = [] i = 0 while i < len(tokens): # Handle castling if i + 1 < len(tokens) and tokens[i + 1] in ["O-O", "O-O-O"]: # Castling is just color + castle notation i += 2 continue # Regular move: need at least 5 tokens [color, piece, from_sq, action, to_sq] if i + 4 < len(tokens): color, piece, from_sq, action, to_sq = tokens[i:i+5] move = f"{color}{piece}{from_sq}{to_sq}" # Collect modifiers that follow j = i + 5 while j < len(tokens) and tokens[j] not in ["W", "B"]: mod = tokens[j] if mod == "+": move += "(+)" elif mod == "+*": move += "(+*)" elif mod.startswith("="): move += f"({mod[1]})" # Promotion j += 1 # Add capture marker if action was capture if action == self.CAPTURE_CROSS and "(x)" not in move: move += "(x)" moves.append(move) i = j else: i += 1 return " ".join(moves) def build_inputs_with_special_tokens(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None) -> List[int]: """Build model inputs with BOS/EOS tokens.""" bos_id = self._vocab[self.BOS_TOKEN] eos_id = self._vocab[self.EOS_TOKEN] def wrap(ids: List[int]) -> List[int]: out = list(ids) # Only add BOS if missing if len(out) == 0 or out[0] != bos_id: out = [bos_id] + out # Only add EOS if missing if len(out) == 0 or out[-1] != eos_id: out = out + [eos_id] return out if token_ids_1 is None: return wrap(token_ids_0) # For pair inputs, keep it simple/consistent: wrap concatenation return wrap(token_ids_0 + token_ids_1) def get_special_tokens_mask( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False ) -> List[int]: """Return mask identifying special tokens.""" bos_id = self._vocab[self.BOS_TOKEN] eos_id = self._vocab[self.EOS_TOKEN] pad_id = self._vocab[self.PAD_TOKEN] unk_id = self._vocab[self.UNK_TOKEN] def mask_for(ids: List[int]) -> List[int]: return [1 if tid in {bos_id, eos_id, pad_id, unk_id} else 0 for tid in ids] if already_has_special_tokens: if token_ids_1 is None: return mask_for(token_ids_0) return mask_for(token_ids_0 + token_ids_1) # No special tokens yet: build mask consistent with build_inputs_with_special_tokens if token_ids_1 is None: return [1] + ([0] * len(token_ids_0)) + [1] return [1] + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + [1] def create_token_type_ids_from_sequences( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None ) -> List[int]: """Create token type IDs (all zeros for single sequence type).""" # Build final ids consistently with build_inputs_with_special_tokens final_ids = self.build_inputs_with_special_tokens(token_ids_0, token_ids_1) return [0] * len(final_ids) def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple: """Save vocabulary to JSON file.""" if not os.path.isdir(save_directory): os.makedirs(save_directory, exist_ok=True) vocab_file = os.path.join( save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.json" ) with open(vocab_file, "w", encoding="utf-8") as f: json.dump(self._vocab, f, ensure_ascii=False, indent=2) return (vocab_file,)