""" Coordinate Chess Tokenizer (Vocab Size = 72). Compatible with Hugging Face AutoTokenizer and existing Evaluation scripts. """ from __future__ import annotations import json import os import re from typing import Dict, List, Optional, Tuple, Union from transformers import PreTrainedTokenizer class ChessTokenizer(PreTrainedTokenizer): model_input_names = ["input_ids", "attention_mask"] vocab_files_names = {"vocab_file": "vocab.json"} # Special tokens PAD_TOKEN = "[PAD]" BOS_TOKEN = "[BOS]" EOS_TOKEN = "[EOS]" UNK_TOKEN = "[UNK]" # Regex to capture coordinates and promotions from any format (UCI, SAN, Extended) # Captures: "e2", "e4", "q" inside strings like "WPe2e4" or "e2e4q" MOVE_REGEX = re.compile(r"([a-h][1-8])([a-h][1-8])([qrbn])?") def __init__( self, vocab_file: Optional[str] = None, **kwargs, ): # Initialize special tokens self._pad_token = self.PAD_TOKEN self._bos_token = self.BOS_TOKEN self._eos_token = self.EOS_TOKEN self._unk_token = self.UNK_TOKEN # Clean kwargs to avoid duplication errors during loading kwargs.pop("pad_token", None) kwargs.pop("bos_token", None) kwargs.pop("eos_token", None) kwargs.pop("unk_token", None) # 1. Load or Create Vocabulary # If a vocab_file is provided (loading from HF), use it. # Otherwise, create the fixed 72-token vocabulary. if vocab_file is not None and os.path.exists(vocab_file): with open(vocab_file, "r", encoding="utf-8") as f: self._vocab = json.load(f) else: self._vocab = self._create_fixed_vocab() self._ids_to_tokens = {v: k for k, v in self._vocab.items()} super().__init__( pad_token=self._pad_token, bos_token=self._bos_token, eos_token=self._eos_token, unk_token=self._unk_token, **kwargs, ) def _create_fixed_vocab(self) -> Dict[str, int]: """Creates the deterministic 72-token vocabulary.""" vocab = {} # 0-3: Special Tokens special_tokens = [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN] for idx, token in enumerate(special_tokens): vocab[token] = idx # 4-7: Promotions (q, r, b, n) promotions = ["q", "r", "b", "n"] for idx, token in enumerate(promotions): vocab[token] = len(vocab) # 8-71: Squares (a1...h8) files = "abcdefgh" ranks = "12345678" for r in ranks: for f in files: square = f + r vocab[square] = len(vocab) return vocab @property def vocab_size(self) -> int: return len(self._vocab) def get_vocab(self) -> Dict[str, int]: return dict(self._vocab) def _tokenize(self, text: str) -> List[str]: """ Robust tokenization handling both raw coordinates and 'dirty' UCI extended strings. """ tokens = [] # Split by whitespace first raw_chunks = text.strip().split() # Set of exact match tokens to preserve special tokens special_set = {self.BOS_TOKEN, self.EOS_TOKEN, self.PAD_TOKEN, self.UNK_TOKEN} for chunk in raw_chunks: # If it's explicitly a special token, keep it if chunk in special_set: tokens.append(chunk) continue # Otherwise, use Regex to extract coordinates # This handles "WPe2e4" -> ["e2", "e4"] # And "e2e4" -> ["e2", "e4"] match = self.MOVE_REGEX.search(chunk) if match: start_sq, end_sq, promotion = match.groups() tokens.append(start_sq) tokens.append(end_sq) if promotion: tokens.append(promotion) else: # If regex fails but it is in our vocab (e.g. isolated 'a1'), take it if chunk in self._vocab: tokens.append(chunk) else: tokens.append(self.UNK_TOKEN) return tokens def _convert_token_to_id(self, token: str) -> int: return self._vocab.get(token, self._vocab.get(self.UNK_TOKEN)) def _convert_id_to_token(self, index: int) -> str: return self._ids_to_tokens.get(index, self.UNK_TOKEN) def convert_tokens_to_string(self, tokens: List[str]) -> str: """ Reconstructs string. Important: adds spaces between coordinates. Evaluate.py handles spaces fine via regex. """ special = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN} clean_tokens = [t for t in tokens if t not in special] return " ".join(clean_tokens) def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: """ Vital for Hugging Face: saves the vocab.json to the directory. """ if not os.path.isdir(save_directory): os.makedirs(save_directory, exist_ok=True) vocab_file = os.path.join( save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.json" ) with open(vocab_file, "w", encoding="utf-8") as f: json.dump(self._vocab, f, ensure_ascii=False, indent=2) return (vocab_file,) @classmethod def build_vocab_from_dataset( cls, dataset_name: str = "dlouapre/lichess_2025-01_1M", split: str = "train", column: str = "text", min_frequency: int = 500, # Ignored max_samples: Optional[int] = 100000, # Ignored ) -> "ChessTokenizer": """ Mock implementation to satisfy train.py API. Ignores dataset scanning since vocab is fixed. """ print(f"Coordinate Tokenizer: Using fixed vocabulary (size 72). Ignoring dataset scan.") return cls()