| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| import os |
| import PIL |
| import json |
|
|
| from torchvision import datasets, transforms |
| import numpy as np |
| from torch.utils.data import Dataset |
| import random |
|
|
| from timm.data import create_transform |
| from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD |
|
|
| import os |
| from PIL import Image |
| import numpy as np |
| import torch |
| |
| |
| class CelebAMaskDataset(Dataset): |
| def __init__(self, images_path, annot_path, target_classes, img_size=(256, 256), mask_size=(32, 32), return_downsized_image=False): |
| self.images_path = images_path |
| self.annot_path = annot_path |
| self.target_classes = target_classes |
| self.img_size = img_size |
| self.mask_size = mask_size |
| self.return_downsized_image = return_downsized_image |
| |
| |
| self.image_files = sorted([f for f in os.listdir(images_path) if f.endswith('.jpg')]) |
| print(f"Total {len(self.image_files)} images.") |
|
|
| self.annot_path_dict = {} |
| |
| for folder in os.listdir(annot_path): |
| folder_path = os.path.join(annot_path, folder) |
| |
| |
| if os.path.isdir(folder_path): |
| for filename in os.listdir(folder_path): |
| file_path = os.path.join(folder_path, filename) |
| |
| |
| if os.path.isfile(file_path): |
| self.annot_path_dict[filename] = file_path |
| print(f"Total {len(self.annot_path_dict)} masks.") |
| |
| |
| self.image_transform = transforms.Compose([ |
| transforms.Resize(img_size), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=0.5, std=0.5), |
| ]) |
| if self.return_downsized_image: |
| print('return image, mask, downsized image') |
| self.image_transform_downsize = transforms.Compose([ |
| transforms.Resize((32,32)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=0.5, std=0.5), |
| ]) |
| self.mask_transform = transforms.Compose([ |
| transforms.Resize(mask_size, interpolation=Image.NEAREST) |
| ]) |
|
|
| def __len__(self): |
| return len(self.image_files) |
|
|
| def __getitem__(self, idx): |
| |
| img_filename = self.image_files[idx] |
| img_path = os.path.join(self.images_path, img_filename) |
| _image = Image.open(img_path).convert('RGB') |
| |
| image = self.image_transform(_image) |
| if self.return_downsized_image: |
| downsized_image = self.image_transform_downsize(_image) |
|
|
| |
| base_filename = os.path.splitext(img_filename)[0] |
| base_filename = str(base_filename).zfill(5) |
| |
| mask_dict = dict() |
| |
| for i, cls in enumerate(self.target_classes): |
| mask_filename = f'{base_filename}_{cls}.png' |
| mask_path = self.annot_path_dict.get(mask_filename) |
| if mask_path is not None: |
| mask = Image.open(mask_path).convert('L') |
| mask = self.mask_transform(mask) |
| mask = torch.tensor(np.array(mask) > 0, dtype=torch.bool) |
| if mask.sum() > 0: |
| mask_dict[cls] = mask |
| |
| if self.return_downsized_image: |
| return image, mask_dict, downsized_image |
| else: |
| return image, mask_dict |
|
|
| def get_tiny_imagenet(full_dataset): |
| class_list_file = '/data/projects/jeongwoo/tiny-imagenet/200_wnids.txt' |
|
|
| |
| with open(class_list_file, 'r') as f: |
| selected_classes = [line.strip() for line in f.readlines()] |
|
|
| |
| class_name_to_new_label = {class_name: idx for idx, class_name in enumerate(selected_classes)} |
|
|
| |
| filtered_samples = [] |
| for img_path, label in full_dataset.samples: |
| class_name = os.path.basename(os.path.dirname(img_path)) |
| if class_name in selected_classes: |
| new_label = class_name_to_new_label[class_name] |
| filtered_samples.append((img_path, new_label)) |
|
|
| |
| full_dataset.samples = filtered_samples |
| full_dataset.targets = [label for _, label in filtered_samples] |
|
|
| return full_dataset |
|
|
|
|
| class ADE20KPatchDataset(Dataset): |
| def __init__(self, root_dir, split='training', image_transform=None, annot_transform=None, return_downsized_image=False): |
| print(f"ADE20K: {split}") |
| self.image_transform = image_transform |
| self.annot_transform = annot_transform |
| |
| self.image_dir = os.path.join(root_dir, 'images', split) |
| self.annotation_dir = os.path.join(root_dir, 'annotations', split) |
|
|
| self.image_files = sorted(os.listdir(self.image_dir)) |
| self.annotation_files = sorted(os.listdir(self.annotation_dir)) |
|
|
| self.return_downsized_image = return_downsized_image |
| if self.return_downsized_image: |
| print('return image, mask, downsized image') |
| self.image_transform_downsize = transforms.Compose([ |
| transforms.Resize((32,32)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=0.5, std=0.5), |
| ]) |
| |
| def __len__(self): |
| return len(self.image_files) |
|
|
| def __getitem__(self, idx): |
| img_path = os.path.join(self.image_dir, self.image_files[idx]) |
| ann_path = os.path.join(self.annotation_dir, self.annotation_files[idx]) |
|
|
| _image = PIL.Image.open(img_path).convert("RGB") |
| annotation = PIL.Image.open(ann_path) |
|
|
| if self.image_transform: |
| image = self.image_transform(_image) |
| if self.annot_transform: |
| annotation = self.annot_transform(annotation) |
| |
| if self.return_downsized_image: |
| downsized_image = self.image_transform_downsize(_image) |
|
|
| if self.return_downsized_image: |
| return image, annotation, downsized_image |
| else: |
| return image, annotation |
| |
| class InatDataset(datasets.ImageFolder): |
| def __init__(self, is_train,**kwargs): |
| super().__init__(**kwargs) |
| self.mode = 'train' if is_train else 'val' |
| print('<Before>') |
| print(self.samples[100]) |
| print(len(self.samples)) |
| |
| annot_path = f'{self.root}/{self.mode}2019.json' |
| with open(annot_path, 'r') as file: |
| annot = json.load(file) |
| self.samples = [] |
| for img, tgt in zip(annot['images'], annot['annotations']): |
| self.samples.append([f"{self.root}/{img['file_name']}", tgt['category_id']]) |
| print('<After>') |
| print(self.samples[100]) |
| print(len(self.samples)) |
| |
| def __getitem__(self, index: int): |
| |
| path, target = self.samples[index] |
| sample = self.loader(path) |
| if self.transform is not None: |
| sample = self.transform(sample) |
| |
| return sample, target |
|
|
| class ImgDataset(datasets.ImageFolder): |
| def __init__(self, **kwargs): |
| super().__init__(**kwargs) |
| self.num_retries = 10 |
| |
| def __getitem__(self, index: int): |
| """ |
| Args: |
| index (int): Index |
| |
| Returns: |
| tuple: (sample, target) where target is class_index of the target class. |
| """ |
| failed = [] |
| for _ in range(self.num_retries): |
| path, target = self.samples[index] |
| try: |
| sample = self.loader(path) |
| except: |
| try: |
| sample = self.loader(path) |
| except: |
| failed.append(path) |
| index = random.randint(0, len(self.samples) - 1) |
| continue |
| if self.transform is not None: |
| sample = self.transform(sample) |
| if self.target_transform is not None: |
| target = self.target_transform(target) |
| |
| return sample, target |
| else: |
| print('Failed to load {} after {} retries'.format( |
| failed, self.num_retries |
| )) |
| |
| def build_dataset(is_train, args): |
| transform = build_transform(is_train, args) |
|
|
| if args.dataset is not None: |
| print(f'Downstream task with {args.dataset}') |
| if args.dataset.lower() == 'inat': |
| dataset = InatDataset(is_train = is_train, root=args.data_path, transform=transform) |
| elif args.dataset.lower() == 'cifar100': |
| from torchvision.datasets import CIFAR100 |
| dataset = CIFAR100(train = is_train, root=args.data_path, transform=transform) |
| elif args.dataset.lower() == 'cub': |
| from torchvision.datasets import ImageFolder |
| if is_train: |
| dataset = ImageFolder(root=args.data_path+'/train', transform=transform) |
| else: |
| dataset = ImageFolder(root=args.data_path+'/test', transform=transform) |
|
|
| else: |
| root = os.path.join(args.data_path, 'train' if is_train else 'val') |
| dataset = ImgDataset(root=root, transform=transform) |
|
|
| print(dataset) |
|
|
| return dataset |
|
|
|
|
| def build_transform(is_train, args): |
| mean = IMAGENET_DEFAULT_MEAN |
| std = IMAGENET_DEFAULT_STD |
| |
| if is_train: |
| |
| transform = create_transform( |
| input_size=args.input_size, |
| is_training=True, |
| color_jitter=args.color_jitter, |
| auto_augment=args.aa, |
| interpolation='bicubic', |
| re_prob=args.reprob, |
| re_mode=args.remode, |
| re_count=args.recount, |
| mean=mean, |
| std=std, |
| ) |
| if args.dataset is not None: |
| transform.transforms.insert(0, transforms.Lambda(lambda img: img.convert('RGB'))) |
|
|
| return transform |
|
|
| |
| t = [] |
| if args.dataset is not None: |
| t.append(transforms.Lambda(lambda image: image.convert("RGB"))) |
| if args.input_size <= 224: |
| crop_pct = 224 / 256 |
| else: |
| crop_pct = 1.0 |
| size = int(args.input_size / crop_pct) |
| t.append( |
| transforms.Resize(size, interpolation=PIL.Image.BICUBIC), |
| ) |
| t.append(transforms.CenterCrop(args.input_size)) |
|
|
| t.append(transforms.ToTensor()) |
| t.append(transforms.Normalize(mean, std)) |
| return transforms.Compose(t) |
|
|