Spaces:
Runtime error
Runtime error
| import os | |
| import os.path as osp | |
| from modules.layerdiffuse.diffusers_kdiffusion_sdxl import KDiffusionStableDiffusionXLPipeline, UNetFrameConditionModel | |
| from modules.layerdiffuse.vae import TransparentVAE | |
| from modules.layerdiffuse.layerdiff3d import UNetFrameConditionModel | |
| from modules.marigold import MarigoldDepthPipeline | |
| from utils.cv import center_square_pad_resize, img_alpha_blending, smart_resize | |
| from utils.torch_utils import seed_everything | |
| from utils.io_utils import json2dict, dict2json, load_parts, save_tmp_img, load_part, save_psd | |
| from utils.torchcv import cluster_inpaint_part | |
| from psd_tools import PSDImage | |
| from safetensors.torch import load_file | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| VALID_BODY_PARTS_V2 = [ | |
| 'hair', 'headwear', 'face', 'eyes', 'eyewear', 'ears', 'earwear', 'nose', 'mouth', | |
| 'neck', 'neckwear', 'topwear', 'handwear', 'bottomwear', 'legwear', 'footwear', | |
| 'tail', 'wings', 'objects' | |
| ] | |
| layerdiff_pipeline: KDiffusionStableDiffusionXLPipeline = None | |
| def apply_layerdiff( | |
| imgp: str, pretrained: str, num_inference_steps=30, seed=0, save_dir='workspace/layerdiff_output', target_tag_list=VALID_BODY_PARTS_V2, | |
| resolution=1280, vae_ckpt=None, unet_ckpt=None): | |
| global layerdiff_pipeline | |
| if layerdiff_pipeline is None: | |
| trans_vae = TransparentVAE.from_pretrained(pretrained, subfolder='trans_vae') | |
| if unet_ckpt is None: | |
| unet = UNetFrameConditionModel.from_pretrained(pretrained, subfolder='unet') | |
| else: | |
| print(f'load unet from {unet_ckpt}') | |
| unet = UNetFrameConditionModel.from_pretrained(unet_ckpt) | |
| layerdiff_pipeline = KDiffusionStableDiffusionXLPipeline.from_pretrained( | |
| pretrained, | |
| trans_vae=trans_vae, unet=unet, | |
| scheduler=None | |
| ) | |
| if vae_ckpt is not None: | |
| td_sd = {} | |
| vae_sd = {} | |
| sd = load_file(vae_ckpt) | |
| for k, v in sd.items(): | |
| if k.startswith('trans_decoder.'): | |
| td_sd[k.lstrip('trans_decoder.')] = v | |
| elif k.startswith('vae.'): | |
| vae_sd[k.replace('vae.', '')] = v | |
| if len(vae_sd) > 0: | |
| layerdiff_pipeline.vae.load_state_dict(vae_sd) | |
| print(f'load vae from {vae_ckpt}') | |
| if len(td_sd) > 0: | |
| layerdiff_pipeline.trans_vae.decoder.load_state_dict(td_sd) | |
| print(f'load vae from {vae_ckpt}') | |
| layerdiff_pipeline.vae.to(dtype=torch.bfloat16, device='cuda') | |
| layerdiff_pipeline.trans_vae.to(dtype=torch.bfloat16, device='cuda') | |
| layerdiff_pipeline.unet.to(dtype=torch.bfloat16, device='cuda') | |
| layerdiff_pipeline.text_encoder.to(dtype=torch.bfloat16, device='cuda') | |
| layerdiff_pipeline.text_encoder_2.to(dtype=torch.bfloat16, device='cuda') | |
| pipeline = layerdiff_pipeline | |
| saved = osp.join(save_dir, osp.splitext(osp.basename(imgp))[0]) | |
| os.makedirs(saved, exist_ok=True) | |
| input_img = np.array(Image.open(imgp).convert('RGBA')) | |
| fullpage, pad_size, pad_pos = center_square_pad_resize(input_img, resolution, return_pad_info=True) | |
| scale = pad_size[0] / resolution | |
| Image.fromarray(fullpage).save(osp.join(saved, 'src_img.png')) | |
| rng = torch.Generator(device=pipeline.unet.device).manual_seed(seed) | |
| tag_version = pipeline.unet.get_tag_version() | |
| if tag_version == 'v2': | |
| pipeline_output = pipeline( | |
| strength=1.0, | |
| num_inference_steps=num_inference_steps, | |
| batch_size=1, | |
| generator=rng, | |
| guidance_scale=1.0, | |
| prompt=target_tag_list, | |
| negative_prompt='', | |
| fullpage=fullpage | |
| ) | |
| images = pipeline_output.images | |
| for rst, tag in zip(images, target_tag_list): | |
| savename = osp.join(saved, f'{tag}.png') | |
| Image.fromarray(rst).save(savename) | |
| elif tag_version == 'v3': | |
| def _crop_head(img, xywh): | |
| x, y, w, h = xywh | |
| ih, iw = img.shape[:2] | |
| x1 = x | |
| y1 = y | |
| x2 = x + w | |
| y2 = y + h | |
| if w < iw // 2: | |
| px = min(iw - x - w, x, w // 5) | |
| x1 = min(max(x - px, 0), iw) | |
| x2 = min(max(x + w + px, 0), iw) | |
| if h < ih // 2: | |
| py = min(ih - y - h, y, h // 5) | |
| y2 = min(max(y + h + py, 0), ih) | |
| y1 = min(max(y - py, 0), ih) | |
| return img[y1: y2, x1: x2], (x1, y1, x2, y2) | |
| body_tag_list = ['front hair', 'back hair', 'head', 'neck', 'neckwear', 'topwear', 'handwear', 'bottomwear', 'legwear', 'footwear', 'tail', 'wings', 'objects'] | |
| pipeline_output = pipeline( | |
| strength=1.0, | |
| num_inference_steps=num_inference_steps, | |
| batch_size=1, | |
| generator=rng, | |
| guidance_scale=1.0, | |
| prompt=body_tag_list, | |
| negative_prompt='', | |
| fullpage=fullpage, | |
| group_index=0 | |
| ) | |
| images = pipeline_output.images | |
| for rst, tag in zip(pipeline_output.images, body_tag_list): | |
| savename = osp.join(saved, f'{tag}.png') | |
| Image.fromarray(rst).save(savename) | |
| head_img = images[2] | |
| # head_img = np.array(Image.open(osp.join(saved, 'head.png'))) | |
| head_tag_list = ['headwear', 'face', 'irides', 'eyebrow', 'eyewhite', 'eyelash', 'eyewear', 'ears', 'earwear', 'nose', 'mouth'] | |
| hx0, hy0, hw, hh = cv2.boundingRect(cv2.findNonZero((head_img[..., -1] > 15).astype(np.uint8))) | |
| hx = int(hx0 * scale) - pad_pos[0] | |
| hy = int(hy0 * scale) - pad_pos[1] | |
| hw = int(hw * scale) | |
| hh = int(hh * scale) | |
| input_head, (hx1, hy1, hx2, hy2) = _crop_head(input_img, [hx, hy, hw, hh]) | |
| hx1 = int(hx1 / scale + pad_pos[0] / scale) | |
| hy1 = int(hy1 / scale + pad_pos[1] / scale) | |
| ih, iw = input_head.shape[:2] | |
| input_head, pad_size, pad_pos = center_square_pad_resize(input_head, resolution, return_pad_info=True) | |
| Image.fromarray(input_head).save(osp.join(saved, 'src_head.png')) | |
| pipeline_output = pipeline( | |
| strength=1.0, | |
| num_inference_steps=num_inference_steps, | |
| batch_size=1, | |
| generator=rng, | |
| guidance_scale=1.0, | |
| prompt=head_tag_list, | |
| negative_prompt='', | |
| fullpage=input_head, | |
| group_index=1 | |
| ) | |
| canvas = np.zeros((resolution, resolution, 4), dtype=np.uint8) | |
| py1, py2, px1, px2 = (np.array([pad_pos[1], pad_pos[1] + ih, pad_pos[0], pad_pos[0] + iw]) / scale).astype(np.int64) | |
| scale_size = (int(pad_size[0] / scale), int(pad_size[1] / scale)) | |
| for rst, tag in zip(pipeline_output.images, head_tag_list): | |
| rst = smart_resize(rst, scale_size)[py1: py2, px1: px2] | |
| full = canvas.copy() | |
| full[hy1: hy1 + rst.shape[0], hx1: hx1 + rst.shape[1]] = rst | |
| savename = osp.join(saved, f'{tag}.png') | |
| Image.fromarray(full).save(savename) | |
| else: | |
| raise | |
| marigold_pipeline: MarigoldDepthPipeline = None | |
| def apply_marigold(srcp, pretrained: str, num_inference_steps=30, seed=0, save_dir='workspace/layerdiff_output', target_tag_list=VALID_BODY_PARTS_V2, resolution=1280, normalize_depth=False): | |
| global marigold_pipeline | |
| if marigold_pipeline is None: | |
| unet = UNetFrameConditionModel.from_pretrained(pretrained, subfolder='unet') | |
| marigold_pipeline = MarigoldDepthPipeline.from_pretrained(pretrained, unet=unet) | |
| marigold_pipeline.to(device='cuda', dtype=torch.bfloat16) | |
| pipe = marigold_pipeline | |
| srcname = osp.basename(osp.splitext(srcp)[0]) | |
| img_list = [] | |
| caption_list = [] | |
| exist_list = [] | |
| empty_array = np.zeros((resolution, resolution, 4), dtype=np.uint8) | |
| blended_alpha = np.zeros((resolution, resolution), dtype=np.float32) | |
| fullpage = center_square_pad_resize(np.array(Image.open(srcp).convert('RGBA')), resolution) | |
| saved = osp.join(save_dir, srcname) | |
| compose_list = {'eyes': ['eyewhite', 'irides', 'eyelash', 'eyebrow'], 'hair': ['back hair', 'front hair']} | |
| for tag in VALID_BODY_PARTS_V2: | |
| tagp = osp.join(saved, f'{tag}.png') | |
| if osp.exists(tagp): | |
| exist_list.append(True) | |
| caption_list.append(tag) | |
| tag_arr = np.array(Image.open(tagp)) | |
| tag_arr[..., -1][tag_arr[..., -1] < 15] = 0 | |
| # blended_alpha += tag_arr[..., -1].astype(np.float32) / 255 | |
| img_list.append(tag_arr) | |
| else: | |
| img_list.append(empty_array) | |
| exist_list.append(False) | |
| compose_dict = {} | |
| for c, clist in compose_list.items(): | |
| imlist = [] | |
| taglist = [] | |
| for tag in clist: | |
| p = osp.join(saved, tag + '.png') | |
| if osp.exists(p): | |
| tag_arr = np.array(Image.open(p)) | |
| tag_arr[..., -1][tag_arr[..., -1] < 15] = 0 | |
| imlist.append(tag_arr) | |
| taglist.append(tag) | |
| if len(imlist) > 0: | |
| img = img_alpha_blending(imlist, premultiplied=False) | |
| img_list[VALID_BODY_PARTS_V2.index(c)] = img | |
| compose_dict[c] = {'taglist': taglist, 'imlist': imlist} | |
| for img in img_list: | |
| blended_alpha += img[..., -1].astype(np.float32) / 255 | |
| blended_alpha = np.clip(blended_alpha, 0, 1) * 255 | |
| blended_alpha = blended_alpha.astype(np.uint8) | |
| fullpage[..., -1] = blended_alpha | |
| img_list.append(fullpage) | |
| seed_everything(seed) | |
| pipe_out = pipe( | |
| # tensor2img(img, 'pil', denormalize=True, mean=127.5, std=127.5), | |
| color_map=None, | |
| show_progress_bar=False, | |
| img_list = img_list | |
| ) | |
| depth_pred: np.ndarray = pipe_out.depth_tensor | |
| depth_pred = depth_pred.to(device='cpu', dtype=torch.float32).numpy() | |
| drawables = [{'img': img, 'depth': depth} for img, depth in zip(img_list, depth_pred)] | |
| drawables = drawables[:-1] | |
| blended = img_alpha_blending(drawables, premultiplied=False) | |
| infop = osp.join(saved, 'info.json') | |
| if osp.exists(infop): | |
| info = json2dict(infop) | |
| else: | |
| info = {'parts': {}} | |
| parts = info['parts'] | |
| for ii, depth in enumerate(depth_pred[:-1]): | |
| if normalize_depth: | |
| depth_max, depth_min = depth.max(), depth.min() | |
| depth = np.clip((depth - depth_min) / (depth_max - depth_min + 1e-7) * 255, 0, 255).astype(np.uint8) | |
| else: | |
| depth = (np.clip(depth, 0, 1) * 255).astype(np.uint8) | |
| # depth = depth[..., None][..., [-1] * 3].copy() | |
| tag = VALID_BODY_PARTS_V2[ii] | |
| if tag in compose_dict: | |
| mask = blended_alpha > 256 | |
| for t, im in zip(compose_dict[tag]['taglist'][::-1], compose_dict[tag]['imlist'][::-1]): | |
| mask_local = im[..., -1] > 15 | |
| mask_invis = np.bitwise_and(mask, mask_local) | |
| depth_local = np.full((resolution, resolution), fill_value=255, dtype=np.uint8) | |
| depth_local[mask_local] = depth[mask_local] | |
| if np.any(mask_invis): | |
| depth_local[mask_invis] = np.median(depth[np.bitwise_and(mask_local, np.bitwise_not(mask_invis))]) | |
| mask = np.bitwise_or(mask, mask_local) | |
| parts_info = parts.get(t, {}) | |
| savep = osp.join(saved, f'{t}_depth.png') | |
| Image.fromarray(depth_local).save(savep) | |
| parts[t] = parts_info | |
| if normalize_depth: | |
| parts_info['depth_max'] = depth_max | |
| parts_info['depth_min'] = depth_min | |
| continue | |
| parts_info = parts.get(tag, {}) | |
| savep = osp.join(saved, f'{tag}_depth.png') | |
| Image.fromarray(depth).save(savep) | |
| parts[tag] = parts_info | |
| if normalize_depth: | |
| parts_info['depth_max'] = depth_max | |
| parts_info['depth_min'] = depth_min | |
| dict2json(info, infop) | |
| Image.fromarray(blended).save(osp.join(saved, 'reconstruction.png')) | |
| def label_lr_split(labels, stats, id1, id2): | |
| label1 = (labels == id1).astype(np.uint8) * 255 | |
| label2 = (labels == id2).astype(np.uint8) * 255 | |
| stats1, stats2 = stats[id1], stats[id2] | |
| x1 = stats[id1][0] + stats[id1][2] / 2 | |
| x2 = stats[id2][0] + stats[id2][2] / 2 | |
| if x2 < x1: | |
| return label2, label1, stats2, stats1 | |
| else: | |
| return label1, label2, stats1, stats2 | |
| def save_part(tag, saved, part_dict, crop=True, save_part_info=False, save_to_disk=True): | |
| img = part_dict.pop('img') | |
| if 'mask' in part_dict: | |
| part_dict.pop('mask') | |
| depth = part_dict.pop('depth') | |
| mask = img[..., -1] > 10 | |
| depth_median = np.median(depth[mask]) | |
| if crop: | |
| xywh = cv2.boundingRect(cv2.findNonZero(mask.astype(np.uint8))) | |
| xyxy = np.array(xywh).copy() | |
| xyxy[2] += xyxy[0] | |
| xyxy[3] += xyxy[1] | |
| depth = depth[xyxy[1]: xyxy[3], xyxy[0]: xyxy[2]] | |
| img = img[xyxy[1]: xyxy[3], xyxy[0]: xyxy[2]] | |
| x1, y1, x2, y2 = part_dict['xyxy'] | |
| part_dict['xyxy'] = [x1 + xyxy[0], y1 + xyxy[1], x1 + xyxy[2], y1 + xyxy[3]] | |
| # dmin, dmax = np.min(depth), np.max(depth) | |
| depth = np.clip(depth, 0, 1) * 255 | |
| depth = np.round(depth).astype(np.uint8) | |
| # part_dict['depth_min'] = dmin | |
| # part_dict['depth_max'] = dmax | |
| part_dict['depth_median'] = depth_median | |
| part_dict['tag'] = tag | |
| if save_to_disk: | |
| Image.fromarray(img).save(osp.join(saved, tag + '.png')) | |
| Image.fromarray(depth).save(osp.join(saved, tag + '_depth.png')) | |
| if save_part_info: | |
| dict2json(part_dict, osp.join(saved, tag + '.json')) | |
| else: | |
| part_dict['img'] = img | |
| part_dict['depth'] = depth | |
| return part_dict | |
| def process_cuts(img, depth, src_xyxy, tgt_bbox, p=5, mask=None): | |
| tx1, ty1, tx2, ty2 = tgt_bbox[:4] | |
| tx2 += tx1 | |
| ty2 += ty1 | |
| img = img[ty1: ty2, tx1: tx2].copy() | |
| depth = depth[ty1: ty2, tx1: tx2] | |
| depth_median = 1 | |
| if mask is not None: | |
| mask = (mask[ty1: ty2, tx1: tx2].copy() > 15).astype(np.uint8) | |
| ksize = 1 | |
| element = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2 * ksize + 1, 2 * ksize + 1),(ksize, ksize)) | |
| mask = cv2.dilate(mask, element) | |
| img[..., -1] *= mask | |
| depth = 1 - (1-depth) * mask | |
| if np.any(mask): | |
| depth_median = np.median(depth[mask]) | |
| fxyxy = [tx1 + src_xyxy[0], ty1 + src_xyxy[1], tx2 + src_xyxy[0], ty2 + src_xyxy[1]] | |
| return img, depth, fxyxy, depth_median | |
| def part_lr_split(tag, part_info): | |
| num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats( | |
| part_info['mask'].astype(np.uint8) * 255, connectivity=8) | |
| tag2pinfo = {} | |
| if len(stats) > 2: | |
| stats = np.array(stats) | |
| stats_order = np.argsort(stats[..., -1])[::-1][1:] | |
| arml_mask, armr_mask, statsl, statsr = label_lr_split(labels, stats, stats_order[0], stats_order[1]) | |
| depth_median = part_info.get('depth_median', 1) | |
| img, depth, xyxy, depth_median = process_cuts(part_info['img'], part_info['depth'], part_info['xyxy'], statsl, mask=arml_mask) | |
| arml_mask = arml_mask[xyxy[1]: xyxy[3], xyxy[0]: xyxy[2]] | |
| tag2pinfo[f'{tag}-r'] = {'img': img, 'xyxy': xyxy, 'depth': depth, 'depth_median': depth_median, 'tag': f'{tag}-r'} | |
| img, depth, xyxy, depth_median = process_cuts(part_info['img'], part_info['depth'], part_info['xyxy'], statsr, mask=armr_mask) | |
| armr_mask = armr_mask[xyxy[1]: xyxy[3], xyxy[0]: xyxy[2]] | |
| tag2pinfo[f'{tag}-l'] = {'img': img, 'xyxy': xyxy, 'depth': depth, 'depth_median': depth_median, 'tag': f'{tag}-l'} | |
| else: | |
| tag2pinfo[tag] = part_info | |
| return tag2pinfo | |
| def tag_lr_split(tag: str, tag2pinfo): | |
| if tag in tag2pinfo: | |
| part_info = tag2pinfo.pop(tag) | |
| tag2pinfo.update(part_lr_split(tag, part_info)) | |
| def further_extr(srcd: str, rotate=True, save_to_psd=False, tblr_split=True): | |
| saved = osp.join(srcd, 'optimized') | |
| # infos = json2dict(osp.join(srcd, 'info.json')) | |
| os.makedirs(saved, exist_ok=True) | |
| fullpage, infos, part_dict_list = load_parts(srcd, rotate=rotate) | |
| # optim_depth(part_dict_list, fullpage) | |
| tag2pinfo = {} | |
| for pinfo in part_dict_list: | |
| tag = pinfo['tag'] | |
| tag2pinfo[tag] = pinfo | |
| if 'eyes' in tag2pinfo: | |
| part_info = tag2pinfo.pop('eyes') | |
| num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats( | |
| part_info['mask'].astype(np.uint8) * 255, connectivity=8) | |
| if len(stats) > 2: | |
| stats = np.array(stats) | |
| if len(stats[..., -1]) >= 5: | |
| stats_order = np.argsort(stats[..., -1])[::-1][1:] | |
| eyel_mask, eyer_mask, statsl, statsr = label_lr_split(labels, stats, stats_order[0], stats_order[1]) | |
| img, depth, xyxy, _ = process_cuts(part_info['img'], part_info['depth'], part_info['xyxy'], statsl) | |
| tag2pinfo['eyer'] = {'img': img, 'xyxy': xyxy, 'depth': depth} | |
| img, depth, xyxy, _ = process_cuts(part_info['img'], part_info['depth'], part_info['xyxy'], statsr) | |
| tag2pinfo['eyel'] = {'img': img, 'xyxy': xyxy, 'depth': depth} | |
| browl_mask, browr_mask, statsl, statsr = label_lr_split(labels, stats, stats_order[2], stats_order[3]) | |
| img, depth, xyxy, _ = process_cuts(part_info['img'], part_info['depth'], part_info['xyxy'], statsl) | |
| tag2pinfo['browr'] = {'img': img, 'xyxy': xyxy, 'depth': depth} | |
| img, depth, xyxy, _ = process_cuts(part_info['img'], part_info['depth'], part_info['xyxy'], statsr) | |
| tag2pinfo['browl'] = {'img': img, 'xyxy': xyxy, 'depth': depth} | |
| else: | |
| tag2pinfo['eyes'] = part_info | |
| if tblr_split: | |
| tag_lr_split('handwear', tag2pinfo) | |
| eyetags_v3 = ['eyewhite', 'irides', 'eyelash', 'eyebrow'] | |
| for tag in eyetags_v3: | |
| tag_lr_split(tag, tag2pinfo) | |
| tag_lr_split('ears', tag2pinfo) | |
| # if 'headwear' in tag2pinfo: | |
| # part_info = tag2pinfo.pop('headwear') | |
| # tag2pinfo['hair']['img'] = img_alpha_blending([tag2pinfo['hair'], part_info], xyxy=tag2pinfo['hair']['xyxy'], premultiplied=False) | |
| # if 'headwear' in tag2pinfo: | |
| # part_info = tag2pinfo.pop('headwear') | |
| # tag2pinfo['hair']['img'] = img_alpha_blending([tag2pinfo['hair'], part_info], xyxy=tag2pinfo['hair']['xyxy'], premultiplied=False) | |
| # if 'footwear' in tag2pinfo: | |
| # part_info = tag2pinfo.pop('footwear') | |
| # tag2pinfo['legwear']['img'] = img_alpha_blending([tag2pinfo['legwear'], part_info], xyxy=tag2pinfo['legwear']['xyxy'], premultiplied=False) | |
| if 'hair' in tag2pinfo: | |
| part_info = tag2pinfo.pop('hair') | |
| parts = cluster_inpaint_part(**part_info) | |
| parts.sort(key=lambda x: x['depth_median']) | |
| tag2pinfo['hairf'] = parts[0] | |
| tag2pinfo['hairb'] = parts[1] | |
| # if 'footwear' in tag2pinfo: | |
| # tag2pinfo.pop('footwear') | |
| if 'nose' in tag2pinfo: | |
| xyxy = tag2pinfo['nose']['xyxy'] | |
| tag2pinfo['nose']['img'][..., :3] = fullpage[xyxy[1]: xyxy[3], xyxy[0]: xyxy[2], :3] | |
| if 'mouth' in tag2pinfo: | |
| xyxy = tag2pinfo['mouth']['xyxy'] | |
| tag2pinfo['mouth']['img'][..., :3] = fullpage[xyxy[1]: xyxy[3], xyxy[0]: xyxy[2], :3] | |
| part_dict_list = [] | |
| save_dir = osp.dirname(saved) | |
| psd_savep = osp.join(osp.dirname(save_dir), osp.basename(save_dir) + '.psd') | |
| for t in tag2pinfo: | |
| if t not in tag2pinfo: | |
| print(f'{t} is not valid') | |
| continue | |
| part_dict = tag2pinfo[t] | |
| part_dict = save_part(t, saved, part_dict, save_to_disk=not save_to_psd) | |
| if save_to_psd: | |
| part_dict_list.append(part_dict) | |
| if 'face' in tag2pinfo: | |
| for t in ['nose', 'mouth', 'eyes']: | |
| if t in tag2pinfo: | |
| if tag2pinfo[t]['depth_median'] > tag2pinfo['face']['depth_median']: | |
| tag2pinfo[t]['depth_median'] = tag2pinfo['face']['depth_median'] - 0.001 | |
| for t in ['earr', 'earl', 'ears']: | |
| if t in tag2pinfo: | |
| tag2pinfo[t]['depth_median'] = tag2pinfo['face']['depth_median'] + 0.001 | |
| # if 'hairb' in tag2pinfo: | |
| # tag2pinfo['hairb']['depth_median'] = 1. | |
| frame_size = fullpage.shape[:2] | |
| if save_to_psd: | |
| dump_parts_psd(tag2pinfo, frame_size, psd_savep, part_dict_list=part_dict_list) | |
| print(f'psd saved to {psd_savep}') | |
| else: | |
| dict2json({'parts': tag2pinfo, 'frame_size': frame_size}, osp.join(saved, 'info.json')) | |
| def dump_parts_psd(tag2pinfo, frame_size, psd_savep, part_dict_list=None): | |
| if part_dict_list is None: | |
| part_dict_list = [] | |
| for v in tag2pinfo.values(): | |
| part_dict_list.append(v) | |
| psd_depth_savep = osp.splitext(psd_savep)[0] + '_depth.psd' | |
| part_dict_list.sort(key=lambda x: x['depth_median'], reverse=True) | |
| save_psd(psd_savep, part_dict_list, frame_size[0], frame_size[1]) | |
| save_psd(psd_depth_savep, part_dict_list, frame_size[0], frame_size[1], mode='L', img_key='depth') | |
| for pdict in tag2pinfo.values(): | |
| for k in {'img', 'depth', 'mask'}: | |
| if k in pdict: | |
| pdict.pop(k) | |
| dict2json({'parts': tag2pinfo, 'frame_size': frame_size}, psd_savep + '.json') | |
| def psd2partdicts(srcp): | |
| psd = PSDImage.open(srcp) | |
| json_path = srcp + '.json' | |
| partdict = json2dict(json_path) | |
| tag2part= partdict['parts'] | |
| for layer in psd: | |
| img = layer.numpy() | |
| tag2part[layer.name]['img'] = np.round(img * 255).astype(np.uint8) | |
| depth_path = osp.splitext(srcp)[0] + '_depth.psd' | |
| psd = PSDImage.open(depth_path) | |
| for layer in psd: | |
| img = layer.numpy() | |
| tag2part[layer.name]['depth'] = img[..., 0] | |
| return partdict | |
| def seg_wdepth(srcp, *args, **kwargs): | |
| srcd = osp.dirname(srcp) | |
| part_dict = load_part(srcp) | |
| tag = part_dict['tag'] | |
| rst_list = cluster_inpaint_part(**part_dict) | |
| saved = osp.join(srcd, tag) | |
| if len(rst_list) > 0: | |
| os.makedirs(saved, exist_ok=True) | |
| for ii, part in enumerate(rst_list): | |
| sub_tag = tag + '-' + str(ii) | |
| save_part(sub_tag, saved, part, save_part_info=True) | |
| print(f'sub part saved to {saved}') | |
| else: | |
| print(f'seg_wdepth: failed to seg more parts') | |
| def seg_wdepth_psd(srcp, target_tags, savep=None): | |
| part_infos = psd2partdicts(srcp) | |
| if savep is None: | |
| savep = osp.splitext(srcp)[0] + '_wdepth.psd' | |
| if isinstance(target_tags, str): | |
| target_tags = target_tags.split(',') | |
| else: | |
| assert isinstance(target_tags, list) | |
| valid_tags = list(part_infos['parts'].keys()) | |
| for tag in target_tags: | |
| if tag not in part_infos['parts']: | |
| print(f'{tag} is not in {valid_tags}') | |
| continue | |
| part_dict = part_infos['parts'].pop(tag) | |
| mask = part_dict['img'][..., -1] > 10 | |
| if not np.any(mask): | |
| continue | |
| part_dict['mask'] = mask | |
| rst_list = cluster_inpaint_part(**part_dict) | |
| if len(rst_list) > 0: | |
| for ii, part in enumerate(rst_list): | |
| sub_tag = tag + '-' + str(ii) | |
| part['tag'] = sub_tag | |
| part_infos['parts'][sub_tag] = part | |
| dump_parts_psd(part_infos['parts'], part_infos['frame_size'], savep) | |
| print(f'psd saved to {savep}') | |
| def seg_wlr(srcp, *args, **kwargs): | |
| srcd = osp.dirname(srcp) | |
| part_dict = load_part(srcp) | |
| tag = part_dict['tag'] | |
| rst_dict = part_lr_split(tag, part_dict) | |
| saved = osp.join(srcd, tag) | |
| if len(rst_dict) > 1: | |
| os.makedirs(saved, exist_ok=True) | |
| for sub_tag, part in rst_dict.items(): | |
| save_part(sub_tag, saved, part, save_part_info=True) | |
| print(f'sub part saved to {saved}') | |
| else: | |
| print(f'seg_wdepth: failed to seg more parts') | |
| def seg_wlr_psd(srcp, target_tags, savep=None): | |
| part_infos = psd2partdicts(srcp) | |
| if savep is None: | |
| savep = osp.splitext(srcp)[0] + '_lrsplit.psd' | |
| if isinstance(target_tags, str): | |
| target_tags = target_tags.split(',') | |
| else: | |
| assert isinstance(target_tags, list) | |
| valid_tags = list(part_infos['parts'].keys()) | |
| for tag in target_tags: | |
| if tag not in part_infos['parts']: | |
| print(f'{tag} is not in {valid_tags}') | |
| continue | |
| part_dict = part_infos['parts'].pop(tag) | |
| mask = part_dict['img'][..., -1] > 10 | |
| if not np.any(mask): | |
| continue | |
| part_dict['mask'] = mask | |
| rst_dict = part_lr_split(tag, part_dict) | |
| part_infos['parts'].update(rst_dict) | |
| dump_parts_psd(part_infos['parts'], part_infos['frame_size'], savep) | |
| print(f'psd saved to {savep}') |