Diffusion-Transformer / dataset.py
YashNagraj75's picture
Add the dataset and the training script
31677e7
import glob
import os
import random
import torch
import torchvision
import numpy as np
from PIL import Image
from tqdm import tqdm
from torch.utils.data.dataset import Dataset
class CelebDataset(Dataset):
r"""
Celeb dataset will by default centre crop and resize the images.
This can be replaced by any other dataset. As long as all the images
are under one directory.
"""
def __init__(self, split, im_path, im_size=256, im_channels=3, im_ext='jpg',
use_latents=False, latent_path=None, condition_config=None):
self.split = split
self.im_size = im_size
self.im_channels = im_channels
self.im_ext = im_ext
self.im_path = im_path
self.latent_maps = None
self.use_latents = False
self.condition_types = [] if condition_config is None else condition_config['condition_types']
self.idx_to_cls_map = {}
self.cls_to_idx_map = {}
if 'image' in self.condition_types:
self.mask_channels = condition_config['image_condition_config']['image_condition_input_channels']
self.mask_h = condition_config['image_condition_config']['image_condition_h']
self.mask_w = condition_config['image_condition_config']['image_condition_w']
self.images, self.texts, self.masks = self.load_images(im_path)
def load_images(self, im_path):
r"""
Gets all images from the path specified
and stacks them all up
"""
assert os.path.exists(
im_path), "images path {} does not exist".format(im_path)
ims = []
fnames = glob.glob(os.path.join(
im_path, 'CelebA-HQ-img/*.{}'.format('png')))
fnames += glob.glob(os.path.join(im_path,
'CelebA-HQ-img/*.{}'.format('jpg')))
fnames += glob.glob(os.path.join(im_path,
'CelebA-HQ-img/*.{}'.format('jpeg')))
texts = []
masks = []
if 'image' in self.condition_types:
label_list = ['skin', 'nose', 'eye_g', 'l_eye', 'r_eye', 'l_brow', 'r_brow', 'l_ear', 'r_ear', 'mouth',
'u_lip', 'l_lip', 'hair', 'hat', 'ear_r', 'neck_l', 'neck', 'cloth']
self.idx_to_cls_map = {idx: label_list[idx]
for idx in range(len(label_list))}
self.cls_to_idx_map = {
label_list[idx]: idx for idx in range(len(label_list))}
for fname in tqdm(fnames):
ims.append(fname)
if 'text' in self.condition_types:
im_name = os.path.split(fname)[1].split('.')[0]
captions_im = []
with open(os.path.join(im_path, 'celeba-caption/{}.txt'.format(im_name))) as f:
for line in f.readlines():
captions_im.append(line.strip())
texts.append(captions_im)
if 'image' in self.condition_types:
im_name = int(os.path.split(fname)[1].split('.')[0])
masks.append(os.path.join(
im_path, 'CelebAMask-HQ-mask', '{}.png'.format(im_name)))
if 'text' in self.condition_types:
assert len(texts) == len(
ims), "Condition Type Text but could not find captions for all images"
if 'image' in self.condition_types:
assert len(masks) == len(
ims), "Condition Type Image but could not find masks for all images"
print('Found {} images'.format(len(ims)))
print('Found {} masks'.format(len(masks)))
print('Found {} captions'.format(len(texts)))
return ims, texts, masks
def get_mask(self, index):
r"""
Method to get the mask of WxH
for given index and convert it into
Classes x W x H mask image
:param index:
:return:
"""
mask_im = Image.open(self.masks[index])
mask_im = np.array(mask_im)
im_base = np.zeros((self.mask_h, self.mask_w, self.mask_channels))
for orig_idx in range(len(self.idx_to_cls_map)):
im_base[mask_im == (orig_idx+1), orig_idx] = 1
mask = torch.from_numpy(im_base).permute(2, 0, 1).float()
return mask
def __len__(self):
return len(self.images)
def __getitem__(self, index):
######## Set Conditioning Info ########
cond_inputs = {}
if 'text' in self.condition_types:
cond_inputs['text'] = random.sample(self.texts[index], k=1)[0]
if 'image' in self.condition_types:
mask = self.get_mask(index)
cond_inputs['image'] = mask
#######################################
if self.use_latents:
latent = self.latent_maps[self.images[index]]
if len(self.condition_types) == 0:
return latent
else:
return latent, cond_inputs
else:
im = Image.open(self.images[index])
im_tensor = torchvision.transforms.Compose([
torchvision.transforms.Resize(self.im_size),
torchvision.transforms.CenterCrop(self.im_size),
torchvision.transforms.ToTensor(),
])(im)
im.close()
# Convert input to -1 to 1 range.
im_tensor = (2 * im_tensor) - 1
if len(self.condition_types) == 0:
return im_tensor
else:
return im_tensor, cond_inputs