""" Optimized Chess Tokenizer for the Chess Challenge. Strategies for smaller vocabulary: 1. Remove rare moves (high min_frequency threshold) 2. Decompose moves into sub-tokens (piece + squares) 3. Merge similar move patterns This tokenizer uses a hybrid approach: - Common moves as single tokens (efficient for frequent patterns) - Sub-token decomposition for rare moves (better generalization) """ from __future__ import annotations import json import os import re from collections import Counter from pathlib import Path from typing import Dict, List, Optional, Tuple from transformers import PreTrainedTokenizer class ChessTokenizer(PreTrainedTokenizer): """ Optimized chess tokenizer with smaller vocabulary. Uses move decomposition for rare moves to reduce vocabulary size while maintaining good coverage. """ model_input_names = ["input_ids", "attention_mask"] vocab_files_names = {"vocab_file": "vocab.json"} # Special tokens PAD_TOKEN = "[PAD]" BOS_TOKEN = "[BOS]" EOS_TOKEN = "[EOS]" UNK_TOKEN = "[UNK]" # Sub-token markers for decomposed moves PIECE_PREFIX = "P:" # P:WP, P:BN, etc. FROM_PREFIX = "F:" # F:e2, F:g1, etc. TO_PREFIX = "T:" # T:e4, T:f3, etc. SUFFIX_PREFIX = "S:" # S:(x), S:(+), etc. def __init__( self, vocab_file: Optional[str] = None, vocab: Optional[Dict[str, int]] = None, use_decomposition: bool = True, **kwargs, ): # Initialize special tokens self._pad_token = self.PAD_TOKEN self._bos_token = self.BOS_TOKEN self._eos_token = self.EOS_TOKEN self._unk_token = self.UNK_TOKEN # Whether to use sub-token decomposition self.use_decomposition = use_decomposition # Remove duplicate special token kwargs kwargs.pop("pad_token", None) kwargs.pop("bos_token", None) kwargs.pop("eos_token", None) kwargs.pop("unk_token", None) # Load or create vocabulary if vocab is not None: self._vocab = vocab elif vocab_file is not None and os.path.exists(vocab_file): with open(vocab_file, "r", encoding="utf-8") as f: self._vocab = json.load(f) else: self._vocab = self._create_default_vocab() # Create reverse mapping self._ids_to_tokens = {v: k for k, v in self._vocab.items()} # Build set of full-move tokens for fast lookup self._full_move_tokens = { t for t in self._vocab.keys() if not t.startswith(("[", "P:", "F:", "T:", "S:")) } super().__init__( pad_token=self._pad_token, bos_token=self._bos_token, eos_token=self._eos_token, unk_token=self._unk_token, **kwargs, ) def _create_default_vocab(self) -> Dict[str, int]: """Create minimal default vocabulary.""" special_tokens = [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN] return {token: idx for idx, token in enumerate(special_tokens)} @staticmethod def _parse_move(move: str) -> Optional[Tuple[str, str, str, str]]: """ Parse a move into components: (color+piece, from_square, to_square, suffix). Example: "WPe2e4" -> ("WP", "e2", "e4", "") "BNg8f6(x)" -> ("BN", "g8", "f6", "(x)") """ # Pattern: [WB][PNBRQK][a-h][1-8][a-h][1-8](\(.+\))? pattern = r'^([WB][PNBRQK])([a-h][1-8])([a-h][1-8])(\([^)]+\))?$' match = re.match(pattern, move) if match: piece = match.group(1) from_sq = match.group(2) to_sq = match.group(3) suffix = match.group(4) or "" return (piece, from_sq, to_sq, suffix) return None def _decompose_move(self, move: str) -> List[str]: """ Decompose a move into sub-tokens. Example: "WPe2e4" -> ["P:WP", "F:e2", "T:e4"] "BNg8f6(x)" -> ["P:BN", "F:g8", "T:f6", "S:(x)"] """ parsed = self._parse_move(move) if parsed is None: return [self.UNK_TOKEN] piece, from_sq, to_sq, suffix = parsed tokens = [ f"{self.PIECE_PREFIX}{piece}", f"{self.FROM_PREFIX}{from_sq}", f"{self.TO_PREFIX}{to_sq}", ] if suffix: tokens.append(f"{self.SUFFIX_PREFIX}{suffix}") return tokens def _tokenize(self, text: str) -> List[str]: """ Tokenize text into tokens. Uses full-move tokens for common moves, decomposes rare moves. """ tokens = [] for word in text.strip().split(): if word in self._full_move_tokens: # Common move - use as single token tokens.append(word) elif word in self._vocab: # Special token or sub-token tokens.append(word) elif self.use_decomposition: # Rare move - decompose into sub-tokens sub_tokens = self._decompose_move(word) # Check if all sub-tokens are in vocab if all(t in self._vocab for t in sub_tokens): tokens.extend(sub_tokens) else: tokens.append(self.UNK_TOKEN) else: tokens.append(self.UNK_TOKEN) return tokens @classmethod def build_vocab_from_dataset( cls, dataset_name: str = "dlouapre/lichess_2025-01_1M", split: str = "train", column: str = "text", min_frequency: int = 1000, max_vocab_size: int = 1500, max_samples: Optional[int] = 200000, use_decomposition: bool = True, ) -> "ChessTokenizer": """ Build optimized vocabulary from dataset. Strategy: 1. Count all moves 2. Keep frequent moves as full tokens 3. Add sub-tokens for decomposition 4. Limit total vocabulary size """ from datasets import load_dataset print(f"Building vocabulary from {dataset_name}...") dataset = load_dataset(dataset_name, split=split) if max_samples is not None: dataset = dataset.select(range(min(max_samples, len(dataset)))) # Count all moves move_counts = Counter() for example in dataset: moves = example[column].strip().split() move_counts.update(moves) print(f"Total unique moves: {len(move_counts)}") # Start with special tokens vocab = { cls.PAD_TOKEN: 0, cls.BOS_TOKEN: 1, cls.EOS_TOKEN: 2, cls.UNK_TOKEN: 3, } idx = 4 if use_decomposition: # Add sub-tokens first pieces = ["WP", "WN", "WB", "WR", "WQ", "WK", "BP", "BN", "BB", "BR", "BQ", "BK"] squares = [f"{f}{r}" for f in "abcdefgh" for r in "12345678"] suffixes = ["(x)", "(+)", "(x+)", "(+*)", "(x+*)", "(o)", "(O)", "(Q)", "(R)", "(B)", "(N)"] # Add piece tokens for p in pieces: vocab[f"{cls.PIECE_PREFIX}{p}"] = idx idx += 1 # Add square tokens (from and to) for sq in squares: vocab[f"{cls.FROM_PREFIX}{sq}"] = idx idx += 1 vocab[f"{cls.TO_PREFIX}{sq}"] = idx idx += 1 # Add suffix tokens for s in suffixes: vocab[f"{cls.SUFFIX_PREFIX}{s}"] = idx idx += 1 # Add frequent full moves frequent_moves = [ move for move, count in move_counts.most_common() if count >= min_frequency ] # Sort for reproducibility frequent_moves = sorted(frequent_moves) # Limit vocabulary size available_slots = max_vocab_size - len(vocab) frequent_moves = frequent_moves[:available_slots] for move in frequent_moves: if move not in vocab: vocab[move] = idx idx += 1 print(f"Final vocabulary size: {len(vocab)}") print(f" - Special tokens: 4") print(f" - Sub-tokens: {idx - 4 - len(frequent_moves)}") print(f" - Full moves: {len(frequent_moves)}") return cls(vocab=vocab, use_decomposition=use_decomposition) @classmethod def build_simple_vocab( cls, dataset_name: str = "dlouapre/lichess_2025-01_1M", split: str = "train", column: str = "text", min_frequency: int = 2000, max_samples: Optional[int] = 200000, ) -> "ChessTokenizer": """ Build simple vocabulary without decomposition. Just keeps frequent moves, maps rare to UNK. """ from datasets import load_dataset print(f"Building simple vocabulary from {dataset_name}...") dataset = load_dataset(dataset_name, split=split) if max_samples is not None: dataset = dataset.select(range(min(max_samples, len(dataset)))) move_counts = Counter() for example in dataset: moves = example[column].strip().split() move_counts.update(moves) # Keep only frequent moves vocab = { cls.PAD_TOKEN: 0, cls.BOS_TOKEN: 1, cls.EOS_TOKEN: 2, cls.UNK_TOKEN: 3, } frequent_moves = sorted([ move for move, count in move_counts.items() if count >= min_frequency ]) for idx, move in enumerate(frequent_moves, start=4): vocab[move] = idx print(f"Vocabulary size: {len(vocab)}") return cls(vocab=vocab, use_decomposition=False) @property def vocab_size(self) -> int: return len(self._vocab) def get_vocab(self) -> Dict[str, int]: return dict(self._vocab) def _convert_token_to_id(self, token: str) -> int: return self._vocab.get(token, self._vocab.get(self.UNK_TOKEN, 0)) def _convert_id_to_token(self, index: int) -> str: return self._ids_to_tokens.get(index, self.UNK_TOKEN) def convert_tokens_to_string(self, tokens: List[str]) -> str: """Convert tokens back to string, reconstructing decomposed moves.""" special = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN} result = [] i = 0 while i < len(tokens): token = tokens[i] if token in special: i += 1 continue # Check if this is a decomposed move if token.startswith(self.PIECE_PREFIX): # Reconstruct move from sub-tokens piece = token[len(self.PIECE_PREFIX):] from_sq = "" to_sq = "" suffix = "" if i + 1 < len(tokens) and tokens[i + 1].startswith(self.FROM_PREFIX): from_sq = tokens[i + 1][len(self.FROM_PREFIX):] i += 1 if i + 1 < len(tokens) and tokens[i + 1].startswith(self.TO_PREFIX): to_sq = tokens[i + 1][len(self.TO_PREFIX):] i += 1 if i + 1 < len(tokens) and tokens[i + 1].startswith(self.SUFFIX_PREFIX): suffix = tokens[i + 1][len(self.SUFFIX_PREFIX):] i += 1 result.append(f"{piece}{from_sq}{to_sq}{suffix}") else: result.append(token) i += 1 return " ".join(result) def save_vocabulary( self, save_directory: str, filename_prefix: Optional[str] = None, ) -> tuple: if not os.path.isdir(save_directory): os.makedirs(save_directory, exist_ok=True) vocab_file = os.path.join( save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.json", ) with open(vocab_file, "w", encoding="utf-8") as f: json.dump(self._vocab, f, ensure_ascii=False, indent=2) return (vocab_file,) def count_vocab_from_dataset( dataset_name: str = "dlouapre/lichess_2025-01_1M", split: str = "train", column: str = "text", max_samples: Optional[int] = 10000, ) -> Dict[str, int]: """Count token frequencies in dataset.""" from collections import Counter from datasets import load_dataset dataset = load_dataset(dataset_name, split=split) if max_samples is not None: dataset = dataset.select(range(min(max_samples, len(dataset)))) token_counts = Counter() for example in dataset: moves = example[column].strip().split() token_counts.update(moves) return dict(token_counts)