model_kheng / tokenizer.py
Gusthavok's picture
Chess Challenge submission by Gusthavok
7883f46 verified
"""
Custom Chess Tokenizer for the Chess Challenge.
This tokenizer treats each move as a single token using the extended UCI notation
from the Lichess dataset (e.g., WPe2e4, BNg8f6).
The dataset format uses:
- W/B prefix for White/Black
- Piece letter: P=Pawn, N=Knight, B=Bishop, R=Rook, Q=Queen, K=King
- Source and destination squares (e.g., e2e4)
- Special suffixes: (x)=capture, (+)=check, (+*)=checkmate, (o)/(O)=castling
"""
from __future__ import annotations
import json
import os
from pathlib import Path
from typing import Dict, List, Optional
from transformers import PreTrainedTokenizer
class ChessTokenizer(PreTrainedTokenizer):
"""
A custom tokenizer for chess moves using extended UCI notation.
This tokenizer maps each possible chess move to a unique token ID.
The vocabulary is built from the training dataset to ensure all moves
encountered during training have a corresponding token.
Example:
>>> tokenizer = ChessTokenizer()
>>> tokenizer.encode("WPe2e4 BPe7e5")
[1, 42, 87, 2] # [BOS, e2e4, e7e5, EOS]
"""
model_input_names = ["input_ids", "attention_mask"]
vocab_files_names = {"vocab_file": "vocab.json"}
# Special tokens
PAD_TOKEN = "[PAD]"
BOS_TOKEN = "[BOS]"
EOS_TOKEN = "[EOS]"
UNK_TOKEN = "[UNK]"
EOM_TOKEN = "[EOM]" # End of Move - marks boundary between moves
def __init__(
self,
vocab_file: Optional[str] = None,
vocab: Optional[Dict[str, int]] = None,
component_mode: bool = False,
**kwargs,
):
"""
Initialize the chess tokenizer.
Args:
vocab_file: Path to a JSON file containing the vocabulary mapping.
vocab: Dictionary mapping tokens to IDs (alternative to vocab_file).
component_mode: If True, tokenize moves into components (WP, e2, e4).
**kwargs: Additional arguments passed to PreTrainedTokenizer.
"""
# Initialize special tokens
self._pad_token = self.PAD_TOKEN
self._bos_token = self.BOS_TOKEN
self._eos_token = self.EOS_TOKEN
self._unk_token = self.UNK_TOKEN
self._eom_token = self.EOM_TOKEN
# Component mode flag (for splitting moves into parts)
self._component_mode = component_mode
# Remove any duplicate special-token entries passed through kwargs
# to avoid "multiple values for keyword" errors when loading from disk.
kwargs.pop("pad_token", None)
kwargs.pop("bos_token", None)
kwargs.pop("eos_token", None)
kwargs.pop("unk_token", None)
kwargs.pop("eom_token", None)
kwargs.pop("component_mode", None)
# Load or create vocabulary
if vocab is not None:
self._vocab = vocab
elif vocab_file is not None and os.path.exists(vocab_file):
with open(vocab_file, "r", encoding="utf-8") as f:
self._vocab = json.load(f)
else:
# Create a minimal vocabulary with just special tokens
# The full vocabulary should be built from the dataset
self._vocab = self._create_default_vocab()
# Create reverse mapping
self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
# Call parent init AFTER setting up vocab
super().__init__(
pad_token=self._pad_token,
bos_token=self._bos_token,
eos_token=self._eos_token,
unk_token=self._unk_token,
component_mode=component_mode, # This gets saved to tokenizer_config.json
**kwargs,
)
# Store EOM token ID for easy access
self.eom_token_id = self._vocab.get(self.EOM_TOKEN, -1)
def _create_default_vocab(self) -> Dict[str, int]:
"""
Create a minimal default vocabulary with just special tokens.
For the full vocabulary, use `build_vocab_from_dataset()`.
This minimal vocab is just a placeholder - you should build from data.
"""
special_tokens = [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]
vocab = {token: idx for idx, token in enumerate(special_tokens)}
return vocab
@classmethod
def build_vocab_from_iterator(
cls,
iterator,
min_frequency: int = 1,
) -> "ChessTokenizer":
"""
Build a tokenizer vocabulary from an iterator of game strings.
Args:
iterator: An iterator yielding game strings (space-separated moves).
min_frequency: Minimum frequency for a token to be included.
Returns:
A ChessTokenizer with the built vocabulary.
"""
from collections import Counter
token_counts = Counter()
for game in iterator:
moves = game.strip().split()
token_counts.update(moves)
# Filter by frequency
tokens = [
token for token, count in token_counts.items()
if count >= min_frequency
]
# Sort for reproducibility
tokens = sorted(tokens)
# Build vocabulary
special_tokens = [cls.PAD_TOKEN, cls.BOS_TOKEN, cls.EOS_TOKEN, cls.UNK_TOKEN]
vocab = {token: idx for idx, token in enumerate(special_tokens + tokens)}
return cls(vocab=vocab)
@classmethod
def build_vocab_from_dataset(
cls,
dataset_name: str = "dlouapre/lichess_2025-01_1M",
split: str = "train",
column: str = "text",
min_frequency: int = 500,
max_samples: Optional[int] = 100000,
) -> "ChessTokenizer":
"""
Build a tokenizer vocabulary from a Hugging Face dataset.
Args:
dataset_name: Name of the dataset on Hugging Face Hub.
split: Dataset split to use.
column: Column containing the game strings.
min_frequency: Minimum frequency for a token to be included (default: 500).
max_samples: Maximum number of samples to process (default: 100k).
Returns:
A ChessTokenizer with the built vocabulary.
"""
from datasets import load_dataset
dataset = load_dataset(dataset_name, split=split)
if max_samples is not None:
dataset = dataset.select(range(min(max_samples, len(dataset))))
def game_iterator():
for example in dataset:
yield example[column]
return cls.build_vocab_from_iterator(game_iterator(), min_frequency=min_frequency)
@classmethod
def build_vocab_more_detailed(
cls,
) -> "ChessTokenizer":
"""
Build a component-based tokenizer for chess moves.
Instead of one token per move (WPe2e4), splits into components:
WPe2e4 -> [WP, e2, e4]
BNg8f6(x) -> [BN, g8, f6, (x)]
This gives ~90 tokens instead of ~1200, with better generalization.
Returns:
A ChessTokenizer with component vocabulary.
"""
# Combined color+piece tokens (avoids B collision between Black and Bishop)
tokens_pieces = [
"WP", "WN", "WB", "WR", "WQ", "WK", # White pieces
"BP", "BN", "BB", "BR", "BQ", "BK", # Black pieces
]
# the positions:
files = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h']
ranks = ['1', '2', '3', '4', '5', '6', '7', '8']
tokens_positions = [f + r for f in files for r in ranks]
# the special suffixes:
tokens_suffixes = [
"(x)", # capture
"(+)", # check
"(x+)", # capture + check
"(+*)", # checkmate
"(x+*)", # capture + checkmate
"(o)", # kingside castling
"(O)", # queenside castling
"(xE)", # en passant
"=Q", # promotion to queen
"=R", # promotion to rook
"=B", # promotion to bishop
"=N", # promotion to knight
]
# Combine all tokens
tokens = tokens_pieces + tokens_positions + tokens_suffixes
# Build vocabulary with [EOM] for move boundaries
# [EOM] helps the model understand when a move ends
special_tokens = [cls.PAD_TOKEN, cls.BOS_TOKEN, cls.EOS_TOKEN, cls.UNK_TOKEN, cls.EOM_TOKEN]
vocab = {token: idx for idx, token in enumerate(special_tokens + tokens)}
for ind, token in enumerate(special_tokens+tokens):
print(f"Token {ind}: {token}")
# Pass component_mode=True so it gets saved to tokenizer_config.json
return cls(vocab=vocab, component_mode=True)
@property
def vocab_size(self) -> int:
"""Return the size of the vocabulary."""
return len(self._vocab)
def get_vocab(self) -> Dict[str, int]:
"""Return the vocabulary as a dictionary."""
return dict(self._vocab)
def _tokenize(self, text: str) -> List[str]:
"""
Tokenize a string of moves into a list of tokens.
If component_mode is enabled, splits each move into parts:
WPe2e4 -> [W, P, e2, e4, " "]
BNg8f6(x) -> [B, N, g8, f6, (x), " "]
Args:
text: A string of space-separated moves.
Returns:
List of tokens.
"""
if getattr(self, '_component_mode', False):
return self._tokenize_components(text)
return text.strip().split()
def _tokenize_components(self, text: str) -> List[str]:
"""
Tokenize moves into component parts with [EOM] boundaries.
Move format: [Color][Piece][from_square][to_square][suffix] [EOM]
Example:
WPe2e4 -> [WP, e2, e4, EOM]
BNg8f6(x) -> [BN, g8, f6, (x), EOM]
"""
import re
tokens = []
moves = text.strip().split()
for i, move in enumerate(moves):
# Skip special tokens
if move in [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN, self.EOM_TOKEN]:
tokens.append(move)
continue
# Parse move: ColorPiece + from_square + to_square + optional suffix
# Pattern: (W|B)(P|N|B|R|Q|K)([a-h][1-8])([a-h][1-8])(suffix)?
pattern = r'^([WB])([PNBRQK])([a-h][1-8])([a-h][1-8])(.*)$'
match = re.match(pattern, move)
if match:
color, piece, from_sq, to_sq, suffix = match.groups()
# Combined color+piece token (e.g., "WP", "BN", "BB")
tokens.append(color + piece)
tokens.extend([from_sq, to_sq])
# Handle suffix (could be combination like "(x+)" or "=Q")
if suffix:
# Try to match known suffixes
suffix_pattern = r'(\(x\+\*\)|\(x\+\)|\(\+\*\)|\(xE\)|\(x\)|\(\+\)|\(o\)|\(O\)|=Q|=R|=B|=N)'
suffix_matches = re.findall(suffix_pattern, suffix)
tokens.extend(suffix_matches)
# Add [EOM] to mark end of this move
tokens.append(self.EOM_TOKEN)
else:
# Fallback: add as unknown + EOM
tokens.append(self.UNK_TOKEN)
tokens.append(self.EOM_TOKEN)
return tokens
def _convert_token_to_id(self, token: str) -> int:
"""Convert a token to its ID."""
return self._vocab.get(token, self._vocab.get(self.UNK_TOKEN, 0))
def _convert_id_to_token(self, index: int) -> str:
"""Convert an ID to its token."""
token = self._ids_to_tokens.get(index, self.UNK_TOKEN)
# Convert [EOM] to whitespace for evaluator compatibility
# This makes _generate_until_whitespace stop after one move
if token == self.EOM_TOKEN:
return " "
return token
# Color+piece tokens that mark the start of a new move
_MOVE_START_TOKENS = {"WP", "WN", "WB", "WR", "WQ", "WK", "BP", "BN", "BB", "BR", "BQ", "BK"}
def convert_tokens_to_string(self, tokens: List[str]) -> str:
"""Convert a list of tokens back to a string.
In component mode, reconstructs moves by replacing [EOM] with spaces.
CRITICAL: [EOM] must decode to a non-empty whitespace string so that
the evaluator's _generate_until_whitespace stops after one move.
"""
# Filter out special tokens except EOM for cleaner output
special = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN}
if getattr(self, '_component_mode', False):
# Reconstruct with [EOM] as space delimiter
result = []
for token in tokens:
if token == self.EOM_TOKEN:
# MUST be non-empty whitespace for evaluator
result.append(" ")
elif token not in special:
result.append(token)
# Don't strip! We need the trailing space from [EOM]
return "".join(result)
# Non-component mode: just join with spaces
filtered = [t for t in tokens if t not in special]
return " ".join(filtered)
# =========================================================================
# Structured Generation Support Methods
# =========================================================================
def get_token_category(self, token: str) -> str:
"""Categorize a token into: piece, square, suffix, eom, or special.
Args:
token: Token string to categorize.
Returns:
Category name: 'piece', 'square', 'suffix', 'eom', or 'special'.
"""
if token in [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]:
return 'special'
if token == self.EOM_TOKEN:
return 'eom'
if self.is_piece_token(token):
return 'piece'
if self.is_square_token(token):
return 'square'
if self.is_suffix_token(token):
return 'suffix'
return 'unknown'
def is_piece_token(self, token: str) -> bool:
"""Check if token is a piece token (WP, BN, etc.)."""
return token in ['WP', 'WN', 'WB', 'WR', 'WQ', 'WK', 'BP', 'BN', 'BB', 'BR', 'BQ', 'BK']
def is_square_token(self, token: str) -> bool:
"""Check if token is a square token (e2, g8, etc.)."""
if len(token) != 2:
return False
return token[0] in 'abcdefgh' and token[1] in '12345678'
def is_suffix_token(self, token: str) -> bool:
"""Check if token is a suffix token ((x), (+), =Q, etc.)."""
return token in ['(x)', '(+)', '(x+)', '(+*)', '(x+*)', '(o)', '(O)', '(xE)', '=Q', '=R', '=B', '=N']
def is_eom_token(self, token: str) -> bool:
"""Check if token is the [EOM] token."""
return token == self.EOM_TOKEN
def get_token_color(self, token: str) -> Optional[str]:
"""Get the color ('W' or 'B') from a piece token, None otherwise."""
if self.is_piece_token(token) and len(token) >= 2:
return token[0] # 'W' or 'B'
return None
def build_vocabulary_masks(self) -> dict:
"""Build boolean masks for each token category.
Returns:
Dictionary with keys: 'piece', 'square', 'suffix', 'eom', 'white_piece', 'black_piece'.
Each value is a boolean list/tensor of length vocab_size.
"""
import torch
vocab_size = len(self._vocab)
masks = {
'piece': [False] * vocab_size,
'square': [False] * vocab_size,
'suffix': [False] * vocab_size,
'eom': [False] * vocab_size,
'white_piece': [False] * vocab_size,
'black_piece': [False] * vocab_size,
}
for token, token_id in self._vocab.items():
if self.is_piece_token(token):
masks['piece'][token_id] = True
color = self.get_token_color(token)
if color == 'W':
masks['white_piece'][token_id] = True
elif color == 'B':
masks['black_piece'][token_id] = True
elif self.is_square_token(token):
masks['square'][token_id] = True
elif self.is_suffix_token(token):
masks['suffix'][token_id] = True
elif self.is_eom_token(token):
masks['eom'][token_id] = True
# Convert to tensors
return {k: torch.tensor(v, dtype=torch.bool) for k, v in masks.items()}
def analyze_generation_state(self, input_ids: torch.Tensor) -> dict:
"""Analyze the current generation state to determine next expected token.
Args:
input_ids: Tensor of shape (batch_size, seq_len) with token IDs.
Returns:
Dictionary with:
- 'position': 0 (piece), 1 (from_square), 2 (to_square), 3 (suffix/eom)
- 'expected_color': 'W' or 'B'
- 'last_eom_idx': Index of last [EOM] token in sequence
"""
batch_size = input_ids.shape[0]
results = []
for b in range(batch_size):
seq = input_ids[b].tolist()
# Find last [EOM] or [BOS]
last_eom_idx = -1
for i in range(len(seq) - 1, -1, -1):
token = self._ids_to_tokens.get(seq[i], self.UNK_TOKEN)
if token in [self.EOM_TOKEN, self.BOS_TOKEN]:
last_eom_idx = i
break
# Count tokens since last [EOM]/[BOS] (excluding padding)
tokens_since_boundary = []
for i in range(last_eom_idx + 1, len(seq)):
token = self._ids_to_tokens.get(seq[i], self.UNK_TOKEN)
if token != self.PAD_TOKEN:
tokens_since_boundary.append(token)
# Determine position in move structure: [Piece][Square][Square][Suffix?][EOM]
num_tokens = len(tokens_since_boundary)
if num_tokens == 0:
position = 0 # Expect piece
elif num_tokens == 1:
position = 1 # Expect from_square
elif num_tokens == 2:
position = 2 # Expect to_square
else:
position = 3 # Expect suffix or [EOM]
# Determine expected color by counting complete moves
# Count [EOM] tokens to get move number
eom_count = sum(1 for i in seq if self._ids_to_tokens.get(i, '') == self.EOM_TOKEN)
expected_color = 'W' if eom_count % 2 == 0 else 'B'
results.append({
'position': position,
'expected_color': expected_color,
'last_eom_idx': last_eom_idx,
})
# For single batch, return dict directly; for multi-batch, return list
return results[0] if batch_size == 1 else results
def save_vocabulary(
self,
save_directory: str,
filename_prefix: Optional[str] = None,
) -> tuple:
"""
Save the vocabulary to a JSON file.
Args:
save_directory: Directory to save the vocabulary.
filename_prefix: Optional prefix for the filename.
Returns:
Tuple containing the path to the saved vocabulary file.
"""
if not os.path.isdir(save_directory):
os.makedirs(save_directory, exist_ok=True)
vocab_file = os.path.join(
save_directory,
(filename_prefix + "-" if filename_prefix else "") + "vocab.json",
)
with open(vocab_file, "w", encoding="utf-8") as f:
json.dump(self._vocab, f, ensure_ascii=False, indent=2)
return (vocab_file,)
def count_vocab_from_dataset(
dataset_name: str = "dlouapre/lichess_2025-01_1M",
split: str = "train",
column: str = "text",
max_samples: Optional[int] = 10000,
) -> Dict[str, int]:
"""
Count token frequencies in a dataset (useful for vocabulary analysis).
Args:
dataset_name: Name of the dataset on Hugging Face Hub.
split: Dataset split to use.
column: Column containing the game strings.
max_samples: Maximum number of samples to process.
Returns:
Dictionary mapping tokens to their frequencies.
"""
from collections import Counter
from datasets import load_dataset
dataset = load_dataset(dataset_name, split=split)
if max_samples is not None:
dataset = dataset.select(range(min(max_samples, len(dataset))))
token_counts = Counter()
for example in dataset:
moves = example[column].strip().split()
token_counts.update(moves)
return dict(token_counts)