Spaces:
Sleeping
Sleeping
| from torch.utils.data import DataLoader | |
| from rstor.data.synthetic_dataloader import DeadLeavesDataset, DeadLeavesDatasetGPU | |
| from rstor.data.stored_images_dataloader import RestorationDataset | |
| from rstor.properties import ( | |
| DATALOADER, BATCH_SIZE, TRAIN, VALIDATION, LENGTH, CONFIG_DEAD_LEAVES, SIZE, NAME, CONFIG_DEGRADATION, | |
| DATASET_SYNTH_LIST, DATASET_DIV2K, | |
| DATASET_PATH | |
| ) | |
| from typing import Optional | |
| from random import seed, shuffle | |
| def get_data_loader_synthetic(config, frozen_seed=42): | |
| # print(config[DATALOADER].get(CONFIG_DEAD_LEAVES, {})) | |
| if config[DATALOADER].get("gpu_gen", False): | |
| print("Using GPU dead leaves generator") | |
| ds = DeadLeavesDatasetGPU | |
| else: | |
| ds = DeadLeavesDataset | |
| dl_train = ds(config[DATALOADER][SIZE], config[DATALOADER][LENGTH][TRAIN], | |
| frozen_seed=None, **config[DATALOADER].get(CONFIG_DEAD_LEAVES, {})) | |
| dl_valid = ds(config[DATALOADER][SIZE], config[DATALOADER][LENGTH][VALIDATION], | |
| frozen_seed=frozen_seed, **config[DATALOADER].get(CONFIG_DEAD_LEAVES, {})) | |
| dl_dict = create_dataloaders(config, dl_train, dl_valid) | |
| return dl_dict | |
| def create_dataloaders(config, dl_train, dl_valid) -> dict: | |
| dl_dict = { | |
| TRAIN: DataLoader( | |
| dl_train, | |
| shuffle=True, | |
| batch_size=config[DATALOADER][BATCH_SIZE][TRAIN], | |
| ), | |
| VALIDATION: DataLoader( | |
| dl_valid, | |
| shuffle=False, | |
| batch_size=config[DATALOADER][BATCH_SIZE][VALIDATION] | |
| ), | |
| # TEST: DataLoader(dl_test, shuffle=False, batch_size=config[DATALOADER][BATCH_SIZE][TEST]) | |
| } | |
| return dl_dict | |
| def get_data_loader_from_disk(config, frozen_seed: Optional[int] = 42) -> dict: | |
| ds = RestorationDataset | |
| dataset_name = config[DATALOADER][NAME] # NAME shall be here! | |
| if dataset_name == DATASET_DIV2K: | |
| dataset_root = DATASET_PATH/DATASET_DIV2K | |
| train_root = dataset_root/"DIV2K_train_HR"/"DIV2K_train_HR" | |
| valid_root = dataset_root/"DIV2K_valid_HR"/"DIV2K_valid_HR" | |
| train_files = sorted(list(train_root.glob("*.png"))) | |
| train_files = 5*train_files # Just to get 4000 elements... | |
| valid_files = sorted(list(valid_root.glob("*.png"))) | |
| elif dataset_name in DATASET_SYNTH_LIST: | |
| dataset_root = DATASET_PATH/dataset_name | |
| all_files = sorted(list(dataset_root.glob("*.png"))) | |
| seed(frozen_seed) | |
| shuffle(all_files) # Easy way to perform cross validation if neeeded | |
| cut_index = int(0.9*len(all_files)) | |
| train_files = all_files[:cut_index] | |
| valid_files = all_files[cut_index:] | |
| dl_train = ds( | |
| train_files, | |
| size=config[DATALOADER][SIZE], | |
| frozen_seed=None, | |
| **config[DATALOADER].get(CONFIG_DEGRADATION, {}) | |
| ) | |
| dl_valid = ds( | |
| valid_files, | |
| size=config[DATALOADER][SIZE], | |
| frozen_seed=frozen_seed, | |
| **config[DATALOADER].get(CONFIG_DEGRADATION, {}) | |
| ) | |
| dl_dict = create_dataloaders(config, dl_train, dl_valid) | |
| return dl_dict | |
| def get_data_loader(config, frozen_seed=42): | |
| dataset_name = config[DATALOADER].get(NAME, False) | |
| if dataset_name: | |
| return get_data_loader_from_disk(config, frozen_seed) | |
| else: | |
| return get_data_loader_synthetic(config, frozen_seed) | |
| if __name__ == "__main__": | |
| # Example of usage synthetic dataset | |
| for dataset_name in [DATASET_DIV2K, None, DATASET_DL_DIV2K_512, DATASET_DL_DIV2K_1024]: | |
| if dataset_name is None: | |
| dead_leaves_dataset = DeadLeavesDatasetGPU(colored=True) | |
| dl = DataLoader(dead_leaves_dataset, batch_size=4, shuffle=True) | |
| else: | |
| # Example of usage stored images dataset | |
| config = { | |
| DATALOADER: { | |
| NAME: dataset_name, | |
| SIZE: (128, 128), | |
| BATCH_SIZE: { | |
| TRAIN: 4, | |
| VALIDATION: 4 | |
| }, | |
| } | |
| } | |
| dl_dict = get_data_loader(config) | |
| dl = dl_dict[TRAIN] | |
| # dl = dl_dict[VALIDATION] | |
| for i, (batch_inp, batch_target) in enumerate(dl): | |
| print(batch_inp.shape, batch_target.shape) # Should print [batch_size, size[0], size[1], 3] for each batch | |
| if i == 1: # Just to break the loop after two batches for demonstration | |
| import matplotlib.pyplot as plt | |
| plt.subplot(1, 2, 1) | |
| plt.imshow(batch_inp.permute(0, 2, 3, 1).reshape(-1, batch_inp.shape[-1], 3).cpu().numpy()) | |
| plt.title("Degraded") | |
| plt.subplot(1, 2, 2) | |
| plt.imshow(batch_target.permute(0, 2, 3, 1).reshape(-1, batch_inp.shape[-1], 3).cpu().numpy()) | |
| plt.title("Target") | |
| plt.show() | |
| # print(batch_target) | |
| break | |