Spaces:
Sleeping
Sleeping
| import os | |
| import random | |
| from PIL import Image, ImageFile | |
| from datasets import register | |
| from torch.utils.data import Dataset | |
| from torchvision import transforms | |
| import os | |
| import random | |
| Image.MAX_IMAGE_PIXELS = 933120000 | |
| ImageFile.LOAD_TRUNCATED_IMAGES = True | |
| IMAGE_EXTS = ('.png', '.PNG', '.jpg', '.JPG', '.jpeg', '.JPEG', '.webp') | |
| class ClassFolder(Dataset): | |
| def __init__(self, root_path, resize=None, square_crop=False, rand_crop=None, rand_flip=False, drop_label_p=0.0, image_only=False): | |
| folders = [] | |
| print('root_path', root_path) | |
| for folder in sorted(os.listdir(root_path)): | |
| print('folder', folder) | |
| if os.path.isdir(os.path.join(root_path, folder)): | |
| folders.append(os.path.join(root_path, folder)) | |
| print('folders', folders) | |
| self.files = [] | |
| self.labels = [] | |
| for i, folder in enumerate(folders): | |
| for file in sorted(os.listdir(os.path.join(root_path, folder))): | |
| if file.endswith(IMAGE_EXTS): | |
| self.files.append(os.path.join(root_path, folder, file)) | |
| self.labels.append(i) | |
| self.resize = resize | |
| self.square_crop = square_crop | |
| self.rand_crop = rand_crop | |
| self.rand_flip = transforms.RandomHorizontalFlip() if rand_flip else None | |
| self.n_classes = len(folders) | |
| self.drop_label_p = drop_label_p | |
| self.image_only = image_only | |
| def __len__(self): | |
| return len(self.files) | |
| def __getitem__(self, idx): | |
| try: | |
| image = Image.open(self.files[idx]).convert('RGB') | |
| label = self.labels[idx] | |
| except: | |
| print('Error loading image:', self.files[idx]) | |
| return self.__getitem__((idx + 1) % self.__len__()) | |
| if self.resize is not None: | |
| r = self.resize | |
| if isinstance(r, int): | |
| w, h = image.size | |
| if w < h: | |
| r = (r, int(h / w * r)) | |
| else: | |
| r = (int(w / h * r), r) | |
| image = image.resize(r, Image.LANCZOS) | |
| if self.square_crop: | |
| w, h = image.size | |
| l = min(w, h) | |
| left, upper = (w - l) // 2, (h - l) // 2 | |
| image = image.crop((left, upper, left + l, upper + l)) | |
| if self.rand_crop is not None: | |
| w, h = image.size | |
| left = random.randint(0, w - self.rand_crop) | |
| upper = random.randint(0, h - self.rand_crop) | |
| image = image.crop((left, upper, left + self.rand_crop, upper + self.rand_crop)) | |
| if self.rand_flip is not None: | |
| image = self.rand_flip(image) | |
| if self.drop_label_p > 0.0 and random.random() < self.drop_label_p: | |
| label = self.n_classes | |
| if self.image_only: | |
| return image | |
| else: | |
| return { | |
| 'image': image, | |
| 'class_labels': label, | |
| } | |