| import torch |
| from data.field.mini_torchtext.field import RawField |
| from data.field.mini_torchtext.vocab import Vocab |
| from collections import Counter |
|
|
|
|
| class LabelField(RawField): |
| def __self__(self, preprocessing): |
| super(LabelField, self).__init__(preprocessing=preprocessing) |
| self.vocab = None |
|
|
| def build_vocab(self, *args, **kwargs): |
| sources = [] |
| for arg in args: |
| if isinstance(arg, torch.utils.data.Dataset): |
| sources += [arg.get_examples(name) for name, field in arg.fields.items() if field is self] |
| else: |
| sources.append(arg) |
|
|
| counter = Counter() |
| for data in sources: |
| for x in data: |
| counter.update(x) |
|
|
| self.vocab = Vocab(counter, specials=[]) |
|
|
| def process(self, example, device=None): |
| tensor, lengths = self.numericalize(example, device=device) |
| return tensor, lengths |
|
|
| def numericalize(self, example, device=None): |
| example = [self.vocab.stoi[x] + 1 for x in example] |
| length = torch.LongTensor([len(example)], device=device).squeeze(0) |
| tensor = torch.LongTensor(example, device=device) |
|
|
| return tensor, length |
|
|