| # """ | |
| # Improved Chess Tokenizer (Structured) for the Chess Challenge. | |
| # Key idea: | |
| # - Decompose each move into sub-tokens: | |
| # SIDE_W / SIDE_B | |
| # piece (P,N,B,R,Q,K) | |
| # from-square (e2) | |
| # to-square (e4) | |
| # optional flags: CAPTURE, CHECK, MATE, CASTLE_SHORT, CASTLE_LONG | |
| # Important implementation detail: | |
| # - We MUST avoid token-string collisions. In particular, "B" is both: | |
| # - Black side ("B") | |
| # - Bishop piece ("B") | |
| # If we used raw "W"/"B" for side, the vocab dict would overwrite one of them, | |
| # creating non-contiguous IDs and leading to embedding "index out of range". | |
| # """ | |
| # from __future__ import annotations | |
| # import json | |
| # import os | |
| # import re | |
| # from typing import Dict, List, Optional | |
| # from transformers import PreTrainedTokenizer | |
| # class ChessTokenizer(PreTrainedTokenizer): | |
| # model_input_names = ["input_ids", "attention_mask"] | |
| # vocab_files_names = {"vocab_file": "vocab.json"} | |
| # # Special tokens | |
| # PAD_TOKEN = "[PAD]" | |
| # BOS_TOKEN = "[BOS]" | |
| # EOS_TOKEN = "[EOS]" | |
| # UNK_TOKEN = "[UNK]" | |
| # # Side tokens (avoid collision with piece "B" for Bishop) | |
| # SIDE_W = "SIDE_W" | |
| # SIDE_B = "SIDE_B" | |
| # def __init__( | |
| # self, | |
| # vocab_file: Optional[str] = None, | |
| # vocab: Optional[Dict[str, int]] = None, | |
| # **kwargs, | |
| # ): | |
| # self._pad_token = self.PAD_TOKEN | |
| # self._bos_token = self.BOS_TOKEN | |
| # self._eos_token = self.EOS_TOKEN | |
| # self._unk_token = self.UNK_TOKEN | |
| # # Avoid duplicate kwargs when HF loads from disk | |
| # kwargs.pop("pad_token", None) | |
| # kwargs.pop("bos_token", None) | |
| # kwargs.pop("eos_token", None) | |
| # kwargs.pop("unk_token", None) | |
| # if vocab is not None: | |
| # self._vocab = {str(k): int(v) for k, v in vocab.items()} | |
| # elif vocab_file is not None and os.path.exists(vocab_file): | |
| # with open(vocab_file, "r", encoding="utf-8") as f: | |
| # loaded = json.load(f) | |
| # self._vocab = {str(k): int(v) for k, v in loaded.items()} | |
| # else: | |
| # self._vocab = self._create_default_vocab() | |
| # # Ensure IDs are contiguous 0..(len-1) (robust to any old saved vocabs) | |
| # self._vocab = self._normalize_vocab_ids(self._vocab) | |
| # self._ids_to_tokens = {v: k for k, v in self._vocab.items()} | |
| # super().__init__( | |
| # pad_token=self._pad_token, | |
| # bos_token=self._bos_token, | |
| # eos_token=self._eos_token, | |
| # unk_token=self._unk_token, | |
| # **kwargs, | |
| # ) | |
| # @staticmethod | |
| # def _normalize_vocab_ids(vocab: Dict[str, int]) -> Dict[str, int]: | |
| # """ | |
| # Re-map token IDs to be contiguous and deterministic. | |
| # Sort by old id then by token string. | |
| # """ | |
| # items = sorted(vocab.items(), key=lambda kv: (kv[1], kv[0])) | |
| # return {tok: new_id for new_id, (tok, _) in enumerate(items)} | |
| # # ------------------------------------------------------------------ | |
| # # REQUIRED compatibility method (train.py expects this to exist) | |
| # # ------------------------------------------------------------------ | |
| # @classmethod | |
| # def build_vocab_from_dataset( | |
| # cls, | |
| # dataset_name: str = "dlouapre/lichess_2025-01_1M", | |
| # split: str = "train", | |
| # column: str = "text", | |
| # min_frequency: int = 1, | |
| # max_samples: Optional[int] = None, | |
| # ) -> "ChessTokenizer": | |
| # """ | |
| # Compatibility hook. | |
| # For the structured tokenizer, the vocabulary is fixed and does not | |
| # depend on dataset statistics. We keep this method so src/train.py | |
| # (template code) does not need to change. | |
| # """ | |
| # return cls() | |
| # # ------------------------------------------------------------------ | |
| # # Vocabulary construction | |
| # # ------------------------------------------------------------------ | |
| # def _create_default_vocab(self) -> Dict[str, int]: | |
| # special = [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN] | |
| # sides = [self.SIDE_W, self.SIDE_B] # no collision with piece tokens | |
| # pieces = ["P", "N", "B", "R", "Q", "K"] | |
| # files = list("abcdefgh") | |
| # ranks = list("12345678") | |
| # squares = [f + r for f in files for r in ranks] # 64 tokens | |
| # flags = ["CAPTURE", "CHECK", "MATE", "CASTLE_SHORT", "CASTLE_LONG"] | |
| # tokens = special + sides + pieces + squares + flags | |
| # return {tok: i for i, tok in enumerate(tokens)} # contiguous by construction | |
| # @property | |
| # def vocab_size(self) -> int: | |
| # return len(self._vocab) | |
| # def get_vocab(self) -> Dict[str, int]: | |
| # return dict(self._vocab) | |
| # # ------------------------------------------------------------------ | |
| # # Tokenization logic | |
| # # ------------------------------------------------------------------ | |
| # MOVE_REGEX = re.compile( | |
| # r""" | |
| # (?P<side>[WB]) | |
| # (?P<piece>[PNBRQK]) | |
| # (?P<from>[a-h][1-8]) | |
| # (?P<to>[a-h][1-8]) | |
| # (?P<suffix>.*)? | |
| # """, | |
| # re.VERBOSE, | |
| # ) | |
| # def _tokenize(self, text: str) -> List[str]: | |
| # out: List[str] = [] | |
| # for move in text.strip().split(): | |
| # out.extend(self._decompose_move(move)) | |
| # return out | |
| # def _decompose_move(self, move: str) -> List[str]: | |
| # m = self.MOVE_REGEX.match(move) | |
| # if not m: | |
| # return [self.UNK_TOKEN] | |
| # side_raw = m.group("side") | |
| # side_tok = self.SIDE_W if side_raw == "W" else self.SIDE_B | |
| # tokens = [ | |
| # side_tok, | |
| # m.group("piece"), | |
| # m.group("from"), | |
| # m.group("to"), | |
| # ] | |
| # suffix = m.group("suffix") or "" | |
| # if "(x)" in suffix: | |
| # tokens.append("CAPTURE") | |
| # if "(+*)" in suffix: | |
| # tokens.append("MATE") | |
| # elif "(+)" in suffix: | |
| # tokens.append("CHECK") | |
| # if "(o)" in suffix: | |
| # tokens.append("CASTLE_SHORT") | |
| # if "(O)" in suffix: | |
| # tokens.append("CASTLE_LONG") | |
| # return tokens | |
| # # ------------------------------------------------------------------ | |
| # # ID conversion | |
| # # ------------------------------------------------------------------ | |
| # def _convert_token_to_id(self, token: str) -> int: | |
| # return self._vocab.get(token, self._vocab[self.UNK_TOKEN]) | |
| # def _convert_id_to_token(self, index: int) -> str: | |
| # return self._ids_to_tokens.get(index, self.UNK_TOKEN) | |
| # def convert_tokens_to_string(self, tokens: List[str]) -> str: | |
| # special = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN} | |
| # return " ".join(t for t in tokens if t not in special) | |
| # # ------------------------------------------------------------------ | |
| # # Saving | |
| # # ------------------------------------------------------------------ | |
| # def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple: | |
| # os.makedirs(save_directory, exist_ok=True) | |
| # vocab_file = os.path.join( | |
| # save_directory, | |
| # (filename_prefix + "-" if filename_prefix else "") + "vocab.json", | |
| # ) | |
| # with open(vocab_file, "w", encoding="utf-8") as f: | |
| # json.dump(self._vocab, f, indent=2) | |
| # return (vocab_file,) | |
| # def count_vocab_from_dataset( | |
| # dataset_name: str = "dlouapre/lichess_2025-01_1M", | |
| # split: str = "train", | |
| # column: str = "text", | |
| # max_samples: Optional[int] = 10000, | |
| # ) -> Dict[str, int]: | |
| # """ | |
| # Count token frequencies after structured tokenization. | |
| # (Editor warning about 'datasets' can be ignored if terminal run works.) | |
| # """ | |
| # from collections import Counter | |
| # from datasets import load_dataset | |
| # dataset = load_dataset(dataset_name, split=split) | |
| # if max_samples is not None: | |
| # dataset = dataset.select(range(min(max_samples, len(dataset)))) | |
| # tok = ChessTokenizer() | |
| # counts = Counter() | |
| # for ex in dataset: | |
| # counts.update(tok._tokenize(ex[column])) | |
| # return dict(counts) | |
| """ | |
| Final Structured Chess Tokenizer for the Chess Challenge. | |
| Design goals: | |
| - Strong legality bias | |
| - Fixed, collision-free vocabulary | |
| - HF-compatible (Trainer, save/load, Hub) | |
| - Evaluator-friendly (square extraction still works) | |
| Move decomposition: | |
| PIECE | |
| FROM_<square> | |
| TO_<square> | |
| optional FLAGS | |
| Example: | |
| P FROM_e2 TO_e4 | |
| N FROM_g1 TO_f3 CHECK | |
| K FROM_e1 TO_g1 CASTLE_SHORT | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import os | |
| import re | |
| from typing import Dict, List, Optional | |
| from transformers import PreTrainedTokenizer | |
| class ChessTokenizer(PreTrainedTokenizer): | |
| model_input_names = ["input_ids", "attention_mask"] | |
| vocab_files_names = {"vocab_file": "vocab.json"} | |
| # Special tokens | |
| PAD_TOKEN = "[PAD]" | |
| BOS_TOKEN = "[BOS]" | |
| EOS_TOKEN = "[EOS]" | |
| UNK_TOKEN = "[UNK]" | |
| # Fixed role prefixes | |
| FROM_PREFIX = "FROM_" | |
| TO_PREFIX = "TO_" | |
| def __init__( | |
| self, | |
| vocab_file: Optional[str] = None, | |
| vocab: Optional[Dict[str, int]] = None, | |
| **kwargs, | |
| ): | |
| self._pad_token = self.PAD_TOKEN | |
| self._bos_token = self.BOS_TOKEN | |
| self._eos_token = self.EOS_TOKEN | |
| self._unk_token = self.UNK_TOKEN | |
| # Avoid duplicate kwargs when loading | |
| for k in ("pad_token", "bos_token", "eos_token", "unk_token"): | |
| kwargs.pop(k, None) | |
| if vocab is not None: | |
| self._vocab = {str(k): int(v) for k, v in vocab.items()} | |
| elif vocab_file and os.path.exists(vocab_file): | |
| with open(vocab_file, "r", encoding="utf-8") as f: | |
| self._vocab = {str(k): int(v) for k, v in json.load(f).items()} | |
| else: | |
| self._vocab = self._create_default_vocab() | |
| # Ensure contiguous IDs | |
| self._vocab = self._normalize_vocab(self._vocab) | |
| self._ids_to_tokens = {v: k for k, v in self._vocab.items()} | |
| super().__init__( | |
| pad_token=self._pad_token, | |
| bos_token=self._bos_token, | |
| eos_token=self._eos_token, | |
| unk_token=self._unk_token, | |
| **kwargs, | |
| ) | |
| def _normalize_vocab(vocab: Dict[str, int]) -> Dict[str, int]: | |
| items = sorted(vocab.items(), key=lambda kv: (kv[1], kv[0])) | |
| return {tok: i for i, (tok, _) in enumerate(items)} | |
| # ------------------------------------------------------------ | |
| # Required by train.py (kept for compatibility) | |
| # ------------------------------------------------------------ | |
| def build_vocab_from_dataset( | |
| cls, | |
| *args, | |
| **kwargs, | |
| ) -> "ChessTokenizer": | |
| return cls() | |
| # ------------------------------------------------------------ | |
| # Vocabulary | |
| # ------------------------------------------------------------ | |
| def _create_default_vocab(self) -> Dict[str, int]: | |
| special = [ | |
| self.PAD_TOKEN, | |
| self.BOS_TOKEN, | |
| self.EOS_TOKEN, | |
| self.UNK_TOKEN, | |
| ] | |
| pieces = ["P", "N", "B", "R", "Q", "K"] | |
| files = "abcdefgh" | |
| ranks = "12345678" | |
| squares = [f + r for f in files for r in ranks] | |
| from_tokens = [self.FROM_PREFIX + sq for sq in squares] | |
| to_tokens = [self.TO_PREFIX + sq for sq in squares] | |
| flags = [ | |
| "CAPTURE", | |
| "CHECK", | |
| "MATE", | |
| "CASTLE_SHORT", | |
| "CASTLE_LONG", | |
| ] | |
| tokens = special + pieces + from_tokens + to_tokens + flags | |
| return {tok: i for i, tok in enumerate(tokens)} | |
| def vocab_size(self) -> int: | |
| return len(self._vocab) | |
| def get_vocab(self) -> Dict[str, int]: | |
| return dict(self._vocab) | |
| # ------------------------------------------------------------ | |
| # Tokenization | |
| # ------------------------------------------------------------ | |
| MOVE_REGEX = re.compile( | |
| r""" | |
| (?P<piece>[PNBRQK]) | |
| (?P<from>[a-h][1-8]) | |
| (?P<to>[a-h][1-8]) | |
| (?P<suffix>.*)? | |
| """, | |
| re.VERBOSE, | |
| ) | |
| def _tokenize(self, text: str) -> List[str]: | |
| out: List[str] = [] | |
| for move in text.strip().split(): | |
| out.extend(self._decompose_move(move)) | |
| return out | |
| def _decompose_move(self, move: str) -> List[str]: | |
| m = self.MOVE_REGEX.search(move) | |
| if not m: | |
| return [self.UNK_TOKEN] | |
| tokens = [ | |
| m.group("piece"), | |
| self.FROM_PREFIX + m.group("from"), | |
| self.TO_PREFIX + m.group("to"), | |
| ] | |
| suffix = m.group("suffix") or "" | |
| if "(x)" in suffix: | |
| tokens.append("CAPTURE") | |
| if "(+*)" in suffix: | |
| tokens.append("MATE") | |
| elif "(+)" in suffix: | |
| tokens.append("CHECK") | |
| if "(o)" in suffix: | |
| tokens.append("CASTLE_SHORT") | |
| if "(O)" in suffix: | |
| tokens.append("CASTLE_LONG") | |
| return tokens | |
| # ------------------------------------------------------------ | |
| # ID conversion | |
| # ------------------------------------------------------------ | |
| def _convert_token_to_id(self, token: str) -> int: | |
| return self._vocab.get(token, self._vocab[self.UNK_TOKEN]) | |
| def _convert_id_to_token(self, index: int) -> str: | |
| return self._ids_to_tokens.get(index, self.UNK_TOKEN) | |
| def convert_tokens_to_string(self, tokens: List[str]) -> str: | |
| special = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN} | |
| return " ".join(t for t in tokens if t not in special) | |
| # ------------------------------------------------------------ | |
| # Saving | |
| # ------------------------------------------------------------ | |
| def save_vocabulary( | |
| self, | |
| save_directory: str, | |
| filename_prefix: Optional[str] = None, | |
| ) -> tuple: | |
| os.makedirs(save_directory, exist_ok=True) | |
| path = os.path.join( | |
| save_directory, | |
| (filename_prefix + "-" if filename_prefix else "") + "vocab.json", | |
| ) | |
| with open(path, "w", encoding="utf-8") as f: | |
| json.dump(self._vocab, f, indent=2) | |
| return (path,) | |