| import glob |
| import os |
| import random |
| import torch |
| import torchvision |
| import numpy as np |
| from PIL import Image |
| from tqdm import tqdm |
| from torch.utils.data.dataset import Dataset |
|
|
|
|
| class CelebDataset(Dataset): |
| r""" |
| Celeb dataset will by default centre crop and resize the images. |
| This can be replaced by any other dataset. As long as all the images |
| are under one directory. |
| """ |
|
|
| def __init__(self, split, im_path, im_size=256, im_channels=3, im_ext='jpg', |
| use_latents=False, latent_path=None, condition_config=None): |
| self.split = split |
| self.im_size = im_size |
| self.im_channels = im_channels |
| self.im_ext = im_ext |
| self.im_path = im_path |
| self.latent_maps = None |
| self.use_latents = False |
|
|
| self.condition_types = [] if condition_config is None else condition_config['condition_types'] |
|
|
| self.idx_to_cls_map = {} |
| self.cls_to_idx_map = {} |
|
|
| if 'image' in self.condition_types: |
| self.mask_channels = condition_config['image_condition_config']['image_condition_input_channels'] |
| self.mask_h = condition_config['image_condition_config']['image_condition_h'] |
| self.mask_w = condition_config['image_condition_config']['image_condition_w'] |
|
|
| self.images, self.texts, self.masks = self.load_images(im_path) |
|
|
|
|
| def load_images(self, im_path): |
| r""" |
| Gets all images from the path specified |
| and stacks them all up |
| """ |
| assert os.path.exists( |
| im_path), "images path {} does not exist".format(im_path) |
| ims = [] |
| fnames = glob.glob(os.path.join( |
| im_path, 'CelebA-HQ-img/*.{}'.format('png'))) |
| fnames += glob.glob(os.path.join(im_path, |
| 'CelebA-HQ-img/*.{}'.format('jpg'))) |
| fnames += glob.glob(os.path.join(im_path, |
| 'CelebA-HQ-img/*.{}'.format('jpeg'))) |
| texts = [] |
| masks = [] |
|
|
| if 'image' in self.condition_types: |
| label_list = ['skin', 'nose', 'eye_g', 'l_eye', 'r_eye', 'l_brow', 'r_brow', 'l_ear', 'r_ear', 'mouth', |
| 'u_lip', 'l_lip', 'hair', 'hat', 'ear_r', 'neck_l', 'neck', 'cloth'] |
| self.idx_to_cls_map = {idx: label_list[idx] |
| for idx in range(len(label_list))} |
| self.cls_to_idx_map = { |
| label_list[idx]: idx for idx in range(len(label_list))} |
|
|
| for fname in tqdm(fnames): |
| ims.append(fname) |
|
|
| if 'text' in self.condition_types: |
| im_name = os.path.split(fname)[1].split('.')[0] |
| captions_im = [] |
| with open(os.path.join(im_path, 'celeba-caption/{}.txt'.format(im_name))) as f: |
| for line in f.readlines(): |
| captions_im.append(line.strip()) |
| texts.append(captions_im) |
|
|
| if 'image' in self.condition_types: |
| im_name = int(os.path.split(fname)[1].split('.')[0]) |
| masks.append(os.path.join( |
| im_path, 'CelebAMask-HQ-mask', '{}.png'.format(im_name))) |
| if 'text' in self.condition_types: |
| assert len(texts) == len( |
| ims), "Condition Type Text but could not find captions for all images" |
| if 'image' in self.condition_types: |
| assert len(masks) == len( |
| ims), "Condition Type Image but could not find masks for all images" |
| print('Found {} images'.format(len(ims))) |
| print('Found {} masks'.format(len(masks))) |
| print('Found {} captions'.format(len(texts))) |
| return ims, texts, masks |
|
|
| def get_mask(self, index): |
| r""" |
| Method to get the mask of WxH |
| for given index and convert it into |
| Classes x W x H mask image |
| :param index: |
| :return: |
| """ |
| mask_im = Image.open(self.masks[index]) |
| mask_im = np.array(mask_im) |
| im_base = np.zeros((self.mask_h, self.mask_w, self.mask_channels)) |
| for orig_idx in range(len(self.idx_to_cls_map)): |
| im_base[mask_im == (orig_idx+1), orig_idx] = 1 |
| mask = torch.from_numpy(im_base).permute(2, 0, 1).float() |
| return mask |
|
|
| def __len__(self): |
| return len(self.images) |
|
|
| def __getitem__(self, index): |
| |
| cond_inputs = {} |
| if 'text' in self.condition_types: |
| cond_inputs['text'] = random.sample(self.texts[index], k=1)[0] |
| if 'image' in self.condition_types: |
| mask = self.get_mask(index) |
| cond_inputs['image'] = mask |
| |
|
|
| if self.use_latents: |
| latent = self.latent_maps[self.images[index]] |
| if len(self.condition_types) == 0: |
| return latent |
| else: |
| return latent, cond_inputs |
| else: |
| im = Image.open(self.images[index]) |
| im_tensor = torchvision.transforms.Compose([ |
| torchvision.transforms.Resize(self.im_size), |
| torchvision.transforms.CenterCrop(self.im_size), |
| torchvision.transforms.ToTensor(), |
| ])(im) |
| im.close() |
|
|
| |
| im_tensor = (2 * im_tensor) - 1 |
| if len(self.condition_types) == 0: |
| return im_tensor |
| else: |
| return im_tensor, cond_inputs |
|
|