tanguy-chess7 / tokenizer.py
Tanguy85's picture
Chess Challenge submission by Tanguy85
f3100e5 verified
"""
Custom Chess Tokenizer for the Chess Challenge.
This tokenizer supports TWO tokenization modes:
1) tokenization_mode="move" (original)
- Each move is a single token using the extended UCI notation
from the Lichess dataset (e.g., WPe2e4, BNg8f6, WPe7e8=Q(x+), ...).
- Vocabulary is usually built from the dataset (frequency threshold).
2) tokenization_mode="uci_square" (recommended for good legal-move performance with small vocab)
- Each move is decomposed into 3 tokens:
[from_square, to_square, promotion_or_-]
Example:
"WPe2e4" -> ["e2", "e4", "-"]
"WPe7e8=Q(+)" -> ["e7", "e8", "q"]
- Fixed vocabulary that can express ANY UCI move:
specials (4) + squares (64) + promo tokens (5) = 73 tokens.
Why uci_square helps:
- You can keep vocab tiny (70-150 range) WITHOUT losing expressivity,
so the model can still output any move.
"""
from __future__ import annotations
import json
import os
import re
from typing import Dict, List, Optional
from transformers import PreTrainedTokenizer
class ChessTokenizer(PreTrainedTokenizer):
"""
A custom tokenizer for chess moves.
- "move" mode: extended-uci move tokens like "WPe2e4"
- "uci_square" mode: squares + promotion tokens
"""
model_input_names = ["input_ids", "attention_mask"]
vocab_files_names = {"vocab_file": "vocab.json"}
# Special tokens
PAD_TOKEN = "[PAD]"
BOS_TOKEN = "[BOS]"
EOS_TOKEN = "[EOS]"
UNK_TOKEN = "[UNK]"
def __init__(
self,
vocab_file: Optional[str] = None,
vocab: Optional[Dict[str, int]] = None,
**kwargs,
):
"""
Initialize the chess tokenizer.
Args:
vocab_file: Path to a JSON file containing the vocabulary mapping.
vocab: Dictionary mapping tokens to IDs (alternative to vocab_file).
kwargs:
- tokenization_mode: "move" (default) or "uci_square"
- plus usual HF tokenizer kwargs
"""
# Initialize special tokens
self._pad_token = self.PAD_TOKEN
self._bos_token = self.BOS_TOKEN
self._eos_token = self.EOS_TOKEN
self._unk_token = self.UNK_TOKEN
# Read tokenization_mode from kwargs (and keep it for save/load)
tokenization_mode = kwargs.pop("tokenization_mode", "move")
if tokenization_mode not in ("move", "uci_square"):
raise ValueError(f"Unknown tokenization_mode={tokenization_mode!r}")
self.tokenization_mode = tokenization_mode
# Remove any duplicate special-token entries passed through kwargs
# to avoid "multiple values for keyword" errors when loading from disk.
kwargs.pop("pad_token", None)
kwargs.pop("bos_token", None)
kwargs.pop("eos_token", None)
kwargs.pop("unk_token", None)
# Load or create vocabulary
if vocab is not None:
self._vocab = vocab
elif vocab_file is not None and os.path.exists(vocab_file):
with open(vocab_file, "r", encoding="utf-8") as f:
self._vocab = json.load(f)
else:
# Create a minimal vocabulary with just special tokens
# (you should build from dataset or use build_uci_square_vocab)
self._vocab = self._create_default_vocab()
# Create reverse mapping
self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
# Ensure tokenization_mode is saved in tokenizer_config.json
kwargs["tokenization_mode"] = self.tokenization_mode
# Call parent init AFTER setting up vocab
super().__init__(
pad_token=self._pad_token,
bos_token=self._bos_token,
eos_token=self._eos_token,
unk_token=self._unk_token,
**kwargs,
)
def _create_default_vocab(self) -> Dict[str, int]:
"""
Minimal default vocabulary with just special tokens.
"""
special_tokens = [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]
vocab = {token: idx for idx, token in enumerate(special_tokens)}
return vocab
@classmethod
def build_vocab_from_iterator(
cls,
iterator,
min_frequency: int = 1,
) -> "ChessTokenizer":
"""
Build a "move" tokenizer vocabulary from an iterator of game strings.
Args:
iterator: yields game strings (space-separated moves).
min_frequency: minimum frequency for a token to be included.
Returns:
ChessTokenizer(tokenization_mode="move") with the built vocabulary.
"""
from collections import Counter
token_counts = Counter()
for game in iterator:
moves = game.strip().split()
token_counts.update(moves)
tokens = [token for token, count in token_counts.items() if count >= min_frequency]
tokens = sorted(tokens)
special_tokens = [cls.PAD_TOKEN, cls.BOS_TOKEN, cls.EOS_TOKEN, cls.UNK_TOKEN]
vocab = {token: idx for idx, token in enumerate(special_tokens + tokens)}
return cls(vocab=vocab, tokenization_mode="move")
@classmethod
def build_vocab_from_dataset(
cls,
dataset_name: str = "dlouapre/lichess_2025-01_1M",
split: str = "train",
column: str = "text",
min_frequency: int = 500,
max_samples: Optional[int] = 100000,
) -> "ChessTokenizer":
"""
Build a "move" tokenizer vocabulary from a Hugging Face dataset.
Args:
dataset_name: dataset on HF Hub.
split: dataset split.
column: column containing game strings.
min_frequency: minimum frequency for a token to be included.
max_samples: max number of samples to process.
Returns:
ChessTokenizer(tokenization_mode="move") with the built vocabulary.
"""
from datasets import load_dataset
dataset = load_dataset(dataset_name, split=split)
if max_samples is not None:
dataset = dataset.select(range(min(max_samples, len(dataset))))
def game_iterator():
for example in dataset:
yield example[column]
return cls.build_vocab_from_iterator(game_iterator(), min_frequency=min_frequency)
@classmethod
def build_uci_square_vocab(cls) -> "ChessTokenizer":
"""
Build a fixed tiny vocab that can express ANY UCI move using 3 tokens:
[from_square, to_square, promotion_or_-].
Vocab:
- 4 specials
- 64 squares (a1..h8)
- 5 promo tokens: "-", "q", "r", "b", "n"
Total = 73 tokens.
"""
special = [cls.PAD_TOKEN, cls.BOS_TOKEN, cls.EOS_TOKEN, cls.UNK_TOKEN]
files = "abcdefgh"
ranks = "12345678"
squares = [f"{f}{r}" for r in ranks for f in files] # 64
promo = ["-", "q", "r", "b", "n"] # 5
vocab = {tok: i for i, tok in enumerate(special + squares + promo)}
return cls(vocab=vocab, tokenization_mode="uci_square")
@property
def vocab_size(self) -> int:
return len(self._vocab)
def get_vocab(self) -> Dict[str, int]:
return dict(self._vocab)
def _tokenize(self, text: str) -> List[str]:
"""
Tokenize a string.
- mode="move": split on spaces (original dataset tokens like "WPe2e4").
- mode="uci_square": each dataset move token -> [from_sq, to_sq, promo_or_-]
Example: "WPe2e4" -> ["e2", "e4", "-"]
"WPe7e8=Q" -> ["e7", "e8", "q"]
"""
tokens = text.strip().split()
if self.tokenization_mode != "uci_square":
return tokens
out: List[str] = []
for tok in tokens:
# Keep special tokens as-is if they appear in text
if tok in self._vocab:
out.append(tok)
continue
# Typical dataset format:
# [W|B][Piece][from_sq][to_sq]... possibly "(x)" "(+)" "(o)" "=Q" etc.
# Examples:
# WPe2e4
# BNg8f6
# WPe7e8=Q(+)
# WPe5d6(x)
if len(tok) >= 6 and tok[0] in ("W", "B"):
from_sq = tok[2:4]
to_sq = tok[4:6]
if re.fullmatch(r"[a-h][1-8]", from_sq) and re.fullmatch(r"[a-h][1-8]", to_sq):
promo = "-"
if "=" in tok:
i = tok.index("=")
if i + 1 < len(tok):
p = tok[i + 1].lower()
if p in ("q", "r", "b", "n"):
promo = p
out.extend([from_sq, to_sq, promo])
continue
# Fallback: find two squares anywhere in token
squares = re.findall(r"[a-h][1-8]", tok)
if len(squares) >= 2:
promo = "-"
m = re.search(r"[=]?([qrbnQRBN])", tok)
if m:
promo = m.group(1).lower()
out.extend([squares[0], squares[1], promo])
else:
out.append(self.UNK_TOKEN)
return out
def _convert_token_to_id(self, token: str) -> int:
return self._vocab.get(token, self._vocab.get(self.UNK_TOKEN, 0))
def _convert_id_to_token(self, index: int) -> str:
return self._ids_to_tokens.get(index, self.UNK_TOKEN)
def convert_tokens_to_string(self, tokens: List[str]) -> str:
# Filter out special tokens for cleaner output
special = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN}
return " ".join(t for t in tokens if t not in special)
def save_vocabulary(
self,
save_directory: str,
filename_prefix: Optional[str] = None,
) -> tuple:
if not os.path.isdir(save_directory):
os.makedirs(save_directory, exist_ok=True)
vocab_file = os.path.join(
save_directory,
(filename_prefix + "-" if filename_prefix else "") + "vocab.json",
)
with open(vocab_file, "w", encoding="utf-8") as f:
json.dump(self._vocab, f, ensure_ascii=False, indent=2)
return (vocab_file,)
def count_vocab_from_dataset(
dataset_name: str = "dlouapre/lichess_2025-01_1M",
split: str = "train",
column: str = "text",
max_samples: Optional[int] = 10000,
) -> Dict[str, int]:
"""
Count token frequencies in a dataset (useful for vocabulary analysis).
"""
from collections import Counter
from datasets import load_dataset
dataset = load_dataset(dataset_name, split=split)
if max_samples is not None:
dataset = dataset.select(range(min(max_samples, len(dataset))))
token_counts = Counter()
for example in dataset:
moves = example[column].strip().split()
token_counts.update(moves)
return dict(token_counts)