metalmind / src /datasets /image_dataset.py
IELTS8's picture
Upload folder using huggingface_hub
ada3f28 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import os
from logging import getLogger
import torch
import torchvision
_GLOBAL_SEED = 0
logger = getLogger()
class ImageFolder(torchvision.datasets.ImageFolder):
def __init__(
self,
root,
image_folder='imagenet_full_size/061417/',
transform=None,
train=True,
):
"""
ImageFolder
:param root: root network directory for ImageFolder data
:param image_folder: path to images inside root network directory
:param train: whether to load train data (or validation)
"""
suffix = 'train/' if train else 'val/'
data_path = os.path.join(root, image_folder, suffix)
logger.info(f'data-path {data_path}')
super(ImageFolder, self).__init__(root=data_path, transform=transform)
logger.info('Initialized ImageFolder')
def make_imagedataset(
transform,
batch_size,
collator=None,
pin_mem=True,
num_workers=8,
world_size=1,
rank=0,
root_path=None,
image_folder=None,
training=True,
copy_data=False,
drop_last=True,
persistent_workers=False,
subset_file=None
):
dataset = ImageFolder(
root=root_path,
image_folder=image_folder,
transform=transform,
train=training)
logger.info('ImageFolder dataset created')
dist_sampler = torch.utils.data.distributed.DistributedSampler(
dataset=dataset,
num_replicas=world_size,
rank=rank)
data_loader = torch.utils.data.DataLoader(
dataset,
collate_fn=collator,
sampler=dist_sampler,
batch_size=batch_size,
drop_last=drop_last,
pin_memory=pin_mem,
num_workers=num_workers,
persistent_workers=persistent_workers)
logger.info('ImageFolder unsupervised data loader created')
return dataset, data_loader, dist_sampler