Spaces:
Running
Running
| import torch | |
| from rstor.data.stored_images_dataloader import RestorationDataset | |
| from numba import cuda | |
| from rstor.properties import DATASET_PATH, AUGMENTATION_FLIP, AUGMENTATION_ROTATE | |
| def test_dataloader_stored(): | |
| if not cuda.is_available(): | |
| print("cuda unavailable, exiting") | |
| return | |
| # Test case 1: Default parameters | |
| dataset = RestorationDataset(noise_stddev=(0, 0), | |
| images_path=DATASET_PATH/"sample") | |
| assert len(dataset) == 2 | |
| assert dataset.frozen_seed is None | |
| # Test case 2: Custom parameters | |
| dataset = RestorationDataset(images_path=DATASET_PATH/"sample", | |
| size=(64, 64), | |
| frozen_seed=42, | |
| noise_stddev=(0, 0)) | |
| assert len(dataset) == 2 | |
| assert dataset.frozen_seed == 42 | |
| # Test case 3: Check item retrieval | |
| item, item_tgt = dataset[0] | |
| assert isinstance(item, torch.Tensor) | |
| assert item.shape == item_tgt.shape | |
| assert item.shape == (3, 64, 64) | |
| # Test case 4: Repeatable results with frozen seed | |
| dataset1 = RestorationDataset(images_path=DATASET_PATH/"sample", | |
| frozen_seed=42, noise_stddev=(0, 0)) | |
| dataset2 = RestorationDataset(images_path=DATASET_PATH/"sample", | |
| frozen_seed=42, noise_stddev=(0, 0)) | |
| item1, item_tgt1 = dataset1[0] | |
| item2, item_tgt2 = dataset2[0] | |
| assert torch.all(torch.eq(item1, item2)) | |
| # Test case 4: Repeatable results with frozen seed and augmentation | |
| augmentation_list = [AUGMENTATION_FLIP, AUGMENTATION_ROTATE] | |
| dataset1 = RestorationDataset(images_path=DATASET_PATH/"sample", | |
| frozen_seed=42, noise_stddev=(0, 0), | |
| augmentation_list=augmentation_list) | |
| dataset2 = RestorationDataset(images_path=DATASET_PATH/"sample", | |
| frozen_seed=42, noise_stddev=(0, 0), | |
| augmentation_list=augmentation_list) | |
| item1, item_tgt1 = dataset1[0] | |
| item2, item_tgt2 = dataset2[0] | |
| assert torch.all(torch.eq(item1, item2)) | |
| # Test case 5: Visualize | |
| # dataset = RestorationDataset(images_path=DATASET_PATH/"sample", | |
| # noise_stddev=(0, 0), | |
| # augmentation_list=augmentation_list) | |
| # item, item_tgt = dataset[0] | |
| # import matplotlib.pyplot as plt | |
| # plt.figure() | |
| # plt.imshow(item.permute(1, 2, 0).detach().cpu()) | |
| # plt.show() | |
| # breakpoint() | |
| print("done") | |