import logging import glob import torch class FolderData(torch.utils.data.Dataset): def __init__(self, path, transforms, extensions=['.jpg', '.png'], recursive=False, verbose=False): self.verbose = verbose if self.verbose: logger = logging.getLogger(__name__) if len(extensions) == 0: if self.verbose: logger.error("Expected at list one extension, but none was received.") raise ValueError if self.verbose: logger.info("Constructing the list of images.") additional_pattern = '/**/*' if recursive else '/*' files = [] for extension in extensions: files.extend(glob.glob(path + additional_pattern + extension, recursive=recursive)) if self.verbose: logger.info("Finished searching for images. %s images found", len(files)) logger.info("Preparing to run the detection.") self.files = files self.transforms = transforms def __getitem__(self, idx): image_path = self.files[idx] image = self.transforms(image_path) return image_path, image def __len__(self): return len(self.files)