chess-v1 / src /tokenizer.py
MDaytek's picture
Submission by MDaytek
ba848b8 verified
from transformers import PreTrainedTokenizer
import json
import os
class ChessTokenizer(PreTrainedTokenizer):
model_input_names = ["input_ids", "attention_mask"]
def __init__(self, vocab_file="vocab.json", **kwargs):
if os.path.exists(vocab_file):
with open(vocab_file, 'r') as f: data = json.load(f)
self.token_to_id = data["token_to_id"]
self.id_to_token = {int(k): v for k, v in data["id_to_token"].items()}
else:
raise ValueError(f"CRITIQUE: {vocab_file} introuvable.")
self.unk_token = "[UNK]"
self.pad_token = "[PAD]"
self.bos_token = "[BOS]"
self.eos_token = "[EOS]"
super().__init__(pad_token="[PAD]", bos_token="[BOS]", eos_token="[EOS]", unk_token="[UNK]", **kwargs)
@property
def vocab_size(self): return len(self.token_to_id)
def get_vocab(self): return self.token_to_id
def _convert_token_to_id(self, token): return self.token_to_id.get(token, self.token_to_id.get("[UNK]", 0))
def _convert_id_to_token(self, index): return self.id_to_token.get(index, "[UNK]")
def __call__(self, text, **kwargs):
if isinstance(text, list): return {"input_ids": [[0]] * len(text)}
moves = text.split()
ids = [self.token_to_id.get(m, self.token_to_id.get("[UNK]", 0)) for m in moves]
max_len = kwargs.get('max_length', 256)
ids = ids[:max_len]
return {"input_ids": ids}
def save_pretrained(self, save_directory, **kwargs):
with open(os.path.join(save_directory, "vocab.json"), "w") as f:
json.dump({"token_to_id": self.token_to_id, "id_to_token": self.id_to_token}, f)
with open(os.path.join(save_directory, "tokenizer_config.json"), "w") as f:
json.dump({"model_type": "chess_transformer"}, f)
@classmethod
def from_pretrained(cls, path, **kwargs):
vocab_path = os.path.join(path, "vocab.json")
if os.path.exists(vocab_path): return cls(vocab_file=vocab_path, **kwargs)
return cls(**kwargs)