Spaces:
Sleeping
Sleeping
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import torch | |
| from . import FairseqDataset | |
| class ConcatSentencesDataset(FairseqDataset): | |
| def __init__(self, *datasets): | |
| super().__init__() | |
| self.datasets = datasets | |
| assert all( | |
| len(ds) == len(datasets[0]) for ds in datasets | |
| ), "datasets must have the same length" | |
| def __getitem__(self, index): | |
| return torch.cat([ds[index] for ds in self.datasets]) | |
| def __len__(self): | |
| return len(self.datasets[0]) | |
| def collater(self, samples): | |
| return self.datasets[0].collater(samples) | |
| def sizes(self): | |
| return sum(ds.sizes for ds in self.datasets) | |
| def num_tokens(self, index): | |
| return sum(ds.num_tokens(index) for ds in self.datasets) | |
| def size(self, index): | |
| return sum(ds.size(index) for ds in self.datasets) | |
| def ordered_indices(self): | |
| return self.datasets[0].ordered_indices() | |
| def supports_prefetch(self): | |
| return any(getattr(ds, "supports_prefetch", False) for ds in self.datasets) | |
| def prefetch(self, indices): | |
| for ds in self.datasets: | |
| if getattr(ds, "supports_prefetch", False): | |
| ds.prefetch(indices) | |
| def set_epoch(self, epoch): | |
| super().set_epoch(epoch) | |
| for ds in self.datasets: | |
| if hasattr(ds, "set_epoch"): | |
| ds.set_epoch(epoch) | |