from __future__ import annotations import json import re from pathlib import Path from typing import Dict, List, Optional, Tuple from transformers import PreTrainedTokenizer _SQUARE_RE = re.compile(r"[a-h][1-8]") _PROMO_RE = re.compile(r"=([QRBNqrbn])") def _all_squares() -> List[str]: files = "abcdefgh" ranks = "12345678" return [f + r for r in ranks for f in files] class ChessSquareTokenizer(PreTrainedTokenizer): """ We read strings like "WPe2e4" or "BPd7d8=Q" and turn them into tokens. We also insert [EOS] after each move so generation can stop cleanly. """ vocab_files_names = {"vocab_file": "vocab.json"} model_input_names = ["input_ids", "attention_mask"] PAD_TOKEN = "[PAD]" BOS_TOKEN = "[BOS]" EOS_TOKEN = "[EOS]" UNK_TOKEN = "[UNK]" W_TOKEN = "W" B_TOKEN = "B" def __init__( self, vocab_file: Optional[str] = None, vocab: Optional[Dict[str, int]] = None, **kwargs, ): self._pad_token = self.PAD_TOKEN self._bos_token = self.BOS_TOKEN self._eos_token = self.EOS_TOKEN self._unk_token = self.UNK_TOKEN kwargs.pop("pad_token", None) kwargs.pop("bos_token", None) kwargs.pop("eos_token", None) kwargs.pop("unk_token", None) if vocab is not None: self._vocab = dict(vocab) elif vocab_file is not None and Path(vocab_file).exists(): self._vocab = json.loads(Path(vocab_file).read_text(encoding="utf-8")) else: self._vocab = self._build_default_vocab() self._ids_to_tokens = {i: t for t, i in self._vocab.items()} super().__init__( pad_token=self._pad_token, bos_token=self._bos_token, eos_token=self._eos_token, unk_token=self._unk_token, **kwargs, ) @staticmethod def _build_default_vocab() -> Dict[str, int]: special = [ ChessSquareTokenizer.PAD_TOKEN, ChessSquareTokenizer.BOS_TOKEN, ChessSquareTokenizer.EOS_TOKEN, ChessSquareTokenizer.UNK_TOKEN, ] turns = [ChessSquareTokenizer.W_TOKEN, ChessSquareTokenizer.B_TOKEN] squares = _all_squares() promos = ["q", "r", "b", "n"] tokens = special + turns + squares + promos return {t: i for i, t in enumerate(tokens)} @property def vocab_size(self) -> int: return len(self._vocab) def get_vocab(self) -> Dict[str, int]: return dict(self._vocab) def _convert_token_to_id(self, token: str) -> int: return self._vocab.get(token, self._vocab[self.UNK_TOKEN]) def _convert_id_to_token(self, index: int) -> str: return self._ids_to_tokens.get(index, self.UNK_TOKEN) def _tokenize(self, text: str) -> List[str]: # Input is a list of moves separated by spaces. tokens: List[str] = [] for chunk in text.strip().split(): if chunk in (self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN): tokens.append(chunk) continue # Moves in the dataset start with W or B. if chunk and chunk[0] in ("W", "B"): tokens.append(chunk[0]) from_sq, to_sq, promo = self._parse_move_chunk(chunk) if from_sq is None or to_sq is None: tokens.append(self.UNK_TOKEN) continue tokens.append(from_sq) tokens.append(to_sq) if promo is not None: tokens.append(promo) # End-of-move marker. tokens.append(self.EOS_TOKEN) return tokens @staticmethod def _parse_move_chunk(chunk: str) -> Tuple[Optional[str], Optional[str], Optional[str]]: # Grab the first two squares we see. squares = _SQUARE_RE.findall(chunk) if len(squares) < 2: return None, None, None from_sq, to_sq = squares[0], squares[1] # Promotion shows up like "=Q". promo = None m = _PROMO_RE.search(chunk) if m: promo = m.group(1).lower() if promo not in {"q", "r", "b", "n"}: promo = None return from_sq, to_sq, promo def convert_tokens_to_string(self, tokens: List[str]) -> str: # Keep squares and promo tokens, drop PAD for cleanliness. return " ".join(t for t in tokens if t != self.PAD_TOKEN) def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple: save_dir = Path(save_directory) save_dir.mkdir(parents=True, exist_ok=True) fname = (filename_prefix + "-" if filename_prefix else "") + "vocab.json" path = save_dir / fname path.write_text(json.dumps(self._vocab, indent=2), encoding="utf-8") return (str(path),)