Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import torch | |
| class DiscreteTokenizer(object): | |
| def __init__(self, num_bins, seq_len, add_cls=False): | |
| self.num_bins = num_bins | |
| vocab_size = num_bins * num_bins | |
| self.seq_len = seq_len | |
| self.add_cls = add_cls | |
| self.bos = vocab_size + 0 | |
| self.eos = vocab_size + 1 | |
| self.sep = vocab_size + 2 | |
| self.pad = vocab_size + 3 | |
| if add_cls: | |
| self.cls = vocab_size + 4 | |
| self.vocab_size = vocab_size + 5 | |
| else: | |
| self.vocab_size = vocab_size + 4 | |
| def __len__(self): | |
| return self.vocab_size | |
| def _padding(self, seq, pad_value, dtype): | |
| if self.seq_len > len(seq): | |
| seq.extend([pad_value] * (self.seq_len - len(seq))) | |
| return torch.tensor(np.array(seq), dtype=dtype) | |
| def __call__(self, seq, add_bos, add_eos, dtype, return_indices=False): | |
| out = [] | |
| if add_bos: | |
| out = [self.bos] | |
| num_extra = 1 if not self.add_cls else 2 # cls and sep | |
| indices = [] | |
| for i, sub in enumerate(seq): | |
| cur_len = len(out) | |
| # Append sub only if it doesn't exceed seq_len | |
| if cur_len + len(sub) + num_extra <= self.seq_len: | |
| out.extend(sub) | |
| indices.append(i) | |
| else: | |
| continue | |
| # Append cls and sep tokens only if it doesn't exceed seq_len | |
| if self.add_cls: | |
| out.append(self.cls) # cls token | |
| out.append(self.sep) | |
| # Remove last separator token if present | |
| if out and out[-1] == self.sep: | |
| out.pop(-1) # remove last separator token | |
| if self.seq_len > len(out): | |
| out.extend([self.pad] * (self.seq_len - len(out))) | |
| if add_eos: | |
| out[-1] = self.eos | |
| if return_indices: | |
| return torch.tensor(out, dtype=dtype), indices | |
| return torch.tensor(out, dtype=dtype) | |