| | |
| | |
| |
|
| | import torch.utils.data |
| | import random |
| | from data.base_data_loader import BaseDataLoader |
| | from data import online_dataset_for_old_photos as dts_ray_bigfile |
| |
|
| |
|
| | def CreateDataset(opt): |
| | dataset = None |
| | if opt.training_dataset=='domain_A' or opt.training_dataset=='domain_B': |
| | dataset = dts_ray_bigfile.UnPairOldPhotos_SR() |
| | if opt.training_dataset=='mapping': |
| | if opt.random_hole: |
| | dataset = dts_ray_bigfile.PairOldPhotos_with_hole() |
| | else: |
| | dataset = dts_ray_bigfile.PairOldPhotos() |
| | print("dataset [%s] was created" % (dataset.name())) |
| | dataset.initialize(opt) |
| | return dataset |
| |
|
| | class CustomDatasetDataLoader(BaseDataLoader): |
| | def name(self): |
| | return 'CustomDatasetDataLoader' |
| |
|
| | def initialize(self, opt): |
| | BaseDataLoader.initialize(self, opt) |
| | self.dataset = CreateDataset(opt) |
| | self.dataloader = torch.utils.data.DataLoader( |
| | self.dataset, |
| | batch_size=opt.batchSize, |
| | shuffle=not opt.serial_batches, |
| | num_workers=int(opt.nThreads), |
| | drop_last=True) |
| |
|
| | def load_data(self): |
| | return self.dataloader |
| |
|
| | def __len__(self): |
| | return min(len(self.dataset), self.opt.max_dataset_size) |
| |
|