Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| # coding=utf-8 | |
| import torch | |
| from data.field.mini_torchtext.field import NestedField as TorchTextNestedField | |
| class NestedField(TorchTextNestedField): | |
| def pad(self, example): | |
| self.nesting_field.include_lengths = self.include_lengths | |
| if not self.include_lengths: | |
| return self.nesting_field.pad(example) | |
| sentence_length = len(example) | |
| example, word_lengths = self.nesting_field.pad(example) | |
| return example, sentence_length, word_lengths | |
| def numericalize(self, arr, device=None): | |
| numericalized = [] | |
| self.nesting_field.include_lengths = False | |
| if self.include_lengths: | |
| arr, sentence_length, word_lengths = arr | |
| numericalized = self.nesting_field.numericalize(arr, device=device) | |
| self.nesting_field.include_lengths = True | |
| if self.include_lengths: | |
| sentence_length = torch.tensor(sentence_length, dtype=self.dtype, device=device) | |
| word_lengths = torch.tensor(word_lengths, dtype=self.dtype, device=device) | |
| return (numericalized, sentence_length, word_lengths) | |
| return numericalized | |
| def build_vocab(self, *args, **kwargs): | |
| sources = [] | |
| for arg in args: | |
| if isinstance(arg, torch.utils.data.Dataset): | |
| sources += [arg.get_examples(name) for name, field in arg.fields.items() if field is self] | |
| else: | |
| sources.append(arg) | |
| flattened = [] | |
| for source in sources: | |
| flattened.extend(source) | |
| # just build vocab and does not load vector | |
| self.nesting_field.build_vocab(*flattened, **kwargs) | |
| super(TorchTextNestedField, self).build_vocab() | |
| self.vocab.extend(self.nesting_field.vocab) | |
| self.vocab.freqs = self.nesting_field.vocab.freqs.copy() | |
| self.nesting_field.vocab = self.vocab | |