import os import random from PIL import Image, ImageFile from datasets import register from torch.utils.data import Dataset from torchvision import transforms import os import random Image.MAX_IMAGE_PIXELS = 933120000 ImageFile.LOAD_TRUNCATED_IMAGES = True IMAGE_EXTS = ('.png', '.PNG', '.jpg', '.JPG', '.jpeg', '.JPEG', '.webp') @register('class_folder') class ClassFolder(Dataset): def __init__(self, root_path, resize=None, square_crop=False, rand_crop=None, rand_flip=False, drop_label_p=0.0, image_only=False): folders = [] print('root_path', root_path) for folder in sorted(os.listdir(root_path)): print('folder', folder) if os.path.isdir(os.path.join(root_path, folder)): folders.append(os.path.join(root_path, folder)) print('folders', folders) self.files = [] self.labels = [] for i, folder in enumerate(folders): for file in sorted(os.listdir(os.path.join(root_path, folder))): if file.endswith(IMAGE_EXTS): self.files.append(os.path.join(root_path, folder, file)) self.labels.append(i) self.resize = resize self.square_crop = square_crop self.rand_crop = rand_crop self.rand_flip = transforms.RandomHorizontalFlip() if rand_flip else None self.n_classes = len(folders) self.drop_label_p = drop_label_p self.image_only = image_only def __len__(self): return len(self.files) def __getitem__(self, idx): try: image = Image.open(self.files[idx]).convert('RGB') label = self.labels[idx] except: print('Error loading image:', self.files[idx]) return self.__getitem__((idx + 1) % self.__len__()) if self.resize is not None: r = self.resize if isinstance(r, int): w, h = image.size if w < h: r = (r, int(h / w * r)) else: r = (int(w / h * r), r) image = image.resize(r, Image.LANCZOS) if self.square_crop: w, h = image.size l = min(w, h) left, upper = (w - l) // 2, (h - l) // 2 image = image.crop((left, upper, left + l, upper + l)) if self.rand_crop is not None: w, h = image.size left = random.randint(0, w - self.rand_crop) upper = random.randint(0, h - self.rand_crop) image = image.crop((left, upper, left + self.rand_crop, upper + self.rand_crop)) if self.rand_flip is not None: image = self.rand_flip(image) if self.drop_label_p > 0.0 and random.random() < self.drop_label_p: label = self.n_classes if self.image_only: return image else: return { 'image': image, 'class_labels': label, }