Spaces:
Paused
Paused
| # -*- coding: utf-8 -*- | |
| """ | |
| # File name: script_get_sean_code.py | |
| # Time : 2021/11/16 15:56 | |
| # Author: xyguoo@163.com | |
| # Description: | |
| """ | |
| import os | |
| import sys | |
| sys.path.append('.') | |
| from dataset_scripts.utils import merge_pickle_dir_to_dict | |
| import cv2 | |
| import tqdm | |
| from global_value_utils import GLOBAL_DATA_ROOT, DATASET_NAME | |
| from hair_editor import HairEditor | |
| import pickle | |
| data_name = DATASET_NAME | |
| root_dir = GLOBAL_DATA_ROOT | |
| imgs_sub_dir = 'images_256' | |
| target_dir = os.path.join(root_dir, 'hair_info_all_dataset/sean_code') | |
| he = HairEditor(load_mask_model=False) | |
| path_list = [] | |
| for d in data_name: | |
| data_dir = os.path.join(root_dir, d, imgs_sub_dir) | |
| path_list += [os.path.join(data_dir, pp) for pp in os.listdir(data_dir)] | |
| path_list.sort() | |
| # res_dict = {} | |
| if not os.path.exists(target_dir): | |
| os.makedirs(target_dir) | |
| for img_path in tqdm.tqdm(path_list): | |
| for dd in data_name: | |
| if img_path.find(dd) != -1: | |
| dataset_name = dd | |
| break | |
| else: | |
| raise NotImplementedError | |
| base_name = os.path.basename(img_path) | |
| hair_path = os.path.join(root_dir, dataset_name, imgs_sub_dir, base_name) | |
| hair_img = cv2.cvtColor(cv2.imread(hair_path), cv2.COLOR_BGR2RGB) | |
| hair_parsing = cv2.imread(os.path.join(root_dir, dataset_name, 'label', base_name), cv2.IMREAD_GRAYSCALE) | |
| # resize | |
| hair_img = he.preprocess_img(hair_img) | |
| hair_parsing = he.preprocess_mask(hair_parsing) | |
| cur_code = he.get_code(hair_img, hair_parsing) | |
| cur_code = cur_code.cpu().numpy()[0] | |
| # res_dict['%s___%s' % (dataset_name, base_name)] = cur_code | |
| target_file = os.path.join(target_dir, '%s___%s.pkl' % (dataset_name, base_name[:-4])) | |
| with open(target_file, 'wb') as f: | |
| pickle.dump(cur_code, f) | |
| merge_pickle_dir_to_dict(target_dir, os.path.join(root_dir, 'sean_code_dict.pkl')) | |