| |
| |
| |
| |
|
|
| import torch |
|
|
| from . import FairseqDataset |
|
|
|
|
| class ConcatSentencesDataset(FairseqDataset): |
| def __init__(self, *datasets): |
| super().__init__() |
| self.datasets = datasets |
| assert all( |
| len(ds) == len(datasets[0]) for ds in datasets |
| ), "datasets must have the same length" |
|
|
| def __getitem__(self, index): |
| return torch.cat([ds[index] for ds in self.datasets]) |
|
|
| def __len__(self): |
| return len(self.datasets[0]) |
|
|
| def collater(self, samples): |
| return self.datasets[0].collater(samples) |
|
|
| @property |
| def sizes(self): |
| return sum(ds.sizes for ds in self.datasets) |
|
|
| def num_tokens(self, index): |
| return sum(ds.num_tokens(index) for ds in self.datasets) |
|
|
| def size(self, index): |
| return sum(ds.size(index) for ds in self.datasets) |
|
|
| def ordered_indices(self): |
| return self.datasets[0].ordered_indices() |
|
|
| @property |
| def supports_prefetch(self): |
| return any(getattr(ds, "supports_prefetch", False) for ds in self.datasets) |
|
|
| def prefetch(self, indices): |
| for ds in self.datasets: |
| if getattr(ds, "supports_prefetch", False): |
| ds.prefetch(indices) |
|
|
| def set_epoch(self, epoch): |
| super().set_epoch(epoch) |
| for ds in self.datasets: |
| if hasattr(ds, "set_epoch"): |
| ds.set_epoch(epoch) |
|
|