RetroGPT-lite / tokenization_retrogpt.py
kssrikar4's picture
Upload 7 files
d3d71fe verified
import json
import os
import re
from transformers import PreTrainedTokenizer
class RetroGPTTokenizer(PreTrainedTokenizer):
vocab_files_names = {"vocab_file": "vocab.json"}
model_input_names = ["input_ids", "attention_mask"]
def __init__(self, vocab_file=None, bos_token="<s>", eos_token="</s>", sep_token="<sep>", pad_token="<pad>", **kwargs):
if vocab_file is None:
# Look for vocab.json in the same directory as the script
vocab_file = os.path.join(os.path.dirname(__file__), "vocab.json")
if os.path.exists(vocab_file):
with open(vocab_file, "r") as f:
self.stoi = json.load(f)
else:
self.stoi = {}
self.itos = {int(v): k for k, v in self.stoi.items()}
# Exact pattern from app.py
self.pattern = re.compile(r"(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])")
super().__init__(
bos_token=bos_token,
eos_token=eos_token,
sep_token=sep_token,
pad_token=pad_token,
**kwargs
)
@property
def vocab_size(self):
return len(self.stoi)
def get_vocab(self):
return {k: int(v) for k, v in self.stoi.items()}
def _tokenize(self, text):
return self.pattern.findall(text)
def _convert_token_to_id(self, token):
return self.stoi.get(token, self.stoi.get(self.pad_token))
def _convert_id_to_token(self, index):
return self.itos.get(index, self.pad_token)
def save_vocabulary(self, save_directory, filename_prefix=None):
vocab_file = os.path.join(save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.json")
with open(vocab_file, "w") as f:
json.dump(self.stoi, f, indent=2)
return (vocab_file,)