| | import importlib |
| | import torch.utils.data |
| | from data_test_flow.dd_dataset import DDDataset |
| |
|
| | def CreateDataLoader(opt): |
| | data_loader = CustomDatasetDataLoader() |
| | data_loader.initialize(opt) |
| | return data_loader |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | class BaseDataLoader(): |
| | def __init__(self): |
| | pass |
| |
|
| | def initialize(self, opt): |
| | self.opt = opt |
| | pass |
| |
|
| | def load_data(self): |
| | return None |
| |
|
| | class CustomDatasetDataLoader(BaseDataLoader): |
| | def name(self): |
| | return 'CustomDatasetDataLoader' |
| | |
| | def initialize(self, opt): |
| | BaseDataLoader.initialize(self, opt) |
| | self.dataset = DDDataset() |
| | self.dataset.initialize(opt) |
| | ''' |
| | sampler = torch.utils.data.distributed.DistributedSampler(self.dataset) |
| | self.dataloader = torch.utils.data.DataLoader( |
| | self.dataset, |
| | batch_size=opt.batch_size, |
| | shuffle=False, |
| | sampler=sampler) |
| | ''' |
| | self.dataloader = torch.utils.data.DataLoader( |
| | self.dataset, |
| | batch_size=opt.batch_size, |
| | shuffle=opt.shuffle, |
| | drop_last=True, |
| | num_workers=int(opt.num_threads)) |
| | |
| | def load_data(self): |
| | return self |
| |
|
| | def __len__(self): |
| | return min(len(self.dataset), self.opt.max_dataset_size) |
| | |
| | def __iter__(self): |
| | for i, data in enumerate(self.dataloader): |
| | if i * self.opt.batch_size >= self.opt.max_dataset_size: |
| | break |
| | yield data |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|