|
|
import json |
|
|
import re |
|
|
from pathlib import Path |
|
|
from typing import List, Dict, Union, Optional |
|
|
from transformers import PreTrainedTokenizer |
|
|
|
|
|
|
|
|
class ChessTokenizer(PreTrainedTokenizer): |
|
|
""" |
|
|
Chess move tokenizer compatible with HuggingFace transformers. |
|
|
Can be loaded with: AutoTokenizer.from_pretrained("ankanmbz/chess-tok") |
|
|
""" |
|
|
|
|
|
vocab_files_names = { |
|
|
"vocab_file": "vocab.json", |
|
|
} |
|
|
|
|
|
model_input_names = ["input_ids", "attention_mask"] |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
vocab_file, |
|
|
unk_token="<unk>", |
|
|
pad_token="<pad>", |
|
|
bos_token="<sos>", |
|
|
eos_token="<eos>", |
|
|
**kwargs |
|
|
): |
|
|
|
|
|
with open(vocab_file, 'r', encoding='utf-8') as f: |
|
|
self.encoder = json.load(f) |
|
|
|
|
|
self.decoder = {v: k for k, v in self.encoder.items()} |
|
|
|
|
|
|
|
|
self.pattern = r"w\.|b\.|[ββββββββββββ][a-h][1-8]|[a-h][1-8]|\.\.[+#x\.]*|\.x\.[+#]*|\.\.x\.[+#]*|\+#|\+|x|\." |
|
|
|
|
|
|
|
|
super().__init__( |
|
|
unk_token=unk_token, |
|
|
pad_token=pad_token, |
|
|
bos_token=bos_token, |
|
|
eos_token=eos_token, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
@property |
|
|
def vocab_size(self) -> int: |
|
|
return len(self.encoder) |
|
|
|
|
|
def get_vocab(self) -> Dict[str, int]: |
|
|
return dict(self.encoder) |
|
|
|
|
|
def _tokenize(self, text: str) -> List[str]: |
|
|
"""Tokenize a chess move string into tokens""" |
|
|
tokens = re.findall(self.pattern, text) |
|
|
return tokens |
|
|
|
|
|
def _convert_token_to_id(self, token: str) -> int: |
|
|
"""Convert a token to an id using the vocab""" |
|
|
return self.encoder.get(token, self.encoder.get(self.unk_token)) |
|
|
|
|
|
def _convert_id_to_token(self, index: int) -> str: |
|
|
"""Convert an id to a token using the vocab""" |
|
|
return self.decoder.get(index, self.unk_token) |
|
|
|
|
|
def convert_tokens_to_string(self, tokens: List[str]) -> str: |
|
|
"""Convert a list of tokens to a string""" |
|
|
return ''.join(tokens) |
|
|
|
|
|
def build_inputs_with_special_tokens( |
|
|
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None |
|
|
) -> List[int]: |
|
|
""" |
|
|
Build model inputs by adding special tokens. |
|
|
Format: <sos> X <eos> |
|
|
""" |
|
|
|
|
|
bos = [self.bos_token_id] if self.bos_token_id is not None else [] |
|
|
eos = [self.eos_token_id] if self.eos_token_id is not None else [] |
|
|
|
|
|
if token_ids_1 is None: |
|
|
return bos + token_ids_0 + eos |
|
|
return bos + token_ids_0 + eos + token_ids_1 + eos |
|
|
|
|
|
def get_special_tokens_mask( |
|
|
self, |
|
|
token_ids_0: List[int], |
|
|
token_ids_1: Optional[List[int]] = None, |
|
|
already_has_special_tokens: bool = False, |
|
|
) -> List[int]: |
|
|
""" |
|
|
Get mask for special tokens |
|
|
""" |
|
|
if already_has_special_tokens: |
|
|
return super().get_special_tokens_mask( |
|
|
token_ids_0=token_ids_0, |
|
|
token_ids_1=token_ids_1, |
|
|
already_has_special_tokens=True, |
|
|
) |
|
|
|
|
|
bos = [1] if self.bos_token_id is not None else [] |
|
|
eos = [1] if self.eos_token_id is not None else [] |
|
|
|
|
|
if token_ids_1 is None: |
|
|
return bos + ([0] * len(token_ids_0)) + eos |
|
|
return bos + ([0] * len(token_ids_0)) + eos + ([0] * len(token_ids_1)) + eos |
|
|
|
|
|
def create_token_type_ids_from_sequences( |
|
|
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None |
|
|
) -> List[int]: |
|
|
""" |
|
|
Create token type IDs (not used for chess, but required by interface) |
|
|
""" |
|
|
bos = [0] if self.bos_token_id is not None else [] |
|
|
eos = [0] if self.eos_token_id is not None else [] |
|
|
|
|
|
if token_ids_1 is None: |
|
|
return bos + ([0] * len(token_ids_0)) + eos |
|
|
return bos + ([0] * len(token_ids_0)) + eos + ([1] * len(token_ids_1)) + eos |
|
|
|
|
|
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple: |
|
|
"""Save the vocabulary to a directory""" |
|
|
if not Path(save_directory).is_dir(): |
|
|
print(f"Vocabulary path {save_directory} should be a directory") |
|
|
return |
|
|
|
|
|
vocab_file = Path(save_directory) / ( |
|
|
(filename_prefix + "-" if filename_prefix else "") + self.vocab_files_names["vocab_file"] |
|
|
) |
|
|
|
|
|
with open(vocab_file, "w", encoding="utf-8") as f: |
|
|
json.dump(self.encoder, f, ensure_ascii=False, indent=2) |
|
|
|
|
|
return (str(vocab_file),) |
|
|
|
|
|
def prepare_for_tokenization(self, text: str, **kwargs): |
|
|
"""Prepare text before tokenization - must return (text, kwargs) tuple""" |
|
|
return (text, kwargs) |
|
|
|
|
|
def _decode( |
|
|
self, |
|
|
token_ids: List[int], |
|
|
skip_special_tokens: bool = False, |
|
|
**kwargs |
|
|
) -> str: |
|
|
"""Decode token ids to string""" |
|
|
tokens = [self._convert_id_to_token(id) for id in token_ids] |
|
|
|
|
|
if skip_special_tokens: |
|
|
tokens = [ |
|
|
token for token in tokens |
|
|
if token not in [self.pad_token, self.bos_token, self.eos_token, self.unk_token] |
|
|
] |
|
|
|
|
|
return self.convert_tokens_to_string(tokens) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_hf_tokenizer(dataset_path, output_dir="chess-tok-hf"): |
|
|
"""Build HuggingFace-compatible tokenizer from dataset""" |
|
|
import pandas as pd |
|
|
from collections import Counter |
|
|
|
|
|
print("Building HuggingFace-compatible tokenizer...") |
|
|
print(f"Dataset: {dataset_path}") |
|
|
print(f"Output: {output_dir}") |
|
|
|
|
|
|
|
|
df = pd.read_parquet(dataset_path) |
|
|
df = df.head(1_000_000) |
|
|
print(f"β Loaded {len(df):,} rows") |
|
|
|
|
|
|
|
|
all_tokens = set() |
|
|
token_freq = Counter() |
|
|
pattern = r"w\.|b\.|[ββββββββββββ][a-h][1-8]|[a-h][1-8]|\.\.[+#x\.]*|\.x\.[+#]*|\.\.x\.[+#]*|\+#|\+|x|\." |
|
|
|
|
|
for moves_list in df['moves_custom']: |
|
|
for move in moves_list: |
|
|
tokens = re.findall(pattern, move) |
|
|
for token in tokens: |
|
|
all_tokens.add(token) |
|
|
token_freq[token] += 1 |
|
|
|
|
|
print(f"β Found {len(all_tokens)} unique tokens") |
|
|
|
|
|
|
|
|
special_tokens = { |
|
|
"<pad>": 0, |
|
|
"<sos>": 1, |
|
|
"<eos>": 2, |
|
|
"<unk>": 3 |
|
|
} |
|
|
|
|
|
vocab = special_tokens.copy() |
|
|
current_id = len(vocab) |
|
|
|
|
|
|
|
|
for i in range(32, 127): |
|
|
char = chr(i) |
|
|
if char not in vocab: |
|
|
vocab[char] = current_id |
|
|
current_id += 1 |
|
|
|
|
|
|
|
|
sorted_tokens = sorted(token_freq.items(), key=lambda x: x[1], reverse=True) |
|
|
for token, freq in sorted_tokens: |
|
|
if token not in vocab: |
|
|
vocab[token] = current_id |
|
|
current_id += 1 |
|
|
|
|
|
print(f"β Vocabulary size: {len(vocab)}") |
|
|
|
|
|
|
|
|
output_path = Path(output_dir) |
|
|
output_path.mkdir(exist_ok=True, parents=True) |
|
|
|
|
|
|
|
|
with open(output_path / "vocab.json", 'w', encoding='utf-8') as f: |
|
|
json.dump(vocab, f, ensure_ascii=False, indent=2) |
|
|
print(f"β Saved: vocab.json") |
|
|
|
|
|
|
|
|
tokenizer_config = { |
|
|
"tokenizer_class": "ChessTokenizer", |
|
|
"auto_map": { |
|
|
"AutoTokenizer": ["tokenizer.ChessTokenizer", None] |
|
|
}, |
|
|
"model_max_length": 512, |
|
|
"pad_token": "<pad>", |
|
|
"sos_token": "<sos>", |
|
|
"eos_token": "<eos>", |
|
|
"unk_token": "<unk>", |
|
|
"clean_up_tokenization_spaces": True |
|
|
} |
|
|
|
|
|
with open(output_path / "tokenizer_config.json", 'w') as f: |
|
|
json.dump(tokenizer_config, f, indent=2) |
|
|
print(f"β Saved: tokenizer_config.json") |
|
|
|
|
|
|
|
|
special_tokens_map = { |
|
|
"pad_token": "<pad>", |
|
|
"bos_token": "<sos>", |
|
|
"eos_token": "<eos>", |
|
|
"unk_token": "<unk>" |
|
|
} |
|
|
|
|
|
with open(output_path / "special_tokens_map.json", 'w') as f: |
|
|
json.dump(special_tokens_map, f, indent=2) |
|
|
print(f"β Saved: special_tokens_map.json") |
|
|
|
|
|
|
|
|
import shutil |
|
|
shutil.copy(__file__, output_path / "tokenizer.py") |
|
|
print(f"β Saved: tokenizer.py") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
dataset_path = "/vast/users/ankan.deria/Document/TinyRecursiveModels/data/chees_data/dataset.parquet" |
|
|
build_hf_tokenizer(dataset_path, output_dir="chess-tok-hf") |