| import torch, json |
|
|
| class CharTokenizer: |
| def __init__(self, corpus=None, vocab=None): |
| if vocab is not None: |
| self.vocab = vocab |
| elif corpus is not None: |
| self.vocab = self._build_vocab(corpus) |
| else: |
| raise Exception("Either corpus or vocab has to be supplied") |
| self.id2vocab = [char for char, index in sorted(self.vocab.items(), key=lambda item: item[1])] |
| |
| def _tokenize(self, text): |
| return list(text) |
| |
| def __call__(self, prompt, text=None, add_eos_token=False): |
| token_ids = [self.vocab.get(token, 0) for token in self._tokenize(prompt)] |
| if text is not None: |
| text_token_ids = [self.vocab.get(token, 0) for token in self._tokenize(text)] |
| token_ids = token_ids + [self.vocab["<bos>"]] + text_token_ids |
| if add_eos_token: |
| token_ids = token_ids + [self.vocab["<eos>"]] |
| input_ids_tensor = torch.tensor(token_ids, dtype=torch.long).unsqueeze(0) |
| attention_masks = torch.ones_like(input_ids_tensor) |
| return {"input_ids": input_ids_tensor, "attention_mask": attention_masks} |
| |
| def _build_vocab(self, corpus): |
| vocab = {"<pad>": 0} |
| for verse_lengths in range(3, 10): |
| vocab[str(verse_lengths)] = len(vocab) |
| for doc in corpus: |
| chars = self._tokenize(doc) |
| for char in chars: |
| if char not in vocab: |
| vocab[char] = len(vocab) |
| vocab["<bos>"] = len(vocab) |
| vocab["<eos>"] = len(vocab) |
| return vocab |
| |
| def decode(self, token_ids): |
| chars = [self.id2vocab[token_id] for token_id in token_ids.flatten().tolist()] |
| filtered_chars = [char for char in chars if char not in ["<eos>", "<bos>", "<pad>"]] |
| return "".join(filtered_chars) |
| |
| def save(self, filepath): |
| with open(filepath, "w") as f: |
| json.dump(self.vocab, f) |
| |
| @classmethod |
| def load(cls, filepath): |
| with open(filepath) as f: |
| vocab = json.load(f) |
| return cls(vocab=vocab) |
|
|