Spaces:
Runtime error
Runtime error
| from typing import Iterator, Optional | |
| import torch | |
| from torch.utils.data import Dataset, IterableDataset | |
| class ValidationWrapper(Dataset): | |
| """Wraps a dataset so that PyTorch Lightning's validation step can be turned into a | |
| visualization step. | |
| """ | |
| dataset: Dataset | |
| dataset_iterator: Optional[Iterator] | |
| length: int | |
| def __init__(self, dataset: Dataset, length: int) -> None: | |
| super().__init__() | |
| self.dataset = dataset | |
| self.length = length | |
| self.dataset_iterator = None | |
| def __len__(self): | |
| return self.length | |
| def __getitem__(self, index: tuple): | |
| if isinstance(self.dataset, IterableDataset): | |
| if self.dataset_iterator is None: | |
| self.dataset_iterator = iter(self.dataset) | |
| return next(self.dataset_iterator) | |
| random_index = torch.randint(0, len(self.dataset), tuple()) | |
| random_context_num = torch.randint(2, self.dataset.view_sampler.num_context_views + 1, tuple()) | |
| # breakpoint() | |
| return self.dataset[random_index.item(), random_context_num.item()] | |