| | from typing import Iterator, Optional |
| |
|
| | import torch |
| | from torch.utils.data import Dataset, IterableDataset |
| |
|
| |
|
| | class ValidationWrapper(Dataset): |
| | """Wraps a dataset so that PyTorch Lightning's validation step can be turned into a |
| | visualization step. |
| | """ |
| |
|
| | dataset: Dataset |
| | dataset_iterator: Optional[Iterator] |
| | length: int |
| |
|
| | def __init__(self, dataset: Dataset, length: int) -> None: |
| | super().__init__() |
| | self.dataset = dataset |
| | self.length = length |
| | self.dataset_iterator = None |
| |
|
| | def __len__(self): |
| | return self.length |
| | |
| | def __getitem__(self, index: tuple): |
| | if isinstance(self.dataset, IterableDataset): |
| | if self.dataset_iterator is None: |
| | self.dataset_iterator = iter(self.dataset) |
| | return next(self.dataset_iterator) |
| | |
| | random_index = torch.randint(0, len(self.dataset), tuple()) |
| | random_context_num = torch.randint(2, self.dataset.view_sampler.num_context_views + 1, tuple()) |
| | |
| | return self.dataset[random_index.item(), random_context_num.item()] |
| |
|