chess_model_lucas_meb_4 / tokenizer.py
luluM's picture
Chess Challenge submission by luluM
6866407 verified
"""
Optimized Chess Tokenizer for the Chess Challenge.
Strategies for smaller vocabulary:
1. Remove rare moves (high min_frequency threshold)
2. Decompose moves into sub-tokens (piece + squares)
3. Merge similar move patterns
This tokenizer uses a hybrid approach:
- Common moves as single tokens (efficient for frequent patterns)
- Sub-token decomposition for rare moves (better generalization)
"""
from __future__ import annotations
import json
import os
import re
from collections import Counter
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from transformers import PreTrainedTokenizer
class ChessTokenizer(PreTrainedTokenizer):
"""
Optimized chess tokenizer with smaller vocabulary.
Uses move decomposition for rare moves to reduce vocabulary size
while maintaining good coverage.
"""
model_input_names = ["input_ids", "attention_mask"]
vocab_files_names = {"vocab_file": "vocab.json"}
# Special tokens
PAD_TOKEN = "[PAD]"
BOS_TOKEN = "[BOS]"
EOS_TOKEN = "[EOS]"
UNK_TOKEN = "[UNK]"
# Sub-token markers for decomposed moves
PIECE_PREFIX = "P:" # P:WP, P:BN, etc.
FROM_PREFIX = "F:" # F:e2, F:g1, etc.
TO_PREFIX = "T:" # T:e4, T:f3, etc.
SUFFIX_PREFIX = "S:" # S:(x), S:(+), etc.
def __init__(
self,
vocab_file: Optional[str] = None,
vocab: Optional[Dict[str, int]] = None,
use_decomposition: bool = True,
**kwargs,
):
# Initialize special tokens
self._pad_token = self.PAD_TOKEN
self._bos_token = self.BOS_TOKEN
self._eos_token = self.EOS_TOKEN
self._unk_token = self.UNK_TOKEN
# Whether to use sub-token decomposition
self.use_decomposition = use_decomposition
# Remove duplicate special token kwargs
kwargs.pop("pad_token", None)
kwargs.pop("bos_token", None)
kwargs.pop("eos_token", None)
kwargs.pop("unk_token", None)
# Load or create vocabulary
if vocab is not None:
self._vocab = vocab
elif vocab_file is not None and os.path.exists(vocab_file):
with open(vocab_file, "r", encoding="utf-8") as f:
self._vocab = json.load(f)
else:
self._vocab = self._create_default_vocab()
# Create reverse mapping
self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
# Build set of full-move tokens for fast lookup
self._full_move_tokens = {
t for t in self._vocab.keys()
if not t.startswith(("[", "P:", "F:", "T:", "S:"))
}
super().__init__(
pad_token=self._pad_token,
bos_token=self._bos_token,
eos_token=self._eos_token,
unk_token=self._unk_token,
**kwargs,
)
def _create_default_vocab(self) -> Dict[str, int]:
"""Create minimal default vocabulary."""
special_tokens = [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]
return {token: idx for idx, token in enumerate(special_tokens)}
@staticmethod
def _parse_move(move: str) -> Optional[Tuple[str, str, str, str]]:
"""
Parse a move into components: (color+piece, from_square, to_square, suffix).
Example: "WPe2e4" -> ("WP", "e2", "e4", "")
"BNg8f6(x)" -> ("BN", "g8", "f6", "(x)")
"""
# Pattern: [WB][PNBRQK][a-h][1-8][a-h][1-8](\(.+\))?
pattern = r'^([WB][PNBRQK])([a-h][1-8])([a-h][1-8])(\([^)]+\))?$'
match = re.match(pattern, move)
if match:
piece = match.group(1)
from_sq = match.group(2)
to_sq = match.group(3)
suffix = match.group(4) or ""
return (piece, from_sq, to_sq, suffix)
return None
def _decompose_move(self, move: str) -> List[str]:
"""
Decompose a move into sub-tokens.
Example: "WPe2e4" -> ["P:WP", "F:e2", "T:e4"]
"BNg8f6(x)" -> ["P:BN", "F:g8", "T:f6", "S:(x)"]
"""
parsed = self._parse_move(move)
if parsed is None:
return [self.UNK_TOKEN]
piece, from_sq, to_sq, suffix = parsed
tokens = [
f"{self.PIECE_PREFIX}{piece}",
f"{self.FROM_PREFIX}{from_sq}",
f"{self.TO_PREFIX}{to_sq}",
]
if suffix:
tokens.append(f"{self.SUFFIX_PREFIX}{suffix}")
return tokens
def _tokenize(self, text: str) -> List[str]:
"""
Tokenize text into tokens.
Uses full-move tokens for common moves, decomposes rare moves.
"""
tokens = []
for word in text.strip().split():
if word in self._full_move_tokens:
# Common move - use as single token
tokens.append(word)
elif word in self._vocab:
# Special token or sub-token
tokens.append(word)
elif self.use_decomposition:
# Rare move - decompose into sub-tokens
sub_tokens = self._decompose_move(word)
# Check if all sub-tokens are in vocab
if all(t in self._vocab for t in sub_tokens):
tokens.extend(sub_tokens)
else:
tokens.append(self.UNK_TOKEN)
else:
tokens.append(self.UNK_TOKEN)
return tokens
@classmethod
def build_vocab_from_dataset(
cls,
dataset_name: str = "dlouapre/lichess_2025-01_1M",
split: str = "train",
column: str = "text",
min_frequency: int = 1000,
max_vocab_size: int = 1500,
max_samples: Optional[int] = 200000,
use_decomposition: bool = True,
) -> "ChessTokenizer":
"""
Build optimized vocabulary from dataset.
Strategy:
1. Count all moves
2. Keep frequent moves as full tokens
3. Add sub-tokens for decomposition
4. Limit total vocabulary size
"""
from datasets import load_dataset
print(f"Building vocabulary from {dataset_name}...")
dataset = load_dataset(dataset_name, split=split)
if max_samples is not None:
dataset = dataset.select(range(min(max_samples, len(dataset))))
# Count all moves
move_counts = Counter()
for example in dataset:
moves = example[column].strip().split()
move_counts.update(moves)
print(f"Total unique moves: {len(move_counts)}")
# Start with special tokens
vocab = {
cls.PAD_TOKEN: 0,
cls.BOS_TOKEN: 1,
cls.EOS_TOKEN: 2,
cls.UNK_TOKEN: 3,
}
idx = 4
if use_decomposition:
# Add sub-tokens first
pieces = ["WP", "WN", "WB", "WR", "WQ", "WK",
"BP", "BN", "BB", "BR", "BQ", "BK"]
squares = [f"{f}{r}" for f in "abcdefgh" for r in "12345678"]
suffixes = ["(x)", "(+)", "(x+)", "(+*)", "(x+*)", "(o)", "(O)",
"(Q)", "(R)", "(B)", "(N)"]
# Add piece tokens
for p in pieces:
vocab[f"{cls.PIECE_PREFIX}{p}"] = idx
idx += 1
# Add square tokens (from and to)
for sq in squares:
vocab[f"{cls.FROM_PREFIX}{sq}"] = idx
idx += 1
vocab[f"{cls.TO_PREFIX}{sq}"] = idx
idx += 1
# Add suffix tokens
for s in suffixes:
vocab[f"{cls.SUFFIX_PREFIX}{s}"] = idx
idx += 1
# Add frequent full moves
frequent_moves = [
move for move, count in move_counts.most_common()
if count >= min_frequency
]
# Sort for reproducibility
frequent_moves = sorted(frequent_moves)
# Limit vocabulary size
available_slots = max_vocab_size - len(vocab)
frequent_moves = frequent_moves[:available_slots]
for move in frequent_moves:
if move not in vocab:
vocab[move] = idx
idx += 1
print(f"Final vocabulary size: {len(vocab)}")
print(f" - Special tokens: 4")
print(f" - Sub-tokens: {idx - 4 - len(frequent_moves)}")
print(f" - Full moves: {len(frequent_moves)}")
return cls(vocab=vocab, use_decomposition=use_decomposition)
@classmethod
def build_simple_vocab(
cls,
dataset_name: str = "dlouapre/lichess_2025-01_1M",
split: str = "train",
column: str = "text",
min_frequency: int = 2000,
max_samples: Optional[int] = 200000,
) -> "ChessTokenizer":
"""
Build simple vocabulary without decomposition.
Just keeps frequent moves, maps rare to UNK.
"""
from datasets import load_dataset
print(f"Building simple vocabulary from {dataset_name}...")
dataset = load_dataset(dataset_name, split=split)
if max_samples is not None:
dataset = dataset.select(range(min(max_samples, len(dataset))))
move_counts = Counter()
for example in dataset:
moves = example[column].strip().split()
move_counts.update(moves)
# Keep only frequent moves
vocab = {
cls.PAD_TOKEN: 0,
cls.BOS_TOKEN: 1,
cls.EOS_TOKEN: 2,
cls.UNK_TOKEN: 3,
}
frequent_moves = sorted([
move for move, count in move_counts.items()
if count >= min_frequency
])
for idx, move in enumerate(frequent_moves, start=4):
vocab[move] = idx
print(f"Vocabulary size: {len(vocab)}")
return cls(vocab=vocab, use_decomposition=False)
@property
def vocab_size(self) -> int:
return len(self._vocab)
def get_vocab(self) -> Dict[str, int]:
return dict(self._vocab)
def _convert_token_to_id(self, token: str) -> int:
return self._vocab.get(token, self._vocab.get(self.UNK_TOKEN, 0))
def _convert_id_to_token(self, index: int) -> str:
return self._ids_to_tokens.get(index, self.UNK_TOKEN)
def convert_tokens_to_string(self, tokens: List[str]) -> str:
"""Convert tokens back to string, reconstructing decomposed moves."""
special = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN}
result = []
i = 0
while i < len(tokens):
token = tokens[i]
if token in special:
i += 1
continue
# Check if this is a decomposed move
if token.startswith(self.PIECE_PREFIX):
# Reconstruct move from sub-tokens
piece = token[len(self.PIECE_PREFIX):]
from_sq = ""
to_sq = ""
suffix = ""
if i + 1 < len(tokens) and tokens[i + 1].startswith(self.FROM_PREFIX):
from_sq = tokens[i + 1][len(self.FROM_PREFIX):]
i += 1
if i + 1 < len(tokens) and tokens[i + 1].startswith(self.TO_PREFIX):
to_sq = tokens[i + 1][len(self.TO_PREFIX):]
i += 1
if i + 1 < len(tokens) and tokens[i + 1].startswith(self.SUFFIX_PREFIX):
suffix = tokens[i + 1][len(self.SUFFIX_PREFIX):]
i += 1
result.append(f"{piece}{from_sq}{to_sq}{suffix}")
else:
result.append(token)
i += 1
return " ".join(result)
def save_vocabulary(
self,
save_directory: str,
filename_prefix: Optional[str] = None,
) -> tuple:
if not os.path.isdir(save_directory):
os.makedirs(save_directory, exist_ok=True)
vocab_file = os.path.join(
save_directory,
(filename_prefix + "-" if filename_prefix else "") + "vocab.json",
)
with open(vocab_file, "w", encoding="utf-8") as f:
json.dump(self._vocab, f, ensure_ascii=False, indent=2)
return (vocab_file,)
def count_vocab_from_dataset(
dataset_name: str = "dlouapre/lichess_2025-01_1M",
split: str = "train",
column: str = "text",
max_samples: Optional[int] = 10000,
) -> Dict[str, int]:
"""Count token frequencies in dataset."""
from collections import Counter
from datasets import load_dataset
dataset = load_dataset(dataset_name, split=split)
if max_samples is not None:
dataset = dataset.select(range(min(max_samples, len(dataset))))
token_counts = Counter()
for example in dataset:
moves = example[column].strip().split()
token_counts.update(moves)
return dict(token_counts)