chess-yentl-v2 / component_tokenizer.py
Yentlcol's picture
Chess Challenge submission by Yentlcol
50068c7 verified
"""
Component-based Chess Tokenizer - Optimized for Parameter Efficiency.
This tokenizer decomposes chess moves into reusable components:
- Piece type (P, N, B, R, Q, K)
- Source square (a1-h8)
- Destination square (a1-h8)
- Modifiers (capture, check, castling, etc.)
Example:
"WPe2e4" → ["P", "e2", "e4"]
"BNg8f6(x)" → ["N", "g8", "f6", "(x)"]
This reduces vocabulary from ~1682 to ~80 tokens, saving 205K parameters.
"""
from __future__ import annotations
import json
import os
from typing import Dict, List, Optional
from transformers import PreTrainedTokenizer
class ComponentChessTokenizer(PreTrainedTokenizer):
"""
Component-based tokenizer for chess moves.
Decomposes moves into: [piece, from_square, to_square, modifiers...]
Key advantages:
- 95% smaller vocabulary (1682 → 80 tokens)
- Saves 205K embedding parameters
- Better generalization to rare move combinations
- Compositional understanding of chess structure
"""
model_input_names = ["input_ids", "attention_mask"]
vocab_files_names = {"vocab_file": "vocab.json"}
# Special tokens
PAD_TOKEN = "[PAD]"
BOS_TOKEN = "[BOS]"
EOS_TOKEN = "[EOS]"
UNK_TOKEN = "[UNK]"
SEP_TOKEN = "[SEP]" # Separates components within a move
# Chess piece types (6 tokens)
PIECES = ["P", "N", "B", "R", "Q", "K"]
# All squares on the board (64 tokens)
FILES = "abcdefgh"
RANKS = "12345678"
# Move modifiers (10 tokens)
MODIFIERS = [
"(x)", # capture
"(+)", # check
"(+*)", # checkmate
"(o)", # kingside castling
"(O)", # queenside castling
"=Q", # promotion to queen
"=R", # promotion to rook
"=B", # promotion to bishop
"=N", # promotion to knight
"(e.p.)", # en passant
]
def __init__(
self,
vocab_file: Optional[str] = None,
vocab: Optional[Dict[str, int]] = None,
**kwargs,
):
"""Initialize the component chess tokenizer."""
# Initialize special tokens
self._pad_token = self.PAD_TOKEN
self._bos_token = self.BOS_TOKEN
self._eos_token = self.EOS_TOKEN
self._unk_token = self.UNK_TOKEN
# Remove duplicate special-token entries
kwargs.pop("pad_token", None)
kwargs.pop("bos_token", None)
kwargs.pop("eos_token", None)
kwargs.pop("unk_token", None)
# Load or create vocabulary
if vocab is not None:
self._vocab = vocab
elif vocab_file is not None and os.path.exists(vocab_file):
with open(vocab_file, "r", encoding="utf-8") as f:
self._vocab = json.load(f)
else:
self._vocab = self._create_component_vocab()
# Create reverse mapping
self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
# Call parent init
super().__init__(
pad_token=self._pad_token,
bos_token=self._bos_token,
eos_token=self._eos_token,
unk_token=self._unk_token,
**kwargs,
)
def _create_component_vocab(self) -> Dict[str, int]:
"""
Create the component vocabulary.
Vocabulary structure:
- Special tokens (5): [PAD], [BOS], [EOS], [UNK], [SEP]
- Pieces (6): P, N, B, R, Q, K
- Squares (64): a1, a2, ..., h8
- Modifiers (10): (x), (+), (+*), (o), (O), =Q, =R, =B, =N, (e.p.)
Total: 85 tokens (vs 1682 in original tokenizer)
"""
tokens = [
self.PAD_TOKEN,
self.BOS_TOKEN,
self.EOS_TOKEN,
self.UNK_TOKEN,
self.SEP_TOKEN,
]
# Add pieces
tokens.extend(self.PIECES)
# Add all squares
squares = [f + r for f in self.FILES for r in self.RANKS]
tokens.extend(squares)
# Add modifiers
tokens.extend(self.MODIFIERS)
# Create vocabulary
vocab = {token: idx for idx, token in enumerate(tokens)}
return vocab
@classmethod
def build_vocab(cls) -> "ComponentChessTokenizer":
"""
Build tokenizer with component vocabulary.
No dataset needed - vocabulary is deterministic based on chess rules.
"""
return cls()
@property
def vocab_size(self) -> int:
"""Return the size of the vocabulary."""
return len(self._vocab)
def get_vocab(self) -> Dict[str, int]:
"""Return the vocabulary as a dictionary."""
return dict(self._vocab)
def _decompose_move(self, move: str) -> List[str]:
"""
Decompose a move string into components.
Examples:
"WPe2e4" → ["P", "e2", "e4"]
"BNg8f6(x)" → ["N", "g8", "f6", "(x)"]
"WKe1g1(o)" → ["K", "e1", "g1", "(o)"]
"BPe7e8=Q(+)" → ["P", "e7", "e8", "=Q", "(+)"]
Args:
move: Extended UCI move string (e.g., "WPe2e4")
Returns:
List of component tokens
"""
if not move or move in [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]:
return [move]
components = []
# Remove color prefix (W/B)
if move.startswith(('W', 'B')):
move = move[1:]
if not move:
return [self.UNK_TOKEN]
# Extract piece type
piece = move[0]
if piece in self.PIECES:
components.append(piece)
move = move[1:]
else:
# Invalid piece
return [self.UNK_TOKEN]
# Extract squares (from and to)
# Format: <piece><from_square><to_square>[modifiers]
# E.g., "Pe2e4", "Ng1f3(x)", "Ke1g1(o)"
if len(move) < 4:
# Not enough characters for two squares
return [self.UNK_TOKEN]
# Generate valid squares for checking
valid_squares = [f + r for f in self.FILES for r in self.RANKS]
# Extract from_square (2 chars)
from_square = move[0:2]
if from_square in valid_squares:
components.append(from_square)
else:
return [self.UNK_TOKEN]
# Extract to_square (2 chars)
to_square = move[2:4]
if to_square in valid_squares:
components.append(to_square)
else:
return [self.UNK_TOKEN]
# Extract modifiers (remaining characters)
remaining = move[4:]
if remaining:
# Parse modifiers: (x), (+), (+*), (o), (O), =Q, =R, =B, =N, (e.p.)
i = 0
while i < len(remaining):
# Check for known modifiers
found = False
for modifier in self.MODIFIERS:
if remaining[i:].startswith(modifier):
components.append(modifier)
i += len(modifier)
found = True
break
if not found:
# Unknown character, skip it
i += 1
return components
def _tokenize(self, text: str) -> List[str]:
"""
Tokenize a string of moves into component tokens.
Args:
text: Space-separated moves (e.g., "WPe2e4 BPe7e5 WNg1f3")
Returns:
List of component tokens
"""
moves = text.strip().split()
tokens = []
for move in moves:
# Skip special tokens
if move in [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]:
tokens.append(move)
else:
# Decompose move into components
components = self._decompose_move(move)
tokens.extend(components)
return tokens
def _convert_token_to_id(self, token: str) -> int:
"""Convert a token to its ID."""
return self._vocab.get(token, self._vocab.get(self.UNK_TOKEN, 0))
def _convert_id_to_token(self, index: int) -> str:
"""Convert an ID to its token."""
return self._ids_to_tokens.get(index, self.UNK_TOKEN)
def convert_tokens_to_string(self, tokens: List[str]) -> str:
"""
Convert component tokens back to move strings.
This reconstructs moves from components.
Note: We lose the W/B color prefix, but it's redundant
(can be inferred from move position).
"""
# Filter out special tokens
special = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN, self.SEP_TOKEN}
tokens = [t for t in tokens if t not in special]
# Generate valid squares for checking
valid_squares = [f + r for f in self.FILES for r in self.RANKS]
# Reconstruct moves from components
moves = []
i = 0
while i < len(tokens):
# Expect: piece, from_square, to_square, [modifiers...]
if i + 2 >= len(tokens):
break
piece = tokens[i]
from_sq = tokens[i + 1]
to_sq = tokens[i + 2]
if piece in self.PIECES and from_sq in valid_squares and to_sq in valid_squares:
move = f"{piece}{from_sq}{to_sq}"
i += 3
# Collect modifiers
while i < len(tokens) and tokens[i] in self.MODIFIERS:
move += tokens[i]
i += 1
moves.append(move)
else:
# Skip invalid tokens
i += 1
return " ".join(moves)
def save_vocabulary(
self,
save_directory: str,
filename_prefix: Optional[str] = None,
) -> tuple:
"""Save the vocabulary to a JSON file."""
if not os.path.isdir(save_directory):
os.makedirs(save_directory, exist_ok=True)
vocab_file = os.path.join(
save_directory,
(filename_prefix + "-" if filename_prefix else "") + "vocab.json",
)
with open(vocab_file, "w", encoding="utf-8") as f:
json.dump(self._vocab, f, ensure_ascii=False, indent=2)
return (vocab_file,)