import json import os from transformers import PreTrainedTokenizer class ChessIntelligentTokenizer(PreTrainedTokenizer): model_input_names = ["input_ids", "attention_mask"] vocab_files_names = {"vocab_file": "vocab.json"} PAD_TOKEN = "[PAD]" BOS_TOKEN = "[BOS]" EOS_TOKEN = "[EOS]" UNK_TOKEN = "[UNK]" def __init__(self, vocab_file=None, vocab=None, **kwargs): self._pad_token = self.PAD_TOKEN self._bos_token = self.BOS_TOKEN self._eos_token = self.EOS_TOKEN self._unk_token = self.UNK_TOKEN kwargs.pop("pad_token", None) kwargs.pop("bos_token", None) kwargs.pop("eos_token", None) kwargs.pop("unk_token", None) if vocab is not None: self._vocab = vocab elif vocab_file is not None and os.path.exists(vocab_file): with open(vocab_file, "r", encoding="utf-8") as f: self._vocab = json.load(f) else: self._vocab = self._create_chess_vocab() self._ids_to_tokens = {v: k for k, v in self._vocab.items()} super().__init__(pad_token=self._pad_token, bos_token=self._bos_token, eos_token=self._eos_token, unk_token=self._unk_token, **kwargs) def _create_chess_vocab(self): tokens = [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN] tokens.extend(["W", "B"]) tokens.extend(["P", "N", "R", "Q", "K", "B"]) files = "abcdefgh" ranks = "12345678" for f in files: for r in ranks: tokens.append(f + r) tokens.extend(["(", ")", "x", "+", "*", "o", "=", " "]) vocab = {token: idx for idx, token in enumerate(tokens)} return vocab @property def vocab_size(self): return len(self._vocab) def get_vocab(self): return dict(self._vocab) def _tokenize(self, text): tokens = [] moves = text.strip().split() for i, move in enumerate(moves): if i > 0: tokens.append(" ") tokens.extend(self._parse_move(move)) return tokens def _parse_move(self, move): tokens = [] idx = 0 if idx < len(move) and move[idx] in "WB": tokens.append(move[idx]) idx += 1 if idx < len(move) and move[idx] in "PNRQKB": tokens.append(move[idx]) idx += 1 while idx < len(move): if idx + 1 < len(move) and move[idx] in "abcdefgh" and move[idx+1] in "12345678": tokens.append(move[idx:idx+2]) idx += 2 elif move[idx] in "()+*xo=": tokens.append(move[idx]) idx += 1 else: idx += 1 return tokens def _convert_token_to_id(self, token): return self._vocab.get(token, self._vocab[self.UNK_TOKEN]) def _convert_id_to_token(self, index): return self._ids_to_tokens.get(index, self.UNK_TOKEN) def convert_tokens_to_string(self, tokens): special = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN} return "".join(t for t in tokens if t not in special) def save_vocabulary(self, save_directory, filename_prefix=None): if not os.path.isdir(save_directory): os.makedirs(save_directory, exist_ok=True) vocab_file = os.path.join(save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.json") with open(vocab_file, "w", encoding="utf-8") as f: json.dump(self._vocab, f, ensure_ascii=False, indent=2) return (vocab_file,)