| | """ |
| | Chess Tokenizer (Refactored). |
| | |
| | Architecture: |
| | - Splits chess moves into atomic component tokens. |
| | - Structure: [Actor] -> [Source_Square] -> [Target_Square] -> [Promotion?] |
| | - Output format: "WP", "e2_f", "e4_t" |
| | """ |
| |
|
| | from __future__ import annotations |
| |
|
| | import json |
| | import os |
| | import re |
| | from typing import Dict, List, Optional, Any, Tuple |
| |
|
| | from transformers import PreTrainedTokenizer |
| |
|
| |
|
| | class ChessTokenizer(PreTrainedTokenizer): |
| | """ |
| | A tokenizer that breaks chess moves into explicit actor and coordinate tokens. |
| | Designed for high-precision state tracking. |
| | """ |
| | |
| | model_input_names = ["input_ids", "attention_mask"] |
| | vocab_files_names = {"vocab_file": "vocab.json"} |
| |
|
| | |
| | TOKENS_SPECIAL = ["[PAD]", "[BOS]", "[EOS]", "[UNK]"] |
| | CHARS_PIECE = "PNBRQK" |
| | CHARS_COLOR = "WB" |
| | CHARS_FILE = "abcdefgh" |
| | CHARS_RANK = "12345678" |
| | CHARS_PROMO = {"q", "r", "b", "n"} |
| |
|
| | |
| | |
| | PATTERN_MOVE = re.compile(r"^([WB])([PNBRQK])([a-h][1-8])([a-h][1-8])(.*)$") |
| |
|
| | def __init__( |
| | self, |
| | vocab_file: Optional[str] = None, |
| | vocab: Optional[Dict[str, int]] = None, |
| | **kwargs: Any, |
| | ): |
| | |
| | self._pad_token = self.TOKENS_SPECIAL[0] |
| | self._bos_token = self.TOKENS_SPECIAL[1] |
| | self._eos_token = self.TOKENS_SPECIAL[2] |
| | self._unk_token = self.TOKENS_SPECIAL[3] |
| |
|
| | |
| | for token_arg in ["pad_token", "bos_token", "eos_token", "unk_token"]: |
| | kwargs.pop(token_arg, None) |
| |
|
| | |
| | if vocab: |
| | self._vocab = vocab |
| | elif vocab_file and os.path.isfile(vocab_file): |
| | with open(vocab_file, "r", encoding="utf-8") as f: |
| | self._vocab = json.load(f) |
| | else: |
| | self._vocab = self._generate_vocabulary() |
| |
|
| | |
| | self._id_to_token = {v: k for k, v in self._vocab.items()} |
| |
|
| | super().__init__( |
| | pad_token=self._pad_token, |
| | bos_token=self._bos_token, |
| | eos_token=self._eos_token, |
| | unk_token=self._unk_token, |
| | **kwargs, |
| | ) |
| |
|
| | def _generate_vocabulary(self) -> Dict[str, int]: |
| | """Constructs the fixed dictionary of tokens.""" |
| | token_list = list(self.TOKENS_SPECIAL) |
| |
|
| | |
| | token_list.extend( |
| | f"{c}{p}" for c in self.CHARS_COLOR for p in self.CHARS_PIECE |
| | ) |
| |
|
| | |
| | squares = [f"{f}{r}" for r in self.CHARS_RANK for f in self.CHARS_FILE] |
| | token_list.extend(f"{sq}_f" for sq in squares) |
| | token_list.extend(f"{sq}_t" for sq in squares) |
| |
|
| | |
| | token_list.extend(sorted(self.CHARS_PROMO)) |
| |
|
| | return {token: idx for idx, token in enumerate(token_list)} |
| |
|
| | @property |
| | def vocab_size(self) -> int: |
| | return len(self._vocab) |
| |
|
| | def get_vocab(self) -> Dict[str, int]: |
| | return self._vocab.copy() |
| |
|
| | def _tokenize(self, text: str) -> List[str]: |
| | """ |
| | Parses a string of moves into atomic tokens. |
| | Input: "WPe2e4 BNg8f6" |
| | Output: ["WP", "e2_f", "e4_t", "BN", "g8_f", "f6_t"] |
| | """ |
| | if not text: |
| | return [] |
| |
|
| | tokens = [] |
| | raw_items = text.strip().split() |
| | special_set = set(self.TOKENS_SPECIAL) |
| |
|
| | for item in raw_items: |
| | |
| | if item in special_set: |
| | tokens.append(item) |
| | continue |
| |
|
| | |
| | match = self.PATTERN_MOVE.match(item) |
| | if not match: |
| | tokens.append(self.unk_token) |
| | continue |
| |
|
| | |
| | color, piece, src, dst, suffix = match.groups() |
| |
|
| | |
| | tokens.append(f"{color}{piece}") |
| | |
| | |
| | tokens.append(f"{src}_f") |
| | |
| | |
| | tokens.append(f"{dst}_t") |
| |
|
| | |
| | |
| | if suffix: |
| | if "=" in suffix: |
| | |
| | eq_idx = suffix.find("=") |
| | if eq_idx + 1 < len(suffix): |
| | promo_char = suffix[eq_idx + 1].lower() |
| | if promo_char in self.CHARS_PROMO: |
| | tokens.append(promo_char) |
| | |
| | return tokens |
| |
|
| | def _convert_token_to_id(self, token: str) -> int: |
| | return self._vocab.get(token, self.unk_token_id) |
| |
|
| | def _convert_id_to_token(self, index: int) -> str: |
| | return self._id_to_token.get(index, self.unk_token) |
| |
|
| | def convert_tokens_to_string(self, tokens: List[str]) -> str: |
| | """Joins tokens into a space-separated string, filtering out specials.""" |
| | special_set = set(self.TOKENS_SPECIAL) |
| | return " ".join(t for t in tokens if t not in special_set) |
| |
|
| | def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: |
| | if not os.path.exists(save_directory): |
| | os.makedirs(save_directory, exist_ok=True) |
| |
|
| | filename = "vocab.json" |
| | if filename_prefix: |
| | filename = f"{filename_prefix}-{filename}" |
| | |
| | full_path = os.path.join(save_directory, filename) |
| |
|
| | with open(full_path, "w", encoding="utf-8") as f: |
| | json.dump(self._vocab, f, indent=2, ensure_ascii=False) |
| |
|
| | return (full_path,) |
| |
|
| | @classmethod |
| | def build_vocab_from_dataset(cls, *args: Any, **kwargs: Any) -> "ChessTokenizer": |
| | """Compatibility method for training pipelines.""" |
| | return cls() |