| import pickle | |
| import re | |
| from typing import List | |
| class TAIMGANTokenizer: | |
| def __init__(self, captions_path): | |
| with open(captions_path, "rb") as ckpt_file: | |
| captions = pickle.load(ckpt_file) | |
| self.ix_to_word = captions[2] | |
| self.word_to_ix = captions[3] | |
| self.token_regex = r'\w+' | |
| self.pad_token_id = self.word_to_ix["<end>"] | |
| self.pad_repr = "[PAD]" | |
| def encode(self, text: str) -> List[int]: | |
| return [self.word_to_ix.get(word, self.pad_token_id) | |
| for word in re.findall(self.token_regex, text.lower())] | |
| def decode(self, tokens: List[int]) -> str: | |
| return ' '.join([self.ix_to_word[token] | |
| if token != self.pad_token_id else self.pad_repr | |
| for token in tokens]) | |