| | """This script defines the custom dataset for Deep3DFaceRecon_pytorch |
| | """ |
| |
|
| | import os.path |
| | from data.base_dataset import BaseDataset, get_transform, get_affine_mat, apply_img_affine, apply_lm_affine |
| | from data.image_folder import make_dataset |
| | from PIL import Image |
| | import random |
| | import util.util as util |
| | import numpy as np |
| | import json |
| | import torch |
| | from scipy.io import loadmat, savemat |
| | import pickle |
| | from util.preprocess import align_img, estimate_norm |
| | from util.load_mats import load_lm3d |
| |
|
| |
|
| | def default_flist_reader(flist): |
| | """ |
| | flist format: impath label\nimpath label\n ...(same to caffe's filelist) |
| | """ |
| | imlist = [] |
| | with open(flist, 'r') as rf: |
| | for line in rf.readlines(): |
| | impath = line.strip() |
| | imlist.append(impath) |
| |
|
| | return imlist |
| |
|
| | def jason_flist_reader(flist): |
| | with open(flist, 'r') as fp: |
| | info = json.load(fp) |
| | return info |
| |
|
| | def parse_label(label): |
| | return torch.tensor(np.array(label).astype(np.float32)) |
| |
|
| |
|
| | class FlistDataset(BaseDataset): |
| | """ |
| | It requires one directories to host training images '/path/to/data/train' |
| | You can train the model with the dataset flag '--dataroot /path/to/data'. |
| | """ |
| |
|
| | def __init__(self, opt): |
| | """Initialize this dataset class. |
| | |
| | Parameters: |
| | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions |
| | """ |
| | BaseDataset.__init__(self, opt) |
| | |
| | self.lm3d_std = load_lm3d(opt.bfm_folder) |
| | |
| | msk_names = default_flist_reader(opt.flist) |
| | self.msk_paths = [os.path.join(opt.data_root, i) for i in msk_names] |
| |
|
| | self.size = len(self.msk_paths) |
| | self.opt = opt |
| | |
| | self.name = 'train' if opt.isTrain else 'val' |
| | if '_' in opt.flist: |
| | self.name += '_' + opt.flist.split(os.sep)[-1].split('_')[0] |
| | |
| |
|
| | def __getitem__(self, index): |
| | """Return a data point and its metadata information. |
| | |
| | Parameters: |
| | index (int) -- a random integer for data indexing |
| | |
| | Returns a dictionary that contains A, B, A_paths and B_paths |
| | img (tensor) -- an image in the input domain |
| | msk (tensor) -- its corresponding attention mask |
| | lm (tensor) -- its corresponding 3d landmarks |
| | im_paths (str) -- image paths |
| | aug_flag (bool) -- a flag used to tell whether its raw or augmented |
| | """ |
| | msk_path = self.msk_paths[index % self.size] |
| | img_path = msk_path.replace('mask/', '') |
| | lm_path = '.'.join(msk_path.replace('mask', 'landmarks').split('.')[:-1]) + '.txt' |
| |
|
| | raw_img = Image.open(img_path).convert('RGB') |
| | raw_msk = Image.open(msk_path).convert('RGB') |
| | raw_lm = np.loadtxt(lm_path).astype(np.float32) |
| |
|
| | _, img, lm, msk = align_img(raw_img, raw_lm, self.lm3d_std, raw_msk) |
| | |
| | aug_flag = self.opt.use_aug and self.opt.isTrain |
| | if aug_flag: |
| | img, lm, msk = self._augmentation(img, lm, self.opt, msk) |
| | |
| | _, H = img.size |
| | M = estimate_norm(lm, H) |
| | transform = get_transform() |
| | img_tensor = transform(img) |
| | msk_tensor = transform(msk)[:1, ...] |
| | lm_tensor = parse_label(lm) |
| | M_tensor = parse_label(M) |
| |
|
| |
|
| | return {'imgs': img_tensor, |
| | 'lms': lm_tensor, |
| | 'msks': msk_tensor, |
| | 'M': M_tensor, |
| | 'im_paths': img_path, |
| | 'aug_flag': aug_flag, |
| | 'dataset': self.name} |
| |
|
| | def _augmentation(self, img, lm, opt, msk=None): |
| | affine, affine_inv, flip = get_affine_mat(opt, img.size) |
| | img = apply_img_affine(img, affine_inv) |
| | lm = apply_lm_affine(lm, affine, flip, img.size) |
| | if msk is not None: |
| | msk = apply_img_affine(msk, affine_inv, method=Image.BILINEAR) |
| | return img, lm, msk |
| | |
| |
|
| |
|
| |
|
| | def __len__(self): |
| | """Return the total number of images in the dataset. |
| | """ |
| | return self.size |
| |
|