| from confs import * |
| from pathlib import Path |
| import numpy as np |
| import cv2 |
| from PIL import Image |
| import torch |
| import torch.utils.data as data |
| import torchvision.transforms as T |
| from einops import rearrange |
| import albumentations |
|
|
| from util_face import * |
| from util_4dataset import * |
| from util_cv2 import cv2_resize_auto_interpolation |
| from Mediapipe_Result_Cache import Mediapipe_Result_Cache |
|
|
|
|
| def resize_A(img, dataset_name, size=(512, 512), interpolation=None): |
| is_pil = isinstance(img, Image.Image) |
| if is_pil: |
| img = np.array(img) |
| if img.shape[:2] != (512, 512): |
| img = cv2_resize_auto_interpolation(img, size, interpolation=interpolation) |
| if is_pil: |
| img = Image.fromarray(img) |
| return img |
|
|
|
|
| def un_norm_clip(x1): |
| x = x1 * 1.0 |
| reduce = False |
| if len(x.shape) == 3: |
| x = x.unsqueeze(0) |
| reduce = True |
| x[:, 0, :, :] = x[:, 0, :, :] * 0.26862954 + 0.48145466 |
| x[:, 1, :, :] = x[:, 1, :, :] * 0.26130258 + 0.4578275 |
| x[:, 2, :, :] = x[:, 2, :, :] * 0.27577711 + 0.40821073 |
| if reduce: |
| x = x.squeeze(0) |
| return x |
|
|
|
|
| def un_norm(x): |
| return (x + 1.0) / 2.0 |
|
|
|
|
| def _dilate(_mask, kernel_size, iterations): |
| _mask = _mask.astype(np.uint8) |
| kernel = np.ones((kernel_size, kernel_size), np.uint8) |
| _mask = cv2.dilate(_mask, kernel, iterations=iterations) |
| _mask = _mask.astype(bool) |
| return _mask |
|
|
|
|
| def dilate_4_task0(sm_mask): |
| sm_mask = np.array(sm_mask) |
| preserve1 = [2, 3, 10, 5] |
| mask1 = np.isin(sm_mask, preserve1) |
| mask1 = _dilate(mask1, 7, 1) |
| preserve2 = [3, 10] |
| mask2 = np.isin(sm_mask, preserve2) |
| mask2 = _dilate(mask2, 10, 3) |
| preserve3 = [1] |
| mask3 = np.isin(sm_mask, preserve3) |
| mask3 = _dilate(mask3, 7, 2) |
| mask = mask1 | mask2 | mask3 |
| return mask |
|
|
|
|
| class Dataset_custom(data.Dataset): |
| mean = (0.5, 0.5, 0.5) |
| std = (0.5, 0.5, 0.5) |
|
|
| def get_img4clip( |
| self, |
| img, |
| sm_mask, |
| preserve, |
| for_clip=True, |
| add_semantic_head=False, |
| mask_after_npisin=None, |
| for_inpaint512=False, |
| ): |
| sm_mask = np.array(sm_mask) |
| if mask_after_npisin is None: |
| if self.task == 0 and 0: |
| mask = dilate_4_task0(sm_mask) |
| else: |
| mask = np.isin(sm_mask, preserve) |
| if self.task == 0 and 1 and for_inpaint512: |
| forehead_mask = get_forehead_mask(sm_mask) |
| mask = mask & ~forehead_mask |
| else: |
| mask = mask_after_npisin |
|
|
| if isinstance(img, np.ndarray): |
| img = Image.fromarray(img) |
| if add_semantic_head: |
| mask_before_colorSM = mask |
| img, mask = add_colorSM(img, sm_mask, preserve, None) |
| mask = mask_after_npisin__2__tensor(mask) |
|
|
| if for_clip: |
| image_tensor = get_tensor_clip()(img) |
| else: |
| image_tensor = get_tensor(mean=self.mean, std=self.std)(img) |
| image_tensor = T.Resize([512, 512])(image_tensor) |
| image_tensor = image_tensor * mask |
| if for_clip: |
| image_tensor = 255.0 * rearrange(un_norm_clip(image_tensor), "c h w -> h w c").cpu().numpy() |
| _size = 224 |
| else: |
| image_tensor = 255.0 * rearrange(un_norm(image_tensor), "c h w -> h w c").cpu().numpy() |
| _size = 512 |
| |
| image_tensor = albumentations.Resize(height=_size, width=_size)(image=image_tensor) |
| image_tensor = Image.fromarray(image_tensor["image"].astype(np.uint8)) |
| if for_clip: |
| image_tensor = get_tensor_clip()(image_tensor) |
| else: |
| image_tensor = get_tensor(mean=self.mean, std=self.std)(image_tensor) |
| image_tensor = image_tensor * mask |
| if add_semantic_head: |
| mask = mask_after_npisin__2__tensor(mask_before_colorSM) |
| return image_tensor, mask |
|
|
| def __init__( |
| self, |
| state, |
| task, |
| paths_tgt, |
| paths_ref, |
| name="custom", |
| ): |
| if task == 0: |
| USE_filter_mediapipe_fail_swap = 1 |
| USE_pts = 1 |
| READ_mediapipe_result_from_cache = 1 |
| elif task == 1: |
| USE_filter_mediapipe_fail_swap = 0 |
| USE_pts = 0 |
| READ_mediapipe_result_from_cache = 1 |
| elif task == 2: |
| USE_filter_mediapipe_fail_swap = 1 |
| USE_pts = 1 |
| READ_mediapipe_result_from_cache = 1 |
| elif task == 3: |
| USE_filter_mediapipe_fail_swap = 0 |
| USE_pts = 1 |
| READ_mediapipe_result_from_cache = 1 |
| self.READ_mediapipe_result_from_cache = READ_mediapipe_result_from_cache |
|
|
| assert state == "test" |
| self.state = state |
| self.image_size = 512 |
| self.kernel = np.ones((1, 1), np.uint8) |
| self.name = name |
|
|
| assert paths_tgt is not None and paths_ref is not None, "paths_tgt and paths_ref are required" |
| assert len(paths_tgt) == len(paths_ref), "paths_tgt and paths_ref must be the same length" |
| self.paths_tgt = list(paths_tgt) |
| self.paths_ref = list(paths_ref) |
|
|
| if READ_mediapipe_result_from_cache: |
| self.mediapipe_Result_Cache = Mediapipe_Result_Cache() |
| self.task = task |
|
|
| def __getitem__(self, index): |
| task = self.task |
| path_tgt = self.paths_tgt[index] |
| path_ref = self.paths_ref[index] |
|
|
|
|
| img_tgt = Image.open(path_tgt).convert("RGB") |
| img_tgt = resize_A(img_tgt, self.name) |
|
|
| mask_path = path_img_2_path_mask(path_tgt) |
| if self.task == 0: |
| preserve = [1, 2, 3, 10, 5, 6, 7, 9] |
| if 0: |
| preserve = [1, 2, 3, 10, 5] |
| sm_mask_tgt = Image.open(mask_path).convert("L") |
| sm_mask_tgt = np.array(sm_mask_tgt) |
| if 0: |
| mask_tgt = dilate_4_task0(sm_mask_tgt) |
| else: |
| mask_tgt = np.isin(sm_mask_tgt, preserve) |
| if self.task == 0 and 1: |
| forehead_mask = get_forehead_mask(sm_mask_tgt) |
| mask_tgt = mask_tgt & ~forehead_mask |
| elif self.task == 1: |
| preserve = [4] |
| mask_tgt = path_img_2_mask(path_tgt, preserve) |
| elif self.task == 3: |
| preserve = [1, 2, 3, 10, 4, 5, 6, 7, 9] |
| mask_tgt = path_img_2_mask(path_tgt, preserve) |
| elif self.task == 2: |
| preserve = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 20, 21] |
| sm_mask_tgt = Image.open(mask_path).convert("L") |
| sm_mask_tgt = np.array(sm_mask_tgt) |
| mask_tgt = np.isin(sm_mask_tgt, preserve) |
|
|
| converted_mask = np.zeros_like(mask_tgt) |
| converted_mask[mask_tgt] = 255 |
| mask_tgt = Image.fromarray(converted_mask).convert("L") |
| mask_tensor = 1 - get_tensor(normalize=False, toTensor=True)(mask_tgt) |
|
|
| image_tensor = get_tensor(mean=self.mean, std=self.std)(img_tgt) |
| image_tensor_resize = T.Resize([self.image_size, self.image_size])(image_tensor) |
| mask_tensor_resize = T.Resize([self.image_size, self.image_size])(mask_tensor) |
|
|
| if task == 2: |
| inpaint_tensor_resize = image_tensor_resize |
| else: |
| inpaint_tensor_resize = image_tensor_resize * mask_tensor_resize |
| if 1: |
| mask_tensor_resize = 1 - mask_tensor_resize |
|
|
| if 1: |
| mask_path_ref = path_img_2_path_mask(path_ref) |
| sm_mask_ref = Image.open(mask_path_ref).convert("L") |
| sm_mask_ref = np.array(sm_mask_ref) |
| img_ref = cv2.imread(str(path_ref)) |
| img_ref = cv2.cvtColor(img_ref, cv2.COLOR_BGR2RGB) |
| img_ref = resize_A(img_ref, self.name) |
|
|
| if task != 2: |
| ref_image_tensor, ref_mask_tensor = self.get_img4clip( |
| img_ref, sm_mask_ref, preserve, for_clip=True, add_semantic_head=0 |
| ) |
| if task == 3: |
| ref_image_faceOnly_tensor, _ = self.get_img4clip( |
| img_ref, |
| sm_mask_ref, |
| [1, 2, 3, 10, 5, 6, 7, 9], |
| for_clip=False, |
| add_semantic_head=0, |
| ) |
| else: |
| ref_image_tensor = inpaint_tensor_resize |
|
|
| ret = { |
| "inpaint_image": inpaint_tensor_resize, |
| "inpaint_mask": mask_tensor_resize, |
| "ref_imgs": ref_image_tensor, |
| "task": self.task, |
| } |
|
|
| if self.task == 0: |
| ret["enInputs"] = { |
| "face_ID-in": ref_image_tensor, |
| "face-clip-in": ref_image_tensor, |
| } |
| elif self.task == 1: |
| ret["enInputs"] = { |
| "hair-clip-in": ref_image_tensor, |
| } |
| elif self.task == 2: |
| tgt_nonBg_tensor, _ = self.get_img4clip(img_tgt, sm_mask_tgt, preserve) |
| ret["enInputs"] = { |
| "face_ID-in": tgt_nonBg_tensor, |
| "head-clip-in": tgt_nonBg_tensor, |
| } |
| elif self.task == 3: |
| ret["enInputs"] = { |
| "face_ID-in": ref_image_faceOnly_tensor, |
| "head-clip-in": ref_image_tensor, |
| } |
|
|
| if (REFNET.ENABLE and REFNET.task2layerNum[task] > 0) or CH14: |
| if task != 2: |
| ref_imgs_4unet, ref_mask_4unet = self.get_img4clip( |
| img_ref, sm_mask_ref, preserve, for_clip=False, add_semantic_head=0 |
| ) |
| else: |
| ref_imgs_4unet, ref_mask_4unet = self.get_img4clip( |
| img_tgt, |
| sm_mask_tgt, |
| "any", |
| for_clip=False, |
| add_semantic_head=0, |
| mask_after_npisin=np.ones_like(sm_mask_tgt).astype(bool), |
| ) |
| ref_imgs_4unet = T.Resize([self.image_size, self.image_size])(ref_imgs_4unet) |
| ref_mask_512 = T.Resize([self.image_size, self.image_size])(ref_mask_4unet) |
| ret["ref_imgs_4unet"] = ref_imgs_4unet |
| ret["ref_mask_512"] = ref_mask_512 |
|
|
| if self.READ_mediapipe_result_from_cache: |
| if self.state == "test": |
| if task == 2: |
| _p_lmk = path_ref |
| else: |
| _p_lmk = path_tgt |
| else: |
| _p_lmk = path_tgt |
| ret["mediapipe_lmkAll"] = self.mediapipe_Result_Cache.get(_p_lmk) |
| if ret["mediapipe_lmkAll"] is None: |
| raise RuntimeError( |
| f"Missing Mediapipe cache for input image: {_p_lmk}. " |
| "Precompute landmarks and ensure cache exists before inference." |
| ) |
|
|
| if self.state == "test": |
| prior_image_tensor = "None" |
| out_stem = f"{Path(path_tgt).stem}-{Path(path_ref).stem}" |
| if task == 2: |
| ref512, _ = self.get_img4clip( |
| img_ref, sm_mask_ref, preserve, for_clip=False, add_semantic_head=0 |
| ) |
| ref512 = T.Resize([self.image_size, self.image_size])(ref512) |
| ret["ref512"] = ref512 |
| ret = (image_tensor_resize, prior_image_tensor, ret, out_stem) |
| return ret |
|
|
| def __len__(self): |
| return len(self.paths_tgt) |
|
|