""" 4-Step Split Tokenizer Splits moves into: [Piece] -> [From] -> [To] -> [Suffix] Minimizes vocabulary to ~150 tokens. """ from __future__ import annotations import json import os import re from typing import Dict, List, Optional from transformers import PreTrainedTokenizer, AutoTokenizer class ChessTokenizer(PreTrainedTokenizer): vocab_files_names = {"vocab_file": "vocab.json"} model_input_names = ["input_ids", "attention_mask"] # 1. Pieces PIECES = ["WP", "WN", "WB", "WR", "WQ", "WK", "BP", "BN", "BB", "BR", "BQ", "BK"] # 2. Squares SQUARES = [f"{c}{r}" for c in "abcdefgh" for r in "12345678"] # 3. Suffixes (Crucial: (-) represents "No Suffix/Quiet Move") SUFFIXES = ["(-)", "(x)", "(+)", "(#)", "(x+)", "(x#)", "(O)", "(o)", "(Q)", "=Q"] PAD_TOKEN = "[PAD]" BOS_TOKEN = "[BOS]" EOS_TOKEN = "[EOS]" UNK_TOKEN = "[UNK]" # def __init__(self, vocab_file: Optional[str] = None, vocab: Optional[Dict[str, int]] = None, **kwargs): # # 1. Build or Load Vocab first # self._vocab = vocab # if vocab_file and os.path.exists(vocab_file): # with open(vocab_file, "r", encoding="utf-8") as f: # self._vocab = json.load(f) # if not self._vocab: # self._vocab = self._build_split_vocab() # self._ids_to_tokens = {v: k for k, v in self._vocab.items()} # # 2. Call parent init with explicit tokens to prevent auto-add errors # super().__init__( # pad_token=self.PAD_TOKEN, # bos_token=self.BOS_TOKEN, # eos_token=self.EOS_TOKEN, # unk_token=self.UNK_TOKEN, # **kwargs, # ) def __init__(self, vocab_file: Optional[str] = None, vocab: Optional[Dict[str, int]] = None, **kwargs): # 1. Build or Load Vocab self._vocab = vocab if vocab_file and os.path.exists(vocab_file): with open(vocab_file, "r", encoding="utf-8") as f: self._vocab = json.load(f) if not self._vocab: self._vocab = self._build_split_vocab() self._ids_to_tokens = {v: k for k, v in self._vocab.items()} # 2. Handle Special Tokens Safely # We "pop" them from kwargs to prevent the "multiple values" error. # This prioritizes the loaded config (kwargs) if it exists, # falling back to your class constants if it doesn't. pad_token = kwargs.pop("pad_token", self.PAD_TOKEN) bos_token = kwargs.pop("bos_token", self.BOS_TOKEN) eos_token = kwargs.pop("eos_token", self.EOS_TOKEN) unk_token = kwargs.pop("unk_token", self.UNK_TOKEN) # 3. Call parent super().__init__( pad_token=pad_token, bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs, ) def _build_split_vocab(self): tokens = [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN] tokens += self.PIECES + self.SQUARES + self.SUFFIXES # Sort and unique to be safe unique_tokens = sorted(list(set(tokens))) return {t: i for i, t in enumerate(unique_tokens)} def get_vocab(self) -> Dict[str, int]: """Required by Hugging Face PreTrainedTokenizer""" return dict(self._vocab) @property def vocab_size(self) -> int: return len(self._vocab) def _tokenize(self, text: str) -> List[str]: moves = text.strip().split() tokens = [] # Regex: (Piece)(Square)(Square)(Optional Suffix) pattern = re.compile(r"([WB][PNBRQK])([a-h][1-8])([a-h][1-8])(.*)") for move in moves: match = pattern.match(move) if match: p, s, t, suf = match.groups() tokens.extend([p, s, t]) tokens.append(suf if suf else "(-)") else: tokens.append(self.UNK_TOKEN) return tokens def _convert_token_to_id(self, token: str) -> int: return self._vocab.get(token, self._vocab.get(self.UNK_TOKEN)) def _convert_id_to_token(self, index: int) -> str: return self._ids_to_tokens.get(index, self.UNK_TOKEN) def convert_tokens_to_string(self, tokens: List[str]) -> str: out = [] specials = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN} clean = [t for t in tokens if t not in specials] current_move = "" for i, t in enumerate(clean): if t == "(-)": pass else: current_move += t # Every 4th token completes a move if (i + 1) % 4 == 0: out.append(current_move) current_move = "" if current_move: out.append(current_move) return " ".join(out) def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple: path = os.path.join(save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.json") with open(path, "w") as f: json.dump(self._vocab, f) return (path,) @classmethod def build_vocab_from_dataset(cls, *args, **kwargs): print("Using static 4-Step Split vocabulary.") return cls() # Register AutoTokenizer.register("ChessTokenizer", ChessTokenizer)