""" Custom Chess Tokenizer for the Chess Challenge. We build a vocabulary with: - W/B prefix for White/Black - Piece letter: P=Pawn, N=Knight, B=Bishop, R=Rook, Q=Queen, K=King - Source and rank and file: e.g e 2 - Destination and rank and file: e.g e 4 - Special suffixes: (x)=capture, (+)=check, (+*)=checkmate, (o)/(O)=castling """ from __future__ import annotations import json import os from pathlib import Path import shutil import inspect from typing import Dict, List, Optional from transformers import PreTrainedTokenizer from datasets import load_dataset class ChessTokenizer(PreTrainedTokenizer): model_input_names = ["input_ids", "attention_mask"] vocab_files_names = {"vocab_file": "vocab.json"} # Special tokens PAD_TOKEN = "[PAD]" BOS_TOKEN = "[BOS]" EOS_TOKEN = "[EOS]" UNK_TOKEN = "[UNK]" SEP_TOKEN = "[SEP]" def __init__( self, vocab_file: Optional[str] = None, vocab: Optional[Dict[str, int]] = None, **kwargs, ): self._pad_token = self.PAD_TOKEN self._bos_token = self.BOS_TOKEN self._eos_token = self.EOS_TOKEN self._unk_token = self.UNK_TOKEN self._sep_token = self.SEP_TOKEN kwargs.pop("pad_token", None) kwargs.pop("bos_token", None) kwargs.pop("eos_token", None) kwargs.pop("unk_token", None) kwargs.pop("sep_token", None) print("Initializing ChessTokenizer") print(f" vocab_file: {vocab_file}") print(f" vocab provided: {vocab is not None}") print(f" vocab: {vocab}") print(os.listdir(".")) vocab = { "[PAD]": 0, "[BOS]": 1, "[EOS]": 2, "[UNK]": 3, "[SEP]": 4, "(+)": 5, "(+*)": 6, "(+*B)": 7, "(+*N)": 8, "(+*Q)": 9, "(+*R)": 10, "(+B)": 11, "(+N)": 12, "(+Q)": 13, "(+R)": 14, "(B)": 15, "(N)": 16, "(O)": 17, "(O+)": 18, "(O+*)": 19, "(Q)": 20, "(R)": 21, "(o)": 22, "(o+)": 23, "(o+*)": 24, "(x)": 25, "(x+)": 26, "(x+*)": 27, "(x+*B)": 28, "(x+*Q)": 29, "(x+*R)": 30, "(x+B)": 31, "(x+N)": 32, "(x+Q)": 33, "(x+R)": 34, "(xB)": 35, "(xE)": 36, "(xE+)": 37, "(xE+*)": 38, "(xN)": 39, "(xQ)": 40, "(xR)": 41, "B": 42, "K": 43, "N": 44, "P": 45, "Q": 46, "R": 47, "W": 48, "a1": 49, "a2": 50, "a3": 51, "a4": 52, "a5": 53, "a6": 54, "a7": 55, "a8": 56, "b1": 57, "b2": 58, "b3": 59, "b4": 60, "b5": 61, "b6": 62, "b7": 63, "b8": 64, "c1": 65, "c2": 66, "c3": 67, "c4": 68, "c5": 69, "c6": 70, "c7": 71, "c8": 72, "d1": 73, "d2": 74, "d3": 75, "d4": 76, "d5": 77, "d6": 78, "d7": 79, "d8": 80, "e1": 81, "e2": 82, "e3": 83, "e4": 84, "e5": 85, "e6": 86, "e7": 87, "e8": 88, "f1": 89, "f2": 90, "f3": 91, "f4": 92, "f5": 93, "f6": 94, "f7": 95, "f8": 96, "g1": 97, "g2": 98, "g3": 99, "g4": 100, "g5": 101, "g6": 102, "g7": 103, "g8": 104, "h1": 105, "h2": 106, "h3": 107, "h4": 108, "h5": 109, "h6": 110, "h7": 111, "h8": 112, } if vocab is not None: self._vocab = vocab elif vocab_file is not None and os.path.exists(vocab_file): with open(vocab_file, "r", encoding="utf-8") as f: self._vocab = json.load(f) else: print("No vocabulary provided; creating default minimal vocab.") self._vocab = self._create_default_vocab() self._ids_to_tokens = {v: k for k, v in self._vocab.items()} super().__init__( pad_token=self._pad_token, bos_token=self._bos_token, eos_token=self._eos_token, unk_token=self._unk_token, sep_token=self._sep_token, **kwargs, ) def _create_default_vocab(self) -> Dict[str, int]: special_tokens = [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN, self.SEP_TOKEN] vocab = {token: idx for idx, token in enumerate(special_tokens)} return vocab @classmethod def build_vocab_from_dataset( cls, dataset_name: str = "dlouapre/lichess_2025-01_1M", split: str = "train", column: str = "text", min_frequency: Optional[int] = 1, max_samples: Optional[int] = None, save_path: Optional[str] = None, ) -> "ChessTokenizer": if save_path is None: cwd = os.getcwd() save_path = os.path.join(cwd, "chess_tokenizer_vocab.json") if os.path.exists(save_path): try: with open(save_path, "r", encoding="utf-8") as f: print("Loading existing tokenizer vocab from", save_path) vocab = json.load(f) return cls(vocab=vocab) except Exception: pass dataset = load_dataset(dataset_name, split=split) samples = dataset[column] tokens = set() for game in samples: if not isinstance(game, str): continue moves = game.strip().split() for move in moves: if len(move) < 2: continue color = move[0] piece = move[1] from_square = move[2:4] if len(move) >= 4 else '' to_square = move[4:6] if len(move) >= 6 else '' suffix = move[6:] if len(move) > 6 else '' tokens.add(color) tokens.add(piece) tokens.add(from_square) tokens.add(to_square) if suffix: tokens.add(suffix) tokens = sorted(tokens) special_tokens = [cls.PAD_TOKEN, cls.BOS_TOKEN, cls.EOS_TOKEN, cls.UNK_TOKEN, cls.SEP_TOKEN] vocab: Dict[str, int] = {} idx = 0 for st in special_tokens: vocab[st] = idx idx += 1 for t in tokens: if t in vocab: continue vocab[t] = idx idx += 1 tokenizer = cls(vocab=vocab) try: if save_path is None: cwd = os.getcwd() save_path = os.path.join(cwd, "chess_tokenizer_vocab.json") tmp_path = save_path + ".tmp" with open(tmp_path, "w", encoding="utf-8") as f: json.dump(vocab, f, ensure_ascii=False, indent=2) os.replace(tmp_path, save_path) except Exception: # Non-fatal: ignore save errors but don't leave temp files behind. try: if 'tmp_path' in locals() and os.path.exists(tmp_path): os.remove(tmp_path) except Exception: pass return tokenizer @property def vocab_size(self) -> int: """Return the size of the vocabulary.""" return len(self._vocab) def get_vocab(self) -> Dict[str, int]: """Return the vocabulary as a dictionary.""" return dict(self._vocab) def _tokenize(self, text: str) -> List[str]: """ Tokenize a string of moves into a list of tokens. Args: text: A string of space-separated moves. Returns: List of move tokens. """ tokens: List[str] = [] for move in text.strip().split(): if len(move) < 2: continue color, piece, from_square, to_square, suffix = self._decompose_move(move) tokens.append(color) tokens.append(piece) tokens.append(from_square) tokens.append(to_square) if suffix: tokens.append(suffix) tokens.append(self._sep_token) return tokens[:-1] # Remove last SEP token @staticmethod def _decompose_move(move: str): """Decompose a move string into components: color, piece, from_square, to_square, suffix. Returns a 5-tuple of strings (empty strings for missing parts). """ color = move[0] piece = move[1] if len(move) >= 2 else '' from_square = move[2:4] if len(move) >= 4 else '' to_square = move[4:6] if len(move) >= 6 else '' suffix = move[6:] if len(move) > 6 else '' return color, piece, from_square, to_square, suffix def _convert_token_to_id(self, token: str) -> int: """Convert a token to its ID.""" return self._vocab.get(token, self._vocab.get(self.UNK_TOKEN, 0)) def _convert_id_to_token(self, index: int) -> str: """Convert an ID to its token.""" return self._ids_to_tokens.get(index, self.UNK_TOKEN) def convert_tokens_to_string(self, tokens: List[str]) -> str: """Convert a list of tokens back to a string.""" # Filter out special tokens for cleaner output special = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN} return " ".join(t for t in tokens if t not in special) def decode(self, token_ids: List[int], skip_special_tokens: bool = True) -> str: """Decode a list of token IDs back to a string.""" tokens = [self._convert_id_to_token(int(tid)) for tid in token_ids] if skip_special_tokens: special = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN} # SEP token should be replace by space tokens = [t if t != self.SEP_TOKEN else " " for t in tokens if t not in special] return "".join(tokens) def save_vocabulary( self, save_directory: str, filename_prefix: Optional[str] = None, ) -> tuple: """ Save the vocabulary to a JSON file. Args: save_directory: Directory to save the vocabulary. filename_prefix: Optional prefix for the filename. Returns: Tuple containing the path to the saved vocabulary file. """ if not os.path.isdir(save_directory): os.makedirs(save_directory, exist_ok=True) vocab_file = os.path.join( save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.json", ) with open(vocab_file, "w", encoding="utf-8") as f: json.dump(self._vocab, f, ensure_ascii=False, indent=2) return (vocab_file,) def save_pretrained( self, save_directory: str, filename_prefix: Optional[str] = None, save_tokenizer_code: bool = True, ) -> None: """Save tokenizer files to a directory in a HF-compatible layout. This writes the vocab JSON (via `save_vocabulary`), a small `tokenizer_config.json` describing special tokens and the vocab filename, and optionally copies the tokenizer module source file into the directory so others can import the implementation. """ if not os.path.isdir(save_directory): os.makedirs(save_directory, exist_ok=True) # Save the vocabulary file vocab_file_tuple = self.save_vocabulary(save_directory, filename_prefix) vocab_file = vocab_file_tuple[0] # Write a minimal tokenizer config config = { "tokenizer_class": self.__class__.__name__, "vocab_file": os.path.basename(vocab_file), "pad_token": self.PAD_TOKEN, "bos_token": self.BOS_TOKEN, "eos_token": self.EOS_TOKEN, "unk_token": self.UNK_TOKEN, } config_path = os.path.join(save_directory, "tokenizer_config.json") with open(config_path, "w", encoding="utf-8") as f: json.dump(config, f, ensure_ascii=False, indent=2) # Optionally copy this module file so the tokenizer class implementation # is available alongside the saved vocab/config. This helps when # transferring the saved tokenizer to another environment. if save_tokenizer_code: try: src_file = Path(inspect.getsourcefile(self.__class__)) dst_file = Path(save_directory) / src_file.name shutil.copy2(src_file, dst_file) except Exception: # Non-fatal; we still saved vocab and config pass def count_vocab_from_dataset( dataset_name: str = "dlouapre/lichess_2025-01_1M", split: str = "train", column: str = "text", max_samples: Optional[int] = 10000, ) -> Dict[str, int]: """ Count token frequencies in a dataset (useful for vocabulary analysis). Args: dataset_name: Name of the dataset on Hugging Face Hub. split: Dataset split to use. column: Column containing the game strings. max_samples: Maximum number of samples to process. Returns: Dictionary mapping tokens to their frequencies. """ from collections import Counter from datasets import load_dataset dataset = load_dataset(dataset_name, split=split) if max_samples is not None: dataset = dataset.select(range(min(max_samples, len(dataset)))) tokenizer = ChessTokenizer() token_counts = Counter() for example in dataset: token_counts.update(tokenizer._tokenize(example[column])) return dict(token_counts)