""" Custom Chess Tokenizer for the Chess Challenge. This tokenizer uses sub-token decomposition to achieve a minimal vocabulary by breaking moves into atomic components (squares + modifiers). Example: WPe2e4 → ["e2", "e4"] BNg8f6(+) → ["g8", "f6", "+"] This approach trades sequence length (3x longer) for vocabulary size (77 vs 1800+). """ from __future__ import annotations import json import os from typing import Dict, List, Optional from transformers import PreTrainedTokenizer class ChessTokenizer(PreTrainedTokenizer): """ Sub-token chess tokenizer with minimal fixed vocabulary. Decomposes each move into: - Source square (e.g., e2) - Destination square (e.g., e4) - Optional modifiers (x, +, +*, Q/R/B/N, O/o) Vocabulary composition: - 64 squares (a1-h8) - 9 modifiers (x, +, +*, Q, R, B, N, O, o) - 4 special tokens ([PAD], [BOS], [EOS], [UNK]) Total: 77 tokens """ model_input_names = ["input_ids", "attention_mask"] vocab_files_names = {"vocab_file": "vocab.json"} # Special tokens PAD_TOKEN = "[PAD]" BOS_TOKEN = "[BOS]" EOS_TOKEN = "[EOS]" UNK_TOKEN = "[UNK]" def __init__( self, vocab_file: Optional[str] = None, vocab: Optional[Dict[str, int]] = None, **kwargs, ): """ Initialize the chess tokenizer. """ # Initialize special tokens self._pad_token = self.PAD_TOKEN self._bos_token = self.BOS_TOKEN self._eos_token = self.EOS_TOKEN self._unk_token = self.UNK_TOKEN kwargs.pop("pad_token", None) kwargs.pop("bos_token", None) kwargs.pop("eos_token", None) kwargs.pop("unk_token", None) # Load or create vocabulary if vocab is not None: self._vocab = vocab elif vocab_file is not None and os.path.exists(vocab_file): with open(vocab_file, "r", encoding="utf-8") as f: self._vocab = json.load(f) else: self._vocab = self.build_minimal_vocab().get_vocab() # Reverse mapping self._ids_to_tokens = {v: k for k, v in self._vocab.items()} # Call parent init super().__init__( pad_token=self._pad_token, bos_token=self._bos_token, eos_token=self._eos_token, unk_token=self._unk_token, **kwargs, ) @classmethod def build_minimal_vocab(cls) -> "ChessTokenizer": """ Build tokenizer with minimal fixed vocabulary (77 tokens). """ files = "abcdefgh" ranks = "12345678" squares = [f + r for f in files for r in ranks] modifiers = ["x", "+", "+*", "Q", "R", "B", "N", "O", "o"] special_tokens = [cls.PAD_TOKEN, cls.BOS_TOKEN, cls.EOS_TOKEN, cls.UNK_TOKEN] vocab_tokens = special_tokens + squares + modifiers vocab = {tok: i for i, tok in enumerate(vocab_tokens)} return cls(vocab=vocab) @property def vocab_size(self) -> int: return len(self._vocab) def get_vocab(self) -> Dict[str, int]: return dict(self._vocab) def _tokenize(self, text: str) -> List[str]: """ Tokenize moves into squares + modifiers. Examples: WPe2e4 -> ["e2", "e4"] BNg8f6(+) -> ["g8", "f6", "+"] WKe1g1(O) -> ["e1", "g1", "O"] """ tokens = [] for move in text.strip().split(): if len(move) < 4: continue core = move[2:] # Remove color + piece # Squares from_sq = core[0:2] to_sq = core[2:4] tokens.extend([from_sq, to_sq]) # Modifiers suffix = core[4:] if "x" in suffix: tokens.append("x") if "+*" in suffix: tokens.append("+*") elif "+" in suffix: tokens.append("+") for promo in ["Q", "R", "B", "N"]: if f"({promo})" in suffix: tokens.append(promo) # Castling if "O" in move or "o" in move: tokens.append("O") return tokens def _convert_token_to_id(self, token: str) -> int: return self._vocab.get(token, self._vocab.get(self.UNK_TOKEN, 0)) def _convert_id_to_token(self, index: int) -> str: return self._ids_to_tokens.get(index, self.UNK_TOKEN) def convert_tokens_to_string(self, tokens: List[str]) -> str: """Convert sub-tokens back to string representation.""" special = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN} return " ".join(t for t in tokens if t not in special) def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple: if not os.path.isdir(save_directory): os.makedirs(save_directory, exist_ok=True) vocab_file = os.path.join( save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.json", ) with open(vocab_file, "w", encoding="utf-8") as f: json.dump(self._vocab, f, ensure_ascii=False, indent=2) return (vocab_file,) # ===== Example usage ===== if __name__ == "__main__": tokenizer = ChessTokenizer.build_minimal_vocab() print(f"Vocabulary size: {tokenizer.vocab_size}") test_games = [ "WPe2e4 BPe7e5", "WNg1f3 BNb8c6", "WBb5c6(x) BPd7d6", "WPe7e8(Q) BKe8d7", "WKe1g1(O) BKe8c8(o)", ] for game in test_games: print(f"\nOriginal: {game}") tokens = tokenizer._tokenize(game) print(f"Tokens: {tokens}") ids = tokenizer.convert_tokens_to_ids(tokens) print(f"IDs: {ids}") decoded = tokenizer.convert_ids_to_tokens(ids) print(f"Decoded: {decoded}")