| | |
| | |
| | |
| | |
| |
|
| | from torch.utils.data.dataloader import default_collate |
| |
|
| | from . import FairseqDataset |
| |
|
| |
|
| | class BaseWrapperDataset(FairseqDataset): |
| | def __init__(self, dataset): |
| | super().__init__() |
| | self.dataset = dataset |
| |
|
| | def __getitem__(self, index): |
| | return self.dataset[index] |
| |
|
| | def __len__(self): |
| | return len(self.dataset) |
| |
|
| | def collater(self, samples): |
| | if hasattr(self.dataset, "collater"): |
| | return self.dataset.collater(samples) |
| | else: |
| | return default_collate(samples) |
| |
|
| | @property |
| | def sizes(self): |
| | return self.dataset.sizes |
| |
|
| | def num_tokens(self, index): |
| | return self.dataset.num_tokens(index) |
| |
|
| | def size(self, index): |
| | return self.dataset.size(index) |
| |
|
| | def ordered_indices(self): |
| | return self.dataset.ordered_indices() |
| |
|
| | @property |
| | def supports_prefetch(self): |
| | return getattr(self.dataset, "supports_prefetch", False) |
| |
|
| | def attr(self, attr: str, index: int): |
| | return self.dataset.attr(attr, index) |
| |
|
| | def prefetch(self, indices): |
| | self.dataset.prefetch(indices) |
| |
|
| | def get_batch_shapes(self): |
| | return self.dataset.get_batch_shapes() |
| |
|
| | def batch_by_size( |
| | self, |
| | indices, |
| | max_tokens=None, |
| | max_sentences=None, |
| | required_batch_size_multiple=1, |
| | ): |
| | return self.dataset.batch_by_size( |
| | indices, |
| | max_tokens=max_tokens, |
| | max_sentences=max_sentences, |
| | required_batch_size_multiple=required_batch_size_multiple, |
| | ) |
| |
|
| | def filter_indices_by_size(self, indices, max_sizes): |
| | return self.dataset.filter_indices_by_size(indices, max_sizes) |
| |
|
| | @property |
| | def can_reuse_epoch_itr_across_epochs(self): |
| | return self.dataset.can_reuse_epoch_itr_across_epochs |
| |
|
| | def set_epoch(self, epoch): |
| | super().set_epoch(epoch) |
| | if hasattr(self.dataset, "set_epoch"): |
| | self.dataset.set_epoch(epoch) |
| |
|