""" Chess Tokenizer (Refactored). Architecture: - Splits chess moves into atomic component tokens. - Structure: [Actor] -> [Source_Square] -> [Target_Square] -> [Promotion?] - Output format: "WP", "e2_f", "e4_t" """ from __future__ import annotations import json import os import re from typing import Dict, List, Optional, Any, Tuple from transformers import PreTrainedTokenizer class ChessTokenizer(PreTrainedTokenizer): """ A tokenizer that breaks chess moves into explicit actor and coordinate tokens. Designed for high-precision state tracking. """ model_input_names = ["input_ids", "attention_mask"] vocab_files_names = {"vocab_file": "vocab.json"} # --- Configuration Constants --- TOKENS_SPECIAL = ["[PAD]", "[BOS]", "[EOS]", "[UNK]"] CHARS_PIECE = "PNBRQK" CHARS_COLOR = "WB" CHARS_FILE = "abcdefgh" CHARS_RANK = "12345678" CHARS_PROMO = {"q", "r", "b", "n"} # Regex to validate and parse standard Lichess moves (e.g., WPe2e4) # Group 1: Color, 2: Piece, 3: Source, 4: Target, 5: Suffix PATTERN_MOVE = re.compile(r"^([WB])([PNBRQK])([a-h][1-8])([a-h][1-8])(.*)$") def __init__( self, vocab_file: Optional[str] = None, vocab: Optional[Dict[str, int]] = None, **kwargs: Any, ): # Initialize special tokens for the parent class self._pad_token = self.TOKENS_SPECIAL[0] self._bos_token = self.TOKENS_SPECIAL[1] self._eos_token = self.TOKENS_SPECIAL[2] self._unk_token = self.TOKENS_SPECIAL[3] # Clean kwargs to prevent collisions for token_arg in ["pad_token", "bos_token", "eos_token", "unk_token"]: kwargs.pop(token_arg, None) # 1. Load Vocabulary if vocab: self._vocab = vocab elif vocab_file and os.path.isfile(vocab_file): with open(vocab_file, "r", encoding="utf-8") as f: self._vocab = json.load(f) else: self._vocab = self._generate_vocabulary() # 2. Build ID-to-Token Map self._id_to_token = {v: k for k, v in self._vocab.items()} super().__init__( pad_token=self._pad_token, bos_token=self._bos_token, eos_token=self._eos_token, unk_token=self._unk_token, **kwargs, ) def _generate_vocabulary(self) -> Dict[str, int]: """Constructs the fixed dictionary of tokens.""" token_list = list(self.TOKENS_SPECIAL) # A. Actor Tokens (e.g., WP, BN) token_list.extend( f"{c}{p}" for c in self.CHARS_COLOR for p in self.CHARS_PIECE ) # B. Coordinate Tokens (Source & Target) squares = [f"{f}{r}" for r in self.CHARS_RANK for f in self.CHARS_FILE] token_list.extend(f"{sq}_f" for sq in squares) # From token_list.extend(f"{sq}_t" for sq in squares) # To # C. Promotion Tokens (Sorted for consistency) token_list.extend(sorted(self.CHARS_PROMO)) return {token: idx for idx, token in enumerate(token_list)} @property def vocab_size(self) -> int: return len(self._vocab) def get_vocab(self) -> Dict[str, int]: return self._vocab.copy() def _tokenize(self, text: str) -> List[str]: """ Parses a string of moves into atomic tokens. Input: "WPe2e4 BNg8f6" Output: ["WP", "e2_f", "e4_t", "BN", "g8_f", "f6_t"] """ if not text: return [] tokens = [] raw_items = text.strip().split() special_set = set(self.TOKENS_SPECIAL) for item in raw_items: # Pass through special tokens immediately if item in special_set: tokens.append(item) continue # Parse move structure match = self.PATTERN_MOVE.match(item) if not match: tokens.append(self.unk_token) continue # Deconstruct parts color, piece, src, dst, suffix = match.groups() # 1. Actor (Who) tokens.append(f"{color}{piece}") # 2. Origin (Where from) tokens.append(f"{src}_f") # 3. Destination (Where to) tokens.append(f"{dst}_t") # 4. Promotion (Transformation) # Check for suffixes like "=Q" or trailing chars if suffix: if "=" in suffix: # Look for the character immediately following '=' eq_idx = suffix.find("=") if eq_idx + 1 < len(suffix): promo_char = suffix[eq_idx + 1].lower() if promo_char in self.CHARS_PROMO: tokens.append(promo_char) return tokens def _convert_token_to_id(self, token: str) -> int: return self._vocab.get(token, self.unk_token_id) def _convert_id_to_token(self, index: int) -> str: return self._id_to_token.get(index, self.unk_token) def convert_tokens_to_string(self, tokens: List[str]) -> str: """Joins tokens into a space-separated string, filtering out specials.""" special_set = set(self.TOKENS_SPECIAL) return " ".join(t for t in tokens if t not in special_set) def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: if not os.path.exists(save_directory): os.makedirs(save_directory, exist_ok=True) filename = "vocab.json" if filename_prefix: filename = f"{filename_prefix}-{filename}" full_path = os.path.join(save_directory, filename) with open(full_path, "w", encoding="utf-8") as f: json.dump(self._vocab, f, indent=2, ensure_ascii=False) return (full_path,) @classmethod def build_vocab_from_dataset(cls, *args: Any, **kwargs: Any) -> "ChessTokenizer": """Compatibility method for training pipelines.""" return cls()