| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import os |
| |
|
| | from logging import getLogger |
| |
|
| | import torch |
| | import torchvision |
| |
|
| | _GLOBAL_SEED = 0 |
| | logger = getLogger() |
| |
|
| |
|
| | class ImageFolder(torchvision.datasets.ImageFolder): |
| |
|
| | def __init__( |
| | self, |
| | root, |
| | image_folder='imagenet_full_size/061417/', |
| | transform=None, |
| | train=True, |
| | ): |
| | """ |
| | ImageFolder |
| | :param root: root network directory for ImageFolder data |
| | :param image_folder: path to images inside root network directory |
| | :param train: whether to load train data (or validation) |
| | """ |
| |
|
| | suffix = 'train/' if train else 'val/' |
| | data_path = os.path.join(root, image_folder, suffix) |
| | logger.info(f'data-path {data_path}') |
| | super(ImageFolder, self).__init__(root=data_path, transform=transform) |
| | logger.info('Initialized ImageFolder') |
| |
|
| |
|
| | def make_imagedataset( |
| | transform, |
| | batch_size, |
| | collator=None, |
| | pin_mem=True, |
| | num_workers=8, |
| | world_size=1, |
| | rank=0, |
| | root_path=None, |
| | image_folder=None, |
| | training=True, |
| | copy_data=False, |
| | drop_last=True, |
| | persistent_workers=False, |
| | subset_file=None |
| | ): |
| | dataset = ImageFolder( |
| | root=root_path, |
| | image_folder=image_folder, |
| | transform=transform, |
| | train=training) |
| | logger.info('ImageFolder dataset created') |
| | dist_sampler = torch.utils.data.distributed.DistributedSampler( |
| | dataset=dataset, |
| | num_replicas=world_size, |
| | rank=rank) |
| | data_loader = torch.utils.data.DataLoader( |
| | dataset, |
| | collate_fn=collator, |
| | sampler=dist_sampler, |
| | batch_size=batch_size, |
| | drop_last=drop_last, |
| | pin_memory=pin_mem, |
| | num_workers=num_workers, |
| | persistent_workers=persistent_workers) |
| | logger.info('ImageFolder unsupervised data loader created') |
| |
|
| | return dataset, data_loader, dist_sampler |
| |
|