|
|
""" |
|
|
Custom Chess Tokenizer for the Chess Challenge. |
|
|
We build a vocabulary with: |
|
|
- W/B prefix for White/Black |
|
|
- Piece letter: P=Pawn, N=Knight, B=Bishop, R=Rook, Q=Queen, K=King |
|
|
- Source and rank and file: e.g e 2 |
|
|
- Destination and rank and file: e.g e 4 |
|
|
- Special suffixes: (x)=capture, (+)=check, (+*)=checkmate, (o)/(O)=castling |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import json |
|
|
import os |
|
|
from pathlib import Path |
|
|
import shutil |
|
|
import inspect |
|
|
from typing import Dict, List, Optional |
|
|
|
|
|
from transformers import PreTrainedTokenizer |
|
|
from datasets import load_dataset |
|
|
|
|
|
|
|
|
class ChessTokenizer(PreTrainedTokenizer): |
|
|
|
|
|
model_input_names = ["input_ids", "attention_mask"] |
|
|
vocab_files_names = {"vocab_file": "vocab.json"} |
|
|
|
|
|
|
|
|
PAD_TOKEN = "[PAD]" |
|
|
BOS_TOKEN = "[BOS]" |
|
|
EOS_TOKEN = "[EOS]" |
|
|
UNK_TOKEN = "[UNK]" |
|
|
SEP_TOKEN = "[SEP]" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
vocab_file: Optional[str] = None, |
|
|
vocab: Optional[Dict[str, int]] = None, |
|
|
**kwargs, |
|
|
): |
|
|
|
|
|
self._pad_token = self.PAD_TOKEN |
|
|
self._bos_token = self.BOS_TOKEN |
|
|
self._eos_token = self.EOS_TOKEN |
|
|
self._unk_token = self.UNK_TOKEN |
|
|
self._sep_token = self.SEP_TOKEN |
|
|
|
|
|
kwargs.pop("pad_token", None) |
|
|
kwargs.pop("bos_token", None) |
|
|
kwargs.pop("eos_token", None) |
|
|
kwargs.pop("unk_token", None) |
|
|
kwargs.pop("sep_token", None) |
|
|
|
|
|
print("Initializing ChessTokenizer") |
|
|
print(f" vocab_file: {vocab_file}") |
|
|
print(f" vocab provided: {vocab is not None}") |
|
|
print(f" vocab: {vocab}") |
|
|
|
|
|
print(os.listdir(".")) |
|
|
|
|
|
vocab = { |
|
|
"[PAD]": 0, |
|
|
"[BOS]": 1, |
|
|
"[EOS]": 2, |
|
|
"[UNK]": 3, |
|
|
"[SEP]": 4, |
|
|
"(+)": 5, |
|
|
"(+*)": 6, |
|
|
"(+*B)": 7, |
|
|
"(+*N)": 8, |
|
|
"(+*Q)": 9, |
|
|
"(+*R)": 10, |
|
|
"(+B)": 11, |
|
|
"(+N)": 12, |
|
|
"(+Q)": 13, |
|
|
"(+R)": 14, |
|
|
"(B)": 15, |
|
|
"(N)": 16, |
|
|
"(O)": 17, |
|
|
"(O+)": 18, |
|
|
"(O+*)": 19, |
|
|
"(Q)": 20, |
|
|
"(R)": 21, |
|
|
"(o)": 22, |
|
|
"(o+)": 23, |
|
|
"(o+*)": 24, |
|
|
"(x)": 25, |
|
|
"(x+)": 26, |
|
|
"(x+*)": 27, |
|
|
"(x+*B)": 28, |
|
|
"(x+*Q)": 29, |
|
|
"(x+*R)": 30, |
|
|
"(x+B)": 31, |
|
|
"(x+N)": 32, |
|
|
"(x+Q)": 33, |
|
|
"(x+R)": 34, |
|
|
"(xB)": 35, |
|
|
"(xE)": 36, |
|
|
"(xE+)": 37, |
|
|
"(xE+*)": 38, |
|
|
"(xN)": 39, |
|
|
"(xQ)": 40, |
|
|
"(xR)": 41, |
|
|
"B": 42, |
|
|
"K": 43, |
|
|
"N": 44, |
|
|
"P": 45, |
|
|
"Q": 46, |
|
|
"R": 47, |
|
|
"W": 48, |
|
|
"a1": 49, |
|
|
"a2": 50, |
|
|
"a3": 51, |
|
|
"a4": 52, |
|
|
"a5": 53, |
|
|
"a6": 54, |
|
|
"a7": 55, |
|
|
"a8": 56, |
|
|
"b1": 57, |
|
|
"b2": 58, |
|
|
"b3": 59, |
|
|
"b4": 60, |
|
|
"b5": 61, |
|
|
"b6": 62, |
|
|
"b7": 63, |
|
|
"b8": 64, |
|
|
"c1": 65, |
|
|
"c2": 66, |
|
|
"c3": 67, |
|
|
"c4": 68, |
|
|
"c5": 69, |
|
|
"c6": 70, |
|
|
"c7": 71, |
|
|
"c8": 72, |
|
|
"d1": 73, |
|
|
"d2": 74, |
|
|
"d3": 75, |
|
|
"d4": 76, |
|
|
"d5": 77, |
|
|
"d6": 78, |
|
|
"d7": 79, |
|
|
"d8": 80, |
|
|
"e1": 81, |
|
|
"e2": 82, |
|
|
"e3": 83, |
|
|
"e4": 84, |
|
|
"e5": 85, |
|
|
"e6": 86, |
|
|
"e7": 87, |
|
|
"e8": 88, |
|
|
"f1": 89, |
|
|
"f2": 90, |
|
|
"f3": 91, |
|
|
"f4": 92, |
|
|
"f5": 93, |
|
|
"f6": 94, |
|
|
"f7": 95, |
|
|
"f8": 96, |
|
|
"g1": 97, |
|
|
"g2": 98, |
|
|
"g3": 99, |
|
|
"g4": 100, |
|
|
"g5": 101, |
|
|
"g6": 102, |
|
|
"g7": 103, |
|
|
"g8": 104, |
|
|
"h1": 105, |
|
|
"h2": 106, |
|
|
"h3": 107, |
|
|
"h4": 108, |
|
|
"h5": 109, |
|
|
"h6": 110, |
|
|
"h7": 111, |
|
|
"h8": 112, |
|
|
|
|
|
} |
|
|
|
|
|
if vocab is not None: |
|
|
self._vocab = vocab |
|
|
elif vocab_file is not None and os.path.exists(vocab_file): |
|
|
with open(vocab_file, "r", encoding="utf-8") as f: |
|
|
self._vocab = json.load(f) |
|
|
else: |
|
|
print("No vocabulary provided; creating default minimal vocab.") |
|
|
self._vocab = self._create_default_vocab() |
|
|
|
|
|
self._ids_to_tokens = {v: k for k, v in self._vocab.items()} |
|
|
|
|
|
super().__init__( |
|
|
pad_token=self._pad_token, |
|
|
bos_token=self._bos_token, |
|
|
eos_token=self._eos_token, |
|
|
unk_token=self._unk_token, |
|
|
sep_token=self._sep_token, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
def _create_default_vocab(self) -> Dict[str, int]: |
|
|
special_tokens = [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN, self.SEP_TOKEN] |
|
|
vocab = {token: idx for idx, token in enumerate(special_tokens)} |
|
|
return vocab |
|
|
|
|
|
|
|
|
@classmethod |
|
|
def build_vocab_from_dataset( |
|
|
cls, |
|
|
dataset_name: str = "dlouapre/lichess_2025-01_1M", |
|
|
split: str = "train", |
|
|
column: str = "text", |
|
|
min_frequency: Optional[int] = 1, |
|
|
max_samples: Optional[int] = None, |
|
|
save_path: Optional[str] = None, |
|
|
) -> "ChessTokenizer": |
|
|
|
|
|
|
|
|
|
|
|
if save_path is None: |
|
|
cwd = os.getcwd() |
|
|
save_path = os.path.join(cwd, "chess_tokenizer_vocab.json") |
|
|
|
|
|
if os.path.exists(save_path): |
|
|
try: |
|
|
with open(save_path, "r", encoding="utf-8") as f: |
|
|
print("Loading existing tokenizer vocab from", save_path) |
|
|
vocab = json.load(f) |
|
|
return cls(vocab=vocab) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
dataset = load_dataset(dataset_name, split=split) |
|
|
|
|
|
samples = dataset[column] |
|
|
|
|
|
tokens = set() |
|
|
|
|
|
for game in samples: |
|
|
if not isinstance(game, str): |
|
|
continue |
|
|
moves = game.strip().split() |
|
|
for move in moves: |
|
|
if len(move) < 2: |
|
|
continue |
|
|
color = move[0] |
|
|
piece = move[1] |
|
|
from_square = move[2:4] if len(move) >= 4 else '' |
|
|
to_square = move[4:6] if len(move) >= 6 else '' |
|
|
suffix = move[6:] if len(move) > 6 else '' |
|
|
|
|
|
tokens.add(color) |
|
|
tokens.add(piece) |
|
|
tokens.add(from_square) |
|
|
tokens.add(to_square) |
|
|
if suffix: |
|
|
tokens.add(suffix) |
|
|
|
|
|
tokens = sorted(tokens) |
|
|
|
|
|
special_tokens = [cls.PAD_TOKEN, cls.BOS_TOKEN, cls.EOS_TOKEN, cls.UNK_TOKEN, cls.SEP_TOKEN] |
|
|
|
|
|
vocab: Dict[str, int] = {} |
|
|
idx = 0 |
|
|
for st in special_tokens: |
|
|
vocab[st] = idx |
|
|
idx += 1 |
|
|
|
|
|
for t in tokens: |
|
|
if t in vocab: |
|
|
continue |
|
|
vocab[t] = idx |
|
|
idx += 1 |
|
|
|
|
|
tokenizer = cls(vocab=vocab) |
|
|
|
|
|
try: |
|
|
if save_path is None: |
|
|
cwd = os.getcwd() |
|
|
save_path = os.path.join(cwd, "chess_tokenizer_vocab.json") |
|
|
|
|
|
tmp_path = save_path + ".tmp" |
|
|
with open(tmp_path, "w", encoding="utf-8") as f: |
|
|
json.dump(vocab, f, ensure_ascii=False, indent=2) |
|
|
os.replace(tmp_path, save_path) |
|
|
except Exception: |
|
|
|
|
|
try: |
|
|
if 'tmp_path' in locals() and os.path.exists(tmp_path): |
|
|
os.remove(tmp_path) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
return tokenizer |
|
|
|
|
|
@property |
|
|
def vocab_size(self) -> int: |
|
|
"""Return the size of the vocabulary.""" |
|
|
return len(self._vocab) |
|
|
|
|
|
def get_vocab(self) -> Dict[str, int]: |
|
|
"""Return the vocabulary as a dictionary.""" |
|
|
return dict(self._vocab) |
|
|
|
|
|
def _tokenize(self, text: str) -> List[str]: |
|
|
""" |
|
|
Tokenize a string of moves into a list of tokens. |
|
|
|
|
|
Args: |
|
|
text: A string of space-separated moves. |
|
|
|
|
|
Returns: |
|
|
List of move tokens. |
|
|
""" |
|
|
tokens: List[str] = [] |
|
|
for move in text.strip().split(): |
|
|
if len(move) < 2: |
|
|
continue |
|
|
color, piece, from_square, to_square, suffix = self._decompose_move(move) |
|
|
tokens.append(color) |
|
|
tokens.append(piece) |
|
|
tokens.append(from_square) |
|
|
tokens.append(to_square) |
|
|
if suffix: |
|
|
tokens.append(suffix) |
|
|
|
|
|
tokens.append(self._sep_token) |
|
|
|
|
|
return tokens[:-1] |
|
|
|
|
|
@staticmethod |
|
|
def _decompose_move(move: str): |
|
|
"""Decompose a move string into components: color, piece, from_square, to_square, suffix. |
|
|
Returns a 5-tuple of strings (empty strings for missing parts). |
|
|
""" |
|
|
color = move[0] |
|
|
piece = move[1] if len(move) >= 2 else '' |
|
|
from_square = move[2:4] if len(move) >= 4 else '' |
|
|
to_square = move[4:6] if len(move) >= 6 else '' |
|
|
suffix = move[6:] if len(move) > 6 else '' |
|
|
return color, piece, from_square, to_square, suffix |
|
|
|
|
|
def _convert_token_to_id(self, token: str) -> int: |
|
|
"""Convert a token to its ID.""" |
|
|
return self._vocab.get(token, self._vocab.get(self.UNK_TOKEN, 0)) |
|
|
|
|
|
def _convert_id_to_token(self, index: int) -> str: |
|
|
"""Convert an ID to its token.""" |
|
|
return self._ids_to_tokens.get(index, self.UNK_TOKEN) |
|
|
|
|
|
def convert_tokens_to_string(self, tokens: List[str]) -> str: |
|
|
"""Convert a list of tokens back to a string.""" |
|
|
|
|
|
special = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN} |
|
|
return " ".join(t for t in tokens if t not in special) |
|
|
|
|
|
def decode(self, token_ids: List[int], skip_special_tokens: bool = True) -> str: |
|
|
"""Decode a list of token IDs back to a string.""" |
|
|
tokens = [self._convert_id_to_token(int(tid)) for tid in token_ids] |
|
|
if skip_special_tokens: |
|
|
special = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN} |
|
|
|
|
|
tokens = [t if t != self.SEP_TOKEN else " " for t in tokens if t not in special] |
|
|
return "".join(tokens) |
|
|
|
|
|
def save_vocabulary( |
|
|
self, |
|
|
save_directory: str, |
|
|
filename_prefix: Optional[str] = None, |
|
|
) -> tuple: |
|
|
""" |
|
|
Save the vocabulary to a JSON file. |
|
|
|
|
|
Args: |
|
|
save_directory: Directory to save the vocabulary. |
|
|
filename_prefix: Optional prefix for the filename. |
|
|
|
|
|
Returns: |
|
|
Tuple containing the path to the saved vocabulary file. |
|
|
""" |
|
|
if not os.path.isdir(save_directory): |
|
|
os.makedirs(save_directory, exist_ok=True) |
|
|
|
|
|
vocab_file = os.path.join( |
|
|
save_directory, |
|
|
(filename_prefix + "-" if filename_prefix else "") + "vocab.json", |
|
|
) |
|
|
|
|
|
with open(vocab_file, "w", encoding="utf-8") as f: |
|
|
json.dump(self._vocab, f, ensure_ascii=False, indent=2) |
|
|
|
|
|
return (vocab_file,) |
|
|
|
|
|
def save_pretrained( |
|
|
self, |
|
|
save_directory: str, |
|
|
filename_prefix: Optional[str] = None, |
|
|
save_tokenizer_code: bool = True, |
|
|
) -> None: |
|
|
"""Save tokenizer files to a directory in a HF-compatible layout. |
|
|
This writes the vocab JSON (via `save_vocabulary`), a small |
|
|
`tokenizer_config.json` describing special tokens and the vocab |
|
|
filename, and optionally copies the tokenizer module source file |
|
|
into the directory so others can import the implementation. |
|
|
""" |
|
|
if not os.path.isdir(save_directory): |
|
|
os.makedirs(save_directory, exist_ok=True) |
|
|
|
|
|
|
|
|
vocab_file_tuple = self.save_vocabulary(save_directory, filename_prefix) |
|
|
vocab_file = vocab_file_tuple[0] |
|
|
|
|
|
|
|
|
config = { |
|
|
"tokenizer_class": self.__class__.__name__, |
|
|
"vocab_file": os.path.basename(vocab_file), |
|
|
"pad_token": self.PAD_TOKEN, |
|
|
"bos_token": self.BOS_TOKEN, |
|
|
"eos_token": self.EOS_TOKEN, |
|
|
"unk_token": self.UNK_TOKEN, |
|
|
} |
|
|
config_path = os.path.join(save_directory, "tokenizer_config.json") |
|
|
with open(config_path, "w", encoding="utf-8") as f: |
|
|
json.dump(config, f, ensure_ascii=False, indent=2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if save_tokenizer_code: |
|
|
try: |
|
|
src_file = Path(inspect.getsourcefile(self.__class__)) |
|
|
dst_file = Path(save_directory) / src_file.name |
|
|
shutil.copy2(src_file, dst_file) |
|
|
except Exception: |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
def count_vocab_from_dataset( |
|
|
dataset_name: str = "dlouapre/lichess_2025-01_1M", |
|
|
split: str = "train", |
|
|
column: str = "text", |
|
|
max_samples: Optional[int] = 10000, |
|
|
) -> Dict[str, int]: |
|
|
""" |
|
|
Count token frequencies in a dataset (useful for vocabulary analysis). |
|
|
|
|
|
Args: |
|
|
dataset_name: Name of the dataset on Hugging Face Hub. |
|
|
split: Dataset split to use. |
|
|
column: Column containing the game strings. |
|
|
max_samples: Maximum number of samples to process. |
|
|
|
|
|
Returns: |
|
|
Dictionary mapping tokens to their frequencies. |
|
|
""" |
|
|
from collections import Counter |
|
|
from datasets import load_dataset |
|
|
|
|
|
dataset = load_dataset(dataset_name, split=split) |
|
|
|
|
|
if max_samples is not None: |
|
|
dataset = dataset.select(range(min(max_samples, len(dataset)))) |
|
|
|
|
|
tokenizer = ChessTokenizer() |
|
|
token_counts = Counter() |
|
|
|
|
|
for example in dataset: |
|
|
token_counts.update(tokenizer._tokenize(example[column])) |
|
|
|
|
|
return dict(token_counts) |