""" Custom Chess Tokenizer for the Chess Challenge. This tokenizer decomposes moves into atomic tokens: Piece -> Source Square -> Target Square -> Suffixes. Example: "WPe2e4" -> ['P', 'e2', 'e4'] (Color is implicit to save context) Example: "Bxb7+" -> ['B', 'c8', 'b7', '(x)', '(+)'] """ from __future__ import annotations import json import os import re from typing import Dict, List, Optional from transformers import PreTrainedTokenizer class ChessTokenizer(PreTrainedTokenizer): model_input_names = ["input_ids", "attention_mask"] vocab_files_names = {"vocab_file": "vocab.json"} # Special tokens PAD_TOKEN = "[PAD]" BOS_TOKEN = "[BOS]" EOS_TOKEN = "[EOS]" UNK_TOKEN = "[UNK]" # Atomic components PIECES = ["P", "N", "B", "R", "Q", "K"] FILES = "abcdefgh" RANKS = "12345678" SUFFIXES = ["(x)", "(+)", "(+*)", "(o)", "(O)", "(=)"] def __init__( self, vocab_file: Optional[str] = None, vocab: Optional[Dict[str, int]] = None, **kwargs, ): # Initialize special tokens self._pad_token = self.PAD_TOKEN self._bos_token = self.BOS_TOKEN self._eos_token = self.EOS_TOKEN self._unk_token = self.UNK_TOKEN # Clean kwargs kwargs.pop("pad_token", None) kwargs.pop("bos_token", None) kwargs.pop("eos_token", None) kwargs.pop("unk_token", None) # Load or create FIXED vocabulary if vocab is not None: self._vocab = vocab else: self._vocab = self._create_fixed_vocab() self._ids_to_tokens = {v: k for k, v in self._vocab.items()} super().__init__( pad_token=self._pad_token, bos_token=self._bos_token, eos_token=self._eos_token, unk_token=self._unk_token, **kwargs, ) def _create_fixed_vocab(self) -> Dict[str, int]: """Creates the fixed vocabulary of ~80 atomic tokens.""" vocab = {} idx = 0 # 1. Special Tokens for token in [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]: vocab[token] = idx idx += 1 # 2. Pieces for p in self.PIECES: vocab[p] = idx idx += 1 # 3. Squares (a1...h8) # We treat squares as atomic tokens for better spatial learning for f in self.FILES: for r in self.RANKS: vocab[f"{f}{r}"] = idx idx += 1 # 4. Suffixes for s in self.SUFFIXES: vocab[s] = idx idx += 1 return vocab @classmethod def build_vocab_from_dataset( cls, dataset_name: str = "dlouapre/lichess_2025-01_1M", **kwargs ) -> "ChessTokenizer": """ Override: Returns the tokenizer with the fixed vocabulary immediately. We do not need to scan the dataset anymore. """ print("Initializing Fixed Vocabulary Tokenizer (Deconstructed Strategy)...") return cls() @property def vocab_size(self) -> int: return len(self._vocab) def get_vocab(self) -> Dict[str, int]: return dict(self._vocab) def _tokenize(self, text: str) -> List[str]: """ Decomposes move strings into atomic tokens. Input: "[BOS] WPe2e4 BNg8f6" Output: ['[BOS]', 'P', 'e2', 'e4', 'N', 'g8', 'f6'] """ tokens = [] moves = text.strip().split() # Set of special tokens for quick lookup special_tokens = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN} for move in moves: # Skip empty strings if not move: continue # 1. Handle Special Tokens (Important for data.py compatibility) if move in special_tokens: tokens.append(move) continue # 2. Regex to parse Lichess format: WPe2e4(x) # Group 1: Color (W/B) - Ignored # Group 2: Piece (P/N/B/R/Q/K) # Group 3: Source (e.g. e2) # Group 4: Target (e.g. e4) # Group 5: Suffix (optional) match = re.match(r"([WB])([PNBRQK])([a-h][1-8])([a-h][1-8])(.*)", move) if match: _, piece, src, dst, suffix = match.groups() tokens.extend([piece, src, dst]) if suffix: if suffix in self._vocab: tokens.append(suffix) else: # Fallback for unexpected formats found_any = False # Check for piece for p in self.PIECES: if p in move: tokens.append(p) found_any = True break # Check for squares squares = re.findall(r"[a-h][1-8]", move) tokens.extend(squares) if squares: found_any = True # Check for suffixes for s in self.SUFFIXES: if s in move: tokens.append(s) found_any = True if not found_any: tokens.append(self.UNK_TOKEN) return tokens def _convert_token_to_id(self, token: str) -> int: return self._vocab.get(token, self._vocab.get(self.UNK_TOKEN)) def _convert_id_to_token(self, index: int) -> str: return self._ids_to_tokens.get(index, self.UNK_TOKEN) def convert_tokens_to_string(self, tokens: List[str]) -> str: # Reconstructs a readable string (space separated for clarity) special = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN} return " ".join(t for t in tokens if t not in special) def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple: if not os.path.isdir(save_directory): os.makedirs(save_directory, exist_ok=True) vocab_file = os.path.join( save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.json" ) with open(vocab_file, "w", encoding="utf-8") as f: json.dump(self._vocab, f, ensure_ascii=False, indent=2) return (vocab_file,)