Chess-Eya / tokenizer.py
eyaa99's picture
Chess Challenge submission by eyaa99
2d23967 verified
"""
Custom Chess Tokenizer for the Chess Challenge.
This tokenizer treats each move as a single token using the extended UCI notation
from the Lichess dataset (e.g., WPe2e4, BNg8f6).
The dataset format uses:
- W/B prefix for White/Black
- Piece letter: P=Pawn, N=Knight, B=Bishop, R=Rook, Q=Queen, K=King
- Source and destination squares (e.g., e2e4)
- Special suffixes: (x)=capture, (+)=check, (+*)=checkmate, (o)/(O)=castling
"""
from __future__ import annotations
import json, os, re
from typing import Dict, List, Optional
from transformers import PreTrainedTokenizer
_MOVE_RE = re.compile(r"^(?P<side>[WB])(?P<piece>[PNBRQK])(?P<src>[a-h][1-8])(?P<dst>[a-h][1-8])(?P<suffix>.*)$")
_PROMO_RE = re.compile(r"=([QRBNqrbn])")
def _parse_suffix(suffix: str):
s = (suffix or "").strip()
is_capture = "x" in s
is_check = "+" in s
is_mate = "*" in s
castle = "O-O-O" if "(O)" in s else ("O-O" if "(o)" in s else None)
promo = None
m = _PROMO_RE.search(s)
if m:
promo = m.group(1).lower()
return is_capture, is_check, is_mate, castle, promo
class ChessTokenizer(PreTrainedTokenizer):
"""
A custom tokenizer for chess moves using extended UCI notation.
This tokenizer maps each possible chess move to a unique token ID.
The vocabulary is built from the training dataset to ensure all moves
encountered during training have a corresponding token.
Example:
>>> tokenizer = ChessTokenizer()
>>> tokenizer.encode("WPe2e4 BPe7e5")
[1, 42, 87, 2] # [BOS, e2e4, e7e5, EOS]
"""
model_input_names = ["input_ids", "attention_mask"]
vocab_files_names = {"vocab_file": "vocab.json"}
PAD_TOKEN = "[PAD]"
BOS_TOKEN = "[BOS]"
EOS_TOKEN = "[EOS]"
UNK_TOKEN = "[UNK]"
def __init__(self, vocab_file: Optional[str] = None, vocab: Optional[Dict[str, int]] = None, **kwargs):
"""
Initialize the chess tokenizer.
Args:
vocab_file: Path to a JSON file containing the vocabulary mapping.
vocab: Dictionary mapping tokens to IDs (alternative to vocab_file).
**kwargs: Additional arguments passed to PreTrainedTokenizer.
"""
# Initialize special tokens
self._pad_token = self.PAD_TOKEN
self._bos_token = self.BOS_TOKEN
self._eos_token = self.EOS_TOKEN
self._unk_token = self.UNK_TOKEN
kwargs.pop("pad_token", None)
kwargs.pop("bos_token", None)
kwargs.pop("eos_token", None)
kwargs.pop("unk_token", None)
if vocab is not None:
self._vocab = vocab
elif vocab_file and os.path.exists(vocab_file):
with open(vocab_file, "r", encoding="utf-8") as f:
self._vocab = json.load(f)
else:
self._vocab = self._create_default_vocab()
self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
super().__init__(
pad_token=self._pad_token,
bos_token=self._bos_token,
eos_token=self._eos_token,
unk_token=self._unk_token,
**kwargs,
)
@property
def vocab_size(self) -> int:
return len(self._vocab)
def get_vocab(self) -> Dict[str, int]:
return dict(self._vocab)
def _convert_token_to_id(self, token: str) -> int:
return self._vocab.get(token, self._vocab.get(self.UNK_TOKEN, 0))
def _convert_id_to_token(self, index: int) -> str:
return self._ids_to_tokens.get(index, self.UNK_TOKEN)
def _create_default_vocab(self) -> Dict[str, int]:
"""
Create a minimal default vocabulary with just special tokens.
For the full vocabulary, use `build_vocab_from_dataset()`.
This minimal vocab is just a placeholder - you should build from data.
"""
tokens: List[str] = [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]
tokens += [f"[W{p}]" for p in "PNBRQK"]
tokens += [f"[B{p}]" for p in "PNBRQK"]
tokens += [f"[{f}{r}]" for f in "abcdefgh" for r in "12345678"]
tokens += ["[x]", "[+]", "[#]", "[O-O]", "[O-O-O]"]
tokens += [f"[={p}]" for p in "qrbn"]
return {tok: i for i, tok in enumerate(tokens)}
def _tokenize(self, text: str) -> List[str]:
out: List[str] = []
for move in (text or "").strip().split():
# Raw UCI like e2e4 / e7e8q (no side/piece available)
if re.fullmatch(r"[a-h][1-8][a-h][1-8][qrbn]?", move):
src, dst = move[:2], move[2:4]
out += [f"[{src}]", f"[{dst}]"]
if len(move) == 5:
out += [f"[={move[4]}]"]
continue
m = _MOVE_RE.match(move)
if not m:
out.append(self.UNK_TOKEN)
continue
side = m.group("side") # "W" or "B"
piece = m.group("piece") # P/N/B/R/Q/K
src = f"[{m.group('src')}]"
dst = f"[{m.group('dst')}]"
is_cap, is_chk, is_mate, castle, promo = _parse_suffix(m.group("suffix") or "")
out += [f"[{side}{piece}]", src, dst]
if castle:
out.append(f"[{castle}]")
if is_cap:
out.append("[x]")
if is_mate:
out.append("[#]")
elif is_chk:
out.append("[+]")
if promo:
out.append(f"[={promo}]")
return out
def convert_tokens_to_string(self, tokens: List[str]) -> str:
special = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN}
return " ".join(t for t in tokens if t not in special)
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple:
if not os.path.isdir(save_directory):
os.makedirs(save_directory, exist_ok=True)
vocab_file = os.path.join(save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.json")
with open(vocab_file, "w", encoding="utf-8") as f:
json.dump(self._vocab, f, ensure_ascii=False, indent=2)
return (vocab_file,)