chess_vre / tokenizer.py
VreVre's picture
Chess Challenge submission by VreVre
b80658c verified
"""
Custom Chess Tokenizer for the Chess Challenge.
This tokenizer treats each move as a single token using the extended UCI notation
from the Lichess dataset (e.g., WPe2e4, BNg8f6).
The dataset format uses:
- W/B prefix for White/Black
- Piece letter: P=Pawn, N=Knight, B=Bishop, R=Rook, Q=Queen, K=King
- Source and destination squares (e.g., e2e4)
- Special suffixes: (x)=capture, (+)=check, (+*)=checkmate, (o)/(O)=castling
"""
from __future__ import annotations
import json
import os
from pathlib import Path
from typing import Dict, List, Optional
from transformers import PreTrainedTokenizer
import re
class ChessTokenizer(PreTrainedTokenizer):
"""
A custom tokenizer for chess moves using extended UCI notation.
This tokenizer maps each possible chess move to a unique token ID.
The vocabulary is built from the training dataset to ensure all moves
encountered during training have a corresponding token.
Example:
>>> tokenizer = ChessTokenizer()
>>> tokenizer.encode("WPe2e4 BPe7e5")
[1, 42, 87, 2] # [BOS, e2e4, e7e5, EOS]
"""
model_input_names = ["input_ids", "attention_mask"]
vocab_files_names = {"vocab_file": "vocab.json"}
# Special tokens
PAD_TOKEN = "[PAD]"
BOS_TOKEN = "[BOS]"
EOS_TOKEN = "[EOS]"
UNK_TOKEN = "[UNK]"
EOM_TOKEN = "EOM" # End of Move token
MOVE_REGEX = re.compile(
r"""
(?P<color>[WB])
(?P<piece>[PNBRQK])
(?P<from>[a-h][1-8])
(?P<capture>x)?
(?P<to>[a-h][1-8])
(?P<promotion>[QRBN])?
(?P<check>[+#])?
""",
re.VERBOSE,
)
def __init__(
self,
vocab_file: Optional[str] = None,
vocab: Optional[Dict[str, int]] = None,
**kwargs,
):
"""
Initialize the chess tokenizer.
Args:
vocab_file: Path to a JSON file containing the vocabulary mapping.
vocab: Dictionary mapping tokens to IDs (alternative to vocab_file).
**kwargs: Additional arguments passed to PreTrainedTokenizer.
"""
# Initialize special tokens
self._pad_token = self.PAD_TOKEN
self._bos_token = self.BOS_TOKEN
self._eos_token = self.EOS_TOKEN
self._unk_token = self.UNK_TOKEN
# Remove any duplicate special-token entries passed through kwargs
# to avoid "multiple values for keyword" errors when loading from disk.
kwargs.pop("pad_token", None)
kwargs.pop("bos_token", None)
kwargs.pop("eos_token", None)
kwargs.pop("unk_token", None)
# Load or create vocabulary
if vocab is not None:
self._vocab = vocab
elif vocab_file is not None and os.path.exists(vocab_file):
with open(vocab_file, "r", encoding="utf-8") as f:
self._vocab = json.load(f)
else:
# Create a minimal vocabulary with just special tokens
# The full vocabulary should be built from the dataset
self._vocab = self._create_default_vocab()
# Create reverse mapping
self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
# Call parent init AFTER setting up vocab
super().__init__(
pad_token=self._pad_token,
bos_token=self._bos_token,
eos_token=self._eos_token,
unk_token=self._unk_token,
**kwargs,
)
def _create_default_vocab(self) -> Dict[str, int]:
"""
Create a minimal default vocabulary with just special tokens.
For the full vocabulary, use `build_vocab_from_dataset()`.
This minimal vocab is just a placeholder - you should build from data.
"""
special_tokens = [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]
vocab = {token: idx for idx, token in enumerate(special_tokens)}
return vocab
@classmethod
def build_vocab_from_iterator(
cls,
iterator,
min_frequency: int = 1,
) -> "ChessTokenizer":
"""
Build a tokenizer vocabulary from an iterator of game strings.
Args:
iterator: An iterator yielding game strings (space-separated moves).
min_frequency: Minimum frequency for a token to be included.
Returns:
A ChessTokenizer with the built vocabulary.
"""
from collections import Counter
token_counts = Counter()
for game in iterator:
moves = game.strip().split()
token_counts.update(moves)
# Filter by frequency
tokens = [
token for token, count in token_counts.items()
if count >= min_frequency
]
# Sort for reproducibility
tokens = sorted(tokens)
# Build vocabulary
special_tokens = [cls.PAD_TOKEN, cls.BOS_TOKEN, cls.EOS_TOKEN, cls.UNK_TOKEN]
vocab = {token: idx for idx, token in enumerate(special_tokens + tokens)}
return cls(vocab=vocab)
@classmethod
def build_vocab_from_iterator_bis(
cls,
iterator,
min_frequency: int = 1,
) -> "ChessTokenizer":
from collections import Counter
import re
MOVE_REGEX = re.compile(
r"""
(?P<color>[WB])
(?P<piece>[PNBRQK])
(?P<from>[a-h][1-8])
(?P<to>[a-h][1-8])
(?P<capture>[(][x][)])?
(?P<check>[(][+][)])?
(?P<mate>[(][+][*][)])?
(?P<castle>[(][oO][)])?
""",
re.VERBOSE,
)
token_counts = Counter()
for game in iterator:
moves = game.strip().split()
for move in moves:
m = MOVE_REGEX.fullmatch(move)
if not m:
token_counts.update([cls.UNK_TOKEN])
continue
# Color + piece
token_counts.update([
m["color"],
m["piece"],
m['from'],
m['to'],
])
if m["capture"]:
token_counts.update(["CAP"])
if m["check"]:
token_counts.update(["CHECK"])
if m["mate"]:
token_counts.update(["MATE"])
if m["castle"]:
if m["castle"] == "(o)":
token_counts.update(["CASTLE_K"])
elif m["castle"] == "(O)":
token_counts.update(["CASTLE_Q"])
tokens = [tok for tok, c in token_counts.items() if c >= min_frequency]
tokens = tokens + ["EOM"]
tokens = sorted(tokens)
special_tokens = [
cls.PAD_TOKEN,
cls.BOS_TOKEN,
cls.EOS_TOKEN,
cls.UNK_TOKEN,
]
vocab = {tok: i for i, tok in enumerate(special_tokens + tokens)}
return cls(vocab=vocab)
@classmethod
def build_vocab_from_dataset(
cls,
dataset_name: str = "dlouapre/lichess_2025-01_1M",
split: str = "train",
column: str = "text",
min_frequency: int = 500,
max_samples: Optional[int] = 100000,
) -> "ChessTokenizer":
"""
Build a tokenizer vocabulary from a Hugging Face dataset.
Args:
dataset_name: Name of the dataset on Hugging Face Hub.
split: Dataset split to use.
column: Column containing the game strings.
min_frequency: Minimum frequency for a token to be included (default: 500).
max_samples: Maximum number of samples to process (default: 100k).
Returns:
A ChessTokenizer with the built vocabulary.
"""
from datasets import load_dataset
dataset = load_dataset(dataset_name, split=split)
if max_samples is not None:
dataset = dataset.select(range(min(max_samples, len(dataset))))
def game_iterator():
for example in dataset:
yield example[column]
return cls.build_vocab_from_iterator_bis(game_iterator(), min_frequency=min_frequency)
@property
def vocab_size(self) -> int:
"""Return the size of the vocabulary."""
return max(self._vocab.values()) + 1
def get_vocab(self) -> Dict[str, int]:
"""Return the vocabulary as a dictionary."""
return dict(self._vocab)
def _tokenize(self, text: str) -> List[str]:
tokens: List[str] = []
MOVE_REGEX = re.compile(
r"""
(?P<color>[WB])
(?P<piece>[PNBRQK])
(?P<from>[a-h][1-8])
(?P<to>[a-h][1-8])
(?P<capture>[(][x][)])?
(?P<check>[(][+][)])?
(?P<mate>[(][+][*][)])?
(?P<castle>[(][oO][)])?
""",
re.VERBOSE,
)
moves = text.split()
for move in moves:
# --- Castling ---
"""
if move in ("WKe1g1", "BKe8g8"):
tokens.extend(["W", "K", "CASTLE_K", "EOM"])
continue
if move in ("WKe1c1", "BKe8c8"):
tokens.extend(["W", "K", "CASTLE_Q", "EOM"])
continue
"""
m = MOVE_REGEX.fullmatch(move)
if not m:
tokens.append(self.UNK_TOKEN)
continue
# Color & piece
tokens.append(m["color"])
tokens.append(m["piece"])
tokens.append(m['from'])
# --- TO + FLAGS + EOM ---
to_tok = m['to']
tokens.append(to_tok)
if m["capture"]:
tokens.append("CAP")
if m["check"] == "+":
tokens.append("CHECK")
elif m["check"] == "#":
tokens.append("MATE")
if m["castle"]:
if m["castle"] == "(o)":
tokens.append("CASTLE_K")
elif m["castle"] == "(O)":
tokens.append("CASTLE_Q")
if to_tok:
tokens.append("EOM")
return tokens
def _convert_token_to_id(self, token: str) -> int:
"""Convert a token to its ID."""
return self._vocab.get(token, self._vocab.get(self.UNK_TOKEN, 0))
def _convert_id_to_token(self, index: int) -> str:
"""Convert an ID to its token."""
return self._ids_to_tokens.get(index, self.UNK_TOKEN)
def convert_tokens_to_string(self, tokens: List[str]) -> str:
special = {
self.PAD_TOKEN,
self.BOS_TOKEN,
self.EOS_TOKEN,
self.UNK_TOKEN,
}
moves = []
current = []
for tok in tokens:
if tok in special:
continue
if tok == "EOM":
moves.append("".join(current) + " ")
current = []
continue
if tok == "W" or tok == "B":
current.append(tok)
elif tok in {"P","N","B","R","Q","K"}:
current.append(tok)
elif tok == "CAP":
current.append("x")
elif tok == "CHECK":
current.append("+")
elif tok == "MATE":
current.append("+*")
elif tok.startswith("CASTLE"):
if tok == "CASTLE_K":
moves.append("(o)")
elif tok == "CASTLE_Q":
moves.append("(0)")
current = []
else :
current.append(tok[:2])
moves.append("".join(current))
return "".join(moves)
def save_vocabulary(
self,
save_directory: str,
filename_prefix: Optional[str] = None,
) -> tuple:
"""
Save the vocabulary to a JSON file.
Args:
save_directory: Directory to save the vocabulary.
filename_prefix: Optional prefix for the filename.
Returns:
Tuple containing the path to the saved vocabulary file.
"""
if not os.path.isdir(save_directory):
os.makedirs(save_directory, exist_ok=True)
vocab_file = os.path.join(
save_directory,
(filename_prefix + "-" if filename_prefix else "") + "vocab.json",
)
with open(vocab_file, "w", encoding="utf-8") as f:
json.dump(self._vocab, f, ensure_ascii=False, indent=2)
return (vocab_file,)
def count_vocab_from_dataset(
dataset_name: str = "dlouapre/lichess_2025-01_1M",
split: str = "train",
column: str = "text",
max_samples: Optional[int] = 10000,
) -> Dict[str, int]:
"""
Count token frequencies in a dataset (useful for vocabulary analysis).
Args:
dataset_name: Name of the dataset on Hugging Face Hub.
split: Dataset split to use.
column: Column containing the game strings.
max_samples: Maximum number of samples to process.
Returns:
Dictionary mapping tokens to their frequencies.
"""
from collections import Counter
from datasets import load_dataset
dataset = load_dataset(dataset_name, split=split)
if max_samples is not None:
dataset = dataset.select(range(min(max_samples, len(dataset))))
token_counts = Counter()
for example in dataset:
moves = example[column].strip().split()
token_counts.update(moves)
return dict(token_counts)