| # """ | |
| # Custom Chess Tokenizer for the Chess Challenge. | |
| # This tokenizer treats each move as a single token using the extended UCI notation | |
| # from the Lichess dataset (e.g., WPe2e4, BNg8f6). | |
| # The dataset format uses: | |
| # - W/B prefix for White/Black | |
| # - Piece letter: P=Pawn, N=Knight, B=Bishop, R=Rook, Q=Queen, K=King | |
| # - Source and destination squares (e.g., e2e4) | |
| # - Special suffixes: (x)=capture, (+)=check, (+*)=checkmate, (o)/(O)=castling | |
| # """ | |
| # from __future__ import annotations | |
| # import json | |
| # import osz | |
| # from pathlib import Path | |
| # from typing import Dict, List, Optional | |
| # from transformers import PreTrainedTokenizer | |
| # class ChessTokenizer(PreTrainedTokenizer): | |
| # """ | |
| # A custom tokenizer for chess moves using extended UCI notation. | |
| # This tokenizer maps each possible chess move to a unique token ID. | |
| # The vocabulary is built from the training dataset to ensure all moves | |
| # encountered during training have a corresponding token. | |
| # Example: | |
| # >>> tokenizer = ChessTokenizer() | |
| # >>> tokenizer.encode("WPe2e4 BPe7e5") | |
| # [1, 42, 87, 2] # [BOS, e2e4, e7e5, EOS] | |
| # """ | |
| # model_input_names = ["input_ids", "attention_mask"] | |
| # vocab_files_names = {"vocab_file": "vocab.json"} | |
| # # Special tokens | |
| # PAD_TOKEN = "[PAD]" | |
| # BOS_TOKEN = "[BOS]" | |
| # EOS_TOKEN = "[EOS]" | |
| # UNK_TOKEN = "[UNK]" | |
| # def __init__( | |
| # self, | |
| # vocab_file: Optional[str] = None, | |
| # vocab: Optional[Dict[str, int]] = None, | |
| # **kwargs, | |
| # ): | |
| # """ | |
| # Initialize the chess tokenizer. | |
| # Args: | |
| # vocab_file: Path to a JSON file containing the vocabulary mapping. | |
| # vocab: Dictionary mapping tokens to IDs (alternative to vocab_file). | |
| # **kwargs: Additional arguments passed to PreTrainedTokenizer. | |
| # """ | |
| # # Initialize special tokens | |
| # self._pad_token = self.PAD_TOKEN | |
| # self._bos_token = self.BOS_TOKEN | |
| # self._eos_token = self.EOS_TOKEN | |
| # self._unk_token = self.UNK_TOKEN | |
| # # Remove any duplicate special-token entries passed through kwargs | |
| # # to avoid "multiple values for keyword" errors when loading from disk. | |
| # kwargs.pop("pad_token", None) | |
| # kwargs.pop("bos_token", None) | |
| # kwargs.pop("eos_token", None) | |
| # kwargs.pop("unk_token", None) | |
| # # Load or create vocabulary | |
| # if vocab is not None: | |
| # self._vocab = vocab | |
| # elif vocab_file is not None and os.path.exists(vocab_file): | |
| # with open(vocab_file, "r", encoding="utf-8") as f: | |
| # self._vocab = json.load(f) | |
| # else: | |
| # # Create a minimal vocabulary with just special tokens | |
| # # The full vocabulary should be built from the dataset | |
| # self._vocab = self._create_default_vocab() | |
| # # Create reverse mapping | |
| # self._ids_to_tokens = {v: k for k, v in self._vocab.items()} | |
| # # Call parent init AFTER setting up vocab | |
| # super().__init__( | |
| # pad_token=self._pad_token, | |
| # bos_token=self._bos_token, | |
| # eos_token=self._eos_token, | |
| # unk_token=self._unk_token, | |
| # **kwargs, | |
| # ) | |
| # def _create_default_vocab(self) -> Dict[str, int]: | |
| # """ | |
| # Create a minimal default vocabulary with just special tokens. | |
| # For the full vocabulary, use `build_vocab_from_dataset()`. | |
| # This minimal vocab is just a placeholder - you should build from data. | |
| # """ | |
| # special_tokens = [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN] | |
| # vocab = {token: idx for idx, token in enumerate(special_tokens)} | |
| # return vocab | |
| # @classmethod | |
| # def build_vocab_from_iterator( | |
| # cls, | |
| # iterator, | |
| # min_frequency: int = 1, | |
| # ) -> "ChessTokenizer": | |
| # """ | |
| # Build a tokenizer vocabulary from an iterator of game strings. | |
| # Args: | |
| # iterator: An iterator yielding game strings (space-separated moves). | |
| # min_frequency: Minimum frequency for a token to be included. | |
| # Returns: | |
| # A ChessTokenizer with the built vocabulary. | |
| # """ | |
| # from collections import Counter | |
| # token_counts = Counter() | |
| # for game in iterator: | |
| # moves = game.strip().split() | |
| # token_counts.update(moves) | |
| # # Filter by frequency | |
| # tokens = [ | |
| # token for token, count in token_counts.items() | |
| # if count >= min_frequency | |
| # ] | |
| # # Sort for reproducibility | |
| # tokens = sorted(tokens) | |
| # # Build vocabulary | |
| # special_tokens = [cls.PAD_TOKEN, cls.BOS_TOKEN, cls.EOS_TOKEN, cls.UNK_TOKEN] | |
| # vocab = {token: idx for idx, token in enumerate(special_tokens + tokens)} | |
| # return cls(vocab=vocab) | |
| # @classmethod | |
| # def build_vocab_from_dataset( | |
| # cls, | |
| # dataset_name: str = "dlouapre/lichess_2025-01_1M", | |
| # split: str = "train", | |
| # column: str = "text", | |
| # min_frequency: int = 500, | |
| # max_samples: Optional[int] = 100000, | |
| # ) -> "ChessTokenizer": | |
| # """ | |
| # Build a tokenizer vocabulary from a Hugging Face dataset. | |
| # Args: | |
| # dataset_name: Name of the dataset on Hugging Face Hub. | |
| # split: Dataset split to use. | |
| # column: Column containing the game strings. | |
| # min_frequency: Minimum frequency for a token to be included (default: 500). | |
| # max_samples: Maximum number of samples to process (default: 100k). | |
| # Returns: | |
| # A ChessTokenizer with the built vocabulary. | |
| # """ | |
| # from datasets import load_dataset | |
| # dataset = load_dataset(dataset_name, split=split) | |
| # if max_samples is not None: | |
| # dataset = dataset.select(range(min(max_samples, len(dataset)))) | |
| # def game_iterator(): | |
| # for example in dataset: | |
| # yield example[column] | |
| # return cls.build_vocab_from_iterator(game_iterator(), min_frequency=min_frequency) | |
| # @property | |
| # def vocab_size(self) -> int: | |
| # """Return the size of the vocabulary.""" | |
| # return len(self._vocab) | |
| # def get_vocab(self) -> Dict[str, int]: | |
| # """Return the vocabulary as a dictionary.""" | |
| # return dict(self._vocab) | |
| # def _tokenize(self, text: str) -> List[str]: | |
| # """ | |
| # Tokenize a string of moves into a list of tokens. | |
| # Args: | |
| # text: A string of space-separated moves. | |
| # Returns: | |
| # List of move tokens. | |
| # """ | |
| # return text.strip().split() | |
| # def _convert_token_to_id(self, token: str) -> int: | |
| # """Convert a token to its ID.""" | |
| # return self._vocab.get(token, self._vocab.get(self.UNK_TOKEN, 0)) | |
| # def _convert_id_to_token(self, index: int) -> str: | |
| # """Convert an ID to its token.""" | |
| # return self._ids_to_tokens.get(index, self.UNK_TOKEN) | |
| # def convert_tokens_to_string(self, tokens: List[str]) -> str: | |
| # """Convert a list of tokens back to a string.""" | |
| # # Filter out special tokens for cleaner output | |
| # special = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN} | |
| # return " ".join(t for t in tokens if t not in special) | |
| # def save_vocabulary( | |
| # self, | |
| # save_directory: str, | |
| # filename_prefix: Optional[str] = None, | |
| # ) -> tuple: | |
| # """ | |
| # Save the vocabulary to a JSON file. | |
| # Args: | |
| # save_directory: Directory to save the vocabulary. | |
| # filename_prefix: Optional prefix for the filename. | |
| # Returns: | |
| # Tuple containing the path to the saved vocabulary file. | |
| # """ | |
| # if not os.path.isdir(save_directory): | |
| # os.makedirs(save_directory, exist_ok=True) | |
| # vocab_file = os.path.join( | |
| # save_directory, | |
| # (filename_prefix + "-" if filename_prefix else "") + "vocab.json", | |
| # ) | |
| # with open(vocab_file, "w", encoding="utf-8") as f: | |
| # json.dump(self._vocab, f, ensure_ascii=False, indent=2) | |
| # return (vocab_file,) | |
| # def count_vocab_from_dataset( | |
| # dataset_name: str = "dlouapre/lichess_2025-01_1M", | |
| # split: str = "train", | |
| # column: str = "text", | |
| # max_samples: Optional[int] = 10000, | |
| # ) -> Dict[str, int]: | |
| # """ | |
| # Count token frequencies in a dataset (useful for vocabulary analysis). | |
| # Args: | |
| # dataset_name: Name of the dataset on Hugging Face Hub. | |
| # split: Dataset split to use. | |
| # column: Column containing the game strings. | |
| # max_samples: Maximum number of samples to process. | |
| # Returns: | |
| # Dictionary mapping tokens to their frequencies. | |
| # """ | |
| # from collections import Counter | |
| # from datasets import load_dataset | |
| # dataset = load_dataset(dataset_name, split=split) | |
| # if max_samples is not None: | |
| # dataset = dataset.select(range(min(max_samples, len(dataset)))) | |
| # token_counts = Counter() | |
| # for example in dataset: | |
| # moves = example[column].strip().split() | |
| # token_counts.update(moves) | |
| # return dict(token_counts) | |
| """ | |
| Decomposed Chess Tokenizer (Idea 1) | |
| Each move in extended UCI is split into structured subtokens: | |
| Example: | |
| WPe2e4 -> ["W", "P", "e2", "e4"] | |
| BNg8f6(x) -> ["B", "N", "g8", "f6", "(x)"] | |
| WPe7e8=Q(+) -> ["W", "P", "e7", "e8", "=Q", "(+)"] | |
| WKe1g1(o) -> ["W", "K", "e1", "g1", "(o)"] | |
| A full game string (space-separated moves) is expanded move-by-move. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import os | |
| from typing import Dict, List, Optional, Tuple | |
| from transformers import PreTrainedTokenizer | |
| class ChessTokenizer(PreTrainedTokenizer): | |
| model_input_names = ["input_ids", "attention_mask"] | |
| vocab_files_names = {"vocab_file": "vocab.json"} | |
| PAD_TOKEN = "[PAD]" | |
| BOS_TOKEN = "[BOS]" | |
| EOS_TOKEN = "[EOS]" | |
| UNK_TOKEN = "[UNK]" | |
| # Allowed atomic tokens | |
| COLORS = ["W", "B"] | |
| PIECES = ["P", "N", "B", "R", "Q", "K"] | |
| # Common suffixes in the dataset/template utils | |
| SUFFIXES = [ | |
| "(x)", "(+)", "(+*)", | |
| "(x+)", "(x+*)", | |
| "(o)", "(O)", | |
| ] | |
| PROMOTIONS = ["=Q", "=R", "=B", "=N"] | |
| def __init__( | |
| self, | |
| vocab_file: Optional[str] = None, | |
| vocab: Optional[Dict[str, int]] = None, | |
| **kwargs, | |
| ): | |
| # Special tokens | |
| self._pad_token = self.PAD_TOKEN | |
| self._bos_token = self.BOS_TOKEN | |
| self._eos_token = self.EOS_TOKEN | |
| self._unk_token = self.UNK_TOKEN | |
| # Avoid duplicate kwargs when loading | |
| kwargs.pop("pad_token", None) | |
| kwargs.pop("bos_token", None) | |
| kwargs.pop("eos_token", None) | |
| kwargs.pop("unk_token", None) | |
| if vocab is not None: | |
| self._vocab = vocab | |
| elif vocab_file is not None and os.path.exists(vocab_file): | |
| with open(vocab_file, "r", encoding="utf-8") as f: | |
| self._vocab = json.load(f) | |
| else: | |
| self._vocab = self._create_default_vocab() | |
| self._ids_to_tokens = {v: k for k, v in self._vocab.items()} | |
| super().__init__( | |
| pad_token=self._pad_token, | |
| bos_token=self._bos_token, | |
| eos_token=self._eos_token, | |
| unk_token=self._unk_token, | |
| **kwargs, | |
| ) | |
| # ----------------------- | |
| # Vocab building | |
| # ----------------------- | |
| def _create_default_vocab(self) -> Dict[str, int]: | |
| special = [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN] | |
| squares = [f"{file}{rank}" for rank in "12345678" for file in "abcdefgh"] | |
| tokens = [] | |
| tokens += self.COLORS | |
| tokens += self.PIECES | |
| tokens += squares | |
| tokens += self.PROMOTIONS | |
| tokens += self.SUFFIXES | |
| vocab_tokens = special + tokens | |
| return {tok: i for i, tok in enumerate(vocab_tokens)} | |
| def build_vocab_from_iterator(cls, iterator, min_frequency: int = 1) -> "ChessTokenizer": | |
| # For this decomposed tokenizer, vocab is fixed (structured tokens), | |
| # so iterator/frequency are ignored, but kept for API compatibility. | |
| return cls() | |
| def build_vocab_from_dataset( | |
| cls, | |
| dataset_name: str = "dlouapre/lichess_2025-01_1M", | |
| split: str = "train", | |
| column: str = "text", | |
| min_frequency: int = 1, | |
| max_samples: Optional[int] = None, | |
| ) -> "ChessTokenizer": | |
| # Fixed vocab, dataset not needed. | |
| return cls() | |
| # ----------------------- | |
| # Required tokenizer API | |
| # ----------------------- | |
| def vocab_size(self) -> int: | |
| return len(self._vocab) | |
| def get_vocab(self) -> Dict[str, int]: | |
| return dict(self._vocab) | |
| def _convert_token_to_id(self, token: str) -> int: | |
| return self._vocab.get(token, self._vocab[self.UNK_TOKEN]) | |
| def _convert_id_to_token(self, index: int) -> str: | |
| return self._ids_to_tokens.get(index, self.UNK_TOKEN) | |
| def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: | |
| os.makedirs(save_directory, exist_ok=True) | |
| vocab_file = os.path.join( | |
| save_directory, | |
| (filename_prefix + "-" if filename_prefix else "") + "vocab.json", | |
| ) | |
| with open(vocab_file, "w", encoding="utf-8") as f: | |
| json.dump(self._vocab, f, ensure_ascii=False, indent=2) | |
| return (vocab_file,) | |
| # ----------------------- | |
| # Tokenization logic | |
| # ----------------------- | |
| def _tokenize(self, text: str) -> List[str]: | |
| """ | |
| Split a full game string into atomic subtokens. | |
| Input format is typically: | |
| "[BOS] WPe2e4 BPe7e5 WNg1f3 ..." | |
| or just | |
| "WPe2e4 BPe7e5 ..." | |
| """ | |
| parts = text.strip().split() | |
| out: List[str] = [] | |
| for part in parts: | |
| if part in (self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN): | |
| out.append(part) | |
| continue | |
| # Expand one move token into subtokens | |
| out.extend(self._split_move_token(part)) | |
| return out | |
| def convert_tokens_to_string(self, tokens: List[str]) -> str: | |
| """ | |
| Reconstruct a space-separated move string from atomic tokens. | |
| We group tokens into moves: | |
| COLOR PIECE FROM TO [PROMO] [SUFFIX] | |
| """ | |
| special = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN} | |
| toks = [t for t in tokens if t not in special] | |
| moves: List[str] = [] | |
| i = 0 | |
| while i < len(toks): | |
| # Need at least 4 tokens for base move | |
| if i + 3 >= len(toks): | |
| break | |
| color, piece, from_sq, to_sq = toks[i], toks[i + 1], toks[i + 2], toks[i + 3] | |
| i += 4 | |
| # Basic sanity: if structure broken, fall back to raw join | |
| if color not in self.COLORS or piece not in self.PIECES or not self._is_square(from_sq) or not self._is_square(to_sq): | |
| # fallback: join remaining tokens | |
| return " ".join(toks) | |
| move = f"{color}{piece}{from_sq}{to_sq}" | |
| # Optional promotion | |
| if i < len(toks) and toks[i] in self.PROMOTIONS: | |
| move += toks[i] | |
| i += 1 | |
| # Optional suffix | |
| if i < len(toks) and toks[i] in self.SUFFIXES: | |
| move += toks[i] | |
| i += 1 | |
| moves.append(move) | |
| return " ".join(moves) | |
| # ----------------------- | |
| # Helpers | |
| # ----------------------- | |
| def _is_square(self, s: str) -> bool: | |
| return ( | |
| len(s) == 2 and | |
| s[0] in "abcdefgh" and | |
| s[1] in "12345678" | |
| ) | |
| def _split_move_token(self, move: str) -> List[str]: | |
| """ | |
| Parse one extended-UCI move token. | |
| Expected minimum length is 6: [W|B][Piece][from][to] | |
| Suffix/promotion may appear after that. | |
| """ | |
| if len(move) < 6: | |
| return [self.UNK_TOKEN] | |
| color = move[0] | |
| piece = move[1] | |
| from_sq = move[2:4] | |
| to_sq = move[4:6] | |
| if color not in self.COLORS or piece not in self.PIECES or not self._is_square(from_sq) or not self._is_square(to_sq): | |
| return [self.UNK_TOKEN] | |
| tokens = [color, piece, from_sq, to_sq] | |
| # Promotion like "=Q" | |
| promo = None | |
| if "=" in move: | |
| eq = move.index("=") | |
| if eq + 1 < len(move): | |
| promo = "=" + move[eq + 1].upper() | |
| if promo in self.PROMOTIONS: | |
| tokens.append(promo) | |
| # Suffix like "(x)", "(+)", "(x+*)", "(o)", "(O)" | |
| if "(" in move: | |
| suf = move[move.index("("):] | |
| if suf in self.SUFFIXES: | |
| tokens.append(suf) | |
| return tokens | |