""" Custom Chess Tokenizer for the Chess Challenge. This tokenizer treats each move as a sequence of structured tokens derived from the extended UCI notation from the Lichess dataset (e.g., WPe2e4, BNg8f6). The dataset format uses: - W/B prefix for White/Black - Piece letter: P=Pawn, N=Knight, B=Bishop, R=Rook, Q=Queen, K=King - Source and destination squares (e.g., e2e4) - Special suffixes: (x)=capture, (+)=check, (+*)=checkmate, (o)/(O)=castling """ from __future__ import annotations import json import os import re from typing import Dict, List, Optional, Sequence, Union from transformers import PreTrainedTokenizer _MOVE_RE = re.compile( r"^(?P[WB])" r"(?P[PNBRQK])" r"(?P[a-h][1-8])" r"(?P[a-h][1-8])" r"(?P.*)$" ) class ChessTokenizer(PreTrainedTokenizer): """ A structured tokenizer for chess moves. Each move is decomposed into: SIDE_(W/B), PIECE_(P/N/B/R/Q/K), SQ_, SQ_, and optional flags: CAPTURE, CHECK, MATE, CASTLE, PROMO_(Q/R/B/N). This avoids UNK explosions when using a move-as-token vocabulary. """ model_input_names = ["input_ids", "attention_mask"] vocab_files_names = {"vocab_file": "vocab.json"} # Special tokens PAD_TOKEN = "[PAD]" BOS_TOKEN = "[BOS]" EOS_TOKEN = "[EOS]" UNK_TOKEN = "[UNK]" # Fixed token set SIDE_W = "SIDE_W" SIDE_B = "SIDE_B" PIECES = ["P", "N", "B", "R", "Q", "K"] PROMO_PREFIX = "PROMO_" CAPTURE = "CAPTURE" CHECK = "CHECK" MATE = "MATE" CASTLE = "CASTLE" def __init__( self, vocab_file: Optional[str] = None, vocab: Optional[Dict[str, int]] = None, **kwargs, ): kwargs.pop("pad_token", None) kwargs.pop("bos_token", None) kwargs.pop("eos_token", None) kwargs.pop("unk_token", None) self._pad_token = self.PAD_TOKEN self._bos_token = self.BOS_TOKEN self._eos_token = self.EOS_TOKEN self._unk_token = self.UNK_TOKEN if vocab is not None: self._vocab = vocab elif vocab_file is not None and os.path.exists(vocab_file): with open(vocab_file, "r", encoding="utf-8") as f: self._vocab = json.load(f) else: self._vocab = self._build_fixed_vocab() self._ids_to_tokens = {v: k for k, v in self._vocab.items()} super().__init__( pad_token=self._pad_token, bos_token=self._bos_token, eos_token=self._eos_token, unk_token=self._unk_token, **kwargs, ) def _build_fixed_vocab(self) -> Dict[str, int]: special = [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN] sides = [self.SIDE_W, self.SIDE_B] pieces = [f"PIECE_{p}" for p in self.PIECES] squares = [f"SQ_{file}{rank}" for file in "abcdefgh" for rank in "12345678"] promos = [f"{self.PROMO_PREFIX}{p}" for p in ["Q", "R", "B", "N"]] flags = [self.CAPTURE, self.CHECK, self.MATE, self.CASTLE] tokens = special + sides + pieces + squares + promos + flags return {tok: i for i, tok in enumerate(tokens)} @classmethod def build_vocab_from_dataset(cls, *args, **kwargs) -> "ChessTokenizer": """ Kept for API compatibility with the template training script. This tokenizer uses a fixed vocabulary (no dataset-dependent pruning). """ return cls() @classmethod def build_vocab_from_iterator(cls, *args, **kwargs) -> "ChessTokenizer": """ Kept for API compatibility. This tokenizer uses a fixed vocabulary. """ return cls() @property def vocab_size(self) -> int: return len(self._vocab) def get_vocab(self) -> Dict[str, int]: return dict(self._vocab) def _tokenize(self, text: str) -> List[str]: tokens: List[str] = [] moves = text.strip().split() for mv in moves: tokens.extend(self._tokenize_move(mv)) return tokens def _tokenize_move(self, move: str) -> List[str]: m = _MOVE_RE.match(move) if not m: return [self.UNK_TOKEN] side = m.group("side") piece = m.group("piece") src = m.group("src") dst = m.group("dst") rest = m.group("rest") or "" out: List[str] = [] out.append(self.SIDE_W if side == "W" else self.SIDE_B) out.append(f"PIECE_{piece}") out.append(f"SQ_{src}") out.append(f"SQ_{dst}") promo = self._parse_promotion(rest) if promo is not None: out.append(f"{self.PROMO_PREFIX}{promo}") if "(x)" in rest or "x" in rest: out.append(self.CAPTURE) if "(+*)" in rest or "++" in rest or "#" in rest: out.append(self.MATE) elif "(+)" in rest or "+" in rest: out.append(self.CHECK) if "(o)" in rest or "(O)" in rest or "O-O" in rest: out.append(self.CASTLE) return out def _parse_promotion(self, rest: str) -> Optional[str]: m = re.search(r"=([QRBNqrbn])", rest) if m: return m.group(1).upper() return None def _convert_token_to_id(self, token: str) -> int: return self._vocab.get(token, self._vocab[self.UNK_TOKEN]) def _convert_id_to_token(self, index: int) -> str: return self._ids_to_tokens.get(index, self.UNK_TOKEN) def convert_tokens_to_string(self, tokens: List[str]) -> str: special = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN} out: List[str] = [] for t in tokens: if t in special: continue out.append(t) return " ".join(out) def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple: if not os.path.isdir(save_directory): os.makedirs(save_directory, exist_ok=True) vocab_file = os.path.join( save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.json", ) with open(vocab_file, "w", encoding="utf-8") as f: json.dump(self._vocab, f, ensure_ascii=False, indent=2) return (vocab_file,) def decode(self, token_ids: Union[int, Sequence[int]], skip_special_tokens: bool = False, **kwargs) -> str: if isinstance(token_ids, int): ids = [token_ids] elif "torch" in str(type(token_ids)): ids = token_ids.detach().cpu().flatten().tolist() else: ids = list(token_ids) toks = [self._convert_id_to_token(i) for i in ids] if skip_special_tokens: special = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN} toks = [t for t in toks if t not in special] return self.convert_tokens_to_string(toks)