""" Custom Chess Tokenizer for the Chess Challenge. Decomposed Chess Tokenizer (coverage, no UNKs in practice for well-formed moves). Each dataset move like: WPe2e4 WBb5c6(x+) WPe7e8=Q(+) is tokenized into: ["WP", "e2_f", "e4_t"] # normal ["WB", "b5_f", "c6_t"] # capture/check ignored ["WP", "e7_f", "e8_t", "q"] # promotion token appended """ from __future__ import annotations import json import os import re from typing import Dict, List, Optional from transformers import PreTrainedTokenizer class ChessTokenizer(PreTrainedTokenizer): model_input_names = ["input_ids", "attention_mask"] vocab_files_names = {"vocab_file": "vocab.json"} PAD_TOKEN = "[PAD]" BOS_TOKEN = "[BOS]" EOS_TOKEN = "[EOS]" UNK_TOKEN = "[UNK]" SQUARE_RE = re.compile(r"([a-h][1-8])([a-h][1-8])") PROMO_RE = re.compile(r"=([QRBN])", re.IGNORECASE) def __init__( self, vocab_file: Optional[str] = None, vocab: Optional[Dict[str, int]] = None, **kwargs, ): self._pad_token = self.PAD_TOKEN self._bos_token = self.BOS_TOKEN self._eos_token = self.EOS_TOKEN self._unk_token = self.UNK_TOKEN # avoid duplicate kwargs on load kwargs.pop("pad_token", None) kwargs.pop("bos_token", None) kwargs.pop("eos_token", None) kwargs.pop("unk_token", None) if vocab is not None: self._vocab = dict(vocab) elif vocab_file is not None and os.path.exists(vocab_file): with open(vocab_file, "r", encoding="utf-8") as f: self._vocab = json.load(f) else: self._vocab = self._build_fixed_vocab() self._ids_to_tokens = {v: k for k, v in self._vocab.items()} super().__init__( pad_token=self._pad_token, bos_token=self._bos_token, eos_token=self._eos_token, unk_token=self._unk_token, **kwargs, ) def _build_fixed_vocab(self) -> Dict[str, int]: special = [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN] # color+piece tokens pieces = ["P", "N", "B", "R", "Q", "K"] cp = [f"W{p}" for p in pieces] + [f"B{p}" for p in pieces] # squares with role suffix files = "abcdefgh" ranks = "12345678" squares = [f"{f}{r}" for f in files for r in ranks] from_tokens = [f"{sq}_f" for sq in squares] to_tokens = [f"{sq}_t" for sq in squares] # promotions as separate token (lowercase) promo = ["q", "r", "b", "n"] tokens = special + cp + from_tokens + to_tokens + promo return {t: i for i, t in enumerate(tokens)} @classmethod def build_vocab_from_dataset( cls, dataset_name: str = "dlouapre/lichess_2025-01_1M", split: str = "train", column: str = "text", min_frequency: int = 0, max_samples: Optional[int] = None, save_dir: Optional[str] = None, ) -> "ChessTokenizer": tok = cls() if save_dir is not None: tok.save_pretrained(save_dir) return tok @property def vocab_size(self) -> int: return len(self._vocab) def get_vocab(self) -> Dict[str, int]: return dict(self._vocab) def _tokenize(self, text: str) -> List[str]: text = text.strip() if not text: return [] raw = text.split() out: List[str] = [] for tok in raw: # keep explicit BOS/EOS if they appear in text if tok in (self.BOS_TOKEN, self.EOS_TOKEN, self.PAD_TOKEN, self.UNK_TOKEN): out.append(tok) continue # Expect at least color+piece at positions 0,1 if len(tok) < 6: out.append(self.UNK_TOKEN) continue color = tok[0] # W/B piece = tok[1] # P/N/B/R/Q/K cp = f"{color}{piece}" # Find squares anywhere in token (works even with suffixes like (x+), (o), etc.) m = self.SQUARE_RE.search(tok) if not m: out.append(self.UNK_TOKEN) continue from_sq, to_sq = m.group(1), m.group(2) out.extend([cp, f"{from_sq}_f", f"{to_sq}_t"]) # Promotion pm = self.PROMO_RE.search(tok) if pm: out.append(pm.group(1).lower()) return out def _convert_token_to_id(self, token: str) -> int: return self._vocab.get(token, self._vocab.get(self.UNK_TOKEN, 0)) def _convert_id_to_token(self, index: int) -> str: return self._ids_to_tokens.get(index, self.UNK_TOKEN) def convert_tokens_to_string(self, tokens: List[str]) -> str: special = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN} return " ".join(t for t in tokens if t not in special) def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple: os.makedirs(save_directory, exist_ok=True) vocab_file = os.path.join( save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.json", ) with open(vocab_file, "w", encoding="utf-8") as f: json.dump(self._vocab, f, ensure_ascii=False, indent=2) return (vocab_file,) def count_vocab_from_dataset( dataset_name: str = "dlouapre/lichess_2025-01_1M", split: str = "train", column: str = "text", max_samples: Optional[int] = 10000, ) -> Dict[str, int]: """ With a fixed vocab tokenizer, "count vocab from dataset" is not very meaningful. Kept for API compatibility; returns the fixed vocab. """ return ChessTokenizer().get_vocab() from transformers import AutoTokenizer AutoTokenizer.register("ChessTokenizer", ChessTokenizer)