| # Copyright (c) Microsoft Corporation. | |
| # Licensed under the MIT License. | |
| import importlib | |
| import torch.utils.data | |
| from data.base_dataset import BaseDataset | |
| from data.face_dataset import FaceTestDataset | |
| def create_dataloader(opt): | |
| instance = FaceTestDataset() | |
| instance.initialize(opt) | |
| print("dataset [%s] of size %d was created" % (type(instance).__name__, len(instance))) | |
| dataloader = torch.utils.data.DataLoader( | |
| instance, | |
| batch_size=opt.batchSize, | |
| shuffle=not opt.serial_batches, | |
| num_workers=int(opt.nThreads), | |
| drop_last=opt.isTrain, | |
| ) | |
| return dataloader | |