ykolo-baseline / tokenizer.py
ykolo's picture
Chess Challenge submission by ykolo
7b8a427 verified
"""
Compositional Chess Tokenizer for the Chess Challenge.
This tokenizer decomposes chess moves into meaningful components:
- Color (W/B), Piece (P/N/B/R/Q/K), Squares, Actions, Modifiers
Reduces vocabulary from 3803 to 86 tokens while enabling better generalization.
Example:
WPe2e4 -> [W, P, e2, ->, e4]
BNg8f6(x) -> [B, N, g8, x, f6]
"""
from __future__ import annotations
import json
import os
import re
from typing import Dict, List, Optional
from transformers import PreTrainedTokenizer
class ChessTokenizer(PreTrainedTokenizer):
"""Compositional tokenizer for chess moves (86 tokens vs 3803 baseline)."""
model_input_names = ["input_ids", "attention_mask"]
vocab_files_names = {"vocab_file": "vocab.json"}
PAD_TOKEN = "[PAD]"
BOS_TOKEN = "[BOS]"
EOS_TOKEN = "[EOS]"
UNK_TOKEN = "[UNK]"
# Use ASCII-safe tokens to avoid encoding issues and mismatches
MOVE_ARROW = "->"
CAPTURE_CROSS = "x"
def __init__(self, vocab_file: Optional[str] = None, vocab: Optional[Dict[str, int]] = None, **kwargs):
"""Initialize compositional chess tokenizer."""
self._pad_token = self.PAD_TOKEN
self._bos_token = self.BOS_TOKEN
self._eos_token = self.EOS_TOKEN
self._unk_token = self.UNK_TOKEN
kwargs.pop("pad_token", None)
kwargs.pop("bos_token", None)
kwargs.pop("eos_token", None)
kwargs.pop("unk_token", None)
if vocab is not None:
self._vocab = vocab
elif vocab_file is not None and os.path.exists(vocab_file):
with open(vocab_file, "r", encoding="utf-8") as f:
self._vocab = json.load(f)
else:
print("Building compositional vocabulary (86 tokens)...")
self._vocab = self._build_compositional_vocab()
self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
super().__init__(
pad_token=self._pad_token,
bos_token=self._bos_token,
eos_token=self._eos_token,
unk_token=self._unk_token,
**kwargs,
)
def _build_compositional_vocab(self) -> Dict[str, int]:
"""Build 86-token compositional vocabulary."""
vocab = {}
idx = 0
# Special tokens (4 tokens)
for token in [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]:
vocab[token] = idx
idx += 1
# Colors (2 tokens: W, B)
for color in ["W", "B"]:
vocab[color] = idx
idx += 1
# Pieces (6 tokens: P, N, B, R, Q, K)
for piece in ["P", "N", "B", "R", "Q", "K"]:
vocab[piece] = idx
idx += 1
# Squares (64 tokens: a1-h8)
for f in ["a", "b", "c", "d", "e", "f", "g", "h"]:
for r in ["1", "2", "3", "4", "5", "6", "7", "8"]:
vocab[f + r] = idx
idx += 1
# Actions (2 tokens: →move, ×capture)
vocab[self.MOVE_ARROW] = idx
idx += 1
vocab[self.CAPTURE_CROSS] = idx
idx += 1
# Modifiers (6 tokens: +check, +*checkmate, =Q/R/B/N promotions)
for mod in ["+", "+*", "=Q", "=R", "=B", "=N"]:
vocab[mod] = idx
idx += 1
# Special moves (2 tokens: O-O, O-O-O)
for move in ["O-O", "O-O-O"]:
vocab[move] = idx
idx += 1
return vocab
@property
def vocab_size(self) -> int:
"""Return vocabulary size."""
return len(self._vocab)
def get_vocab(self) -> Dict[str, int]:
"""Return vocabulary dictionary."""
return dict(self._vocab)
def _decompose_move(self, move: str) -> List[str]:
"""Decompose chess move into component tokens.
Args:
move: Chess move in format WPe2e4 or BNg8f6(x)
Returns:
List of component tokens [color, piece, from_sq, action, to_sq, modifiers...]
"""
move = move.strip()
# Handle castling
if "O-O-O" in move or "o-o-o" in move.lower():
return [move[0], "O-O-O"] # Color + castling
elif "O-O" in move or "o-o" in move.lower():
return [move[0], "O-O"] # Color + castling
# Basic validation
if len(move) < 6:
return [self.UNK_TOKEN]
# Extract basic components
tokens = [
move[0], # Color (W/B)
move[1], # Piece (P/N/B/R/Q/K)
move[2:4] # From square (e.g., e2)
]
# Determine action (capture vs move)
is_capture = "(x)" in move or "(x+" in move or "(x+*)" in move
tokens.append(self.CAPTURE_CROSS if is_capture else self.MOVE_ARROW)
# To square
tokens.append(move[4:6])
# Add modifiers (check, checkmate)
if "(+*)" in move or "(x+*)" in move:
tokens.append("+*")
elif "(+)" in move or "(x+)" in move:
tokens.append("+")
# Handle promotions (e.g., (Q), (+Q), (xQ))
promotion_match = re.search(r'\((?:x\s*)?(?:\+\s*)?([QRBN])\)', move)
if promotion_match:
tokens.append(f"={promotion_match.group(1)}")
return tokens
def _tokenize(self, text: str) -> List[str]:
"""Tokenize string of moves into component tokens.
Args:
text: String of space-separated chess moves
Returns:
List of all component tokens
"""
all_tokens = []
for move in text.strip().split():
all_tokens.extend(self._decompose_move(move))
return all_tokens
def _convert_token_to_id(self, token: str) -> int:
"""Convert token to ID."""
return self._vocab.get(token, self._vocab.get(self.UNK_TOKEN, 3))
def _convert_id_to_token(self, index: int) -> str:
"""Convert ID to token."""
return self._ids_to_tokens.get(index, self.UNK_TOKEN)
def convert_tokens_to_string(self, tokens: List[str]) -> str:
"""Convert component tokens back to move strings.
Args:
tokens: List of component tokens
Returns:
String of reconstructed chess moves
"""
# Filter out special tokens
special = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN}
tokens = [t for t in tokens if t not in special]
moves = []
i = 0
while i < len(tokens):
# Handle castling
if i + 1 < len(tokens) and tokens[i + 1] in ["O-O", "O-O-O"]:
# Castling is just color + castle notation
i += 2
continue
# Regular move: need at least 5 tokens [color, piece, from_sq, action, to_sq]
if i + 4 < len(tokens):
color, piece, from_sq, action, to_sq = tokens[i:i+5]
move = f"{color}{piece}{from_sq}{to_sq}"
# Collect modifiers that follow
j = i + 5
while j < len(tokens) and tokens[j] not in ["W", "B"]:
mod = tokens[j]
if mod == "+":
move += "(+)"
elif mod == "+*":
move += "(+*)"
elif mod.startswith("="):
move += f"({mod[1]})" # Promotion
j += 1
# Add capture marker if action was capture
if action == self.CAPTURE_CROSS and "(x)" not in move:
move += "(x)"
moves.append(move)
i = j
else:
i += 1
return " ".join(moves)
def build_inputs_with_special_tokens(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None) -> List[int]:
"""Build model inputs with BOS/EOS tokens."""
bos_id = self._vocab[self.BOS_TOKEN]
eos_id = self._vocab[self.EOS_TOKEN]
def wrap(ids: List[int]) -> List[int]:
out = list(ids)
# Only add BOS if missing
if len(out) == 0 or out[0] != bos_id:
out = [bos_id] + out
# Only add EOS if missing
if len(out) == 0 or out[-1] != eos_id:
out = out + [eos_id]
return out
if token_ids_1 is None:
return wrap(token_ids_0)
# For pair inputs, keep it simple/consistent: wrap concatenation
return wrap(token_ids_0 + token_ids_1)
def get_special_tokens_mask(
self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None,
already_has_special_tokens: bool = False
) -> List[int]:
"""Return mask identifying special tokens."""
bos_id = self._vocab[self.BOS_TOKEN]
eos_id = self._vocab[self.EOS_TOKEN]
pad_id = self._vocab[self.PAD_TOKEN]
unk_id = self._vocab[self.UNK_TOKEN]
def mask_for(ids: List[int]) -> List[int]:
return [1 if tid in {bos_id, eos_id, pad_id, unk_id} else 0 for tid in ids]
if already_has_special_tokens:
if token_ids_1 is None:
return mask_for(token_ids_0)
return mask_for(token_ids_0 + token_ids_1)
# No special tokens yet: build mask consistent with build_inputs_with_special_tokens
if token_ids_1 is None:
return [1] + ([0] * len(token_ids_0)) + [1]
return [1] + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + [1]
def create_token_type_ids_from_sequences(
self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""Create token type IDs (all zeros for single sequence type)."""
# Build final ids consistently with build_inputs_with_special_tokens
final_ids = self.build_inputs_with_special_tokens(token_ids_0, token_ids_1)
return [0] * len(final_ids)
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple:
"""Save vocabulary to JSON file."""
if not os.path.isdir(save_directory):
os.makedirs(save_directory, exist_ok=True)
vocab_file = os.path.join(
save_directory,
(filename_prefix + "-" if filename_prefix else "") + "vocab.json"
)
with open(vocab_file, "w", encoding="utf-8") as f:
json.dump(self._vocab, f, ensure_ascii=False, indent=2)
return (vocab_file,)