from confs import * from pathlib import Path import numpy as np import cv2 from PIL import Image import torch import torch.utils.data as data import torchvision.transforms as T from einops import rearrange import albumentations from util_face import * from util_4dataset import * from util_cv2 import cv2_resize_auto_interpolation from Mediapipe_Result_Cache import Mediapipe_Result_Cache def resize_A(img, dataset_name, size=(512, 512), interpolation=None): is_pil = isinstance(img, Image.Image) if is_pil: img = np.array(img) if img.shape[:2] != (512, 512): img = cv2_resize_auto_interpolation(img, size, interpolation=interpolation) if is_pil: img = Image.fromarray(img) return img def un_norm_clip(x1): x = x1 * 1.0 reduce = False if len(x.shape) == 3: x = x.unsqueeze(0) reduce = True x[:, 0, :, :] = x[:, 0, :, :] * 0.26862954 + 0.48145466 x[:, 1, :, :] = x[:, 1, :, :] * 0.26130258 + 0.4578275 x[:, 2, :, :] = x[:, 2, :, :] * 0.27577711 + 0.40821073 if reduce: x = x.squeeze(0) return x def un_norm(x): return (x + 1.0) / 2.0 def _dilate(_mask, kernel_size, iterations): _mask = _mask.astype(np.uint8) kernel = np.ones((kernel_size, kernel_size), np.uint8) _mask = cv2.dilate(_mask, kernel, iterations=iterations) _mask = _mask.astype(bool) return _mask def dilate_4_task0(sm_mask): sm_mask = np.array(sm_mask) preserve1 = [2, 3, 10, 5] mask1 = np.isin(sm_mask, preserve1) mask1 = _dilate(mask1, 7, 1) preserve2 = [3, 10] mask2 = np.isin(sm_mask, preserve2) mask2 = _dilate(mask2, 10, 3) preserve3 = [1] mask3 = np.isin(sm_mask, preserve3) mask3 = _dilate(mask3, 7, 2) mask = mask1 | mask2 | mask3 return mask class Dataset_custom(data.Dataset): mean = (0.5, 0.5, 0.5) std = (0.5, 0.5, 0.5) def get_img4clip( self, img, sm_mask, preserve, for_clip=True, add_semantic_head=False, mask_after_npisin=None, for_inpaint512=False, ): sm_mask = np.array(sm_mask) if mask_after_npisin is None: if self.task == 0 and 0: mask = dilate_4_task0(sm_mask) else: mask = np.isin(sm_mask, preserve) if self.task == 0 and 1 and for_inpaint512: forehead_mask = get_forehead_mask(sm_mask) mask = mask & ~forehead_mask else: mask = mask_after_npisin if isinstance(img, np.ndarray): img = Image.fromarray(img) if add_semantic_head: mask_before_colorSM = mask img, mask = add_colorSM(img, sm_mask, preserve, None) mask = mask_after_npisin__2__tensor(mask) if for_clip: image_tensor = get_tensor_clip()(img) else: image_tensor = get_tensor(mean=self.mean, std=self.std)(img) image_tensor = T.Resize([512, 512])(image_tensor) image_tensor = image_tensor * mask if for_clip: image_tensor = 255.0 * rearrange(un_norm_clip(image_tensor), "c h w -> h w c").cpu().numpy() _size = 224 else: image_tensor = 255.0 * rearrange(un_norm(image_tensor), "c h w -> h w c").cpu().numpy() _size = 512 image_tensor = albumentations.Resize(height=_size, width=_size)(image=image_tensor) image_tensor = Image.fromarray(image_tensor["image"].astype(np.uint8)) if for_clip: image_tensor = get_tensor_clip()(image_tensor) else: image_tensor = get_tensor(mean=self.mean, std=self.std)(image_tensor) image_tensor = image_tensor * mask if add_semantic_head: mask = mask_after_npisin__2__tensor(mask_before_colorSM) return image_tensor, mask def __init__( self, state, task, paths_tgt, paths_ref, name="custom", ): if task == 0: USE_filter_mediapipe_fail_swap = 1 USE_pts = 1 READ_mediapipe_result_from_cache = 1 elif task == 1: USE_filter_mediapipe_fail_swap = 0 USE_pts = 0 READ_mediapipe_result_from_cache = 1 elif task == 2: USE_filter_mediapipe_fail_swap = 1 USE_pts = 1 READ_mediapipe_result_from_cache = 1 elif task == 3: USE_filter_mediapipe_fail_swap = 0 USE_pts = 1 READ_mediapipe_result_from_cache = 1 self.READ_mediapipe_result_from_cache = READ_mediapipe_result_from_cache assert state == "test" self.state = state self.image_size = 512 self.kernel = np.ones((1, 1), np.uint8) self.name = name assert paths_tgt is not None and paths_ref is not None, "paths_tgt and paths_ref are required" assert len(paths_tgt) == len(paths_ref), "paths_tgt and paths_ref must be the same length" self.paths_tgt = list(paths_tgt) self.paths_ref = list(paths_ref) if READ_mediapipe_result_from_cache: self.mediapipe_Result_Cache = Mediapipe_Result_Cache() self.task = task def __getitem__(self, index): task = self.task path_tgt = self.paths_tgt[index] path_ref = self.paths_ref[index] img_tgt = Image.open(path_tgt).convert("RGB") img_tgt = resize_A(img_tgt, self.name) mask_path = path_img_2_path_mask(path_tgt) if self.task == 0: preserve = [1, 2, 3, 10, 5, 6, 7, 9] if 0: preserve = [1, 2, 3, 10, 5] sm_mask_tgt = Image.open(mask_path).convert("L") sm_mask_tgt = np.array(sm_mask_tgt) if 0: mask_tgt = dilate_4_task0(sm_mask_tgt) else: mask_tgt = np.isin(sm_mask_tgt, preserve) if self.task == 0 and 1: forehead_mask = get_forehead_mask(sm_mask_tgt) mask_tgt = mask_tgt & ~forehead_mask elif self.task == 1: preserve = [4] mask_tgt = path_img_2_mask(path_tgt, preserve) elif self.task == 3: preserve = [1, 2, 3, 10, 4, 5, 6, 7, 9] mask_tgt = path_img_2_mask(path_tgt, preserve) elif self.task == 2: preserve = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 20, 21] sm_mask_tgt = Image.open(mask_path).convert("L") sm_mask_tgt = np.array(sm_mask_tgt) mask_tgt = np.isin(sm_mask_tgt, preserve) converted_mask = np.zeros_like(mask_tgt) converted_mask[mask_tgt] = 255 mask_tgt = Image.fromarray(converted_mask).convert("L") mask_tensor = 1 - get_tensor(normalize=False, toTensor=True)(mask_tgt) image_tensor = get_tensor(mean=self.mean, std=self.std)(img_tgt) image_tensor_resize = T.Resize([self.image_size, self.image_size])(image_tensor) mask_tensor_resize = T.Resize([self.image_size, self.image_size])(mask_tensor) if task == 2: inpaint_tensor_resize = image_tensor_resize else: inpaint_tensor_resize = image_tensor_resize * mask_tensor_resize if 1: mask_tensor_resize = 1 - mask_tensor_resize if 1: mask_path_ref = path_img_2_path_mask(path_ref) sm_mask_ref = Image.open(mask_path_ref).convert("L") sm_mask_ref = np.array(sm_mask_ref) img_ref = cv2.imread(str(path_ref)) img_ref = cv2.cvtColor(img_ref, cv2.COLOR_BGR2RGB) img_ref = resize_A(img_ref, self.name) if task != 2: ref_image_tensor, ref_mask_tensor = self.get_img4clip( img_ref, sm_mask_ref, preserve, for_clip=True, add_semantic_head=0 ) if task == 3: ref_image_faceOnly_tensor, _ = self.get_img4clip( img_ref, sm_mask_ref, [1, 2, 3, 10, 5, 6, 7, 9], for_clip=False, add_semantic_head=0, ) else: ref_image_tensor = inpaint_tensor_resize ret = { "inpaint_image": inpaint_tensor_resize, "inpaint_mask": mask_tensor_resize, "ref_imgs": ref_image_tensor, "task": self.task, } if self.task == 0: ret["enInputs"] = { "face_ID-in": ref_image_tensor, "face-clip-in": ref_image_tensor, } elif self.task == 1: ret["enInputs"] = { "hair-clip-in": ref_image_tensor, } elif self.task == 2: tgt_nonBg_tensor, _ = self.get_img4clip(img_tgt, sm_mask_tgt, preserve) ret["enInputs"] = { "face_ID-in": tgt_nonBg_tensor, "head-clip-in": tgt_nonBg_tensor, } elif self.task == 3: ret["enInputs"] = { "face_ID-in": ref_image_faceOnly_tensor, "head-clip-in": ref_image_tensor, } if (REFNET.ENABLE and REFNET.task2layerNum[task] > 0) or CH14: if task != 2: ref_imgs_4unet, ref_mask_4unet = self.get_img4clip( img_ref, sm_mask_ref, preserve, for_clip=False, add_semantic_head=0 ) else: ref_imgs_4unet, ref_mask_4unet = self.get_img4clip( img_tgt, sm_mask_tgt, "any", for_clip=False, add_semantic_head=0, mask_after_npisin=np.ones_like(sm_mask_tgt).astype(bool), ) ref_imgs_4unet = T.Resize([self.image_size, self.image_size])(ref_imgs_4unet) ref_mask_512 = T.Resize([self.image_size, self.image_size])(ref_mask_4unet) ret["ref_imgs_4unet"] = ref_imgs_4unet ret["ref_mask_512"] = ref_mask_512 if self.READ_mediapipe_result_from_cache: if self.state == "test": if task == 2: _p_lmk = path_ref else: _p_lmk = path_tgt else: _p_lmk = path_tgt ret["mediapipe_lmkAll"] = self.mediapipe_Result_Cache.get(_p_lmk) if ret["mediapipe_lmkAll"] is None: raise RuntimeError( f"Missing Mediapipe cache for input image: {_p_lmk}. " "Precompute landmarks and ensure cache exists before inference." ) if self.state == "test": prior_image_tensor = "None" out_stem = f"{Path(path_tgt).stem}-{Path(path_ref).stem}" if task == 2: ref512, _ = self.get_img4clip( img_ref, sm_mask_ref, preserve, for_clip=False, add_semantic_head=0 ) ref512 = T.Resize([self.image_size, self.image_size])(ref512) ret["ref512"] = ref512 ret = (image_tensor_resize, prior_image_tensor, ret, out_stem) return ret def __len__(self): return len(self.paths_tgt)