| import torch | |
| from transformers import AutoTokenizer | |
| class DaedalusTokenizer(AutoTokenizer): | |
| def __init__(self, config): | |
| super(DaedalusTokenizer, self).__init__(config) | |
| self.config = config | |
| def encode(self, text): | |
| return self.encode_plus(text, max_length=self.config.max_seq_length, padding='max_length', truncation=True) | |
| def decode(self, ids): | |
| return self.decode(ids, skip_special_tokens=True) |