# --- START OF FILE data/__init__.py --- import importlib import torch.utils.data from data.base_dataset import BaseDataset def find_dataset_using_name(dataset_name): dataset_filename = "data." + dataset_name + "_dataset" datasetlib = importlib.import_module(dataset_filename) dataset = None target_dataset_name = dataset_name.replace('_', '') + 'dataset' for name, cls in datasetlib.__dict__.items(): if name.lower() == target_dataset_name.lower() \ and issubclass(cls, BaseDataset): dataset = cls if dataset is None: raise ValueError("In %s.py, there should be a subclass of BaseDataset " "with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) return dataset def get_option_setter(dataset_name): dataset_class = find_dataset_using_name(dataset_name) return dataset_class.modify_commandline_options def create_dataloader(opt): dataset = find_dataset_using_name(opt.dataset_mode) instance = dataset(opt) print("dataset [%s] of size %d was created" % (type(instance).__name__, len(instance))) dataloader = torch.utils.data.DataLoader( instance, batch_size=opt.batchSize, shuffle=not opt.serial_batches, num_workers=int(opt.nThreads), drop_last=opt.isTrain ) return dataloader