velmen-chess-model_v2 / tokenizer.py
velmen's picture
Chess Challenge submission by velmen
e00ee6e verified
"""
Custom Chess Tokenizer V3 for the Chess Challenge.
Enhanced version with additional chess-specific tokens for:
- Castling moves (O-O, O-O-O)
- Check/checkmate indicators (+, #)
- Capture indicator (x)
- Turn indicators ([WHITE], [BLACK])
This provides richer context while keeping vocabulary minimal (81 tokens total).
"""
from __future__ import annotations
import json
import os
from pathlib import Path
from typing import Dict, List, Optional
import re
from transformers import PreTrainedTokenizer
class ChessTokenizer(PreTrainedTokenizer):
"""
Enhanced chess tokenizer with special chess notation tokens.
Vocabulary (79 tokens):
- 4 special tokens: [PAD], [BOS], [EOS], [UNK]
- 64 square tokens: a1-h8
- 4 promotion tokens: q, r, b, n
- 2 castling tokens: O-O, O-O-O
- 3 modifier tokens: +, #, x (check, checkmate, capture)
- 2 turn tokens: [WHITE], [BLACK]
"""
model_input_names = ["input_ids", "attention_mask"]
vocab_files_names = {"vocab_file": "vocab.json"}
# Special tokens
PAD_TOKEN = "[PAD]"
BOS_TOKEN = "[BOS]"
EOS_TOKEN = "[EOS]"
UNK_TOKEN = "[UNK]"
WHITE_TOKEN = "[WHITE]"
BLACK_TOKEN = "[BLACK]"
def __init__(
self,
vocab_file: Optional[str] = None,
vocab: Optional[Dict[str, int]] = None,
**kwargs,
):
self._pad_token = self.PAD_TOKEN
self._bos_token = self.BOS_TOKEN
self._eos_token = self.EOS_TOKEN
self._unk_token = self.UNK_TOKEN
kwargs.pop("pad_token", None)
kwargs.pop("bos_token", None)
kwargs.pop("eos_token", None)
kwargs.pop("unk_token", None)
# Enhanced regex pattern for chess notation
# Matches: squares, promotions, castling, modifiers, turn indicators
self.token_pattern = re.compile(
r'O-O-O|O-O|' # Castling (match O-O-O first!)
r'\[WHITE\]|\[BLACK\]|' # Turn indicators
r'[a-h][1-8]|' # Squares
r'[qrbn]|' # Promotions
r'[+#x]' # Check, checkmate, capture
)
if vocab is not None:
self._vocab = vocab
elif vocab_file is not None and os.path.exists(vocab_file):
with open(vocab_file, "r", encoding="utf-8") as f:
self._vocab = json.load(f)
else:
self._vocab = self._create_default_vocab()
self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
super().__init__(
pad_token=self._pad_token,
bos_token=self._bos_token,
eos_token=self._eos_token,
unk_token=self._unk_token,
**kwargs,
)
def _create_default_vocab(self) -> Dict[str, int]:
"""
Create the complete vocabulary with all chess-specific tokens.
Total: 79 tokens
"""
vocab = {}
idx = 0
# Special tokens (0-3)
for token in [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]:
vocab[token] = idx
idx += 1
# Squares (4-67)
for f in 'abcdefgh':
for r in '12345678':
vocab[f"{f}{r}"] = idx
idx += 1
# Promotions (68-71)
for p in ['q', 'r', 'b', 'n']:
vocab[p] = idx
idx += 1
# Castling (72-73)
vocab["O-O"] = idx
idx += 1
vocab["O-O-O"] = idx
idx += 1
# Modifiers (74-76)
vocab["+"] = idx # Check
idx += 1
vocab["#"] = idx # Checkmate
idx += 1
vocab["x"] = idx # Capture
idx += 1
# Turn indicators (77-78)
vocab[self.WHITE_TOKEN] = idx
idx += 1
vocab[self.BLACK_TOKEN] = idx
idx += 1
return vocab
def _tokenize(self, text: str) -> List[str]:
"""
Enhanced tokenization with preprocessing for common chess notation variants.
Handles:
- Lichess format: (Q) → q, (x) → x, (+) → +, (#) → #
- Standard notation: keeps O-O, O-O-O, +, #, x as-is
- Extracts squares, promotions, castling, and modifiers
"""
# Normalize Lichess-style parentheses notation
text = (text.replace("(Q)", "q")
.replace("(R)", "r")
.replace("(B)", "b")
.replace("(N)", "n")
.replace("(x)", "x")
.replace("(+)", "+")
.replace("(#)", "#")
.replace("(+*)", "#") # Checkmate variant
.replace("(o)", "O-O") # Kingside castling
.replace("(O)", "O-O-O")) # Queenside castling
# Extract all chess tokens
return self.token_pattern.findall(text)
def _convert_token_to_id(self, token: str) -> int:
"""Convert a token to its ID."""
return self._vocab.get(token, self._vocab.get(self.UNK_TOKEN, 0))
def _convert_id_to_token(self, index: int) -> str:
"""Convert an ID to its token."""
return self._ids_to_tokens.get(index, self.UNK_TOKEN)
def convert_tokens_to_string(self, tokens: List[str]) -> str:
"""
Reconstructs chess moves in standard UCI format with modifiers.
Intelligently groups tokens:
- Combines squares into moves: e2, e4 → e2e4
- Attaches promotions: a7, a8, q → a7a8q
- Keeps modifiers separate: e2e4, x, + → e2e4x+
- Preserves castling and turn indicators
"""
special = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN}
clean_tokens = [t for t in tokens if t not in special]
output = []
modifiers = {'+', '#', 'x'}
promotions = {'q', 'r', 'b', 'n'}
for token in clean_tokens:
# Castling and turn indicators stay as-is
if token in ["O-O", "O-O-O", self.WHITE_TOKEN, self.BLACK_TOKEN]:
output.append(token)
# Promotions attach to previous move
elif token in promotions and output and len(output[-1]) == 4:
output[-1] += token
# Modifiers can attach or stay separate (flexible)
elif token in modifiers and output:
output[-1] += token
# Square: either start new move or complete previous
elif len(token) == 2 and token[0] in 'abcdefgh':
if output and len(output[-1]) == 2 and output[-1][0] in 'abcdefgh':
# Complete the move
output[-1] += token
else:
# Start new move
output.append(token)
else:
output.append(token)
return " ".join(output)
def add_turn_indicators(self, text: str, add_white_indicator: bool = True) -> str:
"""
Add turn indicators to help the model understand whose turn it is.
Args:
text: Game string (space-separated moves)
add_white_indicator: If True, add [WHITE] at start (white moves first)
Returns:
Game string with turn indicators
"""
moves = text.strip().split()
result = []
# White starts (by convention)
is_white = add_white_indicator
for move in moves:
turn_token = self.WHITE_TOKEN if is_white else self.BLACK_TOKEN
result.append(turn_token)
result.append(move)
is_white = not is_white
return " ".join(result)
def save_vocabulary(
self,
save_directory: str,
filename_prefix: Optional[str] = None,
) -> tuple:
"""Save the vocabulary to a JSON file."""
if not os.path.isdir(save_directory):
os.makedirs(save_directory, exist_ok=True)
vocab_file = os.path.join(
save_directory,
(filename_prefix + "-" if filename_prefix else "") + "vocab.json",
)
with open(vocab_file, "w", encoding="utf-8") as f:
json.dump(self._vocab, f, ensure_ascii=False, indent=2)
return (vocab_file,)
@classmethod
def build_vocab_from_iterator(cls, iterator, min_frequency=1):
"""Returns tokenizer with fixed vocabulary (doesn't depend on data)."""
return cls()
@classmethod
def build_vocab_from_dataset(cls, **kwargs):
"""Returns tokenizer with fixed vocabulary (doesn't depend on data)."""
return cls()
@property
def vocab_size(self) -> int:
"""Return the size of the vocabulary (79 tokens)."""
return len(self._vocab)
def get_vocab(self) -> Dict[str, int]:
"""Return the vocabulary as a dictionary."""
return dict(self._vocab)