| | """This script is the test script for Deep3DFaceRecon_pytorch |
| | Pytorch Deep3D_Recon is 8x faster than TF-based, 16s/iter ==> 2s/iter |
| | """ |
| |
|
| | import os |
| | |
| | import torch |
| | import torch.nn as nn |
| | from .deep_3drecon_models.facerecon_model import FaceReconModel |
| | from .util.preprocess import align_img |
| | from PIL import Image |
| | import numpy as np |
| | from .util.load_mats import load_lm3d |
| | import torch |
| | import pickle as pkl |
| | from PIL import Image |
| |
|
| | from utils.commons.tensor_utils import convert_to_tensor, convert_to_np |
| |
|
| | with open("deep_3drecon/reconstructor_opt.pkl", "rb") as f: |
| | opt = pkl.load(f) |
| | |
| | class Reconstructor(nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| | self.model = FaceReconModel(opt) |
| | self.model.setup(opt) |
| | self.model.device = 'cuda:0' |
| | self.model.parallelize() |
| | |
| | self.model.eval() |
| | self.lm3d_std = load_lm3d(opt.bfm_folder) |
| | |
| | def preprocess_data(self, im, lm, lm3d_std): |
| | |
| | H,W,_ = im.shape |
| | lm = lm.reshape([-1, 2]) |
| | lm[:, -1] = H - 1 - lm[:, -1] |
| |
|
| | _, im, lm, _ = align_img(Image.fromarray(convert_to_np(im)), convert_to_np(lm), convert_to_np(lm3d_std)) |
| | im = torch.tensor(np.array(im)/255., dtype=torch.float32).permute(2, 0, 1).unsqueeze(0) |
| | lm = torch.tensor(lm).unsqueeze(0) |
| | return im, lm |
| | |
| | @torch.no_grad() |
| | def recon_coeff(self, batched_images, batched_lm5, return_image=True, batch_mode=True): |
| | bs = batched_images.shape[0] |
| | data_lst = [] |
| | for i in range(bs): |
| | img = batched_images[i] |
| | lm5 = batched_lm5[i] |
| | align_im, lm = self.preprocess_data(img, lm5, self.lm3d_std) |
| | data = { |
| | 'imgs': align_im, |
| | 'lms': lm |
| | } |
| | data_lst.append(data) |
| | if not batch_mode: |
| | coeff_lst = [] |
| | align_lst = [] |
| | for i in range(bs): |
| | data = data_lst |
| | self.model.set_input(data) |
| | self.model.forward() |
| | pred_coeff = self.model.output_coeff.cpu().numpy() |
| | align_im = (align_im.squeeze().permute(1,2,0)*255).int().numpy().astype(np.uint8) |
| | coeff_lst.append(pred_coeff) |
| | align_lst.append(align_im) |
| | batch_coeff = np.concatenate(coeff_lst) |
| | batch_align_img = np.stack(align_lst) |
| | else: |
| | imgs = torch.cat([d['imgs'] for d in data_lst]) |
| | lms = torch.cat([d['lms'] for d in data_lst]) |
| | data = { |
| | 'imgs': imgs, |
| | 'lms': lms |
| | } |
| | self.model.set_input(data) |
| | self.model.forward() |
| | batch_coeff = self.model.output_coeff.cpu().numpy() |
| | batch_align_img = (imgs.permute(0,2,3,1)*255).int().numpy().astype(np.uint8) |
| | return batch_coeff, batch_align_img |
| | |
| | |
| | |
| | def forward(self, batched_images, batched_lm5, return_image=True): |
| | return self.recon_coeff(batched_images, batched_lm5, return_image) |
| |
|
| |
|
| | |
| |
|