|
|
"""This script defines the custom dataset for Deep3DFaceRecon_pytorch |
|
|
""" |
|
|
import json |
|
|
import os.path |
|
|
import pickle |
|
|
import random |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import util.util as util |
|
|
from data.base_dataset import apply_img_affine |
|
|
from data.base_dataset import apply_lm_affine |
|
|
from data.base_dataset import BaseDataset |
|
|
from data.base_dataset import get_affine_mat |
|
|
from data.base_dataset import get_transform |
|
|
from data.image_folder import make_dataset |
|
|
from PIL import Image |
|
|
from scipy.io import loadmat |
|
|
from scipy.io import savemat |
|
|
from util.load_mats import load_lm3d |
|
|
from util.preprocess import align_img |
|
|
from util.preprocess import estimate_norm |
|
|
|
|
|
|
|
|
def default_flist_reader(flist): |
|
|
""" |
|
|
flist format: impath label\nimpath label\n ...(same to caffe's filelist) |
|
|
""" |
|
|
imlist = [] |
|
|
with open(flist, "r") as rf: |
|
|
for line in rf.readlines(): |
|
|
impath = line.strip() |
|
|
imlist.append(impath) |
|
|
|
|
|
return imlist |
|
|
|
|
|
|
|
|
def jason_flist_reader(flist): |
|
|
with open(flist, "r") as fp: |
|
|
info = json.load(fp) |
|
|
return info |
|
|
|
|
|
|
|
|
def parse_label(label): |
|
|
return torch.tensor(np.array(label).astype(np.float32)) |
|
|
|
|
|
|
|
|
class FlistDataset(BaseDataset): |
|
|
""" |
|
|
It requires one directories to host training images '/path/to/data/train' |
|
|
You can train the model with the dataset flag '--dataroot /path/to/data'. |
|
|
""" |
|
|
|
|
|
def __init__(self, opt): |
|
|
"""Initialize this dataset class. |
|
|
|
|
|
Parameters: |
|
|
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions |
|
|
""" |
|
|
BaseDataset.__init__(self, opt) |
|
|
|
|
|
self.lm3d_std = load_lm3d(opt.bfm_folder) |
|
|
|
|
|
msk_names = default_flist_reader(opt.flist) |
|
|
self.msk_paths = [os.path.join(opt.data_root, i) for i in msk_names] |
|
|
|
|
|
self.size = len(self.msk_paths) |
|
|
self.opt = opt |
|
|
|
|
|
self.name = "train" if opt.isTrain else "val" |
|
|
if "_" in opt.flist: |
|
|
self.name += "_" + opt.flist.split(os.sep)[-1].split("_")[0] |
|
|
|
|
|
def __getitem__(self, index): |
|
|
"""Return a data point and its metadata information. |
|
|
|
|
|
Parameters: |
|
|
index (int) -- a random integer for data indexing |
|
|
|
|
|
Returns a dictionary that contains A, B, A_paths and B_paths |
|
|
img (tensor) -- an image in the input domain |
|
|
msk (tensor) -- its corresponding attention mask |
|
|
lm (tensor) -- its corresponding 3d landmarks |
|
|
im_paths (str) -- image paths |
|
|
aug_flag (bool) -- a flag used to tell whether its raw or augmented |
|
|
""" |
|
|
msk_path = self.msk_paths[index % self.size] |
|
|
img_path = msk_path.replace("mask/", "") |
|
|
lm_path = ".".join(msk_path.replace("mask", "landmarks").split(".")[:-1]) + ".txt" |
|
|
|
|
|
raw_img = Image.open(img_path).convert("RGB") |
|
|
raw_msk = Image.open(msk_path).convert("RGB") |
|
|
raw_lm = np.loadtxt(lm_path).astype(np.float32) |
|
|
|
|
|
_, img, lm, msk = align_img(raw_img, raw_lm, self.lm3d_std, raw_msk) |
|
|
|
|
|
aug_flag = self.opt.use_aug and self.opt.isTrain |
|
|
if aug_flag: |
|
|
img, lm, msk = self._augmentation(img, lm, self.opt, msk) |
|
|
|
|
|
_, H = img.size |
|
|
M = estimate_norm(lm, H) |
|
|
transform = get_transform() |
|
|
img_tensor = transform(img) |
|
|
msk_tensor = transform(msk)[:1, ...] |
|
|
lm_tensor = parse_label(lm) |
|
|
M_tensor = parse_label(M) |
|
|
|
|
|
return { |
|
|
"imgs": img_tensor, |
|
|
"lms": lm_tensor, |
|
|
"msks": msk_tensor, |
|
|
"M": M_tensor, |
|
|
"im_paths": img_path, |
|
|
"aug_flag": aug_flag, |
|
|
"dataset": self.name, |
|
|
} |
|
|
|
|
|
def _augmentation(self, img, lm, opt, msk=None): |
|
|
affine, affine_inv, flip = get_affine_mat(opt, img.size) |
|
|
img = apply_img_affine(img, affine_inv) |
|
|
lm = apply_lm_affine(lm, affine, flip, img.size) |
|
|
if msk is not None: |
|
|
msk = apply_img_affine(msk, affine_inv, method=Image.BILINEAR) |
|
|
return img, lm, msk |
|
|
|
|
|
def __len__(self): |
|
|
"""Return the total number of images in the dataset.""" |
|
|
return self.size |
|
|
|