chess-bee-1 / tokenizer.py
FAdrien's picture
Chess Challenge submission by FAdrien
b6a6ce6 verified
"""
Custom Chess Tokenizer for the Chess Challenge.
This tokenizer treats each move as a single token using the extended UCI notation
from the Lichess dataset (e.g., WPe2e4, BNg8f6).
The dataset format uses:
- W/B prefix for White/Black
- Piece letter: P=Pawn, N=Knight, B=Bishop, R=Rook, Q=Queen, K=King
- Source and destination squares (e.g., e2e4)
- Special suffixes: (x)=capture, (+)=check, (+*)=checkmate, (o)/(O)=castling
"""
from __future__ import annotations
import json
import os
from pathlib import Path
from typing import Dict, List, Optional
from transformers import PreTrainedTokenizer
import re
# To decompose a move WBb5c6(x) into groups [color, piece, source, destination, suffix], here [W, B, b5, c6, (x)]
TOKEN_PATTERN_REGEX = r'^(?P<color>[WB])(?P<piece>[PNBRQK])(?P<src>[a-h][1-8])(?P<dst>[a-h][1-8])(?P<suffix>.*)$'
TOKEN_PATTERN = re.compile(TOKEN_PATTERN_REGEX)
# Do not consider capture, check, checkmate, castling and 'en passant' capture (E)
REPLACE_RULES = {
'x': '',
'+': '',
'*': '',
'#': '', # if any
'o': '',
'O': '',
'E': '',
'()': '',
}
def normalize(text: str) -> str:
_text = text.strip()
for k, v in REPLACE_RULES.items():
_text = _text.replace(k, v)
return _text
def decompose_into_groups(move: str) -> List[str]:
match = TOKEN_PATTERN.match(move)
return [match.group("color"), match.group("piece"), match.group("src"), match.group("dst"), match.group("suffix")]
def extract_promotion(suffix: str) -> Optional[str]:
if not suffix:
return None
# Look for promotion letter (Q, R, B, N), can handle arbitratry suffix (...)
m = re.search(r'[QRBN]', suffix.upper())
return m.group(0).lower() if m else None
class ChessTokenizer(PreTrainedTokenizer):
"""
A custom tokenizer for chess moves using extended UCI notation.
This tokenizer maps each possible chess move to a unique token ID.
The vocabulary is built from the training dataset to ensure all moves
encountered during training have a corresponding token.
Example:
>>> tokenizer = ChessTokenizer()
>>> tokenizer.encode("WPe2e4 BPe7e5")
[1, 42, 87, 2] # [BOS, e2e4, e7e5, EOS]
"""
model_input_names = ["input_ids", "attention_mask"]
vocab_files_names = {"vocab_file": "vocab.json"}
# Special tokens
PAD_TOKEN = "[PAD]"
BOS_TOKEN = "[BOS]"
EOS_TOKEN = "[EOS]"
UNK_TOKEN = "[UNK]"
WHITE = "[W]"
BLACK = "[B]"
PIECES = ["P", "N", "B", "R", "Q", "K"]
SQUARES = [f + r for f in "abcdefgh" for r in "12345678"]
PROMOS = ["q", "r", "b", "n"]
MOVE_SEP = "|"
def __init__(
self,
vocab_file: Optional[str] = None,
vocab: Optional[Dict[str, int]] = None,
**kwargs,
):
"""
Initialize the chess tokenizer.
Args:
vocab_file: Path to a JSON file containing the vocabulary mapping.
vocab: Dictionary mapping tokens to IDs (alternative to vocab_file).
**kwargs: Additional arguments passed to PreTrainedTokenizer.
"""
# Initialize special tokens
self._pad_token = self.PAD_TOKEN
self._bos_token = self.BOS_TOKEN
self._eos_token = self.EOS_TOKEN
self._unk_token = self.UNK_TOKEN
self.include_move_separator = False
# Remove any duplicate special-token entries passed through kwargs
# to avoid "multiple values for keyword" errors when loading from disk.
kwargs.pop("pad_token", None)
kwargs.pop("bos_token", None)
kwargs.pop("eos_token", None)
kwargs.pop("unk_token", None)
# Load or create vocabulary
if vocab is not None:
self._vocab = vocab
elif vocab_file is not None and os.path.exists(vocab_file):
with open(vocab_file, "r", encoding="utf-8") as f:
self._vocab = json.load(f)
else:
# Create a minimal vocabulary with just special tokens
# The full vocabulary should be built from the dataset
self._vocab = self._create_default_vocab()
# Create reverse mapping
self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
# Call parent init AFTER setting up vocab
super().__init__(
pad_token=self._pad_token,
bos_token=self._bos_token,
eos_token=self._eos_token,
unk_token=self._unk_token,
**kwargs,
)
def _create_default_vocab(self) -> Dict[str, int]:
"""
Create a minimal default vocabulary with just special tokens.
For the full vocabulary, use `build_vocab_from_dataset()`.
This minimal vocab is just a placeholder - you should build from data.
"""
special_tokens = [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN, self.WHITE, self.BLACK]
if self.include_move_separator:
special_tokens.append(self.MOVE_SEP)
vocab = {token: idx for idx, token in enumerate(special_tokens)}
idx = len(vocab)
for p in self.PIECES:
vocab[p] = idx
idx += 1
for s in self.SQUARES:
vocab[s] = idx
idx += 1
for p in self.PROMOS:
vocab[p] = idx
idx += 1
return vocab
@classmethod
def build_vocab_from_iterator(
cls,
iterator,
min_frequency: int = 1,
) -> "ChessTokenizer":
"""
Build a tokenizer vocabulary from an iterator of game strings.
Args:
iterator: An iterator yielding game strings (space-separated moves).
min_frequency: Minimum frequency for a token to be included.
Returns:
A ChessTokenizer with the built vocabulary.
"""
from collections import Counter
token_counts = Counter()
for game in iterator:
moves = normalize(game).split()
token_counts.update(moves)
# Filter by frequency
tokens = [
token for token, count in token_counts.items()
if count >= min_frequency
]
# Sort for reproducibility
tokens = sorted(tokens)
# Build vocabulary
special_tokens = [cls.PAD_TOKEN, cls.BOS_TOKEN, cls.EOS_TOKEN, cls.UNK_TOKEN]
vocab = {token: idx for idx, token in enumerate(special_tokens + tokens)}
return cls(vocab=vocab)
@classmethod
def build_vocab_from_dataset(
cls,
dataset_name: str = "dlouapre/lichess_2025-01_1M",
split: str = "train",
column: str = "text",
min_frequency: int = 500,
max_samples: Optional[int] = 100000,
) -> "ChessTokenizer":
"""
Build a tokenizer vocabulary from a Hugging Face dataset.
Args:
dataset_name: Name of the dataset on Hugging Face Hub.
split: Dataset split to use.
column: Column containing the game strings.
min_frequency: Minimum frequency for a token to be included (default: 500).
max_samples: Maximum number of samples to process (default: 100k).
Returns:
A ChessTokenizer with the built vocabulary.
"""
from datasets import load_dataset
dataset = load_dataset(dataset_name, split=split)
if max_samples is not None:
dataset = dataset.select(range(min(max_samples, len(dataset))))
def game_iterator():
for example in dataset:
yield example[column]
return cls()
#return cls.build_vocab_from_iterator(game_iterator(), min_frequency=min_frequency)
@property
def vocab_size(self) -> int:
"""Return the size of the vocabulary."""
return len(self._vocab)
def get_vocab(self) -> Dict[str, int]:
"""Return the vocabulary as a dictionary."""
return dict(self._vocab)
def _tokenize(self, text: str) -> List[str]:
"""
Tokenize a string of moves into a list of tokens.
Args:
text: A string of space-separated moves.
Returns:
List of move tokens.
"""
moves = normalize(text).split()
tokens = []
for move in moves:
color, piece, src, dest, suffix = decompose_into_groups(move)
promotion = extract_promotion(suffix)
tks = [
self.WHITE if piece == 'W' else self.BLACK,
piece,
src,
dest
]
if promotion is not None:
tks.append(promotion)
if self.include_move_separator:
tks.append(self.MOVE_SEP)
tokens.extend(tks)
return tokens
def decode(
self,
token_ids,
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool = True,
**kwargs,
) -> str:
"""
Decode token IDs to string, then fix promotion spacing.
Ensures promotions appear immediately after the destination square,
e.g., 'e7 e8 q' -> 'e7e8q', since the evaluator does not support this
"""
# Call parent decode
text = super().decode(
token_ids,
skip_special_tokens=skip_special_tokens,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
**kwargs,
)
# Fix promotions: remove space before q, r, b, n (case sensitive)
text = re.sub(r'\s([qrbn])\s', r'\1 ', text)
return text
def _convert_token_to_id(self, token: str) -> int:
"""Convert a token to its ID."""
return self._vocab.get(token, self._vocab.get(self.UNK_TOKEN, 0))
def _convert_id_to_token(self, index: int) -> str:
"""Convert an ID to its token."""
return self._ids_to_tokens.get(index, self.UNK_TOKEN)
def convert_tokens_to_string(self, tokens: List[str]) -> str:
"""Convert a list of tokens back to a string."""
# Filter out special tokens for cleaner output
special = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN}
return " ".join(t for t in tokens if t not in special)
def save_vocabulary(
self,
save_directory: str,
filename_prefix: Optional[str] = None,
) -> tuple:
"""
Save the vocabulary to a JSON file.
Args:
save_directory: Directory to save the vocabulary.
filename_prefix: Optional prefix for the filename.
Returns:
Tuple containing the path to the saved vocabulary file.
"""
if not os.path.isdir(save_directory):
os.makedirs(save_directory, exist_ok=True)
vocab_file = os.path.join(
save_directory,
(filename_prefix + "-" if filename_prefix else "") + "vocab.json",
)
with open(vocab_file, "w", encoding="utf-8") as f:
json.dump(self._vocab, f, ensure_ascii=False, indent=2)
return (vocab_file,)
def count_vocab_from_dataset(
dataset_name: str = "dlouapre/lichess_2025-01_1M",
split: str = "train",
column: str = "text",
max_samples: Optional[int] = 10000,
) -> Dict[str, int]:
"""
Count token frequencies in a dataset (useful for vocabulary analysis).
Args:
dataset_name: Name of the dataset on Hugging Face Hub.
split: Dataset split to use.
column: Column containing the game strings.
max_samples: Maximum number of samples to process.
Returns:
Dictionary mapping tokens to their frequencies.
"""
from collections import Counter
from datasets import load_dataset
dataset = load_dataset(dataset_name, split=split)
if max_samples is not None:
dataset = dataset.select(range(min(max_samples, len(dataset))))
token_counts = Counter()
for example in dataset:
moves = example[column].strip().split()
token_counts.update(moves)
return dict(token_counts)
if __name__ == '__main__':
#seq = '''WPe2e4 BPc7c5 WNg1f3 BNb8c6 WPd2d4 BPc5d4(x) WNf3d4(x) BPg7g6 WNb1c3 BBf8g7 WBc1e3 BPe7e6 WBf1c4 BNg8e7 WPf2f3 BKe8g8(o) WQd1d2 BPd7d5 WPe4d5(x) BPe6d5(x) WBc4b3 BNc6d4(x) WBe3d4(x) BNe7f5 WBd4g7(x) BKg8g7(x) WQd2d5(x) BRf8e8(+) WKe1f2 BBc8e6 WQd5d8(x) BRa8d8(x) WBb3e6(x) BRe8e6(x) WRh1e1 BRd8f8 WRe1e6(x) BPf7e6(x) WNc3e4 BRf8d8 WPc2c3 BNf5d6 WKf2e3 BNd6e4(x) WPf3e4(x) BKg7f6 WRa1f1(+) BKf6g7 WRf1f2 BPe6e5 WRf2d2 BRd8d2(x) WKe3d2(x) BKg7f7 WKd2e3 BKf7e6 WPg2g4 BPh7h6 WPh2h4 BPg6g5 WPh4h5 BPb7b5 WPb2b3 BKe6d6 WKe3d3 BKd6c5 WPc3c4 BPb5b4 WKd3e3 BPa7a6 WKe3d3 BPa6a5 WKd3e3 BKc5d6 WKe3d3 BKd6c5 WKd3e3 BKc5d6 WKe3d3 BKd6c5(+Q)'''
seq = "BKd6c5=Q"
tokenizer = ChessTokenizer()
tks = tokenizer.encode(seq)
txt = tokenizer.decode(tks)
print(txt)
#tokenizer = ChessTokenizer.build_vocab_from_dataset(min_frequency=500)
#print(tokenizer.vocab_size)
#print(tokenizer.get_vocab())