from __future__ import annotations import json import os import re from typing import Dict, List, Optional from transformers import PreTrainedTokenizer, AutoTokenizer class ChessTokenizer(PreTrainedTokenizer): vocab_files_names = {"vocab_file": "vocab.json"} model_input_names = ["input_ids", "attention_mask"] PIECES = ["WP", "WN", "WB", "WR", "WQ", "WK", "BP", "BN", "BB", "BR", "BQ", "BK"] SQUARES = [f"{c}{r}" for c in "abcdefgh" for r in "12345678"] SUFFIXES = ["(-)", "(x)", "(+)", "(#)", "(x+)", "(x#)", "(O)", "(o)", "(Q)", "=Q"] PAD_TOKEN = "[PAD]" BOS_TOKEN = "[BOS]" EOS_TOKEN = "[EOS]" UNK_TOKEN = "[UNK]" def __init__(self, vocab_file: Optional[str] = None, vocab: Optional[Dict[str, int]] = None, **kwargs): # 1. Build or Load Vocab self._vocab = vocab if vocab_file and os.path.exists(vocab_file): with open(vocab_file, "r", encoding="utf-8") as f: self._vocab = json.load(f) if not self._vocab: self._vocab = self._build_split_vocab() self._ids_to_tokens = {v: k for k, v in self._vocab.items()} pad_token = kwargs.pop("pad_token", self.PAD_TOKEN) bos_token = kwargs.pop("bos_token", self.BOS_TOKEN) eos_token = kwargs.pop("eos_token", self.EOS_TOKEN) unk_token = kwargs.pop("unk_token", self.UNK_TOKEN) super().__init__( pad_token=pad_token, bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs, ) def _build_split_vocab(self): tokens = [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN] tokens += self.PIECES + self.SQUARES + self.SUFFIXES unique_tokens = sorted(list(set(tokens))) return {t: i for i, t in enumerate(unique_tokens)} def get_vocab(self) -> Dict[str, int]: return dict(self._vocab) @property def vocab_size(self) -> int: return len(self._vocab) def _tokenize(self, text: str) -> List[str]: moves = text.strip().split() tokens = [] pattern = re.compile(r"([WB][PNBRQK])([a-h][1-8])([a-h][1-8])(.*)") for move in moves: match = pattern.match(move) if match: p, s, t, suf = match.groups() tokens.extend([p, s, t]) tokens.append(self._normalize_suffix(suf)) else: tokens.extend(["WP", "a1", "a1", "(-)"]) return tokens def _normalize_suffix(self, suf: str) -> str: suf = suf.strip() if not suf: return "(-)" if suf.startswith("x"): if "+" in suf: return "(x+)" if "#" in suf: return "(x#)" return "(x)" if suf == "+": return "(+)" if suf == "#": return "(#)" if suf in {"O", "o"}: return f"({suf})" if suf in {"Q", "=Q"}: return "=Q" return "(-)" def _convert_token_to_id(self, token: str) -> int: return self._vocab.get(token, self._vocab.get(self.UNK_TOKEN)) def _convert_id_to_token(self, index: int) -> str: return self._ids_to_tokens.get(index, self.UNK_TOKEN) def convert_tokens_to_string(self, tokens: List[str]) -> str: out = [] specials = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN} clean = [t for t in tokens if t not in specials] current_move = "" for i, t in enumerate(clean): if t == "(-)": pass else: current_move += t if (i + 1) % 4 == 0: out.append(current_move) current_move = "" if current_move: out.append(current_move) return " ".join(out) def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple: path = os.path.join(save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.json") with open(path, "w") as f: json.dump(self._vocab, f) return (path,) @classmethod def build_vocab_from_dataset(cls, *args, **kwargs): print("Using static 4-Step Split vocabulary.") return cls() # Register AutoTokenizer.register("ChessTokenizer", ChessTokenizer)