| import json |
| import os |
| import re |
| from transformers import PreTrainedTokenizer |
|
|
| class RetroGPTTokenizer(PreTrainedTokenizer): |
| vocab_files_names = {"vocab_file": "vocab.json"} |
| model_input_names = ["input_ids", "attention_mask"] |
| |
| def __init__(self, vocab_file=None, bos_token="<s>", eos_token="</s>", sep_token="<sep>", pad_token="<pad>", **kwargs): |
| if vocab_file is None: |
| |
| vocab_file = os.path.join(os.path.dirname(__file__), "vocab.json") |
| |
| if os.path.exists(vocab_file): |
| with open(vocab_file, "r") as f: |
| self.stoi = json.load(f) |
| else: |
| self.stoi = {} |
| |
| self.itos = {int(v): k for k, v in self.stoi.items()} |
| |
| |
| self.pattern = re.compile(r"(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])") |
| |
| super().__init__( |
| bos_token=bos_token, |
| eos_token=eos_token, |
| sep_token=sep_token, |
| pad_token=pad_token, |
| **kwargs |
| ) |
|
|
| @property |
| def vocab_size(self): |
| return len(self.stoi) |
|
|
| def get_vocab(self): |
| return {k: int(v) for k, v in self.stoi.items()} |
|
|
| def _tokenize(self, text): |
| return self.pattern.findall(text) |
|
|
| def _convert_token_to_id(self, token): |
| return self.stoi.get(token, self.stoi.get(self.pad_token)) |
|
|
| def _convert_id_to_token(self, index): |
| return self.itos.get(index, self.pad_token) |
|
|
| def save_vocabulary(self, save_directory, filename_prefix=None): |
| vocab_file = os.path.join(save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.json") |
| with open(vocab_file, "w") as f: |
| json.dump(self.stoi, f, indent=2) |
| return (vocab_file,) |
|
|