# src/tokenizer.py from __future__ import annotations import json import os from typing import Dict, List, Optional, Tuple from transformers import PreTrainedTokenizer # --- Fixed vocab pieces --- _SQUARES = [f"{file}{rank}" for rank in "12345678" for file in "abcdefgh"] _PROMOS = ["=Q", "=R", "=B", "=N"] class SquaresOnlyChessTokenizer(PreTrainedTokenizer): """ Tokenizer designed to MINIMIZE illegal-move formatting issues under the provided evaluate.py, WITHOUT modifying evaluate.py. Key idea: - evaluate.py extracts UCI using move_token[2:4] + move_token[4:6] - so decoded move strings must look like: "W" + + from_sq + to_sq [+ "=Q/R/B/N"] e.g. "WPe2e4", "WNg8f6", "WPe7e8=Q" - evaluate.py stops generation on whitespace; we therefore include a SPACE token as a move separator. Encoding (per move): from_sq, to_sq, promo? , " " (space is a separator token) Decoding (per move): "WP" + from_sq + to_sq + promo? (constant prefix) We strip all suffixes like (x), (+), (+*), (o)/(O) since evaluator doesn't use them. """ vocab_files_names = {"vocab_file": "vocab.json"} model_input_names = ["input_ids", "attention_mask"] PAD_TOKEN = "[PAD]" BOS_TOKEN = "[BOS]" EOS_TOKEN = "[EOS]" UNK_TOKEN = "[UNK]" MOVE_SEP = " " # IMPORTANT: whitespace => evaluator stops on separator def __init__( self, vocab: Optional[Dict[str, int]] = None, vocab_file: Optional[str] = None, **kwargs, ): # Avoid duplicates when loading/saving kwargs.pop("pad_token", None) kwargs.pop("bos_token", None) kwargs.pop("eos_token", None) kwargs.pop("unk_token", None) self._pad_token = self.PAD_TOKEN self._bos_token = self.BOS_TOKEN self._eos_token = self.EOS_TOKEN self._unk_token = self.UNK_TOKEN if vocab is not None: self._vocab = vocab elif vocab_file is not None and os.path.exists(vocab_file): with open(vocab_file, "r", encoding="utf-8") as f: self._vocab = json.load(f) else: self._vocab = self._build_fixed_vocab() self._ids_to_tokens = {i: t for t, i in self._vocab.items()} super().__init__( pad_token=self._pad_token, bos_token=self._bos_token, eos_token=self._eos_token, unk_token=self._unk_token, **kwargs, ) # ------------------------- # Vocab # ------------------------- @classmethod def _build_fixed_vocab(cls) -> Dict[str, int]: toks = [cls.PAD_TOKEN, cls.BOS_TOKEN, cls.EOS_TOKEN, cls.UNK_TOKEN] toks += [cls.MOVE_SEP] toks += _SQUARES toks += _PROMOS return {t: i for i, t in enumerate(toks)} @property def vocab_size(self) -> int: return len(self._vocab) def get_vocab(self) -> Dict[str, int]: return dict(self._vocab) # ------------------------- # Helpers: parse / normalize # ------------------------- @staticmethod def _strip_suffixes(token: str) -> str: # Remove "(x)" "(+)" "(+*)" "(o)" "(O)" etc. return token.split("(", 1)[0] @staticmethod def _extract_squares_and_promo(base: str) -> Tuple[Optional[str], Optional[str], Optional[str]]: """ base expected like: WPe2e4 BNg8f6 WPe7e8=Q Return: (from_sq, to_sq, promo_token like '=Q' or None) """ if len(base) < 6: return None, None, None from_sq = base[2:4].lower() to_sq = base[4:6].lower() if from_sq not in _SQUARES or to_sq not in _SQUARES: return None, None, None promo = None if "=" in base: promo = base[base.index("="):].upper() # "=Q" if promo not in _PROMOS: promo = None return from_sq, to_sq, promo # ------------------------- # Tokenization API # ------------------------- def _tokenize(self, text: str) -> List[str]: """ Tokenize a string of moves (space-separated). Special tokens are preserved if present. Each move becomes: from, to, promo?, " " """ raw = text.strip().split() out: List[str] = [] for tok in raw: if tok in (self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN): out.append(tok) continue base = self._strip_suffixes(tok) from_sq, to_sq, promo = self._extract_squares_and_promo(base) if from_sq is None or to_sq is None: out.append(self.UNK_TOKEN) out.append(self.MOVE_SEP) continue out.append(from_sq) out.append(to_sq) if promo is not None: out.append(promo) out.append(self.MOVE_SEP) return out def _convert_token_to_id(self, token: str) -> int: return self._vocab.get(token, self._vocab[self.UNK_TOKEN]) def _convert_id_to_token(self, index: int) -> str: return self._ids_to_tokens.get(index, self.UNK_TOKEN) def convert_tokens_to_string(self, tokens: List[str]) -> str: """ Reconstruct a text compatible with evaluate.py. Each move is rendered as: "WP" + from + to + promo? Moves are separated by actual spaces (MOVE_SEP token). """ s: List[str] = [] at_move_start = True for tok in tokens: if tok in (self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN): continue if tok == self.MOVE_SEP: s.append(" ") at_move_start = True continue if tok in _PROMOS: s.append(tok) continue if tok in _SQUARES: if at_move_start: s.append("WP") # constant prefix, starts with 'W' at_move_start = False s.append(tok) continue # Fallback (should be rare) if at_move_start: s.append("WP") at_move_start = False s.append(tok) return "".join(s) # ------------------------- # Saving / loading # ------------------------- def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: os.makedirs(save_directory, exist_ok=True) path = os.path.join(save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.json") with open(path, "w", encoding="utf-8") as f: json.dump(self._vocab, f, ensure_ascii=False, indent=2) return (path,) ChessTokenizer = SquaresOnlyChessTokenizer