""" Decomposed Chess Tokenizer (v2) for the Chess Challenge. This tokenizer factorizes each move into a small set of reusable tokens: - One token for (color + piece): e.g. "WP", "BN" - One token for the from-square with role suffix: e.g. "e2_f" - One token for the to-square with role suffix: e.g. "e4_t" - Optional promotion token: "q", "r", "b", "n" It is compatible with the teacher evaluator's supported formats: - Standard: "WPe2e4", "BNg8f6", with optional annotations "(x)", "(+)", "(o)/(O)", "(Q)" - Decomposed: "WP e2_f e4_t" - UCI: "e2e4", "e7e8q" - UCI spaced: "e2 e4" The tokenizer parses those inputs and emits the decomposed tokens above. """ from __future__ import annotations import json import os import re from pathlib import Path from typing import Dict, List, Optional from transformers import PreTrainedTokenizer class ChessTokenizer(PreTrainedTokenizer): model_input_names = ["input_ids", "attention_mask"] vocab_files_names = {"vocab_file": "vocab.json"} PAD_TOKEN = "[PAD]" BOS_TOKEN = "[BOS]" EOS_TOKEN = "[EOS]" UNK_TOKEN = "[UNK]" _COLOR_PIECE_RE = re.compile(r"^[WB][PNBRQK]$") _SQUARE_RE = re.compile(r"[a-h][1-8]") _SQUARE_ROLE_RE = re.compile(r"^([a-h][1-8])_([ft])$", re.IGNORECASE) _PLAIN_SQUARE_RE = re.compile(r"^[a-h][1-8]$", re.IGNORECASE) def __init__( self, vocab_file: Optional[str] = None, vocab: Optional[Dict[str, int]] = None, **kwargs, ): self._pad_token = self.PAD_TOKEN self._bos_token = self.BOS_TOKEN self._eos_token = self.EOS_TOKEN self._unk_token = self.UNK_TOKEN # Remove any duplicate special-token entries passed through kwargs to avoid collisions. kwargs.pop("pad_token", None) kwargs.pop("bos_token", None) kwargs.pop("eos_token", None) kwargs.pop("unk_token", None) if vocab is not None: self._vocab = vocab elif vocab_file is not None and os.path.exists(vocab_file): with open(vocab_file, "r", encoding="utf-8") as f: self._vocab = json.load(f) else: self._vocab = self._create_default_vocab() self._ids_to_tokens = {v: k for k, v in self._vocab.items()} super().__init__( pad_token=self._pad_token, bos_token=self._bos_token, eos_token=self._eos_token, unk_token=self._unk_token, **kwargs, ) @classmethod def build_vocab_from_dataset( cls, *_, **__, ) -> "ChessTokenizer2": """ Kept for API compatibility with `train.py`. The v2 tokenizer uses a fixed vocabulary (colors/pieces/squares/promotions), so dataset statistics are not required. """ return cls() def _create_default_vocab(self) -> Dict[str, int]: special_tokens = [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN] color_pieces = [ f"{color}{piece}" for color in ("W", "B") for piece in ("P", "N", "B", "R", "Q", "K") ] squares = [f"{file}{rank}" for rank in range(1, 9) for file in "abcdefgh"] square_from = [f"{sq}_f" for sq in squares] square_to = [f"{sq}_t" for sq in squares] promotions = ["q", "r", "b", "n"] # Deterministic order for reproducibility. all_tokens = special_tokens + color_pieces + square_from + square_to + promotions return {tok: idx for idx, tok in enumerate(all_tokens)} @property def vocab_size(self) -> int: return len(self._vocab) def get_vocab(self) -> Dict[str, int]: return dict(self._vocab) def _tokenize(self, text: str) -> List[str]: parts = text.strip().split() if not parts: return [] out: List[str] = [] next_role = "f" # Used only when squares arrive without _f/_t. for part in parts: if part in {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN}: out.append(part) next_role = "f" continue # Decomposed color+piece token: "WP", "BN", ... if self._COLOR_PIECE_RE.match(part.upper()): out.append(part.upper()) next_role = "f" continue # Square with role suffix: "e2_f" / "e4_t" m_role = self._SQUARE_ROLE_RE.match(part) if m_role: sq = m_role.group(1).lower() role = m_role.group(2).lower() out.append(f"{sq}_{role}") next_role = "t" if role == "f" else "f" continue # Plain square: "e2" (assign role by position) if self._PLAIN_SQUARE_RE.match(part): sq = part.lower() out.append(f"{sq}_{next_role}") next_role = "t" if next_role == "f" else "f" continue # Promotion token as its own chunk: "q", "=Q", "(Q)" etc. promo = self._extract_promotion(part) if promo and self._looks_like_promo_only(part): out.append(promo) continue # Standard / UCI move chunk: "WPe2e4(x+)", "e2e4", "e7e8=Q", ... move_tokens = self._tokenize_move_chunk(part) if move_tokens: out.extend(move_tokens) next_role = "f" continue # Skip pure annotation chunks if they appear separated (rare). if re.fullmatch(r"[\(\)\+\*xoO=]+", part): continue out.append(self.UNK_TOKEN) return out def _looks_like_promo_only(self, part: str) -> bool: part_stripped = part.strip() if re.fullmatch(r"[qrbnQRBN]", part_stripped): return True if re.fullmatch(r"=[qrbnQRBN]", part_stripped): return True if re.fullmatch(r"\([qrbnQRBN]\)", part_stripped): return True return False def _extract_promotion(self, text: str) -> Optional[str]: text_lower = text.lower() m = re.search(r"\(([qrbn])\)", text_lower) if m: return m.group(1) m = re.search(r"=([qrbn])", text_lower) if m: return m.group(1) return None def _tokenize_move_chunk(self, chunk: str) -> List[str]: chunk_stripped = chunk.strip() if not chunk_stripped: return [] chunk_lower = chunk_stripped.lower() squares = re.findall(self._SQUARE_RE, chunk_lower) if len(squares) < 2: return [] from_sq, to_sq = squares[0], squares[1] color_piece = None if len(chunk_stripped) >= 2 and self._COLOR_PIECE_RE.match(chunk_stripped[:2].upper()): color_piece = chunk_stripped[:2].upper() tokens: List[str] = [] if color_piece: tokens.append(color_piece) tokens.append(f"{from_sq}_f") tokens.append(f"{to_sq}_t") # Promotion: look right after the destination square. after_to = chunk_lower.find(to_sq) if after_to != -1: remaining = chunk_lower[after_to + 2 : after_to + 6] m = re.search(r"[=]?([qrbn])", remaining) if m: tokens.append(m.group(1)) # Also support dataset-style "(Q)" promotions. promo = self._extract_promotion(chunk_stripped) if promo and promo not in tokens: tokens.append(promo) return tokens def _convert_token_to_id(self, token: str) -> int: return self._vocab.get(token, self._vocab.get(self.UNK_TOKEN, 0)) def _convert_id_to_token(self, index: int) -> str: return self._ids_to_tokens.get(index, self.UNK_TOKEN) def convert_tokens_to_string(self, tokens: List[str]) -> str: special = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN} return " ".join(t for t in tokens if t not in special) def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple: if not os.path.isdir(save_directory): os.makedirs(save_directory, exist_ok=True) vocab_file = os.path.join( save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.json", ) with open(vocab_file, "w", encoding="utf-8") as f: json.dump(self._vocab, f, ensure_ascii=False, indent=2) return (vocab_file,)