import sys import os.path as osp import os import argparse default_n_threads = 8 os.environ['OPENBLAS_NUM_THREADS'] = f"{default_n_threads}" os.environ['MKL_NUM_THREADS'] = f"{default_n_threads}" os.environ['OMP_NUM_THREADS'] = f"{default_n_threads}" sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__)))) import numpy as np from tqdm import tqdm from PIL import Image from utils.io_utils import json2dict, dict2json, save_psd, load_img_depth from utils.torch_utils import seed_everything eye_mesh_list = [ '1-2-3-2-2+eyebgs-l', '1-2-3-3-2+irides-l', '1-2-3-1-2+eyelashs-l', '1-2-3-2-1+eyebgs-r', '1-2-3-3-1+irides-r', '1-2-3-1-1+eyelashs-r' ] if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--saved', type=str, default='workspace/datasets/tmp_output') parser.add_argument('--srcd', type=str, default='workspace/datasets/testcaseall_output') args = parser.parse_args() srcd = args.srcd saved = args.saved os.makedirs(saved, exist_ok=True) seed_everything(0) for srcp in tqdm(os.listdir(srcd)): srcp = osp.join(args.srcd, srcp) from talking_head.preprocess import further_extr further_extr(srcp, rotate=False) import smug from smug.utils.logging import set_logger from smug import KeyPointsController, IRIAMSeparationController controller_kp = KeyPointsController() controller_sep = IRIAMSeparationController() for srcp in tqdm(os.listdir(srcd)): srcp = osp.join(args.srcd, srcp) faceparsing_output = osp.join(srcp, 'face_parsing') src_img = np.array(Image.open(osp.join(srcp, 'src_img.png')).convert('RGB')) os.makedirs(faceparsing_output, exist_ok=True) if len(os.listdir(faceparsing_output))> 1: continue try: result = controller_kp.run(src_img) landmark_dict = { 'pose': result.pose, 'face_landmarks': {'left': result.face_landmarks.left, 'top': result.face_landmarks.top, 'points': result.face_landmarks.points} } dict2json(landmark_dict, osp.join(faceparsing_output, 'pose.json')) ls = controller_sep.run(src_img, {'output_folder': faceparsing_output}) except Exception as e: print(f'failed to process {srcp}: {e}') for srcp in tqdm(os.listdir(srcd)): srcp = osp.join(args.srcd, srcp) op_dir = osp.join(srcp, 'optimized') faceparsing_output = osp.join(srcp, 'face_parsing') face_parts = osp.join(faceparsing_output, 'parts.json') face_list = None if not osp.exists(face_parts): # op_dir = srcp pass else: face_parts = json2dict(face_parts) face_list = [] eye_valid = True for e in eye_mesh_list: ep = osp.join(faceparsing_output, e) + '.png' if not osp.exists(ep): eye_valid = False break img = np.array(Image.open(ep)) x, y, w, h = face_parts[e]['x'], face_parts[e]['y'], face_parts[e]['w'], face_parts[e]['h'] face_list.append({'img': img, 'xyxy': [x, y, x+w, y+h], 'layer_name': e}) # if not eye_valid: # continue src_infop = osp.join(op_dir, 'info.json') src_info = json2dict(src_infop) load_img_depth(op_dir, src_info) # part_dict_list.sort(key=lambda x: x['depth_median'], reverse=True) img_list = [] for t, pd in src_info['parts'].items(): if 'xyxy' not in pd: pd['xyxy'] = [0, 0, pd['img'].shape[1], pd['img'].shape[0]] if 'depth_median' not in pd: print(np.max(pd['mask'])) pd['depth_median'] = np.median(pd['depth'][pd['mask'] > 127]) pd['tag'] = t img_list.append({'img': pd['img'], 'xyxy': pd['xyxy'], 'layer_name': pd['tag'], 'depth_median': pd['depth_median']}) img_list.sort(key=lambda x: x['depth_median'], reverse=True) if face_list is not None: img_list += face_list psd_savep = osp.join(saved, osp.basename(srcp)) + '.psd' save_psd(psd_savep, img_list, 1024, 1024)