UniBioTransfer / Dataset_custom.py
scy639's picture
Upload folder using huggingface_hub
2b534de verified
from confs import *
from pathlib import Path
import numpy as np
import cv2
from PIL import Image
import torch
import torch.utils.data as data
import torchvision.transforms as T
from einops import rearrange
import albumentations
from util_face import *
from util_4dataset import *
from util_cv2 import cv2_resize_auto_interpolation
from Mediapipe_Result_Cache import Mediapipe_Result_Cache
def resize_A(img, dataset_name, size=(512, 512), interpolation=None):
is_pil = isinstance(img, Image.Image)
if is_pil:
img = np.array(img)
if img.shape[:2] != (512, 512):
img = cv2_resize_auto_interpolation(img, size, interpolation=interpolation)
if is_pil:
img = Image.fromarray(img)
return img
def un_norm_clip(x1):
x = x1 * 1.0
reduce = False
if len(x.shape) == 3:
x = x.unsqueeze(0)
reduce = True
x[:, 0, :, :] = x[:, 0, :, :] * 0.26862954 + 0.48145466
x[:, 1, :, :] = x[:, 1, :, :] * 0.26130258 + 0.4578275
x[:, 2, :, :] = x[:, 2, :, :] * 0.27577711 + 0.40821073
if reduce:
x = x.squeeze(0)
return x
def un_norm(x):
return (x + 1.0) / 2.0
def _dilate(_mask, kernel_size, iterations):
_mask = _mask.astype(np.uint8)
kernel = np.ones((kernel_size, kernel_size), np.uint8)
_mask = cv2.dilate(_mask, kernel, iterations=iterations)
_mask = _mask.astype(bool)
return _mask
def dilate_4_task0(sm_mask):
sm_mask = np.array(sm_mask)
preserve1 = [2, 3, 10, 5]
mask1 = np.isin(sm_mask, preserve1)
mask1 = _dilate(mask1, 7, 1)
preserve2 = [3, 10]
mask2 = np.isin(sm_mask, preserve2)
mask2 = _dilate(mask2, 10, 3)
preserve3 = [1]
mask3 = np.isin(sm_mask, preserve3)
mask3 = _dilate(mask3, 7, 2)
mask = mask1 | mask2 | mask3
return mask
class Dataset_custom(data.Dataset):
mean = (0.5, 0.5, 0.5)
std = (0.5, 0.5, 0.5)
def get_img4clip(
self,
img,
sm_mask,
preserve,
for_clip=True,
add_semantic_head=False,
mask_after_npisin=None,
for_inpaint512=False,
):
sm_mask = np.array(sm_mask)
if mask_after_npisin is None:
if self.task == 0 and 0:
mask = dilate_4_task0(sm_mask)
else:
mask = np.isin(sm_mask, preserve)
if self.task == 0 and 1 and for_inpaint512:
forehead_mask = get_forehead_mask(sm_mask)
mask = mask & ~forehead_mask
else:
mask = mask_after_npisin
if isinstance(img, np.ndarray):
img = Image.fromarray(img)
if add_semantic_head:
mask_before_colorSM = mask
img, mask = add_colorSM(img, sm_mask, preserve, None)
mask = mask_after_npisin__2__tensor(mask)
if for_clip:
image_tensor = get_tensor_clip()(img)
else:
image_tensor = get_tensor(mean=self.mean, std=self.std)(img)
image_tensor = T.Resize([512, 512])(image_tensor)
image_tensor = image_tensor * mask
if for_clip:
image_tensor = 255.0 * rearrange(un_norm_clip(image_tensor), "c h w -> h w c").cpu().numpy()
_size = 224
else:
image_tensor = 255.0 * rearrange(un_norm(image_tensor), "c h w -> h w c").cpu().numpy()
_size = 512
image_tensor = albumentations.Resize(height=_size, width=_size)(image=image_tensor)
image_tensor = Image.fromarray(image_tensor["image"].astype(np.uint8))
if for_clip:
image_tensor = get_tensor_clip()(image_tensor)
else:
image_tensor = get_tensor(mean=self.mean, std=self.std)(image_tensor)
image_tensor = image_tensor * mask
if add_semantic_head:
mask = mask_after_npisin__2__tensor(mask_before_colorSM)
return image_tensor, mask
def __init__(
self,
state,
task,
paths_tgt,
paths_ref,
name="custom",
):
if task == 0:
USE_filter_mediapipe_fail_swap = 1
USE_pts = 1
READ_mediapipe_result_from_cache = 1
elif task == 1:
USE_filter_mediapipe_fail_swap = 0
USE_pts = 0
READ_mediapipe_result_from_cache = 1
elif task == 2:
USE_filter_mediapipe_fail_swap = 1
USE_pts = 1
READ_mediapipe_result_from_cache = 1
elif task == 3:
USE_filter_mediapipe_fail_swap = 0
USE_pts = 1
READ_mediapipe_result_from_cache = 1
self.READ_mediapipe_result_from_cache = READ_mediapipe_result_from_cache
assert state == "test"
self.state = state
self.image_size = 512
self.kernel = np.ones((1, 1), np.uint8)
self.name = name
assert paths_tgt is not None and paths_ref is not None, "paths_tgt and paths_ref are required"
assert len(paths_tgt) == len(paths_ref), "paths_tgt and paths_ref must be the same length"
self.paths_tgt = list(paths_tgt)
self.paths_ref = list(paths_ref)
if READ_mediapipe_result_from_cache:
self.mediapipe_Result_Cache = Mediapipe_Result_Cache()
self.task = task
def __getitem__(self, index):
task = self.task
path_tgt = self.paths_tgt[index]
path_ref = self.paths_ref[index]
img_tgt = Image.open(path_tgt).convert("RGB")
img_tgt = resize_A(img_tgt, self.name)
mask_path = path_img_2_path_mask(path_tgt)
if self.task == 0:
preserve = [1, 2, 3, 10, 5, 6, 7, 9]
if 0:
preserve = [1, 2, 3, 10, 5]
sm_mask_tgt = Image.open(mask_path).convert("L")
sm_mask_tgt = np.array(sm_mask_tgt)
if 0:
mask_tgt = dilate_4_task0(sm_mask_tgt)
else:
mask_tgt = np.isin(sm_mask_tgt, preserve)
if self.task == 0 and 1:
forehead_mask = get_forehead_mask(sm_mask_tgt)
mask_tgt = mask_tgt & ~forehead_mask
elif self.task == 1:
preserve = [4]
mask_tgt = path_img_2_mask(path_tgt, preserve)
elif self.task == 3:
preserve = [1, 2, 3, 10, 4, 5, 6, 7, 9]
mask_tgt = path_img_2_mask(path_tgt, preserve)
elif self.task == 2:
preserve = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 20, 21]
sm_mask_tgt = Image.open(mask_path).convert("L")
sm_mask_tgt = np.array(sm_mask_tgt)
mask_tgt = np.isin(sm_mask_tgt, preserve)
converted_mask = np.zeros_like(mask_tgt)
converted_mask[mask_tgt] = 255
mask_tgt = Image.fromarray(converted_mask).convert("L")
mask_tensor = 1 - get_tensor(normalize=False, toTensor=True)(mask_tgt)
image_tensor = get_tensor(mean=self.mean, std=self.std)(img_tgt)
image_tensor_resize = T.Resize([self.image_size, self.image_size])(image_tensor)
mask_tensor_resize = T.Resize([self.image_size, self.image_size])(mask_tensor)
if task == 2:
inpaint_tensor_resize = image_tensor_resize
else:
inpaint_tensor_resize = image_tensor_resize * mask_tensor_resize
if 1:
mask_tensor_resize = 1 - mask_tensor_resize
if 1:
mask_path_ref = path_img_2_path_mask(path_ref)
sm_mask_ref = Image.open(mask_path_ref).convert("L")
sm_mask_ref = np.array(sm_mask_ref)
img_ref = cv2.imread(str(path_ref))
img_ref = cv2.cvtColor(img_ref, cv2.COLOR_BGR2RGB)
img_ref = resize_A(img_ref, self.name)
if task != 2:
ref_image_tensor, ref_mask_tensor = self.get_img4clip(
img_ref, sm_mask_ref, preserve, for_clip=True, add_semantic_head=0
)
if task == 3:
ref_image_faceOnly_tensor, _ = self.get_img4clip(
img_ref,
sm_mask_ref,
[1, 2, 3, 10, 5, 6, 7, 9],
for_clip=False,
add_semantic_head=0,
)
else:
ref_image_tensor = inpaint_tensor_resize
ret = {
"inpaint_image": inpaint_tensor_resize,
"inpaint_mask": mask_tensor_resize,
"ref_imgs": ref_image_tensor,
"task": self.task,
}
if self.task == 0:
ret["enInputs"] = {
"face_ID-in": ref_image_tensor,
"face-clip-in": ref_image_tensor,
}
elif self.task == 1:
ret["enInputs"] = {
"hair-clip-in": ref_image_tensor,
}
elif self.task == 2:
tgt_nonBg_tensor, _ = self.get_img4clip(img_tgt, sm_mask_tgt, preserve)
ret["enInputs"] = {
"face_ID-in": tgt_nonBg_tensor,
"head-clip-in": tgt_nonBg_tensor,
}
elif self.task == 3:
ret["enInputs"] = {
"face_ID-in": ref_image_faceOnly_tensor,
"head-clip-in": ref_image_tensor,
}
if (REFNET.ENABLE and REFNET.task2layerNum[task] > 0) or CH14:
if task != 2:
ref_imgs_4unet, ref_mask_4unet = self.get_img4clip(
img_ref, sm_mask_ref, preserve, for_clip=False, add_semantic_head=0
)
else:
ref_imgs_4unet, ref_mask_4unet = self.get_img4clip(
img_tgt,
sm_mask_tgt,
"any",
for_clip=False,
add_semantic_head=0,
mask_after_npisin=np.ones_like(sm_mask_tgt).astype(bool),
)
ref_imgs_4unet = T.Resize([self.image_size, self.image_size])(ref_imgs_4unet)
ref_mask_512 = T.Resize([self.image_size, self.image_size])(ref_mask_4unet)
ret["ref_imgs_4unet"] = ref_imgs_4unet
ret["ref_mask_512"] = ref_mask_512
if self.READ_mediapipe_result_from_cache:
if self.state == "test":
if task == 2:
_p_lmk = path_ref
else:
_p_lmk = path_tgt
else:
_p_lmk = path_tgt
ret["mediapipe_lmkAll"] = self.mediapipe_Result_Cache.get(_p_lmk)
if ret["mediapipe_lmkAll"] is None:
raise RuntimeError(
f"Missing Mediapipe cache for input image: {_p_lmk}. "
"Precompute landmarks and ensure cache exists before inference."
)
if self.state == "test":
prior_image_tensor = "None"
out_stem = f"{Path(path_tgt).stem}-{Path(path_ref).stem}"
if task == 2:
ref512, _ = self.get_img4clip(
img_ref, sm_mask_ref, preserve, for_clip=False, add_semantic_head=0
)
ref512 = T.Resize([self.image_size, self.image_size])(ref512)
ret["ref512"] = ref512
ret = (image_tensor_resize, prior_image_tensor, ret, out_stem)
return ret
def __len__(self):
return len(self.paths_tgt)