diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..0f3678cc0cda886e0cbb1654c0c0f76f525eae9f 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,12 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +Other_dependencies/mp_models/face_landmarker_v2_with_blendshapes.task filter=lfs diff=lfs merge=lfs -text +examples/face/ref.png filter=lfs diff=lfs merge=lfs -text +examples/face/tgt.png filter=lfs diff=lfs merge=lfs -text +examples/hair/ref.png filter=lfs diff=lfs merge=lfs -text +examples/hair/tgt.png filter=lfs diff=lfs merge=lfs -text +examples/head/ref.png filter=lfs diff=lfs merge=lfs -text +examples/head/tgt.png filter=lfs diff=lfs merge=lfs -text +examples/motion/ref.png filter=lfs diff=lfs merge=lfs -text +examples/motion/tgt.png filter=lfs diff=lfs merge=lfs -text diff --git a/Dataset_custom.py b/Dataset_custom.py new file mode 100644 index 0000000000000000000000000000000000000000..647f9dff607fe1f986b146ff7f26a6bf9b11085f --- /dev/null +++ b/Dataset_custom.py @@ -0,0 +1,317 @@ +from imports import * +from pathlib import Path +import numpy as np +import cv2 +from PIL import Image +import torch +import torch.utils.data as data +import torchvision.transforms as T +from einops import rearrange +import albumentations + +from util_face import * +from util_4dataset import * +from util_cv2 import cv2_resize_auto_interpolation +from Mediapipe_Result_Cache import Mediapipe_Result_Cache + + +def resize_A(img, dataset_name, size=(512, 512), interpolation=None): + is_pil = isinstance(img, Image.Image) + if is_pil: + img = np.array(img) + if img.shape[:2] != (512, 512): + img = cv2_resize_auto_interpolation(img, size, interpolation=interpolation) + if is_pil: + img = Image.fromarray(img) + return img + + +def un_norm_clip(x1): + x = x1 * 1.0 + reduce = False + if len(x.shape) == 3: + x = x.unsqueeze(0) + reduce = True + x[:, 0, :, :] = x[:, 0, :, :] * 0.26862954 + 0.48145466 + x[:, 1, :, :] = x[:, 1, :, :] * 0.26130258 + 0.4578275 + x[:, 2, :, :] = x[:, 2, :, :] * 0.27577711 + 0.40821073 + if reduce: + x = x.squeeze(0) + return x + + +def un_norm(x): + return (x + 1.0) / 2.0 + + +def _dilate(_mask, kernel_size, iterations): + _mask = _mask.astype(np.uint8) + kernel = np.ones((kernel_size, kernel_size), np.uint8) + _mask = cv2.dilate(_mask, kernel, iterations=iterations) + _mask = _mask.astype(bool) + return _mask + + +def dilate_4_task0(sm_mask): + sm_mask = np.array(sm_mask) + preserve1 = [2, 3, 10, 5] + mask1 = np.isin(sm_mask, preserve1) + mask1 = _dilate(mask1, 7, 1) + preserve2 = [3, 10] + mask2 = np.isin(sm_mask, preserve2) + mask2 = _dilate(mask2, 10, 3) + preserve3 = [1] + mask3 = np.isin(sm_mask, preserve3) + mask3 = _dilate(mask3, 7, 2) + mask = mask1 | mask2 | mask3 + return mask + + +class Dataset_custom(data.Dataset): + mean = (0.5, 0.5, 0.5) + std = (0.5, 0.5, 0.5) + + def get_img4clip( + self, + img, + sm_mask, + preserve, + for_clip=True, + add_semantic_head=False, + mask_after_npisin=None, + for_inpaint512=False, + ): + sm_mask = np.array(sm_mask) + if mask_after_npisin is None: + if self.task == 0 and 0: + mask = dilate_4_task0(sm_mask) + else: + mask = np.isin(sm_mask, preserve) + if self.task == 0 and 1 and for_inpaint512: + forehead_mask = get_forehead_mask(sm_mask) + mask = mask & ~forehead_mask + else: + mask = mask_after_npisin + + if isinstance(img, np.ndarray): + img = Image.fromarray(img) + if add_semantic_head: + mask_before_colorSM = mask + img, mask = add_colorSM(img, sm_mask, preserve, None) + mask = mask_after_npisin__2__tensor(mask) + + if for_clip: + image_tensor = get_tensor_clip()(img) + else: + image_tensor = get_tensor(mean=self.mean, std=self.std)(img) + image_tensor = T.Resize([512, 512])(image_tensor) + image_tensor = image_tensor * mask + if for_clip: + image_tensor = 255.0 * rearrange(un_norm_clip(image_tensor), "c h w -> h w c").cpu().numpy() + _size = 224 + else: + image_tensor = 255.0 * rearrange(un_norm(image_tensor), "c h w -> h w c").cpu().numpy() + _size = 512 + + image_tensor = albumentations.Resize(height=_size, width=_size)(image=image_tensor) + image_tensor = Image.fromarray(image_tensor["image"].astype(np.uint8)) + if for_clip: + image_tensor = get_tensor_clip()(image_tensor) + else: + image_tensor = get_tensor(mean=self.mean, std=self.std)(image_tensor) + image_tensor = image_tensor * mask + if add_semantic_head: + mask = mask_after_npisin__2__tensor(mask_before_colorSM) + return image_tensor, mask + + def __init__( + self, + state, + task, + paths_tgt, + paths_ref, + name="custom", + ): + if task == 0: + USE_filter_mediapipe_fail_swap = 1 + USE_pts = 1 + READ_mediapipe_result_from_cache = 1 + elif task == 1: + USE_filter_mediapipe_fail_swap = 0 + USE_pts = 0 + READ_mediapipe_result_from_cache = 1 + elif task == 2: + USE_filter_mediapipe_fail_swap = 1 + USE_pts = 1 + READ_mediapipe_result_from_cache = 1 + elif task == 3: + USE_filter_mediapipe_fail_swap = 0 + USE_pts = 1 + READ_mediapipe_result_from_cache = 1 + self.READ_mediapipe_result_from_cache = READ_mediapipe_result_from_cache + + assert state == "test" + self.state = state + self.image_size = 512 + self.kernel = np.ones((1, 1), np.uint8) + self.name = name + + assert paths_tgt is not None and paths_ref is not None, "paths_tgt and paths_ref are required" + assert len(paths_tgt) == len(paths_ref), "paths_tgt and paths_ref must be the same length" + self.paths_tgt = list(paths_tgt) + self.paths_ref = list(paths_ref) + + if READ_mediapipe_result_from_cache: + self.mediapipe_Result_Cache = Mediapipe_Result_Cache() + self.task = task + + def __getitem__(self, index): + task = self.task + path_tgt = self.paths_tgt[index] + path_ref = self.paths_ref[index] + + + img_tgt = Image.open(path_tgt).convert("RGB") + img_tgt = resize_A(img_tgt, self.name) + + mask_path = path_img_2_path_mask(path_tgt) + if self.task == 0: + preserve = [1, 2, 3, 10, 5, 6, 7, 9] + if 0: + preserve = [1, 2, 3, 10, 5] + sm_mask_tgt = Image.open(mask_path).convert("L") + sm_mask_tgt = np.array(sm_mask_tgt) + if 0: + mask_tgt = dilate_4_task0(sm_mask_tgt) + else: + mask_tgt = np.isin(sm_mask_tgt, preserve) + if self.task == 0 and 1: + forehead_mask = get_forehead_mask(sm_mask_tgt) + mask_tgt = mask_tgt & ~forehead_mask + elif self.task == 1: + preserve = [4] + mask_tgt = path_img_2_mask(path_tgt, preserve) + elif self.task == 3: + preserve = [1, 2, 3, 10, 4, 5, 6, 7, 9] + mask_tgt = path_img_2_mask(path_tgt, preserve) + elif self.task == 2: + preserve = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 20, 21] + sm_mask_tgt = Image.open(mask_path).convert("L") + sm_mask_tgt = np.array(sm_mask_tgt) + mask_tgt = np.isin(sm_mask_tgt, preserve) + + converted_mask = np.zeros_like(mask_tgt) + converted_mask[mask_tgt] = 255 + mask_tgt = Image.fromarray(converted_mask).convert("L") + mask_tensor = 1 - get_tensor(normalize=False, toTensor=True)(mask_tgt) + + image_tensor = get_tensor(mean=self.mean, std=self.std)(img_tgt) + image_tensor_resize = T.Resize([self.image_size, self.image_size])(image_tensor) + mask_tensor_resize = T.Resize([self.image_size, self.image_size])(mask_tensor) + + if task == 2: + inpaint_tensor_resize = image_tensor_resize + else: + inpaint_tensor_resize = image_tensor_resize * mask_tensor_resize + if 1: + mask_tensor_resize = 1 - mask_tensor_resize + + if 1: + mask_path_ref = path_img_2_path_mask(path_ref) + sm_mask_ref = Image.open(mask_path_ref).convert("L") + sm_mask_ref = np.array(sm_mask_ref) + img_ref = cv2.imread(str(path_ref)) + img_ref = cv2.cvtColor(img_ref, cv2.COLOR_BGR2RGB) + img_ref = resize_A(img_ref, self.name) + + if task != 2: + ref_image_tensor, ref_mask_tensor = self.get_img4clip( + img_ref, sm_mask_ref, preserve, for_clip=True, add_semantic_head=0 + ) + if task == 3: + ref_image_faceOnly_tensor, _ = self.get_img4clip( + img_ref, + sm_mask_ref, + [1, 2, 3, 10, 5, 6, 7, 9], + for_clip=False, + add_semantic_head=0, + ) + else: + ref_image_tensor = inpaint_tensor_resize + + ret = { + "inpaint_image": inpaint_tensor_resize, + "inpaint_mask": mask_tensor_resize, + "ref_imgs": ref_image_tensor, + "task": self.task, + } + + if self.task == 0: + ret["enInputs"] = { + "face_ID-in": ref_image_tensor, + "face-clip-in": ref_image_tensor, + } + elif self.task == 1: + ret["enInputs"] = { + "hair-clip-in": ref_image_tensor, + } + elif self.task == 2: + tgt_nonBg_tensor, _ = self.get_img4clip(img_tgt, sm_mask_tgt, preserve) + ret["enInputs"] = { + "face_ID-in": tgt_nonBg_tensor, + "head-clip-in": tgt_nonBg_tensor, + } + elif self.task == 3: + ret["enInputs"] = { + "face_ID-in": ref_image_faceOnly_tensor, + "head-clip-in": ref_image_tensor, + } + + if (REFNET.ENABLE and REFNET.task2layerNum[task] > 0) or CH14: + if task != 2: + ref_imgs_4unet, ref_mask_4unet = self.get_img4clip( + img_ref, sm_mask_ref, preserve, for_clip=False, add_semantic_head=0 + ) + else: + ref_imgs_4unet, ref_mask_4unet = self.get_img4clip( + img_tgt, + sm_mask_tgt, + "any", + for_clip=False, + add_semantic_head=0, + mask_after_npisin=np.ones_like(sm_mask_tgt).astype(bool), + ) + ref_imgs_4unet = T.Resize([self.image_size, self.image_size])(ref_imgs_4unet) + ref_mask_512 = T.Resize([self.image_size, self.image_size])(ref_mask_4unet) + ret["ref_imgs_4unet"] = ref_imgs_4unet + ret["ref_mask_512"] = ref_mask_512 + + if self.READ_mediapipe_result_from_cache: + if self.state == "test": + if task == 2: + _p_lmk = path_ref + else: + _p_lmk = path_tgt + else: + _p_lmk = path_tgt + ret["mediapipe_lmkAll"] = self.mediapipe_Result_Cache.get(_p_lmk) + if ret["mediapipe_lmkAll"] is None: + raise RuntimeError( + f"Missing Mediapipe cache for input image: {_p_lmk}. " + "Precompute landmarks and ensure cache exists before inference." + ) + + if self.state == "test": + prior_image_tensor = "None" + out_stem = f"{Path(path_tgt).stem}-{Path(path_ref).stem}" + if task == 2: + ref512, _ = self.get_img4clip( + img_ref, sm_mask_ref, preserve, for_clip=False, add_semantic_head=0 + ) + ref512 = T.Resize([self.image_size, self.image_size])(ref512) + ret["ref512"] = ref512 + ret = (image_tensor_resize, prior_image_tensor, ret, out_stem) + return ret + + def __len__(self): + return len(self.paths_tgt) diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..a2b53eca07f8ce1d0e74dea9900f0dd2b7216a17 --- /dev/null +++ b/LICENSE @@ -0,0 +1,23 @@ +MIT License + +Copyright (c) 2024 Sanoojan + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + + diff --git a/LatentDiffusion.yaml b/LatentDiffusion.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e13076acdb53c5d41daa82b778bd685ad19cfa53 --- /dev/null +++ b/LatentDiffusion.yaml @@ -0,0 +1,83 @@ +model: + base_learning_rate: 4.0e-04 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "inpaint" + cond_stage_key: "image" + image_size: 64 + channels: 4 + cond_stage_trainable: true # Note: different from the one we trained before + conditioning_key: crossattn + monitor: val/loss_simple_ema + u_cond_percent: 0.2 + scale_factor: 0.18215 + use_ema: False + + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 10000 ] + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-1 ] + f_max: [ 1. ] + f_min: [ 1. ] + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 # unused + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: True + legacy: False + add_conv_in_front_of_unet: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenCLIPEmbedder + other_params: + clip_weight: 1.0 + arcface_path: "Other_dependencies/arcface/model_ir_se50.pth" + multi_scale_ID: False # True was used for the previous training there is an issue + Additional_config: + Reconstruct_initial: False # scy: + Target_CLIP_feat: True + Source_CLIP_feat: True + Reconstruct_DDIM_steps: 4 + + \ No newline at end of file diff --git a/Mediapipe_Result_Cache.py b/Mediapipe_Result_Cache.py new file mode 100644 index 0000000000000000000000000000000000000000..fe2cb7756b60f5bcd61551ce7ccecc023d7b0566 --- /dev/null +++ b/Mediapipe_Result_Cache.py @@ -0,0 +1,36 @@ +from imports import * +import json,random,os +import numpy as np + + + +class Mediapipe_Result_Cache: + """ + Convention: when a cache entry exists, it must not be None. + In other words, None results should not be cached; get/set guard against historical None values. + """ + # DIR = Path('/inspurfs/group/mayuexin/suncy/mediapipe_result/A') + DIR = Path('data/mediapipe_result') + def __init__(self): + pass + def get_path(self, img_path): + img_path = Path(img_path) + str_img_folder = str(img_path.parent) + assert '|' not in str_img_folder + str_img_folder = str_img_folder.replace('/', '|') + lmk_folder = self.DIR / str_img_folder + lmk_folder.mkdir(parents=1, exist_ok=True) + ret= lmk_folder / (img_path.name+'.npy') + return ret + def get(self, img_path): + path = self.get_path(img_path) + # print(f"[get] {path=}") + if path.exists(): + ret = np.load(path) + assert ret is not None + return ret + def set(self, img_path, lmks): + assert lmks is not None + path = self.get_path(img_path) + np.save(path, lmks) + # print(f"{path=}") diff --git a/MoE.py b/MoE.py new file mode 100644 index 0000000000000000000000000000000000000000..7659ecc3c0011f14b98293ddeeb793862c4ba271 --- /dev/null +++ b/MoE.py @@ -0,0 +1,141 @@ +from imports import * +import global_ +import torch,copy +import torch.nn as nn +from ldm.modules.attention import FeedForward,CrossAttention +from ldm.modules.diffusionmodules.openaimodel import UNetModel,ResBlock,TimestepEmbedSequential +# import torch.nn.functional as F + +# ---------------- Configs ---------------- +CONV2D_PARAM_STATS = [] + +def average_module_weight(src_modules: list): + """Average the weights of multiple modules (similar to init_model.py).""" + if not src_modules: + return None + avg_state_dict = {} + first_state_dict = src_modules[0].state_dict() + for key in first_state_dict: + avg_state_dict[key] = torch.zeros_like(first_state_dict[key]) + for module in src_modules: + module_state_dict = module.state_dict() + for key in avg_state_dict: + avg_state_dict[key] += module_state_dict[key] + for key in avg_state_dict: + avg_state_dict[key] /= len(src_modules) + return avg_state_dict + +class ModuleDict_W(nn.Module): # Wrapper of ModuleDict + def __init__(self, modules: list, keys: list): + super().__init__() + assert len(keys) == len(modules), f"{len(keys)=} {len(modules)=}" + self._keys = [int(k) for k in keys] + self._moduleDict = nn.ModuleDict({str(int(k)): m for k, m in zip(self._keys, modules)}) + def __getitem__(self, k: int): + _k = str(int(k)) + return self._moduleDict[_k] + def keys(self): + return list(self._keys) + def forward(self, *args, **kwargs): + cur_task = global_.task + assert cur_task in self._keys, f"Current task {cur_task} not in available tasks {self._keys}" + return self._moduleDict[str(int(cur_task))](*args, **kwargs) + def offload_unused_tasks(self, unused_tasks, method: str): + for i in unused_tasks: + _k = str(int(i)) + if _k in self._moduleDict: + if method == 'del': + # self._moduleDict[_k] = None # should behave the same either way + del self._moduleDict[_k] + elif method == 'cpu': + self._moduleDict[_k].to('cpu') + else: + raise + +class TaskSpecific_MoE(nn.Module): + def __init__( + self, + module:nn.Module,# or list of Module + tasks:tuple, + ): + super().__init__() + self.cur_task = None + self.tasks = tasks + if isinstance(module, nn.Module): + modules = [copy.deepcopy(module) for _ in self.tasks] + elif isinstance(module, list): + assert len(module) == len(self.tasks), f"got {len(module)} and {len(self.tasks)}" + modules = module + else: + raise ValueError(f"got {type(module)}") + self.tasks_2_module = ModuleDict_W(modules, self.tasks) + + def forward(self, *args, **kwargs) -> torch.Tensor: + # cur_task = self.cur_task + cur_task = global_.task + assert cur_task in self.tasks, f"Current task {cur_task} not in available tasks {self.tasks}" + return self.tasks_2_module[cur_task](*args, **kwargs) + + def set_task(self, task): + assert 0, 'set_task is disabled for now; update to gg.task instead' + # assert task in self.tasks, f"Task {task} not in available tasks {self.tasks}" + self.cur_task = task + +def is_task_specific_(name:str): + is_task_specific = ( + ('._moduleDict.' in name) or + ('tasks_2_module' in name) or + ('task_ffn' in name) or + ('task_proj' in name) or + ('task_conv' in name) or + ('task_gate_mlps' in name) or + ('task_lora' in name) or + + ('encoder_clip_' in name) or + ('proj_out_source__' in name) or + ('ID_proj_out' in name) or + ('landmark_proj_out' in name) or + ('learnable_vector' in name) + ) + return is_task_specific +def tp_param_need_sync(name: str, p: torch.nn.Parameter): + if is_task_specific_(name): + return False, True + if 'first_stage_model' in name or 'face_ID_model' in name or 'encoder_clip_face.tokenizer' in name or 'encoder_clip_face.model' in name: + return False, False + if not p.requires_grad: + return False, False + return True, False +def offload_unused_tasks(parent: nn.Module, active_task: int, method: str, ): + unused_tasks = [_t for _t in TASKS if _t != active_task] # inactive tasks + for name, child in parent.named_children(): + if hasattr(child, '__class__') and child.__class__.__name__ in [ + 'TaskSpecific_MoE', + 'FFN_TaskSpecific_Plus_Shared', + 'Linear_TaskSpecific_Plus_Shared', + 'Conv_TaskSpecific_Plus_Shared', + 'FFN_Shared_Plus_TaskLoRA', + 'Linear_Shared_Plus_TaskLoRA', + 'Conv_Shared_Plus_TaskLoRA', + ]: + for attr_name in [ # normalize attribute handling to avoid repetition + 'tasks_2_module', + 'task_ffn', 'task_proj', 'task_conv', + 'task_lora_in', 'task_lora_out', 'task_lora', + ]: + if hasattr(child, attr_name): + ml = getattr(child, attr_name) + if isinstance(ml, nn.ModuleList): + for i in unused_tasks: # move or delete parameters for inactive tasks + if method == 'del': + ml[i] = None + elif method == 'cpu': + ml[i].to('cpu') + else: raise Exception + elif isinstance(ml, ModuleDict_W): + ml.offload_unused_tasks(unused_tasks,method) + # recurse(child) + else: offload_unused_tasks(child,active_task,method) +def offload_unused_tasks__LD(modelMOE, task_keep: int, method: str, ): + # Remove or offload inactive task-related parameters to save CUDA memory (method: del|cpu) + offload_unused_tasks(modelMOE, task_keep, method) diff --git a/Other_dependencies/arcface/add.txt b/Other_dependencies/arcface/add.txt new file mode 100644 index 0000000000000000000000000000000000000000..e99aeafae94e9bb0808dadf455ebb1145034e628 --- /dev/null +++ b/Other_dependencies/arcface/add.txt @@ -0,0 +1 @@ +Add arcface model \ No newline at end of file diff --git a/Other_dependencies/arcface/model_ir_se50.pth b/Other_dependencies/arcface/model_ir_se50.pth new file mode 100644 index 0000000000000000000000000000000000000000..d3a030dd9a353d94023d3fc3a5baa0991ca3873b --- /dev/null +++ b/Other_dependencies/arcface/model_ir_se50.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a035c768259b98ab1ce0e646312f48b9e1e218197a0f80ac6765e88f8b6ddf28 +size 175367323 diff --git a/Other_dependencies/face_parsing/79999_iter.pth b/Other_dependencies/face_parsing/79999_iter.pth new file mode 100644 index 0000000000000000000000000000000000000000..ca57f3257ca7715bc340d065764bc249d985c287 --- /dev/null +++ b/Other_dependencies/face_parsing/79999_iter.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:468e13ca13a9b43cc0881a9f99083a430e9c0a38abd935431d1c28ee94b26567 +size 53289463 diff --git a/Other_dependencies/face_parsing/add.txt b/Other_dependencies/face_parsing/add.txt new file mode 100644 index 0000000000000000000000000000000000000000..d5f7f26e0efd5d06d6cae36bd0ff5dd74d9c6960 --- /dev/null +++ b/Other_dependencies/face_parsing/add.txt @@ -0,0 +1 @@ +Add face parsing model \ No newline at end of file diff --git a/Other_dependencies/mp_models/blaze_face_short_range.tflite b/Other_dependencies/mp_models/blaze_face_short_range.tflite new file mode 100644 index 0000000000000000000000000000000000000000..2645898ee18d8bf53746df830303779c9deabc7d --- /dev/null +++ b/Other_dependencies/mp_models/blaze_face_short_range.tflite @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b4578f35940bf5a1a655214a1cce5cab13eba73c1297cd78e1a04c2380b0152f +size 229746 diff --git a/Other_dependencies/mp_models/face_landmarker_v2_with_blendshapes.task b/Other_dependencies/mp_models/face_landmarker_v2_with_blendshapes.task new file mode 100644 index 0000000000000000000000000000000000000000..fedb14de6d2b6708a56c04ae259783e23404c1aa --- /dev/null +++ b/Other_dependencies/mp_models/face_landmarker_v2_with_blendshapes.task @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:64184e229b263107bc2b804c6625db1341ff2bb731874b0bcc2fe6544e0bc9ff +size 3758596 diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..2474f8427cb75f32f539d0e77f4c5927bacdcded --- /dev/null +++ b/app.py @@ -0,0 +1,239 @@ +""" +Hugging Face Space demo for UniBioTransfer. +Gradio interface for face/hair/motion/head transfer. + +ZeroGPU Compatible: +- Model initialized on CPU (no GPU memory during startup) +- Inference wrapped with @spaces.GPU decorator +- Thread-safe global variable access with Lock +""" + +import threading +import torch +from PIL import Image +import numpy as np + +# ========================================== +# 兼容层:处理本地测试 vs HF ZeroGPU 环境 +# ========================================== +try: + import spaces + print("Detected spaces library (Hugging Face environment).") +except ImportError: + print("Local environment detected. Mocking spaces.GPU...") + class spaces: + @staticmethod + def GPU(func): + return func # 本地测试时,装饰器变为空壳,直接执行原函数 + +from infer_hf import UniBioTransferPipeline + +# 锁和全局单例 Pipeline +inference_lock = threading.Lock() +global_pipeline :UniBioTransferPipeline = None + + +def get_pipeline(task): + """ + 单例模式:全局只初始化一次模型(放在 CPU),后续只切换任务。 + 强制写死 CPU,保证 ZeroGPU 全局初始化时不碰显卡。 + """ + global global_pipeline + if global_pipeline is None: + print("Initializing pipeline once on CPU...") + # 强制写死 CPU,保证 ZeroGPU 全局初始化时不碰显卡 + global_pipeline = UniBioTransferPipeline.from_pretrained( + repo_id="scy639/UniBioTransfer", + task=task, + device="cpu", + ) + else: + # 如果模型已经在内存中,只需切换 task ID 即可 + print(f"Switching existing pipeline to task: {task}") + global_pipeline.set_task(task) + return global_pipeline + + +# 核心:将所有会用到 GPU 的前向推理逻辑包裹在这里 +@spaces.GPU +def run_gpu_inference(pipeline:UniBioTransferPipeline, tgt_pil, ref_pil, ddim_steps, scale, seed, num_images): + """ + 这里是 ZeroGPU 分配算力的地方。进入此函数时可以安全地 to("cuda")。 + 如果是在本地服务器,这个装饰器没用,但内部的 .to("cuda") 同样生效。 + """ + return pipeline( + tgt_pil, + ref_pil, + ddim_steps=ddim_steps, + scale=scale, + seed=seed, + num_images=num_images, + ) + + +def inference(task, tgt_img, ref_img, ddim_steps, seed, num_images): + """ + Run inference for the demo. + """ + if tgt_img is None or ref_img is None: + return None, "Please upload both target and reference images." + + try: + # 1. 拿模型 (此时模型在 CPU) + pipeline = get_pipeline(task) + + tgt_pil = Image.fromarray(tgt_img).convert("RGB") + ref_pil = Image.fromarray(ref_img).convert("RGB") + + # 2. 加锁,防止并发污染 global_.task,进入 GPU 推理 + with inference_lock: + results = run_gpu_inference( + pipeline, + tgt_pil, + ref_pil, + int(ddim_steps), + float(3), + int(seed), + int(num_images) + ) + + return results, f"Success! Task: {task} transfer completed." + + except Exception as e: + import traceback + error_msg = f"Error: {str(e)}\n{traceback.format_exc()}" + print(f"{error_msg}") + return None, error_msg + + +def create_demo(): + """Create Gradio demo interface.""" + import gradio as gr + + with gr.Blocks(title="UniBioTransfer") as demo: + gr.Markdown( + """ + # UniBioTransfer + + Perform face transfer, hair transfer, motion transfer (face reenactment), and head transfer. + + - **Face Transfer**: Transfer face identity from reference to target + - **Hair Transfer**: Transfer hairstyle from reference to target + - **Motion Transfer**: Transfer motion(expression+head pose) from reference to target + - **Head Transfer**: Transfer entire head from reference to target + + [Code](https://github.com/scy639/UniBioTransfer) + [Project Page](https://scy639.github.io/UniBioTransfer.github.io/) + [Paper](https://arxiv.org/abs/2603.19637) + """ + ) + + with gr.Row(): + with gr.Column(): + task_dropdown = gr.Dropdown( + choices=["face", "hair", "motion", "head"], + value="face", + label="Task", + info="Select the transfer type", + ) + + with gr.Row(): + tgt_image = gr.Image( + label="Target Image", + type="numpy", + height=300, + ) + ref_image = gr.Image( + label="Reference Image", + type="numpy", + height=300, + ) + + with gr.Row(): + ddim_steps = gr.Slider( + minimum=4, + maximum=50, + value=50, + step=1, + label="DDIM Steps", + info="More steps = better quality but slower", + ) + # scale = gr.Slider( + # minimum=1.0, + # maximum=10.0, + # value=3.0, + # step=0.5, + # label="CFG Scale", + # info="Guidance scale for conditioning", + # ) + + seed = gr.Number( + value=42, + label="Random Seed", + info="For reproducibility", + ) + + num_images = gr.Slider( + minimum=1, + maximum=32, + value=4, + step=1, + label="Number of output images", + info="Multi-output with different initial noise", + ) + + run_btn = gr.Button("Run Inference", variant="primary") + + with gr.Column(): + output_gallery = gr.Gallery( + label="Results", + height=800, + columns=2, + ) + status_text = gr.Textbox( + label="Status", + lines=3, + ) + + gr.Markdown( +""" +### Usage +1. Upload a **target image** (the person whose face/hair/motion/head will be modified) +2. Upload a **reference image** (the source of the attribute to transfer) +3. Select the **task** type +4. Click "Run Inference" + +### Requirements +- Works best when the heads in the two input images have similar sizes. +""" + ) + + run_btn.click( + fn=inference, + inputs=[task_dropdown, tgt_image, ref_image, ddim_steps, seed, num_images], + outputs=[output_gallery, status_text], + ) + + task_dropdown.change( + fn=lambda t: f"Task switched to: {t} transfer", + inputs=[task_dropdown], + outputs=[status_text], + ) + + gr.Examples( + examples=[ + ["face", "examples/face/tgt.png", "examples/face/ref.png", 20, 42, 4], + ["hair", "examples/hair/tgt.png", "examples/hair/ref.png", 20, 42, 4], + ["motion", "examples/motion/tgt.png", "examples/motion/ref.png", 20, 42, 4], + ["head", "examples/head/tgt.png", "examples/head/ref.png", 20, 42, 4], + ], + inputs=[task_dropdown, tgt_image, ref_image, ddim_steps, seed, num_images], + label="Examples", + ) + + return demo + + +if __name__ == "__main__": + demo = create_demo() + demo.launch() diff --git a/checkpoints/pretrained.json b/checkpoints/pretrained.json new file mode 100644 index 0000000000000000000000000000000000000000..31e99a9992de126677c05f152c625ec6b73c0977 --- /dev/null +++ b/checkpoints/pretrained.json @@ -0,0 +1,1072 @@ +{ + ".model.diffusion_model.input_blocks.0.0": [ + 4, + 4, + 4, + 4 + ], + ".model.diffusion_model.input_blocks.1.0.in_layers.2": [ + 5, + 4, + 8, + 4 + ], + ".model.diffusion_model.input_blocks.1.0.out_layers.3": [ + 7, + 4, + 12, + 4 + ], + ".model.diffusion_model.input_blocks.1.1.proj_in": [ + 4, + 4, + 6, + 4 + ], + ".model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff": [ + [ + 5, + 4, + 8, + 4 + ], + [ + 7, + 4, + 12, + 4 + ] + ], + ".model.diffusion_model.input_blocks.1.1.proj_out": [ + 4, + 4, + 8, + 4 + ], + ".model.diffusion_model.input_blocks.2.0.in_layers.2": [ + 14, + 5, + 19, + 4 + ], + ".model.diffusion_model.input_blocks.2.0.out_layers.3": [ + 16, + 4, + 15, + 4 + ], + ".model.diffusion_model.input_blocks.2.1.proj_in": [ + 9, + 4, + 11, + 4 + ], + ".model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff": [ + [ + 16, + 4, + 14, + 4 + ], + [ + 17, + 4, + 14, + 4 + ] + ], + ".model.diffusion_model.input_blocks.2.1.proj_out": [ + 13, + 4, + 11, + 4 + ], + ".model.diffusion_model.input_blocks.3.0.op": [ + 26, + 7, + 31, + 8 + ], + ".model.diffusion_model.input_blocks.4.0.in_layers.2": [ + 23, + 6, + 31, + 8 + ], + ".model.diffusion_model.input_blocks.4.0.out_layers.3": [ + 27, + 6, + 37, + 8 + ], + ".model.diffusion_model.input_blocks.4.0.skip_connection": [ + 20, + 6, + 22, + 6 + ], + ".model.diffusion_model.input_blocks.4.1.proj_in": [ + 20, + 6, + 28, + 7 + ], + ".model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff": [ + [ + 22, + 6, + 37, + 8 + ], + [ + 31, + 8, + 39, + 10 + ] + ], + ".model.diffusion_model.input_blocks.4.1.proj_out": [ + 26, + 8, + 37, + 10 + ], + ".model.diffusion_model.input_blocks.5.0.in_layers.2": [ + 27, + 10, + 46, + 11 + ], + ".model.diffusion_model.input_blocks.5.0.out_layers.3": [ + 18, + 6, + 36, + 7 + ], + ".model.diffusion_model.input_blocks.5.1.proj_in": [ + 20, + 7, + 29, + 7 + ], + ".model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff": [ + [ + 22, + 7, + 41, + 9 + ], + [ + 26, + 10, + 33, + 12 + ] + ], + ".model.diffusion_model.input_blocks.5.1.proj_out": [ + 24, + 9, + 33, + 10 + ], + ".model.diffusion_model.input_blocks.6.0.op": [ + 52, + 17, + 76, + 20 + ], + ".model.diffusion_model.input_blocks.7.0.in_layers.2": [ + 50, + 14, + 80, + 19 + ], + ".model.diffusion_model.input_blocks.7.0.out_layers.3": [ + 56, + 15, + 90, + 22 + ], + ".model.diffusion_model.input_blocks.7.0.skip_connection": [ + 40, + 13, + 59, + 16 + ], + ".model.diffusion_model.input_blocks.7.1.proj_in": [ + 33, + 12, + 55, + 14 + ], + ".model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff": [ + [ + 39, + 11, + 62, + 13 + ], + [ + 59, + 17, + 82, + 21 + ] + ], + ".model.diffusion_model.input_blocks.7.1.proj_out": [ + 55, + 17, + 80, + 22 + ], + ".model.diffusion_model.input_blocks.8.0.in_layers.2": [ + 73, + 20, + 108, + 27 + ], + ".model.diffusion_model.input_blocks.8.0.out_layers.3": [ + 65, + 15, + 95, + 21 + ], + ".model.diffusion_model.input_blocks.8.1.proj_in": [ + 43, + 13, + 69, + 18 + ], + ".model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff": [ + [ + 41, + 10, + 68, + 13 + ], + [ + 56, + 17, + 85, + 21 + ] + ], + ".model.diffusion_model.input_blocks.8.1.proj_out": [ + 52, + 16, + 78, + 20 + ], + ".model.diffusion_model.input_blocks.9.0.op": [ + 90, + 30, + 157, + 39 + ], + ".model.diffusion_model.input_blocks.10.0.in_layers.2": [ + 81, + 21, + 113, + 26 + ], + ".model.diffusion_model.input_blocks.10.0.out_layers.3": [ + 80, + 21, + 123, + 28 + ], + ".model.diffusion_model.input_blocks.11.0.in_layers.2": [ + 87, + 23, + 118, + 28 + ], + ".model.diffusion_model.input_blocks.11.0.out_layers.3": [ + 77, + 20, + 113, + 26 + ], + ".model.diffusion_model.middle_block.0.in_layers.2": [ + 84, + 22, + 113, + 26 + ], + ".model.diffusion_model.middle_block.0.out_layers.3": [ + 68, + 16, + 99, + 21 + ], + ".model.diffusion_model.middle_block.1.proj_in": [ + 36, + 10, + 59, + 13 + ], + ".model.diffusion_model.middle_block.1.transformer_blocks.0.ff": [ + [ + 31, + 5, + 45, + 6 + ], + [ + 55, + 15, + 69, + 17 + ] + ], + ".model.diffusion_model.middle_block.1.proj_out": [ + 39, + 10, + 61, + 14 + ], + ".model.diffusion_model.middle_block.2.in_layers.2": [ + 73, + 17, + 104, + 23 + ], + ".model.diffusion_model.middle_block.2.out_layers.3": [ + 62, + 15, + 88, + 20 + ], + ".model.diffusion_model.output_blocks.0.0.in_layers.2": [ + 96, + 25, + 135, + 32 + ], + ".model.diffusion_model.output_blocks.0.0.out_layers.3": [ + 86, + 21, + 120, + 28 + ], + ".model.diffusion_model.output_blocks.0.0.skip_connection": [ + 64, + 21, + 106, + 27 + ], + ".model.diffusion_model.output_blocks.1.0.in_layers.2": [ + 94, + 27, + 155, + 36 + ], + ".model.diffusion_model.output_blocks.1.0.out_layers.3": [ + 86, + 24, + 136, + 31 + ], + ".model.diffusion_model.output_blocks.1.0.skip_connection": [ + 72, + 23, + 115, + 29 + ], + ".model.diffusion_model.output_blocks.2.0.in_layers.2": [ + 84, + 31, + 164, + 39 + ], + ".model.diffusion_model.output_blocks.2.0.out_layers.3": [ + 42, + 19, + 123, + 29 + ], + ".model.diffusion_model.output_blocks.2.0.skip_connection": [ + 72, + 24, + 110, + 28 + ], + ".model.diffusion_model.output_blocks.2.1.conv": [ + 72, + 25, + 121, + 29 + ], + ".model.diffusion_model.output_blocks.3.0.in_layers.2": [ + 85, + 31, + 158, + 38 + ], + ".model.diffusion_model.output_blocks.3.0.out_layers.3": [ + 42, + 21, + 117, + 25 + ], + ".model.diffusion_model.output_blocks.3.0.skip_connection": [ + 71, + 23, + 111, + 28 + ], + ".model.diffusion_model.output_blocks.3.1.proj_in": [ + 42, + 14, + 73, + 18 + ], + ".model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff": [ + [ + 37, + 10, + 68, + 13 + ], + [ + 60, + 18, + 83, + 20 + ] + ], + ".model.diffusion_model.output_blocks.3.1.proj_out": [ + 51, + 18, + 79, + 21 + ], + ".model.diffusion_model.output_blocks.4.0.in_layers.2": [ + 104, + 32, + 159, + 40 + ], + ".model.diffusion_model.output_blocks.4.0.out_layers.3": [ + 83, + 24, + 125, + 29 + ], + ".model.diffusion_model.output_blocks.4.0.skip_connection": [ + 73, + 22, + 101, + 28 + ], + ".model.diffusion_model.output_blocks.4.1.proj_in": [ + 49, + 15, + 77, + 20 + ], + ".model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff": [ + [ + 38, + 11, + 70, + 14 + ], + [ + 63, + 16, + 85, + 20 + ] + ], + ".model.diffusion_model.output_blocks.4.1.proj_out": [ + 51, + 18, + 81, + 21 + ], + ".model.diffusion_model.output_blocks.5.0.in_layers.2": [ + 91, + 33, + 161, + 40 + ], + ".model.diffusion_model.output_blocks.5.0.out_layers.3": [ + 83, + 26, + 140, + 32 + ], + ".model.diffusion_model.output_blocks.5.0.skip_connection": [ + 81, + 24, + 116, + 30 + ], + ".model.diffusion_model.output_blocks.5.1.proj_in": [ + 48, + 16, + 82, + 21 + ], + ".model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff": [ + [ + 34, + 12, + 76, + 15 + ], + [ + 55, + 16, + 81, + 18 + ] + ], + ".model.diffusion_model.output_blocks.5.1.proj_out": [ + 57, + 19, + 85, + 22 + ], + ".model.diffusion_model.output_blocks.5.2.conv": [ + 108, + 34, + 159, + 41 + ], + ".model.diffusion_model.output_blocks.6.0.in_layers.2": [ + 55, + 18, + 87, + 22 + ], + ".model.diffusion_model.output_blocks.6.0.out_layers.3": [ + 32, + 13, + 54, + 15 + ], + ".model.diffusion_model.output_blocks.6.0.skip_connection": [ + 25, + 9, + 30, + 14 + ], + ".model.diffusion_model.output_blocks.6.1.proj_in": [ + 26, + 9, + 40, + 11 + ], + ".model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff": [ + [ + 25, + 8, + 47, + 12 + ], + [ + 36, + 11, + 47, + 13 + ] + ], + ".model.diffusion_model.output_blocks.6.1.proj_out": [ + 23, + 10, + 38, + 12 + ], + ".model.diffusion_model.output_blocks.7.0.in_layers.2": [ + 55, + 18, + 82, + 20 + ], + ".model.diffusion_model.output_blocks.7.0.out_layers.3": [ + 47, + 14, + 65, + 17 + ], + ".model.diffusion_model.output_blocks.7.0.skip_connection": [ + 40, + 11, + 40, + 12 + ], + ".model.diffusion_model.output_blocks.7.1.proj_in": [ + 27, + 9, + 41, + 11 + ], + ".model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff": [ + [ + 27, + 8, + 47, + 11 + ], + [ + 34, + 11, + 47, + 12 + ] + ], + ".model.diffusion_model.output_blocks.7.1.proj_out": [ + 33, + 9, + 39, + 12 + ], + ".model.diffusion_model.output_blocks.8.0.in_layers.2": [ + 58, + 17, + 82, + 20 + ], + ".model.diffusion_model.output_blocks.8.0.out_layers.3": [ + 56, + 15, + 75, + 18 + ], + ".model.diffusion_model.output_blocks.8.0.skip_connection": [ + 44, + 10, + 47, + 11 + ], + ".model.diffusion_model.output_blocks.8.1.proj_in": [ + 32, + 9, + 43, + 10 + ], + ".model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff": [ + [ + 28, + 7, + 47, + 8 + ], + [ + 35, + 8, + 45, + 8 + ] + ], + ".model.diffusion_model.output_blocks.8.1.proj_out": [ + 35, + 10, + 44, + 10 + ], + ".model.diffusion_model.output_blocks.8.2.conv": [ + 65, + 19, + 85, + 22 + ], + ".model.diffusion_model.output_blocks.9.0.in_layers.2": [ + 37, + 10, + 35, + 10 + ], + ".model.diffusion_model.output_blocks.9.0.out_layers.3": [ + 28, + 6, + 23, + 5 + ], + ".model.diffusion_model.output_blocks.9.0.skip_connection": [ + 15, + 4, + 4, + 4 + ], + ".model.diffusion_model.output_blocks.9.1.proj_in": [ + 16, + 4, + 6, + 4 + ], + ".model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff": [ + [ + 24, + 5, + 23, + 5 + ], + [ + 23, + 5, + 24, + 6 + ] + ], + ".model.diffusion_model.output_blocks.9.1.proj_out": [ + 16, + 4, + 14, + 4 + ], + ".model.diffusion_model.output_blocks.10.0.in_layers.2": [ + 31, + 9, + 38, + 10 + ], + ".model.diffusion_model.output_blocks.10.0.out_layers.3": [ + 20, + 4, + 24, + 4 + ], + ".model.diffusion_model.output_blocks.10.0.skip_connection": [ + 4, + 4, + 7, + 4 + ], + ".model.diffusion_model.output_blocks.10.1.proj_in": [ + 6, + 4, + 11, + 4 + ], + ".model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff": [ + [ + 17, + 4, + 21, + 4 + ], + [ + 17, + 5, + 21, + 5 + ] + ], + ".model.diffusion_model.output_blocks.10.1.proj_out": [ + 9, + 4, + 12, + 4 + ], + ".model.diffusion_model.output_blocks.11.0.in_layers.2": [ + 7, + 4, + 18, + 4 + ], + ".model.diffusion_model.output_blocks.11.0.out_layers.3": [ + 16, + 6, + 22, + 5 + ], + ".model.diffusion_model.output_blocks.11.0.skip_connection": [ + 4, + 4, + 4, + 4 + ], + ".model.diffusion_model.output_blocks.11.1.proj_in": [ + 9, + 4, + 13, + 4 + ], + ".model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff": [ + [ + 19, + 4, + 24, + 4 + ], + [ + 12, + 4, + 14, + 4 + ] + ], + ".model.diffusion_model.output_blocks.11.1.proj_out": [ + 7, + 4, + 10, + 4 + ], + ".model.diffusion_model.out.2": [ + 4, + 4, + 4, + 4 + ], + ".model.diffusion_model_refNet.input_blocks.0.0": [ + 4, + 4, + 4, + 4 + ], + ".model.diffusion_model_refNet.input_blocks.1.0.in_layers.2": [ + 17, + 8, + 26, + 8 + ], + ".model.diffusion_model_refNet.input_blocks.1.0.out_layers.3": [ + 21, + 14, + 37, + 12 + ], + ".model.diffusion_model_refNet.input_blocks.1.1.proj_in": [ + 11, + 8, + 19, + 6 + ], + ".model.diffusion_model_refNet.input_blocks.1.1.transformer_blocks.0.ff": [ + [ + 14, + 12, + 24, + 7 + ], + [ + 17, + 12, + 26, + 7 + ] + ], + ".model.diffusion_model_refNet.input_blocks.1.1.proj_out": [ + 11, + 7, + 20, + 5 + ], + ".model.diffusion_model_refNet.input_blocks.2.0.in_layers.2": [ + 27, + 15, + 40, + 13 + ], + ".model.diffusion_model_refNet.input_blocks.2.0.out_layers.3": [ + 26, + 15, + 38, + 12 + ], + ".model.diffusion_model_refNet.input_blocks.2.1.proj_in": [ + 15, + 7, + 21, + 6 + ], + ".model.diffusion_model_refNet.input_blocks.2.1.transformer_blocks.0.ff": [ + [ + 17, + 13, + 30, + 9 + ], + [ + 16, + 12, + 27, + 8 + ] + ], + ".model.diffusion_model_refNet.input_blocks.2.1.proj_out": [ + 12, + 7, + 18, + 6 + ], + ".model.diffusion_model_refNet.input_blocks.3.0.op": [ + 27, + 13, + 43, + 12 + ], + ".model.diffusion_model_refNet.input_blocks.4.0.in_layers.2": [ + 30, + 19, + 49, + 14 + ], + ".model.diffusion_model_refNet.input_blocks.4.0.out_layers.3": [ + 32, + 26, + 55, + 15 + ], + ".model.diffusion_model_refNet.input_blocks.4.0.skip_connection": [ + 22, + 10, + 30, + 9 + ], + ".model.diffusion_model_refNet.input_blocks.4.1.proj_in": [ + 22, + 14, + 35, + 10 + ], + ".model.diffusion_model_refNet.input_blocks.4.1.transformer_blocks.0.ff": [ + [ + 26, + 25, + 52, + 14 + ], + [ + 28, + 22, + 51, + 14 + ] + ], + ".model.diffusion_model_refNet.input_blocks.4.1.proj_out": [ + 24, + 15, + 40, + 11 + ], + ".model.diffusion_model_refNet.input_blocks.5.0.in_layers.2": [ + 44, + 30, + 78, + 22 + ], + ".model.diffusion_model_refNet.input_blocks.5.0.out_layers.3": [ + 28, + 29, + 56, + 15 + ], + ".model.diffusion_model_refNet.input_blocks.5.1.proj_in": [ + 20, + 13, + 34, + 9 + ], + ".model.diffusion_model_refNet.input_blocks.5.1.transformer_blocks.0.ff": [ + [ + 26, + 27, + 52, + 14 + ], + [ + 23, + 23, + 53, + 14 + ] + ], + ".model.diffusion_model_refNet.input_blocks.5.1.proj_out": [ + 17, + 14, + 36, + 10 + ], + ".model.diffusion_model_refNet.input_blocks.6.0.op": [ + 46, + 31, + 82, + 21 + ], + ".model.diffusion_model_refNet.input_blocks.7.0.in_layers.2": [ + 75, + 41, + 116, + 32 + ], + ".model.diffusion_model_refNet.input_blocks.7.0.out_layers.3": [ + 67, + 50, + 108, + 29 + ], + ".model.diffusion_model_refNet.input_blocks.7.0.skip_connection": [ + 31, + 19, + 59, + 15 + ], + ".model.diffusion_model_refNet.input_blocks.7.1.proj_in": [ + 36, + 29, + 73, + 19 + ], + ".model.diffusion_model_refNet.input_blocks.7.1.transformer_blocks.0.ff": [ + [ + 74, + 61, + 106, + 26 + ], + [ + 63, + 49, + 90, + 24 + ] + ], + ".model.diffusion_model_refNet.input_blocks.7.1.proj_out": [ + 34, + 29, + 68, + 18 + ], + ".model.diffusion_model_refNet.input_blocks.8.0.in_layers.2": [ + 92, + 56, + 128, + 36 + ], + ".model.diffusion_model_refNet.input_blocks.8.0.out_layers.3": [ + 43, + 51, + 66, + 16 + ], + ".model.diffusion_model_refNet.input_blocks.8.1.proj_in": [ + 26, + 28, + 59, + 15 + ], + ".model.diffusion_model_refNet.input_blocks.8.1.transformer_blocks.0.ff": [ + [ + 188, + 69, + 232, + 69 + ], + [ + 140, + 51, + 173, + 51 + ] + ], + ".model.diffusion_model_refNet.input_blocks.8.1.proj_out": [ + 91, + 33, + 113, + 33 + ] +} \ No newline at end of file diff --git a/download_checkpoints.py b/download_checkpoints.py new file mode 100644 index 0000000000000000000000000000000000000000..92908f318b574ad79b5805f727c7882ee803df85 --- /dev/null +++ b/download_checkpoints.py @@ -0,0 +1,29 @@ +from pathlib import Path +import os +from imports import * + + + +def _download(repo_id, filename, local_path: Path) -> Path: + local_path = Path(local_path) + from huggingface_hub import hf_hub_download + local_path.parent.mkdir(parents=True, exist_ok=True) + token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN") + print(f"downloading to {local_path}") + downloaded = hf_hub_download( + repo_id=repo_id, + filename=filename, + local_dir=str(local_path.parent), + local_dir_use_symlinks=False, + token=token, + ) + + + +_download("CompVis/stable-diffusion-v-1-4-original",SD14_filename, SD14_localpath) + +_download("scy639/UniBioTransfer",PRETRAIN_CKPT_PATH, ".") +_download("scy639/UniBioTransfer",PRETRAIN_JSON_PATH, ".") + +_download("scy639/UniBioTransfer","Other_dependencies/arcface/model_ir_se50.pth", ".") +_download("scy639/UniBioTransfer","Other_dependencies/face_parsing/79999_iter.pth", ".") diff --git a/eval_tool/lpips/__init__.py b/eval_tool/lpips/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/eval_tool/lpips/lpips.py b/eval_tool/lpips/lpips.py new file mode 100644 index 0000000000000000000000000000000000000000..f6ff35be6655a924659270772ab2f792246dc60b --- /dev/null +++ b/eval_tool/lpips/lpips.py @@ -0,0 +1,35 @@ +import torch +import torch.nn as nn + +from eval_tool.lpips.networks import get_network, LinLayers +from eval_tool.lpips.utils import get_state_dict + + +class LPIPS(nn.Module): + r"""Creates a criterion that measures + Learned Perceptual Image Patch Similarity (LPIPS). + Arguments: + net_type (str): the network type to compare the features: + 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. + version (str): the version of LPIPS. Default: 0.1. + """ + def __init__(self, net_type: str = 'alex', version: str = '0.1'): + + assert version in ['0.1'], 'v0.1 is only supported now' + + super(LPIPS, self).__init__() + + # pretrained network + self.net = get_network(net_type) + + # linear layers + self.lin = LinLayers(self.net.n_channels_list) + self.lin.load_state_dict(get_state_dict(net_type, version)) + + def forward(self, x: torch.Tensor, y: torch.Tensor): + feat_x, feat_y = self.net(x), self.net(y) + + diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)] + res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)] + + return torch.sum(torch.cat(res, 0)) / x.shape[0] diff --git a/eval_tool/lpips/networks.py b/eval_tool/lpips/networks.py new file mode 100644 index 0000000000000000000000000000000000000000..c258acf9a7300b15a84fe7d98c4369cfab6e62aa --- /dev/null +++ b/eval_tool/lpips/networks.py @@ -0,0 +1,96 @@ +from typing import Sequence + +from itertools import chain + +import torch +import torch.nn as nn +from torchvision import models + +from eval_tool.lpips.utils import normalize_activation + + +def get_network(net_type: str): + if net_type == 'alex': + return AlexNet() + elif net_type == 'squeeze': + return SqueezeNet() + elif net_type == 'vgg': + return VGG16() + else: + raise NotImplementedError('choose net_type from [alex, squeeze, vgg].') + + +class LinLayers(nn.ModuleList): + def __init__(self, n_channels_list: Sequence[int]): + super(LinLayers, self).__init__([ + nn.Sequential( + nn.Identity(), + nn.Conv2d(nc, 1, 1, 1, 0, bias=False) + ) for nc in n_channels_list + ]) + + for param in self.parameters(): + param.requires_grad = False + + +class BaseNet(nn.Module): + def __init__(self): + super(BaseNet, self).__init__() + + # register buffer + self.register_buffer( + 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) + self.register_buffer( + 'std', torch.Tensor([.458, .448, .450])[None, :, None, None]) + + def set_requires_grad(self, state: bool): + for param in chain(self.parameters(), self.buffers()): + param.requires_grad = state + + def z_score(self, x: torch.Tensor): + return (x - self.mean) / self.std + + def forward(self, x: torch.Tensor): + x = self.z_score(x) + + output = [] + for i, (_, layer) in enumerate(self.layers._modules.items(), 1): + x = layer(x) + if i in self.target_layers: + output.append(normalize_activation(x)) + if len(output) == len(self.target_layers): + break + return output + + +class SqueezeNet(BaseNet): + def __init__(self): + super(SqueezeNet, self).__init__() + + self.layers = models.squeezenet1_1(True).features + self.target_layers = [2, 5, 8, 10, 11, 12, 13] + self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] + + self.set_requires_grad(False) + + +class AlexNet(BaseNet): + def __init__(self): + super(AlexNet, self).__init__() + + self.layers = models.alexnet(True).features + self.target_layers = [2, 5, 8, 10, 12] + self.n_channels_list = [64, 192, 384, 256, 256] + + self.set_requires_grad(False) + + +class VGG16(BaseNet): + def __init__(self): + super(VGG16, self).__init__() + + self.layers = models.vgg16(True).features + self.target_layers = [4, 9, 16, 23, 30] + self.n_channels_list = [64, 128, 256, 512, 512] + + self.set_requires_grad(False) \ No newline at end of file diff --git a/eval_tool/lpips/utils.py b/eval_tool/lpips/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5a771d327c163cdb885b237b358a908255b4dfba --- /dev/null +++ b/eval_tool/lpips/utils.py @@ -0,0 +1,30 @@ +from collections import OrderedDict + +import torch + + +def normalize_activation(x, eps=1e-10): + norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)+1e-16) # + return x / (norm_factor + eps) + + +def get_state_dict(net_type: str = 'alex', version: str = '0.1'): + # build url + url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \ + + f'master/lpips/weights/v{version}/{net_type}.pth' + + # download + old_state_dict = torch.hub.load_state_dict_from_url( + url, progress=True, + map_location=None if torch.cuda.is_available() else torch.device('cpu') + ) + + # rename keys + new_state_dict = OrderedDict() + for key, val in old_state_dict.items(): + new_key = key + new_key = new_key.replace('lin', '') + new_key = new_key.replace('model.', '') + new_state_dict[new_key] = val + + return new_state_dict diff --git a/examples/face/ref-semantic_mask.png b/examples/face/ref-semantic_mask.png new file mode 100644 index 0000000000000000000000000000000000000000..b6acb6f81f8000a95e3731357bd3052f2c58985c Binary files /dev/null and b/examples/face/ref-semantic_mask.png differ diff --git a/examples/face/ref.png b/examples/face/ref.png new file mode 100644 index 0000000000000000000000000000000000000000..47d94edc21e00c0d61ccff4e4057bb89f6c27d9e --- /dev/null +++ b/examples/face/ref.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a477d2f5928b4ab40046fdcd7a0b9d4f35d619822eccd4137396fc06dbb82b48 +size 398792 diff --git a/examples/face/tgt-semantic_mask.png b/examples/face/tgt-semantic_mask.png new file mode 100644 index 0000000000000000000000000000000000000000..4ed605e2b8a263dca0424f667bae28d9572e4977 Binary files /dev/null and b/examples/face/tgt-semantic_mask.png differ diff --git a/examples/face/tgt.png b/examples/face/tgt.png new file mode 100644 index 0000000000000000000000000000000000000000..c20f5739bbaeeae4697f4d83a5ae6984ecef5a00 --- /dev/null +++ b/examples/face/tgt.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dea3592ab41c766b8d1ba041eda3b545871f1684528bff5c40321a9fbd7c8546 +size 409790 diff --git a/examples/hair/ref-semantic_mask.png b/examples/hair/ref-semantic_mask.png new file mode 100644 index 0000000000000000000000000000000000000000..50b0721263fb58d9016fcd9910634b3ed8269096 Binary files /dev/null and b/examples/hair/ref-semantic_mask.png differ diff --git a/examples/hair/ref.png b/examples/hair/ref.png new file mode 100644 index 0000000000000000000000000000000000000000..b8e653fcac4bacafffd92b725aa26eccb9919b42 --- /dev/null +++ b/examples/hair/ref.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:946981b5a077df22a393d6e1ebb1bdef73c020f25e99339f732345777ae6565c +size 434529 diff --git a/examples/hair/tgt-semantic_mask.png b/examples/hair/tgt-semantic_mask.png new file mode 100644 index 0000000000000000000000000000000000000000..a588a92135b1796174829002450e37e9be78e867 Binary files /dev/null and b/examples/hair/tgt-semantic_mask.png differ diff --git a/examples/hair/tgt.png b/examples/hair/tgt.png new file mode 100644 index 0000000000000000000000000000000000000000..fadc4c3de0c1ee4ed486e9b364ae0e0ef67af6c7 --- /dev/null +++ b/examples/hair/tgt.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:daa1c69651861183fe113995abb20192fafe829a7b1a349c2ccc2713d7b057b4 +size 398893 diff --git a/examples/head/ref-semantic_mask.png b/examples/head/ref-semantic_mask.png new file mode 100644 index 0000000000000000000000000000000000000000..54980ae84c796668abbb108c18d4d1da96cdf8a8 Binary files /dev/null and b/examples/head/ref-semantic_mask.png differ diff --git a/examples/head/ref.png b/examples/head/ref.png new file mode 100644 index 0000000000000000000000000000000000000000..71e283932b47658ae7711dc2b016b024953f229e --- /dev/null +++ b/examples/head/ref.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ff89b38ec94ee110a8760c6bb6b316c8ad2f4502a14aec1d217305e0ca2dfa47 +size 439504 diff --git a/examples/head/tgt-semantic_mask.png b/examples/head/tgt-semantic_mask.png new file mode 100644 index 0000000000000000000000000000000000000000..cad2f7f9195983fd4addc9e4f91c430ae88d6dce Binary files /dev/null and b/examples/head/tgt-semantic_mask.png differ diff --git a/examples/head/tgt.png b/examples/head/tgt.png new file mode 100644 index 0000000000000000000000000000000000000000..152d4074aa8bfc5840078aab04fa109ac0081e50 --- /dev/null +++ b/examples/head/tgt.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9467c48978020761d76df2e133808f490f2eacb359f5fda61d08017a77b20151 +size 335945 diff --git a/examples/inputs.txt b/examples/inputs.txt new file mode 100644 index 0000000000000000000000000000000000000000..3f1acdfd40b72aa109ff5924dd53faaf0390cd34 --- /dev/null +++ b/examples/inputs.txt @@ -0,0 +1,5 @@ +target_path_1 reference_path_1 +target_path_2 reference_path_2 +target_path_3 reference_path_3 +target_path_4 reference_path_4 +target_path_5 reference_path_5 \ No newline at end of file diff --git a/examples/motion/ref-semantic_mask.png b/examples/motion/ref-semantic_mask.png new file mode 100644 index 0000000000000000000000000000000000000000..2303e7cc4c116bd7ec12604f32a6bab818ad0d5e Binary files /dev/null and b/examples/motion/ref-semantic_mask.png differ diff --git a/examples/motion/ref.png b/examples/motion/ref.png new file mode 100644 index 0000000000000000000000000000000000000000..a9040205d54d5dba74a0c957786a8410e741d23f --- /dev/null +++ b/examples/motion/ref.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8b58a80c13e5741072b6c603f7edd61ba3e3c9456536064b0d4746f4bab9c786 +size 424338 diff --git a/examples/motion/tgt-semantic_mask.png b/examples/motion/tgt-semantic_mask.png new file mode 100644 index 0000000000000000000000000000000000000000..f2832bc441c7414f1e881f969ec54f02b92e4461 Binary files /dev/null and b/examples/motion/tgt-semantic_mask.png differ diff --git a/examples/motion/tgt.png b/examples/motion/tgt.png new file mode 100644 index 0000000000000000000000000000000000000000..03ef82b0abfee8f396fbf34d6694d6c9413d7334 --- /dev/null +++ b/examples/motion/tgt.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6e527760e591e97ab36892ee683f91673a678a8b23b7603779d430cfbcc0e5f3 +size 426727 diff --git a/gen_lmk_and_mask.py b/gen_lmk_and_mask.py new file mode 100644 index 0000000000000000000000000000000000000000..0427de4313c5fb1453f1920b7c376c4c72d15463 --- /dev/null +++ b/gen_lmk_and_mask.py @@ -0,0 +1,41 @@ +ENABLE_lmk_cache = False +ENABLE_mask_cache = False + + +import cv2 +from imports import * +from util_cv2 import cv2_resize_auto_interpolation +from Mediapipe_Result_Cache import Mediapipe_Result_Cache +from lmk_util.lmk_extractor import LandmarkExtractor + + +def gen_lmk_and_mask(img_paths, size=512, write_cache=True): + extractor = LandmarkExtractor() + cache = Mediapipe_Result_Cache() + seen = set() + for p in img_paths: + if not p: + continue + p = str(p) + if p in seen: + continue + seen.add(p) + + cache_path = cache.get_path(p) + if not ( cache_path.exists() and ENABLE_lmk_cache ): + img = cv2.imread(p) + if img is None: + print(f"cv2.imread failed: {p}") + raise + continue + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = cv2_resize_auto_interpolation(img, (size, size)) + lmks = extractor.extract_single(img) + if lmks is None: + print(f"no lmks: {p}") + raise + continue + if write_cache: + cache.set(p, lmks) + + path_img_2_path_mask(p, reuse_if_exists=ENABLE_mask_cache, label_mode="RF12_") diff --git a/gen_semantic_mask.py b/gen_semantic_mask.py new file mode 100644 index 0000000000000000000000000000000000000000..e6c03008f15cfe018db4080841f525d9f19617a4 --- /dev/null +++ b/gen_semantic_mask.py @@ -0,0 +1,90 @@ +""" +def: + tgt: Target image to be edited (face swapped) + ref: Face ID source image (also called src in REFace) + swap: Swapped output image, using face ID from ref to replace face in tgt +""" +import os +from pathlib import Path +from tqdm import tqdm +from my_py_lib.image_util import print_image_statistics +import torch +import torchvision +from PIL import Image +import numpy as np +from einops import rearrange +from torchvision.transforms import Resize +from torchvision.utils import make_grid +from contextlib import nullcontext +from torch.cuda.amp import autocast +from omegaconf import OmegaConf +import cv2 + +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# Sampling configs +DDIM_STEPS = 50 +GUIDANCE_SCALE = 3.0 +IMG_SIZE = 512 +LATENT_CHANNELS = 4 +DOWNSAMPLE_FACTOR = 8 +START_NOISE_T = 1000 +DDIM_ETA = 0.0 +PRECISION = "full" # or "autocast" +FIXED_CODE = False # whether to use fixed starting code +SAVE_INTERMEDIATES = False # whether to save intermediate results +LOG_EVERY_T = 100 # log frequency during sampling + + +class MaskModel_LazyLoader: + model = None + @classmethod + def get(cls): + faceParsing_ckpt = "Other_dependencies/face_parsing/79999_iter.pth" + if cls.model is None: + from pretrained.face_parsing.face_parsing_demo import init_faceParsing_pretrained_model + cls.model = init_faceParsing_pretrained_model( + 'default', + faceParsing_ckpt, + '' + ) + print(f"Initialized face parsing model from {faceParsing_ckpt}") + return cls.model + + +def gen_semantic_mask(path_img: Path, path_mask_to_save: Path, label_mode:str, path_vis: Path = None): + """Generate semantic mask for an image using face parsing model""" + pil_im = Image.open(path_img).convert("RGB") + w, h = pil_im.size + # print(f"{pil_im.size=}") # 512,512 + TMP_size = 1024 + if w != TMP_size or h != TMP_size: + pil_im = pil_im.resize((TMP_size, TMP_size), Image.BILINEAR) + + model = MaskModel_LazyLoader.get() + from pretrained.face_parsing.face_parsing_demo import faceParsing_demo, vis_parsing_maps + + # print(f"{pil_im.size=}") # 1024,1024 + # Generate mask with conversion to seg12 format + mask = faceParsing_demo( + model, + pil_im, + label_mode, + model_name='default' + ) + + try: + Image.fromarray(mask).save(path_mask_to_save) + except Exception as e: + print(f"{e=}") + print(f"{path_mask_to_save=}") + if path_mask_to_save.exists(): + path_mask_to_save.unlink() + print(f'path_mask_to_save.unlink()') + # print(f"Saved mask: {path_mask_to_save}") + # print(f"{mask.shape=}") # 512,512 + + if path_vis: + mask_vis = vis_parsing_maps(pil_im, mask) + Image.fromarray(mask_vis).save(path_vis) + print(f"Saved mask vis: {path_vis}") \ No newline at end of file diff --git a/get_mask.py b/get_mask.py new file mode 100644 index 0000000000000000000000000000000000000000..eb9c79a559e17766b1871d3d9579262ccfaead1a --- /dev/null +++ b/get_mask.py @@ -0,0 +1,68 @@ +from util_and_constant import * +from pathlib import Path +from PIL import Image +import cv2 +import numpy as np + +def path_img_2_mask( + path_img, + preserve=(1, 2, 3, 5, 6, 7, 9, 10, 11, ), # int | list-liek. Default val represents face +): + """ + 0 bg, 1 mouth, 2 eyebrow, 3 eyes, 4 hair, 5 nose, 6 face (excluding facial parts), 7: ear, 8: neck, 9: tooth + 10: eye_glass, 11: ear_rings + """ + if isinstance(preserve,int): + preserve = (preserve,) + if 1: + assert isinstance(preserve,tuple) or isinstance(preserve,list) + assert all(isinstance(p, int) and 0 <= p <= 11 for p in preserve) + import numpy as np + from PIL import Image + mask_path = path_img_2_path_mask(path_img) + mask = Image.open(mask_path).convert('L') + mask = np.array(mask) + mask = np.isin(mask, preserve) + return mask + + + +def get_forehead_mask(sm_mask): + # return mask (np bool) where the forehead (face above eyebrows) is True + sm_mask = np.array(sm_mask) + # 6 is face (excluding facial parts); keep only the forehead part + # First get all face pixels + face_mask = (sm_mask == 6) + # Get eyebrow pixels to determine forehead boundary + # if 2 in sm, ; elif 3(eyes) in ; elif 10(eye_glass) in ; else + if 2 in sm_mask: + eyebrow_mask = (sm_mask == 2) + eyebrow_coords = np.where(eyebrow_mask) + eyebrow_top = np.min(eyebrow_coords[0]) + # Forehead is face region above eyebrows + forehead_mask = face_mask & (np.arange(sm_mask.shape[0])[:, None] < eyebrow_top) + elif 3 in sm_mask: + eye_mask = (sm_mask == 3) + eye_coords = np.where(eye_mask) + eye_top = np.min(eye_coords[0]) + # Estimate forehead as region above eyes with some margin + forehead_threshold = eye_top - 20 # 20 pixels above eyes as forehead + forehead_mask = face_mask & (np.arange(sm_mask.shape[0])[:, None] < forehead_threshold) + elif 10 in sm_mask: + glass_mask = (sm_mask == 10) + glass_coords = np.where(glass_mask) + glass_top = np.min(glass_coords[0]) + # Forehead is face region above glasses + forehead_mask = face_mask & (np.arange(sm_mask.shape[0])[:, None] < glass_top) + else: + # If no eyebrows detected, keep upper portion of face + face_coords = np.where(face_mask) + if len(face_coords[0]) > 0: + face_top = np.min(face_coords[0]) + face_height = np.max(face_coords[0]) - face_top + forehead_threshold = face_top + face_height * 0.15 # top 15% as forehead + forehead_mask = face_mask & (np.arange(sm_mask.shape[0])[:, None] < forehead_threshold) + else: + forehead_mask = np.zeros_like(face_mask, dtype=bool) + forehead_mask = forehead_mask & face_mask + return forehead_mask diff --git a/global_.py b/global_.py new file mode 100644 index 0000000000000000000000000000000000000000..c6508c969942b399ba2dd54c59075f58196a72cc --- /dev/null +++ b/global_.py @@ -0,0 +1,9 @@ +""" +some global variables +""" +task :int = None # current batch task id + +TP_enable:bool = None # None means not set yet. should be set in imports.py +rank_:int = None +moduleName_2_adaRank:dict = {} # adaptive rank for each shared+LoRA module + diff --git a/hf_model.py b/hf_model.py new file mode 100644 index 0000000000000000000000000000000000000000..2cdf46af40ffe674f35879fcae39444d0261c3b2 --- /dev/null +++ b/hf_model.py @@ -0,0 +1,247 @@ +""" +Hugging Face Hub compatible model wrapper for UniBioTransfer. +Provides from_pretrained() and push_to_hub() functionality via PyTorchModelHubMixin. +""" +from pathlib import Path +import torch +import json +import copy +import os +from huggingface_hub import PyTorchModelHubMixin, hf_hub_download + +import global_ +from ldm.models.diffusion.ddpm import LatentDiffusion, LandmarkExtractor +from ldm.util import instantiate_from_config +from omegaconf import OmegaConf +from pytorch_lightning import seed_everything +from MoE import offload_unused_tasks__LD +from multiTask_model import TaskSpecific_MoE, replace_modules_lossless +from my_py_lib.torch_util import cleanup_gpu_memory + +TASKS = (0, 1, 2, 3) +TASK_NAME2ID = {"face": 0, "hair": 1, "motion": 2, "head": 3} +TASK_ID2NAME = {v: k for k, v in TASK_NAME2ID.items()} + +SD14_FILENAME = "sd-v1-4.ckpt" +SD14_REPO = "CompVis/stable-diffusion-v-1-4-original" +PRETRAIN_REPO = "scy639/UniBioTransfer" + + +def _load_first_stage_from_sd14(model, sd14_path): + """Load first_stage_model (VAE) from SD v1.4 checkpoint.""" + print(f"Loading first_stage_model from {sd14_path}") + sd14 = torch.load(str(sd14_path), map_location="cpu") + if isinstance(sd14, dict) and "state_dict" in sd14: + sd14_sd = sd14["state_dict"] + else: + sd14_sd = sd14 + + prefixes = ["first_stage_model.", "model.first_stage_model."] + fs_sd = {} + for prefix in prefixes: + for k, v in sd14_sd.items(): + if k.startswith(prefix): + fs_sd[k[len(prefix):]] = v + if fs_sd: + break + + if not fs_sd: + raise RuntimeError("Could not find first_stage_model weights in SD v1-4 checkpoint.") + + model.first_stage_model.load_state_dict(fs_sd, strict=True) + + +class UniBioTransferModel(LatentDiffusion, PyTorchModelHubMixin): + """ + Hugging Face Hub compatible wrapper for UniBioTransfer. + + Inherits from LatentDiffusion and adds HF Hub integration via PyTorchModelHubMixin. + + Usage: + # Load model from HF Hub + model = UniBioTransferModel.from_pretrained("scy639/UniBioTransfer", task="face") + + # Push to HF Hub + model.push_to_hub("your-repo/UniBioTransfer") + + Args: + config: Model config dict (handled by PyTorchModelHubMixin) + task: Task name or ID (face/hair/motion/head) + **kwargs: Additional arguments passed to LatentDiffusion + """ + + def __init__(self, config=None, task="face", **kwargs): + self._task_name = task if isinstance(task, str) else TASK_ID2NAME.get(task, "face") + self._task_id = TASK_NAME2ID.get(self._task_name, 0) if isinstance(task, str) else task + + global_.task = self._task_id + + if config is None: + config = {} + + super().__init__(**config) + + self._hf_config = { + "task": self._task_name, + "task_id": self._task_id, + } + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path=None, + task="face", + device="cuda", + download_sd14=True, + download_deps=True, + cache_dir=None, + **kwargs, + ): + """ + Load model from Hugging Face Hub. + + Args: + pretrained_model_name_or_path: HF repo ID or local path. + Default: "scy639/UniBioTransfer" + task: Task name (face/hair/motion/head) or task ID (0/1/2/3) + device: Device to load model to ("cuda" or "cpu") + download_sd14: Whether to download SD v1.4 VAE weights + download_deps: Whether to download other dependencies (ArcFace, DLIB, face_parsing) + cache_dir: Cache directory for downloads + **kwargs: Additional arguments + + Returns: + UniBioTransferModel: Loaded model + """ + task_id = TASK_NAME2ID.get(task, task) if isinstance(task, str) else task + task_name = TASK_ID2NAME.get(task_id, "face") + + global_.task = task_id + + if pretrained_model_name_or_path is None: + pretrained_model_name_or_path = PRETRAIN_REPO + + repo_id = pretrained_model_name_or_path + + cache_dir = Path(cache_dir) if cache_dir else Path(".") + + ckpt_path = cache_dir / "checkpoints" / "pretrained.ckpt" + json_path = cache_dir / "checkpoints" / "pretrained.json" + sd14_path = cache_dir / "checkpoints" / SD14_FILENAME + arcface_path = cache_dir / "Other_dependencies" / "arcface" / "model_ir_se50.pth" + face_parsing_path = cache_dir / "Other_dependencies" / "face_parsing" / "79999_iter.pth" + + def _download_file(repo, filename, local_path): + local_path = Path(local_path) + local_path.parent.mkdir(parents=True, exist_ok=True) + print(f"Downloading {filename} from {repo}...") + token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN") + hf_hub_download( + repo_id=repo, + filename=filename, + local_dir=str(local_path.parent), + local_dir_use_symlinks=False, + token=token, + ) + + if not ckpt_path.exists(): + _download_file(repo_id, "checkpoints/pretrained.ckpt", ckpt_path) + if not json_path.exists(): + _download_file(repo_id, "checkpoints/pretrained.json", json_path) + + if download_sd14 and not sd14_path.exists(): + _download_file(SD14_REPO, SD14_FILENAME, sd14_path) + + if download_deps: + if not arcface_path.exists(): + _download_file(repo_id, "Other_dependencies/arcface/model_ir_se50.pth", arcface_path) + if not face_parsing_path.exists(): + _download_file(repo_id, "Other_dependencies/face_parsing/79999_iter.pth", face_parsing_path) + + seed_everything(42) + + cur_dir = Path(__file__).parent + yaml_path = cur_dir / "LatentDiffusion.yaml" + if not yaml_path.exists(): + yaml_path = Path("LatentDiffusion.yaml") + + model_config = OmegaConf.load(yaml_path).model + model = instantiate_from_config(model_config) + + with open(json_path, 'r') as f: + global_.moduleName_2_adaRank = json.load(f) + print(f"Loaded adaptive rank config from {json_path}") + + _src0 = copy.deepcopy(model.model.diffusion_model) + _src1 = copy.deepcopy(model.model.diffusion_model) + _src2 = copy.deepcopy(model.model.diffusion_model) + _src3 = copy.deepcopy(model.model.diffusion_model) + replace_modules_lossless( + model.model.diffusion_model, + [_src0, _src1, _src2, _src3], + [0, 1, 2, 3], + parent_name=".model.diffusion_model", + ) + + model.ID_proj_out = TaskSpecific_MoE([ + copy.deepcopy(model.ID_proj_out), + copy.deepcopy(model.ID_proj_out), + copy.deepcopy(model.ID_proj_out), + ], [0, 2, 3]) + model.landmark_proj_out = TaskSpecific_MoE([ + copy.deepcopy(model.landmark_proj_out), + copy.deepcopy(model.landmark_proj_out), + copy.deepcopy(model.landmark_proj_out), + ], [0, 2, 3]) + model.proj_out_source__head = TaskSpecific_MoE([ + copy.deepcopy(model.proj_out_source__head), + copy.deepcopy(model.proj_out_source__head), + ], [2, 3]) + + from util_and_constant import REFNET + if REFNET.ENABLE: + shared_ref = model.model.diffusion_model_refNet + src0 = shared_ref + src1 = copy.deepcopy(shared_ref) + src2 = copy.deepcopy(shared_ref) + src3 = copy.deepcopy(shared_ref) + replace_modules_lossless(shared_ref, [src0, src1, src2, src3], [0, 1, 2, 3], parent_name=".model.diffusion_model_refNet", for_refnet=True) + from ldm.models.diffusion.bank import Bank + model.model.bank = Bank( + reader=model.model.diffusion_model, + writer=model.model.diffusion_model_refNet + ) + + print(f"Loading model weights from {ckpt_path}") + pl_sd = torch.load(str(ckpt_path), map_location="cpu") + if isinstance(pl_sd, dict) and "state_dict" in pl_sd: + sd = pl_sd["state_dict"] + else: + sd = pl_sd + + m, u = model.load_state_dict(sd, strict=False) + if len(m) > 0: + print(f"Missing keys: {len(m)}") + if len(u) > 0: + print(f"Unexpected keys: {len(u)}") + + _load_first_stage_from_sd14(model, sd14_path) + + # offload_unused_tasks__LD(model, task_id, method="cpu") + + model.ptsM_Generator = LandmarkExtractor(include_visualizer=True, img_256_mode=False) + cleanup_gpu_memory() + + # ZeroGPU 兼容:只在 device 不是 "cpu" 且 CUDA 可用时才移动到 GPU + # 如果传入 device="cpu",保持模型在 CPU 上(ZeroGPU 初始化时不碰显卡) + if device != "cpu" and torch.cuda.is_available(): + model = model.to(torch.device(device)) + else: + model = model.to(torch.device("cpu")) + model.eval() + + model._task_id = task_id + model._task_name = task_name + model._hf_config = {"task": task_name, "task_id": task_id} + + return model diff --git a/imports.py b/imports.py new file mode 100644 index 0000000000000000000000000000000000000000..e27df894cb5242590eb1a96de51fc1db2e2c8287 --- /dev/null +++ b/imports.py @@ -0,0 +1,8 @@ + + + +#--------------------------------------------------------------------------------------------------------------------- +from util_and_constant import * +from get_mask import * +from util_cv2 import * + diff --git a/infer.py b/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..6a59c40ebfb265957cb672d6b0f82afbcdd1f386 --- /dev/null +++ b/infer.py @@ -0,0 +1,366 @@ +# --------------------------------------------------------- Config ------------------------------------------------- +num_workers :int = 1 +DDIM_STEPS = 50 +BATCH_SIZE = 1 +FIXED_CODE = False +# for vis +SAVE_INTERMEDIATES = True +NUM_grid_in_a_column = 5 +# ------------------------------------------------------------------------------------------------------------------------ +import argparse +parser = argparse.ArgumentParser(description="Custom inference for tgt/ref image pairs.") +parser.add_argument("--task-name", type=str, + default='face', + help="face|hair|motion|head") +parser.add_argument("--out-dir", type=str, default='examples/outputs', help="Output directory") +# option 1: pass 2 paths +parser.add_argument("--tgt", type=str, default=None, help="Path to target image. if None, will use paths read from --pair-list") +parser.add_argument("--ref", type=str, default=None, help="Path to reference image") +# option 2: pass a txt containing paths +parser.add_argument("--pair-list", type=str, default='examples/inputs.txt', help="white-space-separated list file: tgt_path ref_path") +args = parser.parse_args() + +#-----------------------------------------set TASK-------------------------------------------------------------------------- + +task_name :str = args.task_name +TASK :int = { + 'face': 0, + 'hair': 1, + 'motion': 2, + 'head': 3, +}[task_name] +print(f'task: {task_name} transfer (ID: {TASK})') +# ------------------------------------------------------------------------------------------------------------------------ + + +import sys +import os +from pathlib import Path + +cur_dir = os.path.dirname(os.path.abspath(__file__)) + +from imports import * +import torch +import numpy as np +from omegaconf import OmegaConf +from PIL import Image +from tqdm import tqdm +from einops import rearrange +from torchvision.utils import make_grid +from my_py_lib.image_util import imgs_2_grid_A,img_paths_2_grid_A +from pytorch_lightning import seed_everything +from torch import autocast +from contextlib import nullcontext +import torchvision + +from ldm.models.diffusion.ddpm import LatentDiffusion +from ldm.util import instantiate_from_config +from ldm.models.diffusion.ddim import DDIMSampler +from Dataset_custom import Dataset_custom +from MoE import offload_unused_tasks__LD +from ldm.models.diffusion.ddpm import LandmarkExtractor +from my_py_lib.torch_util import cleanup_gpu_memory +from gen_lmk_and_mask import gen_lmk_and_mask + + + + + + + + + +# ------------------------------------------------------------------------------------------------------------------------ +DDIM_ETA = 0.0 +SCALE = 3.0 +PRECISION = "full" # "full" or "autocast" +H = 512 +W = 512 +C = 4 +F = 8 +# ------------------------------------------------------------------------------------------------------------------------ + + +def load_first_stage_from_sd14(model: LatentDiffusion, sd14_path: Path) -> None: + print(f"Loading first_stage_model from {sd14_path}") + sd14 = torch.load(str(sd14_path), map_location="cpu") + if isinstance(sd14, dict) and "state_dict" in sd14: + sd14_sd = sd14["state_dict"] + else: + sd14_sd = sd14 + + prefixes = ["first_stage_model.", "model.first_stage_model."] + fs_sd = {} + for prefix in prefixes: + for k, v in sd14_sd.items(): + if k.startswith(prefix): + fs_sd[k[len(prefix):]] = v + if fs_sd: + break + + if not fs_sd: + raise RuntimeError("Could not find first_stage_model weights in SD v1-4 checkpoint.") + + model.first_stage_model.load_state_dict(fs_sd, strict=True) + + +def save_sample_by_decode(x, model, base_path, segment_id, intermediate_num): + x = model.decode_first_stage(x) + x = torch.clamp((x + 1.0) / 2.0, min=0.0, max=1.0) + x = x.cpu().permute(0, 2, 3, 1).numpy() + for i in range(len(x)): + img = Image.fromarray((x[i] * 255).astype(np.uint8)) + save_path = Path(base_path) / segment_id + save_path.mkdir(parents=True, exist_ok=True) + img.save(save_path / f"{intermediate_num}.png") + + +def get_tensor_clip(normalize=True, toTensor=True): + transform_list = [] + if toTensor: + transform_list += [torchvision.transforms.ToTensor()] + if normalize: + transform_list += [ + torchvision.transforms.Normalize( + (0.48145466, 0.4578275, 0.40821073), + (0.26862954, 0.26130258, 0.27577711), + ) + ] + return torchvision.transforms.Compose(transform_list) + + +def load_model_from_config(ckpt, verbose=1): + if 1: + ckpt = Path(ckpt) + print(f"Loading model from {ckpt}") + pl_sd = torch.load(str(ckpt), map_location="cpu") + if isinstance(pl_sd, dict) and "state_dict" in pl_sd: + sd = pl_sd["state_dict"] + else: + sd = pl_sd + else: + print("DEBUG_skip_load_ckpt") + if 1: + from init_model import get_moe + model: LatentDiffusion = get_moe() + model.ptsM_Generator = LandmarkExtractor(include_visualizer=True, img_256_mode=False) + cleanup_gpu_memory() + if 1: + m, u = model.load_state_dict(sd, strict=False) + if len(m) > 0 and verbose: + print("missing keys:") + pretty_print_torch_module_keys(m) + if len(u) > 0 and verbose: + print("unexpected keys:") + pretty_print_torch_module_keys(u) + load_first_stage_from_sd14(model, SD14_localpath) + + offload_unused_tasks__LD(model, TASK, method="del") # for save cuda mem + model.cuda() + model.eval() + return model + + + + +def load_pairs(pair_list, tgt, ref): + if tgt and ref: + pairs = [(tgt, ref), ] + elif pair_list: + pairs = [] + with open(pair_list, "r") as f: + for line_num, line in enumerate(f, start=1): + line = line.strip() + if not line or line.startswith("#"): + continue + parts = line.split(" ") + if len(parts) != 2: + raise ValueError(f"Invalid pair list line {line_num}: expected white-space-separated tgt/ref. got {parts=}") + pairs.append((parts[0], parts[1])) + else: + raise ValueError("No input pairs provided. Use --tgt/--ref or --pair-list.") + print(f"{pairs=}") + return pairs + + +def un_norm(x): + return (x + 1.0) / 2.0 + + +def un_norm_clip(x1): + x = x1 * 1.0 + reduce = False + if len(x.shape) == 3: + x = x.unsqueeze(0) + reduce = True + x[:, 0, :, :] = x[:, 0, :, :] * 0.26862954 + 0.48145466 + x[:, 1, :, :] = x[:, 1, :, :] * 0.26130258 + 0.4578275 + x[:, 2, :, :] = x[:, 2, :, :] * 0.27577711 + 0.40821073 + if reduce: + x = x.squeeze(0) + return x + + +if __name__ == "__main__": + pairs = load_pairs(args.pair_list, args.tgt, args.ref) + + out_dir = Path(args.out_dir) + result_path = out_dir / "results" + grid_path = out_dir / "grid" + inter_path = out_dir / "intermediates" + inter_pred_path = inter_path / "pred_x0" + inter_noised_path = inter_path / "noised" + out_dir.mkdir(parents=False, exist_ok=True) + result_path.mkdir(parents=False, exist_ok=True) + grid_path.mkdir(parents=False, exist_ok=True) + inter_path.mkdir(parents=False, exist_ok=True) + if SAVE_INTERMEDIATES: + inter_pred_path.mkdir(parents=False, exist_ok=True) + inter_noised_path.mkdir(parents=False, exist_ok=True) + paths_tgt = [p[0] for p in pairs] + paths_ref = [p[1] for p in pairs] + gen_lmk_and_mask(paths_tgt + paths_ref) + + seed_everything(42) + + model: LatentDiffusion = load_model_from_config(PRETRAIN_CKPT_PATH, ) + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + model = model.to(device) + sampler = DDIMSampler(model) + + dataset = Dataset_custom( + "test", + task=TASK, + paths_tgt=paths_tgt, + paths_ref=paths_ref, + ) + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=BATCH_SIZE, + num_workers=num_workers, + pin_memory=True, + shuffle=False, + drop_last=False, + ) + + start_code = None + if FIXED_CODE: + start_code = torch.randn([BATCH_SIZE, C, H // F, W // F], device=device) + + precision_scope = autocast if PRECISION == "autocast" else nullcontext + grids = [] + grid_stems = [] + + with torch.no_grad(): + with precision_scope("cuda"): + with model.ema_scope(): + for test_batch, prior, test_model_kwargs, out_stem_batch in tqdm(dataloader): + model.set_task(test_model_kwargs) + bs = test_batch.shape[0] + + batch_ = { + **test_model_kwargs, + "GT": torch.zeros_like(test_model_kwargs["inpaint_image"]), + } + batch_, c = model.get_input_and_conditioning(batch_, device=device) + z_inpaint = batch_["z4_inpaint"] + z_inpaint_mask = batch_["tgt_mask_64"] + z_ref = batch_["z_ref"] + z9 = batch_["z9"] + + uc = None + if SCALE != 1.0: + uc = model.learnable_vector[TASK].repeat(bs, 1, 1) + + shape = [C, H // F, W // F] + local_start_code = start_code + if FIXED_CODE and (local_start_code is None or local_start_code.shape[0] != bs): + local_start_code = torch.randn([bs, C, H // F, W // F], device=device) + samples_ddim, intermediates = sampler.sample( + S=DDIM_STEPS, + conditioning=c, + batch_size=bs, + shape=shape, + verbose=False, + unconditional_guidance_scale=SCALE, + unconditional_conditioning=uc, + eta=DDIM_ETA, + x_T=local_start_code, + log_every_t=100, + z_inpaint=z_inpaint, + z_inpaint_mask=z_inpaint_mask, + z_ref=z_ref, + z9=z9, + ) + + if SAVE_INTERMEDIATES: + intermediate_pred_x0 = intermediates["pred_x0"] + intermediate_noised = intermediates["x_inter"] + for i in range(len(intermediate_pred_x0)): + for j in range(bs): + stem = f"{out_stem_batch[j]}" + save_sample_by_decode( + intermediate_pred_x0[i][j : j + 1], + model, + inter_pred_path, + stem, + i, + ) + save_sample_by_decode( + intermediate_noised[i][j : j + 1], + model, + inter_noised_path, + stem, + i, + ) + + x_samples_ddim = model.decode_first_stage(samples_ddim) + x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy() + + x_checked_image_torch = torch.from_numpy(x_samples_ddim).permute(0, 3, 1, 2) + for i, x_sample in enumerate(x_checked_image_torch): + stem = f"{out_stem_batch[i]}" + out_path = result_path / f"{stem}.png" + img = Image.fromarray((x_sample.permute(1, 2, 0).numpy() * 255).astype(np.uint8)) + img.save(out_path) + print(f"{out_path=}") + + for i, x_sample in enumerate(x_checked_image_torch): + all_img = [] + all_img.append(un_norm(test_batch[i]).cpu()) + if TASK != 2: + ref_img = test_model_kwargs["ref_imgs"].squeeze(1) + ref_img = torchvision.transforms.Resize([512, 512])(ref_img) + ref_img = un_norm_clip(ref_img[i]).cpu() + else: + ref_img = un_norm(test_model_kwargs["ref512"].squeeze(1)[i]).cpu() + all_img.append(ref_img) + all_img.append(x_sample) + + grid = torch.stack(all_img, 0) + grid = make_grid(grid) + grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy() + img = Image.fromarray(grid.astype(np.uint8)) + stem = f"{out_stem_batch[i]}" + path_save_img = grid_path / f"grid-{stem}.jpg" + img.save(path_save_img) + print(f"{path_save_img=}") + grids.append(img) + grid_stems.append(stem) + if len(grids) >= NUM_grid_in_a_column: + stem_start = grid_stems[0] + stem_end = grid_stems[-1] + grid_column = imgs_2_grid_A( + grids, + grid_layout='column', + grid_path=os.path.join(grid_path, f"{stem_start}--{stem_end}.jpg"), + ) + grids = [] + grid_stems = [] + + model.unset_task() + + print(f"Your samples are ready and waiting for you here: {out_dir}") + + diff --git a/infer_hf.py b/infer_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..f22a0fa5427108391be02285655a5fe0b3767c21 --- /dev/null +++ b/infer_hf.py @@ -0,0 +1,279 @@ +""" +High-level inference pipeline for UniBioTransfer. +Designed for easy use in Hugging Face Spaces and other applications. + +ZeroGPU Compatible: +- Supports CPU initialization (device="cpu") +- Dynamically switches to CUDA during inference when called from @spaces.GPU +""" +from pathlib import Path +import torch +import numpy as np +from PIL import Image +import cv2 + +import global_ +from hf_model import UniBioTransferModel, TASK_NAME2ID, TASK_ID2NAME +from ldm.models.diffusion.ddim import DDIMSampler +from pytorch_lightning import seed_everything + +DDIM_STEPS_DEFAULT = 50 +SCALE_DEFAULT = 3.0 + + +H, W, C, F = 512, 512, 4, 8 +class UniBioTransferPipeline: + """ + High-level pipeline for UniBioTransfer inference. + """ + + def __init__(self, model, task="face", device="cpu"): + """ + Initialize pipeline with a loaded model. + """ + self.model = model + self.task = task + self.task_id = TASK_NAME2ID.get(task, task) if isinstance(task, str) else task + self._init_device = device + + global_.task = self.task_id + self.model.task = self.task_id + + self.sampler = DDIMSampler(model) + + @classmethod + def from_pretrained( + cls, + repo_id="scy639/UniBioTransfer", + task="face", + device="cpu", + cache_dir=None, + **kwargs, + ): + """ + Load pipeline from Hugging Face Hub. + """ + model = UniBioTransferModel.from_pretrained( + pretrained_model_name_or_path=repo_id, + task=task, + device=device, + cache_dir=cache_dir, + **kwargs, + ) + return cls(model, task=task, device=device) + + def set_task(self, task): + """Switch to a different task.""" + self.task = task + self.task_id = TASK_NAME2ID.get(task, task) if isinstance(task, str) else task + global_.task = self.task_id + self.model.task = self.task_id + + def __call__( + self, + tgt_image, + ref_image, + ddim_steps=DDIM_STEPS_DEFAULT, + scale=SCALE_DEFAULT, + seed=42, + num_images=1, + ): + """ + Run inference on a pair of images. + """ + seed_everything(seed) + + tgt_img = self._load_image(tgt_image) + ref_img = self._load_image(ref_image) + + tgt_img = self._resize_image(tgt_img, (H, W)) + ref_img = self._resize_image(ref_img, (H, W)) + + result_tensors = self._run_inference(tgt_img, ref_img, ddim_steps, scale, num_images) + + result_imgs = [self._postprocess(result_tensors[i]) for i in range(result_tensors.shape[0])] + return result_imgs + + def _load_image(self, img): + """Load image from various formats.""" + if isinstance(img, Image.Image): + return img.convert("RGB") + elif isinstance(img, np.ndarray): + return Image.fromarray(img).convert("RGB") + elif isinstance(img, (str, Path)): + return Image.open(img).convert("RGB") + else: + raise ValueError(f"Unsupported image type: {type(img)}") + + def _resize_image(self, img, size): + """Resize image to target size.""" + if img.size != size: + img = img.resize(size, Image.LANCZOS) + return img + + def _run_inference(self, tgt_img, ref_img, ddim_steps, scale, num_images): + """ + Run diffusion sampling. + 完全复用 infer.py 的逻辑,使用 dataloader。 + """ + from Dataset_custom import Dataset_custom + from gen_lmk_and_mask import gen_lmk_and_mask + import tempfile + + with tempfile.TemporaryDirectory() as tmpdir: + tgt_path = Path(tmpdir) / "tgt.png" + ref_path = Path(tmpdir) / "ref.png" + tgt_img.save(tgt_path) + ref_img.save(ref_path) + + gen_lmk_and_mask([str(tgt_path), str(ref_path)], write_cache=True) + + dataset = Dataset_custom( + "test", + task=self.task_id, + paths_tgt=[str(tgt_path)], + paths_ref=[str(ref_path)], + ) + + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=1, + num_workers=1, + pin_memory=True, + shuffle=False, + drop_last=False, + ) + + run_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.model = self.model.to(run_device) + + with torch.no_grad(): + for test_batch, prior, test_model_kwargs, out_stem_batch in dataloader: + test_batch = test_batch.to(run_device) + if test_batch.shape[0] == 1: + test_batch = test_batch.repeat(num_images, 1, 1, 1) + if isinstance(prior, torch.Tensor): + prior = prior.to(run_device) + if prior.shape[0] == 1: + prior = prior.repeat(num_images, 1, 1, 1) + for k, v in test_model_kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.to(run_device) + if v.shape[0] == 1: + repeats = [num_images] + [1] * (v.ndim - 1) + v = v.repeat(*repeats) + test_model_kwargs[k] = v + elif isinstance(v, dict): + new_v = {} + for kk, vv in v.items(): + if isinstance(vv, torch.Tensor): + vv = vv.to(run_device) + if vv.shape[0] == 1: + repeats = [num_images] + [1] * (vv.ndim - 1) + vv = vv.repeat(*repeats) + new_v[kk] = vv + else: + new_v[kk] = vv + test_model_kwargs[k] = new_v + elif isinstance(v, list): + test_model_kwargs[k] = v * num_images + + self.model.set_task(test_model_kwargs) + bs = num_images + + batch_ = { + **test_model_kwargs, + "GT": torch.zeros(num_images, *test_model_kwargs["inpaint_image"].shape[1:], device=run_device), + } + batch_, c = self.model.get_input_and_conditioning(batch_, device=run_device) + + z_inpaint = batch_["z4_inpaint"] + z_inpaint_mask = batch_["tgt_mask_64"] + z_ref = batch_["z_ref"] + z9 = batch_["z9"] + + uc = None + if scale != 1.0: + uc = self.model.learnable_vector[self.task_id].repeat(bs, 1, 1) + + shape = [C, H // F, W // F] + start_code = None + + samples_ddim, _ = self.sampler.sample( + S=ddim_steps, + conditioning=c, + batch_size=bs, + shape=shape, + verbose=False, + unconditional_guidance_scale=scale, + unconditional_conditioning=uc, + eta=0.0, + x_T=start_code, + log_every_t=100, + z_inpaint=z_inpaint, + z_inpaint_mask=z_inpaint_mask, + z_ref=z_ref, + z9=z9, + ) + + x_samples_ddim = self.model.decode_first_stage(samples_ddim) + x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + + self.model.unset_task() + + return x_samples_ddim + + def _postprocess(self, tensor): + """Convert model output tensor to PIL Image.""" + img_array = tensor.cpu().permute(1, 2, 0).numpy() + img_array = (img_array * 255).astype(np.uint8) + return Image.fromarray(img_array) + + +def infer_single( + tgt_path, + ref_path, + task="face", + output_path=None, + ddim_steps=DDIM_STEPS_DEFAULT, + scale=SCALE_DEFAULT, + device="cuda", +): + """ + Convenience function for single inference. + """ + pipeline = UniBioTransferPipeline.from_pretrained(task=task, device=device) + result = pipeline(tgt_path, ref_path, ddim_steps=ddim_steps, scale=scale) + + if output_path is not None: + result.save(output_path) + print(f"Saved result to {output_path}") + + return result + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="UniBioTransfer inference") + parser.add_argument("--task", type=str, default="face", choices=["face", "hair", "motion", "head"]) + parser.add_argument("--tgt", type=str, required=True, help="Path to target image") + parser.add_argument("--ref", type=str, required=True, help="Path to reference image") + parser.add_argument("--out", type=str, default="result.png", help="Output path") + parser.add_argument("--ddim-steps", type=int, default=50) + parser.add_argument("--scale", type=float, default=3.0) + parser.add_argument("--device", type=str, default="cuda") + + args = parser.parse_args() + + result = infer_single( + args.tgt, + args.ref, + task=args.task, + output_path=args.out, + ddim_steps=args.ddim_steps, + scale=args.scale, + device=args.device, + ) + + print(f"Inference complete. Result shape: {result.size}") diff --git a/init_model.py b/init_model.py new file mode 100644 index 0000000000000000000000000000000000000000..ed3ace7798c2620bd3cda5fe6b97b05a60c75455 --- /dev/null +++ b/init_model.py @@ -0,0 +1,178 @@ +import sys,os +cur_dir = os.path.dirname(os.path.abspath(__file__)) +if __name__=='__main__': sys.path.append(os.path.abspath(os.path.join(cur_dir, '..'))) + +from imports import * +import json +import argparse, os, sys, glob +import cv2 +import torch +import numpy as np +from MoE import * +from multiTask_model import * +from lora_layers import * +from omegaconf import OmegaConf +from PIL import Image +from tqdm import tqdm, trange +from itertools import islice +from einops import rearrange +from torchvision.utils import make_grid +from my_py_lib.image_util import imgs_2_grid_A,img_paths_2_grid_A +import time +import copy +from pytorch_lightning import seed_everything +from torch import autocast +from contextlib import contextmanager, nullcontext +import torchvision +from ldm.models.diffusion.ddpm import LatentDiffusion +from ldm.models.diffusion.bank import Bank +from ldm.util import instantiate_from_config + +from ldm.models.diffusion.ddim import DDIMSampler + +from transformers import AutoFeatureExtractor + +# import clip +from torchvision.transforms import Resize +from fnmatch import fnmatch + + +from PIL import Image +from torchvision.transforms import PILToTensor +#---------------------------------------------------------------------------- + + +def get_moe(): + if 1: + seed_everything(42) + # torch.cuda.set_device(opt.device_ID) + model :LatentDiffusion = instantiate_from_config(OmegaConf.load(f"LatentDiffusion.yaml").model,) + if REFNET.ENABLE: + assert model.model.diffusion_model_refNet.is_refNet + + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + device = torch.device("cpu") + model = model.to(device) + if FOR_upcycle_ckpt_GEN_or_USE: + del model.ptsM_Generator + + def average_module_weight( + src_modules: list, + ): + """Average the weights of multiple modules""" + if not src_modules: + return None + # Get the state dict of the first module as template + avg_state_dict = {} + first_state_dict = src_modules[0].state_dict() + # Initialize with zeros + for key in first_state_dict: + avg_state_dict[key] = torch.zeros_like(first_state_dict[key]) + # Sum + for module in src_modules: + module_state_dict = module.state_dict() + for key in avg_state_dict: + avg_state_dict[key] += module_state_dict[key] + # Average + for key in avg_state_dict: + avg_state_dict[key] /= len(src_modules) + return avg_state_dict + def recursive_average_module_weight( + tgt_module: nn.Module, + src_modules: list, + cb, + ): + """ + Recursively find modules and replace with averaged weights based on callback + """ + for name, child in tgt_module.named_children(): + if 1: # Get corresponding modules from source models + src_child_modules = [] + for src_module in src_modules: + src_child = getattr(src_module, name) + assert src_child is not None,name + src_child_modules.append(src_child) + # assert not isinstance(child, TaskSpecific_MoE) + if cb(child, name, tgt_module): + print(f"[recursive_average_module_weight] {name=} child: {repr(child)[:50]} tgt_module: {repr(tgt_module)[:50]}") + # Average & load + avg_weights = average_module_weight(src_child_modules) + child.load_state_dict(avg_weights) + else: + recursive_average_module_weight(child, src_child_modules, cb) + return tgt_module + + def replace_module_with_TaskSpecific( + tgt_module: nn.Module,# tgt module + src_modules: list, + cb, + parent_name: str = "", + depth :int = 0, + ): + for name, child in tgt_module.named_children(): + if 1: # Get corresponding modules from source models + src_child_modules = [] + for src_module in src_modules: + src_child = getattr(src_module, name) + assert src_child is not None,name + src_child_modules.append(src_child) + assert not isinstance(child, TaskSpecific_MoE) + full_name = f"{parent_name}.{name}" + if cb(child, name, full_name, tgt_module): + print(f"[replace_module_with_TaskSpecific] {name=} child: {repr(child)[:50]} tgt_module: {repr(tgt_module)[:50]}") + setattr(tgt_module, name, TaskSpecific_MoE(src_child_modules,TASKS)) + else: + if depth<=0: + replace_module_with_TaskSpecific(child, src_child_modules,cb,parent_name=full_name,depth=depth+1) + return tgt_module + + if not FOR_upcycle_ckpt_GEN_or_USE: + modelMOE :LatentDiffusion = model + del model + if 1: # ensure distinct module instances per task (avoid shared identities) + with open(PRETRAIN_JSON_PATH, 'r') as f: global_.moduleName_2_adaRank = json.load(f) + print(f"loaded from {PRETRAIN_JSON_PATH=}") + _src0 = copy.deepcopy(modelMOE.model.diffusion_model) + _src1 = copy.deepcopy(modelMOE.model.diffusion_model) + _src2 = copy.deepcopy(modelMOE.model.diffusion_model) + _src3 = copy.deepcopy(modelMOE.model.diffusion_model) + replace_modules_lossless( + modelMOE.model.diffusion_model, + [ _src0, _src1, _src2, _src3 ], + [0,1,2,3], + parent_name=".model.diffusion_model", + ) + # Build-time dummy wrapping for task-specific heads so that ckpt keys match + modelMOE.ID_proj_out = TaskSpecific_MoE([ + copy.deepcopy(modelMOE.ID_proj_out), + copy.deepcopy(modelMOE.ID_proj_out), + copy.deepcopy(modelMOE.ID_proj_out), + ], [0,2,3]) + modelMOE.landmark_proj_out = TaskSpecific_MoE([ + copy.deepcopy(modelMOE.landmark_proj_out), + copy.deepcopy(modelMOE.landmark_proj_out), + copy.deepcopy(modelMOE.landmark_proj_out), + ], [0,2,3]) + modelMOE.proj_out_source__head = TaskSpecific_MoE([ + copy.deepcopy(modelMOE.proj_out_source__head), + copy.deepcopy(modelMOE.proj_out_source__head), + ], [2,3]) + # Upcycle single refNet using three source refNets, and keep only one + if REFNET.ENABLE: + shared_ref = modelMOE.model.diffusion_model_refNet + src0 = shared_ref + src1 = copy.deepcopy(shared_ref) + src2 = copy.deepcopy(shared_ref) + src3 = copy.deepcopy(shared_ref) + replace_modules_lossless(shared_ref, [src0, src1, src2, src3],[0,1,2,3], parent_name=".model.diffusion_model_refNet", for_refnet=True) + # load from ./modelMOE.ckpt + time.sleep(20*rank_) + print(f"ckpt load over. m,u:") + # Initialize bank here (after model structure is finalized) + if REFNET.ENABLE : + modelMOE.model.bank = Bank(reader=modelMOE.model.diffusion_model,writer=modelMOE.model.diffusion_model_refNet) + if __name__=='__main__': + for key in sorted( get_representative_moduleNames(modelMOE.state_dict().keys()) ): + print(f" - {key}") + return modelMOE + diff --git a/ldm/lr_scheduler.py b/ldm/lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..37baa26300ddf4d739f9c30dd65e87df5976a616 --- /dev/null +++ b/ldm/lr_scheduler.py @@ -0,0 +1,99 @@ +import numpy as np + + +class LambdaWarmUpCosineScheduler: + """ + note: use with a base_lr of 1.0 + """ + def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): + self.lr_warm_up_steps = warm_up_steps + self.lr_start = lr_start + self.lr_min = lr_min + self.lr_max = lr_max + self.lr_max_decay_steps = max_decay_steps + self.last_lr = 0. + self.verbosity_interval = verbosity_interval + + def schedule(self, n, **kwargs): + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") + if n < self.lr_warm_up_steps: + lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start + self.last_lr = lr + return lr + else: + t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) + t = min(t, 1.0) + lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( + 1 + np.cos(t * np.pi)) + self.last_lr = lr + return lr + + def __call__(self, n, **kwargs): + return self.schedule(n,**kwargs) + + +class LambdaWarmUpCosineScheduler2: + """ + supports repeated iterations, configurable via lists + note: use with a base_lr of 1.0. + """ + def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): + assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) + self.lr_warm_up_steps = warm_up_steps + self.f_start = f_start + self.f_min = f_min + self.f_max = f_max + self.cycle_lengths = cycle_lengths + self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) + self.last_f = 0. + self.verbosity_interval = verbosity_interval + + def find_in_interval(self, n): + interval = 0 + for cl in self.cum_cycles[1:]: + if n <= cl: + return interval + interval += 1 + + def schedule(self, n, **kwargs): + cycle = self.find_in_interval(n) + n = n - self.cum_cycles[cycle] + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " + f"current cycle {cycle}") + if n < self.lr_warm_up_steps[cycle]: + f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] + self.last_f = f + return f + else: + t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) + t = min(t, 1.0) + f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( + 1 + np.cos(t * np.pi)) + self.last_f = f + return f + + def __call__(self, n, **kwargs): + return self.schedule(n, **kwargs) + + +class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): + + def schedule(self, n, **kwargs):# n is the step index + cycle = self.find_in_interval(n) + n = n - self.cum_cycles[cycle] + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " + f"current cycle {cycle}") + + if n < self.lr_warm_up_steps[cycle]: + f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] + self.last_f = f + # print(f"0 {n=} {f=}") + return f + else: + f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) + self.last_f = f + # print(f"1 {n=} {f=}") + return f diff --git a/ldm/models/autoencoder.py b/ldm/models/autoencoder.py new file mode 100644 index 0000000000000000000000000000000000000000..6a9c4f45498561953b8085981609b2a3298a5473 --- /dev/null +++ b/ldm/models/autoencoder.py @@ -0,0 +1,443 @@ +import torch +import pytorch_lightning as pl +import torch.nn.functional as F +from contextlib import contextmanager + +from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer + +from ldm.modules.diffusionmodules.model import Encoder, Decoder +from ldm.modules.distributions.distributions import DiagonalGaussianDistribution + +from ldm.util import instantiate_from_config + + +class VQModel(pl.LightningModule): + def __init__(self, + ddconfig, + lossconfig, + n_embed, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + batch_resize_range=None, + scheduler_config=None, + lr_g_factor=1.0, + remap=None, + sane_index_shape=False, # tell vector quantizer to return indices as bhw + use_ema=False + ): + super().__init__() + self.embed_dim = embed_dim + self.n_embed = n_embed + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.loss = instantiate_from_config(lossconfig) + self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, + remap=remap, + sane_index_shape=sane_index_shape) + self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + if colorize_nlabels is not None: + assert type(colorize_nlabels)==int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + self.batch_resize_range = batch_resize_range + if self.batch_resize_range is not None: + print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.") + + self.use_ema = use_ema + if self.use_ema: + self.model_ema = LitEma(self) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + self.scheduler_config = scheduler_config + self.lr_g_factor = lr_g_factor + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.parameters()) + self.model_ema.copy_to(self) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict(sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + print(f"Unexpected Keys: {unexpected}") + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self) + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + quant, emb_loss, info = self.quantize(h) + return quant, emb_loss, info + + def encode_to_prequant(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return h + + def decode(self, quant): + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + + def decode_code(self, code_b): + quant_b = self.quantize.embed_code(code_b) + dec = self.decode(quant_b) + return dec + + def forward(self, input, return_pred_indices=False): + quant, diff, (_,_,ind) = self.encode(input) + dec = self.decode(quant) + if return_pred_indices: + return dec, diff, ind + return dec, diff + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + if self.batch_resize_range is not None: + lower_size = self.batch_resize_range[0] + upper_size = self.batch_resize_range[1] + if self.global_step <= 4: + # do the first few batches with max size to avoid later oom + new_resize = upper_size + else: + new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16)) + if new_resize != x.shape[2]: + x = F.interpolate(x, size=new_resize, mode="bicubic") + x = x.detach() + return x + + def training_step(self, batch, batch_idx, optimizer_idx): + # https://github.com/pytorch/pytorch/issues/37142 + # try not to fool the heuristics + x = self.get_input(batch, self.image_key) + xrec, qloss, ind = self(x, return_pred_indices=True) + + if optimizer_idx == 0: + # autoencode + aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train", + predicted_indices=ind) + + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return aeloss + + if optimizer_idx == 1: + # discriminator + discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return discloss + + def validation_step(self, batch, batch_idx): + log_dict = self._validation_step(batch, batch_idx) + with self.ema_scope(): + log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema") + return log_dict + + def _validation_step(self, batch, batch_idx, suffix=""): + x = self.get_input(batch, self.image_key) + xrec, qloss, ind = self(x, return_pred_indices=True) + aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, + self.global_step, + last_layer=self.get_last_layer(), + split="val"+suffix, + predicted_indices=ind + ) + + discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, + self.global_step, + last_layer=self.get_last_layer(), + split="val"+suffix, + predicted_indices=ind + ) + rec_loss = log_dict_ae[f"val{suffix}/rec_loss"] + self.log(f"val{suffix}/rec_loss", rec_loss, + prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) + self.log(f"val{suffix}/aeloss", aeloss, + prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) + if version.parse(pl.__version__) >= version.parse('1.4.0'): + del log_dict_ae[f"val{suffix}/rec_loss"] + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr_d = self.learning_rate + lr_g = self.lr_g_factor*self.learning_rate + print("lr_d", lr_d) + print("lr_g", lr_g) + opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ + list(self.decoder.parameters())+ + list(self.quantize.parameters())+ + list(self.quant_conv.parameters())+ + list(self.post_quant_conv.parameters()), + lr=lr_g, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + lr=lr_d, betas=(0.5, 0.9)) + + if self.scheduler_config is not None: + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }, + { + 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }, + ] + return [opt_ae, opt_disc], scheduler + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if only_inputs: + log["inputs"] = x + return log + xrec, _ = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["inputs"] = x + log["reconstructions"] = xrec + if plot_ema: + with self.ema_scope(): + xrec_ema, _ = self(x) + if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema) + log["reconstructions_ema"] = xrec_ema + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + return x + + +class VQModelInterface(VQModel): + def __init__(self, embed_dim, *args, **kwargs): + super().__init__(embed_dim=embed_dim, *args, **kwargs) + self.embed_dim = embed_dim + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return h + + def decode(self, h, force_not_quantize=False): + # also go through quantization layer + if not force_not_quantize: + quant, emb_loss, info = self.quantize(h) + else: + quant = h + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + + +class AutoencoderKL(pl.LightningModule): + def __init__(self, + ddconfig, + lossconfig, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + ): + super().__init__() + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.loss = instantiate_from_config(lossconfig) + assert ddconfig["double_z"] + self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + self.embed_dim = embed_dim + if colorize_nlabels is not None: + assert type(colorize_nlabels)==int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + self.load_state_dict(sd, strict=False) + print(f"Restored from {path}") + + def encode(self, x): + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z): + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + return x + + def training_step(self, batch, batch_idx, optimizer_idx): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + + if optimizer_idx == 0: + # train encoder+decoder+logvar + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return aeloss + + if optimizer_idx == 1: + # train the discriminator + discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + + self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return discloss + + def validation_step(self, batch, batch_idx): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, + last_layer=self.get_last_layer(), split="val") + + discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, + last_layer=self.get_last_layer(), split="val") + + self.log("val/rec_loss", log_dict_ae["val/rec_loss"]) + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr = self.learning_rate + opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ + list(self.decoder.parameters())+ + list(self.quant_conv.parameters())+ + list(self.post_quant_conv.parameters()), + lr=lr, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + lr=lr, betas=(0.5, 0.9)) + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + @torch.no_grad() + def log_images(self, batch, only_inputs=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if not only_inputs: + xrec, posterior = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["samples"] = self.decode(torch.randn_like(posterior.sample())) + log["reconstructions"] = xrec + log["inputs"] = x + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + return x + + +class IdentityFirstStage(torch.nn.Module): + def __init__(self, *args, vq_interface=False, **kwargs): + self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff + super().__init__() + + def encode(self, x, *args, **kwargs): + return x + + def decode(self, x, *args, **kwargs): + return x + + def quantize(self, x, *args, **kwargs): + if self.vq_interface: + return x, None, [None, None, None] + return x + + def forward(self, x, *args, **kwargs): + return x diff --git a/ldm/models/diffusion/__init__.py b/ldm/models/diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ldm/models/diffusion/bank.py b/ldm/models/diffusion/bank.py new file mode 100644 index 0000000000000000000000000000000000000000..0d17f68b755b17ae0cb6b59f1df21bf1e1233603 --- /dev/null +++ b/ldm/models/diffusion/bank.py @@ -0,0 +1,76 @@ +from .misc_4ddpm import * + + +from ldm.modules.attention import BasicTransformerBlock +class Bank: + def __init__(self,reader:nn.Module, writer:nn.Module) -> None: + """ + For the DFS model, mark every BasicTransformerBlock with name_4bank and isReader_4bank flags. + Similar logic applies for the writer while checking for BasicTransformerBlock instances. + """ + self.name2data = {} + self.name2count = {} # track how many times each name has been retrieved + self.WHEN_clear_a_field = 2 # clear the entry after this many gets + skip_names = [ + 'input_blocks.1.1.transformer_blocks.0', + 'input_blocks.2.1.transformer_blocks.0', + # 'input_blocks.4.1.transformer_blocks.0', + # 'input_blocks.5.1.transformer_blocks.0', + # 'input_blocks.7.1.transformer_blocks.0', + # 'input_blocks.8.1.transformer_blocks.0', + ##-----------all middle and output_blocks (everything outside input_blocks)---- + 'middle_block.1.transformer_blocks.0', + 'output_blocks.3.1.transformer_blocks.0', + 'output_blocks.4.1.transformer_blocks.0', + 'output_blocks.5.1.transformer_blocks.0', + 'output_blocks.6.1.transformer_blocks.0', + 'output_blocks.7.1.transformer_blocks.0', + 'output_blocks.8.1.transformer_blocks.0', + 'output_blocks.9.1.transformer_blocks.0', + 'output_blocks.10.1.transformer_blocks.0', + 'output_blocks.11.1.transformer_blocks.0', + ] + # print(f"{skip_names=}") + + l_name = [] + for name, _module in writer.named_modules(): + if isinstance(_module, BasicTransformerBlock): + if DEBUG: + print(f"{name=}") + if name in skip_names: + # print(f"skip {name=}") + continue + _module.bank = self + _module.name4bank = name + _module.isReader_4bank = False + l_name.append(name) + # print(f"{l_name=}") + + for name, _module in reader.named_modules(): + if isinstance(_module, BasicTransformerBlock): + if name not in l_name: + continue + _module.bank = self + _module.name4bank = name + _module.isReader_4bank = True + def set(self,name,data): + self.name2data[name] = data + # self.name2count[name] = 0 + def get(self,name): + printC('bank get', name) + if name in self.name2data: + if name not in self.name2count: + self.name2count[name] = 0 + self.name2count[name] += 1 + data = self.name2data[name] + if self.name2count[name] >= self.WHEN_clear_a_field: # once the max get count is reached, remove the entry + del self.name2data[name] + del self.name2count[name] + return data + raise Exception(f"{name}\n{list(self.name2data.keys())}") + return None + def clear(self,): + printC('clear') + printC('mean ct:', sum( self.name2count.values() ) / len( self.name2count.values() ) if len( self.name2count.values() )>0 else 'null' ) + self.name2data.clear() + self.name2count.clear() diff --git a/ldm/models/diffusion/classifier.py b/ldm/models/diffusion/classifier.py new file mode 100644 index 0000000000000000000000000000000000000000..67e98b9d8ffb96a150b517497ace0a242d7163ef --- /dev/null +++ b/ldm/models/diffusion/classifier.py @@ -0,0 +1,267 @@ +import os +import torch +import pytorch_lightning as pl +from omegaconf import OmegaConf +from torch.nn import functional as F +from torch.optim import AdamW +from torch.optim.lr_scheduler import LambdaLR +from copy import deepcopy +from einops import rearrange +from glob import glob +from natsort import natsorted + +from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel +from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config + +__models__ = { + 'class_label': EncoderUNetModel, + 'segmentation': UNetModel +} + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class NoisyLatentImageClassifier(pl.LightningModule): + + def __init__(self, + diffusion_path, + num_classes, + ckpt_path=None, + pool='attention', + label_key=None, + diffusion_ckpt_path=None, + scheduler_config=None, + weight_decay=1.e-2, + log_steps=10, + monitor='val/loss', + *args, + **kwargs): + super().__init__(*args, **kwargs) + self.num_classes = num_classes + # get latest config of diffusion model + diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1] + self.diffusion_config = OmegaConf.load(diffusion_config).model + self.diffusion_config.params.ckpt_path = diffusion_ckpt_path + self.load_diffusion() + + self.monitor = monitor + self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1 + self.log_time_interval = self.diffusion_model.num_timesteps // log_steps + self.log_steps = log_steps + + self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \ + else self.diffusion_model.cond_stage_key + + assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params' + + if self.label_key not in __models__: + raise NotImplementedError() + + self.load_classifier(ckpt_path, pool) + + self.scheduler_config = scheduler_config + self.use_scheduler = self.scheduler_config is not None + self.weight_decay = weight_decay + + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( + sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") + + def load_diffusion(self): + model = instantiate_from_config(self.diffusion_config) + self.diffusion_model = model.eval() + self.diffusion_model.train = disabled_train + for param in self.diffusion_model.parameters(): + param.requires_grad = False + + def load_classifier(self, ckpt_path, pool): + model_config = deepcopy(self.diffusion_config.params.unet_config.params) + model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels + model_config.out_channels = self.num_classes + if self.label_key == 'class_label': + model_config.pool = pool + + self.model = __models__[self.label_key](**model_config) + if ckpt_path is not None: + print('#####################################################################') + print(f'load from ckpt "{ckpt_path}"') + print('#####################################################################') + self.init_from_ckpt(ckpt_path) + + @torch.no_grad() + def get_x_noisy(self, x, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x)) + continuous_sqrt_alpha_cumprod = None + if self.diffusion_model.use_continuous_noise: + continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1) + # todo: make sure t+1 is correct here + + return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise, + continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod) + + def forward(self, x_noisy, t, *args, **kwargs): + return self.model(x_noisy, t) + + @torch.no_grad() + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = rearrange(x, 'b h w c -> b c h w') + x = x.to(memory_format=torch.contiguous_format).float() + return x + + @torch.no_grad() + def get_conditioning(self, batch, k=None): + if k is None: + k = self.label_key + assert k is not None, 'Needs to provide label key' + + targets = batch[k].to(self.device) + + if self.label_key == 'segmentation': + targets = rearrange(targets, 'b h w c -> b c h w') + for down in range(self.numd): + h, w = targets.shape[-2:] + targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest') + + # targets = rearrange(targets,'b c h w -> b h w c') + + return targets + + def compute_top_k(self, logits, labels, k, reduction="mean"): + _, top_ks = torch.topk(logits, k, dim=1) + if reduction == "mean": + return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item() + elif reduction == "none": + return (top_ks == labels[:, None]).float().sum(dim=-1) + + def on_train_epoch_start(self): + # save some memory + self.diffusion_model.model.to('cpu') + + @torch.no_grad() + def write_logs(self, loss, logits, targets): + log_prefix = 'train' if self.training else 'val' + log = {} + log[f"{log_prefix}/loss"] = loss.mean() + log[f"{log_prefix}/acc@1"] = self.compute_top_k( + logits, targets, k=1, reduction="mean" + ) + log[f"{log_prefix}/acc@5"] = self.compute_top_k( + logits, targets, k=5, reduction="mean" + ) + + self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True) + self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False) + self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True) + lr = self.optimizers().param_groups[0]['lr'] + self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True) + + def shared_step(self, batch, t=None): + x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key) + targets = self.get_conditioning(batch) + if targets.dim() == 4: + targets = targets.argmax(dim=1) + if t is None: + t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long() + else: + t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long() + x_noisy = self.get_x_noisy(x, t) + logits = self(x_noisy, t) + + loss = F.cross_entropy(logits, targets, reduction='none') + + self.write_logs(loss.detach(), logits.detach(), targets.detach()) + + loss = loss.mean() + return loss, logits, x_noisy, targets + + def training_step(self, batch, batch_idx): + loss, *_ = self.shared_step(batch) + return loss + + def reset_noise_accs(self): + self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in + range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)} + + def on_validation_start(self): + self.reset_noise_accs() + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + loss, *_ = self.shared_step(batch) + + for t in self.noisy_acc: + _, logits, _, targets = self.shared_step(batch, t) + self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean')) + self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean')) + + return loss + + def configure_optimizers(self): + optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) + + if self.use_scheduler: + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }] + return [optimizer], scheduler + + return optimizer + + @torch.no_grad() + def log_images(self, batch, N=8, *args, **kwargs): + log = dict() + x = self.get_input(batch, self.diffusion_model.first_stage_key) + log['inputs'] = x + + y = self.get_conditioning(batch) + + if self.label_key == 'class_label': + y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) + log['labels'] = y + + if ismap(y): + log['labels'] = self.diffusion_model.to_rgb(y) + + for step in range(self.log_steps): + current_time = step * self.log_time_interval + + _, logits, x_noisy, _ = self.shared_step(batch, t=current_time) + + log[f'inputs@t{current_time}'] = x_noisy + + pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes) + pred = rearrange(pred, 'b h w c -> b c h w') + + log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred) + + for key in log: + log[key] = log[key][:N] + + return log diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py new file mode 100644 index 0000000000000000000000000000000000000000..d6749b83b6f00b004388a1cb392b51fbe9caf81b --- /dev/null +++ b/ldm/models/diffusion/ddim.py @@ -0,0 +1,540 @@ +"""SAMPLING ONLY.""" + +from imports import * +import torch +import numpy as np +from tqdm import tqdm +from functools import partial +from src.Face_models.encoders.model_irse import Backbone +import torch.nn as nn +import torchvision.transforms.functional as TF + +from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \ + extract_into_tensor + +def un_norm_clip(x1): + x = x1*1.0 # to avoid changing the original tensor or clone() can be used + reduce=False + if len(x.shape)==3: + x = x.unsqueeze(0) + reduce=True + x[:,0,:,:] = x[:,0,:,:] * 0.26862954 + 0.48145466 + x[:,1,:,:] = x[:,1,:,:] * 0.26130258 + 0.4578275 + x[:,2,:,:] = x[:,2,:,:] * 0.27577711 + 0.40821073 + + if reduce: + x = x.squeeze(0) + return x + +class IDLoss(nn.Module): + def __init__(self,path="Other_dependencies/arcface/model_ir_se50.pth",multiscale=False): + super(IDLoss, self).__init__() + print('Loading ResNet ArcFace') + + self.multiscale = multiscale + self.face_pool_1 = torch.nn.AdaptiveAvgPool2d((256, 256)) + self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se') + # self.facenet=iresnet100(pretrained=False, fp16=False) # changed by sanoojan + + self.facenet.load_state_dict(torch.load(path)) + + self.face_pool_2 = torch.nn.AdaptiveAvgPool2d((112, 112)) + self.facenet.eval() + + self.set_requires_grad(False) + + def set_requires_grad(self, flag=True): + for p in self.parameters(): + p.requires_grad = flag + + def extract_feats(self, x,clip_img=True): + # breakpoint() + if clip_img: + x = un_norm_clip(x) + x = TF.normalize(x, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + x = self.face_pool_1(x) if x.shape[2]!=256 else x # (1) resize to 256 if needed + x = x[:, :, 35:223, 32:220] # (2) Crop interesting region + x = self.face_pool_2(x) # (3) resize to 112 to fit pre-trained model + # breakpoint() + x_feats = self.facenet(x, multi_scale=self.multiscale ) + + # x_feats = self.facenet(x) # changed by sanoojan + return x_feats + + + + def forward(self, y_hat, y,clip_img=True,return_seperate=False): + n_samples = y.shape[0] + y_feats_ms = self.extract_feats(y,clip_img=clip_img) # Otherwise use the feature from there + + y_hat_feats_ms = self.extract_feats(y_hat,clip_img=clip_img) + y_feats_ms = [y_f.detach() for y_f in y_feats_ms] + + loss_all = 0 + sim_improvement_all = 0 + seperate_losses=[] + for y_hat_feats, y_feats in zip(y_hat_feats_ms, y_feats_ms): + + loss = 0 + sim_improvement = 0 + count = 0 + + for i in range(n_samples): + sim_target = y_hat_feats[i].dot(y_feats[i]) + sim_views = y_feats[i].dot(y_feats[i]) + + seperate_losses.append(1-sim_target) + loss += 1 - sim_target # id loss + sim_improvement += float(sim_target) - float(sim_views) + count += 1 + + loss_all += loss / count + sim_improvement_all += sim_improvement / count + + return loss_all, sim_improvement_all, None + + +class DDIMSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + # self.ID_LOSS=IDLoss() + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta,verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + # print(f'Data shape for DDIM sampling is {size}, eta {eta}') + + samples, intermediates = self.ddim_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + **kwargs + ) + return samples, intermediates + + @torch.no_grad() + def ddim_sampling(self, cond, shape, + x_T=None, ddim_use_original_steps=False, + callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, log_every_t=100, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, + z_ref=None, + **kwargs): + device = self.model.betas.device + + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + if z_ref is not None: + tensor_1c = torch.zeros((z_ref.shape[0], 1, z_ref.shape[2], z_ref.shape[3]), device=z_ref.device) + if REFNET.CH9: + z_ref = torch.cat([z_ref, z_ref, tensor_1c], dim=1) + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {'x_inter': [img], 'pred_x0': [img]} + time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + + if mask is not None: # None + assert x0 is not None + img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + img = img_orig * mask + (1. - mask) * img + if z_ref is not None: + z_ref_noisy = self.model.q_sample(x_start=z_ref[:,:4], t=ts, ) + if REFNET.CH9: + z_ref[:,:4] = z_ref_noisy + # img and pred_x0 both B,4,64,64; cond/unconditional_conditioning tensors are B,1,768 + outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, + z_ref=z_ref, + quantize_denoised=quantize_denoised, temperature=temperature, + noise_dropout=noise_dropout, score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning,**kwargs) + img, pred_x0 = outs + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + + return img, intermediates + + + @torch.no_grad() + def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, + z_ref=None, + **kwargs): + """ + 0. input param is: (x, c, t, [z_ref] ) + 1. x=concat(x,inpaint,mask) + 2. apply_model(x, t, c, [z_ref] ) + ( similar to ddpm.py LatentDiffusion.p_losses() + """ + b, *_, device = *x.shape, x.device + if 1: + z_inpaint = kwargs['z_inpaint'] # B,4 + z_inpaint_mask = kwargs['z_inpaint_mask'] # B,1 + z9 = kwargs['z9'] # B,9or14 + # x = torch.cat([x, z_inpaint, z_inpaint_mask],dim=1) # B,9,... + x = torch.cat([x, z9[:,4:] ],dim=1) # B,9or14,... + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + e_t = self.model.apply_model(x, t, c, z_ref=z_ref,) + else: # check @ sanoojan + if MERGE_CFG_in_one_batch: + # b,... -> 2b,... + x_in = torch.cat([x] * 2) #x_in: 2,9,64,64 + t_in = torch.cat([t] * 2) + if z_ref is not None: + z_ref_in = torch.cat([z_ref] * 2) + else: + z_ref_in = None + batch_size = t.shape[0] + double_gg_lmk = batch_size>1 and hasattr(global_, 'lmk_') and global_.lmk_ is not None + if double_gg_lmk: + orig_lmk_ = global_.lmk_ + global_.lmk_ = torch.cat([orig_lmk_] * 2) + c_in = torch.cat([unconditional_conditioning, c]) #c_in: 2,1,768 + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in, z_ref=z_ref_in,).chunk(2) + if double_gg_lmk: + global_.lmk_ = orig_lmk_ + else: + # first infer unconditional then conditional (reduces peak CUDA memory) + e_t_uncond = self.model.apply_model(x, t, unconditional_conditioning, z_ref=z_ref,) + e_t = self.model.apply_model(x, t, c, z_ref=z_ref,) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) #1,4,64,64 + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + if x.shape[1]!=4: + pred_x0 = (x[:,:4,:,:] - sqrt_one_minus_at * e_t) / a_t.sqrt() + else: + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(dir_xt.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + + def sample_train(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + t=None, + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + # print(f'Data shape for DDIM sampling is {size}, eta {eta}') + # for param in self.model.first_stage_model.parameters(): + # param.requires_grad = False + samples, intermediates = self.ddim_sampling_train(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T,ddim_num_steps=S, + curr_t=t, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + **kwargs + ) + return samples, intermediates + + + def ddim_sampling_train(self, cond, shape, + x_T=None, ddim_use_original_steps=False,ddim_num_steps=None, + callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, log_every_t=100,curr_t=None, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None,**kwargs): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + kwargs['rest']=img[:,4:,:,:] + img=img[:,:4,:,:] + + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + curr_t=curr_t.cpu().numpy() + skip = (curr_t-1) // ddim_num_steps + # replace all 0s with 1s + skip[skip == 0] = 1 + if type(skip)!=int: + seq=[range(1, curr_t[n]-1, skip[n]) for n in range(len(curr_t))] + min_length = min(len(sublist) for sublist in seq) + min_length=min(min_length,ddim_num_steps) + # Create a new list of sublists by truncating each sublist to the minimum length + truncated_seq = [sublist[:min_length] for sublist in seq] + seq= np.array(truncated_seq) + + # seq=np.flip(seq) + #concatenate all sequences + # seq = np.concatenate(seq) + seq=torch.from_numpy(seq).to(device) + seq=torch.flip(seq,dims=[1]) + + + + intermediates = {'x_inter': [img], 'pred_x0': [img]} + intermediates = {'x_inter': [], 'pred_x0': []} + # time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) + # total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + # print(f"Running DDIM Sampling with {total_steps} timesteps") + + + # time_range=np.array([1]) + # iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) + + total_steps=seq.shape[1] # 4 (ddim 4 steps) + for i in range(seq.shape[1]): + index = total_steps - i - 1 + # ts = torch.full((b,), step, device=device, dtype=torch.long) + ts=seq[:,i].long() + #make it toech long + # ts=ts.long() + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + img = img_orig * mask + (1. - mask) * img + outs = self.p_sample_ddim_train(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, temperature=temperature, + noise_dropout=noise_dropout, score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning,**kwargs) + img, pred_x0 = outs + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) + + # if index % log_every_t == 0 or index == total_steps - 1: + if i in [ total_steps - 1, ]: + # if 1: # len_inter 4 (5 if orig rf) => OOM + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + + return img, intermediates + + + def p_sample_ddim_train(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None,return_features=False,**kwargs): + b, *_, device = *x.shape, x.device + # if 'test_model_kwargs' in kwargs: + # kwargs=kwargs['test_model_kwargs'] + # x = torch.cat([x, kwargs['inpaint_image'], kwargs['inpaint_mask']],dim=1) + if 'rest' in kwargs: + x = torch.cat((x, kwargs['rest']), dim=1) + + + z_ref = kwargs.pop('z_ref',None) + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + e_t = self.model.apply_model(x, t, c,return_features=return_features,z_ref=z_ref) + else: # check @ sanoojan + assert 0 + x_in = torch.cat([x] * 2) #x_in: 2,9,64,64 + t_in = torch.cat([t] * 2) + c_in = torch.cat([unconditional_conditioning, c]) #c_in: 2,1,768 + if return_features: + e_t_uncond, e_t,features = self.model.apply_model(x_in, t_in, c_in,return_features=return_features).chunk(3) + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) #1,4,64,64 + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + if x.shape[1]!=4: + pred_x0 = (x[:,:4,:,:] - sqrt_one_minus_at * e_t) / a_t.sqrt() + else: + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(dir_xt.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + @torch.no_grad() + def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, + use_original_steps=False): + assert 0 diff --git a/ldm/models/diffusion/ddpm.py b/ldm/models/diffusion/ddpm.py new file mode 100644 index 0000000000000000000000000000000000000000..170a644af6cc1945423ac7bce010b312b95a2a30 --- /dev/null +++ b/ldm/models/diffusion/ddpm.py @@ -0,0 +1,1697 @@ + +from .misc_4ddpm import * +from lmk_util.lmk_extractor import lmkAll_2_lmkMain, get_lmkMain_indices + +class DDPM(pl.LightningModule): + # classic DDPM with Gaussian diffusion, in image space + def __init__(self, + unet_config, + timesteps=1000, + beta_schedule="linear", + loss_type="l2", + ckpt_path=None, + ignore_keys=[], + load_only_unet=False, + monitor="val/loss", + use_ema=True, + first_stage_key="image", + image_size=256, + channels=3, + log_every_t=100, + clip_denoised=True, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + given_betas=None, + original_elbo_weight=0., + v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta + l_simple_weight=1., + conditioning_key=None, + parameterization="eps", # all assuming fixed variance schedules + scheduler_config=None, + learn_logvar=False, + logvar_init=0., + u_cond_percent=0, + ): + super().__init__() + assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"' + self.parameterization = parameterization + print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode") + self.cond_stage_model = None + self.clip_denoised = clip_denoised + self.log_every_t = log_every_t + self.first_stage_key = first_stage_key + self.image_size = image_size + self.channels = channels + self.u_cond_percent=u_cond_percent + unet_config['params']['in_channels'] = 14 if CH14 else 9 + self.model = DiffusionWrapper(unet_config, conditioning_key) + count_params(self.model, verbose=True) + self.use_ema = use_ema + if self.use_ema: + self.model_ema = LitEma(self.model) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + self.use_scheduler = scheduler_config is not None + if self.use_scheduler: + self.scheduler_config = scheduler_config + + self.v_posterior = v_posterior + self.original_elbo_weight = original_elbo_weight + self.l_simple_weight = l_simple_weight + + if monitor is not None: + self.monitor = monitor + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet) + + self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps, + linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) + + self.loss_type = loss_type + + self.learn_logvar = learn_logvar + self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,)) + if self.learn_logvar: + self.logvar = nn.Parameter(self.logvar, requires_grad=True) + + + def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if exists(given_betas): + betas = given_betas + else: + betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, + cosine_s=cosine_s) + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer('betas', to_torch(betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / ( + 1. - alphas_cumprod) + self.v_posterior * betas + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.register_buffer('posterior_variance', to_torch(posterior_variance)) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) + self.register_buffer('posterior_mean_coef1', to_torch( + betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) + self.register_buffer('posterior_mean_coef2', to_torch( + (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) + + if self.parameterization == "eps": + lvlb_weights = self.betas ** 2 / ( + 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)) + elif self.parameterization == "x0": + lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod)) + else: + raise NotImplementedError("mu not supported") + # TODO how to choose this term + lvlb_weights[0] = lvlb_weights[1] + self.register_buffer('lvlb_weights', lvlb_weights, persistent=False) + assert not torch.isnan(self.lvlb_weights).all() + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.model.parameters()) + self.model_ema.copy_to(self.model) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.model.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + assert 0 + print("[init_from_ckpt]") + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( + sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + + def get_loss(self, pred, target, mean=True): + if self.loss_type == 'l1': + loss = (target - pred).abs() + if mean: + loss = loss.mean() + elif self.loss_type == 'l2': + if mean: + loss = torch.nn.functional.mse_loss(target, pred) + else: + loss = torch.nn.functional.mse_loss(target, pred, reduction='none') #--> + else: + raise NotImplementedError("unknown loss type '{loss_type}'") + + return loss + + def p_losses(self, x_start, t, noise=None): + assert 0, 'This should not be called; subclasses override this method' + noise = default(noise, lambda: torch.randn_like(x_start)) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_out = self.model(x_noisy, t) + + loss_dict = {} + if self.parameterization == "eps": + target = noise + elif self.parameterization == "x0": + target = x_start + else: + raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported") + + loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3]) + + # metrics.csv entries like 'train/...' and 'val/...' originate here + log_prefix = 'train' if self.training else 'val' + + loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()}) + loss_simple = loss.mean() * self.l_simple_weight + + loss_vlb = (self.lvlb_weights[t] * loss).mean() + loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb}) + + loss = loss_simple + self.original_elbo_weight * loss_vlb + + loss_dict.update({f'{log_prefix}/loss': loss}) + + return loss, loss_dict + + def forward(self, x, *args, **kwargs): + # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size + # assert h == img_size and w == img_size, f'height and width of image must be {img_size}' + t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() + return self.p_losses(x, t, *args, **kwargs) + + + def shared_step(self, batch): + assert 0 + + + def set_task(self, batch): + task = batch['task'][0].item() + printC('task',f"{task=}") + global_.task = task + assert all(batch['task'] == task), batch['task'] + self.task = task + if 1: + if (not USE_pts) or task==1: self.Landmark_cond=False + else: self.Landmark_cond=True + if 1: + if task in (0,2,3,): + self.Landmarks_weight=0.05 + else: + self.Landmarks_weight=0 + self.STACK_feat=True + return task + def unset_task(self): + global_.task = None + global_.lmk_ = None + del self.task + def training_step(self, batch, batch_idx): + task = batch['task'][0].item() + opt = self.optimizers() + + if not self.Reconstruct_initial:# only MSE loss(orig diffusion). -> shared_step -> forward -> p_losses + loss, loss_dict = self.shared_step(batch) # original + else: # added Multistep (DDIM) loss -> shared_step_face -> forward_face -> p_losses_face + loss, loss_dict = self.shared_step_face(batch) # changed by sanoojan : to add ID loss after reconstructing through DDIM + + step_or_accumulate = ( task==3 or TP_enable) + _ctx = nullcontext + if not step_or_accumulate and not TP_enable: + _ctx = self.trainer.model.no_sync # https://github.com/Lightning-AI/pytorch-lightning/discussions/10792 + with _ctx(): # https://zhuanlan.zhihu.com/p/250471767 + self.manual_backward(loss) + + if (REFNET.ENABLE and REFNET.task2layerNum[task]>0): + self.model.bank.clear() + self.unset_task() + + + total_step = len(self.trainer.train_dataloader) + if step_or_accumulate: + # Average grads of shared params across ranks (TaskParallel) + if dist.is_available() and dist.is_initialized(): + ws = dist.get_world_size() + shared_sync_cnt = 0; task_skip_cnt = 0 + for name, p in self.named_parameters(): + need_sync, is_task_specific_skip = tp_param_need_sync(name, p) + if not need_sync: + if is_task_specific_skip: + task_skip_cnt += 1 + continue + if p.grad is None: + p.grad = torch.zeros_like(p) # ensure collective call sequence remains consistent + dist.all_reduce(p.grad, op=dist.ReduceOp.SUM) + p.grad.div_(ws) + shared_sync_cnt += 1 + if gate_('[TP] shared sync counts'): + print(f"synced={shared_sync_cnt} skipped(task)={task_skip_cnt}") + torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1.0) + opt.step() + opt.zero_grad() + if self.use_scheduler: # handle LR schedulers + sch = self.lr_schedulers() + if isinstance(sch, list) and len(sch) > 0: # schedulers expressed as a list + for scheduler_config in sch: + if isinstance(scheduler_config, dict) and 'scheduler' in scheduler_config: + scheduler_config['scheduler'].step() + else: + scheduler_config.step() + elif hasattr(sch, 'step'): + sch.step() + self.log_dict(loss_dict, prog_bar=True, + logger=True, on_step=True, on_epoch=True) + + self.log("global_step", self.global_step, + prog_bar=True, logger=True, on_step=True, on_epoch=False) + + if self.use_scheduler: + lr = self.optimizers().param_groups[0]['lr'] + self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False) + + return loss + # manual optimization calls backward in training_step already, so this is skipped here + # def backward( + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + _, loss_dict_no_ema = self.shared_step(batch) + with self.ema_scope(): + _, loss_dict_ema = self.shared_step(batch) + loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema} + self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) + self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) + self.unset_task() + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self.model) + + + + +class LatentDiffusion(DDPM): + """main class""" + def __init__(self, + first_stage_config, + cond_stage_config, + num_timesteps_cond=None, + cond_stage_key="image", + cond_stage_trainable=False, + concat_mode=True, + cond_stage_forward=None, + conditioning_key=None, + scale_factor=1.0, + scale_by_std=False, + *args, **kwargs): + self.num_timesteps_cond = default(num_timesteps_cond, 1) + self.scale_by_std = scale_by_std + assert self.num_timesteps_cond <= kwargs['timesteps'] + # for backwards compatibility after implementation of DiffusionWrapper + if conditioning_key is None: + conditioning_key = 'concat' if concat_mode else 'crossattn' + if cond_stage_config == '__is_unconditional__': + conditioning_key = None + ckpt_path = kwargs.pop("ckpt_path", None) + ignore_keys = kwargs.pop("ignore_keys", []) + super().__init__(conditioning_key=conditioning_key, *args, **kwargs) + self.automatic_optimization = False # disable automatic optimization to manage parameter updates manually + + + # self.learnable_vector = nn.Parameter(torch.randn((1,1,768)), requires_grad=True) + # breakpoint() + + + self.concat_mode = concat_mode + self.cond_stage_trainable = cond_stage_trainable + self.cond_stage_key = cond_stage_key + + #check if other_params is present in cond_stage_config + if hasattr(cond_stage_config, 'other_params'): + + self.clip_weight=cond_stage_config.other_params.clip_weight + # those three weights: 0 skips module init, >0 enables it and acts as weight when !STACK_feat + if set(TASKS) & {0,2,3}: self.ID_weight = 10.0 + else: self.ID_weight = 0 + if (not USE_pts) and TASKS==(1,): self.Landmark_cond=False + else: self.Landmark_cond=True + self.Landmarks_weight=0.05 + if hasattr(cond_stage_config.other_params, 'Additional_config'): + self.Reconstruct_initial=cond_stage_config.other_params.Additional_config.Reconstruct_initial + self.Reconstruct_DDIM_steps=cond_stage_config.other_params.Additional_config.Reconstruct_DDIM_steps + self.sampler=DDIMSampler(self) + if hasattr(cond_stage_config.other_params, 'multi_scale_ID'): + self.multi_scale_ID=cond_stage_config.other_params.multi_scale_ID # True has an issue + else: + self.multi_scale_ID=True #this has an issue obtaining earlier layer from ID + if hasattr(cond_stage_config.other_params, 'normalize'): + self.normalize=cond_stage_config.other_params.normalize # normalizes the combintaion of ID and LPIPS loss + else: + self.normalize=False + if 1: + self.lpips_loss = LPIPS(net_type='alex').to(self.device).eval() + if hasattr(cond_stage_config.other_params, 'partial_training'): + self.partial_training=cond_stage_config.other_params.partial_training + self.trainable_keys=cond_stage_config.other_params.trainable_keys + else: + self.partial_training=False + if hasattr(cond_stage_config.other_params.Additional_config, 'Same_image_reconstruct'): + self.Same_image_reconstruct=cond_stage_config.other_params.Additional_config.Same_image_reconstruct + else: + self.Same_image_reconstruct=False + if hasattr(cond_stage_config.other_params.Additional_config, 'Target_CLIP_feat'): + self.Target_CLIP_feat=cond_stage_config.other_params.Additional_config.Target_CLIP_feat + else: + self.Target_CLIP_feat=False + if hasattr(cond_stage_config.other_params.Additional_config, 'Source_CLIP_feat'): + self.Source_CLIP_feat=cond_stage_config.other_params.Additional_config.Source_CLIP_feat + else: + self.Source_CLIP_feat=False + if hasattr(cond_stage_config.other_params.Additional_config, 'use_3dmm'): + self.use_3dmm=cond_stage_config.other_params.Additional_config.use_3dmm + else: + self.use_3dmm=False + + else: + self.Reconstruct_initial=False + self.Reconstruct_DDIM_steps=0 + + self.update_weight=False + + else: + assert 0 + if 1: + self.learnable_vector = nn.ParameterList([ + nn.Parameter(torch.randn((1,259,768)), requires_grad=True), + nn.Parameter(torch.randn((1,257,768)), requires_grad=True), + nn.Parameter(torch.randn((1,259,768)), requires_grad=True), + nn.Parameter(torch.randn((1,259,768)), requires_grad=True), + ]) + if self.ID_weight>0: + if self.multi_scale_ID: + self.ID_proj_out=nn.Linear(200704, 768) + else: + self.ID_proj_out=nn.Linear(512, 768) # yes + self.instantiate_IDLoss(cond_stage_config) + + if self.Landmark_cond: + if USE_pts: + self.ptsM_Generator = LandmarkExtractor(include_visualizer=True,img_256_mode=False) + else: + raise + + if self.Landmarks_weight>0: + self.landmark_proj_out=nn.Linear(NUM_pts*2, 768) + self.total_steps_in_epoch=0 # will be calculated inside training_step. Not known for now + if 1: + assert cond_stage_config.target=="ldm.modules.encoders.modules.FrozenCLIPEmbedder" and self.Source_CLIP_feat and self.Target_CLIP_feat + self.USE_proj_out_source = 1 + if set(TASKS) & {0,}: + self.proj_out_source__face=nn.Linear(768, 768) + if set(TASKS) & {1,}: + self.proj_out_source__hair=nn.Linear(768, 768) + if set(TASKS) & {2,3,}: + self.proj_out_source__head=nn.Linear(768, 768) + if 0: # dummy, just for compa + self.proj_out_target=nn.Linear(768, 768) + self.proj_out=nn.Identity() + + try: + self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 + except: + self.num_downs = 0 + if not scale_by_std: + self.scale_factor = scale_factor + else: + self.register_buffer('scale_factor', torch.tensor(scale_factor)) + self.instantiate_first_stage(first_stage_config) + self.instantiate_cond_stage(cond_stage_config) + + + self.cond_stage_forward = cond_stage_forward + self.clip_denoised = False + self.bbox_tokenizer = None + + self.restarted_from_ckpt = False + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys) + self.restarted_from_ckpt = True + + def get_lmk_for_router(self, batch: dict, x_tensor: torch.Tensor): + """ + Prepare global_.lmk_ (BS, L, 2) normalized to [0,1] for gating/router. + - Prefer cached Mediapipe landmarks if present in batch + - Convert 468/478 to main landmarks with face oval using get_lmkMain_indices(True) + - Fallback to zeros if not available + """ + b, _, H, W = x_tensor.shape + if READ_mediapipe_result_from_cache and ('mediapipe_lmkAll' in batch): + data_all = batch['mediapipe_lmkAll'] # tensor or ndarray + if isinstance(data_all, torch.Tensor): + lmks_all = data_all.to(x_tensor.device).to(x_tensor.dtype) + else: + lmks_all = torch.from_numpy(data_all).to(x_tensor.device).to(x_tensor.dtype) + # map to main indices with face oval (cached tensor indices on device) + idxs = getattr(global_, 'lmk_main_idx_tensor', None) + if (idxs is None) or (idxs.device != x_tensor.device): + idx_list = get_lmkMain_indices(include_face_oval=True) + idxs = torch.as_tensor(list(idx_list), dtype=torch.long, device=x_tensor.device) + global_.lmk_main_idx_tensor = idxs + lmk = torch.index_select(lmks_all, dim=1, index=idxs) + # normalize by current spatial size + if lmk.numel() > 0: + # print(f"0 {lmk[:,:5]=}") + lmk[..., 0] = lmk[..., 0] / float(W) + lmk[..., 1] = lmk[..., 1] / float(H) + # print(f"1 {lmk[:,:5]=}") + else: + assert 0 + lmk = torch.zeros((b, 0, 2), device=x_tensor.device, dtype=x_tensor.dtype) + return lmk + + def make_cond_schedule(self, ): + self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long) + ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long() + self.cond_ids[:self.num_timesteps_cond] = ids + + @rank_zero_only + @torch.no_grad() + def on_train_batch_start(self, batch, batch_idx, dataloader_idx): + # only for very first batch + if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt: + assert 0 + + def register_schedule(self, + given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s) + + self.shorten_cond_schedule = self.num_timesteps_cond > 1 + if self.shorten_cond_schedule: + self.make_cond_schedule() + + def instantiate_first_stage(self, config): + model = instantiate_from_config(config) + self.first_stage_model = model.eval() + self.first_stage_model.train = disabled_train + for param in self.first_stage_model.parameters(): + param.requires_grad = False + + def instantiate_IDLoss(self, config): + # Need to modify @sanoojan + # if not self.cond_stage_trainable: + model = IDLoss(config,multiscale=self.multi_scale_ID) + self.face_ID_model = model.eval() + self.face_ID_model.train = disabled_train + for param in self.face_ID_model.parameters(): + param.requires_grad = False + + + + def instantiate_cond_stage(self, config): + if 1: + assert config != '__is_first_stage__' + assert config != '__is_unconditional__' + model: FrozenCLIPEmbedder = instantiate_from_config(config) #ldm.modules.encoders.modules.FrozenCLIPEmbedder + if 0 in TASKS: + self.encoder_clip_face :FrozenCLIPEmbedder = model + if 1 in TASKS: + self.encoder_clip_hair :FrozenCLIPEmbedder = copy.deepcopy(model) + del self.encoder_clip_hair.model + del self.encoder_clip_hair.tokenizer + if set(TASKS) & {2,}: + self.encoder_clip_head_t2 :FrozenCLIPEmbedder = copy.deepcopy(model) + del self.encoder_clip_head_t2.model + del self.encoder_clip_head_t2.tokenizer + if set(TASKS) & {3,}: + self.encoder_clip_head_t3 :FrozenCLIPEmbedder = copy.deepcopy(model) + del self.encoder_clip_head_t3.model + del self.encoder_clip_head_t3.tokenizer + + + def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False): + denoise_row = [] + for zd in tqdm(samples, desc=desc): + denoise_row.append(self.decode_first_stage(zd.to(self.device), + force_not_quantize=force_no_decoder_quantization)) + n_imgs_per_row = len(denoise_row) + denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W + denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w') + denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + def get_first_stage_encoding(self, encoder_posterior): + if isinstance(encoder_posterior, DiagonalGaussianDistribution): + z = encoder_posterior.sample() + elif isinstance(encoder_posterior, torch.Tensor): + z = encoder_posterior + else: + raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented") + return self.scale_factor * z + + + def get_learned_conditioning(self, c): + raise Exception + def conditioning_with_feat(self,x,landmarks=None,enInputs:dict=None): + if gate_('vis LatentDiffusion.conditioning_with_feat'): + debug_dir = Path(f"4debug/conditioning_with_feat/{ID}"); debug_dir.mkdir(parents=0, exist_ok=True) + all_images = [ ('x', x), ] + for _name, _enInput in enInputs.items(): + all_images.append((_name, _enInput)) + vis_tensors_A(all_images, debug_dir / f"all-{str_t_pid()}.jpg", vis_batch_size= min(5, landmarks.shape[0]) ) + del x # (x is GT during training, ref_imgs during inference) + task = self.task + ID_weight = self.ID_weight + Landmarks_weight = self.Landmarks_weight + if self.task==0: + face_clip_weight = self.clip_weight + elif self.task==1: + hair_clip_weight = self.clip_weight + elif self.task==2: + head_clip_weight = self.clip_weight + elif self.task==3: + head_clip_weight = self.clip_weight + if 1: + cs = [] # conditionings + ws = [] # weights corresponding one-to-one with cs + def encode_face_ID(): + _c = enInputs['face_ID-in'] + _c=self.face_ID_model.extract_feats(_c)[0] + _c = self.ID_proj_out(_c) #-->c:[4,768] + _c = _c.unsqueeze(1) #-->c:[4,1,768] + if self.normalize: #normalize c2 + _c = _c*norm_coeff/F.normalize(_c, p=2, dim=2) + cs.append(_c); ws.append(ID_weight) + def encode_face_clip(_z=None):# _z: result of ViT forward pass + if _z is None: + _c = enInputs['face-clip-in'] + _c = self.encoder_clip_face.encode(_c) #b,3,224,224 --> b,1,768 + else: + assert 0 + _c = self.encoder_clip_face.encode_B(_z) + if hasattr(self,'USE_proj_out_source') and self.USE_proj_out_source: + _c = self.proj_out_source__face(_c) + cs.append(_c); ws.append(face_clip_weight) + def encode_hair_clip(_z=None): + if _z is None: + _c = enInputs['hair-clip-in'] + _c = self.encoder_clip_hair.encode(_c) #b,3,224,224 --> b,1,768 + else: + _c = self.encoder_clip_hair.encode_B(_z) + if hasattr(self,'USE_proj_out_source') and self.USE_proj_out_source: + _c = self.proj_out_source__hair(_c) + printC("hair _c.shape:",f"{_c.shape}") + cs.append(_c); ws.append(hair_clip_weight) + def encode_head_clip(_z=None): + if global_.task == 2: + encoder_clip_head = self.encoder_clip_head_t2 + elif global_.task == 3: + encoder_clip_head = self.encoder_clip_head_t3 + else: + raise ValueError(f"Task {global_.task} does not have encoder_clip_head") + if _z is None: + _c = enInputs['head-clip-in'] + _c = encoder_clip_head.encode(_c) #b,3,224,224 --> b,1,768 + else: + _c = encoder_clip_head.encode_B(_z) + if hasattr(self,'USE_proj_out_source') and self.USE_proj_out_source: + _c = self.proj_out_source__head(_c) + printC("head _c.shape:",f"{_c.shape}") + cs.append(_c); ws.append(head_clip_weight) + if task==0: + encode_face_ID() + encode_face_clip() + elif task==1: + _z = enInputs['hair-clip-in'] + _z = self.encoder_clip_face.forward_vit(_z) + encode_hair_clip(_z) + elif task==2: + encode_face_ID() + _z = enInputs['head-clip-in'] + _z = self.encoder_clip_face.forward_vit(_z) + encode_head_clip(_z) + elif task==3: + encode_face_ID() + _z = enInputs['head-clip-in'] + _z = self.encoder_clip_face.forward_vit(_z) + encode_head_clip(_z) + c=0 + + if Landmarks_weight > 0: + landmarks=landmarks.unsqueeze(1) if len(landmarks.shape)!=3 else landmarks + cs.append(landmarks); ws.append(Landmarks_weight) + if self.STACK_feat: # _Cc + # stack all features + conc=torch.cat(cs, dim=-2) + c = conc + else: + total_weight = sum(ws) + weighted_sum = sum(c * w for c, w in zip(cs, ws)) + c = weighted_sum / total_weight if total_weight > 0 else 0 + printC("[conditioning_with_feat return]",f"{custom_repr_v3(c)}") + # assert c.shape[1]==NUM_token, c.shape + return c + + + def get_landmarks(self,x, batch:dict): + + if (self.Landmark_cond) and x is not None: + # pass + # # Detect faces in an image + #convert to 8bit image + x=255.0*un_norm(x).permute(0,2,3,1).cpu().numpy() + x=x.astype(np.uint8) # B,512,512,3 + Landmarks_all=[] + if USE_pts: + l_lmkAll=[] + if READ_mediapipe_result_from_cache: + _l_lmkAll :np.ndarray = batch['mediapipe_lmkAll'].cpu().numpy() + bs = len(x) + for i in range(len(x)): + if USE_pts: + if READ_mediapipe_result_from_cache: + lmkAll :np.ndarray = _l_lmkAll[i] + else: + lmkAll :np.ndarray = self.ptsM_Generator.extract_single(x[i], only_main_lmk=False) + if lmkAll is None: lmkAll = np.zeros((478,2)) + l_lmkAll.append(lmkAll) + lm = lmkAll_2_lmkMain(lmkAll) # NUM_pts,2 + lm = lm.reshape(1, NUM_pts*2) # num of points * 2 coordinates + Landmarks_all.append(lm) + if 0: + from util_vis import visualize_landmarks + starter_stem = Path(sys.argv[0]).stem + path_vis_lmk = f'4debug/vis_lmk/{starter_stem}-{i}.png' + visualize_landmarks(x[i], lm[0], path_vis_lmk) + print(f"{path_vis_lmk=}") + Landmarks_all=np.concatenate(Landmarks_all,axis=0) + pts68 = Landmarks_all.reshape(bs, NUM_pts, 2, ) + if self.Landmarks_weight>0: + Landmarks_all=torch.tensor(Landmarks_all).float().to(self.device) + if self.Landmark_cond == False: + return Landmarks_all + with torch.enable_grad(): + Landmarks_all=self.landmark_proj_out(Landmarks_all) + # normalize Landmarks_all + + lmk_aux={} + if USE_pts: lmk_aux['l_lmkAll'] = l_lmkAll + return Landmarks_all,pts68,lmk_aux + + def meshgrid(self, h, w): + y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1) + x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1) + + arr = torch.cat([y, x], dim=-1) + return arr + + def delta_border(self, h, w): + """ + :param h: height + :param w: width + :return: normalized distance to image border, + wtith min distance = 0 at border and max dist = 0.5 at image center + """ + lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2) + arr = self.meshgrid(h, w) / lower_right_corner + dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0] + dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0] + edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0] + return edge_dist + + def get_weighting(self, h, w, Ly, Lx, device): + weighting = self.delta_border(h, w) + weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"], + self.split_input_params["clip_max_weight"], ) + weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device) + + if self.split_input_params["tie_braker"]: + L_weighting = self.delta_border(Ly, Lx) + L_weighting = torch.clip(L_weighting, + self.split_input_params["clip_min_tie_weight"], + self.split_input_params["clip_max_tie_weight"]) + + L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device) + weighting = weighting * L_weighting + return weighting + + def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code + """ + :param x: img of size (bs, c, h, w) + :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1]) + """ + bs, nc, h, w = x.shape + + # number of crops in image + Ly = (h - kernel_size[0]) // stride[0] + 1 + Lx = (w - kernel_size[1]) // stride[1] + 1 + + if uf == 1 and df == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params) + + weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx)) + + elif uf > 1 and df == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf), + dilation=1, padding=0, + stride=(stride[0] * uf, stride[1] * uf)) + fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2) + + weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx)) + + elif df > 1 and uf == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df), + dilation=1, padding=0, + stride=(stride[0] // df, stride[1] // df)) + fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2) + + weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx)) + + else: + raise NotImplementedError + + return fold, unfold, normalization, weighting + + # returned x is the concatenated multi-channel tensor (mask, ref, lmk, ...); e.g. "x_start[:,8,:,:]" extracts the mask + @torch.no_grad() + def get_input_(self, batch, k, return_first_stage_outputs=False, + cond_key=None, bs=None, + get_referenceZ=False, # reference image latent tensor, dims B,4,64,64 + ): + if k == "inpaint": # yes + x = batch['GT'] + mask = batch['inpaint_mask'].clone() # b,1,512,512 + inpaint = batch['inpaint_image'].clone() # .clone so that batch['inpaint_image'] remains the original image without landmarks + # reference = batch['ref_imgs'] + reference = None + else: + assert 0 + if len(x.shape) == 3: + assert 0 + x = x[..., None] + if 1: + enInputs = batch['enInputs'] # encoder inputs (each self.encoder receives these raw tensors without preprocessing) + for k,v in enInputs.items(): + enInputs[k] = v.to(memory_format=torch.contiguous_format).float() + #-------------------------------------------------------------------------------- + ref_imgs_4unet = batch.get('ref_imgs_4unet', None) if get_referenceZ else None + + + #x : Original Image + #inpaint : Masked original image + #mask: mask + #reference: Transformed(Masked(original image)) + if bs is not None: + assert 0 + x = x.to(self.device) + + global_.lmk_ = self.get_lmk_for_router(batch, x) # for router/gate + if self.Landmark_cond: + landmarks, pts68, lmk_aux=self.get_landmarks(x,batch) + else: + landmarks=None + + if self.task in (0,2,3,) and USE_pts: + mask_np = mask.detach().cpu().numpy() + if 1: + #convert to 8bit image + x_unnorm=255.0*un_norm(x).permute(0,2,3,1).cpu().numpy() + x_unnorm=x_unnorm.astype(np.uint8) # B,512,512,3 + + batch_size = x.shape[0] + + VIS_pts= 0 + + for b in range(batch_size): + lmkAll = lmk_aux['l_lmkAll'][b] + inpaint[b] = torch.Tensor(self.ptsM_Generator.visualizer.visualize_landmarks(inpaint[b].permute(1,2,0).detach().cpu().numpy(), lmkAll, ) ).permute(2,0,1) + del lmkAll + + if self.training and gate_('vis LatentDiffusion.get_input'): + debug_dir = Path(f"4debug/LatentDiffusion.get_input/{ID}"); debug_dir.mkdir(parents=0, exist_ok=True) + vis_batch_size = min(5, x.shape[0]) # Show at most 4 samples + all_images = [ ('x', x), ('inpaint', inpaint), ('mask', mask), ('reference', reference), ('ref_imgs_4unet', ref_imgs_4unet) ] + for _name, _enInput in enInputs.items(): + all_images.append((_name, _enInput)) + all_path = debug_dir / f"all--after-pts-{str_t_pid()}.jpg" + vis_tensors_A(all_images, all_path, vis_batch_size) + + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + encoder_posterior_inpaint = self.encode_first_stage(inpaint) + z_inpaint = self.get_first_stage_encoding(encoder_posterior_inpaint).detach() + # tgt/ref_mask_64 + mask_resize = Resize([z.shape[-1],z.shape[-1]])(mask) + ref_mask_64 = Resize([z.shape[-1],z.shape[-1]])(batch['ref_mask_512']) if 'ref_mask_512' in batch else None + # z9 & z_ref + if not CH14: + z_new = torch.cat((z,z_inpaint,mask_resize),dim=1) # shape:[4,9,64,64] 9:4+4+1 + if get_referenceZ: + encoder_posterior_ref = self.encode_first_stage(ref_imgs_4unet) + z_ref = self.get_first_stage_encoding(encoder_posterior_ref).detach() # shape:[4,4,64,64] + else: + z_ref = None + if CH14: + z_new = torch.cat((z,z_inpaint,mask_resize, z_ref,ref_mask_64),dim=1) + assert z.shape[1:]==(4,64,64,) + if gate_(f'vis LatentDiffusion.get_input-before_return {self.training}'): + debug_dir = Path(f"4debug/LatentDiffusion.get_input-before_return/{ID}"); debug_dir.mkdir(parents=0, exist_ok=True) + vis_batch_size = min(5, x.shape[0]) + all_images = [ ('x', x), ('inpaint', inpaint), ('mask', mask), ('reference', reference), ('ref_imgs_4unet', ref_imgs_4unet), + ('z4_gt',z[:,:3]),('z4_inpaint', z_inpaint[:,:3]),('tgt_mask_64', mask_resize),('z_ref',None if z_ref is None else z_ref[:,:3]),('ref_mask_64',ref_mask_64),] + all_path = debug_dir / f"{str_t_pid()}.jpg" + vis_tensors_A(all_images, all_path, vis_batch_size) + + if 1: + assert self.model.conditioning_key is not None + assert self.first_stage_key=='inpaint' + assert self.cond_stage_key=='image' + return { + **batch, + 'z9': z_new,# b,9/14,... + 'z4_gt': z, + 'z4_inpaint': z_inpaint, + # + 'tgt_mask_64': mask_resize, + 'ref_mask_64': ref_mask_64, + # + 'z_ref': z_ref, # 'z_ref' is ambiguous but kept for legacy usage; hard-code the intended meaning + # + 'landmarks': landmarks, # projected features, not raw coordinates + } + + @torch.no_grad() + def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): + if predict_cids: + if z.dim() == 4: + z = torch.argmax(z.exp(), dim=1).long() + z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) + z = rearrange(z, 'b h w c -> b c h w').contiguous() + + z = 1. / self.scale_factor * z + + if hasattr(self, "split_input_params"): + if self.split_input_params["patch_distributed_vq"]: + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + uf = self.split_input_params["vqf"] + bs, nc, h, w = z.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + print("reducing Kernel") + + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + print("reducing stride") + + fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf) + + z = unfold(z) # (bn, nc * prod(**ks), L) + # 1. Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + # 2. apply model loop over last dim + if isinstance(self.first_stage_model, VQModelInterface): + output_list = [self.first_stage_model.decode(z[:, :, :, :, i], + force_not_quantize=predict_cids or force_not_quantize) + for i in range(z.shape[-1])] + else: + + output_list = [self.first_stage_model.decode(z[:, :, :, :, i]) + for i in range(z.shape[-1])] + + o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L) + o = o * weighting + # Reverse 1. reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization # norm is shape (1, 1, h, w) + return decoded + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + if self.first_stage_key=='inpaint': + return self.first_stage_model.decode(z[:,:4,:,:]) + else: + return self.first_stage_model.decode(z) + + + + # same as above but without decorator + def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): + if predict_cids: + if z.dim() == 4: + z = torch.argmax(z.exp(), dim=1).long() + z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) + z = rearrange(z, 'b h w c -> b c h w').contiguous() + + z = 1. / self.scale_factor * z + + if hasattr(self, "split_input_params"): + if self.split_input_params["patch_distributed_vq"]: + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + uf = self.split_input_params["vqf"] + bs, nc, h, w = z.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + print("reducing Kernel") + + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + print("reducing stride") + + fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf) + + z = unfold(z) # (bn, nc * prod(**ks), L) + # 1. Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + # 2. apply model loop over last dim + if isinstance(self.first_stage_model, VQModelInterface): + output_list = [self.first_stage_model.decode(z[:, :, :, :, i], + force_not_quantize=predict_cids or force_not_quantize) + for i in range(z.shape[-1])] + else: + + output_list = [self.first_stage_model.decode(z[:, :, :, :, i]) + for i in range(z.shape[-1])] + + o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L) + o = o * weighting + # Reverse 1. reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization # norm is shape (1, 1, h, w) + return decoded + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + @torch.no_grad() + def encode_first_stage(self, x): + if hasattr(self, "split_input_params"): + if self.split_input_params["patch_distributed_vq"]: + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + df = self.split_input_params["vqf"] + self.split_input_params['original_image_size'] = x.shape[-2:] + bs, nc, h, w = x.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + print("reducing Kernel") + + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + print("reducing stride") + + fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df) + z = unfold(x) # (bn, nc * prod(**ks), L) + # Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + output_list = [self.first_stage_model.encode(z[:, :, :, :, i]) + for i in range(z.shape[-1])] + + o = torch.stack(output_list, axis=-1) + o = o * weighting + + # Reverse reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization + return decoded + + else: + return self.first_stage_model.encode(x) + else: + return self.first_stage_model.encode(x) + + def get_input_and_conditioning(self,batch, device=None): + if device is not None: batch = recursive_to(batch, device) + #------------------------from shared_step------------------------- + get_referenceZ=(REFNET.ENABLE and REFNET.task2layerNum[global_.task]>0) or CH14 + batch = self.get_input_(batch, self.first_stage_key,get_referenceZ=get_referenceZ) + #------------------------from shared_step -> forward------------------------- + assert ( self.model.conditioning_key is not None ) and self.cond_stage_trainable + c=self.conditioning_with_feat(batch['ref_imgs'],landmarks=batch['landmarks'],enInputs=batch['enInputs']) + return batch,c + def shared_step(self, batch, **kwargs): + task = self.set_task(batch) + if (REFNET.ENABLE and REFNET.task2layerNum[task]>0): + self.model.bank.clear() + batch, c = self.get_input_and_conditioning(batch) + z9 = batch['z9'] + z_ref = batch['z_ref'] + gt512 = batch['GT'] + gt256 = batch.get('GT256',None) + # del batch + loss = self(z9, c,z_ref=z_ref,gt512=gt512,gt256=gt256,task=task,batch=batch,) + return loss + + def forward(self, x, c, *args, **kwargs): + task = kwargs['task'] + # c is the reference tensor; target shares the same shape + t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() + self.u_cond_prop=random.uniform(0, 1) + if self.model.conditioning_key is not None: + # assert c is not None + if self.cond_stage_trainable: # yes + pass + + if self.shorten_cond_schedule: # TODO: drop this option + raise Exception + tc = self.cond_ids[t].to(self.device) + c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) + + if self.u_cond_propc_crossattn + cond = {key: cond} + + if hasattr(self, "split_input_params"): + assert 0,'This branch should not execute in practice' + assert len(cond) == 1 # todo can only deal with one conditioning atm + assert not return_ids + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + + h, w = x_noisy.shape[-2:] + + fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride) + + z = unfold(x_noisy) # (bn, nc * prod(**ks), L) + # Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])] + + if self.cond_stage_key in ["image", "LR_image", "segmentation", + 'bbox_img'] and self.model.conditioning_key: # todo check for completeness + c_key = next(iter(cond.keys())) # get key + c = next(iter(cond.values())) # get value + assert (len(c) == 1) # todo extend to list with more than one elem + c = c[0] # get element + + c = unfold(c) + c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])] + + elif self.cond_stage_key == 'coordinates_bbox': + assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size' + + # assuming padding of unfold is always 0 and its dilation is always 1 + n_patches_per_row = int((w - ks[0]) / stride[0] + 1) + full_img_h, full_img_w = self.split_input_params['original_image_size'] + # as we are operating on latents, we need the factor from the original image size to the + # spatial latent size to properly rescale the crops for regenerating the bbox annotations + num_downs = self.first_stage_model.encoder.num_resolutions - 1 + rescale_latent = 2 ** (num_downs) + + # get top left positions of patches as conforming for the bbbox tokenizer, therefore we + # need to rescale the tl patch coordinates to be in between (0,1) + tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w, + rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h) + for patch_nr in range(z.shape[-1])] + + # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w) + patch_limits = [(x_tl, y_tl, + rescale_latent * ks[0] / full_img_w, + rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates] + # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates] + + # tokenize crop coordinates for the bounding boxes of the respective patches + patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device) + for bbox in patch_limits] # list of length l with tensors of shape (1, 2) + print(patch_limits_tknzd[0].shape) + # cut tknzd crop position from conditioning + assert isinstance(cond, dict), 'cond must be dict to be fed into model' + cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device) + + adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd]) + adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n') + adapted_cond = self.get_learned_conditioning(adapted_cond) + adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1]) + + cond_list = [{'c_crossattn': [e]} for e in adapted_cond] + + else: + cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient + + # apply model by loop over crops + output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])] + assert not isinstance(output_list[0], + tuple) # todo cant deal with multiple model outputs check this never happens + + o = torch.stack(output_list, axis=-1) + o = o * weighting + # Reverse reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + x_recon = fold(o) / normalization + + else: + x_recon = self.model(x_noisy, t, **cond, return_features=return_features, z_ref=z_ref, + task=self.task, _trainer=self.trainer, + ) + if return_features: + return x_recon + if isinstance(x_recon, tuple) and not return_ids: + return x_recon[0] + else: + return x_recon + + + def p_losses(self, x_start, cond, t, noise=None, z_ref=None, gt512=None, gt256=None, task=None, + batch :dict = None, + ): + # def p_losses_face(self, x_start, cond, t, reference=None,noise=None,GT_tar=None,landmarks=None): + # initialize MoE auxiliary loss to 0 to allow unconditional accumulation later + global_.moe_aux_loss = torch.tensor(0.0, device=self.device) + if self.first_stage_key == 'inpaint': + # x_start=x_start[:,:4,:,:] + noise = default(noise, lambda: torch.randn_like(x_start[:,:4,:,:])) + if 1: + x_noisy = self.q_sample(x_start=x_start[:,:4,:,:], t=t, noise=noise) + x_noisy = torch.cat((x_noisy,x_start[:,4:,:,:]),dim=1) + else: + noise = default(noise, lambda: torch.randn_like(x_start)) + if 1: + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + if z_ref is not None: + assert self.first_stage_key == 'inpaint', 'Expected first_stage_key to be "inpaint"' + """ + z_ref: b,4,... + z_ref = concat [z_ref_noisy, z_ref, tensor_1c] + tensor_1c is temporarily set to all zeros + """ + z_ref_noisy = self.q_sample(x_start=z_ref, t=t, noise=torch.randn_like(z_ref)) + tensor_1c = torch.zeros((z_ref.shape[0], 1, z_ref.shape[2], z_ref.shape[3]), device=z_ref.device) + if REFNET.CH9: + z_ref = torch.cat([z_ref_noisy, z_ref, tensor_1c], dim=1) + if 1: + model_output = self.apply_model(x_noisy, t, cond, z_ref=z_ref, ) + + loss_dict = {} + prefix = 'train' if self.training else 'val' + if DDIM_losses: + ######################## + t_new = torch.randint(self.num_timesteps-1, self.num_timesteps, (x_start.shape[0],), device=self.device).long().to(self.device) + # t_new=torch.tensor(t_new).to(self.device) + # noise_rec = default(noise, lambda: torch.randn_like(x_start[:,:4,:,:])) + x_noisy_rec = self.q_sample(x_start=x_start[:,:4,:,:], t=t_new, noise=noise) + x_noisy_rec = torch.cat((x_noisy_rec,x_start[:,4:,:,:]),dim=1) + + + ddim_steps=self.Reconstruct_DDIM_steps + n_samples=x_noisy_rec.shape[0] + shape=(4,64,64) + scale=5 + ddim_eta=0.0 + start_code=x_noisy_rec + test_model_kwargs=None + # t=t + + samples_ddim, sample_intermediates = self.sampler.sample_train(S=ddim_steps, # 4 (from Reconstruct_DDIM_steps in trian.yaml) + conditioning=cond, + batch_size=n_samples, + shape=shape, + verbose=False, + unconditional_guidance_scale=scale, + unconditional_conditioning=None, + eta=ddim_eta, + x_T=start_code, + t=t_new, + z_ref=z_ref, + test_model_kwargs=test_model_kwargs) + + + + + # x_samples_ddim= self.differentiable_decode_first_stage(samples_ddim) + + other_pred_x_0=sample_intermediates['pred_x0'] + len_inter = len(other_pred_x_0) + printC("len_inter", len_inter ) + for i in range(len(other_pred_x_0)): + other_pred_x_0[i]=self.differentiable_decode_first_stage(other_pred_x_0[i]) + # x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + # x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy() + + + ########################################### + + ID_loss=0 + clip_loss=0 + loss_lpips=0 + loss_rec=0 + loss_landmark=0 + + # model_output=samples_ddim + if 1: + + # x_samples_ddim=TF.resize(x_samples_ddim,(256,256)) + if 0: + inpaint_mask_64 = x_start[:,8,:,:] # inpaint region is 1, background is 0; shape b,64,64 + masks=TF.resize(inpaint_mask_64,(other_pred_x_0[0].shape[2],other_pred_x_0[0].shape[3])) # b,512,512 + if not 1: + masks = 1 - masks + #mask x_samples_ddim + x_samples_ddim_masked=[x_samples_ddim_preds*masks.unsqueeze(1) for x_samples_ddim_preds in other_pred_x_0] + # x_samples_ddim_masked=un_norm_clip(x_samples_ddim_masked) + # x_samples_ddim_masked = TF.normalize(x_samples_ddim_masked, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + else: + x_samples_ddim_masked = other_pred_x_0 + Landmark_loss_weight = 0 + ID_loss_weight = [0.3, 0, 0.1, 0.2, ][task] + if ID_loss_weight > 0 : + ID_Losses=[] + for step,x_samples_ddim_preds in enumerate(x_samples_ddim_masked): + ID_loss,sim_imp,_=self.face_ID_model(x_samples_ddim_preds,gt512,clip_img=False) + ID_Losses.append(ID_loss) + loss_dict.update({f'{prefix}/ID_loss_{step}': ID_loss}) + + ID_loss=torch.mean(torch.stack(ID_Losses)) + loss_dict.update({f'{prefix}/ID_loss': ID_loss}) + loss_dict.update({f'{prefix}/sim_imp': sim_imp}) + + CLIP_loss_weight = [1.5/4, 0.8, 1, 0.5, ][task] + if CLIP_loss_weight > 0 : + def _loss(_img1,_img2): + _e1 = self.encoder_clip_face.forward_vit(_img1,resize=True) + _e2 = self.encoder_clip_face.forward_vit(_img2,resize=True) + return torch.nn.functional.mse_loss( _e1, _e2 ) + clip_Losses=[] + for step,x_samples_ddim_preds in enumerate(x_samples_ddim_masked): + clip_loss = _loss(x_samples_ddim_preds,gt512) + clip_Losses.append(clip_loss) + loss_dict.update({f'{prefix}/clip_loss_{step}': clip_loss}) + clip_loss=torch.mean(torch.stack(clip_Losses)) + loss_dict.update({f'{prefix}/clip_loss': clip_loss}) + + LPIPS_loss_weight = [0.05, 0.015, 0.015, 0.015, ][task] + if LPIPS_loss_weight>0: + if gt256 is not None: + _lpips_base_size = 256 + _gt_for_lpips = gt256 + else: + _lpips_base_size = 512 + _gt_for_lpips = gt512 + + for j in range(len(other_pred_x_0)): + for i in range(3): + _size = _lpips_base_size//2**i + _pred_for_lpips = F.adaptive_avg_pool2d(other_pred_x_0[j],(_size,_size)) + _gt_for_lpips_resized = F.adaptive_avg_pool2d(_gt_for_lpips,(_size,_size)) + loss_lpips_1 = self.lpips_loss( + _pred_for_lpips, + _gt_for_lpips_resized, + ) + loss_dict.update({f'{prefix}/loss_lpips_{j}_{i}': loss_lpips_1}) + printC(f"loss_lpips_1 at {j} {i} :", loss_lpips_1) + loss_lpips += loss_lpips_1 + loss_dict.update({f'{prefix}/loss_lpips': loss_lpips}) + + REC_loss_weight = [0.05, 0.01, 0.01, 0.01, ][task] + if REC_loss_weight > 0 : # rec loss + for j in range(len(other_pred_x_0)): + loss_rec_1 = torch.nn.functional.mse_loss( other_pred_x_0[j], gt512) + loss_dict.update({f'{prefix}/loss_rec_{j}': loss_rec_1}) + printC(f"loss_rec_1 at {j} :", loss_rec_1) + loss_rec += loss_rec_1 + loss_dict.update({f'{prefix}/loss_rec': loss_rec}) + if 1: + if self.parameterization == "x0": + target = x_start + elif self.parameterization == "eps": + target = noise + else: + raise NotImplementedError() + + # this should be an MSE loss + loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3]) + loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()}) + loss_dict.update({f'{prefix}/loss_simple-t{task}': loss_simple.mean()}) + + self.logvar = self.logvar.to(self.device) + logvar_t = self.logvar[t].to(self.device) + loss = loss_simple / torch.exp(logvar_t) + logvar_t + # loss = loss_simple / torch.exp(self.logvar) + self.logvar + if self.learn_logvar: + loss_dict.update({f'{prefix}/loss_gamma': loss.mean()}) + loss_dict.update({'logvar': self.logvar.data.mean()}) + + loss = self.l_simple_weight * loss.mean() + + loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3)) #?? + loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean() + loss_dict.update({f'{prefix}/loss_vlb': loss_vlb}) + loss_dict.update({f'{prefix}/loss_vlb-t{task}': loss_vlb}) + loss += (self.original_elbo_weight * loss_vlb) + else: + loss = 0 + if DDIM_losses: + _item = lambda _a: _a.detach().cpu().item() if isinstance(_a,torch.Tensor) else _a + printC("orig, ID clip, lpips rec lmk:", + f"{_item(loss):.4f}, {_item(ID_loss):.4f} {_item(clip_loss):.4f}, {_item(loss_lpips):.4f} {_item(loss_rec):.4f} {_item(loss_landmark):.4f}", + f"{ID_Losses=}" if ID_loss_weight>0 else "", + f"{clip_Losses=}" if CLIP_loss_weight>0 else "", + ) + loss+=ID_loss_weight*ID_loss+LPIPS_loss_weight*loss_lpips+Landmark_loss_weight*loss_landmark+REC_loss_weight*loss_rec+CLIP_loss_weight*clip_loss + + # incorporate MoE auxiliary loss + moe_aux = global_.moe_aux_loss + if isinstance(moe_aux, torch.Tensor): + loss = loss + moe_aux + loss_dict.update({f'{prefix}/moe_aux_loss': moe_aux}) + loss_dict.update({f'{prefix}/loss': loss}) + loss_dict.update({f'{prefix}/loss-t{task}': loss}) + return loss, loss_dict + + + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.model.parameters()) + + if self.partial_training:# no + # if True: + print("Partial training.............................") + train_names=self.trainable_keys + train_names=[ 'attn2','norm2'] + params_train=[] + for name,param in self.model.named_parameters(): + if "diffusion_model" not in name and param.requires_grad: + print(name) + params_train.append(param) + + elif "diffusion_model" in name and any(train_name in name for train_name in train_names): + print(name) + params_train.append(param) + params=params_train + print("Setting up Adam optimizer.......................") + + if self.cond_stage_trainable:# yes + print(f"{self.__class__.__name__}: Also optimizing conditioner params!") + if hasattr(self,'encoder_clip_face'): + params += list(self.encoder_clip_face.final_ln2.parameters())+list(self.encoder_clip_face.mapper2.parameters()) + if self.USE_proj_out_source: + params += list(self.proj_out_source__face.parameters()) + if hasattr(self,'encoder_clip_hair'): + params += list(self.encoder_clip_hair.final_ln2.parameters())+list(self.encoder_clip_hair.mapper2.parameters()) + if self.USE_proj_out_source: + params += list(self.proj_out_source__hair.parameters()) + if hasattr(self,'encoder_clip_head_t2'): + params += list(self.encoder_clip_head_t2.final_ln2.parameters())+list(self.encoder_clip_head_t2.mapper2.parameters()) + if hasattr(self,'encoder_clip_head_t3'): + params += list(self.encoder_clip_head_t3.final_ln2.parameters())+list(self.encoder_clip_head_t3.mapper2.parameters()) + if hasattr(self,'encoder_clip_head_t2') or hasattr(self,'encoder_clip_head_t3'): + if self.USE_proj_out_source: + params += list(self.proj_out_source__head.parameters()) + if hasattr(self,'ID_proj_out'): + params += list(self.ID_proj_out.parameters()) + if hasattr(self,'landmark_proj_out'): # fixLmkProj + params += list(self.landmark_proj_out.parameters()) + if self.learn_logvar: + print('Diffusion model optimizing logvar') + params.append(self.logvar) + params.extend(self.learnable_vector) + params = [p for p in params if p.requires_grad] + + # Build param groups: MoE gate/expert use larger LR. + # Also apply per-task LR factor to all task-specific params. + # only match MoE-related parameter names generated by the UNet wrappers + moe_gate_ids = set() + moe_ep_ids = set() + for name, p in self.model.named_parameters(): + if not p.requires_grad: + continue + if ".moe_gate_mlp." in name: + moe_gate_ids.add(id(p)) + elif ".moe_experts_" in name: + moe_ep_ids.add(id(p)) + + params_ids = set(id(p) for p in params) + task_specific_ids = set() + for name, p in self.named_parameters(): + if not p.requires_grad: + continue + if id(p) not in params_ids: + continue + is_task_specific = is_task_specific_(name) + if rank_==0: print(f"{is_task_specific=} {name}") + if is_task_specific: + task_specific_ids.add(id(p)) + + base_params = [] + task_specific_params = [] + moe_gate_params = [] + moe_ep_params = [] + for p in params: + pid = id(p) + if pid in task_specific_ids: + task_specific_params.append(p) + elif pid in moe_gate_ids: + moe_gate_params.append(p) + elif pid in moe_ep_ids: + moe_ep_params.append(p) + else: + base_params.append(p) + + param_groups = [] + if base_params: + param_groups.append({"params": base_params, "lr": lr}) + if task_specific_params: + param_groups.append({"params": task_specific_params, "lr": lr * LR_factor}) + if moe_gate_params: + param_groups.append({"params": moe_gate_params, "lr": lr * MOE_GATE_LR_MULT}) + if moe_ep_params: + param_groups.append({"params": moe_ep_params, "lr": lr * MOE_EP_LR_MULT}) + if ZERO1_ENABLE: + zero_pg = None + if 1: + if dist.is_available() and dist.is_initialized(): + zero_pg = dist.new_group(backend='gloo') + opt = ZeroRedundancyOptimizer( + param_groups if (task_specific_params or moe_gate_params or moe_ep_params) else params, + optimizer_class=torch.optim.AdamW if ADAM_or_SGD else torch.optim.SGD, + lr=lr, + process_group=zero_pg, + ) + else: + if ADAM_or_SGD: + opt = torch.optim.AdamW(param_groups if (task_specific_params or moe_gate_params or moe_ep_params) else params, lr=lr) + else: + opt = torch.optim.SGD(param_groups if (task_specific_params or moe_gate_params or moe_ep_params) else params, lr=lr, momentum=0.9) + if gate_('LatentDiffusion.configure_optimizers params:'): + if (task_specific_params or moe_gate_params or moe_ep_params): + print(f"base/task_specific/ep/gate lens: {len(base_params)=} {len(task_specific_params)=} {len(moe_ep_params)=} {len(moe_gate_params)=}") + print(f"sum of .numel(): base={sum(p.numel() for p in base_params)} task_specific={sum(p.numel() for p in task_specific_params)} ep={sum(p.numel() for p in moe_ep_params)} gate={sum(p.numel() for p in moe_gate_params)}") + else: + print(f"{len(params)=}") + print(f"sum of .numel(): {sum(param.numel() for param in params)}") + if self.use_scheduler:# yes + assert 'target' in self.scheduler_config + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }] + return [opt], scheduler + return opt + + def on_train_epoch_start(self): + def _set_req_grad(p, flag): + if p.requires_grad != flag: + p.requires_grad = flag + return 1 + return 0 + return + if 0: + train_now = self.current_epoch < N_EPOCHS_TRAIN_REF_AND_MID + else: # alternating freezing + train_now = (self.current_epoch % 2 == 0) + ct_toggled = 0 + # 1) freeze all shared if not train_now; unfreeze when train_now + ct_shared = 0 + for name, p in self.model.diffusion_model.named_parameters(): + # target only the shared weights inside Shared+LoRA wrappers: FFN.shared_ffn.* and Conv.shared.* + is_shared = ('.shared_ffn.' in name) or ('.shared.' in name) + if is_shared: + ct_shared += _set_req_grad(p, train_now) + print(f"[freeze@epoch]{self.current_epoch=} {train_now=} {ct_toggled=} {ct_shared=}") + + @torch.no_grad() + def to_rgb(self, x): + x = x.float() + if not hasattr(self, "colorize"): + self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) + x = nn.functional.conv2d(x, weight=self.colorize) + x = 2. * (x - x.min()) / (x.max() - x.min()) - 1. + return x + def __repr__(self): + if DEBUG: return 'LatentDiffusion.__repr__' + return super().__repr__() + @property + def model_size(self): + if DEBUG: return -1 + return super().model_size + + +from .bank import Bank +class DiffusionWrapper(pl.LightningModule): + def __init__(self, diff_model_config, conditioning_key): + super().__init__() + diff_model_config['params']['is_refNet'] = False + self.diffusion_model = instantiate_from_config(diff_model_config) + self.conditioning_key = conditioning_key + assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm'] + if REFNET.ENABLE: + diff_model_config_refNet = diff_model_config + print('instantiate / deepcopy diffusion_model_refNet ing...') + if 1: + diff_model_config_refNet['params']['in_channels'] = 9 if REFNET.CH9 else 4 + diff_model_config_refNet['params']['is_refNet'] = True + self.diffusion_model_refNet :UNetModel = instantiate_from_config(diff_model_config_refNet) + else: + self.diffusion_model_refNet :UNetModel = copy.deepcopy(self.diffusion_model) # faster than re-instantiating + self.diffusion_model_refNet.is_refNet = True + if 1: + # print(f"before del: {len(self.diffusion_model_refNet.input_blocks)=}") + if 1: + self.diffusion_model_refNet.input_blocks = self.diffusion_model_refNet.input_blocks[:9] + del self.diffusion_model_refNet.middle_block + del self.diffusion_model_refNet.output_blocks + del self.diffusion_model_refNet.out + print('over.') + # Keep only a single diffusion_model_refNet; no t-suffixed clones + + def forward(self, x, t, c_concat: list = None, c_crossattn: list = None,return_features=False, + z_ref=None, + task = None, + _trainer :pl.Trainer = None, + ): + _in_train_or_val = ( _trainer is not None ) and ( _trainer.validating or _trainer.sanity_checking ) # indicates train or validation state + assert self.conditioning_key == 'crossattn' + if self.conditioning_key is None: + out = self.diffusion_model(x, t) + elif self.conditioning_key == 'concat': + xc = torch.cat([x] + c_concat, dim=1) + out = self.diffusion_model(xc, t) + elif self.conditioning_key == 'crossattn': + cc = torch.cat(c_crossattn, 1) #-->cc.shape = (bs, 1, 768) ## adding return_features here only for testing + if (REFNET.ENABLE and REFNET.task2layerNum[task]>0): + if task in (0,2,3,): + cc_ref = cc[:,:-1, :] + else: + cc_ref = cc + printC("c for refNet",f"{custom_repr_v3(cc_ref)}") + self.diffusion_model_refNet(z_ref, t, context=cc_ref,return_features=False) + out = self.diffusion_model(x, t, context=cc,return_features=return_features) + if (REFNET.ENABLE and REFNET.task2layerNum[task]>0) and not (self.training or _in_train_or_val): + # if 1: + self.bank.clear() + elif self.conditioning_key == 'hybrid': + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(xc, t, context=cc) + elif self.conditioning_key == 'adm': + cc = c_crossattn[0] + out = self.diffusion_model(x, t, y=cc) + else: + raise NotImplementedError() + + return out #-->out.shape = (bs, 4,64,64) diff --git a/ldm/models/diffusion/misc_4ddpm.py b/ldm/models/diffusion/misc_4ddpm.py new file mode 100644 index 0000000000000000000000000000000000000000..54de2448d523e15afafa8035ba4aff118daf88fd --- /dev/null +++ b/ldm/models/diffusion/misc_4ddpm.py @@ -0,0 +1,187 @@ +""" +wild mixture of +https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py +https://github.com/CompVis/taming-transformers +-- merci +""" +import global_ +import os +import sys +from pathlib import Path +import torch +import torch.nn as nn +import numpy as np +import pytorch_lightning as pl +from torch.optim.lr_scheduler import LambdaLR +from einops import rearrange, repeat +from contextlib import contextmanager +from functools import partial +from tqdm import tqdm +from torchvision.utils import make_grid +from pytorch_lightning.utilities.distributed import rank_zero_only +from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config +from ldm.modules.ema import LitEma +from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution +from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL +from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like +from ldm.models.diffusion.ddim import DDIMSampler +from torchvision.transforms import Resize +import torchvision.transforms.functional as TF +import torch.nn.functional as F +import math +import time +import random +import copy +from torch.autograd import Variable +import torch.distributed as dist +from torch.distributed.optim import ZeroRedundancyOptimizer +from src.Face_models.encoders.model_irse import Backbone +from eval_tool.lpips.lpips import LPIPS +from PIL import Image +import argparse +from contextlib import nullcontext +from util_face import * +from util_vis import vis_tensors_A +from my_py_lib.image_util import save_any_A,imgs_2_grid_A +from my_py_lib.torch_util import recursive_to +from my_py_lib.torch_util import custom_repr_v3 +from imports import * +from lmk_util.lmk_extractor import LandmarkExtractor,lmkAll_2_lmkMain +from ldm.modules.encoders.modules import FrozenCLIPEmbedder +from ldm.modules.diffusionmodules.openaimodel import UNetModel +from MoE import * + + +__conditioning_keys__ = {'concat': 'c_concat', + 'crossattn': 'c_crossattn', + 'adm': 'y'} + + + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + +def un_norm_clip(x1): + x = x1*1.0 # to avoid changing the original tensor or clone() can be used + reduce=False + if len(x.shape)==3: + x = x.unsqueeze(0) + reduce=True + x[:,0,:,:] = x[:,0,:,:] * 0.26862954 + 0.48145466 + x[:,1,:,:] = x[:,1,:,:] * 0.26130258 + 0.4578275 + x[:,2,:,:] = x[:,2,:,:] * 0.27577711 + 0.40821073 + + if reduce: + x = x.squeeze(0) + return x + +def un_norm(x): + return (x+1.0)/2.0 + +def save_clip_img(img, path,clip=True): + if clip: + img=un_norm_clip(img) + else: + img=torch.clamp(un_norm(img), min=0.0, max=1.0) + img = img.cpu().numpy().transpose((1, 2, 0)) + img = (img * 255).astype(np.uint8) + img = Image.fromarray(img) + img.save(path) + # if clip: + # img=TF.normalize(img, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]) + # else: + # img=TF.normalize(img, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + + +class IDLoss(nn.Module): + def __init__(self,opts,multiscale=False): + super(IDLoss, self).__init__() + print('Loading ResNet ArcFace') + self.opts = opts + self.multiscale = multiscale + self.face_pool_1 = torch.nn.AdaptiveAvgPool2d((256, 256)) + self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se') + # self.facenet=iresnet100(pretrained=False, fp16=False) # changed by sanoojan + + self.facenet.load_state_dict(torch.load(opts.other_params.arcface_path)) + + self.face_pool_2 = torch.nn.AdaptiveAvgPool2d((112, 112)) + self.facenet.eval() + + self.set_requires_grad(False) + + def set_requires_grad(self, flag=True): + for p in self.parameters(): + p.requires_grad = flag + + def extract_feats(self, x,clip_img=True): + # breakpoint() + if clip_img: + x = un_norm_clip(x) + x = TF.normalize(x, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + x = self.face_pool_1(x) if x.shape[2]!=256 else x # (1) resize to 256 if needed + x = x[:, :, 35:223, 32:220] # (2) Crop interesting region + x = self.face_pool_2(x) # (3) resize to 112 to fit pre-trained model + # breakpoint() + x_feats = self.facenet(x, multi_scale=self.multiscale ) + + # x_feats = self.facenet(x) # changed by sanoojan + return x_feats + + + + def forward(self, y_hat, y,clip_img=True,return_seperate=False): + n_samples = y.shape[0] + y_feats_ms = self.extract_feats(y,clip_img=clip_img) # Otherwise use the feature from there + + y_hat_feats_ms = self.extract_feats(y_hat,clip_img=clip_img) + y_feats_ms = [y_f.detach() for y_f in y_feats_ms] + + loss_all = 0 + sim_improvement_all = 0 + seperate_sim=[] + for y_hat_feats, y_feats in zip(y_hat_feats_ms, y_feats_ms): + + loss = 0 + sim_improvement = 0 + count = 0 + # lossess = [] + for i in range(n_samples): + sim_target = y_hat_feats[i].dot(y_feats[i]) + sim_views = y_feats[i].dot(y_feats[i]) + + seperate_sim.append(sim_target) + loss += 1 - sim_target # id loss + sim_improvement += float(sim_target) - float(sim_views) + count += 1 + + + loss_all += loss / count + sim_improvement_all += sim_improvement / count + if return_seperate: + return loss_all, sim_improvement_all, seperate_sim + return loss_all, sim_improvement_all, None + +def uniform_on_device(r1, r2, shape, device): + return (r1 - r2) * torch.rand(*shape, device=device) + r2 + +class LandmarkDetectionModel(nn.Module): + def __init__(self): + super(LandmarkDetectionModel, self).__init__() + self.features = nn.Sequential( + nn.Conv2d(640, 128, kernel_size=3, padding=1), + nn.ReLU(), + nn.MaxPool2d(kernel_size=2, stride=2) + ) + self.landmark_predictor = nn.Linear(128 * 32 * 32, 68 * 2) # Adjust output size as needed + + def forward(self, x): + x = self.features(x) + x = torch.flatten(x, 1) + landmarks = self.landmark_predictor(x) + return landmarks diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..41d1b7e8680ea60dbdf5a8651fdea0d28d59dfc7 --- /dev/null +++ b/ldm/modules/attention.py @@ -0,0 +1,317 @@ +from inspect import isfunction +import global_ +import math +import torch +import torch.nn.functional as F +from torch import nn, einsum +from einops import rearrange, repeat + +from ldm.modules.diffusionmodules.util import checkpoint +from typing import List, Tuple +from imports import * + + +def exists(val): + return val is not None + + +def uniq(arr): + return{el: True for el in arr}.keys() + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +def init_(tensor): + dim = tensor.shape[-1] + std = 1 / math.sqrt(dim) + tensor.uniform_(-std, std) + return tensor + + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., inner_dim=None): + super().__init__() + if inner_dim is None: + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.dim = dim + self.inner_dim = inner_dim + self.dim_out = dim_out + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x, token_pos=None): + return self.net(x) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class LinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) + k = k.softmax(dim=-1) + context = torch.einsum('bhdn,bhen->bhde', k, v) + out = torch.einsum('bhde,bhdn->bhen', context, q) + out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) + return self.to_out(out) + + +class SpatialSelfAttention(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = rearrange(q, 'b c h w -> b (h w) c') + k = rearrange(k, 'b c h w -> b c (h w)') + w_ = torch.einsum('bij,bjk->bik', q, k) + + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = rearrange(v, 'b c h w -> b c (h w)') + w_ = rearrange(w_, 'b i j -> b j i') + h_ = torch.einsum('bij,bjk->bik', v, w_) + h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) + h_ = self.proj_out(h_) + + return x+h_ + + +class CrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.,sep_head_att=False): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head ** -0.5 + self.heads = heads # 8 + self.dim_head=dim_head #40 + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + # self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + # self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + head_splits=[6,2] + self.head_splits=head_splits + # if sep_head_att: + # self.to_k = nn.ModuleList([nn.Linear(context_dim, dim_head*head_splits[i], bias=False) for i in range(len(head_splits))]) + # self.to_v = nn.ModuleList([nn.Linear(context_dim, dim_head*head_splits[i], bias=False) for i in range(len(head_splits))]) + # else: + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), + nn.Dropout(dropout) + ) + + def forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) # 2,4096,320 + context = default(context, x) #2,4096,320 + if context.shape[-1]==768*2: + # this is for different attention heads + context1,context2=torch.chunk(context,2,dim=-1) # clip/id context1, landmark context2 + k1=self.to_k(context1) + k2=self.to_k(context2) + v1=self.to_v(context1) + v2=self.to_v(context2) + + k=torch.cat([k1[:,:,:self.head_splits[0]*self.dim_head],k2[:,:,-self.head_splits[1]*self.dim_head:]],dim=-1) + v=torch.cat([v1[:,:,:self.head_splits[0]*self.dim_head],v2[:,:,-self.head_splits[1]*self.dim_head:]],dim=-1) + # head_splits=[6,2] + # k1 = self.to_k[0](context1) + # v1 = self.to_v[0](context1) + # k2 = self.to_k[1](context2) + # v2 = self.to_v[1](context2) + # k=torch.cat([k1,k2],dim=-1) + # v=torch.cat([v1,v2],dim=-1) + + else: + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + + if exists(mask): + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + attn = sim.softmax(dim=-1) + + out = einsum('b i j, b j d -> b i d', attn, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out) + + +class BasicTransformerBlock(nn.Module): + def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,sep_head_att=False): + super().__init__() + self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,sep_head_att=False) # is a self-attention + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, + heads=n_heads, dim_head=d_head, dropout=dropout,sep_head_att=sep_head_att) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward(self, x, context=None, token_pos=None): + inputs = (x, context, token_pos, ) + if hasattr(self,'name4bank') and REFNET.task2layerNum[global_.task]>0: + if self.isReader_4bank: + inputs = (x, context, token_pos, self.bank.get(self.name4bank) ) # x, context, x_refNet + else: + self.bank.set(self.name4bank, x) + return checkpoint(self._forward, inputs, self.parameters(), self.checkpoint) + + def _forward(self, x, context=None, token_pos=None, x_refNet=None):# x, x_refNet: before LN + if x_refNet is None: + x = self.attn1(self.norm1(x)) + x + else: + x_norm = self.norm1(x) + x_norm_cat = torch.cat( [ x_norm, self.norm1(x_refNet) ] , dim=1 ) + x = self.attn1(x_norm, context=x_norm_cat) + x + del x_norm,x_norm_cat + x = self.attn2(self.norm2(x), context=context) + x + # This ff might be modified into an MoE module, so pass token_pos + x = self.ff(self.norm3(x), token_pos) + x + return x + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + def __init__(self, in_channels, n_heads, d_head, + depth=1, dropout=0., context_dim=None,sep_head_att=False,head_splits=None): + super().__init__() + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + + self.proj_in = nn.Conv2d(in_channels, + inner_dim, + kernel_size=1, + stride=1, + padding=0) + + self.transformer_blocks = nn.ModuleList( + [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim,sep_head_att=sep_head_att) + for d in range(depth)] + ) + + self.proj_out = zero_module(nn.Conv2d(inner_dim, + in_channels, + kernel_size=1, + stride=1, + padding=0)) + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + x = self.proj_in(x) + x = rearrange(x, 'b c h w -> b (h w) c') + if 1: # set token position grid (normalized centers) for gating/router use + num_tokens = h * w + y_coords = torch.arange(h, device=x.device, dtype=x.dtype) + x_coords = torch.arange(w, device=x.device, dtype=x.dtype) + yy, xx = torch.meshgrid(y_coords, x_coords, indexing='ij') + pos = torch.stack([(xx + 0.5) / float(w), (yy + 0.5) / float(h)], dim=-1) # [h,w,2] + pos = pos.reshape(1, num_tokens, 2).expand(b, -1, -1).contiguous() # b, n, 2 + for block in self.transformer_blocks: + x = block(x, context=context, token_pos=pos) + x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) + x = self.proj_out(x) + return x + x_in diff --git a/ldm/modules/diffusionmodules/__init__.py b/ldm/modules/diffusionmodules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ldm/modules/diffusionmodules/model.py b/ldm/modules/diffusionmodules/model.py new file mode 100644 index 0000000000000000000000000000000000000000..0142fa8920235b87f63e842a3251dd0e8a1d6bca --- /dev/null +++ b/ldm/modules/diffusionmodules/model.py @@ -0,0 +1,592 @@ +# pytorch_diffusion + derived encoder decoder +import math +import torch +import torch.nn as nn +import numpy as np +from einops import rearrange + +from ldm.util import instantiate_from_config +from ldm.modules.attention import LinearAttention + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0,1,0,0)) + return emb + + +def nonlinearity(x): + # swish + return x*torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0,1,0,1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, + dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, + out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x+h + + +class LinAttnBlock(LinearAttention): + """to match AttnBlock usage""" + def __init__(self, in_channels): + super().__init__(dim=in_channels, heads=1, dim_head=in_channels) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = q.reshape(b,c,h*w) + q = q.permute(0,2,1) # b,hw,c + k = k.reshape(b,c,h*w) # b,c,hw + w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b,c,h*w) + w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b,c,h,w) + + h_ = self.proj_out(h_) + + return x+h_ + + +def make_attn(in_channels, attn_type="vanilla"): + assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown' + print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + return AttnBlock(in_channels) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + return LinAttnBlock(in_channels) + + +class Model(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = self.ch*4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList([ + torch.nn.Linear(self.ch, + self.temb_ch), + torch.nn.Linear(self.temb_ch, + self.temb_ch), + ]) + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + skip_in = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + if i_block == self.num_res_blocks: + skip_in = ch*in_ch_mult[i_level] + block.append(ResnetBlock(in_channels=block_in+skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x, t=None, context=None): + #assert x.shape[2] == x.shape[3] == self.resolution + if context is not None: + # assume aligned context, cat along channel axis + x = torch.cat((x, context), dim=1) + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + def get_last_layer(self): + return self.conv_out.weight + + +class Encoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla", + **ignore_kwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + 2*z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, + attn_type="vanilla", **ignorekwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,)+tuple(ch_mult) + block_in = ch*ch_mult[self.num_resolutions-1] + curr_res = resolution // 2**(self.num_resolutions-1) + self.z_shape = (1,z_channels,curr_res,curr_res) + print("Working with z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, + block_in, + kernel_size=3, + stride=1, + padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, z): + #assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if self.tanh_out: + h = torch.tanh(h) + return h + + +class Resize(nn.Module): + def __init__(self, in_channels=None, learned=False, mode="bilinear"): + super().__init__() + self.with_conv = learned + self.mode = mode + if self.with_conv: + print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode") + raise NotImplementedError() + assert in_channels is not None + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=4, + stride=2, + padding=1) + + def forward(self, x, scale_factor=1.0): + if scale_factor==1.0: + return x + else: + x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor) + return x diff --git a/ldm/modules/diffusionmodules/openaimodel.py b/ldm/modules/diffusionmodules/openaimodel.py new file mode 100644 index 0000000000000000000000000000000000000000..552230f4c024902dcaedc192aacc6dbdd8294bce --- /dev/null +++ b/ldm/modules/diffusionmodules/openaimodel.py @@ -0,0 +1,1119 @@ +from imports import * +from abc import abstractmethod +from functools import partial +import math +from typing import Iterable + +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F + +from ldm.modules.diffusionmodules.util import ( + checkpoint, + conv_nd, + linear, + avg_pool_nd, + zero_module, + normalization, + timestep_embedding, +) +from ldm.modules.attention import SpatialTransformer + + +# dummy replace +def convert_module_to_f16(x): + pass + +def convert_module_to_f32(x): + pass + + +## go +class AttentionPool2d(nn.Module): + """ + Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__( + self, + spacial_dim: int, + embed_dim: int, + num_heads_channels: int, + output_dim: int = None, + ): + super().__init__() + self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = embed_dim // num_heads_channels + self.attention = QKVAttention(self.num_heads) + + def forward(self, x): + b, c, *_spatial = x.shape + x = x.reshape(b, c, -1) # NC(HW) + x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, context=None): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, SpatialTransformer): + x = layer(x, context) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + +class TransposedUpsample(nn.Module): + 'Learned 2x upsampling without padding' + def __init__(self, channels, out_channels=None, ks=5): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2) + + def forward(self,x): + return self.up(x) + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, self.channels, self.out_channels, 3, stride=stride, padding=padding + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return checkpoint( + self._forward, (x, emb), self.parameters(), self.use_checkpoint + ) + + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class My_ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, 4, 3, padding=1) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, 4, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return checkpoint( + self._forward, (x, emb), self.parameters(), self.use_checkpoint + ) + + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return h + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x): + return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! + #return pt_checkpoint(self._forward, x) # pytorch + + def _forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial ** 2) * c + model.total_ops += th.DoubleTensor([matmul_ops]) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + add_conv_in_front_of_unet=False, + sep_head_att=False, + land_mark_id_seperate_layers=False, + head_splits=None, + is_refNet=False, + ): + super().__init__() + self.is_refNet = is_refNet + if use_spatial_transformer: + assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' + + if context_dim is not None: + assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' + from omegaconf.listconfig import ListConfig + if type(context_dim) == ListConfig: + context_dim = list(context_dim) + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' + + if num_head_channels == -1: + assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' + + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + self.add_conv_in_front_of_unet=add_conv_in_front_of_unet + + #Newly added for custom transformer support + self.sep_head_att=sep_head_att + self.land_mark_id_seperate_layers=land_mark_id_seperate_layers + self.head_splits=head_splits + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + if DBEUG_skip_most_in_Unet_constructor: + return + + if self.num_classes is not None: + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + + + if self.add_conv_in_front_of_unet: + assert 0 + + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,sep_head_att=sep_head_att,head_splits=head_splits + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,sep_head_att=sep_head_att,head_splits=head_splits + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(num_res_blocks + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,sep_head_att=sep_head_att,head_splits=head_splits + ) + ) + if level and i == num_res_blocks: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), #model_channels=320, out_channels=3, dims=2 + ) + if self.predict_codebook_ids: + self.id_predictor = nn.Sequential( + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1), + #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + self.output_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + self.output_blocks.apply(convert_module_to_f32) + + def forward(self, x, timesteps=None, context=None, y=None,return_features=False,**kwargs): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn + :param y: an [N] Tensor of labels, if class-conditional. + :return: an [N x C x ...] Tensor of outputs. + """ + + dumper_var_prefix = 'refnet' if self.is_refNet else 'unet' + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + features=[] + if self.num_classes is not None: + assert y.shape == (x.shape[0],) + emb = emb + self.label_emb(y) + + h = x.type(self.dtype) + + if context.shape[-1]==768*2 and not self.sep_head_att and self.land_mark_id_seperate_layers: + # split the last dim into 2 + context1,context2=th.chunk(context,2,dim=-1) # clip/id context1, landmark context2 + else: + context1,context2=context,context + + if self.add_conv_in_front_of_unet: + for module in self.add_resbolck: + h = module(h, emb, context2) + + for module in self.input_blocks: + h = module(h, emb, context2) + # global_dumper.save(h, f'{dumper_var_prefix}/input_blocks/{idx}') + hs.append(h) + if self.is_refNet: + return + h = self.middle_block(h, emb, context1) # ([4, 1280, 8, 8]) + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb, context1) + if return_features: + features.append(h) + if self.is_refNet: + return + h = h.type(x.dtype) + if self.predict_codebook_ids: + assert 0 + return self.id_predictor(h) + elif return_features: + return self.out(h), features + else: + return self.out(h) + + +class EncoderUNetModel(nn.Module): + """ + The half UNet model with attention and timestep embedding. + For usage, see UNet. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + use_checkpoint=False, + use_fp16=False, + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + pool="adaptive", + *args, + **kwargs + ): + super().__init__() + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + self.pool = pool + if pool == "adaptive": + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + nn.AdaptiveAvgPool2d((1, 1)), + zero_module(conv_nd(dims, ch, out_channels, 1)), + nn.Flatten(), + ) + elif pool == "attention": + assert num_head_channels != -1 + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + AttentionPool2d( + (image_size // ds), ch, num_head_channels, out_channels + ), + ) + elif pool == "spatial": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + nn.ReLU(), + nn.Linear(2048, self.out_channels), + ) + elif pool == "spatial_v2": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + normalization(2048), + nn.SiLU(), + nn.Linear(2048, self.out_channels), + ) + else: + raise NotImplementedError(f"Unexpected {pool} pooling") + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + + def forward(self, x, timesteps): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :return: an [N x K] Tensor of outputs. + """ + emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + + results = [] + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = self.middle_block(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = th.cat(results, axis=-1) + return self.out(h) + else: + h = h.type(x.dtype) + return self.out(h) + diff --git a/ldm/modules/diffusionmodules/util.py b/ldm/modules/diffusionmodules/util.py new file mode 100644 index 0000000000000000000000000000000000000000..bc4ba2ab4ba4170d7cdba517ef80515ca9f92e30 --- /dev/null +++ b/ldm/modules/diffusionmodules/util.py @@ -0,0 +1,273 @@ +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + + +import os +import math +import torch +import torch.nn as nn +import numpy as np +from einops import repeat + +from ldm.util import instantiate_from_config + + +def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if schedule == "linear": + betas = ( + torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + elif schedule == "sqrt": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + + +def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): + if ddim_discr_method == 'uniform': + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == 'quad': + ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) + else: + raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + steps_out = ddim_timesteps + 1 + if verbose: + print(f'Selected timesteps for ddim sampler: {steps_out}') + return steps_out + + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) + if verbose: + print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') + print(f'For the chosen value of eta, which is {eta}, ' + f'this results in the following sigma_t schedule for ddim sampler {sigmas}') + return sigmas, alphas, alphas_prev + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + # pass only parameters that require gradients to avoid backward on frozen params + # params are used only in backward's autograd.grad, so filtering them does not affect the forward computation + _params_list = list(params) + _params_list = [p for p in _params_list if p.requires_grad ] + args = tuple(inputs) + tuple(_params_list) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + # params do not participate in the forward computation and are only referenced in backward. + with torch.no_grad(): # disable gradient computation during forward to avoid storing intermediate activations + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): # recompute forward to retrieve intermediate activations, then compute gradients + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] # detach inputs and re-enable gradients + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] # use shallow copies to avoid in-place modification issues + output_tensors = ctx.run_function(*shallow_copies) + # ctx.input_params were filtered to only include requires_grad=True during checkpoint(), so frozen parameters won't raise errors here. + input_grads = torch.autograd.grad( + output_tensors, # tensors output by the forward pass + ctx.input_tensors + ctx.input_params, # inputs that require gradients + output_grads, # gradients of the output tensors (from backward) + allow_unused=True, # allow some inputs to not receive gradients + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, 'b -> b d', d=dim) + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +class HybridConditioner(nn.Module): + + def __init__(self, c_concat_config, c_crossattn_config): + super().__init__() + self.concat_conditioner = instantiate_from_config(c_concat_config) + self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) + + def forward(self, c_concat, c_crossattn): + c_concat = self.concat_conditioner(c_concat) + c_crossattn = self.crossattn_conditioner(c_crossattn) + return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() diff --git a/ldm/modules/distributions/__init__.py b/ldm/modules/distributions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ldm/modules/distributions/distributions.py b/ldm/modules/distributions/distributions.py new file mode 100644 index 0000000000000000000000000000000000000000..f2b8ef901130efc171aa69742ca0244d94d3f2e9 --- /dev/null +++ b/ldm/modules/distributions/distributions.py @@ -0,0 +1,92 @@ +import torch +import numpy as np + + +class AbstractDistribution: + def sample(self): + raise NotImplementedError() + + def mode(self): + raise NotImplementedError() + + +class DiracDistribution(AbstractDistribution): + def __init__(self, value): + self.value = value + + def sample(self): + return self.value + + def mode(self): + return self.value + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.]) + else: + if other is None: + return 0.5 * torch.sum(torch.pow(self.mean, 2) + + self.var - 1.0 - self.logvar, + dim=[1, 2, 3]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=[1, 2, 3]) + + def nll(self, sample, dims=[1,2,3]): + if self.deterministic: + return torch.Tensor([0.]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims) + + def mode(self): + return self.mean + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for torch.exp(). + logvar1, logvar2 = [ + x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + torch.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + ) diff --git a/ldm/modules/ema.py b/ldm/modules/ema.py new file mode 100644 index 0000000000000000000000000000000000000000..c8c75af43565f6e140287644aaaefa97dd6e67c5 --- /dev/null +++ b/ldm/modules/ema.py @@ -0,0 +1,76 @@ +import torch +from torch import nn + + +class LitEma(nn.Module): + def __init__(self, model, decay=0.9999, use_num_upates=True): + super().__init__() + if decay < 0.0 or decay > 1.0: + raise ValueError('Decay must be between 0 and 1') + + self.m_name2s_name = {} + self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) + self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates + else torch.tensor(-1,dtype=torch.int)) + + for name, p in model.named_parameters(): + if p.requires_grad: + #remove as '.'-character is not allowed in buffers + s_name = name.replace('.','') + self.m_name2s_name.update({name:s_name}) + self.register_buffer(s_name,p.clone().detach().data) + + self.collected_params = [] + + def forward(self,model): + decay = self.decay + + if self.num_updates >= 0: + self.num_updates += 1 + decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) + + one_minus_decay = 1.0 - decay + + with torch.no_grad(): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + + for key in m_param: + if m_param[key].requires_grad: + sname = self.m_name2s_name[key] + shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) + shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) + else: + assert not key in self.m_name2s_name + + def copy_to(self, model): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + for key in m_param: + if m_param[key].requires_grad: + m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) + else: + assert not key in self.m_name2s_name + + def store(self, parameters): + """ + Save the current parameters for restoring later. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. + """ + self.collected_params = [param.clone() for param in parameters] + + def restore(self, parameters): + """ + Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. + """ + for c_param, param in zip(self.collected_params, parameters): + param.data.copy_(c_param.data) diff --git a/ldm/modules/encoders/__init__.py b/ldm/modules/encoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..8636726ab684b7e4c6269216c1322ee3f8aee84f --- /dev/null +++ b/ldm/modules/encoders/modules.py @@ -0,0 +1,304 @@ +import torch +import torch.nn as nn +from functools import partial +import clip +from einops import rearrange, repeat +from transformers import CLIPTokenizer, CLIPTextModel,CLIPVisionModel,CLIPModel,CLIPProcessor +import kornia +from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test +from .xf import LayerNorm, Transformer +import math +from torchvision import transforms +local_files_only = 0 + +class AbstractEncoder(nn.Module): + def __init__(self): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + + + +class ClassEmbedder(nn.Module): + def __init__(self, embed_dim, n_classes=1000, key='class'): + super().__init__() + self.key = key + self.embedding = nn.Embedding(n_classes, embed_dim) + + def forward(self, batch, key=None): + if key is None: + key = self.key + # this is for use in crossattn + c = batch[key][:, None] + c = self.embedding(c) + return c + + +class TransformerEmbedder(AbstractEncoder): + """Some transformer encoder layers""" + def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"): + super().__init__() + self.device = device + self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, + attn_layers=Encoder(dim=n_embed, depth=n_layer)) + + def forward(self, tokens): + tokens = tokens.to(self.device) # meh + z = self.transformer(tokens, return_embeddings=True) + return z + + def encode(self, x): + return self(x) + + +class BERTTokenizer(AbstractEncoder): + """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)""" + def __init__(self, device="cuda", vq_interface=True, max_length=77): + super().__init__() + from transformers import BertTokenizerFast # TODO: add to reuquirements + self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") + self.device = device + self.vq_interface = vq_interface + self.max_length = max_length + + def forward(self, text): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + return tokens + + @torch.no_grad() + def encode(self, text): + tokens = self(text) + if not self.vq_interface: + return tokens + return None, None, [None, None, tokens] + + def decode(self, text): + return text + + +class BERTEmbedder(AbstractEncoder): + """Uses the BERT tokenizr model and add some transformer encoder layers""" + def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77, + device="cuda",use_tokenizer=True, embedding_dropout=0.0): + super().__init__() + self.use_tknz_fn = use_tokenizer + if self.use_tknz_fn: + self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len) + self.device = device + self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, + attn_layers=Encoder(dim=n_embed, depth=n_layer), + emb_dropout=embedding_dropout) + + def forward(self, text): + if self.use_tknz_fn: + tokens = self.tknz_fn(text)#.to(self.device) + else: + tokens = text + z = self.transformer(tokens, return_embeddings=True) + return z + + def encode(self, text): + # output of length 77 + return self(text) + + +class SpatialRescaler(nn.Module): + def __init__(self, + n_stages=1, + method='bilinear', + multiplier=0.5, + in_channels=3, + out_channels=None, + bias=False): + super().__init__() + self.n_stages = n_stages + assert self.n_stages >= 0 + assert method in ['nearest','linear','bilinear','trilinear','bicubic','area'] + self.multiplier = multiplier + self.interpolator = partial(torch.nn.functional.interpolate, mode=method) + self.remap_output = out_channels is not None + if self.remap_output: + print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.') + self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias) + + def forward(self,x): + for stage in range(self.n_stages): + x = self.interpolator(x, scale_factor=self.multiplier) + + + if self.remap_output: + x = self.channel_mapper(x) + return x + + def encode(self, x): + return self(x) + + +class FrozenCLIPImageEmbedder(AbstractEncoder): + """Uses the CLIP transformer encoder for text (from Hugging Face)""" + def __init__(self, version="openai/clip-vit-large-patch14"): + super().__init__() + self.transformer = CLIPVisionModel.from_pretrained(version) + + self.final_ln = LayerNorm(1024) + self.mapper = Transformer( + 1, + 1024, + 5, + 1, + ) + + + self.freeze() + + def freeze(self): + self.transformer = self.transformer.eval() + for param in self.parameters(): + param.requires_grad = False + for param in self.mapper.parameters(): + param.requires_grad = True + for param in self.final_ln.parameters(): + param.requires_grad = True + + + def forward(self, image): + outputs = self.transformer(pixel_values=image) + z = outputs.pooler_output + z = z.unsqueeze(1) + z = self.mapper(z) + + z = self.final_ln(z) + return z + + def encode(self, image): + return self(image) + + +class FrozenCLIPTextEmbedder(AbstractEncoder): + """Uses the CLIP transformer encoder for text (from Hugging Face)""" + def __init__(self, version="openai/clip-vit-large-patch14"): + super().__init__() + self.transformer = CLIPTextModel.from_pretrained(version) + self.tokenizer = CLIPTokenizer.from_pretrained(version) + # model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32") + # >>> tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") + + # >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt") + + # >>> outputs = model(**inputs) + # >>> last_hidden_state = outputs.last_hidden_state + # >>> pooled_output = outputs.pooler_output # pooled (EOS token) states + self.freeze() + + def freeze(self): + self.transformer = self.transformer.eval() + # self.tokenizer = self.tokenizer.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + inputs= self.tokenizer(text, padding=True, return_tensors="pt") + inputs = {k: v.to(self.transformer.device) for k, v in inputs.items()} + z = self.transformer(**inputs) + return z + + def encode(self, text): + return self(text) + + + +class FrozenCLIPEmbedder(nn.Module): + def __init__(self, version="openai/clip-vit-large-patch14"): + super().__init__() + + self.model = CLIPModel.from_pretrained(version,local_files_only=local_files_only) + del self.model.text_model + self.resize = transforms.Resize((224, 224)) + # self.processor = CLIPProcessor.from_pretrained(version) + self.tokenizer = CLIPTokenizer.from_pretrained(version,local_files_only=local_files_only) + self.final_ln2=LayerNorm(768) + self.mapper2=Transformer( + 1, + 768, + 5, + 1, + ) + + self.projection_back=nn.Linear(768,1024) + + self.freeze() + + def freeze(self): + self.model = self.model.eval() + # self.processor = self.processor.eval() + for param in self.parameters(): + param.requires_grad = False + # for param in self.projection_back.parameters(): + # param.requires_grad = True + for param in self.mapper2.parameters(): + param.requires_grad = True + for param in self.final_ln2.parameters(): + param.requires_grad = True + + def forward_vit(self, image, resize=False): + if resize: + image = self.resize(image) + outputs = self.model.vision_model(pixel_values=image) + # z = outputs.pooler_output + z = outputs.last_hidden_state # B, 257, 1024 + z=self.model.visual_projection(z) + # z=self.projection_back(z) + # z = z.unsqueeze(1) + return z + def encode_B(self, z): + z = self.mapper2(z) + z = self.final_ln2(z) + return z + def forward(self, image, pooled_or_all_tokens = False): + outputs = self.model.vision_model(pixel_values=image) + if pooled_or_all_tokens: + z = outputs.pooler_output + z=self.model.visual_projection(z) + # z=self.projection_back(z) + z = z.unsqueeze(1) + else: + z = outputs.last_hidden_state # B, 257, 1024 + z=self.model.visual_projection(z) # B, 257, 768 + # z=self.projection_back(z) + # z = z.unsqueeze(1) + # shape unchanged: + z = self.mapper2(z) + z = self.final_ln2(z) + return z + + def encode(self, image, **kw): + return self(image, **kw) + + def forward_probabilities(self, text, image): + vision_outputs=self.model.vision_model(pixel_values=image) + image_embeds = vision_outputs[1] + image_embeds = self.model.visual_projection(image_embeds) + + inputs= self.tokenizer(text, padding=True, return_tensors="pt") + inputs = {k: v.to(self.model.device) for k, v in inputs.items()} + text_outputs = self.model.text_model(**inputs) + text_embeds = text_outputs[1] + text_embeds = self.text_projection(text_embeds) + # normalized features + image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True) + text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale + logits_per_image = logits_per_text.T + + return logits_per_image + +if __name__ == "__main__": + from ldm.util import count_params + model = FrozenCLIPEmbedder() + count_params(model, verbose=True) diff --git a/ldm/modules/encoders/xf.py b/ldm/modules/encoders/xf.py new file mode 100644 index 0000000000000000000000000000000000000000..5dfff440b489f3cc3c62450dc28c2f35f692dd94 --- /dev/null +++ b/ldm/modules/encoders/xf.py @@ -0,0 +1,130 @@ +""" +Transformer implementation adapted from CLIP ViT: +https://github.com/openai/CLIP/blob/4c0275784d6d9da97ca1f47eaaee31de1867da91/clip/model.py +""" + +import math + +import torch as th +import torch.nn as nn + + +def convert_module_to_f16(l): + """ + Convert primitive modules to float16. + """ + if isinstance(l, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): + l.weight.data = l.weight.data.half() + if l.bias is not None: + l.bias.data = l.bias.data.half() + + +class LayerNorm(nn.LayerNorm): + """ + Implementation that supports fp16 inputs but fp32 gains/biases. + """ + + def forward(self, x: th.Tensor): + return super().forward(x.float()).to(x.dtype) + + +class MultiheadAttention(nn.Module): + def __init__(self, n_ctx, width, heads): + super().__init__() + self.n_ctx = n_ctx + self.width = width + self.heads = heads + self.c_qkv = nn.Linear(width, width * 3) + self.c_proj = nn.Linear(width, width) + self.attention = QKVMultiheadAttention(heads, n_ctx) + + def forward(self, x): + x = self.c_qkv(x) + x = self.attention(x) + x = self.c_proj(x) + return x + + +class MLP(nn.Module): + def __init__(self, width): + super().__init__() + self.width = width + self.c_fc = nn.Linear(width, width * 4) + self.c_proj = nn.Linear(width * 4, width) + self.gelu = nn.GELU() + + def forward(self, x): + return self.c_proj(self.gelu(self.c_fc(x))) + + +class QKVMultiheadAttention(nn.Module): + def __init__(self, n_heads: int, n_ctx: int): + super().__init__() + self.n_heads = n_heads + self.n_ctx = n_ctx + + def forward(self, qkv): + bs, n_ctx, width = qkv.shape + attn_ch = width // self.n_heads // 3 + scale = 1 / math.sqrt(math.sqrt(attn_ch)) + qkv = qkv.view(bs, n_ctx, self.n_heads, -1) + q, k, v = th.split(qkv, attn_ch, dim=-1) + weight = th.einsum( + "bthc,bshc->bhts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + wdtype = weight.dtype + weight = th.softmax(weight.float(), dim=-1).type(wdtype) + return th.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1) + + +class ResidualAttentionBlock(nn.Module): + def __init__( + self, + n_ctx: int, + width: int, + heads: int, + ): + super().__init__() + + self.attn = MultiheadAttention( + n_ctx, + width, + heads, + ) + self.ln_1 = LayerNorm(width) + self.mlp = MLP(width) + self.ln_2 = LayerNorm(width) + + def forward(self, x: th.Tensor): + x = x + self.attn(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + def __init__( + self, + n_ctx: int, + width: int, + layers: int, + heads: int, + ): + super().__init__() + self.n_ctx = n_ctx + self.width = width + self.layers = layers + self.resblocks = nn.ModuleList( + [ + ResidualAttentionBlock( + n_ctx, + width, + heads, + ) + for _ in range(layers) + ] + ) + + def forward(self, x: th.Tensor): + for block in self.resblocks: + x = block(x) + return x diff --git a/ldm/modules/losses/__init__.py b/ldm/modules/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..876d7c5bd6e3245ee77feb4c482b7a8143604ad5 --- /dev/null +++ b/ldm/modules/losses/__init__.py @@ -0,0 +1 @@ +from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator \ No newline at end of file diff --git a/ldm/modules/losses/contperceptual.py b/ldm/modules/losses/contperceptual.py new file mode 100644 index 0000000000000000000000000000000000000000..672c1e32a1389def02461c0781339681060c540e --- /dev/null +++ b/ldm/modules/losses/contperceptual.py @@ -0,0 +1,111 @@ +import torch +import torch.nn as nn + +from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? + + +class LPIPSWithDiscriminator(nn.Module): + def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, + disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, + perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, + disc_loss="hinge"): + + super().__init__() + assert disc_loss in ["hinge", "vanilla"] + self.kl_weight = kl_weight + self.pixel_weight = pixelloss_weight + self.perceptual_loss = LPIPS().eval() + self.perceptual_weight = perceptual_weight + # output log variance + self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) + + self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, + n_layers=disc_num_layers, + use_actnorm=use_actnorm + ).apply(weights_init) + self.discriminator_iter_start = disc_start + self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss + self.disc_factor = disc_factor + self.discriminator_weight = disc_weight + self.disc_conditional = disc_conditional + + def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): + if last_layer is not None: + nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] + else: + nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] + + d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) + d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() + d_weight = d_weight * self.discriminator_weight + return d_weight + + def forward(self, inputs, reconstructions, posteriors, optimizer_idx, + global_step, last_layer=None, cond=None, split="train", + weights=None): + rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) + if self.perceptual_weight > 0: + p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) + rec_loss = rec_loss + self.perceptual_weight * p_loss + + nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar + weighted_nll_loss = nll_loss + if weights is not None: + weighted_nll_loss = weights*nll_loss + weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] + nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] + kl_loss = posteriors.kl() + kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] + + # now the GAN part + if optimizer_idx == 0: + # generator update + if cond is None: + assert not self.disc_conditional + logits_fake = self.discriminator(reconstructions.contiguous()) + else: + assert self.disc_conditional + logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) + g_loss = -torch.mean(logits_fake) + + if self.disc_factor > 0.0: + try: + d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) + except RuntimeError: + assert not self.training + d_weight = torch.tensor(0.0) + else: + d_weight = torch.tensor(0.0) + + disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss + + log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), + "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), + "{}/rec_loss".format(split): rec_loss.detach().mean(), + "{}/d_weight".format(split): d_weight.detach(), + "{}/disc_factor".format(split): torch.tensor(disc_factor), + "{}/g_loss".format(split): g_loss.detach().mean(), + } + return loss, log + + if optimizer_idx == 1: + # second pass for discriminator update + if cond is None: + logits_real = self.discriminator(inputs.contiguous().detach()) + logits_fake = self.discriminator(reconstructions.contiguous().detach()) + else: + logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) + logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) + + disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) + + log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), + "{}/logits_real".format(split): logits_real.detach().mean(), + "{}/logits_fake".format(split): logits_fake.detach().mean() + } + return d_loss, log + diff --git a/ldm/modules/losses/vqperceptual.py b/ldm/modules/losses/vqperceptual.py new file mode 100644 index 0000000000000000000000000000000000000000..f69981769e4bd5462600458c4fcf26620f7e4306 --- /dev/null +++ b/ldm/modules/losses/vqperceptual.py @@ -0,0 +1,167 @@ +import torch +from torch import nn +import torch.nn.functional as F +from einops import repeat + +from taming.modules.discriminator.model import NLayerDiscriminator, weights_init +from taming.modules.losses.lpips import LPIPS +from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss + + +def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights): + assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0] + loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3]) + loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3]) + loss_real = (weights * loss_real).sum() / weights.sum() + loss_fake = (weights * loss_fake).sum() / weights.sum() + d_loss = 0.5 * (loss_real + loss_fake) + return d_loss + +def adopt_weight(weight, global_step, threshold=0, value=0.): + if global_step < threshold: + weight = value + return weight + + +def measure_perplexity(predicted_indices, n_embed): + # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py + # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally + encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed) + avg_probs = encodings.mean(0) + perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() + cluster_use = torch.sum(avg_probs > 0) + return perplexity, cluster_use + +def l1(x, y): + return torch.abs(x-y) + + +def l2(x, y): + return torch.pow((x-y), 2) + + +class VQLPIPSWithDiscriminator(nn.Module): + def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, + disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, + perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, + disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips", + pixel_loss="l1"): + super().__init__() + assert disc_loss in ["hinge", "vanilla"] + assert perceptual_loss in ["lpips", "clips", "dists"] + assert pixel_loss in ["l1", "l2"] + self.codebook_weight = codebook_weight + self.pixel_weight = pixelloss_weight + if perceptual_loss == "lpips": + print(f"{self.__class__.__name__}: Running with LPIPS.") + self.perceptual_loss = LPIPS().eval() + else: + raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<") + self.perceptual_weight = perceptual_weight + + if pixel_loss == "l1": + self.pixel_loss = l1 + else: + self.pixel_loss = l2 + + self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, + n_layers=disc_num_layers, + use_actnorm=use_actnorm, + ndf=disc_ndf + ).apply(weights_init) + self.discriminator_iter_start = disc_start + if disc_loss == "hinge": + self.disc_loss = hinge_d_loss + elif disc_loss == "vanilla": + self.disc_loss = vanilla_d_loss + else: + raise ValueError(f"Unknown GAN loss '{disc_loss}'.") + print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") + self.disc_factor = disc_factor + self.discriminator_weight = disc_weight + self.disc_conditional = disc_conditional + self.n_classes = n_classes + + def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): + if last_layer is not None: + nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] + else: + nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] + + d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) + d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() + d_weight = d_weight * self.discriminator_weight + return d_weight + + def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, + global_step, last_layer=None, cond=None, split="train", predicted_indices=None): + if not exists(codebook_loss): + codebook_loss = torch.tensor([0.]).to(inputs.device) + #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) + rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous()) + if self.perceptual_weight > 0: + p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) + rec_loss = rec_loss + self.perceptual_weight * p_loss + else: + p_loss = torch.tensor([0.0]) + + nll_loss = rec_loss + #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] + nll_loss = torch.mean(nll_loss) + + # now the GAN part + if optimizer_idx == 0: + # generator update + if cond is None: + assert not self.disc_conditional + logits_fake = self.discriminator(reconstructions.contiguous()) + else: + assert self.disc_conditional + logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) + g_loss = -torch.mean(logits_fake) + + try: + d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) + except RuntimeError: + assert not self.training + d_weight = torch.tensor(0.0) + + disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() + + log = {"{}/total_loss".format(split): loss.clone().detach().mean(), + "{}/quant_loss".format(split): codebook_loss.detach().mean(), + "{}/nll_loss".format(split): nll_loss.detach().mean(), + "{}/rec_loss".format(split): rec_loss.detach().mean(), + "{}/p_loss".format(split): p_loss.detach().mean(), + "{}/d_weight".format(split): d_weight.detach(), + "{}/disc_factor".format(split): torch.tensor(disc_factor), + "{}/g_loss".format(split): g_loss.detach().mean(), + } + if predicted_indices is not None: + assert self.n_classes is not None + with torch.no_grad(): + perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes) + log[f"{split}/perplexity"] = perplexity + log[f"{split}/cluster_usage"] = cluster_usage + return loss, log + + if optimizer_idx == 1: + # second pass for discriminator update + if cond is None: + logits_real = self.discriminator(inputs.contiguous().detach()) + logits_fake = self.discriminator(reconstructions.contiguous().detach()) + else: + logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) + logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) + + disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) + + log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), + "{}/logits_real".format(split): logits_real.detach().mean(), + "{}/logits_fake".format(split): logits_fake.detach().mean() + } + return d_loss, log diff --git a/ldm/modules/x_transformer.py b/ldm/modules/x_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..5fc15bf9cfe0111a910e7de33d04ffdec3877576 --- /dev/null +++ b/ldm/modules/x_transformer.py @@ -0,0 +1,641 @@ +"""shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers""" +import torch +from torch import nn, einsum +import torch.nn.functional as F +from functools import partial +from inspect import isfunction +from collections import namedtuple +from einops import rearrange, repeat, reduce + +# constants + +DEFAULT_DIM_HEAD = 64 + +Intermediates = namedtuple('Intermediates', [ + 'pre_softmax_attn', + 'post_softmax_attn' +]) + +LayerIntermediates = namedtuple('Intermediates', [ + 'hiddens', + 'attn_intermediates' +]) + + +class AbsolutePositionalEmbedding(nn.Module): + def __init__(self, dim, max_seq_len): + super().__init__() + self.emb = nn.Embedding(max_seq_len, dim) + self.init_() + + def init_(self): + nn.init.normal_(self.emb.weight, std=0.02) + + def forward(self, x): + n = torch.arange(x.shape[1], device=x.device) + return self.emb(n)[None, :, :] + + +class FixedPositionalEmbedding(nn.Module): + def __init__(self, dim): + super().__init__() + inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer('inv_freq', inv_freq) + + def forward(self, x, seq_dim=1, offset=0): + t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset + sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq) + emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1) + return emb[None, :, :] + + +# helpers + +def exists(val): + return val is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def always(val): + def inner(*args, **kwargs): + return val + return inner + + +def not_equals(val): + def inner(x): + return x != val + return inner + + +def equals(val): + def inner(x): + return x == val + return inner + + +def max_neg_value(tensor): + return -torch.finfo(tensor.dtype).max + + +# keyword argument helpers + +def pick_and_pop(keys, d): + values = list(map(lambda key: d.pop(key), keys)) + return dict(zip(keys, values)) + + +def group_dict_by_key(cond, d): + return_val = [dict(), dict()] + for key in d.keys(): + match = bool(cond(key)) + ind = int(not match) + return_val[ind][key] = d[key] + return (*return_val,) + + +def string_begins_with(prefix, str): + return str.startswith(prefix) + + +def group_by_key_prefix(prefix, d): + return group_dict_by_key(partial(string_begins_with, prefix), d) + + +def groupby_prefix_and_trim(prefix, d): + kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) + kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) + return kwargs_without_prefix, kwargs + + +# classes +class Scale(nn.Module): + def __init__(self, value, fn): + super().__init__() + self.value = value + self.fn = fn + + def forward(self, x, **kwargs): + x, *rest = self.fn(x, **kwargs) + return (x * self.value, *rest) + + +class Rezero(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + self.g = nn.Parameter(torch.zeros(1)) + + def forward(self, x, **kwargs): + x, *rest = self.fn(x, **kwargs) + return (x * self.g, *rest) + + +class ScaleNorm(nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.scale = dim ** -0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(1)) + + def forward(self, x): + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + return x / norm.clamp(min=self.eps) * self.g + + +class RMSNorm(nn.Module): + def __init__(self, dim, eps=1e-8): + super().__init__() + self.scale = dim ** -0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + return x / norm.clamp(min=self.eps) * self.g + + +class Residual(nn.Module): + def forward(self, x, residual): + return x + residual + + +class GRUGating(nn.Module): + def __init__(self, dim): + super().__init__() + self.gru = nn.GRUCell(dim, dim) + + def forward(self, x, residual): + gated_output = self.gru( + rearrange(x, 'b n d -> (b n) d'), + rearrange(residual, 'b n d -> (b n) d') + ) + + return gated_output.reshape_as(x) + + +# feedforward + +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +# attention. +class Attention(nn.Module): + def __init__( + self, + dim, + dim_head=DEFAULT_DIM_HEAD, + heads=8, + causal=False, + mask=None, + talking_heads=False, + sparse_topk=None, + use_entmax15=False, + num_mem_kv=0, + dropout=0., + on_attn=False + ): + super().__init__() + if use_entmax15: + raise NotImplementedError("Check out entmax activation instead of softmax activation!") + self.scale = dim_head ** -0.5 + self.heads = heads + self.causal = causal + self.mask = mask + + inner_dim = dim_head * heads + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_k = nn.Linear(dim, inner_dim, bias=False) + self.to_v = nn.Linear(dim, inner_dim, bias=False) + self.dropout = nn.Dropout(dropout) + + # talking heads + self.talking_heads = talking_heads + if talking_heads: + self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads)) + self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads)) + + # explicit topk sparse attention + self.sparse_topk = sparse_topk + + # entmax + #self.attn_fn = entmax15 if use_entmax15 else F.softmax + self.attn_fn = F.softmax + + # add memory key / values + self.num_mem_kv = num_mem_kv + if num_mem_kv > 0: + self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) + self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) + + # attention on attention + self.attn_on_attn = on_attn + self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim) + + def forward( + self, + x, + context=None, + mask=None, + context_mask=None, + rel_pos=None, + sinusoidal_emb=None, + prev_attn=None, + mem=None + ): + b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device + kv_input = default(context, x) + + q_input = x + k_input = kv_input + v_input = kv_input + + if exists(mem): + k_input = torch.cat((mem, k_input), dim=-2) + v_input = torch.cat((mem, v_input), dim=-2) + + if exists(sinusoidal_emb): + # in shortformer, the query would start at a position offset depending on the past cached memory + offset = k_input.shape[-2] - q_input.shape[-2] + q_input = q_input + sinusoidal_emb(q_input, offset=offset) + k_input = k_input + sinusoidal_emb(k_input) + + q = self.to_q(q_input) + k = self.to_k(k_input) + v = self.to_v(v_input) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) + + input_mask = None + if any(map(exists, (mask, context_mask))): + q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool()) + k_mask = q_mask if not exists(context) else context_mask + k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool()) + q_mask = rearrange(q_mask, 'b i -> b () i ()') + k_mask = rearrange(k_mask, 'b j -> b () () j') + input_mask = q_mask * k_mask + + if self.num_mem_kv > 0: + mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v)) + k = torch.cat((mem_k, k), dim=-2) + v = torch.cat((mem_v, v), dim=-2) + if exists(input_mask): + input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True) + + dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale + mask_value = max_neg_value(dots) + + if exists(prev_attn): + dots = dots + prev_attn + + pre_softmax_attn = dots + + if talking_heads: + dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous() + + if exists(rel_pos): + dots = rel_pos(dots) + + if exists(input_mask): + dots.masked_fill_(~input_mask, mask_value) + del input_mask + + if self.causal: + i, j = dots.shape[-2:] + r = torch.arange(i, device=device) + mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j') + mask = F.pad(mask, (j - i, 0), value=False) + dots.masked_fill_(mask, mask_value) + del mask + + if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]: + top, _ = dots.topk(self.sparse_topk, dim=-1) + vk = top[..., -1].unsqueeze(-1).expand_as(dots) + mask = dots < vk + dots.masked_fill_(mask, mask_value) + del mask + + attn = self.attn_fn(dots, dim=-1) + post_softmax_attn = attn + + attn = self.dropout(attn) + + if talking_heads: + attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous() + + out = einsum('b h i j, b h j d -> b h i d', attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + + intermediates = Intermediates( + pre_softmax_attn=pre_softmax_attn, + post_softmax_attn=post_softmax_attn + ) + + return self.to_out(out), intermediates + + +class AttentionLayers(nn.Module): + def __init__( + self, + dim, + depth, + heads=8, + causal=False, + cross_attend=False, + only_cross=False, + use_scalenorm=False, + use_rmsnorm=False, + use_rezero=False, + rel_pos_num_buckets=32, + rel_pos_max_distance=128, + position_infused_attn=False, + custom_layers=None, + sandwich_coef=None, + par_ratio=None, + residual_attn=False, + cross_residual_attn=False, + macaron=False, + pre_norm=True, + gate_residual=False, + **kwargs + ): + super().__init__() + ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs) + attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs) + + dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD) + + self.dim = dim + self.depth = depth + self.layers = nn.ModuleList([]) + + self.has_pos_emb = position_infused_attn + self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None + self.rotary_pos_emb = always(None) + + assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance' + self.rel_pos = None + + self.pre_norm = pre_norm + + self.residual_attn = residual_attn + self.cross_residual_attn = cross_residual_attn + + norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm + norm_class = RMSNorm if use_rmsnorm else norm_class + norm_fn = partial(norm_class, dim) + + norm_fn = nn.Identity if use_rezero else norm_fn + branch_fn = Rezero if use_rezero else None + + if cross_attend and not only_cross: + default_block = ('a', 'c', 'f') + elif cross_attend and only_cross: + default_block = ('c', 'f') + else: + default_block = ('a', 'f') + + if macaron: + default_block = ('f',) + default_block + + if exists(custom_layers): + layer_types = custom_layers + elif exists(par_ratio): + par_depth = depth * len(default_block) + assert 1 < par_ratio <= par_depth, 'par ratio out of range' + default_block = tuple(filter(not_equals('f'), default_block)) + par_attn = par_depth // par_ratio + depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper + par_width = (depth_cut + depth_cut // par_attn) // par_attn + assert len(default_block) <= par_width, 'default block is too large for par_ratio' + par_block = default_block + ('f',) * (par_width - len(default_block)) + par_head = par_block * par_attn + layer_types = par_head + ('f',) * (par_depth - len(par_head)) + elif exists(sandwich_coef): + assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth' + layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef + else: + layer_types = default_block * depth + + self.layer_types = layer_types + self.num_attn_layers = len(list(filter(equals('a'), layer_types))) + + for layer_type in self.layer_types: + if layer_type == 'a': + layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs) + elif layer_type == 'c': + layer = Attention(dim, heads=heads, **attn_kwargs) + elif layer_type == 'f': + layer = FeedForward(dim, **ff_kwargs) + layer = layer if not macaron else Scale(0.5, layer) + else: + raise Exception(f'invalid layer type {layer_type}') + + if isinstance(layer, Attention) and exists(branch_fn): + layer = branch_fn(layer) + + if gate_residual: + residual_fn = GRUGating(dim) + else: + residual_fn = Residual() + + self.layers.append(nn.ModuleList([ + norm_fn(), + layer, + residual_fn + ])) + + def forward( + self, + x, + context=None, + mask=None, + context_mask=None, + mems=None, + return_hiddens=False + ): + hiddens = [] + intermediates = [] + prev_attn = None + prev_cross_attn = None + + mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers + + for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)): + is_last = ind == (len(self.layers) - 1) + + if layer_type == 'a': + hiddens.append(x) + layer_mem = mems.pop(0) + + residual = x + + if self.pre_norm: + x = norm(x) + + if layer_type == 'a': + out, inter = block(x, mask=mask, sinusoidal_emb=self.pia_pos_emb, rel_pos=self.rel_pos, + prev_attn=prev_attn, mem=layer_mem) + elif layer_type == 'c': + out, inter = block(x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn) + elif layer_type == 'f': + out = block(x) + + x = residual_fn(out, residual) + + if layer_type in ('a', 'c'): + intermediates.append(inter) + + if layer_type == 'a' and self.residual_attn: + prev_attn = inter.pre_softmax_attn + elif layer_type == 'c' and self.cross_residual_attn: + prev_cross_attn = inter.pre_softmax_attn + + if not self.pre_norm and not is_last: + x = norm(x) + + if return_hiddens: + intermediates = LayerIntermediates( + hiddens=hiddens, + attn_intermediates=intermediates + ) + + return x, intermediates + + return x + + +class Encoder(AttentionLayers): + def __init__(self, **kwargs): + assert 'causal' not in kwargs, 'cannot set causality on encoder' + super().__init__(causal=False, **kwargs) + + + +class TransformerWrapper(nn.Module): + def __init__( + self, + *, + num_tokens, + max_seq_len, + attn_layers, + emb_dim=None, + max_mem_len=0., + emb_dropout=0., + num_memory_tokens=None, + tie_embedding=False, + use_pos_emb=True + ): + super().__init__() + assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder' + + dim = attn_layers.dim + emb_dim = default(emb_dim, dim) + + self.max_seq_len = max_seq_len + self.max_mem_len = max_mem_len + self.num_tokens = num_tokens + + self.token_emb = nn.Embedding(num_tokens, emb_dim) + self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if ( + use_pos_emb and not attn_layers.has_pos_emb) else always(0) + self.emb_dropout = nn.Dropout(emb_dropout) + + self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() + self.attn_layers = attn_layers + self.norm = nn.LayerNorm(dim) + + self.init_() + + self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t() + + # memory tokens (like [cls]) from Memory Transformers paper + num_memory_tokens = default(num_memory_tokens, 0) + self.num_memory_tokens = num_memory_tokens + if num_memory_tokens > 0: + self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim)) + + # let funnel encoder know number of memory tokens, if specified + if hasattr(attn_layers, 'num_memory_tokens'): + attn_layers.num_memory_tokens = num_memory_tokens + + def init_(self): + nn.init.normal_(self.token_emb.weight, std=0.02) + + def forward( + self, + x, + return_embeddings=False, + mask=None, + return_mems=False, + return_attn=False, + mems=None, + **kwargs + ): + b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens + x = self.token_emb(x) + x += self.pos_emb(x) + x = self.emb_dropout(x) + + x = self.project_emb(x) + + if num_mem > 0: + mem = repeat(self.memory_tokens, 'n d -> b n d', b=b) + x = torch.cat((mem, x), dim=1) + + # auto-handle masking after appending memory tokens + if exists(mask): + mask = F.pad(mask, (num_mem, 0), value=True) + + x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs) + x = self.norm(x) + + mem, x = x[:, :num_mem], x[:, num_mem:] + + out = self.to_logits(x) if not return_embeddings else x + + if return_mems: + hiddens = intermediates.hiddens + new_mems = list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) if exists(mems) else hiddens + new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems)) + return out, new_mems + + if return_attn: + attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)) + return out, attn_maps + + return out + diff --git a/ldm/util.py b/ldm/util.py new file mode 100644 index 0000000000000000000000000000000000000000..8ba38853e7a07228cc2c187742b5c45d7359b3f9 --- /dev/null +++ b/ldm/util.py @@ -0,0 +1,203 @@ +import importlib + +import torch +import numpy as np +from collections import abc +from einops import rearrange +from functools import partial + +import multiprocessing as mp +from threading import Thread +from queue import Queue + +from inspect import isfunction +from PIL import Image, ImageDraw, ImageFont + + +def log_txt_as_img(wh, xc, size=10): + # wh a tuple of (width, height) + # xc a list of captions to plot + b = len(xc) + txts = list() + for bi in range(b): + txt = Image.new("RGB", wh, color="white") + draw = ImageDraw.Draw(txt) + font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) + nc = int(40 * (wh[0] / 256)) + lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) + + try: + draw.text((0, 0), lines, fill="black", font=font) + except UnicodeEncodeError: + print("Cant encode string for logging. Skipping.") + + txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 + txts.append(txt) + txts = np.stack(txts) + txts = torch.tensor(txts) + return txts + + +def ismap(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] > 3) + + +def isimage(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def mean_flat(tensor): + """ + https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") + return total_params + + +def instantiate_from_config(config): + if not "target" in config: + if config == '__is_first_stage__': + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): + # create dummy dataset instance + + # run prefetching + if idx_to_fn: + res = func(data, worker_id=idx) + else: + res = func(data) + Q.put([idx, res]) + Q.put("Done") + + +def parallel_data_prefetch( + func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False +): + # if target_data_type not in ["ndarray", "list"]: + # raise ValueError( + # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." + # ) + if isinstance(data, np.ndarray) and target_data_type == "list": + raise ValueError("list expected but function got ndarray.") + elif isinstance(data, abc.Iterable): + if isinstance(data, dict): + print( + f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' + ) + data = list(data.values()) + if target_data_type == "ndarray": + data = np.asarray(data) + else: + data = list(data) + else: + raise TypeError( + f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." + ) + + if cpu_intensive: + Q = mp.Queue(1000) + proc = mp.Process + else: + Q = Queue(1000) + proc = Thread + # spawn processes + if target_data_type == "ndarray": + arguments = [ + [func, Q, part, i, use_worker_id] + for i, part in enumerate(np.array_split(data, n_proc)) + ] + else: + step = ( + int(len(data) / n_proc + 1) + if len(data) % n_proc != 0 + else int(len(data) / n_proc) + ) + arguments = [ + [func, Q, part, i, use_worker_id] + for i, part in enumerate( + [data[i: i + step] for i in range(0, len(data), step)] + ) + ] + processes = [] + for i in range(n_proc): + p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) + processes += [p] + + # start processes + print(f"Start prefetching...") + import time + + start = time.time() + gather_res = [[] for _ in range(n_proc)] + try: + for p in processes: + p.start() + + k = 0 + while k < n_proc: + # get result + res = Q.get() + if res == "Done": + k += 1 + else: + gather_res[res[0]] = res[1] + + except Exception as e: + print("Exception: ", e) + for p in processes: + p.terminate() + + raise e + finally: + for p in processes: + p.join() + print(f"Prefetching complete. [{time.time() - start} sec.]") + + if target_data_type == 'ndarray': + if not isinstance(gather_res[0], np.ndarray): + return np.concatenate([np.asarray(r) for r in gather_res], axis=0) + + # order outputs + return np.concatenate(gather_res, axis=0) + elif target_data_type == 'list': + out = [] + for r in gather_res: + out.extend(r) + return out + else: + return gather_res diff --git a/lmk_util/draw_utils.py b/lmk_util/draw_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..db36ec81176355566b9248b8e8199a871664efe8 --- /dev/null +++ b/lmk_util/draw_utils.py @@ -0,0 +1,222 @@ +import cv2 +import mediapipe as mp +import numpy as np +from mediapipe.framework.formats import landmark_pb2 + +class FaceMeshVisualizer: + def __init__(self, + forehead_edge=False, + upface_only=False, + draw_eye=True, + draw_head=False, + draw_iris=True, + draw_eyebrow=True, + draw_mouse=True, + draw_nose=True, + draw_pupil=True, + line_thickness=2, + img_256_mode :bool=True, # if false, mode 0-1 + ): + self.mp_drawing = mp.solutions.drawing_utils + mp_face_mesh = mp.solutions.face_mesh + self.mp_face_mesh = mp_face_mesh + self.forehead_edge = forehead_edge + self.img_256_mode = img_256_mode + + DrawingSpec = mp.solutions.drawing_styles.DrawingSpec + f_thick = line_thickness + f_rad = 1 + + eye_color = (10, 200, 150) + eyebrow_color = (10, 220, 150) + iris_color = (150 ,120, 100) + # mouth_ob_color = (180, 10, 10) + mouth_ob_color = eye_color + mouth_ot_color = mouth_ob_color + mouth_ib_color = (10, 180, 20) + mouth_it_color = mouth_ib_color + + # Scale colors to 0-1 range if not in 256 mode + if not self.img_256_mode: + eye_color = tuple(c / 255.0 for c in eye_color) + eyebrow_color = tuple(c / 255.0 for c in eyebrow_color) + iris_color = tuple(c / 255.0 for c in iris_color) + mouth_ob_color = tuple(c / 255.0 for c in mouth_ob_color) + mouth_ot_color = tuple(c / 255.0 for c in mouth_ot_color) + mouth_ib_color = tuple(c / 255.0 for c in mouth_ib_color) + mouth_it_color = tuple(c / 255.0 for c in mouth_it_color) + + right_iris_draw = DrawingSpec(color=iris_color, thickness=f_thick, circle_radius=f_rad) + right_eye_draw = DrawingSpec(color=eye_color, thickness=f_thick, circle_radius=f_rad) + right_eyebrow_draw = DrawingSpec(color=eyebrow_color, thickness=f_thick, circle_radius=f_rad) + left_iris_draw = DrawingSpec(color=iris_color, thickness=f_thick, circle_radius=f_rad) + left_eye_draw = DrawingSpec(color=eye_color, thickness=f_thick, circle_radius=f_rad) + left_eyebrow_draw = DrawingSpec(color=eyebrow_color, thickness=f_thick, circle_radius=f_rad) + head_draw = DrawingSpec(color=(10, 200, 10) if self.img_256_mode else (10/255.0, 200/255.0, 10/255.0), thickness=f_thick, circle_radius=f_rad) + nose_draw = DrawingSpec(color=eyebrow_color, thickness=f_thick, circle_radius=f_rad) + + mouth_draw_obl = DrawingSpec(color=mouth_ob_color, thickness=f_thick, circle_radius=f_rad) + mouth_draw_obr = DrawingSpec(color=mouth_ob_color, thickness=f_thick, circle_radius=f_rad) + + mouth_draw_ibl = DrawingSpec(color=mouth_ib_color, thickness=f_thick, circle_radius=f_rad) + mouth_draw_ibr = DrawingSpec(color=mouth_ib_color, thickness=f_thick, circle_radius=f_rad) + + mouth_draw_otl = DrawingSpec(color=mouth_ot_color, thickness=f_thick, circle_radius=f_rad) + mouth_draw_otr = DrawingSpec(color=mouth_ot_color, thickness=f_thick, circle_radius=f_rad) + + mouth_draw_itl = DrawingSpec(color=mouth_it_color, thickness=f_thick, circle_radius=f_rad) + mouth_draw_itr = DrawingSpec(color=mouth_it_color, thickness=f_thick, circle_radius=f_rad) + + FACEMESH_LIPS_OUTER_BOTTOM_LEFT = [(61,146),(146,91),(91,181),(181,84),(84,17)] + FACEMESH_LIPS_OUTER_BOTTOM_RIGHT = [(17,314),(314,405),(405,321),(321,375),(375,291)] + + FACEMESH_LIPS_INNER_BOTTOM_LEFT = [(78,95),(95,88),(88,178),(178,87),(87,14)] + FACEMESH_LIPS_INNER_BOTTOM_RIGHT = [(14,317),(317,402),(402,318),(318,324),(324,308)] + + FACEMESH_LIPS_OUTER_TOP_LEFT = [(61,185),(185,40),(40,39),(39,37),(37,0)] + FACEMESH_LIPS_OUTER_TOP_RIGHT = [(0,267),(267,269),(269,270),(270,409),(409,291)] + + FACEMESH_LIPS_INNER_TOP_LEFT = [(78,191),(191,80),(80,81),(81,82),(82,13)] + FACEMESH_LIPS_INNER_TOP_RIGHT = [(13,312),(312,311),(311,310),(310,415),(415,308)] + + FACEMESH_CUSTOM_FACE_OVAL = [(176, 149), (150, 136), (356, 454), (58, 132), (152, 148), (361, 288), (251, 389), (132, 93), (389, 356), (400, 377), (136, 172), (377, 152), (323, 361), (172, 58), (454, 323), (365, 379), (379, 378), (148, 176), (93, 234), (397, 365), (149, 150), (288, 397), (234, 127), (378, 400), (127, 162), (162, 21)] + + # mp_face_mesh.FACEMESH_CONTOURS has all the items we care about. + face_connection_spec = {} + + #from IPython import embed + #embed() + if self.forehead_edge: + for edge in mp_face_mesh.FACEMESH_FACE_OVAL: + face_connection_spec[edge] = head_draw + else: + if draw_head: + FACEMESH_CUSTOM_FACE_OVAL_sorted = sorted(FACEMESH_CUSTOM_FACE_OVAL) + if upface_only: + for edge in [FACEMESH_CUSTOM_FACE_OVAL_sorted[edge_idx] for edge_idx in [1,2,9,12,13,16,22,25]]: + face_connection_spec[edge] = head_draw + else: + for edge in FACEMESH_CUSTOM_FACE_OVAL_sorted: + face_connection_spec[edge] = head_draw + + if draw_eye: + for edge in mp_face_mesh.FACEMESH_LEFT_EYE: + face_connection_spec[edge] = left_eye_draw + for edge in mp_face_mesh.FACEMESH_RIGHT_EYE: + face_connection_spec[edge] = right_eye_draw + + if draw_eyebrow: + for edge in mp_face_mesh.FACEMESH_LEFT_EYEBROW: + face_connection_spec[edge] = left_eyebrow_draw + + for edge in mp_face_mesh.FACEMESH_RIGHT_EYEBROW: + face_connection_spec[edge] = right_eyebrow_draw + + if draw_iris: + for edge in mp_face_mesh.FACEMESH_LEFT_IRIS: + face_connection_spec[edge] = left_iris_draw + for edge in mp_face_mesh.FACEMESH_RIGHT_IRIS: + face_connection_spec[edge] = right_iris_draw + + #for edge in mp_face_mesh.FACEMESH_RIGHT_EYEBROW: + # face_connection_spec[edge] = right_eyebrow_draw + # for edge in mp_face_mesh.FACEMESH_RIGHT_IRIS: + # face_connection_spec[edge] = right_iris_draw + + # for edge in mp_face_mesh.FACEMESH_LIPS: + # face_connection_spec[edge] = mouth_draw + + if draw_mouse: + for edge in FACEMESH_LIPS_OUTER_BOTTOM_LEFT: + face_connection_spec[edge] = mouth_draw_obl + for edge in FACEMESH_LIPS_OUTER_BOTTOM_RIGHT: + face_connection_spec[edge] = mouth_draw_obr + for edge in FACEMESH_LIPS_INNER_BOTTOM_LEFT: + face_connection_spec[edge] = mouth_draw_ibl + for edge in FACEMESH_LIPS_INNER_BOTTOM_RIGHT: + face_connection_spec[edge] = mouth_draw_ibr + for edge in FACEMESH_LIPS_OUTER_TOP_LEFT: + face_connection_spec[edge] = mouth_draw_otl + for edge in FACEMESH_LIPS_OUTER_TOP_RIGHT: + face_connection_spec[edge] = mouth_draw_otr + for edge in FACEMESH_LIPS_INNER_TOP_LEFT: + face_connection_spec[edge] = mouth_draw_itl + for edge in FACEMESH_LIPS_INNER_TOP_RIGHT: + face_connection_spec[edge] = mouth_draw_itr + + self.face_connection_spec = face_connection_spec + + self.pupil_landmark_spec = {468: right_iris_draw, 473: left_iris_draw} + self.nose_landmark_spec = {4: nose_draw} + + self.draw_pupil = draw_pupil + self.draw_nose = draw_nose + + def draw_points(self, image, landmark_list, drawing_spec, halfwidth: int = 2): + """We have a custom function to draw the pupils because the mp.draw_landmarks method requires a parameter for all + landmarks. Until our PR is merged into mediapipe, we need this separate method.""" + if len(image.shape) != 3: + raise ValueError("Input image must be H,W,C.") + image_rows, image_cols, image_channels = image.shape + if image_channels != 3: # BGR channels + raise ValueError('Input image must contain three channel bgr data.') + for idx, landmark in enumerate(landmark_list.landmark): + if idx not in drawing_spec: + continue + + if ( + (landmark.HasField('visibility') and landmark.visibility < 0.9) or + (landmark.HasField('presence') and landmark.presence < 0.5) + ): + continue + if landmark.x >= 1.0 or landmark.x < 0 or landmark.y >= 1.0 or landmark.y < 0: + continue + + image_x = int(image_cols * landmark.x) + image_y = int(image_rows * landmark.y) + + draw_color = drawing_spec[idx].color + if not self.img_256_mode: # if 0-1 mode, scale color values from 0-255 to 0-1 + if 0: # default: if not img_256_mode, we assume input is 0-1, so no conversion + draw_color = tuple(c / 255.0 for c in draw_color) + image[image_y - halfwidth : image_y + halfwidth, image_x - halfwidth : image_x + halfwidth, :] = draw_color + + + def draw_landmarks(self, image_size, keypoints, normed=False, image=None): + # print(f"{image_size=}") # 512, 512 + ini_size = [512, 512] + if image is None: + if self.img_256_mode: + image = np.zeros([ini_size[1], ini_size[0], 3], dtype=np.uint8) + else: + image = np.zeros([ini_size[1], ini_size[0], 3], dtype=np.float32) + + if keypoints is not None: + new_landmarks = landmark_pb2.NormalizedLandmarkList() + for i in range(keypoints.shape[0]): + landmark = new_landmarks.landmark.add() + if normed: + landmark.x = keypoints[i, 0] + landmark.y = keypoints[i, 1] + else: + landmark.x = keypoints[i, 0] / image_size[0] + landmark.y = keypoints[i, 1] / image_size[1] + landmark.z = 1.0 + + self.mp_drawing.draw_landmarks( + image=image, + landmark_list=new_landmarks, + connections=self.face_connection_spec.keys(), + landmark_drawing_spec=None, + connection_drawing_spec=self.face_connection_spec + ) + + if self.draw_pupil: + self.draw_points(image, new_landmarks, self.pupil_landmark_spec, 3) + + if self.draw_nose: + self.draw_points(image, new_landmarks, self.nose_landmark_spec, 3) + + image = cv2.resize(image, (image_size[0], image_size[1])) + + return image diff --git a/lmk_util/face_landmark.py b/lmk_util/face_landmark.py new file mode 100644 index 0000000000000000000000000000000000000000..b6580cb2cded9dcfeab46b0d50c8931ed6256669 --- /dev/null +++ b/lmk_util/face_landmark.py @@ -0,0 +1,3305 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MediaPipe face landmarker task.""" + +import dataclasses +import enum +from typing import Callable, Mapping, Optional, List + +import numpy as np + +from mediapipe.framework.formats import classification_pb2 +from mediapipe.framework.formats import landmark_pb2 +from mediapipe.framework.formats import matrix_data_pb2 +from mediapipe.python import packet_creator +from mediapipe.python import packet_getter +from mediapipe.python._framework_bindings import image as image_module +from mediapipe.python._framework_bindings import packet as packet_module +# pylint: disable=unused-import +from mediapipe.tasks.cc.vision.face_geometry.proto import face_geometry_pb2 +# pylint: enable=unused-import +from mediapipe.tasks.cc.vision.face_landmarker.proto import face_landmarker_graph_options_pb2 +from mediapipe.tasks.python.components.containers import category as category_module +from mediapipe.tasks.python.components.containers import landmark as landmark_module +from mediapipe.tasks.python.core import base_options as base_options_module +from mediapipe.tasks.python.core import task_info as task_info_module +from mediapipe.tasks.python.core.optional_dependencies import doc_controls +from mediapipe.tasks.python.vision.core import base_vision_task_api +from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module +from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module + +_BaseOptions = base_options_module.BaseOptions +_FaceLandmarkerGraphOptionsProto = ( + face_landmarker_graph_options_pb2.FaceLandmarkerGraphOptions +) +_LayoutEnum = matrix_data_pb2.MatrixData.Layout +_RunningMode = running_mode_module.VisionTaskRunningMode +_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions +_TaskInfo = task_info_module.TaskInfo + +_IMAGE_IN_STREAM_NAME = 'image_in' +_IMAGE_OUT_STREAM_NAME = 'image_out' +_IMAGE_TAG = 'IMAGE' +_NORM_RECT_STREAM_NAME = 'norm_rect_in' +_NORM_RECT_TAG = 'NORM_RECT' +_NORM_LANDMARKS_STREAM_NAME = 'norm_landmarks' +_NORM_LANDMARKS_TAG = 'NORM_LANDMARKS' +_BLENDSHAPES_STREAM_NAME = 'blendshapes' +_BLENDSHAPES_TAG = 'BLENDSHAPES' +_FACE_GEOMETRY_STREAM_NAME = 'face_geometry' +_FACE_GEOMETRY_TAG = 'FACE_GEOMETRY' +_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.face_landmarker.FaceLandmarkerGraph' +_MICRO_SECONDS_PER_MILLISECOND = 1000 + + +class Blendshapes(enum.IntEnum): + """The 52 blendshape coefficients.""" + + NEUTRAL = 0 + BROW_DOWN_LEFT = 1 + BROW_DOWN_RIGHT = 2 + BROW_INNER_UP = 3 + BROW_OUTER_UP_LEFT = 4 + BROW_OUTER_UP_RIGHT = 5 + CHEEK_PUFF = 6 + CHEEK_SQUINT_LEFT = 7 + CHEEK_SQUINT_RIGHT = 8 + EYE_BLINK_LEFT = 9 + EYE_BLINK_RIGHT = 10 + EYE_LOOK_DOWN_LEFT = 11 + EYE_LOOK_DOWN_RIGHT = 12 + EYE_LOOK_IN_LEFT = 13 + EYE_LOOK_IN_RIGHT = 14 + EYE_LOOK_OUT_LEFT = 15 + EYE_LOOK_OUT_RIGHT = 16 + EYE_LOOK_UP_LEFT = 17 + EYE_LOOK_UP_RIGHT = 18 + EYE_SQUINT_LEFT = 19 + EYE_SQUINT_RIGHT = 20 + EYE_WIDE_LEFT = 21 + EYE_WIDE_RIGHT = 22 + JAW_FORWARD = 23 + JAW_LEFT = 24 + JAW_OPEN = 25 + JAW_RIGHT = 26 + MOUTH_CLOSE = 27 + MOUTH_DIMPLE_LEFT = 28 + MOUTH_DIMPLE_RIGHT = 29 + MOUTH_FROWN_LEFT = 30 + MOUTH_FROWN_RIGHT = 31 + MOUTH_FUNNEL = 32 + MOUTH_LEFT = 33 + MOUTH_LOWER_DOWN_LEFT = 34 + MOUTH_LOWER_DOWN_RIGHT = 35 + MOUTH_PRESS_LEFT = 36 + MOUTH_PRESS_RIGHT = 37 + MOUTH_PUCKER = 38 + MOUTH_RIGHT = 39 + MOUTH_ROLL_LOWER = 40 + MOUTH_ROLL_UPPER = 41 + MOUTH_SHRUG_LOWER = 42 + MOUTH_SHRUG_UPPER = 43 + MOUTH_SMILE_LEFT = 44 + MOUTH_SMILE_RIGHT = 45 + MOUTH_STRETCH_LEFT = 46 + MOUTH_STRETCH_RIGHT = 47 + MOUTH_UPPER_UP_LEFT = 48 + MOUTH_UPPER_UP_RIGHT = 49 + NOSE_SNEER_LEFT = 50 + NOSE_SNEER_RIGHT = 51 + + +class FaceLandmarksConnections: + """The connections between face landmarks.""" + + @dataclasses.dataclass + class Connection: + """The connection class for face landmarks.""" + + start: int + end: int + + FACE_LANDMARKS_LIPS: List[Connection] = [ + Connection(61, 146), + Connection(146, 91), + Connection(91, 181), + Connection(181, 84), + Connection(84, 17), + Connection(17, 314), + Connection(314, 405), + Connection(405, 321), + Connection(321, 375), + Connection(375, 291), + Connection(61, 185), + Connection(185, 40), + Connection(40, 39), + Connection(39, 37), + Connection(37, 0), + Connection(0, 267), + Connection(267, 269), + Connection(269, 270), + Connection(270, 409), + Connection(409, 291), + Connection(78, 95), + Connection(95, 88), + Connection(88, 178), + Connection(178, 87), + Connection(87, 14), + Connection(14, 317), + Connection(317, 402), + Connection(402, 318), + Connection(318, 324), + Connection(324, 308), + Connection(78, 191), + Connection(191, 80), + Connection(80, 81), + Connection(81, 82), + Connection(82, 13), + Connection(13, 312), + Connection(312, 311), + Connection(311, 310), + Connection(310, 415), + Connection(415, 308), + ] + + FACE_LANDMARKS_LEFT_EYE: List[Connection] = [ + Connection(263, 249), + Connection(249, 390), + Connection(390, 373), + Connection(373, 374), + Connection(374, 380), + Connection(380, 381), + Connection(381, 382), + Connection(382, 362), + Connection(263, 466), + Connection(466, 388), + Connection(388, 387), + Connection(387, 386), + Connection(386, 385), + Connection(385, 384), + Connection(384, 398), + Connection(398, 362), + ] + + FACE_LANDMARKS_LEFT_EYEBROW: List[Connection] = [ + Connection(276, 283), + Connection(283, 282), + Connection(282, 295), + Connection(295, 285), + Connection(300, 293), + Connection(293, 334), + Connection(334, 296), + Connection(296, 336), + ] + + FACE_LANDMARKS_LEFT_IRIS: List[Connection] = [ + Connection(474, 475), + Connection(475, 476), + Connection(476, 477), + Connection(477, 474), + ] + + FACE_LANDMARKS_RIGHT_EYE: List[Connection] = [ + Connection(33, 7), + Connection(7, 163), + Connection(163, 144), + Connection(144, 145), + Connection(145, 153), + Connection(153, 154), + Connection(154, 155), + Connection(155, 133), + Connection(33, 246), + Connection(246, 161), + Connection(161, 160), + Connection(160, 159), + Connection(159, 158), + Connection(158, 157), + Connection(157, 173), + Connection(173, 133), + ] + + FACE_LANDMARKS_RIGHT_EYEBROW: List[Connection] = [ + Connection(46, 53), + Connection(53, 52), + Connection(52, 65), + Connection(65, 55), + Connection(70, 63), + Connection(63, 105), + Connection(105, 66), + Connection(66, 107), + ] + + FACE_LANDMARKS_RIGHT_IRIS: List[Connection] = [ + Connection(469, 470), + Connection(470, 471), + Connection(471, 472), + Connection(472, 469), + ] + + FACE_LANDMARKS_FACE_OVAL: List[Connection] = [ + Connection(10, 338), + Connection(338, 297), + Connection(297, 332), + Connection(332, 284), + Connection(284, 251), + Connection(251, 389), + Connection(389, 356), + Connection(356, 454), + Connection(454, 323), + Connection(323, 361), + Connection(361, 288), + Connection(288, 397), + Connection(397, 365), + Connection(365, 379), + Connection(379, 378), + Connection(378, 400), + Connection(400, 377), + Connection(377, 152), + Connection(152, 148), + Connection(148, 176), + Connection(176, 149), + Connection(149, 150), + Connection(150, 136), + Connection(136, 172), + Connection(172, 58), + Connection(58, 132), + Connection(132, 93), + Connection(93, 234), + Connection(234, 127), + Connection(127, 162), + Connection(162, 21), + Connection(21, 54), + Connection(54, 103), + Connection(103, 67), + Connection(67, 109), + Connection(109, 10), + ] + + FACE_LANDMARKS_CONTOURS: List[Connection] = ( + FACE_LANDMARKS_LIPS + + FACE_LANDMARKS_LEFT_EYE + + FACE_LANDMARKS_LEFT_EYEBROW + + FACE_LANDMARKS_RIGHT_EYE + + FACE_LANDMARKS_RIGHT_EYEBROW + + FACE_LANDMARKS_FACE_OVAL + ) + + FACE_LANDMARKS_TESSELATION: List[Connection] = [ + Connection(127, 34), + Connection(34, 139), + Connection(139, 127), + Connection(11, 0), + Connection(0, 37), + Connection(37, 11), + Connection(232, 231), + Connection(231, 120), + Connection(120, 232), + Connection(72, 37), + Connection(37, 39), + Connection(39, 72), + Connection(128, 121), + Connection(121, 47), + Connection(47, 128), + Connection(232, 121), + Connection(121, 128), + Connection(128, 232), + Connection(104, 69), + Connection(69, 67), + Connection(67, 104), + Connection(175, 171), + Connection(171, 148), + Connection(148, 175), + Connection(118, 50), + Connection(50, 101), + Connection(101, 118), + Connection(73, 39), + Connection(39, 40), + Connection(40, 73), + Connection(9, 151), + Connection(151, 108), + Connection(108, 9), + Connection(48, 115), + Connection(115, 131), + Connection(131, 48), + Connection(194, 204), + Connection(204, 211), + Connection(211, 194), + Connection(74, 40), + Connection(40, 185), + Connection(185, 74), + Connection(80, 42), + Connection(42, 183), + Connection(183, 80), + Connection(40, 92), + Connection(92, 186), + Connection(186, 40), + Connection(230, 229), + Connection(229, 118), + Connection(118, 230), + Connection(202, 212), + Connection(212, 214), + Connection(214, 202), + Connection(83, 18), + Connection(18, 17), + Connection(17, 83), + Connection(76, 61), + Connection(61, 146), + Connection(146, 76), + Connection(160, 29), + Connection(29, 30), + Connection(30, 160), + Connection(56, 157), + Connection(157, 173), + Connection(173, 56), + Connection(106, 204), + Connection(204, 194), + Connection(194, 106), + Connection(135, 214), + Connection(214, 192), + Connection(192, 135), + Connection(203, 165), + Connection(165, 98), + Connection(98, 203), + Connection(21, 71), + Connection(71, 68), + Connection(68, 21), + Connection(51, 45), + Connection(45, 4), + Connection(4, 51), + Connection(144, 24), + Connection(24, 23), + Connection(23, 144), + Connection(77, 146), + Connection(146, 91), + Connection(91, 77), + Connection(205, 50), + Connection(50, 187), + Connection(187, 205), + Connection(201, 200), + Connection(200, 18), + Connection(18, 201), + Connection(91, 106), + Connection(106, 182), + Connection(182, 91), + Connection(90, 91), + Connection(91, 181), + Connection(181, 90), + Connection(85, 84), + Connection(84, 17), + Connection(17, 85), + Connection(206, 203), + Connection(203, 36), + Connection(36, 206), + Connection(148, 171), + Connection(171, 140), + Connection(140, 148), + Connection(92, 40), + Connection(40, 39), + Connection(39, 92), + Connection(193, 189), + Connection(189, 244), + Connection(244, 193), + Connection(159, 158), + Connection(158, 28), + Connection(28, 159), + Connection(247, 246), + Connection(246, 161), + Connection(161, 247), + Connection(236, 3), + Connection(3, 196), + Connection(196, 236), + Connection(54, 68), + Connection(68, 104), + Connection(104, 54), + Connection(193, 168), + Connection(168, 8), + Connection(8, 193), + Connection(117, 228), + Connection(228, 31), + Connection(31, 117), + Connection(189, 193), + Connection(193, 55), + Connection(55, 189), + Connection(98, 97), + Connection(97, 99), + Connection(99, 98), + Connection(126, 47), + Connection(47, 100), + Connection(100, 126), + Connection(166, 79), + Connection(79, 218), + Connection(218, 166), + Connection(155, 154), + Connection(154, 26), + Connection(26, 155), + Connection(209, 49), + Connection(49, 131), + Connection(131, 209), + Connection(135, 136), + Connection(136, 150), + Connection(150, 135), + Connection(47, 126), + Connection(126, 217), + Connection(217, 47), + Connection(223, 52), + Connection(52, 53), + Connection(53, 223), + Connection(45, 51), + Connection(51, 134), + Connection(134, 45), + Connection(211, 170), + Connection(170, 140), + Connection(140, 211), + Connection(67, 69), + Connection(69, 108), + Connection(108, 67), + Connection(43, 106), + Connection(106, 91), + Connection(91, 43), + Connection(230, 119), + Connection(119, 120), + Connection(120, 230), + Connection(226, 130), + Connection(130, 247), + Connection(247, 226), + Connection(63, 53), + Connection(53, 52), + Connection(52, 63), + Connection(238, 20), + Connection(20, 242), + Connection(242, 238), + Connection(46, 70), + Connection(70, 156), + Connection(156, 46), + Connection(78, 62), + Connection(62, 96), + Connection(96, 78), + Connection(46, 53), + Connection(53, 63), + Connection(63, 46), + Connection(143, 34), + Connection(34, 227), + Connection(227, 143), + Connection(123, 117), + Connection(117, 111), + Connection(111, 123), + Connection(44, 125), + Connection(125, 19), + Connection(19, 44), + Connection(236, 134), + Connection(134, 51), + Connection(51, 236), + Connection(216, 206), + Connection(206, 205), + Connection(205, 216), + Connection(154, 153), + Connection(153, 22), + Connection(22, 154), + Connection(39, 37), + Connection(37, 167), + Connection(167, 39), + Connection(200, 201), + Connection(201, 208), + Connection(208, 200), + Connection(36, 142), + Connection(142, 100), + Connection(100, 36), + Connection(57, 212), + Connection(212, 202), + Connection(202, 57), + Connection(20, 60), + Connection(60, 99), + Connection(99, 20), + Connection(28, 158), + Connection(158, 157), + Connection(157, 28), + Connection(35, 226), + Connection(226, 113), + Connection(113, 35), + Connection(160, 159), + Connection(159, 27), + Connection(27, 160), + Connection(204, 202), + Connection(202, 210), + Connection(210, 204), + Connection(113, 225), + Connection(225, 46), + Connection(46, 113), + Connection(43, 202), + Connection(202, 204), + Connection(204, 43), + Connection(62, 76), + Connection(76, 77), + Connection(77, 62), + Connection(137, 123), + Connection(123, 116), + Connection(116, 137), + Connection(41, 38), + Connection(38, 72), + Connection(72, 41), + Connection(203, 129), + Connection(129, 142), + Connection(142, 203), + Connection(64, 98), + Connection(98, 240), + Connection(240, 64), + Connection(49, 102), + Connection(102, 64), + Connection(64, 49), + Connection(41, 73), + Connection(73, 74), + Connection(74, 41), + Connection(212, 216), + Connection(216, 207), + Connection(207, 212), + Connection(42, 74), + Connection(74, 184), + Connection(184, 42), + Connection(169, 170), + Connection(170, 211), + Connection(211, 169), + Connection(170, 149), + Connection(149, 176), + Connection(176, 170), + Connection(105, 66), + Connection(66, 69), + Connection(69, 105), + Connection(122, 6), + Connection(6, 168), + Connection(168, 122), + Connection(123, 147), + Connection(147, 187), + Connection(187, 123), + Connection(96, 77), + Connection(77, 90), + Connection(90, 96), + Connection(65, 55), + Connection(55, 107), + Connection(107, 65), + Connection(89, 90), + Connection(90, 180), + Connection(180, 89), + Connection(101, 100), + Connection(100, 120), + Connection(120, 101), + Connection(63, 105), + Connection(105, 104), + Connection(104, 63), + Connection(93, 137), + Connection(137, 227), + Connection(227, 93), + Connection(15, 86), + Connection(86, 85), + Connection(85, 15), + Connection(129, 102), + Connection(102, 49), + Connection(49, 129), + Connection(14, 87), + Connection(87, 86), + Connection(86, 14), + Connection(55, 8), + Connection(8, 9), + Connection(9, 55), + Connection(100, 47), + Connection(47, 121), + Connection(121, 100), + Connection(145, 23), + Connection(23, 22), + Connection(22, 145), + Connection(88, 89), + Connection(89, 179), + Connection(179, 88), + Connection(6, 122), + Connection(122, 196), + Connection(196, 6), + Connection(88, 95), + Connection(95, 96), + Connection(96, 88), + Connection(138, 172), + Connection(172, 136), + Connection(136, 138), + Connection(215, 58), + Connection(58, 172), + Connection(172, 215), + Connection(115, 48), + Connection(48, 219), + Connection(219, 115), + Connection(42, 80), + Connection(80, 81), + Connection(81, 42), + Connection(195, 3), + Connection(3, 51), + Connection(51, 195), + Connection(43, 146), + Connection(146, 61), + Connection(61, 43), + Connection(171, 175), + Connection(175, 199), + Connection(199, 171), + Connection(81, 82), + Connection(82, 38), + Connection(38, 81), + Connection(53, 46), + Connection(46, 225), + Connection(225, 53), + Connection(144, 163), + Connection(163, 110), + Connection(110, 144), + Connection(52, 65), + Connection(65, 66), + Connection(66, 52), + Connection(229, 228), + Connection(228, 117), + Connection(117, 229), + Connection(34, 127), + Connection(127, 234), + Connection(234, 34), + Connection(107, 108), + Connection(108, 69), + Connection(69, 107), + Connection(109, 108), + Connection(108, 151), + Connection(151, 109), + Connection(48, 64), + Connection(64, 235), + Connection(235, 48), + Connection(62, 78), + Connection(78, 191), + Connection(191, 62), + Connection(129, 209), + Connection(209, 126), + Connection(126, 129), + Connection(111, 35), + Connection(35, 143), + Connection(143, 111), + Connection(117, 123), + Connection(123, 50), + Connection(50, 117), + Connection(222, 65), + Connection(65, 52), + Connection(52, 222), + Connection(19, 125), + Connection(125, 141), + Connection(141, 19), + Connection(221, 55), + Connection(55, 65), + Connection(65, 221), + Connection(3, 195), + Connection(195, 197), + Connection(197, 3), + Connection(25, 7), + Connection(7, 33), + Connection(33, 25), + Connection(220, 237), + Connection(237, 44), + Connection(44, 220), + Connection(70, 71), + Connection(71, 139), + Connection(139, 70), + Connection(122, 193), + Connection(193, 245), + Connection(245, 122), + Connection(247, 130), + Connection(130, 33), + Connection(33, 247), + Connection(71, 21), + Connection(21, 162), + Connection(162, 71), + Connection(170, 169), + Connection(169, 150), + Connection(150, 170), + Connection(188, 174), + Connection(174, 196), + Connection(196, 188), + Connection(216, 186), + Connection(186, 92), + Connection(92, 216), + Connection(2, 97), + Connection(97, 167), + Connection(167, 2), + Connection(141, 125), + Connection(125, 241), + Connection(241, 141), + Connection(164, 167), + Connection(167, 37), + Connection(37, 164), + Connection(72, 38), + Connection(38, 12), + Connection(12, 72), + Connection(38, 82), + Connection(82, 13), + Connection(13, 38), + Connection(63, 68), + Connection(68, 71), + Connection(71, 63), + Connection(226, 35), + Connection(35, 111), + Connection(111, 226), + Connection(101, 50), + Connection(50, 205), + Connection(205, 101), + Connection(206, 92), + Connection(92, 165), + Connection(165, 206), + Connection(209, 198), + Connection(198, 217), + Connection(217, 209), + Connection(165, 167), + Connection(167, 97), + Connection(97, 165), + Connection(220, 115), + Connection(115, 218), + Connection(218, 220), + Connection(133, 112), + Connection(112, 243), + Connection(243, 133), + Connection(239, 238), + Connection(238, 241), + Connection(241, 239), + Connection(214, 135), + Connection(135, 169), + Connection(169, 214), + Connection(190, 173), + Connection(173, 133), + Connection(133, 190), + Connection(171, 208), + Connection(208, 32), + Connection(32, 171), + Connection(125, 44), + Connection(44, 237), + Connection(237, 125), + Connection(86, 87), + Connection(87, 178), + Connection(178, 86), + Connection(85, 86), + Connection(86, 179), + Connection(179, 85), + Connection(84, 85), + Connection(85, 180), + Connection(180, 84), + Connection(83, 84), + Connection(84, 181), + Connection(181, 83), + Connection(201, 83), + Connection(83, 182), + Connection(182, 201), + Connection(137, 93), + Connection(93, 132), + Connection(132, 137), + Connection(76, 62), + Connection(62, 183), + Connection(183, 76), + Connection(61, 76), + Connection(76, 184), + Connection(184, 61), + Connection(57, 61), + Connection(61, 185), + Connection(185, 57), + Connection(212, 57), + Connection(57, 186), + Connection(186, 212), + Connection(214, 207), + Connection(207, 187), + Connection(187, 214), + Connection(34, 143), + Connection(143, 156), + Connection(156, 34), + Connection(79, 239), + Connection(239, 237), + Connection(237, 79), + Connection(123, 137), + Connection(137, 177), + Connection(177, 123), + Connection(44, 1), + Connection(1, 4), + Connection(4, 44), + Connection(201, 194), + Connection(194, 32), + Connection(32, 201), + Connection(64, 102), + Connection(102, 129), + Connection(129, 64), + Connection(213, 215), + Connection(215, 138), + Connection(138, 213), + Connection(59, 166), + Connection(166, 219), + Connection(219, 59), + Connection(242, 99), + Connection(99, 97), + Connection(97, 242), + Connection(2, 94), + Connection(94, 141), + Connection(141, 2), + Connection(75, 59), + Connection(59, 235), + Connection(235, 75), + Connection(24, 110), + Connection(110, 228), + Connection(228, 24), + Connection(25, 130), + Connection(130, 226), + Connection(226, 25), + Connection(23, 24), + Connection(24, 229), + Connection(229, 23), + Connection(22, 23), + Connection(23, 230), + Connection(230, 22), + Connection(26, 22), + Connection(22, 231), + Connection(231, 26), + Connection(112, 26), + Connection(26, 232), + Connection(232, 112), + Connection(189, 190), + Connection(190, 243), + Connection(243, 189), + Connection(221, 56), + Connection(56, 190), + Connection(190, 221), + Connection(28, 56), + Connection(56, 221), + Connection(221, 28), + Connection(27, 28), + Connection(28, 222), + Connection(222, 27), + Connection(29, 27), + Connection(27, 223), + Connection(223, 29), + Connection(30, 29), + Connection(29, 224), + Connection(224, 30), + Connection(247, 30), + Connection(30, 225), + Connection(225, 247), + Connection(238, 79), + Connection(79, 20), + Connection(20, 238), + Connection(166, 59), + Connection(59, 75), + Connection(75, 166), + Connection(60, 75), + Connection(75, 240), + Connection(240, 60), + Connection(147, 177), + Connection(177, 215), + Connection(215, 147), + Connection(20, 79), + Connection(79, 166), + Connection(166, 20), + Connection(187, 147), + Connection(147, 213), + Connection(213, 187), + Connection(112, 233), + Connection(233, 244), + Connection(244, 112), + Connection(233, 128), + Connection(128, 245), + Connection(245, 233), + Connection(128, 114), + Connection(114, 188), + Connection(188, 128), + Connection(114, 217), + Connection(217, 174), + Connection(174, 114), + Connection(131, 115), + Connection(115, 220), + Connection(220, 131), + Connection(217, 198), + Connection(198, 236), + Connection(236, 217), + Connection(198, 131), + Connection(131, 134), + Connection(134, 198), + Connection(177, 132), + Connection(132, 58), + Connection(58, 177), + Connection(143, 35), + Connection(35, 124), + Connection(124, 143), + Connection(110, 163), + Connection(163, 7), + Connection(7, 110), + Connection(228, 110), + Connection(110, 25), + Connection(25, 228), + Connection(356, 389), + Connection(389, 368), + Connection(368, 356), + Connection(11, 302), + Connection(302, 267), + Connection(267, 11), + Connection(452, 350), + Connection(350, 349), + Connection(349, 452), + Connection(302, 303), + Connection(303, 269), + Connection(269, 302), + Connection(357, 343), + Connection(343, 277), + Connection(277, 357), + Connection(452, 453), + Connection(453, 357), + Connection(357, 452), + Connection(333, 332), + Connection(332, 297), + Connection(297, 333), + Connection(175, 152), + Connection(152, 377), + Connection(377, 175), + Connection(347, 348), + Connection(348, 330), + Connection(330, 347), + Connection(303, 304), + Connection(304, 270), + Connection(270, 303), + Connection(9, 336), + Connection(336, 337), + Connection(337, 9), + Connection(278, 279), + Connection(279, 360), + Connection(360, 278), + Connection(418, 262), + Connection(262, 431), + Connection(431, 418), + Connection(304, 408), + Connection(408, 409), + Connection(409, 304), + Connection(310, 415), + Connection(415, 407), + Connection(407, 310), + Connection(270, 409), + Connection(409, 410), + Connection(410, 270), + Connection(450, 348), + Connection(348, 347), + Connection(347, 450), + Connection(422, 430), + Connection(430, 434), + Connection(434, 422), + Connection(313, 314), + Connection(314, 17), + Connection(17, 313), + Connection(306, 307), + Connection(307, 375), + Connection(375, 306), + Connection(387, 388), + Connection(388, 260), + Connection(260, 387), + Connection(286, 414), + Connection(414, 398), + Connection(398, 286), + Connection(335, 406), + Connection(406, 418), + Connection(418, 335), + Connection(364, 367), + Connection(367, 416), + Connection(416, 364), + Connection(423, 358), + Connection(358, 327), + Connection(327, 423), + Connection(251, 284), + Connection(284, 298), + Connection(298, 251), + Connection(281, 5), + Connection(5, 4), + Connection(4, 281), + Connection(373, 374), + Connection(374, 253), + Connection(253, 373), + Connection(307, 320), + Connection(320, 321), + Connection(321, 307), + Connection(425, 427), + Connection(427, 411), + Connection(411, 425), + Connection(421, 313), + Connection(313, 18), + Connection(18, 421), + Connection(321, 405), + Connection(405, 406), + Connection(406, 321), + Connection(320, 404), + Connection(404, 405), + Connection(405, 320), + Connection(315, 16), + Connection(16, 17), + Connection(17, 315), + Connection(426, 425), + Connection(425, 266), + Connection(266, 426), + Connection(377, 400), + Connection(400, 369), + Connection(369, 377), + Connection(322, 391), + Connection(391, 269), + Connection(269, 322), + Connection(417, 465), + Connection(465, 464), + Connection(464, 417), + Connection(386, 257), + Connection(257, 258), + Connection(258, 386), + Connection(466, 260), + Connection(260, 388), + Connection(388, 466), + Connection(456, 399), + Connection(399, 419), + Connection(419, 456), + Connection(284, 332), + Connection(332, 333), + Connection(333, 284), + Connection(417, 285), + Connection(285, 8), + Connection(8, 417), + Connection(346, 340), + Connection(340, 261), + Connection(261, 346), + Connection(413, 441), + Connection(441, 285), + Connection(285, 413), + Connection(327, 460), + Connection(460, 328), + Connection(328, 327), + Connection(355, 371), + Connection(371, 329), + Connection(329, 355), + Connection(392, 439), + Connection(439, 438), + Connection(438, 392), + Connection(382, 341), + Connection(341, 256), + Connection(256, 382), + Connection(429, 420), + Connection(420, 360), + Connection(360, 429), + Connection(364, 394), + Connection(394, 379), + Connection(379, 364), + Connection(277, 343), + Connection(343, 437), + Connection(437, 277), + Connection(443, 444), + Connection(444, 283), + Connection(283, 443), + Connection(275, 440), + Connection(440, 363), + Connection(363, 275), + Connection(431, 262), + Connection(262, 369), + Connection(369, 431), + Connection(297, 338), + Connection(338, 337), + Connection(337, 297), + Connection(273, 375), + Connection(375, 321), + Connection(321, 273), + Connection(450, 451), + Connection(451, 349), + Connection(349, 450), + Connection(446, 342), + Connection(342, 467), + Connection(467, 446), + Connection(293, 334), + Connection(334, 282), + Connection(282, 293), + Connection(458, 461), + Connection(461, 462), + Connection(462, 458), + Connection(276, 353), + Connection(353, 383), + Connection(383, 276), + Connection(308, 324), + Connection(324, 325), + Connection(325, 308), + Connection(276, 300), + Connection(300, 293), + Connection(293, 276), + Connection(372, 345), + Connection(345, 447), + Connection(447, 372), + Connection(352, 345), + Connection(345, 340), + Connection(340, 352), + Connection(274, 1), + Connection(1, 19), + Connection(19, 274), + Connection(456, 248), + Connection(248, 281), + Connection(281, 456), + Connection(436, 427), + Connection(427, 425), + Connection(425, 436), + Connection(381, 256), + Connection(256, 252), + Connection(252, 381), + Connection(269, 391), + Connection(391, 393), + Connection(393, 269), + Connection(200, 199), + Connection(199, 428), + Connection(428, 200), + Connection(266, 330), + Connection(330, 329), + Connection(329, 266), + Connection(287, 273), + Connection(273, 422), + Connection(422, 287), + Connection(250, 462), + Connection(462, 328), + Connection(328, 250), + Connection(258, 286), + Connection(286, 384), + Connection(384, 258), + Connection(265, 353), + Connection(353, 342), + Connection(342, 265), + Connection(387, 259), + Connection(259, 257), + Connection(257, 387), + Connection(424, 431), + Connection(431, 430), + Connection(430, 424), + Connection(342, 353), + Connection(353, 276), + Connection(276, 342), + Connection(273, 335), + Connection(335, 424), + Connection(424, 273), + Connection(292, 325), + Connection(325, 307), + Connection(307, 292), + Connection(366, 447), + Connection(447, 345), + Connection(345, 366), + Connection(271, 303), + Connection(303, 302), + Connection(302, 271), + Connection(423, 266), + Connection(266, 371), + Connection(371, 423), + Connection(294, 455), + Connection(455, 460), + Connection(460, 294), + Connection(279, 278), + Connection(278, 294), + Connection(294, 279), + Connection(271, 272), + Connection(272, 304), + Connection(304, 271), + Connection(432, 434), + Connection(434, 427), + Connection(427, 432), + Connection(272, 407), + Connection(407, 408), + Connection(408, 272), + Connection(394, 430), + Connection(430, 431), + Connection(431, 394), + Connection(395, 369), + Connection(369, 400), + Connection(400, 395), + Connection(334, 333), + Connection(333, 299), + Connection(299, 334), + Connection(351, 417), + Connection(417, 168), + Connection(168, 351), + Connection(352, 280), + Connection(280, 411), + Connection(411, 352), + Connection(325, 319), + Connection(319, 320), + Connection(320, 325), + Connection(295, 296), + Connection(296, 336), + Connection(336, 295), + Connection(319, 403), + Connection(403, 404), + Connection(404, 319), + Connection(330, 348), + Connection(348, 349), + Connection(349, 330), + Connection(293, 298), + Connection(298, 333), + Connection(333, 293), + Connection(323, 454), + Connection(454, 447), + Connection(447, 323), + Connection(15, 16), + Connection(16, 315), + Connection(315, 15), + Connection(358, 429), + Connection(429, 279), + Connection(279, 358), + Connection(14, 15), + Connection(15, 316), + Connection(316, 14), + Connection(285, 336), + Connection(336, 9), + Connection(9, 285), + Connection(329, 349), + Connection(349, 350), + Connection(350, 329), + Connection(374, 380), + Connection(380, 252), + Connection(252, 374), + Connection(318, 402), + Connection(402, 403), + Connection(403, 318), + Connection(6, 197), + Connection(197, 419), + Connection(419, 6), + Connection(318, 319), + Connection(319, 325), + Connection(325, 318), + Connection(367, 364), + Connection(364, 365), + Connection(365, 367), + Connection(435, 367), + Connection(367, 397), + Connection(397, 435), + Connection(344, 438), + Connection(438, 439), + Connection(439, 344), + Connection(272, 271), + Connection(271, 311), + Connection(311, 272), + Connection(195, 5), + Connection(5, 281), + Connection(281, 195), + Connection(273, 287), + Connection(287, 291), + Connection(291, 273), + Connection(396, 428), + Connection(428, 199), + Connection(199, 396), + Connection(311, 271), + Connection(271, 268), + Connection(268, 311), + Connection(283, 444), + Connection(444, 445), + Connection(445, 283), + Connection(373, 254), + Connection(254, 339), + Connection(339, 373), + Connection(282, 334), + Connection(334, 296), + Connection(296, 282), + Connection(449, 347), + Connection(347, 346), + Connection(346, 449), + Connection(264, 447), + Connection(447, 454), + Connection(454, 264), + Connection(336, 296), + Connection(296, 299), + Connection(299, 336), + Connection(338, 10), + Connection(10, 151), + Connection(151, 338), + Connection(278, 439), + Connection(439, 455), + Connection(455, 278), + Connection(292, 407), + Connection(407, 415), + Connection(415, 292), + Connection(358, 371), + Connection(371, 355), + Connection(355, 358), + Connection(340, 345), + Connection(345, 372), + Connection(372, 340), + Connection(346, 347), + Connection(347, 280), + Connection(280, 346), + Connection(442, 443), + Connection(443, 282), + Connection(282, 442), + Connection(19, 94), + Connection(94, 370), + Connection(370, 19), + Connection(441, 442), + Connection(442, 295), + Connection(295, 441), + Connection(248, 419), + Connection(419, 197), + Connection(197, 248), + Connection(263, 255), + Connection(255, 359), + Connection(359, 263), + Connection(440, 275), + Connection(275, 274), + Connection(274, 440), + Connection(300, 383), + Connection(383, 368), + Connection(368, 300), + Connection(351, 412), + Connection(412, 465), + Connection(465, 351), + Connection(263, 467), + Connection(467, 466), + Connection(466, 263), + Connection(301, 368), + Connection(368, 389), + Connection(389, 301), + Connection(395, 378), + Connection(378, 379), + Connection(379, 395), + Connection(412, 351), + Connection(351, 419), + Connection(419, 412), + Connection(436, 426), + Connection(426, 322), + Connection(322, 436), + Connection(2, 164), + Connection(164, 393), + Connection(393, 2), + Connection(370, 462), + Connection(462, 461), + Connection(461, 370), + Connection(164, 0), + Connection(0, 267), + Connection(267, 164), + Connection(302, 11), + Connection(11, 12), + Connection(12, 302), + Connection(268, 12), + Connection(12, 13), + Connection(13, 268), + Connection(293, 300), + Connection(300, 301), + Connection(301, 293), + Connection(446, 261), + Connection(261, 340), + Connection(340, 446), + Connection(330, 266), + Connection(266, 425), + Connection(425, 330), + Connection(426, 423), + Connection(423, 391), + Connection(391, 426), + Connection(429, 355), + Connection(355, 437), + Connection(437, 429), + Connection(391, 327), + Connection(327, 326), + Connection(326, 391), + Connection(440, 457), + Connection(457, 438), + Connection(438, 440), + Connection(341, 382), + Connection(382, 362), + Connection(362, 341), + Connection(459, 457), + Connection(457, 461), + Connection(461, 459), + Connection(434, 430), + Connection(430, 394), + Connection(394, 434), + Connection(414, 463), + Connection(463, 362), + Connection(362, 414), + Connection(396, 369), + Connection(369, 262), + Connection(262, 396), + Connection(354, 461), + Connection(461, 457), + Connection(457, 354), + Connection(316, 403), + Connection(403, 402), + Connection(402, 316), + Connection(315, 404), + Connection(404, 403), + Connection(403, 315), + Connection(314, 405), + Connection(405, 404), + Connection(404, 314), + Connection(313, 406), + Connection(406, 405), + Connection(405, 313), + Connection(421, 418), + Connection(418, 406), + Connection(406, 421), + Connection(366, 401), + Connection(401, 361), + Connection(361, 366), + Connection(306, 408), + Connection(408, 407), + Connection(407, 306), + Connection(291, 409), + Connection(409, 408), + Connection(408, 291), + Connection(287, 410), + Connection(410, 409), + Connection(409, 287), + Connection(432, 436), + Connection(436, 410), + Connection(410, 432), + Connection(434, 416), + Connection(416, 411), + Connection(411, 434), + Connection(264, 368), + Connection(368, 383), + Connection(383, 264), + Connection(309, 438), + Connection(438, 457), + Connection(457, 309), + Connection(352, 376), + Connection(376, 401), + Connection(401, 352), + Connection(274, 275), + Connection(275, 4), + Connection(4, 274), + Connection(421, 428), + Connection(428, 262), + Connection(262, 421), + Connection(294, 327), + Connection(327, 358), + Connection(358, 294), + Connection(433, 416), + Connection(416, 367), + Connection(367, 433), + Connection(289, 455), + Connection(455, 439), + Connection(439, 289), + Connection(462, 370), + Connection(370, 326), + Connection(326, 462), + Connection(2, 326), + Connection(326, 370), + Connection(370, 2), + Connection(305, 460), + Connection(460, 455), + Connection(455, 305), + Connection(254, 449), + Connection(449, 448), + Connection(448, 254), + Connection(255, 261), + Connection(261, 446), + Connection(446, 255), + Connection(253, 450), + Connection(450, 449), + Connection(449, 253), + Connection(252, 451), + Connection(451, 450), + Connection(450, 252), + Connection(256, 452), + Connection(452, 451), + Connection(451, 256), + Connection(341, 453), + Connection(453, 452), + Connection(452, 341), + Connection(413, 464), + Connection(464, 463), + Connection(463, 413), + Connection(441, 413), + Connection(413, 414), + Connection(414, 441), + Connection(258, 442), + Connection(442, 441), + Connection(441, 258), + Connection(257, 443), + Connection(443, 442), + Connection(442, 257), + Connection(259, 444), + Connection(444, 443), + Connection(443, 259), + Connection(260, 445), + Connection(445, 444), + Connection(444, 260), + Connection(467, 342), + Connection(342, 445), + Connection(445, 467), + Connection(459, 458), + Connection(458, 250), + Connection(250, 459), + Connection(289, 392), + Connection(392, 290), + Connection(290, 289), + Connection(290, 328), + Connection(328, 460), + Connection(460, 290), + Connection(376, 433), + Connection(433, 435), + Connection(435, 376), + Connection(250, 290), + Connection(290, 392), + Connection(392, 250), + Connection(411, 416), + Connection(416, 433), + Connection(433, 411), + Connection(341, 463), + Connection(463, 464), + Connection(464, 341), + Connection(453, 464), + Connection(464, 465), + Connection(465, 453), + Connection(357, 465), + Connection(465, 412), + Connection(412, 357), + Connection(343, 412), + Connection(412, 399), + Connection(399, 343), + Connection(360, 363), + Connection(363, 440), + Connection(440, 360), + Connection(437, 399), + Connection(399, 456), + Connection(456, 437), + Connection(420, 456), + Connection(456, 363), + Connection(363, 420), + Connection(401, 435), + Connection(435, 288), + Connection(288, 401), + Connection(372, 383), + Connection(383, 353), + Connection(353, 372), + Connection(339, 255), + Connection(255, 249), + Connection(249, 339), + Connection(448, 261), + Connection(261, 255), + Connection(255, 448), + Connection(133, 243), + Connection(243, 190), + Connection(190, 133), + Connection(133, 155), + Connection(155, 112), + Connection(112, 133), + Connection(33, 246), + Connection(246, 247), + Connection(247, 33), + Connection(33, 130), + Connection(130, 25), + Connection(25, 33), + Connection(398, 384), + Connection(384, 286), + Connection(286, 398), + Connection(362, 398), + Connection(398, 414), + Connection(414, 362), + Connection(362, 463), + Connection(463, 341), + Connection(341, 362), + Connection(263, 359), + Connection(359, 467), + Connection(467, 263), + Connection(263, 249), + Connection(249, 255), + Connection(255, 263), + Connection(466, 467), + Connection(467, 260), + Connection(260, 466), + Connection(75, 60), + Connection(60, 166), + Connection(166, 75), + Connection(238, 239), + Connection(239, 79), + Connection(79, 238), + Connection(162, 127), + Connection(127, 139), + Connection(139, 162), + Connection(72, 11), + Connection(11, 37), + Connection(37, 72), + Connection(121, 232), + Connection(232, 120), + Connection(120, 121), + Connection(73, 72), + Connection(72, 39), + Connection(39, 73), + Connection(114, 128), + Connection(128, 47), + Connection(47, 114), + Connection(233, 232), + Connection(232, 128), + Connection(128, 233), + Connection(103, 104), + Connection(104, 67), + Connection(67, 103), + Connection(152, 175), + Connection(175, 148), + Connection(148, 152), + Connection(119, 118), + Connection(118, 101), + Connection(101, 119), + Connection(74, 73), + Connection(73, 40), + Connection(40, 74), + Connection(107, 9), + Connection(9, 108), + Connection(108, 107), + Connection(49, 48), + Connection(48, 131), + Connection(131, 49), + Connection(32, 194), + Connection(194, 211), + Connection(211, 32), + Connection(184, 74), + Connection(74, 185), + Connection(185, 184), + Connection(191, 80), + Connection(80, 183), + Connection(183, 191), + Connection(185, 40), + Connection(40, 186), + Connection(186, 185), + Connection(119, 230), + Connection(230, 118), + Connection(118, 119), + Connection(210, 202), + Connection(202, 214), + Connection(214, 210), + Connection(84, 83), + Connection(83, 17), + Connection(17, 84), + Connection(77, 76), + Connection(76, 146), + Connection(146, 77), + Connection(161, 160), + Connection(160, 30), + Connection(30, 161), + Connection(190, 56), + Connection(56, 173), + Connection(173, 190), + Connection(182, 106), + Connection(106, 194), + Connection(194, 182), + Connection(138, 135), + Connection(135, 192), + Connection(192, 138), + Connection(129, 203), + Connection(203, 98), + Connection(98, 129), + Connection(54, 21), + Connection(21, 68), + Connection(68, 54), + Connection(5, 51), + Connection(51, 4), + Connection(4, 5), + Connection(145, 144), + Connection(144, 23), + Connection(23, 145), + Connection(90, 77), + Connection(77, 91), + Connection(91, 90), + Connection(207, 205), + Connection(205, 187), + Connection(187, 207), + Connection(83, 201), + Connection(201, 18), + Connection(18, 83), + Connection(181, 91), + Connection(91, 182), + Connection(182, 181), + Connection(180, 90), + Connection(90, 181), + Connection(181, 180), + Connection(16, 85), + Connection(85, 17), + Connection(17, 16), + Connection(205, 206), + Connection(206, 36), + Connection(36, 205), + Connection(176, 148), + Connection(148, 140), + Connection(140, 176), + Connection(165, 92), + Connection(92, 39), + Connection(39, 165), + Connection(245, 193), + Connection(193, 244), + Connection(244, 245), + Connection(27, 159), + Connection(159, 28), + Connection(28, 27), + Connection(30, 247), + Connection(247, 161), + Connection(161, 30), + Connection(174, 236), + Connection(236, 196), + Connection(196, 174), + Connection(103, 54), + Connection(54, 104), + Connection(104, 103), + Connection(55, 193), + Connection(193, 8), + Connection(8, 55), + Connection(111, 117), + Connection(117, 31), + Connection(31, 111), + Connection(221, 189), + Connection(189, 55), + Connection(55, 221), + Connection(240, 98), + Connection(98, 99), + Connection(99, 240), + Connection(142, 126), + Connection(126, 100), + Connection(100, 142), + Connection(219, 166), + Connection(166, 218), + Connection(218, 219), + Connection(112, 155), + Connection(155, 26), + Connection(26, 112), + Connection(198, 209), + Connection(209, 131), + Connection(131, 198), + Connection(169, 135), + Connection(135, 150), + Connection(150, 169), + Connection(114, 47), + Connection(47, 217), + Connection(217, 114), + Connection(224, 223), + Connection(223, 53), + Connection(53, 224), + Connection(220, 45), + Connection(45, 134), + Connection(134, 220), + Connection(32, 211), + Connection(211, 140), + Connection(140, 32), + Connection(109, 67), + Connection(67, 108), + Connection(108, 109), + Connection(146, 43), + Connection(43, 91), + Connection(91, 146), + Connection(231, 230), + Connection(230, 120), + Connection(120, 231), + Connection(113, 226), + Connection(226, 247), + Connection(247, 113), + Connection(105, 63), + Connection(63, 52), + Connection(52, 105), + Connection(241, 238), + Connection(238, 242), + Connection(242, 241), + Connection(124, 46), + Connection(46, 156), + Connection(156, 124), + Connection(95, 78), + Connection(78, 96), + Connection(96, 95), + Connection(70, 46), + Connection(46, 63), + Connection(63, 70), + Connection(116, 143), + Connection(143, 227), + Connection(227, 116), + Connection(116, 123), + Connection(123, 111), + Connection(111, 116), + Connection(1, 44), + Connection(44, 19), + Connection(19, 1), + Connection(3, 236), + Connection(236, 51), + Connection(51, 3), + Connection(207, 216), + Connection(216, 205), + Connection(205, 207), + Connection(26, 154), + Connection(154, 22), + Connection(22, 26), + Connection(165, 39), + Connection(39, 167), + Connection(167, 165), + Connection(199, 200), + Connection(200, 208), + Connection(208, 199), + Connection(101, 36), + Connection(36, 100), + Connection(100, 101), + Connection(43, 57), + Connection(57, 202), + Connection(202, 43), + Connection(242, 20), + Connection(20, 99), + Connection(99, 242), + Connection(56, 28), + Connection(28, 157), + Connection(157, 56), + Connection(124, 35), + Connection(35, 113), + Connection(113, 124), + Connection(29, 160), + Connection(160, 27), + Connection(27, 29), + Connection(211, 204), + Connection(204, 210), + Connection(210, 211), + Connection(124, 113), + Connection(113, 46), + Connection(46, 124), + Connection(106, 43), + Connection(43, 204), + Connection(204, 106), + Connection(96, 62), + Connection(62, 77), + Connection(77, 96), + Connection(227, 137), + Connection(137, 116), + Connection(116, 227), + Connection(73, 41), + Connection(41, 72), + Connection(72, 73), + Connection(36, 203), + Connection(203, 142), + Connection(142, 36), + Connection(235, 64), + Connection(64, 240), + Connection(240, 235), + Connection(48, 49), + Connection(49, 64), + Connection(64, 48), + Connection(42, 41), + Connection(41, 74), + Connection(74, 42), + Connection(214, 212), + Connection(212, 207), + Connection(207, 214), + Connection(183, 42), + Connection(42, 184), + Connection(184, 183), + Connection(210, 169), + Connection(169, 211), + Connection(211, 210), + Connection(140, 170), + Connection(170, 176), + Connection(176, 140), + Connection(104, 105), + Connection(105, 69), + Connection(69, 104), + Connection(193, 122), + Connection(122, 168), + Connection(168, 193), + Connection(50, 123), + Connection(123, 187), + Connection(187, 50), + Connection(89, 96), + Connection(96, 90), + Connection(90, 89), + Connection(66, 65), + Connection(65, 107), + Connection(107, 66), + Connection(179, 89), + Connection(89, 180), + Connection(180, 179), + Connection(119, 101), + Connection(101, 120), + Connection(120, 119), + Connection(68, 63), + Connection(63, 104), + Connection(104, 68), + Connection(234, 93), + Connection(93, 227), + Connection(227, 234), + Connection(16, 15), + Connection(15, 85), + Connection(85, 16), + Connection(209, 129), + Connection(129, 49), + Connection(49, 209), + Connection(15, 14), + Connection(14, 86), + Connection(86, 15), + Connection(107, 55), + Connection(55, 9), + Connection(9, 107), + Connection(120, 100), + Connection(100, 121), + Connection(121, 120), + Connection(153, 145), + Connection(145, 22), + Connection(22, 153), + Connection(178, 88), + Connection(88, 179), + Connection(179, 178), + Connection(197, 6), + Connection(6, 196), + Connection(196, 197), + Connection(89, 88), + Connection(88, 96), + Connection(96, 89), + Connection(135, 138), + Connection(138, 136), + Connection(136, 135), + Connection(138, 215), + Connection(215, 172), + Connection(172, 138), + Connection(218, 115), + Connection(115, 219), + Connection(219, 218), + Connection(41, 42), + Connection(42, 81), + Connection(81, 41), + Connection(5, 195), + Connection(195, 51), + Connection(51, 5), + Connection(57, 43), + Connection(43, 61), + Connection(61, 57), + Connection(208, 171), + Connection(171, 199), + Connection(199, 208), + Connection(41, 81), + Connection(81, 38), + Connection(38, 41), + Connection(224, 53), + Connection(53, 225), + Connection(225, 224), + Connection(24, 144), + Connection(144, 110), + Connection(110, 24), + Connection(105, 52), + Connection(52, 66), + Connection(66, 105), + Connection(118, 229), + Connection(229, 117), + Connection(117, 118), + Connection(227, 34), + Connection(34, 234), + Connection(234, 227), + Connection(66, 107), + Connection(107, 69), + Connection(69, 66), + Connection(10, 109), + Connection(109, 151), + Connection(151, 10), + Connection(219, 48), + Connection(48, 235), + Connection(235, 219), + Connection(183, 62), + Connection(62, 191), + Connection(191, 183), + Connection(142, 129), + Connection(129, 126), + Connection(126, 142), + Connection(116, 111), + Connection(111, 143), + Connection(143, 116), + Connection(118, 117), + Connection(117, 50), + Connection(50, 118), + Connection(223, 222), + Connection(222, 52), + Connection(52, 223), + Connection(94, 19), + Connection(19, 141), + Connection(141, 94), + Connection(222, 221), + Connection(221, 65), + Connection(65, 222), + Connection(196, 3), + Connection(3, 197), + Connection(197, 196), + Connection(45, 220), + Connection(220, 44), + Connection(44, 45), + Connection(156, 70), + Connection(70, 139), + Connection(139, 156), + Connection(188, 122), + Connection(122, 245), + Connection(245, 188), + Connection(139, 71), + Connection(71, 162), + Connection(162, 139), + Connection(149, 170), + Connection(170, 150), + Connection(150, 149), + Connection(122, 188), + Connection(188, 196), + Connection(196, 122), + Connection(206, 216), + Connection(216, 92), + Connection(92, 206), + Connection(164, 2), + Connection(2, 167), + Connection(167, 164), + Connection(242, 141), + Connection(141, 241), + Connection(241, 242), + Connection(0, 164), + Connection(164, 37), + Connection(37, 0), + Connection(11, 72), + Connection(72, 12), + Connection(12, 11), + Connection(12, 38), + Connection(38, 13), + Connection(13, 12), + Connection(70, 63), + Connection(63, 71), + Connection(71, 70), + Connection(31, 226), + Connection(226, 111), + Connection(111, 31), + Connection(36, 101), + Connection(101, 205), + Connection(205, 36), + Connection(203, 206), + Connection(206, 165), + Connection(165, 203), + Connection(126, 209), + Connection(209, 217), + Connection(217, 126), + Connection(98, 165), + Connection(165, 97), + Connection(97, 98), + Connection(237, 220), + Connection(220, 218), + Connection(218, 237), + Connection(237, 239), + Connection(239, 241), + Connection(241, 237), + Connection(210, 214), + Connection(214, 169), + Connection(169, 210), + Connection(140, 171), + Connection(171, 32), + Connection(32, 140), + Connection(241, 125), + Connection(125, 237), + Connection(237, 241), + Connection(179, 86), + Connection(86, 178), + Connection(178, 179), + Connection(180, 85), + Connection(85, 179), + Connection(179, 180), + Connection(181, 84), + Connection(84, 180), + Connection(180, 181), + Connection(182, 83), + Connection(83, 181), + Connection(181, 182), + Connection(194, 201), + Connection(201, 182), + Connection(182, 194), + Connection(177, 137), + Connection(137, 132), + Connection(132, 177), + Connection(184, 76), + Connection(76, 183), + Connection(183, 184), + Connection(185, 61), + Connection(61, 184), + Connection(184, 185), + Connection(186, 57), + Connection(57, 185), + Connection(185, 186), + Connection(216, 212), + Connection(212, 186), + Connection(186, 216), + Connection(192, 214), + Connection(214, 187), + Connection(187, 192), + Connection(139, 34), + Connection(34, 156), + Connection(156, 139), + Connection(218, 79), + Connection(79, 237), + Connection(237, 218), + Connection(147, 123), + Connection(123, 177), + Connection(177, 147), + Connection(45, 44), + Connection(44, 4), + Connection(4, 45), + Connection(208, 201), + Connection(201, 32), + Connection(32, 208), + Connection(98, 64), + Connection(64, 129), + Connection(129, 98), + Connection(192, 213), + Connection(213, 138), + Connection(138, 192), + Connection(235, 59), + Connection(59, 219), + Connection(219, 235), + Connection(141, 242), + Connection(242, 97), + Connection(97, 141), + Connection(97, 2), + Connection(2, 141), + Connection(141, 97), + Connection(240, 75), + Connection(75, 235), + Connection(235, 240), + Connection(229, 24), + Connection(24, 228), + Connection(228, 229), + Connection(31, 25), + Connection(25, 226), + Connection(226, 31), + Connection(230, 23), + Connection(23, 229), + Connection(229, 230), + Connection(231, 22), + Connection(22, 230), + Connection(230, 231), + Connection(232, 26), + Connection(26, 231), + Connection(231, 232), + Connection(233, 112), + Connection(112, 232), + Connection(232, 233), + Connection(244, 189), + Connection(189, 243), + Connection(243, 244), + Connection(189, 221), + Connection(221, 190), + Connection(190, 189), + Connection(222, 28), + Connection(28, 221), + Connection(221, 222), + Connection(223, 27), + Connection(27, 222), + Connection(222, 223), + Connection(224, 29), + Connection(29, 223), + Connection(223, 224), + Connection(225, 30), + Connection(30, 224), + Connection(224, 225), + Connection(113, 247), + Connection(247, 225), + Connection(225, 113), + Connection(99, 60), + Connection(60, 240), + Connection(240, 99), + Connection(213, 147), + Connection(147, 215), + Connection(215, 213), + Connection(60, 20), + Connection(20, 166), + Connection(166, 60), + Connection(192, 187), + Connection(187, 213), + Connection(213, 192), + Connection(243, 112), + Connection(112, 244), + Connection(244, 243), + Connection(244, 233), + Connection(233, 245), + Connection(245, 244), + Connection(245, 128), + Connection(128, 188), + Connection(188, 245), + Connection(188, 114), + Connection(114, 174), + Connection(174, 188), + Connection(134, 131), + Connection(131, 220), + Connection(220, 134), + Connection(174, 217), + Connection(217, 236), + Connection(236, 174), + Connection(236, 198), + Connection(198, 134), + Connection(134, 236), + Connection(215, 177), + Connection(177, 58), + Connection(58, 215), + Connection(156, 143), + Connection(143, 124), + Connection(124, 156), + Connection(25, 110), + Connection(110, 7), + Connection(7, 25), + Connection(31, 228), + Connection(228, 25), + Connection(25, 31), + Connection(264, 356), + Connection(356, 368), + Connection(368, 264), + Connection(0, 11), + Connection(11, 267), + Connection(267, 0), + Connection(451, 452), + Connection(452, 349), + Connection(349, 451), + Connection(267, 302), + Connection(302, 269), + Connection(269, 267), + Connection(350, 357), + Connection(357, 277), + Connection(277, 350), + Connection(350, 452), + Connection(452, 357), + Connection(357, 350), + Connection(299, 333), + Connection(333, 297), + Connection(297, 299), + Connection(396, 175), + Connection(175, 377), + Connection(377, 396), + Connection(280, 347), + Connection(347, 330), + Connection(330, 280), + Connection(269, 303), + Connection(303, 270), + Connection(270, 269), + Connection(151, 9), + Connection(9, 337), + Connection(337, 151), + Connection(344, 278), + Connection(278, 360), + Connection(360, 344), + Connection(424, 418), + Connection(418, 431), + Connection(431, 424), + Connection(270, 304), + Connection(304, 409), + Connection(409, 270), + Connection(272, 310), + Connection(310, 407), + Connection(407, 272), + Connection(322, 270), + Connection(270, 410), + Connection(410, 322), + Connection(449, 450), + Connection(450, 347), + Connection(347, 449), + Connection(432, 422), + Connection(422, 434), + Connection(434, 432), + Connection(18, 313), + Connection(313, 17), + Connection(17, 18), + Connection(291, 306), + Connection(306, 375), + Connection(375, 291), + Connection(259, 387), + Connection(387, 260), + Connection(260, 259), + Connection(424, 335), + Connection(335, 418), + Connection(418, 424), + Connection(434, 364), + Connection(364, 416), + Connection(416, 434), + Connection(391, 423), + Connection(423, 327), + Connection(327, 391), + Connection(301, 251), + Connection(251, 298), + Connection(298, 301), + Connection(275, 281), + Connection(281, 4), + Connection(4, 275), + Connection(254, 373), + Connection(373, 253), + Connection(253, 254), + Connection(375, 307), + Connection(307, 321), + Connection(321, 375), + Connection(280, 425), + Connection(425, 411), + Connection(411, 280), + Connection(200, 421), + Connection(421, 18), + Connection(18, 200), + Connection(335, 321), + Connection(321, 406), + Connection(406, 335), + Connection(321, 320), + Connection(320, 405), + Connection(405, 321), + Connection(314, 315), + Connection(315, 17), + Connection(17, 314), + Connection(423, 426), + Connection(426, 266), + Connection(266, 423), + Connection(396, 377), + Connection(377, 369), + Connection(369, 396), + Connection(270, 322), + Connection(322, 269), + Connection(269, 270), + Connection(413, 417), + Connection(417, 464), + Connection(464, 413), + Connection(385, 386), + Connection(386, 258), + Connection(258, 385), + Connection(248, 456), + Connection(456, 419), + Connection(419, 248), + Connection(298, 284), + Connection(284, 333), + Connection(333, 298), + Connection(168, 417), + Connection(417, 8), + Connection(8, 168), + Connection(448, 346), + Connection(346, 261), + Connection(261, 448), + Connection(417, 413), + Connection(413, 285), + Connection(285, 417), + Connection(326, 327), + Connection(327, 328), + Connection(328, 326), + Connection(277, 355), + Connection(355, 329), + Connection(329, 277), + Connection(309, 392), + Connection(392, 438), + Connection(438, 309), + Connection(381, 382), + Connection(382, 256), + Connection(256, 381), + Connection(279, 429), + Connection(429, 360), + Connection(360, 279), + Connection(365, 364), + Connection(364, 379), + Connection(379, 365), + Connection(355, 277), + Connection(277, 437), + Connection(437, 355), + Connection(282, 443), + Connection(443, 283), + Connection(283, 282), + Connection(281, 275), + Connection(275, 363), + Connection(363, 281), + Connection(395, 431), + Connection(431, 369), + Connection(369, 395), + Connection(299, 297), + Connection(297, 337), + Connection(337, 299), + Connection(335, 273), + Connection(273, 321), + Connection(321, 335), + Connection(348, 450), + Connection(450, 349), + Connection(349, 348), + Connection(359, 446), + Connection(446, 467), + Connection(467, 359), + Connection(283, 293), + Connection(293, 282), + Connection(282, 283), + Connection(250, 458), + Connection(458, 462), + Connection(462, 250), + Connection(300, 276), + Connection(276, 383), + Connection(383, 300), + Connection(292, 308), + Connection(308, 325), + Connection(325, 292), + Connection(283, 276), + Connection(276, 293), + Connection(293, 283), + Connection(264, 372), + Connection(372, 447), + Connection(447, 264), + Connection(346, 352), + Connection(352, 340), + Connection(340, 346), + Connection(354, 274), + Connection(274, 19), + Connection(19, 354), + Connection(363, 456), + Connection(456, 281), + Connection(281, 363), + Connection(426, 436), + Connection(436, 425), + Connection(425, 426), + Connection(380, 381), + Connection(381, 252), + Connection(252, 380), + Connection(267, 269), + Connection(269, 393), + Connection(393, 267), + Connection(421, 200), + Connection(200, 428), + Connection(428, 421), + Connection(371, 266), + Connection(266, 329), + Connection(329, 371), + Connection(432, 287), + Connection(287, 422), + Connection(422, 432), + Connection(290, 250), + Connection(250, 328), + Connection(328, 290), + Connection(385, 258), + Connection(258, 384), + Connection(384, 385), + Connection(446, 265), + Connection(265, 342), + Connection(342, 446), + Connection(386, 387), + Connection(387, 257), + Connection(257, 386), + Connection(422, 424), + Connection(424, 430), + Connection(430, 422), + Connection(445, 342), + Connection(342, 276), + Connection(276, 445), + Connection(422, 273), + Connection(273, 424), + Connection(424, 422), + Connection(306, 292), + Connection(292, 307), + Connection(307, 306), + Connection(352, 366), + Connection(366, 345), + Connection(345, 352), + Connection(268, 271), + Connection(271, 302), + Connection(302, 268), + Connection(358, 423), + Connection(423, 371), + Connection(371, 358), + Connection(327, 294), + Connection(294, 460), + Connection(460, 327), + Connection(331, 279), + Connection(279, 294), + Connection(294, 331), + Connection(303, 271), + Connection(271, 304), + Connection(304, 303), + Connection(436, 432), + Connection(432, 427), + Connection(427, 436), + Connection(304, 272), + Connection(272, 408), + Connection(408, 304), + Connection(395, 394), + Connection(394, 431), + Connection(431, 395), + Connection(378, 395), + Connection(395, 400), + Connection(400, 378), + Connection(296, 334), + Connection(334, 299), + Connection(299, 296), + Connection(6, 351), + Connection(351, 168), + Connection(168, 6), + Connection(376, 352), + Connection(352, 411), + Connection(411, 376), + Connection(307, 325), + Connection(325, 320), + Connection(320, 307), + Connection(285, 295), + Connection(295, 336), + Connection(336, 285), + Connection(320, 319), + Connection(319, 404), + Connection(404, 320), + Connection(329, 330), + Connection(330, 349), + Connection(349, 329), + Connection(334, 293), + Connection(293, 333), + Connection(333, 334), + Connection(366, 323), + Connection(323, 447), + Connection(447, 366), + Connection(316, 15), + Connection(15, 315), + Connection(315, 316), + Connection(331, 358), + Connection(358, 279), + Connection(279, 331), + Connection(317, 14), + Connection(14, 316), + Connection(316, 317), + Connection(8, 285), + Connection(285, 9), + Connection(9, 8), + Connection(277, 329), + Connection(329, 350), + Connection(350, 277), + Connection(253, 374), + Connection(374, 252), + Connection(252, 253), + Connection(319, 318), + Connection(318, 403), + Connection(403, 319), + Connection(351, 6), + Connection(6, 419), + Connection(419, 351), + Connection(324, 318), + Connection(318, 325), + Connection(325, 324), + Connection(397, 367), + Connection(367, 365), + Connection(365, 397), + Connection(288, 435), + Connection(435, 397), + Connection(397, 288), + Connection(278, 344), + Connection(344, 439), + Connection(439, 278), + Connection(310, 272), + Connection(272, 311), + Connection(311, 310), + Connection(248, 195), + Connection(195, 281), + Connection(281, 248), + Connection(375, 273), + Connection(273, 291), + Connection(291, 375), + Connection(175, 396), + Connection(396, 199), + Connection(199, 175), + Connection(312, 311), + Connection(311, 268), + Connection(268, 312), + Connection(276, 283), + Connection(283, 445), + Connection(445, 276), + Connection(390, 373), + Connection(373, 339), + Connection(339, 390), + Connection(295, 282), + Connection(282, 296), + Connection(296, 295), + Connection(448, 449), + Connection(449, 346), + Connection(346, 448), + Connection(356, 264), + Connection(264, 454), + Connection(454, 356), + Connection(337, 336), + Connection(336, 299), + Connection(299, 337), + Connection(337, 338), + Connection(338, 151), + Connection(151, 337), + Connection(294, 278), + Connection(278, 455), + Connection(455, 294), + Connection(308, 292), + Connection(292, 415), + Connection(415, 308), + Connection(429, 358), + Connection(358, 355), + Connection(355, 429), + Connection(265, 340), + Connection(340, 372), + Connection(372, 265), + Connection(352, 346), + Connection(346, 280), + Connection(280, 352), + Connection(295, 442), + Connection(442, 282), + Connection(282, 295), + Connection(354, 19), + Connection(19, 370), + Connection(370, 354), + Connection(285, 441), + Connection(441, 295), + Connection(295, 285), + Connection(195, 248), + Connection(248, 197), + Connection(197, 195), + Connection(457, 440), + Connection(440, 274), + Connection(274, 457), + Connection(301, 300), + Connection(300, 368), + Connection(368, 301), + Connection(417, 351), + Connection(351, 465), + Connection(465, 417), + Connection(251, 301), + Connection(301, 389), + Connection(389, 251), + Connection(394, 395), + Connection(395, 379), + Connection(379, 394), + Connection(399, 412), + Connection(412, 419), + Connection(419, 399), + Connection(410, 436), + Connection(436, 322), + Connection(322, 410), + Connection(326, 2), + Connection(2, 393), + Connection(393, 326), + Connection(354, 370), + Connection(370, 461), + Connection(461, 354), + Connection(393, 164), + Connection(164, 267), + Connection(267, 393), + Connection(268, 302), + Connection(302, 12), + Connection(12, 268), + Connection(312, 268), + Connection(268, 13), + Connection(13, 312), + Connection(298, 293), + Connection(293, 301), + Connection(301, 298), + Connection(265, 446), + Connection(446, 340), + Connection(340, 265), + Connection(280, 330), + Connection(330, 425), + Connection(425, 280), + Connection(322, 426), + Connection(426, 391), + Connection(391, 322), + Connection(420, 429), + Connection(429, 437), + Connection(437, 420), + Connection(393, 391), + Connection(391, 326), + Connection(326, 393), + Connection(344, 440), + Connection(440, 438), + Connection(438, 344), + Connection(458, 459), + Connection(459, 461), + Connection(461, 458), + Connection(364, 434), + Connection(434, 394), + Connection(394, 364), + Connection(428, 396), + Connection(396, 262), + Connection(262, 428), + Connection(274, 354), + Connection(354, 457), + Connection(457, 274), + Connection(317, 316), + Connection(316, 402), + Connection(402, 317), + Connection(316, 315), + Connection(315, 403), + Connection(403, 316), + Connection(315, 314), + Connection(314, 404), + Connection(404, 315), + Connection(314, 313), + Connection(313, 405), + Connection(405, 314), + Connection(313, 421), + Connection(421, 406), + Connection(406, 313), + Connection(323, 366), + Connection(366, 361), + Connection(361, 323), + Connection(292, 306), + Connection(306, 407), + Connection(407, 292), + Connection(306, 291), + Connection(291, 408), + Connection(408, 306), + Connection(291, 287), + Connection(287, 409), + Connection(409, 291), + Connection(287, 432), + Connection(432, 410), + Connection(410, 287), + Connection(427, 434), + Connection(434, 411), + Connection(411, 427), + Connection(372, 264), + Connection(264, 383), + Connection(383, 372), + Connection(459, 309), + Connection(309, 457), + Connection(457, 459), + Connection(366, 352), + Connection(352, 401), + Connection(401, 366), + Connection(1, 274), + Connection(274, 4), + Connection(4, 1), + Connection(418, 421), + Connection(421, 262), + Connection(262, 418), + Connection(331, 294), + Connection(294, 358), + Connection(358, 331), + Connection(435, 433), + Connection(433, 367), + Connection(367, 435), + Connection(392, 289), + Connection(289, 439), + Connection(439, 392), + Connection(328, 462), + Connection(462, 326), + Connection(326, 328), + Connection(94, 2), + Connection(2, 370), + Connection(370, 94), + Connection(289, 305), + Connection(305, 455), + Connection(455, 289), + Connection(339, 254), + Connection(254, 448), + Connection(448, 339), + Connection(359, 255), + Connection(255, 446), + Connection(446, 359), + Connection(254, 253), + Connection(253, 449), + Connection(449, 254), + Connection(253, 252), + Connection(252, 450), + Connection(450, 253), + Connection(252, 256), + Connection(256, 451), + Connection(451, 252), + Connection(256, 341), + Connection(341, 452), + Connection(452, 256), + Connection(414, 413), + Connection(413, 463), + Connection(463, 414), + Connection(286, 441), + Connection(441, 414), + Connection(414, 286), + Connection(286, 258), + Connection(258, 441), + Connection(441, 286), + Connection(258, 257), + Connection(257, 442), + Connection(442, 258), + Connection(257, 259), + Connection(259, 443), + Connection(443, 257), + Connection(259, 260), + Connection(260, 444), + Connection(444, 259), + Connection(260, 467), + Connection(467, 445), + Connection(445, 260), + Connection(309, 459), + Connection(459, 250), + Connection(250, 309), + Connection(305, 289), + Connection(289, 290), + Connection(290, 305), + Connection(305, 290), + Connection(290, 460), + Connection(460, 305), + Connection(401, 376), + Connection(376, 435), + Connection(435, 401), + Connection(309, 250), + Connection(250, 392), + Connection(392, 309), + Connection(376, 411), + Connection(411, 433), + Connection(433, 376), + Connection(453, 341), + Connection(341, 464), + Connection(464, 453), + Connection(357, 453), + Connection(453, 465), + Connection(465, 357), + Connection(343, 357), + Connection(357, 412), + Connection(412, 343), + Connection(437, 343), + Connection(343, 399), + Connection(399, 437), + Connection(344, 360), + Connection(360, 440), + Connection(440, 344), + Connection(420, 437), + Connection(437, 456), + Connection(456, 420), + Connection(360, 420), + Connection(420, 363), + Connection(363, 360), + Connection(361, 401), + Connection(401, 288), + Connection(288, 361), + Connection(265, 372), + Connection(372, 353), + Connection(353, 265), + Connection(390, 339), + Connection(339, 249), + Connection(249, 390), + Connection(339, 448), + Connection(448, 255), + Connection(255, 339), + ] + + +@dataclasses.dataclass +class FaceLandmarkerResult: + """The face landmarks detection result from FaceLandmarker, where each vector element represents a single face detected in the image. + + Attributes: + face_landmarks: Detected face landmarks in normalized image coordinates. + face_blendshapes: Optional face blendshapes results. + facial_transformation_matrixes: Optional facial transformation matrix. + """ + + face_landmarks: List[List[landmark_module.NormalizedLandmark]] + face_blendshapes: List[List[category_module.Category]] + facial_transformation_matrixes: List[np.ndarray] + + +def _build_landmarker_result( + output_packets: Mapping[str, packet_module.Packet] +) -> FaceLandmarkerResult: + """Constructs a `FaceLandmarkerResult` from output packets.""" + face_landmarks_proto_list = packet_getter.get_proto_list( + output_packets[_NORM_LANDMARKS_STREAM_NAME] + ) + + face_landmarks_results = [] + for proto in face_landmarks_proto_list: + face_landmarks = landmark_pb2.NormalizedLandmarkList() + face_landmarks.MergeFrom(proto) + face_landmarks_list = [] + for face_landmark in face_landmarks.landmark: + face_landmarks_list.append( + landmark_module.NormalizedLandmark.create_from_pb2(face_landmark) + ) + face_landmarks_results.append(face_landmarks_list) + + face_blendshapes_results = [] + if _BLENDSHAPES_STREAM_NAME in output_packets: + face_blendshapes_proto_list = packet_getter.get_proto_list( + output_packets[_BLENDSHAPES_STREAM_NAME] + ) + for proto in face_blendshapes_proto_list: + face_blendshapes_categories = [] + face_blendshapes_classifications = classification_pb2.ClassificationList() + face_blendshapes_classifications.MergeFrom(proto) + for face_blendshapes in face_blendshapes_classifications.classification: + face_blendshapes_categories.append( + category_module.Category( + index=face_blendshapes.index, + score=face_blendshapes.score, + display_name=face_blendshapes.display_name, + category_name=face_blendshapes.label, + ) + ) + face_blendshapes_results.append(face_blendshapes_categories) + + facial_transformation_matrixes_results = [] + if _FACE_GEOMETRY_STREAM_NAME in output_packets: + facial_transformation_matrixes_proto_list = packet_getter.get_proto_list( + output_packets[_FACE_GEOMETRY_STREAM_NAME] + ) + for proto in facial_transformation_matrixes_proto_list: + if hasattr(proto, 'pose_transform_matrix'): + matrix_data = matrix_data_pb2.MatrixData() + matrix_data.MergeFrom(proto.pose_transform_matrix) + matrix = np.array(matrix_data.packed_data) + matrix = matrix.reshape((matrix_data.rows, matrix_data.cols)) + matrix = ( + matrix if matrix_data.layout == _LayoutEnum.ROW_MAJOR else matrix.T + ) + facial_transformation_matrixes_results.append(matrix) + + return FaceLandmarkerResult( + face_landmarks_results, + face_blendshapes_results, + facial_transformation_matrixes_results, + ) + +def _build_landmarker_result2( + output_packets: Mapping[str, packet_module.Packet] +) -> FaceLandmarkerResult: + """Constructs a `FaceLandmarkerResult` from output packets.""" + face_landmarks_proto_list = packet_getter.get_proto_list( + output_packets[_NORM_LANDMARKS_STREAM_NAME] + ) + + face_landmarks_results = [] + for proto in face_landmarks_proto_list: + face_landmarks = landmark_pb2.NormalizedLandmarkList() + face_landmarks.MergeFrom(proto) + face_landmarks_list = [] + for face_landmark in face_landmarks.landmark: + face_landmarks_list.append( + landmark_module.NormalizedLandmark.create_from_pb2(face_landmark) + ) + face_landmarks_results.append(face_landmarks_list) + + face_blendshapes_results = [] + if _BLENDSHAPES_STREAM_NAME in output_packets: + face_blendshapes_proto_list = packet_getter.get_proto_list( + output_packets[_BLENDSHAPES_STREAM_NAME] + ) + for proto in face_blendshapes_proto_list: + face_blendshapes_categories = [] + face_blendshapes_classifications = classification_pb2.ClassificationList() + face_blendshapes_classifications.MergeFrom(proto) + for face_blendshapes in face_blendshapes_classifications.classification: + face_blendshapes_categories.append( + category_module.Category( + index=face_blendshapes.index, + score=face_blendshapes.score, + display_name=face_blendshapes.display_name, + category_name=face_blendshapes.label, + ) + ) + face_blendshapes_results.append(face_blendshapes_categories) + + facial_transformation_matrixes_results = [] + if _FACE_GEOMETRY_STREAM_NAME in output_packets: + facial_transformation_matrixes_proto_list = packet_getter.get_proto_list( + output_packets[_FACE_GEOMETRY_STREAM_NAME] + ) + for proto in facial_transformation_matrixes_proto_list: + if hasattr(proto, 'pose_transform_matrix'): + matrix_data = matrix_data_pb2.MatrixData() + matrix_data.MergeFrom(proto.pose_transform_matrix) + matrix = np.array(matrix_data.packed_data) + matrix = matrix.reshape((matrix_data.rows, matrix_data.cols)) + matrix = ( + matrix if matrix_data.layout == _LayoutEnum.ROW_MAJOR else matrix.T + ) + facial_transformation_matrixes_results.append(matrix) + + return FaceLandmarkerResult( + face_landmarks_results, + face_blendshapes_results, + facial_transformation_matrixes_results, + ), facial_transformation_matrixes_proto_list[0].mesh + +@dataclasses.dataclass +class FaceLandmarkerOptions: + """Options for the face landmarker task. + + Attributes: + base_options: Base options for the face landmarker task. + running_mode: The running mode of the task. Default to the image mode. + FaceLandmarker has three running modes: 1) The image mode for detecting + face landmarks on single image inputs. 2) The video mode for detecting + face landmarks on the decoded frames of a video. 3) The live stream mode + for detecting face landmarks on the live stream of input data, such as + from camera. In this mode, the "result_callback" below must be specified + to receive the detection results asynchronously. + num_faces: The maximum number of faces that can be detected by the + FaceLandmarker. + min_face_detection_confidence: The minimum confidence score for the face + detection to be considered successful. + min_face_presence_confidence: The minimum confidence score of face presence + score in the face landmark detection. + min_tracking_confidence: The minimum confidence score for the face tracking + to be considered successful. + output_face_blendshapes: Whether FaceLandmarker outputs face blendshapes + classification. Face blendshapes are used for rendering the 3D face model. + output_facial_transformation_matrixes: Whether FaceLandmarker outputs facial + transformation_matrix. Facial transformation matrix is used to transform + the face landmarks in canonical face to the detected face, so that users + can apply face effects on the detected landmarks. + result_callback: The user-defined result callback for processing live stream + data. The result callback should only be specified when the running mode + is set to the live stream mode. + """ + + base_options: _BaseOptions + running_mode: _RunningMode = _RunningMode.IMAGE + num_faces: int = 1 + min_face_detection_confidence: float = 0.5 + min_face_presence_confidence: float = 0.5 + min_tracking_confidence: float = 0.5 + output_face_blendshapes: bool = False + output_facial_transformation_matrixes: bool = False + result_callback: Optional[ + Callable[[FaceLandmarkerResult, image_module.Image, int], None] + ] = None + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _FaceLandmarkerGraphOptionsProto: + """Generates an FaceLandmarkerGraphOptions protobuf object.""" + base_options_proto = self.base_options.to_pb2() + base_options_proto.use_stream_mode = ( + False if self.running_mode == _RunningMode.IMAGE else True + ) + + # Initialize the face landmarker options from base options. + face_landmarker_options_proto = _FaceLandmarkerGraphOptionsProto( + base_options=base_options_proto + ) + + # Configure face detector options. + face_landmarker_options_proto.face_detector_graph_options.num_faces = ( + self.num_faces + ) + face_landmarker_options_proto.face_detector_graph_options.min_detection_confidence = ( + self.min_face_detection_confidence + ) + + # Configure face landmark detector options. + face_landmarker_options_proto.min_tracking_confidence = ( + self.min_tracking_confidence + ) + face_landmarker_options_proto.face_landmarks_detector_graph_options.min_detection_confidence = ( + self.min_face_detection_confidence + ) + return face_landmarker_options_proto + + +class FaceLandmarker(base_vision_task_api.BaseVisionTaskApi): + """Class that performs face landmarks detection on images.""" + + @classmethod + def create_from_model_path(cls, model_path: str) -> 'FaceLandmarker': + """Creates an `FaceLandmarker` object from a TensorFlow Lite model and the default `FaceLandmarkerOptions`. + + Note that the created `FaceLandmarker` instance is in image mode, for + detecting face landmarks on single image inputs. + + Args: + model_path: Path to the model. + + Returns: + `FaceLandmarker` object that's created from the model file and the + default `FaceLandmarkerOptions`. + + Raises: + ValueError: If failed to create `FaceLandmarker` object from the + provided file such as invalid file path. + RuntimeError: If other types of error occurred. + """ + base_options = _BaseOptions(model_asset_path=model_path) + options = FaceLandmarkerOptions( + base_options=base_options, running_mode=_RunningMode.IMAGE + ) + return cls.create_from_options(options) + + @classmethod + def create_from_options( + cls, options: FaceLandmarkerOptions + ) -> 'FaceLandmarker': + """Creates the `FaceLandmarker` object from face landmarker options. + + Args: + options: Options for the face landmarker task. + + Returns: + `FaceLandmarker` object that's created from `options`. + + Raises: + ValueError: If failed to create `FaceLandmarker` object from + `FaceLandmarkerOptions` such as missing the model. + RuntimeError: If other types of error occurred. + """ + + def packets_callback(output_packets: Mapping[str, packet_module.Packet]): + if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty(): + return + + image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME]) + if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty(): + return + + if output_packets[_NORM_LANDMARKS_STREAM_NAME].is_empty(): + empty_packet = output_packets[_NORM_LANDMARKS_STREAM_NAME] + options.result_callback( + FaceLandmarkerResult([], [], []), + image, + empty_packet.timestamp.value // _MICRO_SECONDS_PER_MILLISECOND, + ) + return + + face_landmarks_result = _build_landmarker_result(output_packets) + timestamp = output_packets[_NORM_LANDMARKS_STREAM_NAME].timestamp + options.result_callback( + face_landmarks_result, + image, + timestamp.value // _MICRO_SECONDS_PER_MILLISECOND, + ) + + output_streams = [ + ':'.join([_NORM_LANDMARKS_TAG, _NORM_LANDMARKS_STREAM_NAME]), + ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]), + ] + + if options.output_face_blendshapes: + output_streams.append( + ':'.join([_BLENDSHAPES_TAG, _BLENDSHAPES_STREAM_NAME]) + ) + if options.output_facial_transformation_matrixes: + output_streams.append( + ':'.join([_FACE_GEOMETRY_TAG, _FACE_GEOMETRY_STREAM_NAME]) + ) + + task_info = _TaskInfo( + task_graph=_TASK_GRAPH_NAME, + input_streams=[ + ':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]), + ':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]), + ], + output_streams=output_streams, + task_options=options, + ) + return cls( + task_info.generate_graph_config( + enable_flow_limiting=options.running_mode + == _RunningMode.LIVE_STREAM + ), + options.running_mode, + packets_callback if options.result_callback else None, + ) + + def detect( + self, + image: image_module.Image, + image_processing_options: Optional[_ImageProcessingOptions] = None, + ) -> FaceLandmarkerResult: + """Performs face landmarks detection on the given image. + + Only use this method when the FaceLandmarker is created with the image + running mode. + + The image can be of any size with format RGB or RGBA. + TODO: Describes how the input image will be preprocessed after the yuv + support is implemented. + + Args: + image: MediaPipe Image. + image_processing_options: Options for image processing. + + Returns: + The face landmarks detection results. + + Raises: + ValueError: If any of the input arguments is invalid. + RuntimeError: If face landmarker detection failed to run. + """ + + normalized_rect = self.convert_to_normalized_rect( + image_processing_options, image, roi_allowed=False + ) + output_packets = self._process_image_data({ + _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image), + _NORM_RECT_STREAM_NAME: packet_creator.create_proto( + normalized_rect.to_pb2() + ), + }) + + if output_packets[_NORM_LANDMARKS_STREAM_NAME].is_empty(): + return FaceLandmarkerResult([], [], []) + + return _build_landmarker_result2(output_packets) + + def detect_for_video( + self, + image: image_module.Image, + timestamp_ms: int, + image_processing_options: Optional[_ImageProcessingOptions] = None, + ): + """Performs face landmarks detection on the provided video frame. + + Only use this method when the FaceLandmarker is created with the video + running mode. + + Only use this method when the FaceLandmarker is created with the video + running mode. It's required to provide the video frame's timestamp (in + milliseconds) along with the video frame. The input timestamps should be + monotonically increasing for adjacent calls of this method. + + Args: + image: MediaPipe Image. + timestamp_ms: The timestamp of the input video frame in milliseconds. + image_processing_options: Options for image processing. + + Returns: + The face landmarks detection results. + + Raises: + ValueError: If any of the input arguments is invalid. + RuntimeError: If face landmarker detection failed to run. + """ + normalized_rect = self.convert_to_normalized_rect( + image_processing_options, image, roi_allowed=False + ) + output_packets = self._process_video_data({ + _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at( + timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND + ), + _NORM_RECT_STREAM_NAME: packet_creator.create_proto( + normalized_rect.to_pb2() + ).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), + }) + + if output_packets[_NORM_LANDMARKS_STREAM_NAME].is_empty(): + return FaceLandmarkerResult([], [], []) + + return _build_landmarker_result2(output_packets) + + def detect_async( + self, + image: image_module.Image, + timestamp_ms: int, + image_processing_options: Optional[_ImageProcessingOptions] = None, + ) -> None: + """Sends live image data to perform face landmarks detection. + + The results will be available via the "result_callback" provided in the + FaceLandmarkerOptions. Only use this method when the FaceLandmarker is + created with the live stream running mode. + + Only use this method when the FaceLandmarker is created with the live + stream running mode. The input timestamps should be monotonically increasing + for adjacent calls of this method. This method will return immediately after + the input image is accepted. The results will be available via the + `result_callback` provided in the `FaceLandmarkerOptions`. The + `detect_async` method is designed to process live stream data such as + camera input. To lower the overall latency, face landmarker may drop the + input images if needed. In other words, it's not guaranteed to have output + per input image. + + The `result_callback` provides: + - The face landmarks detection results. + - The input image that the face landmarker runs on. + - The input timestamp in milliseconds. + + Args: + image: MediaPipe Image. + timestamp_ms: The timestamp of the input image in milliseconds. + image_processing_options: Options for image processing. + + Raises: + ValueError: If the current input timestamp is smaller than what the + face landmarker has already processed. + """ + normalized_rect = self.convert_to_normalized_rect( + image_processing_options, image, roi_allowed=False + ) + self._send_live_stream_data({ + _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at( + timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND + ), + _NORM_RECT_STREAM_NAME: packet_creator.create_proto( + normalized_rect.to_pb2() + ).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), + }) \ No newline at end of file diff --git a/lmk_util/lmk_extractor.py b/lmk_util/lmk_extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..0d20081aacf7226cb0d14545a928e9956f406896 --- /dev/null +++ b/lmk_util/lmk_extractor.py @@ -0,0 +1,336 @@ +if __name__=='__main__': + import sys,os; cur_dir = os.path.dirname(os.path.abspath(__file__)) + sys.path.append(os.path.abspath(os.path.join(cur_dir, '..'))) +from util_and_constant import * +from pathlib import Path +import cv2 +import numpy as np +from typing import Union, List, Optional, Dict, Any +from natsort import natsorted +import glob +from lmk_util.mp_utils import LMKExtractor +from lmk_util.draw_utils import FaceMeshVisualizer +from PIL import Image +from skimage.io import imsave +import torch + + +def pil_to_cv2(pil_img): + """Convert PIL image to OpenCV format.""" + return cv2.cvtColor(np.array(pil_img).astype(np.uint8), cv2.COLOR_RGB2BGR) +def cv2_to_pil(cv2_img): + """Convert OpenCV image to PIL format.""" + return Image.fromarray(cv2.cvtColor(cv2_img, cv2.COLOR_BGR2RGB).astype(np.uint8)) + +class LandmarkExtractor: + """ + A wrapper class for face landmark extraction. + + This class provides an interface to extract facial landmarks from images. + """ + + def __init__(self, include_visualizer: bool=False, fps: int = 25, **kw_of_vis): + """ + Initialize the landmark extractor. + + Args: + fps: Frames per second for video processing (default: 25) + """ + self.lmk_extractor = LMKExtractor(FPS=fps) + if include_visualizer: + self.visualizer = LandmarkVisualizer(**kw_of_vis) + + def extract_single(self, image: np.ndarray, only_main_lmk = False) -> np.ndarray: + """ + Extract landmarks from a single image. + + Args: + image: Input image as numpy array (h, w, 3) + + Returns: + Landmarks as numpy array (N, 2) with absolute coordinates, or None if detection failed + """ + if 0: + save_dir = Path("4debug/LandmarkExtractor-extract_single") + save_dir.mkdir(parents=0, exist_ok=True) + save_path = save_dir / f"{str_t_pid()}.jpg" + imsave(str(save_path), image) + print(f"{save_path=}") + # Extract landmarks + result = self.lmk_extractor(image) + + if result is None: + if 0: + save_dir = Path("4debug/LandmarkExtractor-extract_single--no-result") + save_dir.mkdir(parents=0, exist_ok=True) + save_path = save_dir / f"{str_t_pid()}.jpg" + imsave(str(save_path), image) + print(f"Landmark not detected: {save_path}") + # assert 0, (image.shape, save_path, image) + return None + + # Extract 2D landmarks and convert to absolute coordinates + lmks = result["lmks"] # Shape: (num_landmarks, 3), normalized coordinates + h, w = image.shape[:2] + + # Convert normalized coordinates to absolute pixel coordinates + absolute_coords = lmks[:, :2] * [w, h] # (N, 2) + if only_main_lmk: + absolute_coords = lmkAll_2_lmkMain(absolute_coords) + + return absolute_coords + +class LandmarkVisualizer: + """ + A class for visualizing facial landmarks on images. + """ + + def __init__(self,img_256_mode=True): + """Initialize the landmark visualizer.""" + self.visualizer = FaceMeshVisualizer( + draw_iris=False, + draw_mouse=True, + draw_eye=True, + draw_nose=True, + draw_eyebrow=True, + draw_pupil=True, + line_thickness=2, + img_256_mode=img_256_mode, + ) + + def visualize_landmarks(self, image: np.ndarray, landmarks: np.ndarray, + target_size: tuple = (512, 512), use_connections: bool = True) -> np.ndarray: + """ + Visualize landmarks on an image. + + Args: + image: Input image as numpy array (h, w, 3) + landmarks: Landmark coordinates as numpy array (N, 2) with absolute coordinates + target_size: Target image size for visualization + use_connections: Whether to use MediaPipe connections (only works with 468+ landmarks) + + Returns: + Image with landmarks drawn as numpy array (BGR format) + """ + image = image.copy() + img_resized = cv2.resize(image, target_size) + + if use_connections and landmarks.shape[0] >= 468: + # Use original MediaPipe visualizer with connections + # Convert absolute coordinates to normalized coordinates for visualizer + h, w = target_size + normalized_lmks = landmarks / [image.shape[1], image.shape[0]] # Normalize by original image size + + # Add dummy z coordinate + lmks_3d = np.column_stack([normalized_lmks, np.zeros(len(normalized_lmks))]) # (N, 3) + + # Draw landmarks + if 0: + landmark_img = self.visualizer.draw_landmarks(target_size, lmks_3d, normed=True) + combined = (img_resized * 0.5 + landmark_img * 0.5).clip(0, 255).astype(np.uint8) + else: + combined = self.visualizer.draw_landmarks(target_size, lmks_3d, normed=True, image=img_resized, ) + else: + # Draw simple points for smaller landmark sets + combined = img_resized.copy() + + # Convert coordinates to target size + scale_x = target_size[0] / image.shape[1] + scale_y = target_size[1] / image.shape[0] + scaled_landmarks = landmarks * [scale_x, scale_y] + + # Draw each landmark as a colored circle + for i, (x, y) in enumerate(scaled_landmarks): + x, y = int(x), int(y) + # Use different colors for different regions + if 0 <= x < target_size[0] and 0 <= y < target_size[1]: + cv2.circle(combined, (x, y), 2, (255, 0 , 0), -1) # red/blue (depends on RGB/BGR) points + + return combined + + def save_landmark_visualization(self, image: np.ndarray, landmarks: np.ndarray, + output_path: Union[str, Path], + target_size: tuple = (512, 512), use_connections: bool = True) -> None: + """ + Save landmark visualization to file. + + Args: + image: Input image as numpy array (h, w, 3) + landmarks: Landmark coordinates as numpy array (N, 2) with absolute coordinates + output_path: Output file path + target_size: Target image size for visualization + use_connections: Whether to use MediaPipe connections (only works with 468+ landmarks) + """ + vis_img = self.visualize_landmarks(image, landmarks, target_size, use_connections) + imsave(str(output_path), vis_img) + print(f"Saved visualization to: {output_path}") +from functools import lru_cache +@lru_cache(maxsize=None) +def get_lmkMain_indices(include_face_oval: bool, return_tensor: bool = False): + # Main landmark indices based on MediaPipe face mesh + # These indices are from FaceMeshVisualizer connections + + # Left eye landmarks (based on FACEMESH_LEFT_EYE connections) + left_eye_indices = [ + 263, 249, 390, 373, 374, 380, 381, 382, 362, # outer contour + 466, 388, 387, 386, 385, 384, 398 # inner contour + ] + + # Right eye landmarks (based on FACEMESH_RIGHT_EYE connections) + right_eye_indices = [ + 33, 7, 163, 144, 145, 153, 154, 155, 133, # outer contour + 246, 161, 160, 159, 158, 157, 173 # inner contour + ] + + # Left eyebrow landmarks (based on FACEMESH_LEFT_EYEBROW connections) + left_eyebrow_indices = [276, 283, 282, 295, 285, 300, 293, 334, 296, 336] + + # Right eyebrow landmarks (based on FACEMESH_RIGHT_EYEBROW connections) + right_eyebrow_indices = [46, 53, 52, 65, 55, 70, 63, 105, 66, 107] + + # Lip landmarks (based on LIPS definition in draw_utils.py) + lips_outer_bottom_left = [61, 146, 91, 181, 84, 17] + lips_outer_bottom_right = [17, 314, 405, 321, 375, 291] + lips_inner_bottom_left = [78, 95, 88, 178, 87, 14] + lips_inner_bottom_right = [14, 317, 402, 318, 324, 308] + lips_outer_top_left = [61, 185, 40, 39, 37, 0] + lips_outer_top_right = [0, 267, 269, 270, 409, 291] + lips_inner_top_left = [78, 191, 80, 81, 82, 13] + lips_inner_top_right = [13, 312, 311, 310, 415, 308] + + # Nose landmarks + nose_indices = [4] # nose tip, defined in draw_utils.py nose_landmark_spec + + # Pupil landmarks (gaze landmarks) + pupil_indices = [468, 473] # 468: right iris center, 473: left iris center + + # Face contour landmarks (based on MediaPipe FACEMESH_FACE_OVAL) + face_oval_indices = [10, 21, 54, 58, 67, 93, 103, 109, 127, 132, 136, 148, 149, 150, 152, 162, 172, 176, + 234, 251, 284, 288, 297, 323, 332, 338, 356, 361, 365, 377, 378, 379, 389, 397, 400, 454] + + # Merge all main landmark indices + main_indices = set() + + # Add eye landmarks + main_indices.update(left_eye_indices) + main_indices.update(right_eye_indices) + + # Add eyebrow landmarks + main_indices.update(left_eyebrow_indices) + main_indices.update(right_eyebrow_indices) + + # Add lip landmarks + main_indices.update(lips_outer_bottom_left) + main_indices.update(lips_outer_bottom_right) + main_indices.update(lips_inner_bottom_left) + main_indices.update(lips_inner_bottom_right) + main_indices.update(lips_outer_top_left) + main_indices.update(lips_outer_top_right) + main_indices.update(lips_inner_top_left) + main_indices.update(lips_inner_top_right) + + # Add nose landmarks + main_indices.update(nose_indices) + + # Add pupil landmarks (gaze landmarks) + main_indices.update(pupil_indices) + + # Add face contour landmarks if requested + if include_face_oval: + main_indices.update(face_oval_indices) + + indices = sorted(main_indices) + if return_tensor: + return torch.as_tensor(indices, dtype=torch.long) + return indices +def lmkAll_2_lmkMain(lmks468or478: np.ndarray, include_face_oval: bool = False) -> np.ndarray: + """ + Convert 468/478 landmarks to a main landmark set. + Based on MediaPipe visualization, extract key landmarks for eyes, eyebrows, lips, nose, pupils, etc. + + Args: + lmks468or478: 468/478 landmark coordinates, shape (468, 2) + include_face_oval: whether to include face contour landmarks (default: False) + + Returns: + Main landmark coordinates, shape (N, 2) + """ + if len(lmks468or478)<473: + raise Exception(lmks468or478.shape) + + main_indices = get_lmkMain_indices(include_face_oval) + # Filter indices out of range (e.g., iris points 468, 473 exist only in refined mode) + valid_indices = [idx for idx in main_indices if idx < lmks468or478.shape[0]] + + # Sort by index for consistency + valid_indices = sorted(valid_indices) + + # Extract the corresponding landmarks + main_landmarks = lmks468or478[valid_indices] + + return main_landmarks +if __name__=='__main__': + """Test the landmark extractor functionality.""" + print("Testing LandmarkExtractor...") + + # Initialize extractor and visualizer + extractor = LandmarkExtractor() + visualizer = LandmarkVisualizer() + + + if not test_images: + print(f"No test images found in {test_img_dir}") + exit(0) + + print(f"Found {len(test_images)} test images") + + # Create output directory + output_dir = Path("4debug/landmark_test_output") + output_dir.mkdir(exist_ok=True) + print(f"{output_dir=}") + + # Test single image extraction + print("\n=== Testing single image extraction ===") + for i, img_path in enumerate(test_images): + print(f"Processing {img_path}") + + # Load image + img_cv2 = cv2.imread(str(img_path)) + img_cv2 = cv2.cvtColor(img_cv2, cv2.COLOR_BGR2RGB) + if img_cv2 is None: + print(f" ❌ Could not load image from {img_path}") + continue + + # Extract landmarks + landmarks = extractor.extract_single(img_cv2) + print(f"{landmarks[:4]=}") + + if landmarks is None: + print(f" ❌ Failed to extract landmarks from {img_path}") + continue + + print(f" ✅ Extracted {landmarks.shape[0]} landmarks") + print(f" Landmark shape: {landmarks.shape}") # (N, 2) + + # Test main landmark extraction (without face contour) + main_landmarks = lmkAll_2_lmkMain(landmarks, include_face_oval=False) + print(f"Extracted {main_landmarks.shape[0]} main landmarks (without face oval) from {landmarks.shape[0]} total landmarks") + + # Test main landmark extraction (with face contour) + main_landmarks_with_oval = lmkAll_2_lmkMain(landmarks, include_face_oval=True) + print(f"Extracted {main_landmarks_with_oval.shape[0]} main landmarks (with face oval) from {landmarks.shape[0]} total landmarks") + + # Visualize and save original landmarks + output_path = output_dir / f"landmark_vis_{i+1}_{Path(img_path).stem}_all468.jpg" + visualizer.save_landmark_visualization(img_cv2, landmarks, output_path) + print(f" 📁 Saved all 468 landmarks visualization to {output_path}") + + # Visualize and save main landmarks only (use simple points, not connections) + output_path_main = output_dir / f"landmark_vis_{i+1}_{Path(img_path).stem}_main.jpg" + visualizer.save_landmark_visualization(img_cv2, main_landmarks, output_path_main, use_connections=False) + print(f" 📁 Saved main landmarks (without face oval) visualization to {output_path_main}") + + # Visualize and save main landmarks with face oval + output_path_main_oval = output_dir / f"landmark_vis_{i+1}_{Path(img_path).stem}_main_with_oval.jpg" + visualizer.save_landmark_visualization(img_cv2, main_landmarks_with_oval, output_path_main_oval, use_connections=False) + print(f" 📁 Saved main landmarks (with face oval) visualization to {output_path_main_oval}") diff --git a/lmk_util/mp_utils.py b/lmk_util/mp_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2fde0ce03d439da590114de692b0633066650eab --- /dev/null +++ b/lmk_util/mp_utils.py @@ -0,0 +1,95 @@ +import os +import numpy as np +import cv2 +import time +from tqdm import tqdm +import multiprocessing +import glob + +import mediapipe as mp +from mediapipe import solutions +from mediapipe.framework.formats import landmark_pb2 +from mediapipe.tasks import python +from mediapipe.tasks.python import vision +from . import face_landmark + +_CUR_DIR = 'Other_dependencies' + + +class LMKExtractor(): + def __init__(self, FPS=25): + # Create an FaceLandmarker object. + self.mode = mp.tasks.vision.FaceDetectorOptions.running_mode.IMAGE + base_options = python.BaseOptions(model_asset_path=os.path.join(_CUR_DIR, 'mp_models/face_landmarker_v2_with_blendshapes.task')) + base_options.delegate = mp.tasks.BaseOptions.Delegate.CPU + options = vision.FaceLandmarkerOptions(base_options=base_options, + running_mode=self.mode, + output_face_blendshapes=True, + output_facial_transformation_matrixes=True, + num_faces=1) + self.detector = face_landmark.FaceLandmarker.create_from_options(options) + self.last_ts = 0 + self.frame_ms = int(1000 / FPS) + + det_base_options = python.BaseOptions(model_asset_path=os.path.join(_CUR_DIR, 'mp_models/blaze_face_short_range.tflite')) + det_options = vision.FaceDetectorOptions(base_options=det_base_options) + self.det_detector = vision.FaceDetector.create_from_options(det_options) + + + def __call__(self, img): + frame = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + image = mp.Image(image_format=mp.ImageFormat.SRGB, data=frame) + t0 = time.time() + if self.mode == mp.tasks.vision.FaceDetectorOptions.running_mode.VIDEO: + det_result = self.det_detector.detect(image) + if len(det_result.detections) != 1: + return None + self.last_ts += self.frame_ms + try: + detection_result, mesh3d = self.detector.detect_for_video(image, timestamp_ms=self.last_ts) + except: + return None + elif self.mode == mp.tasks.vision.FaceDetectorOptions.running_mode.IMAGE: + # det_result = self.det_detector.detect(image) + + # if len(det_result.detections) != 1: + # return None + try: + detection_result, mesh3d = self.detector.detect(image) + except: + return None + + + bs_list = detection_result.face_blendshapes + if len(bs_list) == 1: + bs = bs_list[0] + bs_values = [] + for index in range(len(bs)): + bs_values.append(bs[index].score) + bs_values = bs_values[1:] # remove neutral + trans_mat = detection_result.facial_transformation_matrixes[0] + face_landmarks_list = detection_result.face_landmarks + face_landmarks = face_landmarks_list[0] + lmks = [] + for index in range(len(face_landmarks)): + x = face_landmarks[index].x + y = face_landmarks[index].y + z = face_landmarks[index].z + lmks.append([x, y, z]) + lmks = np.array(lmks) + + lmks3d = np.array(mesh3d.vertex_buffer) + lmks3d = lmks3d.reshape(-1, 5)[:, :3] + mp_tris = np.array(mesh3d.index_buffer).reshape(-1, 3) + 1 + + return { + "lmks": lmks, + 'lmks3d': lmks3d, + "trans_mat": trans_mat, + 'faces': mp_tris, + "bs": bs_values + } + else: + # print('multiple faces in the image: {}'.format(img_path)) + return None + \ No newline at end of file diff --git a/lora_layers.py b/lora_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..714dff440b329b173d5aa05eb88b5fd0911655db --- /dev/null +++ b/lora_layers.py @@ -0,0 +1,541 @@ +""" +LoRA (Low-Rank Adaptation) implementation for MLP layers. +Replaces qkv projections in attention and the FFN MLP layers. +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +import math +USE_LORA :bool = 1 # enable LoRA replacement for MLP layers (and Conv2d) +if USE_LORA: + LORA_dropout :float = 0.0 # LoRA dropout rate + LORA_apply_to_conv :bool = 1 # also apply LoRA to Conv2d layers + LORA_freeze_base :bool = False + LORA_DEBUG :bool = 0 + FORCE_SAME_RANK_ACROSS_TASKS :bool = 0 + DONT_lora_if_dim_lt :int = 90 # 0: disable. increase for low-dim layers (e.g., in/out conv dim < 32) + DONT_lora_if_rankFrac_gt :float = 0.3 + +class LoRALinear(nn.Module): + """ + LoRA layer that wraps a frozen Linear layer with low-rank adaptation. + + Args: + original_linear: original nn.Linear layer that will be frozen + rank: LoRA rank (r) + dropout: dropout probability + """ + def __init__( + self, + original_linear: nn.Linear, + rank: int = 4, + dropout: float = 0.0, + freeze_base: bool = True, + ): + super().__init__() + + self.in_features = original_linear.in_features + self.out_features = original_linear.out_features + self.rank = rank + self.scaling = 2.0 + + # Freeze the original weights + self.original_linear = original_linear + if freeze_base: + for param in self.original_linear.parameters(): + param.requires_grad = False + + # LoRA low-rank decomposition: W = W_0 + B @ A, where B: out_features x rank, A: rank x in_features + self.lora_A = nn.Parameter(torch.zeros(rank, self.in_features)) + self.lora_B = nn.Parameter(torch.zeros(self.out_features, rank)) + + # Initialization + nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B) # initialize B to 0 so LoRA has no initial effect + + # Dropout + self.dropout = nn.Dropout(p=dropout) if dropout > 0 else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Output from the frozen original linear layer + result = self.original_linear(x) + + # LoRA low-rank update: x @ A^T @ B^T + # x: (..., in_features) + # lora_A: (rank, in_features) -> A^T: (in_features, rank) + # lora_B: (out_features, rank) -> B^T: (rank, out_features) + lora_out = self.dropout(x) @ self.lora_A.T @ self.lora_B.T + + return result + lora_out * self.scaling + + def __repr__(self): + return f"LoRALinear(in_features={self.in_features}, out_features={self.out_features}, rank={self.rank}, scaling={self.scaling})" + + +def replace_linear_with_lora( + module: nn.Module, + rank: int = 4, + dropout: float = 0.0, + target_modules: list = None, + verbose: bool = True, +): + """ + Recursively replace nn.Linear layers within a module with LoRALinear wrappers. + + Args: + module: module whose linear layers should be replaced + rank: LoRA rank + dropout: dropout probability + target_modules: specific module names to replace; None means all linears + e.g.: ['to_q', 'to_k', 'to_v', 'to_out'] for attention + ['net.0', 'net.2'] for FeedForward + verbose: whether to log replacements + + Returns: + the module with replacements applied + """ + replaced_count = 0 + + for name, child in module.named_children(): + # Skip modules not in the target list (if filtering is enabled) + if target_modules is not None and name not in target_modules: + # Continue recursing into child modules + replace_linear_with_lora(child, rank, dropout, target_modules, verbose) + continue + + if isinstance(child, nn.Linear): + # Replace with LoRALinear + lora_layer = LoRALinear(child, rank=rank, dropout=dropout, freeze_base=LORA_freeze_base) + setattr(module, name, lora_layer) + replaced_count += 1 + if verbose: + print(f"[LoRA] Replaced {name}: {child.in_features} -> {child.out_features} with rank={rank}") + elif isinstance(child, nn.Sequential): + # Handle Sequential containers (e.g., FeedForward nets) + new_sequential = nn.Sequential() + for idx, submodule in enumerate(child): + if isinstance(submodule, nn.Linear): + lora_layer = LoRALinear(submodule, rank=rank, dropout=dropout, freeze_base=LORA_freeze_base) + new_sequential.add_module(str(idx), lora_layer) + replaced_count += 1 + if verbose: + print(f"[LoRA] Replaced {name}.{idx}: {submodule.in_features} -> {submodule.out_features} with rank={rank}") + else: + new_sequential.add_module(str(idx), submodule) + setattr(module, name, new_sequential) + else: + # Recurse into the remaining submodules + replace_linear_with_lora(child, rank, dropout, target_modules, verbose) + + return module + + +def count_lora_parameters(module: nn.Module): + """ + Count LoRA parameters within a module. + + Returns: + dict: {'trainable': trainable params, 'frozen': frozen params, 'total': total params} + """ + trainable_params = 0 + frozen_params = 0 + + for name, param in module.named_parameters(): + num_params = param.numel() + if param.requires_grad: + trainable_params += num_params + else: + frozen_params += num_params + + total_params = trainable_params + frozen_params + + return { + 'trainable': trainable_params, + 'frozen': frozen_params, + 'total': total_params, + 'trainable_ratio': trainable_params / total_params if total_params > 0 else 0, + } + + +def print_lora_parameters(module: nn.Module, name: str = "Model"): + """Print LoRA parameter statistics.""" + stats = count_lora_parameters(module) + print(f"\n{'='*60}") + print(f"{name} Parameter Statistics:") + print(f"{'='*60}") + print(f"Trainable params: {stats['trainable']:,} ({stats['trainable_ratio']*100:.2f}%)") + print(f"Frozen params: {stats['frozen']:,} ({(1-stats['trainable_ratio'])*100:.2f}%)") + print(f"Total params: {stats['total']:,}") + print(f"{'='*60}\n") + + +class LoRAConv2d(nn.Module): + """ + LoRA layer for Conv2d. + + Treat Conv2d as a matrix multiplication: + - flatten kernel: (out_channels, in_channels, k, k) -> (out_channels, in_channels*k*k) + - apply low-rank decomposition: W = W_0 + B @ A + + Args: + original_conv: original nn.Conv2d layer that will be frozen + rank: LoRA rank (r) + dropout: dropout probability + """ + def __init__( + self, + original_conv: nn.Conv2d, + rank: int = 4, + dropout: float = 0.0, + freeze_base: bool = True, + ): + super().__init__() + + self.out_channels = original_conv.out_channels + self.in_channels = original_conv.in_channels + self.kernel_size = original_conv.kernel_size + self.stride = original_conv.stride + self.padding = original_conv.padding + self.dilation = original_conv.dilation + self.groups = original_conv.groups + + self.rank = rank + self.scaling = 2.0 + + # Freeze the original weights + self.original_conv = original_conv + if freeze_base: + for param in self.original_conv.parameters(): + param.requires_grad = False + + # LoRA low-rank decomposition + # lora_A: (rank, in_channels, kernel_size, kernel_size) + # lora_B: (out_channels, rank, 1, 1) - via 1x1 convolution + self.lora_A = nn.Parameter(torch.zeros( + rank, + self.in_channels // self.groups, + self.kernel_size[0], + self.kernel_size[1] + )) + self.lora_B = nn.Parameter(torch.zeros(self.out_channels, rank, 1, 1)) + + # Initialization + nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B) # initialize B to 0 + + # Dropout + self.dropout = nn.Dropout(p=dropout) if dropout > 0 else nn.Identity() + + print(f"param orig:lora (M) = {self.original_conv.weight.numel()/1024/1024}:{self.lora_A.numel()+self.lora_B.numel()/1024/1024}") + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Output from the frozen original convolution + # Use no_grad to avoid computing gradients for the base weights + result = self.original_conv(x) + + # LoRA low-rank update + # first apply lora_A (down projection) then lora_B (up projection) + x_dropped = self.dropout(x) + lora_out = F.conv2d( + x_dropped, + self.lora_A, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + groups=self.groups + ) + lora_out = F.conv2d(lora_out, self.lora_B) + + return result + lora_out * self.scaling + + def __repr__(self): + return (f"LoRAConv2d(in_channels={self.in_channels}, out_channels={self.out_channels}, " + f"kernel_size={self.kernel_size}, rank={self.rank}, scaling={self.scaling})") + + + + + +def _auto_lora_rank(in_features: int, out_features: int) -> int: + m = min(in_features, out_features) + r = max(LORA_rank_min, int(round(m / max(1.0, LORA_rank_ratio)))) + if (LORA_rank_max is not None) and (r > LORA_rank_max): + r = LORA_rank_max + return max(1, r) + +def _svd_low_rank(M: torch.Tensor, rank: int): + # M: [out, in] + orig_device = M.device + orig_dtype = M.dtype + if 1: + M = M.to(device=torch.device('cuda'), dtype=torch.float32) + U, S, Vh = torch.linalg.svd(M, full_matrices=False) + r = min(rank, U.shape[1], Vh.shape[0]) + U_r = U[:, :r] + S_r = S[:r] + Vh_r = Vh[:r, :] + S_root = torch.sqrt(torch.clamp(S_r, min=0)) + B = U_r @ torch.diag(S_root) # [out, r] + A = torch.diag(S_root) @ Vh_r # [r, in] + + B = B.to(device=orig_device, dtype=orig_dtype) + A = A.to(device=orig_device, dtype=orig_dtype) + S = S.to(device=orig_device, dtype=orig_dtype) + return B, A, S + + +def _svdvals_squared(M: torch.Tensor) -> torch.Tensor: + # Return squared singular values (energy), sorted in descending order; M: [out, in] + orig_device = M.device + orig_dtype = M.dtype + if 1: + M = M.to(device=torch.device('cuda'), dtype=torch.float32) + S = torch.linalg.svdvals(M) + S2 = (S.float() ** 2) + return S2.to(device=orig_device, dtype=torch.float32) + + +def _compute_adaptive_rank_from_S2_list( + list_S2: list, + avg_threshold: float = None, + min_threshold: float = None, + max_rank: int = None, +) -> int: + # list_S2: squared singular value vectors (descending) for each matrix + assert len(list_S2) > 0 + if avg_threshold is None: + avg_threshold = ADAPTIVE_RANK_AVG_ENERGY_THRESH + if min_threshold is None: + min_threshold = ADAPTIVE_RANK_MIN_ENERGY_THRESH + + totals = [] + lengths = [] + for s2 in list_S2: + assert s2.numel() > 0 + total = s2.sum() + # Quick fail: zero ΔW has zero energy, so thresholds can't be evaluated + assert float(total.item()) > 0.0, "Zero energy in weight_diff; cannot determine adaptive rank" + totals.append(total) + lengths.append(int(s2.shape[0])) + + R_cap = min(lengths) + if LORA_rank_max is not None: + R_cap = min(R_cap, int(LORA_rank_max)) + if max_rank is not None: + R_cap = min(R_cap, int(max_rank)) + R_cap = max(1, R_cap) + + # Iterate ranks r to see if both average and minimum energy ratios meet thresholds + for r in range(1, R_cap + 1): + ratios = [] + for s2, total in zip(list_S2, totals): + captured = s2[:r].sum() + ratios.append(float((captured / total).item())) + avg_ratio = sum(ratios) / len(ratios) + min_ratio = min(ratios) + if (avg_ratio >= avg_threshold) and (min_ratio >= min_threshold): + ret = min(int(R_cap), max(int(LORA_rank_min), int(r))) + return ret + # If no rank satisfies both thresholds, fail fast instead of silently degrading + raise AssertionError(f"No rank satisfies avg>={avg_threshold} and min>={min_threshold} up to R_cap={R_cap}") + +def _compute_per_task_ranks_from_S2_list( + list_S2: list, + min_threshold: float = None, + max_rank: int = None, +) -> list: + # Compute rank per matrix so its energy ratio >= min_threshold (uses min threshold only) + assert len(list_S2) > 0 + if min_threshold is None: + min_threshold = ADAPTIVE_RANK_MIN_ENERGY_THRESH + ret = [] + for i, s2 in enumerate(list_S2): + assert s2.numel() > 0 + total = s2.sum() + assert float(total.item()) > 0.0, "Zero energy in weight_diff; cannot determine adaptive rank" + R_cap = int(s2.shape[0]) + if LORA_rank_max is not None: + R_cap = min(R_cap, int(LORA_rank_max)) + if max_rank is not None: + R_cap = min(R_cap, int(max_rank)) + R_cap = max(1, R_cap) + found = R_cap + # Task-level threshold: when ranks are allowed to differ, use TASK_2_adaptive_rank_min_energy_thresh + thres_this = TASK_2_adaptive_rank_min_energy_thresh[i] if (not FORCE_SAME_RANK_ACROSS_TASKS) else min_threshold + for r in range(1, R_cap + 1): + ratio = s2[:r].sum() / total + if float(ratio.item()) >= float(thres_this): + found = r + break + ret.append(int(max(int(LORA_rank_min), int(found)))) + return ret + +def compute_adaptive_rank_for_linear_diffs( + weight_diffs: list, + avg_threshold: float = None, + min_threshold: float = None, + max_rank: int = None, + per_task: bool = None, +): + # weight_diffs: List[Tensor [out, in]] + assert isinstance(weight_diffs, (list, tuple)) and len(weight_diffs) > 0 + if per_task is None: + per_task = not FORCE_SAME_RANK_ACROSS_TASKS + list_S2 = [_svdvals_squared(M) for M in weight_diffs] + out0, in0 = weight_diffs[0].shape + if per_task: + ranks = _compute_per_task_ranks_from_S2_list(list_S2, min_threshold, max_rank) + print(f"[AdaptiveRank-Linear per-task] in={in0} out={out0} ranks={ranks}") + return ranks + else: + ret = _compute_adaptive_rank_from_S2_list(list_S2, None, min_threshold, max_rank) + print(f"[AdaptiveRank-Linear] in={in0} out={out0} rank={ret}") + return ret + + +def compute_adaptive_rank_for_conv_diffs( + weight_diffs: list, + avg_threshold: float = None, + min_threshold: float = None, + max_rank: int = None, + per_task: bool = None, +): + # weight_diffs: List[Tensor [out, in, kH, kW]] -> reshape to [out, in*k*k] + assert isinstance(weight_diffs, (list, tuple)) and len(weight_diffs) > 0 + if per_task is None: + per_task = not FORCE_SAME_RANK_ACROSS_TASKS + list_S2 = [] + for W in weight_diffs: + out_c, in_c, kH, kW = W.shape + M = W.reshape(out_c, in_c * kH * kW) + list_S2.append(_svdvals_squared(M)) + out0, in0, kH0, kW0 = weight_diffs[0].shape + if per_task: + ranks = _compute_per_task_ranks_from_S2_list(list_S2, min_threshold, max_rank) + print(f"[AdaptiveRank-Conv per-task] in_ch={in0} out_ch={out0} kernel=({kH0},{kW0}) ranks={ranks}") + return ranks + else: + ret = _compute_adaptive_rank_from_S2_list(list_S2, None, min_threshold, max_rank) + print(f"[AdaptiveRank-Conv] in_ch={in0} out_ch={out0} kernel=({kH0},{kW0}) rank={ret}") + return ret + + +class LoRAAdapterLinearOnly(nn.Module): + """ + Incremental LoRA (no base Linear) that returns x @ A^T @ B^T + bias_delta. + """ + def __init__(self, in_features: int, out_features: int, rank: int = None, dropout: float = 0.0, scaling: float = 1.0, use_bias_delta: bool = True): + super().__init__() + if rank is None: + rank = _auto_lora_rank(in_features, out_features) + self.in_features = in_features + self.out_features = out_features + self.rank = rank + self.scaling = scaling + self.dropout = nn.Dropout(p=dropout) if dropout > 0 else nn.Identity() + self.lora_A = nn.Parameter(torch.zeros(rank, in_features)) + self.lora_B = nn.Parameter(torch.zeros(out_features, rank)) + self.use_bias_delta = use_bias_delta + if use_bias_delta: + self.lora_bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter('lora_bias', None) + # init + nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B) + + @torch.no_grad() + def init_from_diff(self, weight_diff: torch.Tensor, bias_diff: torch.Tensor = None): + # weight_diff: [out, in] + B, A, S = _svd_low_rank(weight_diff.float(), self.rank) + self.lora_A.copy_(A.to(self.lora_A.dtype).to(self.lora_A.device)) + self.lora_B.copy_(B.to(self.lora_B.dtype).to(self.lora_B.device)) + if self.use_bias_delta and (bias_diff is not None): + self.lora_bias.copy_(bias_diff) + if LORA_DEBUG: + energy_total = (S.float() ** 2).sum().item() + energy_top = (S[: self.rank].float() ** 2).sum().item() + energy_ratio = energy_top / max(1e-12, energy_total) + approx = (B @ A).to(weight_diff.device).to(weight_diff.dtype) + err = torch.linalg.norm((approx - weight_diff).float()).item() + base = torch.linalg.norm(weight_diff.float()).item() + rel_err = err / max(1e-12, base) + bias_norm = 0.0 if (bias_diff is None) else float(torch.linalg.norm(bias_diff.float()).item()) + print(f"[LoRA-Linear init] shape={tuple(weight_diff.shape)} rank={self.rank} energy={energy_ratio:.4f} rel_err={rel_err:.6f} bias_norm={bias_norm:.6f}") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + update = self.dropout(x) @ self.lora_A.T @ self.lora_B.T + if self.lora_bias is not None: + update = update + self.lora_bias + return update * self.scaling + + +class LoRAAdapterConv2dOnly(nn.Module): + """ + Incremental LoRA for Conv2d: convolve with A then 1x1 B, return the delta. + """ + def __init__(self, in_channels: int, out_channels: int, kernel_size: tuple, stride: tuple, padding: tuple, dilation: tuple, groups: int = 1, rank: int = None, dropout: float = 0.0, scaling: float = 1.0, use_bias_delta: bool = True): + super().__init__() + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size) + if isinstance(stride, int): + stride = (stride, stride) + if isinstance(padding, int): + padding = (padding, padding) + if isinstance(dilation, int): + dilation = (dilation, dilation) + kH, kW = kernel_size + if rank is None: + # Estimate rank from the flattened in/out dimensions + rank = _auto_lora_rank(in_channels * kH * kW, out_channels) + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + self.rank = rank + self.scaling = scaling + self.dropout = nn.Dropout(p=dropout) if dropout > 0 else nn.Identity() + # A: [rank, in/groups, kH, kW] + self.lora_A = nn.Parameter(torch.zeros(rank, in_channels // groups, kH, kW)) + # B: [out, rank, 1, 1] + self.lora_B = nn.Parameter(torch.zeros(out_channels, rank, 1, 1)) + self.use_bias_delta = use_bias_delta + if use_bias_delta: + self.lora_bias = nn.Parameter(torch.zeros(out_channels)) + else: + self.register_parameter('lora_bias', None) + # init + nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B) + + @torch.no_grad() + def init_from_diff(self, weight_diff: torch.Tensor, bias_diff: torch.Tensor = None): + # weight_diff: [out, in, kH, kW] + out_c, in_c, kH, kW = weight_diff.shape + M = weight_diff.reshape(out_c, in_c * kH * kW) + B, A, S = _svd_low_rank(M.float(), self.rank) # B:[out,r], A:[r,in*k*k] + A_reshaped = A.view(self.rank, in_c, kH, kW) + self.lora_A.copy_(A_reshaped) + self.lora_B.copy_(B.view(out_c, self.rank, 1, 1)) + if self.lora_bias is not None and (bias_diff is not None): + self.lora_bias.copy_(bias_diff) + if LORA_DEBUG: + energy_total = (S.float() ** 2).sum().item() + energy_top = (S[: self.rank].float() ** 2).sum().item() + energy_ratio = energy_top / max(1e-12, energy_total) + approx = (B @ A).to(M.device).to(M.dtype) + err = torch.linalg.norm((approx - M).float()).item() + base = torch.linalg.norm(M.float()).item() + rel_err = err / max(1e-12, base) + bias_norm = 0.0 if (bias_diff is None) else float(torch.linalg.norm(bias_diff.float()).item()) + print(f"[LoRA-Conv init] out_in_k=({out_c},{in_c},{kH}x{kW}) rank={self.rank} energy={energy_ratio:.4f} rel_err={rel_err:.6f} bias_norm={bias_norm:.6f}") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_d = self.dropout(x) + u = F.conv2d(x_d, self.lora_A, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups) + u = F.conv2d(u, self.lora_B) + if self.lora_bias is not None: + u = u + self.lora_bias.view(1, -1, 1, 1) + return u * self.scaling diff --git a/multiTask_model.py b/multiTask_model.py new file mode 100644 index 0000000000000000000000000000000000000000..08fdfc25a33bda16e699c5b00e0efdd2408d1d6c --- /dev/null +++ b/multiTask_model.py @@ -0,0 +1,712 @@ +from ldm.modules.attention import * +import global_ +import torch +import torch.nn as nn +import torch.nn.functional as F +from my_py_lib.torch_util import custom_repr_v3 +from imports import * +import cv2, numpy as np +from lmk_util.lmk_extractor import lmkAll_2_lmkMain, get_lmkMain_indices +from MoE import * +from lora_layers import * +import json +import copy + + + +""" +Global knobs for shared experts and routing (no argparse per user preference) +""" +NUM_SHARED_FFN = 8 +GATE_TOPK = 2 + +# Sparse MoE FFN for all FFN blocks (in addition to shared orig + task LoRA) +# default off to keep behavior unchanged; enable by setting EXTRA_MoE_enable to True +EXTRA_MoE_enable :bool = 1 +EXTRA_MoE_num_ep = 8 # number of sparse MoE experts (narrow FFN) +EXTRA_MoE_inner_divisor = 128 # each expert intermediate dim = original FFN intermediate dim * this ratio +EXTRA_MoE_topK = 2 # sparse routing selects top-k experts (k fixed to 2) +EXTRA_MoE_add_noise :bool = 1 # add random noise to routing scores for exploration +EXTRA_MoE_noise_std = 0.1 # noise strength (Gaussian standard deviation) +EXTRA_MoE_en_auxLoss :bool = 0 # compute load-balancing auxiliary loss +EXTRA_MoE_aux_coef = 1e-2 # coefficient for auxiliary loss when adding to total loss +EXTRA_MoE_routing_mode = 'sparse' # 'sparse' | 'dense' +LMK_PICK_IDX = None +NUM_lmk_pick = len(LMK_PICK_IDX) if LMK_PICK_IDX is not None else len(get_lmkMain_indices(include_face_oval=True)) +print(f"{NUM_lmk_pick=}") +IMAGE_SIZE_FOR_LMK_NORM = 512.0 + +def _log2(orig_modules, lora_modules): + """Calculate and log parameter statistics for original and LoRA modules""" + # Calculate original module stats + orig_params = sum(p.numel() for p in orig_modules.parameters()) + orig_size = sum(p.numel() * p.element_size() for p in orig_modules.parameters()) + # Calculate LoRA stats (handle both single module and tuple/list) + if isinstance(lora_modules, (list, tuple)): + lora_params = sum(p.numel() for m in lora_modules for p in m.parameters()) + lora_size = sum(p.numel() * p.element_size() for m in lora_modules for p in m.parameters()) + # Try to get rank from lora modules + ranks = [] + for m in lora_modules: + if hasattr(m, 'rank'): + ranks.append(m.rank) + if len(ranks) == 2: + rank_str = f" (rank_in={ranks[0]} rank_out={ranks[1]})" + elif len(ranks) == 1: + rank_str = f" (rank={ranks[0]})" + else: + rank_str = "" + else: + lora_params = sum(p.numel() for p in lora_modules.parameters()) + lora_size = sum(p.numel() * p.element_size() for p in lora_modules.parameters()) + # Try to get rank from lora module + if hasattr(lora_modules, 'rank'): + rank_str = f" (rank={lora_modules.rank})" + else: + rank_str = "" + msg1 = f"orig: {orig_params:,} params, {orig_size/1024/1024:.2f}MB" + msg2 = f"LoRA: {lora_params:,} params, {lora_size/1024/1024:.2f}MB{rank_str}" + for msg in [msg1, msg2]: + print(msg) + continue + with open(_verify_log_file, 'a') as f: + f.write(msg + '\n') +def _log1(msg: str): + """Print message and append to log file""" + print(msg) + return + with open(_verify_log_file, 'a') as f: + f.write(msg + '\n') + +def build_ffn_gate_input_common(x: torch.Tensor, token_pos_grid__cur, tasks: list): + """Build gate input for FFN routing (reusable across FFN classes).""" + b, n, d = x.shape + token_feat = x # token + avg_feat = x.mean(dim=1, keepdim=True).expand(-1, n, -1) # avg(all tokens) + len_task = len(tasks) # task one-hot + task_1h = x.new_zeros(b, len_task) + task_1h[:, global_.task] = 1 + task_1h = task_1h.unsqueeze(1).expand(-1, n, -1) + token_pos = token_pos_grid__cur # token-position from global_.token_pos_grid__cur + assert token_pos.shape[:2] == (b, n), (token_pos.shape, (b, n), ) + rel_flat = x.new_zeros(b, n, 2*NUM_lmk_pick) # token-relative-position to lmks + lmk = global_.lmk_ + if 1: + lmk = lmk.to(x.device).float()# TODO to check is it normed already? + if LMK_PICK_IDX is None: + assert NUM_lmk_pick==lmk.shape[1] + else: + lmk = lmk[:, LMK_PICK_IDX, :] + rel = token_pos.unsqueeze(2) - lmk.unsqueeze(1) # [b,n,L,2] + rel_flat = rel.reshape(b, n, -1) + token_pos = token_pos * 2.0 - 1.0 # [0,1] -> [-1,1] + gate_in = torch.cat([token_feat, avg_feat, task_1h, token_pos, rel_flat], dim=-1) + ctx = {'token_feat': token_feat, 'avg_feat': avg_feat, 'task_1h': task_1h, 'token_pos': token_pos, 'lmk': lmk, 'rel': rel, 'rel_flat': rel_flat} + return gate_in, ctx + +def replace_modules_lossless( + module: nn.Module, + src_modules: list, + l_task: list, + parent_name: str = "", + depth :int = 0, + for_refnet: bool = False, +): + """ + Apply policy: + - FFN: shared-plus-task (lossless upcycle) + - CrossAttention linear projections (to_q, to_k, to_v, to_out.0): shared-plus-task + - Conv2d: keep task-specific or wrap with shared-plus-task if desired + - Norms: keep task-specific (LayerNorm/GroupNorm) + """ + if depth==0: + CONV2D_PARAM_STATS.clear() + # Skip modules with no parameters + if len(list(module.parameters())) == 0: + # print(f'[replace_modules_lossless] Skipping module with no parameters: {module}') + return module + if len(list(module.named_children()))==0: + print('\n!!!! len(list(module.named_children()))==0',module) + assert 0 + for name, child in module.named_children(): + full_name = f"{parent_name}.{name}" if parent_name else f".{name}" + src_child_modules = [getattr(src_module, name) for src_module in src_modules] + if len({id(s) for s in src_child_modules}) < len(src_child_modules): + raise Exception('Duplicate source modules detected!') + # if sources are the same instance(s), clone to ensure distinct expert modules + src_child_modules = [copy.deepcopy(src_child_modules[0]) for _ in src_child_modules] + + if isinstance(child, FeedForward): + if 0: + setattr(module, name, TaskSpecific_MoE([s for s in src_child_modules], tasks=l_task)) + else: + # FFN -> shared average + per-task LoRA + setattr(module, name, upCycle_module(src_child_modules, l_task, module_name=full_name)) + continue + + if isinstance(child, CrossAttention): + # replace linear projections + # if for_refnet: + if 0: + for proj_name in ["to_q", "to_k", "to_v"]: + src_proj_list = [getattr(s, proj_name) for s in src_child_modules] + setattr(child, proj_name, upCycle_module(src_proj_list, l_task, module_name=f"{full_name}.{proj_name}")) + if hasattr(child.to_out, "__getitem__"): + src_linear0 = [s.to_out[0] for s in src_child_modules] + child.to_out[0] = upCycle_module(src_linear0, l_task, module_name=f"{full_name}.to_out.0") + else: + for proj_name in ["to_q", "to_k", "to_v"]: + src_proj_list = [getattr(s, proj_name) for s in src_child_modules] + setattr(child, proj_name, TaskSpecific_MoE([s for s in src_proj_list], tasks=l_task) ) + if hasattr(child.to_out, "__getitem__"): + src_linear0 = [s.to_out[0] for s in src_child_modules] + child.to_out[0] = TaskSpecific_MoE([s for s in src_linear0], tasks=l_task) + continue + + if isinstance(child, nn.Conv2d): + num_params = sum(p.numel() for p in child.parameters()) + CONV2D_PARAM_STATS.append((num_params, full_name)) + # if num_params > CONV2D_PARAM_MOE_THRES and (not any(full_name.startswith(p) for p in FORCE_TASKSPEC_PREFIXES)): + if 1: + printC(f"shared+LoRA Conv2d",f"{full_name}") + setattr(module, name, upCycle_module(src_child_modules, l_task, module_name=full_name)) + else: + setattr(module, name, TaskSpecific_MoE([s for s in src_child_modules], tasks=l_task)) + continue + elif isinstance(child, (nn.LayerNorm, nn.GroupNorm)): + setattr(module, name, TaskSpecific_MoE([s for s in src_child_modules], tasks=l_task)) + continue + elif isinstance(child, nn.Linear): + # default linear: task-specific + setattr(module, name, TaskSpecific_MoE([s for s in src_child_modules], tasks=l_task)) + continue + else: + replace_modules_lossless(child, src_child_modules, l_task, parent_name=full_name, depth=depth+1, for_refnet=for_refnet) + + if depth==0: + stats_sorted = sorted(CONV2D_PARAM_STATS, key=lambda x: x[0], reverse=True) + if gate_("[Conv2d param stats] count, name (sorted desc):"): + for cnt, n in stats_sorted: + print(f" {cnt:12d} {n}") + return module + +def upCycle_module(l_modules, l_task, module_name: str = None): + assert len( set( [type(m) for m in l_modules] ) ) == 1 + m0 = l_modules[0] + if isinstance(m0, FeedForward): + obj = FFN_Shared_Plus_TaskLoRA(l_modules, l_task, module_name=module_name) + elif isinstance(m0, nn.Linear): + obj = Linear_Shared_Plus_TaskLoRA(l_modules, l_task, module_name=module_name) + elif isinstance(m0, nn.Conv2d): + obj = Conv_Shared_Plus_TaskLoRA(l_modules, l_task, module_name=module_name) + else: + raise Exception(module_name,m0) + return TaskSpecific_MoE([s for s in l_modules], tasks=l_task) + if obj.dont_lora: + return TaskSpecific_MoE([s for s in l_modules], tasks=l_task) + return obj + + + + +class ResidualAdapterLinearOnly(nn.Module): + """ + Full-rank residual adapter returning the linear delta (orig - shared). + """ + def __init__(self, in_features: int, out_features: int, scaling: float = 1.0, use_bias_delta: bool = True): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.rank = min(in_features, out_features) + self.scaling = scaling + self.use_bias_delta = use_bias_delta + self.delta_weight = nn.Parameter(torch.zeros(out_features, in_features)) + if use_bias_delta: + self.delta_bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter('delta_bias', None) + @torch.no_grad() + def init_from_diff(self, weight_diff: torch.Tensor, bias_diff: torch.Tensor = None): + self.delta_weight.copy_(weight_diff) + if (self.delta_bias is not None) and (bias_diff is not None): + self.delta_bias.copy_(bias_diff) + def forward(self, x: torch.Tensor) -> torch.Tensor: + update = x @ self.delta_weight.T + if self.delta_bias is not None: + update = update + self.delta_bias + return update * self.scaling + +class ResidualAdapterConv2dOnly(nn.Module): + """ + Full-rank residual adapter for Conv2d, returning the convolutional delta (orig - shared). + """ + def __init__(self, in_channels: int, out_channels: int, kernel_size: tuple, stride: tuple, padding: tuple, dilation: tuple, groups: int = 1, scaling: float = 1.0, use_bias_delta: bool = True): + super().__init__() + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size) + if isinstance(stride, int): + stride = (stride, stride) + if isinstance(padding, int): + padding = (padding, padding) + if isinstance(dilation, int): + dilation = (dilation, dilation) + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + kH, kW = kernel_size + self.rank = min(out_channels, in_channels * kH * kW) + self.scaling = scaling + self.use_bias_delta = use_bias_delta + self.delta_weight = nn.Parameter(torch.zeros(out_channels, in_channels // groups, kH, kW)) + if use_bias_delta: + self.delta_bias = nn.Parameter(torch.zeros(out_channels)) + else: + self.register_parameter('delta_bias', None) + @torch.no_grad() + def init_from_diff(self, weight_diff: torch.Tensor, bias_diff: torch.Tensor = None): + self.delta_weight.copy_(weight_diff) + if (self.delta_bias is not None) and (bias_diff is not None): + self.delta_bias.copy_(bias_diff) + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = F.conv2d(x, self.delta_weight, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups) + if self.delta_bias is not None: + u = u + self.delta_bias.view(1, -1, 1, 1) + return u * self.scaling + + + +class Linear_TaskSpecific_Plus_Shared(nn.Module): + def __init__(self, l_proj: list, l_task: list): + super().__init__() + assert len(l_proj) >= 1 + p0 = l_proj[0] + assert isinstance(p0, nn.Linear) + in_f, out_f = p0.in_features, p0.out_features + bias = p0.bias is not None + self.shared = nn.Linear(in_f, out_f, bias=bias) + self.shared = zero_module(self.shared) + self.tasks = l_task + self.task_proj = ModuleDict_W(l_proj, self.tasks) + + def forward(self, x): + t = global_.task + return self.task_proj[t](x) + self.shared(x) + + +class Conv_TaskSpecific_Plus_Shared(nn.Module): + def __init__(self, l_conv: list, l_task: list): + super().__init__() + assert len(l_conv) >= 1 + c0 = l_conv[0] + assert isinstance(c0, nn.Conv2d) + self.shared = nn.Conv2d(c0.in_channels, c0.out_channels, kernel_size=c0.kernel_size, stride=c0.stride, padding=c0.padding, dilation=c0.dilation, groups=c0.groups, bias=(c0.bias is not None), padding_mode=c0.padding_mode) + self.shared = zero_module(self.shared) + self.tasks = l_task + self.task_conv = ModuleDict_W(l_conv, self.tasks) + + def forward(self, x): + t = global_.task + return self.task_conv[t](x) + self.shared(x) + + + + +def _average_state_dict(modules: list): + assert len(modules) > 0 + sd0 = modules[0].state_dict() + avg = {k: torch.zeros_like(v) for k, v in sd0.items()} + for m in modules: + msd = m.state_dict() + for k in avg: + avg[k] += msd[k] + for k in avg: + avg[k] /= len(modules) + return avg + + +class FFN_Shared_Plus_TaskLoRA(nn.Module): + def __init__(self, l_ffn: list, l_task: list, module_name: str = None): + super().__init__() + self.module_name = module_name + # _log1(f"-------- {module_name} --------") + assert len(l_ffn) >= 1 + self.tasks = l_task + self.num_tasks = len(l_task) + self.dont_lora = False + f0: FeedForward = l_ffn[0] + # build shared from f0 and load avg + self.shared_ffn: FeedForward = copy.deepcopy(f0) + if FOR_upcycle_ckpt_GEN_or_USE: + avg_sd = _average_state_dict(l_ffn) + self.shared_ffn.load_state_dict(avg_sd) + # freeze shared + for p in self.shared_ffn.parameters(): + p.requires_grad = False + # discover inner layers + self.is_glu = isinstance(self.shared_ffn.net[0], GEGLU) + if self.is_glu: + in_linear: nn.Linear = self.shared_ffn.net[0].proj + else: + assert isinstance(self.shared_ffn.net[0], nn.Sequential) + in_linear: nn.Linear = self.shared_ffn.net[0][0] + out_linear: nn.Linear = self.shared_ffn.net[2] + self.in_features = in_linear.in_features + self.mid_features = in_linear.out_features + self.out_features = out_linear.out_features + if 1: # cal/read adaptive rank across tasks + if FOR_upcycle_ckpt_GEN_or_USE: + w_diff_in_list = [] + w_diff_out_list = [] + for f in l_ffn: + if self.is_glu: + tin: nn.Linear = f.net[0].proj + else: + tin: nn.Linear = f.net[0][0] + tout: nn.Linear = f.net[2] + w_diff_in_list.append(tin.weight.data - in_linear.weight.data) + w_diff_out_list.append(tout.weight.data - out_linear.weight.data) + if FORCE_SAME_RANK_ACROSS_TASKS: + rank_in = compute_adaptive_rank_for_linear_diffs(w_diff_in_list) + rank_out = compute_adaptive_rank_for_linear_diffs(w_diff_out_list) + global_.moduleName_2_adaRank[module_name] = [rank_in, rank_out] + else: + ranks_in = compute_adaptive_rank_for_linear_diffs(w_diff_in_list, per_task=True) + ranks_out = compute_adaptive_rank_for_linear_diffs(w_diff_out_list, per_task=True) + global_.moduleName_2_adaRank[module_name] = [ranks_in, ranks_out] + else: + r_info = global_.moduleName_2_adaRank[module_name] + if FORCE_SAME_RANK_ACROSS_TASKS: rank_in, rank_out = r_info + else: ranks_in, ranks_out = r_info + if 1: + # fallback decision: (1) tiny feature dims + min_dim_in = min(self.in_features, self.mid_features) + min_dim_out = min(self.mid_features, self.out_features) + if (min_dim_in < DONT_lora_if_dim_lt) or (min_dim_out < DONT_lora_if_dim_lt): + # print(f"[LoRA fallback][FFN] {module_name} {min_dim_in=} {min_dim_out=} {DONT_lora_if_dim_lt=}") + self.dont_lora = True; return + # per-task LoRA adapters + _l_in = [] + _l_out = [] + for idx, f in enumerate(l_ffn): + if self.is_glu: + tin: nn.Linear = f.net[0].proj + else: + tin: nn.Linear = f.net[0][0] + tout: nn.Linear = f.net[2] + if not FORCE_SAME_RANK_ACROSS_TASKS: + rank_in = ranks_in[idx] + rank_out = ranks_out[idx] + frac_in = float(rank_in) / min(self.in_features, self.mid_features) + frac_out = float(rank_out) / min(self.mid_features, self.out_features) + frac_avg = 0.5 * (frac_in + frac_out) + if frac_avg > DONT_lora_if_rankFrac_gt: + lora_in = ResidualAdapterLinearOnly(self.in_features, self.mid_features, scaling=1.0, use_bias_delta=True) + lora_out = ResidualAdapterLinearOnly(tout.in_features, tout.out_features, scaling=1.0, use_bias_delta=True) + else: + lora_in = LoRAAdapterLinearOnly(self.in_features, self.mid_features, rank=rank_in, dropout=0.0, scaling=1.0) + lora_out = LoRAAdapterLinearOnly(tout.in_features, tout.out_features, rank=rank_out, dropout=0.0, scaling=1.0) + # init from diffs + if FOR_upcycle_ckpt_GEN_or_USE: + with torch.no_grad(): + w_diff_in = tin.weight.data - in_linear.weight.data + b_diff_in = (tin.bias.data - in_linear.bias.data) if tin.bias is not None else None + lora_in.init_from_diff(w_diff_in, b_diff_in) + w_diff_out = tout.weight.data - out_linear.weight.data + b_diff_out = (tout.bias.data - out_linear.bias.data) if tout.bias is not None else None + lora_out.init_from_diff(w_diff_out, b_diff_out) + _l_in.append(lora_in) + _l_out.append(lora_out) + self.task_lora_in = ModuleDict_W(_l_in, self.tasks) + self.task_lora_out = ModuleDict_W(_l_out, self.tasks) + # reuse dropout and activation behavior + self.dropout_p = self.shared_ffn.net[1].p if isinstance(self.shared_ffn.net[1], nn.Dropout) else 0.0 + self.dropout = nn.Dropout(self.dropout_p) if self.dropout_p > 0 else nn.Identity() + + # Sparse/Dense MoE experts (small inner dim) + gate + if EXTRA_MoE_enable: + small_inner = self.mid_features // EXTRA_MoE_inner_divisor + self.num_moe_expert = EXTRA_MoE_num_ep + gate_in_dim = self.in_features + self.in_features + len(self.tasks) + 2 + 2*NUM_lmk_pick + self.moe_gate_mlp = nn.Linear(gate_in_dim, self.num_moe_expert) + + if EXTRA_MoE_routing_mode == 'dense': + self.moe_experts_batched = BatchedFeedForward( + dim=self.in_features, dim_out=self.out_features, + glu=self.is_glu, dropout=self.dropout_p, + inner_dim=small_inner, num_expert=self.num_moe_expert, + ) + else: + mult = small_inner / self.in_features + experts = [] + for _ in range(self.num_moe_expert): + expert = FeedForward(self.in_features, dim_out=self.out_features, mult=mult, glu=self.is_glu, dropout=self.dropout_p) + experts.append(expert) + self.moe_experts_list = nn.ModuleList(experts) + + if FOR_upcycle_ckpt_GEN_or_USE: + self.verify_approximation(orig_ffn_list=l_ffn) + + def forward(self, x: torch.Tensor, token_pos_grid__cur=None): + t = global_.task + # in linear + LoRA + if self.is_glu: + base = self.shared_ffn.net[0].proj(x) + delta = self.task_lora_in[t](x) + z = base + delta + v, gate = z.chunk(2, dim=-1) + h = v * F.gelu(gate) + else: + base = self.shared_ffn.net[0][0](x) + delta = self.task_lora_in[t](x) + h = F.gelu(base + delta) + h = self.dropout(h) + # out linear + LoRA + y_base = self.shared_ffn.net[2](h) + y_delta = self.task_lora_out[t](h) + y = y_base + y_delta + if EXTRA_MoE_enable: + # gate input + gate_in, _ = build_ffn_gate_input_common(x, token_pos_grid__cur, self.tasks) + scores = self.moe_gate_mlp(gate_in).to(dtype=x.dtype) # b,n,k + if EXTRA_MoE_add_noise and self.training: + scores = scores + torch.randn_like(scores) * EXTRA_MoE_noise_std + scores = torch.softmax(scores, dim=-1) + v_topk, idx_topk = scores.topk(k=EXTRA_MoE_topK, dim=-1) + + if EXTRA_MoE_routing_mode == 'dense': + raise Exception('not carefully checked yet') + else: # sparse: forward only the selected experts and aggregate by top-k weights + if 0: weights_topk = torch.softmax(v_topk, dim=-1) # b,n,topk + else: weights_topk = v_topk # b,n,topk. use top-k expert scores directly as weights + b, n, d = x.shape + dim_out = self.out_features + y_moe_flat = x.new_zeros(b*n, dim_out) # flattened tensor accumulating outputs from all experts (bs*N, D_out) + x_flat = x.reshape(b*n, d) # flatten input tensor (bs*N, D_in) + unique_experts = torch.unique(idx_topk) # set of expert IDs actually selected in this batch + for j in range(EXTRA_MoE_num_ep): # iterate only over active experts + mask_j = (idx_topk == j) # b,n,topk boolean mask indicating which tokens picked expert j + sel_token_mask = mask_j.any(dim=-1) # b,n boolean mask for tokens that selected expert j + flat_pos = sel_token_mask.view(-1).nonzero(as_tuple=False).squeeze(1) # T_j flattened indices of tokens assigned to expert j + if flat_pos.numel() == 0: + continue + x_sel = x_flat.index_select(0, flat_pos) # T_j,d select those tokens from flattened input + # run expert only on selected tokens (n = T_j) + y_sel = self.moe_experts_list[j](x_sel.view(1, x_sel.shape[0], d)).squeeze(0) # T_j,dim_out expert j handles only its tokens + w_tok = (weights_topk * mask_j).sum(dim=-1).view(-1).index_select(0, flat_pos).unsqueeze(-1) # T_j,1 weights for each token assigned to expert j + y_moe_flat.index_add_(0, flat_pos, w_tok * y_sel) # add weighted expert output back into flattened tensor (in-place) + y = y + y_moe_flat.view(b, n, dim_out) # reshape aggregated MoE output and add back to backbone output + if EXTRA_MoE_en_auxLoss and self.training: + raise Exception('not carefully checked yet') + importance = torch.zeros(self.num_moe_expert, device=scores.device, dtype=weights_topk.dtype) + importance = importance.scatter_add(0, idx_topk.reshape(-1), weights_topk.reshape(-1)) + load = torch.zeros(self.num_moe_expert, device=scores.device, dtype=weights_topk.dtype) + load = load.scatter_add(0, idx_topk.reshape(-1), torch.ones_like(weights_topk.reshape(-1))) + k = importance.shape[0] + target_imp = torch.full_like(importance, fill_value=importance.sum() / k) + target_load = torch.full_like(load, fill_value=load.sum() / k) + aux_imp = F.mse_loss(importance, target_imp) + aux_load = F.mse_loss(load, target_load) + aux = 0.5 * (aux_imp + aux_load) * EXTRA_MoE_aux_coef + global_.moe_aux_loss = aux # expose aux loss to the training loop for aggregation + return y + + @torch.no_grad() + def verify_approximation(self, num_tokens: int = 16, batch_size: int = 2, orig_ffn_list: list = None): + if EXTRA_MoE_enable: return + device = next(self.shared_ffn.parameters()).device + dtype = next(self.shared_ffn.parameters()).dtype + x = torch.randn(batch_size, num_tokens, self.in_features, device=device, dtype=dtype) + old_task = getattr(global_, 'task', None) + for i,t in enumerate(self.tasks): + _log2(orig_ffn_list[i], [self.task_lora_in[t], self.task_lora_out[t]]) + global_.task = t + y_lora = self.forward(x) + y_avg = self.shared_ffn(x) + assert orig_ffn_list is not None, "orig_ffn_list must be provided for verification" + y_orig = orig_ffn_list[i](x) + d_avg = torch.norm((y_avg - y_orig).float()).item() + d_lora = torch.norm((y_lora - y_orig).float()).item() + _log1(f"[FFN verify] task={t} rank_in={self.task_lora_in[t].rank} rank_out={self.task_lora_out[t].rank} L2(avg,orig)={d_avg:.6f} L2(lora,orig)={d_lora:.6f}") + global_.task = old_task + + +class Linear_Shared_Plus_TaskLoRA(nn.Module): + def __init__(self, l_proj: list, l_task: list, module_name: str = None): + super().__init__() + # _log1(f"-------- {module_name} --------") + assert len(l_proj) >= 1 + self.dont_lora = False + p0: nn.Linear = l_proj[0] + # build shared from p0 and load avg + self.shared: nn.Linear = copy.deepcopy(p0) + if FOR_upcycle_ckpt_GEN_or_USE: + avg_sd = _average_state_dict(l_proj) + self.shared.load_state_dict(avg_sd) + for p in self.shared.parameters(): + p.requires_grad = False + self.in_features = self.shared.in_features + self.out_features = self.shared.out_features + self.tasks = l_task + # cal/read adaptive rank across tasks + if 1: + if FOR_upcycle_ckpt_GEN_or_USE: + w_diff_list = [] + for lin in l_proj: + w_diff_list.append(lin.weight.data - self.shared.weight.data) + if FORCE_SAME_RANK_ACROSS_TASKS: + rank_lin = compute_adaptive_rank_for_linear_diffs(w_diff_list) + global_.moduleName_2_adaRank[module_name] = rank_lin + else: + ranks_lin = compute_adaptive_rank_for_linear_diffs(w_diff_list, per_task=True) + global_.moduleName_2_adaRank[module_name] = ranks_lin + else: + r_info = global_.moduleName_2_adaRank[module_name] + if FORCE_SAME_RANK_ACROSS_TASKS: rank_lin = r_info + else: ranks_lin = r_info + if 1: # fallback decision for Linear + min_dim = min(self.in_features, self.out_features) + if min_dim < DONT_lora_if_dim_lt: + # print(f"[LoRA fallback][Linear] {module_name} {min_dim=} < {DONT_lora_if_dim_lt}") + self.dont_lora = True; return + _l = [] # per-task LoRA adapters + for idx, lin in enumerate(l_proj): + if not FORCE_SAME_RANK_ACROSS_TASKS: + rank_lin = ranks_lin[idx] + frac = float(rank_lin) / min(self.in_features, self.out_features) + if frac > DONT_lora_if_rankFrac_gt: + lora = ResidualAdapterLinearOnly(self.in_features, self.out_features, scaling=1.0, use_bias_delta=True) + else: + lora = LoRAAdapterLinearOnly(self.in_features, self.out_features, rank=rank_lin, dropout=0.0, scaling=1.0) + if FOR_upcycle_ckpt_GEN_or_USE: + with torch.no_grad(): + w_diff = lin.weight.data - self.shared.weight.data + b_diff = (lin.bias.data - self.shared.bias.data) if (lin.bias is not None and self.shared.bias is not None) else None + lora.init_from_diff(w_diff, b_diff) + _l.append(lora) + self.task_lora = ModuleDict_W(_l, self.tasks) + if FOR_upcycle_ckpt_GEN_or_USE: + self.verify_approximation(orig_linear_list=l_proj) + def forward(self, x: torch.Tensor) -> torch.Tensor: + y = self.shared(x) + y = y + self.task_lora[global_.task](x) + return y + @torch.no_grad() + def verify_approximation(self, batch_size: int = 2, in_dim_override: int = None, orig_linear_list: list = None): + device = next(self.shared.parameters()).device + dtype = next(self.shared.parameters()).dtype + d_in = self.in_features if in_dim_override is None else in_dim_override + x = torch.randn(batch_size, d_in, device=device, dtype=dtype) + old_task = getattr(global_, 'task', None) + for i,t in enumerate(self.tasks): + _log2(orig_linear_list[i], self.task_lora[t]) + global_.task = t + y_lora = self.forward(x) + y_avg = self.shared(x) + assert orig_linear_list is not None, "orig_linear_list must be provided for verification" + y_orig = orig_linear_list[i](x) + d_avg = torch.norm((y_avg - y_orig).float()).item() + d_lora = torch.norm((y_lora - y_orig).float()).item() + _log1(f"[Linear verify] task={t} rank={self.task_lora[t].rank} L2(avg,orig)={d_avg:.6f} L2(lora,orig)={d_lora:.6f}") + global_.task = old_task + +class Conv_Shared_Plus_TaskLoRA(nn.Module): + def __init__(self, l_conv: list, l_task: list, module_name: str = None): + super().__init__() + # _log1(f"-------- {module_name} --------") + assert len(l_conv) >= 1 + self.dont_lora = False + c0: nn.Conv2d = l_conv[0] + # build shared conv + self.shared = nn.Conv2d( + c0.in_channels, c0.out_channels, + kernel_size=c0.kernel_size, stride=c0.stride, + padding=c0.padding, dilation=c0.dilation, + groups=c0.groups, bias=(c0.bias is not None), + padding_mode=c0.padding_mode, + ) + if FOR_upcycle_ckpt_GEN_or_USE: + avg_sd = _average_state_dict(l_conv) + self.shared.load_state_dict(avg_sd) + for p in self.shared.parameters(): + p.requires_grad = False + # per-task LoRA + self.tasks = l_task + _l = [] + # cal/read adaptive rank across tasks + if 1: + if FOR_upcycle_ckpt_GEN_or_USE: + w_diff_list = [] + for c in l_conv: + w_diff_list.append(c.weight.data - self.shared.weight.data) + if FORCE_SAME_RANK_ACROSS_TASKS: + rank_conv = compute_adaptive_rank_for_conv_diffs(w_diff_list) + global_.moduleName_2_adaRank[module_name] = rank_conv + else: + ranks_conv = compute_adaptive_rank_for_conv_diffs(w_diff_list, per_task=True) + global_.moduleName_2_adaRank[module_name] = ranks_conv + else: + r_info = global_.moduleName_2_adaRank[module_name] + if FORCE_SAME_RANK_ACROSS_TASKS: rank_conv = r_info + else: ranks_conv = r_info + if 1: # fallback decision for Conv + kH, kW = self.shared.kernel_size + min_dim = min(self.shared.out_channels, self.shared.in_channels * kH * kW ) + if min_dim < DONT_lora_if_dim_lt: + # print(f"[LoRA fallback][Conv] {module_name} {min_dim=} {DONT_lora_if_dim_lt=} (in={self.shared.in_channels}, out={self.shared.out_channels}, k=({kH},{kW}))") + self.dont_lora = True; return + for idx, c in enumerate(l_conv): + if not FORCE_SAME_RANK_ACROSS_TASKS: + rank_conv = ranks_conv[idx] + frac = float(rank_conv) / min(self.shared.out_channels, self.shared.in_channels * kH * kW) + if frac > DONT_lora_if_rankFrac_gt: + lora = ResidualAdapterConv2dOnly( + in_channels=c.in_channels, out_channels=c.out_channels, + kernel_size=c.kernel_size, stride=c.stride, + padding=c.padding, dilation=c.dilation, groups=c.groups, + scaling=1.0, use_bias_delta=True, + ) + else: + lora = LoRAAdapterConv2dOnly( + in_channels=c.in_channels, out_channels=c.out_channels, + kernel_size=c.kernel_size, stride=c.stride, + padding=c.padding, dilation=c.dilation, groups=c.groups, + rank=rank_conv, dropout=0.0, scaling=1.0, + ) + if FOR_upcycle_ckpt_GEN_or_USE: + with torch.no_grad(): + w_diff = c.weight.data - self.shared.weight.data + b_diff = (c.bias.data - self.shared.bias.data) if c.bias is not None and self.shared.bias is not None else None + lora.init_from_diff(w_diff, b_diff) + _l.append(lora) + self.task_lora = ModuleDict_W(_l, self.tasks) + + if FOR_upcycle_ckpt_GEN_or_USE: + self.verify_approximation(orig_conv_list=l_conv) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + y = self.shared(x) + y = y + self.task_lora[global_.task](x) + return y + + @torch.no_grad() + def verify_approximation(self, spatial_hw=(32, 32), batch_size: int = 2, orig_conv_list: list = None): + device = next(self.shared.parameters()).device + dtype = next(self.shared.parameters()).dtype + H, W = spatial_hw + x = torch.randn(batch_size, self.shared.in_channels, H, W, device=device, dtype=dtype) + old_task = getattr(global_, 'task', None) + for i,t in enumerate(self.tasks): + _log2(orig_conv_list[i], self.task_lora[t]) + global_.task = t + y_lora = self.forward(x) + y_avg = self.shared(x) + assert orig_conv_list is not None, "orig_conv_list must be provided for verification" + y_orig = orig_conv_list[i](x) + d_avg = torch.norm((y_avg - y_orig).float()).item() + d_lora = torch.norm((y_lora - y_orig).float()).item() + _log1(f"[Conv2d verify] task={t} rank={self.task_lora[t].rank} L2(avg,orig)={d_avg:.6f} L2(lora,orig)={d_lora:.6f}") + global_.task = old_task diff --git a/my_py_lib/.gitignore b/my_py_lib/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..076019af3b086696ceb327a7816f2e897a69b17a --- /dev/null +++ b/my_py_lib/.gitignore @@ -0,0 +1,23 @@ + + + + + + + + +# ---------------- below from https://stackoverflow.com/questions/11852558/gitignore-only-allow-certain-extensions-and-files +* +!.gitattributes +!.gitignore +!readme.md +!.gitkeep +!*.py +!*/ +#-------------------- + +ttt*.py +__pycache__ +.idea + +bin_recycle diff --git a/my_py_lib/cv2_util.py b/my_py_lib/cv2_util.py new file mode 100644 index 0000000000000000000000000000000000000000..6d032660bd997612d57c8247851388d96c4fc602 --- /dev/null +++ b/my_py_lib/cv2_util.py @@ -0,0 +1,181 @@ +import cv2 +import numpy as np +from typing import Tuple, Optional +from skimage.io import imread, imsave + +def add_text_to_image( + image_rgb: np.ndarray, + label: str, + top_left_xy: Tuple = (0, 0), + font_scale: float = 1, + font_thickness: float = 1, + font_face=cv2.FONT_HERSHEY_SIMPLEX, + font_color_rgb: Tuple = (0, 0, 255), + bg_color_rgb: Optional[Tuple] = None, + outline_color_rgb: Optional[Tuple] = None, + line_spacing: float = 1, +): + """ + from https://stackoverflow.com/questions/27647424/opencv-puttext-new-line-character + """ + """ + Adds text (including multi line text) to images. + You can also control background color, outline color, and line spacing. + + outline color and line spacing adopted from: https://gist.github.com/EricCousineau-TRI/596f04c83da9b82d0389d3ea1d782592 + """ + OUTLINE_FONT_THICKNESS = 3 * font_thickness + + im_h, im_w = image_rgb.shape[:2] + + for line in label.splitlines(): + x, y = top_left_xy + + # ====== get text size + if outline_color_rgb is None: + get_text_size_font_thickness = font_thickness + else: + get_text_size_font_thickness = OUTLINE_FONT_THICKNESS + + (line_width, line_height_no_baseline), baseline = cv2.getTextSize( + line, + font_face, + font_scale, + get_text_size_font_thickness, + ) + line_height = line_height_no_baseline + baseline + + if bg_color_rgb is not None and line: + # === get actual mask sizes with regard to image crop + if im_h - (y + line_height) <= 0: + sz_h = max(im_h - y, 0) + else: + sz_h = line_height + + if im_w - (x + line_width) <= 0: + sz_w = max(im_w - x, 0) + else: + sz_w = line_width + + # ==== add mask to image + if sz_h > 0 and sz_w > 0: + bg_mask = np.zeros((sz_h, sz_w, 3), np.uint8) + bg_mask[:, :] = np.array(bg_color_rgb) + image_rgb[ + y: y + sz_h, + x: x + sz_w, + ] = bg_mask + + # === add outline text to image + if outline_color_rgb is not None: + image_rgb = cv2.putText( + image_rgb, + line, + (x, y + line_height_no_baseline), # putText start bottom-left + font_face, + font_scale, + outline_color_rgb, + OUTLINE_FONT_THICKNESS, + cv2.LINE_AA, + ) + # === add text to image + image_rgb = cv2.putText( + image_rgb, + line, + (x, y + line_height_no_baseline), # putText start bottom-left + font_face, + font_scale, + font_color_rgb, + font_thickness, + cv2.LINE_AA, + ) + top_left_xy = (x, y + int(line_height * line_spacing)) + + return image_rgb + +def putText (img, text, org, fontFace, fontScale, color, thickness=1, lineType=None, bottomLeftOrigin=None): + line_spacing=1 + top_left_xy=org + """ + func that wrap cv2.putText (arg and ret keep the same), but support auto line break + """ + OUTLINE_FONT_THICKNESS = 3 * thickness + im_h, im_w = img.shape[:2] + for line in text.splitlines(): + x, y = top_left_xy + get_text_size_font_thickness = OUTLINE_FONT_THICKNESS + (line_width, line_height_no_baseline), baseline = cv2.getTextSize( + line, + fontFace, + fontScale, + get_text_size_font_thickness, + ) + line_height = line_height_no_baseline + baseline + # === add text to image + img = cv2.putText( + img, + line, + (x, y + line_height_no_baseline), # putText start bottom-left + fontFace, + fontScale, + color, + thickness=thickness, + # lineType=cv2.LINE_AA, + lineType=lineType, + bottomLeftOrigin=bottomLeftOrigin, + ) + top_left_xy = (x, y + int(line_height * line_spacing)) + return img +def putText_B(img, text, org=(5,5), fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.6, color=(255, 100, 100), thickness=1, lineType=None, bottomLeftOrigin=None): + """ + provide many default args + """ + return putText( + img, + text, + org, fontFace, fontScale, color, thickness, + lineType, bottomLeftOrigin, + ) +""" +from gen6d +""" +def concat_images(img0, img1, vert=False): + if not vert: + h0, h1 = img0.shape[0], img1.shape[0], + if h0 < h1: img0 = cv2.copyMakeBorder(img0, 0, h1 - h0, 0, 0, borderType=cv2.BORDER_CONSTANT, value=0) + if h1 < h0: img1 = cv2.copyMakeBorder(img1, 0, h0 - h1, 0, 0, borderType=cv2.BORDER_CONSTANT, value=0) + img = np.concatenate([img0, img1], axis=1) + else: + w0, w1 = img0.shape[1], img1.shape[1] + if w0 < w1: img0 = cv2.copyMakeBorder(img0, 0, 0, 0, w1 - w0, borderType=cv2.BORDER_CONSTANT, value=0) + if w1 < w0: img1 = cv2.copyMakeBorder(img1, 0, 0, 0, w0 - w1, borderType=cv2.BORDER_CONSTANT, value=0) + img = np.concatenate([img0, img1], axis=0) + + return img + + +def concat_images_list(*args, vert=False,max_h=None,max_w=None,img_num_per_row=None): + if len(args) == 1: return args[0] + if img_num_per_row: + if not len(args)%img_num_per_row==0: + args=args+tuple([np.array([[[0,0,0]]])]*(img_num_per_row-len(args)%img_num_per_row)) + args=[concat_images_list(*args[i:i+img_num_per_row],vert=vert,max_h=max_h,max_w=max_w) for i in range(0,len(args),img_num_per_row)] + return concat_images_list(*args,vert=not vert,max_h=max_h,max_w=max_w) + if(max_h is not None): + args=[cv2.resize(img,(int(img.shape[1]*max_h/img.shape[0]),max_h)) if img.shape[0]>max_h else img for img in args] + if(max_w is not None): + args=[cv2.resize(img,(max_w,int(img.shape[0]*max_w/img.shape[1]))) if img.shape[1]>max_w else img for img in args] + img_out = args[0] + for img in args[1:]: + img_out = concat_images(img_out, img, vert) + return img_out +if(__name__=="__main__"): + img=np.zeros((100,100,3),np.uint8) + img=putText(img,"hello\nworld",(img.shape[1]-50,0),cv2.FONT_HERSHEY_SIMPLEX,1,(0,0,255)) + cv2.imshow("img",img) + cv2.waitKey(0) +if(__name__=="__main__"): + img=np.zeros((500,300,3),np.uint8) + img=putText(img,"hello\nworld",(0,0),cv2.FONT_HERSHEY_SIMPLEX,1,(0,0,255)) + cv2.imshow("img",img) + cv2.waitKey(0) \ No newline at end of file diff --git a/my_py_lib/image_util.py b/my_py_lib/image_util.py new file mode 100644 index 0000000000000000000000000000000000000000..f98171990051fda2497a53ec69f951a354dea595 --- /dev/null +++ b/my_py_lib/image_util.py @@ -0,0 +1,467 @@ +import numpy as np +import os,sys,cv2 +import PIL +from PIL import Image +from pathlib import Path + +def to__image_in_npArr(img): + """ + convert PIL/np.ndarray type image to np.ndarray + Equivalent to misc_util.to_ndarray + """ + if isinstance(img, np.ndarray): + return img + if isinstance(img, PIL.Image.Image): + return np.array(img) + import torch + if isinstance(img, torch.Tensor): + return img.detach().cpu().numpy() + raise TypeError("got {}".format(type(img))) +def imgArr_2_objXminYminXmaxYmax(imgArr, bg_color, THRES=5, coarse_bbox=None,diff_type='A'): + """ + param: + imgArr: np.array + bg_color: background color in the form of a tuple (R, G, B) + coarse_bbox: find bbox inside the coarse_bbox + return: + xmin,ymin,xmax,ymax (type= primitive int,NOT np int) + """ + img_array = imgArr + if coarse_bbox is not None: + xmin_coarse, ymin_coarse, xmax_coarse, ymax_coarse = coarse_bbox + img_array = img_array[ymin_coarse:ymax_coarse, xmin_coarse:xmax_coarse] + + if diff_type=='A': + # Extract pixels from the image that are different from the background color + diff_pixels = np.any(np.abs(img_array - np.array(bg_color)) > THRES, axis=2) + elif diff_type=='B': + # Extract pixels from the image that are different from the background color + diff_pixels =( np.sum(np.abs(img_array - np.array(bg_color)) , axis=2)> THRES) + + # Calculate the bounding box of the object + rows = np.any(diff_pixels, axis=1) + cols = np.any(diff_pixels, axis=0) + ymin, ymax = np.where(rows)[0][[0, -1]] + xmin, xmax = np.where(cols)[0][[0, -1]] + xmin=xmin.item() + ymin=ymin.item() + xmax=xmax.item() + ymax=ymax.item() + if coarse_bbox is not None: + xmin += xmin_coarse + ymin += ymin_coarse + xmax += xmin_coarse + ymax += ymin_coarse + + return xmin, ymin, xmax, ymax +def draw_bbox(img, bbox, color=None, thickness=2,bbox_type='x0y0wh'): + """ + xmin,ymin,xmax,ymax + """ + img = np.copy(img) + if color is not None: + color = [int(c) for c in color] + else: + color = (0, 255, 0) + if bbox_type=='x0y0wh': + left = int(round(bbox[0])) + top = int(round(bbox[1])) + width = int(round(bbox[2])) + height = int(round(bbox[3])) + elif bbox_type=='x0y0x1y1': + left,top,right,bottom=bbox + width = right-left + height = bottom-top + img = cv2.rectangle(img, (left, top), (left + width, top + height), color, thickness=thickness) + return img + + + + +def print_image_statistics( + image, + reduce_line:bool = 1, # reduce printed lines by condensing multi-line output + # + return_:bool = False, + print_:bool = True, +): + """ + Print image statistics: + type + dtype and shape + min, max, mean, median, unique values for each channel + """ + string = "----[statistics]----\n" + string += f"type = {type(image)}\n" + image = to__image_in_npArr(image) + string += f"dtype = {image.dtype}\n" + string += f"shape = {image.shape}\n" + + if image.shape[0]==3 or image.shape[0]==4 or image.shape[0]==1: + if image.shape[1] > 13: + print("Assuming the first axis is channel", end=' ') + if len(image.shape) == 2: + raise NotImplementedError + image = image.transpose(1, 2, 0) + print(f"transposed {image.shape=}") + else: + print("[warning] the first axis might be the channel dimension") + if len(image.shape) == 2: + channels = [image] + else: + # channels = np.split(image, image.shape[-1], axis=-1)#poe generated, I cannot understand easily + channels = [image[:, :, i] for i in range(image.shape[-1])] + + for i, channel in enumerate(channels): + uniques=np.unique(channel) + _N=6 + if len(uniques)>_N: + s_uniques = " ".join([f"{x:.3f}" for x in uniques[:_N//2]])# Format the first half with two decimals + s_uniques+=' .. ' + s_uniques += " ".join([f"{x:.3f}" for x in uniques[-_N//2:]]) + else: + s_uniques = " ".join([f"{x:.3f}" for x in uniques]) + if not reduce_line: + string += f"\nChannel {i }:\n" + string += f" Min: {np.min(channel)}\n" + string += f" Max: {np.max(channel)}\n" + string += f" Mean: {np.mean(channel)}\n" + string += f" Median: {np.median(channel)}\n" + string += f" Unique values: {s_uniques}\n" + else: + string += f"Channel {i}: Min={np.min(channel):<8.2f} Max={np.max(channel):<8.2f} Mean={np.mean(channel):<8.2f} Median={np.median(channel):<8.2f} Unique={s_uniques}\n" + if reduce_line: # remove the first few newline characters from string + def remove_first_n_char(text, char, n=3): + modified = text + for _ in range(n): + modified = modified.replace(char, '', 1) + return modified + string = remove_first_n_char(string,'\n') + string=string.replace('\n','\n|') + string += "----[statistics]over----\n" + if print_: + print(string) + if return_: + return string + +def pad_around_center(img, new_size, ): + """ + Pad image to a new size with fill color around image center. + pad with white (255) + """ + img = to__image_in_npArr(img) + assert len(img.shape) == 3 + assert len(new_size) == 2 + + # compute padding + height, width, _ = img.shape + new_height, new_width = new_size + assert new_height >= height + assert new_width >= width + pad_height = new_height - height + pad_width = new_width - width + pad_top = pad_height // 2 + pad_bottom = pad_height - pad_top + pad_left = pad_width // 2 + pad_right = pad_width - pad_left + + # pad image + img = np.pad( + img, + pad_width=((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), + mode="constant", + constant_values=255, + ) + return img + + + +def norm_min0max255_image_per_channel(image, ): + """ + norm to 0-255 for each channel (min=0, max=255 for each channel) + + Args: + image: file path (str) or PIL Image object + + Returns: + normalized PIL Image + """ + if isinstance(image, str): + img_pil = Image.open(image).convert('RGB') + else: + img_pil = image.convert('RGB') + + img_array = np.array(img_pil).astype(np.float32) + + for channel in range(3): + channel_data = img_array[:, :, channel] + c_min = np.min(channel_data) + c_max = np.max(channel_data) + + if c_max > c_min: + img_array[:, :, channel] = (channel_data - c_min) * (255.0 / (c_max - c_min)) + else: # fallback when all channel values are identical + pass + img_array = np.clip(img_array, 0, 255).astype(np.uint8) + if 1: + for c in range(3): + channel_data = img_array[:, :, c] + c_min = np.min(channel_data) + c_max = np.max(channel_data) + # allow up to ±3 absolute error + if abs(c_min-0)>3 or abs(c_max-255)>3: + print_image_statistics(img_array) + assert 0 + img_pil = Image.fromarray(img_array) + return img_pil + +def imgs_2_grid_A( + imgs, # list of RGB images (PIL or numpy arrays) + # any provided mask makes masked pixels lighter across images + masks=None, + # if provided, save grid + grid_path=None, + # other settings + downsample=1, # downsample factor for the grid + inv_mask:bool=False, + resize_mode:str=None, # None | 'mask_to_img' | 'img_to_mask' (resize img to match mask shape) + grid_layout:str="row", # 'row' | 'column' |'square' + auto_pad_if_not_same_size=True, + verbose :int = 1, +): + """ + Create a grid of images from paths, optionally with masks overlaid. + """ + from pathlib import Path + import PIL.Image + import numpy as np + import torchvision.utils as vutils + import torch + + images = [] + for i, img in enumerate(imgs): + if isinstance(img, PIL.Image.Image): + pass + else: + if verbose>0: + print(f"{img.shape=}") + img = to__image_in_npArr(img) + if isinstance(img, np.ndarray): + img = PIL.Image.fromarray(img) + # else: + # raise TypeError(f"Images must be PIL Image or numpy array{type(img)}") + + if not isinstance(img, PIL.Image.Image): + raise TypeError(f"Images must be PIL Image or numpy array{type(img)}") + + img_tensor = torch.tensor(np.array(img).transpose(2, 0, 1)) / 255.0 + + + if masks is not None: + mask = masks[i] + if isinstance(mask, np.ndarray): + mask = PIL.Image.fromarray(mask) + + if not mask.mode == 'L': + mask = mask.convert('L') + + if resize_mode is None: + pass + elif resize_mode == "img_to_mask": + img_tensor = torch.nn.functional.interpolate( + img_tensor.unsqueeze(0), + size=(mask.height, mask.width), + mode='bilinear', + align_corners=False + ).squeeze(0) + elif resize_mode == "mask_to_img": + mask = mask.resize((img_tensor.shape[2], img_tensor.shape[1]), PIL.Image.BILINEAR) + else: + raise NotImplementedError + + mask_np = np.array(mask) / 255.0 + mask_tensor = torch.tensor(mask_np).unsqueeze(0).repeat(3, 1, 1) + if inv_mask: + mask_tensor = 1 - mask_tensor + # make masked pixels lighter + img_tensor = img_tensor * 0.3 + 0.7 * mask_tensor + + # Apply auto padding if needed + if auto_pad_if_not_same_size and i > 0 and (img_tensor.shape[1] != images[0].shape[1] or img_tensor.shape[2] != images[0].shape[2]): + # Resize to match the first image dimensions + img_tensor = torch.nn.functional.interpolate( + img_tensor.unsqueeze(0), + size=(images[0].shape[1], images[0].shape[2]), + mode='bilinear', + align_corners=False + ).squeeze(0) + images.append(img_tensor) + + if grid_layout == "row": + grid_tensor = vutils.make_grid(images, nrow=len(images), ) + elif grid_layout == "column": + grid_tensor = vutils.make_grid(images, nrow=1, ) + elif grid_layout == "square": + grid_tensor = vutils.make_grid(images, nrow=int(np.sqrt(len(images))), ) + else: + raise NotImplementedError + + grid = grid_tensor.numpy().transpose(1, 2, 0) + grid = PIL.Image.fromarray((grid * 255).astype(np.uint8)) + + if downsample > 1: + original_size = grid.size + new_size = (original_size[0] // downsample, original_size[1] // downsample) + grid = grid.resize(new_size, PIL.Image.LANCZOS) + + if grid_path is not None: + grid_path = Path(grid_path) + grid_path.parent.mkdir(parents=False, exist_ok=True) + grid.save(grid_path) + if verbose>-1: + print(f"saved {grid_path}") + + return grid + +def img_paths_2_grid_A( + paths, # paths of rgb img + # any mask option makes masked pixels lighter per image + mask_paths=None, + path_img_2_path_mask=None, # callback to convert RGB image path to mask path + # if provided, save grid + grid_path=None, + # other settings + downsample=1, # downsample factor for the grid + inv_mask:bool=False, + resize_mode:str=None, # None | 'mask_to_img' | 'img_to_mask' (resize image to match mask shape) + grid_layout:str="row", # 'row' | 'column' |'square' + auto_pad_if_not_same_size=True, +): + """ + Create a grid of images from paths, optionally with masks overlaid. + """ + import PIL.Image + + # Load images from paths + imgs = [PIL.Image.open(path).convert('RGB') for path in paths] + + # Load masks if provided + masks = None + if mask_paths is not None: + masks = [PIL.Image.open(mask_path).convert('L') for mask_path in mask_paths] + elif path_img_2_path_mask is not None: + masks = [PIL.Image.open(path_img_2_path_mask(path)).convert('L') for path in paths] + + # Call the img_2_grid_A function + return imgs_2_grid_A( + imgs=imgs, + masks=masks, + grid_path=grid_path, + downsample=downsample, + inv_mask=inv_mask, + resize_mode=resize_mode, + grid_layout=grid_layout, + auto_pad_if_not_same_size=auto_pad_if_not_same_size, + ) + + + +def save_any_A( + a, + path=None, # only valid when !dont_save + dont_save = False, + # log + print_info :bool = True, + value_range: tuple = None, # (min, max) tuple to specify value range, if None then auto determine +): + """ + can auto determine or specify by param: + data shape mode: + ...,1/3/4,h,w ; ...,h,w,1/3/4 ; + value range: + 0-1 ; -1~1 ; 0-255 + + after scaling to 0-255, save a grid containing two images: + scaled image + contrast-adjusted scaled image via linear transform so min=0 and max=255 + """ + a:np.ndarray = to__image_in_npArr(a) + a = a.copy() + if print_info: + import torch; from .torch_util import custom_repr_v3 + print(custom_repr_v3(torch.Tensor(a))) + while(a.ndim>3): + a=a[0] + #-----------now a is chw | hwc -------------------------------------------------------- + if a.ndim > 2: + if a.shape[-3] <= 4: + if a.shape[-3] <= a.shape[-1] and a.shape[-3] <= a.shape[-2]: + # assume the -3 axis is the channel dimension; convert chw -> hwc + a = a.transpose(1, 2, 0) # chw -> hwc + else: # ndim==2 + a = np.expand_dims(a, axis=-1) # hw -> hwc + #-----------now a is hwc -------------------------------------------------------- + if value_range is None: # Auto determine + mean = np.mean(a) + std = np.std(a) + min_ = np.min(a) + max_ = np.max(a) + if a.dtype == np.uint8 or a.dtype == np.int32 or a.dtype == np.int64: + range_ = (0, 255) + elif a.dtype == bool: + range_ = (0, 1) + elif max_ > 100: + range_ = (0, 255) + elif mean > 1: + range_ = (0, 255) + elif min_ <= -1 or mean < 0 : # treat as range -1 to 1 + range_ = (-1, 1) + else: # treat as range 0 to 1 + range_ = (0, 1) + print(f"Auto determined {range_=}") + else: + range_ = value_range + range_min, range_max = range_ + if a.dtype == bool: + a = a.astype(np.uint8) * 255 # bool -> 0/255 + else: + if range_min == 0 and range_max == 255: + pass + else: + # Custom range, normalize to 0~255 + a = (a - range_min) / (range_max - range_min) * 255 + #-----------now a is hwc and scaled to 0~255 -------------------------------------------------------- + if a.shape[-1] == 1: + a = np.repeat(a, 3, axis=-1) + #-----------now a is hwc, 0~255, and channels==3/4 -------------------------------------------------------- + + if 1: # create contrast-adjusted version by linearly mapping min to 0 and max to 255 + a_contrast = a.copy().astype(np.float32) + current_min = np.min(a_contrast) + current_max = np.max(a_contrast) + if current_max > current_min: # avoid division by zero + a_contrast = (a_contrast - current_min) / (current_max - current_min) * 255 + a = np.clip(a, 0, 255).astype(np.uint8) + a_contrast = np.clip(a_contrast, 0, 255).astype(np.uint8) + if dont_save: + path = None + else: + if path is None: + save_dir = Path("/tmp/scy_auto_save") + save_dir.mkdir(exist_ok=True) + import time + timestamp = int(time.time() * 1000) # milliseconds for uniqueness + ext = "jpg" if a.shape[-1] <= 3 else "png" # Use jpg by default if num channels <= 3 + path = save_dir / f"auto_{timestamp}.{ext}" + else: + path = Path(path) + path.parent.mkdir(exist_ok=True) + path = str(path) + grid = imgs_2_grid_A( # create grid with 2 images: original scaled + contrast adjusted + imgs=[a, a_contrast], + grid_path=path, + grid_layout="row", + verbose = -1, + ) + if not dont_save: print(f"{path}") + return grid diff --git a/my_py_lib/misc_util.py b/my_py_lib/misc_util.py new file mode 100644 index 0000000000000000000000000000000000000000..dbe38b611a50042070e08cecf89003ad9862a517 --- /dev/null +++ b/my_py_lib/misc_util.py @@ -0,0 +1,263 @@ + +import os,time +import numpy as np +from pathlib import Path +class ch_cwd_to_this_file: + def __init__(self, _code_file_path): # _code_file_path typically receives __file__ + self._code_file_path = _code_file_path + def __enter__(self): + self._old_dir = os.getcwd() + cwd=os.path.dirname(os.path.abspath(self._code_file_path)) + os.chdir(cwd) + def __exit__(self, exc_type, exc_val, exc_tb): + os.chdir(self._old_dir) +# def img_2_img_full_path(img,format='jpg',original_name_or_path=''): +# """ +# thread safe +# """ +# assert isinstance(img,np.ndarray) +# assert img.shape[2]==3 or img.shape[2]==4 +# original_img_name_without_dir=os.path.basename(original_name_or_path) +# full_path = os.path.join(root_config.path_root, f'./tmp_images/[{root_config.DATASET}][{tmp_cate_or_obj}][{sequence_name}]{img_name_without_suffix}.jpg') +# if not os.path.exists(os.path.dirname(full_path)): +# os.makedirs(os.path.dirname(full_path)) +# print("get_data path:", full_path) +# img.save(full_path) +# return img_full_path + +import datetime +import pytz +def beijing_datetime()->datetime.datetime: + """ + Example: print(f'Current Beijing time = {beijing_time:%Y.%m.%d %H:%M:%S}') + """ + # get the local timezone + local_tz = datetime.datetime.now(datetime.timezone.utc).astimezone().tzinfo + # get Beijing timezone + beijing_tz = pytz.timezone('Asia/Shanghai') + # get the current time + now = datetime.datetime.now() + # convert current time to local timezone + local_time = now.astimezone(local_tz) + # convert local time to Beijing timezone + beijing_time:datetime.datetime = local_time.astimezone(beijing_tz) + return beijing_time +def beijing_str_A( os_is_windows=False)->str: + """ + print( beijing_str_A() ) + """ + ret= f"{beijing_datetime():%m.%d-%H:%M:%S}" + if os_is_windows: + ret=ret.replace(':',':') + return ret + + + + + +# convert numpy or tensor to json/dict +import json +import numpy +import PIL +import torch +from torch import Tensor + + +def to_list_to_primitive(obj): + if isinstance(obj, numpy.ndarray): + return obj.tolist() + if isinstance(obj, torch.Tensor): + return obj.cpu().data.numpy().tolist() + if isinstance(obj, list): + return [to_list_to_primitive(i) for i in obj] + # if isinstance(obj, DataFrame): + # return obj.values.tolist() + elif (isinstance(obj, numpy.int32) or + isinstance(obj, numpy.int64) or + isinstance(obj, numpy.float32) or + isinstance(obj, numpy.float64)): + return obj.item() + elif (isinstance(obj, int) or + isinstance(obj, float) + ): + return obj + else: + raise TypeError("got {}".format(type(obj))) +def to_ndarray(x): + if isinstance(x, numpy.ndarray): + return x + if isinstance(x, torch.Tensor): + return x.cpu().data.numpy() + if isinstance(x, list): + return numpy.array(x) + if isinstance(x, PIL.Image.Image): + return numpy.array(x) + # if isinstance(x, int) or isinstance(x, float): + # return numpy.array([x]) + raise TypeError("got {}".format(type(x))) + +def to_tensor(x): + if isinstance(x, numpy.ndarray): + return torch.from_numpy(x) + if isinstance(x, torch.Tensor): + return x + if isinstance(x, PIL.Image.Image): + return torch.from_numpy(numpy.array(x)) + if isinstance(x, list): + return torch.tensor(x) + # if isinstance(x, int) or isinstance(x, float): + # return torch.tensor([x]) + raise TypeError("got {}".format(type(x))) +def to_pil(x): + import torch + if isinstance(x, PIL.Image.Image): + return x + if isinstance(x, numpy.ndarray): + return PIL.Image.fromarray(x) + if isinstance(x, torch.Tensor): + return PIL.Image.fromarray(x.cpu().data.numpy()) + raise TypeError("got {}".format(type(x))) + + +class myJSONEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, numpy.ndarray): + return obj.tolist() + if isinstance(obj, Tensor): + return obj.cpu().data.numpy().tolist() + elif (isinstance(obj, numpy.int32) or + isinstance(obj, numpy.int64) or + isinstance(obj, numpy.float32) or + isinstance(obj, numpy.float64)): + return obj.item() + elif isinstance(obj,Path): + return str(obj) + return json.JSONEncoder.default(self, obj) + +if(__name__=="__main__"): + import torch + dic = {'x': torch.randn(2, 3), 'rec': numpy.array([[11, 22, 33], [44, 55, 66], [77, 88, 99]])} + s_dic=json.dumps(dic , cls=myJSONEncoder, + sort_keys=True, indent=2, + separators=(',', ': '), ensure_ascii=False) + with open('test.json', 'w', encoding='utf8') as f: + json.dump(dic,f, + # sort_keys=True, + sort_keys=False, + indent=2, separators=(',', ': '), ensure_ascii=False) + + + +def truncate_str(string:str,MAX_LEN:int,suffix_if_truncate="......")->str: + assert isinstance(string,str) + if len(string)> MAX_LEN: + string=string[:MAX_LEN]+suffix_if_truncate + return string +def map_string_to_int(string,MIN,MAX): + """ + Map strings evenly into [MIN, MAX] + """ + assert isinstance(MIN,int) + assert isinstance(MAX,int) + assert MAX-MIN>=2 + # compute ASCII sum + sum = 0 + for char in string: + sum += ord(char) + # print("sum", sum) + ret=2**sum + ret += sum # avoid producing only powers of two + ret=ret%(MAX-MIN) + ret+=MIN + return ret +if 0: + import pprint + def print_optimizer(optimizer): + state_dict=optimizer.state_dict() + param_groups=state_dict['param_groups'] + # for i,param_group in enumerate(param_groups): + pprint.pprint(param_groups) + + +def dic_key_str_2_int(dic: dict) -> dict: + ret = {} + for k, v in dic.items(): + if isinstance(k, str) and k.isdigit(): + k = int(k) + ret[k] = v + return ret +def dic_key_str_2_int__nested(dic: dict) -> dict: + ret = {} + for k, v in dic.items(): + if isinstance(k, str) and k.isdigit(): + k = int(k) + if isinstance(v, dict): + v = dic_key_str_2_int__nested(v) + ret[k] = v + return ret +def dic_list_2_tuple_nested(dic: dict) -> dict:#if k,v is list, to tuple + ret = {} + for k, v in dic.items(): + if isinstance(k, list): + k = tuple(k) + if isinstance(v, list): + v = tuple(v) + if isinstance(v, dict): + v = dic_list_2_tuple_nested(v) + ret[k] = v + return ret + + +import re + +def inverse_fstring(string:str,fmt:str,): + """ + Inverse of string format in python + from https://stackoverflow.com/questions/48536295/inverse-of-string-format-in-python + """ + reg_keys = '{([^{}:]+)[^{}]*}' + reg_fmts = '{[^{}:]+[^{}]*}' + pat_keys = re.compile(reg_keys) + pat_fmts = re.compile(reg_fmts) + + keys = pat_keys.findall(fmt) + lmts = pat_fmts.split(fmt) + temp = string + values = [] + for lmt in lmts: + if not len(lmt)==0: + value,temp = temp.split(lmt,1) + if len(value)>0: + values.append(value) + if len(temp)>0: + values.append(temp) + return dict(zip(keys,values)) +def sort_strings_asc_A(l:list,fmt:str)->list: + """ + fmt: eg. 'home/frame{d}.png' + """ + ret=sorted(l, key= lambda s:int( inverse_fstring(s, fmt )['d']) ) + return ret +from natsort import natsorted +def ls_natsort(folder,re_="*"): + folder = Path(folder) + files = list(folder.glob(re_)) + return natsorted(files ) + return natsorted(files, key=lambda x: x.name) + + + +if __name__=='__main__': + print( beijing_str_A() ) + if 1: + fmt = '{k1:}+{k2:}={k:3}' + res = '1+1=2' + print (inverse_fstring(res,fmt)) + + fmt = '{name:} {age:} {gender}' + res = 'Alice 10 F' + print (inverse_fstring(res,fmt)) + + fmt = 'Hi, {k1:}, this is {k2:}' + res = 'Hi, Alice, this is Bob' + print (inverse_fstring(res,fmt)) diff --git a/my_py_lib/print_util.py b/my_py_lib/print_util.py new file mode 100644 index 0000000000000000000000000000000000000000..700aae8b5a952efc8f0df9e9f1a5cf06a8ebc5c8 --- /dev/null +++ b/my_py_lib/print_util.py @@ -0,0 +1,38 @@ +import random +def print_randomly(a ,p=1): + """ + p is the probability of printing. + """ + if p<1: + if random.random()>=p: + return + print(a) + + +__printed_values = {} +def print_once( a, id_ ): + if id_ not in __printed_values: + print(a) + __printed_values[id_] = True + + +__printed_count = {} +def print_randomly_with_limit( + a, + id_, + p=1, + MAX_prints=5, +): + """ + p: the probability of printing + MAX_prints: the maximum number of times to allow printing of 'a' + """ + if id_ not in __printed_count: + __printed_count[id_] = 0 + if __printed_count[id_] >= MAX_prints: + return + if p < 1: + if random.random() >= p: + return + print(a) + __printed_count[id_] += 1 diff --git a/my_py_lib/torchModuleName_util.py b/my_py_lib/torchModuleName_util.py new file mode 100644 index 0000000000000000000000000000000000000000..3a30c0d63fc5948ddeef17814848d49fc176526b --- /dev/null +++ b/my_py_lib/torchModuleName_util.py @@ -0,0 +1,263 @@ + +from typing import List, Set +from natsort import natsorted +from pathlib import Path + +def pretty_print_torch_module_keys( + keys: list, + indent: int = 4, + # max_part_num: int = 3, + # max_examples: int = 2, + max_part_num: int = 2, + max_examples: int = 1, + show_counts: bool = True +) -> None: + """ + Pretty print PyTorch module keys with hierarchical grouping. + + Args: + keys: List of parameter/buffer keys from state_dict + max_part_num: Maximum number of dot-separated parts to show (0=no truncation) + indent: Number of spaces for indentation + max_examples: Maximum example keys to show per group + show_counts: Whether to show count of keys in each group + """ + # Group keys by their truncated prefix + from collections import defaultdict + groups = defaultdict(list) + for key in keys: + if max_part_num <= 0: # No truncation + groups[key].append(key) + else: + # Split into parts and rejoin the first N parts + parts = key.split('.') + prefix = '.'.join(parts[:max_part_num]) if len(parts) > max_part_num else key + groups[prefix].append(key) + + for prefix, members in sorted(groups.items()): + _s = f"{' ' * indent}{prefix}" + count_str = f" ({len(members)} keys)" if show_counts else "" + # _s += f"{count_str}:" + print(_s) + + # Show example keys (full paths) + examples = members[:max_examples] + for ex in examples: + # print(f"{' ' * (indent * 2)}- {ex[len(prefix):]}") + print(f"{' ' * (indent * 2)}{ex[len(prefix):]}") + + if len(members) > max_examples: + print(f"{' ' * (indent * 2)}... (and {len(members) - max_examples} more)") + + + + +def get_representative_moduleNames( + all_keys: List[str], + ignore_prefixes: tuple = tuple(), + keep_index: int = 0, treat_alpha_digit: bool = True) -> Set[str]: + """ + Filter state dict keys to keep only representative items (specific index in any numbered sequence). + Args: + all_keys: List of all keys from state_dict (all are leaf nodes) + eg. ['learnable_vector', 'model.diffusion_model.time_embed.0.weight', 'model.diffusion_model.time_embed.0.bias', + keep_index: Which index to keep when multiple numbered items exist (default 0 for first) + treat_alpha_digit: If True, also treat letter+digit combinations (e.g., 'attn1', 'attn2') as numbered sequences + Returns: + Set of filtered keys preserving only representative items + """ + import re + if ignore_prefixes: + all_keys = [k for k in all_keys if not any(k.startswith(p) for p in ignore_prefixes)] + num_pattern = re.compile(r'\.(\d+)\.') # Pattern to match numbers in paths (e.g., '.0.', '.1.', etc.) + # Group keys by their pattern (replace numbers with X for grouping) + from collections import defaultdict + groups = defaultdict(list) + + for key in all_keys: + # Create a pattern by replacing all numbers with 'X' + pattern = re.sub(r'\.(\d+)\.', '.X.', key) + # Also handle numbers at the end of the key + pattern = re.sub(r'\.(\d+)$', '.X', pattern) + + if treat_alpha_digit: + # Also replace letter+digit combinations (e.g., 'attn1' -> 'attnX') + pattern = re.sub(r'\.([a-zA-Z]+)(\d+)\.', r'.\1X.', pattern) + pattern = re.sub(r'\.([a-zA-Z]+)(\d+)$', r'.\1X', pattern) + + groups[pattern].append(key) + # print(f"Debug groups: {groups}") + + filtered_keys = [] + for pattern, keys_in_group in groups.items(): + if len(keys_in_group) == 1: + # Only one key in this pattern group - keep it + filtered_keys.extend(keys_in_group) + else: + # Multiple keys - find the one with the target index + def get_numeric_indices(key): + # Extract all numeric indices from the key (pure numbers) + matches = re.findall(r'\.(\d+)(?:\.|$)', key) + indices = [int(x) for x in matches] + + if treat_alpha_digit: + # Also extract indices from letter+digit combinations + alpha_digit_matches = re.findall(r'\.([a-zA-Z]+)(\d+)(?:\.|$)', key) + for _, digit in alpha_digit_matches: + indices.append(int(digit)) + + return tuple(indices) + + # Sort by numeric indices + keys_in_group.sort(key=get_numeric_indices) + + # Try to find the key with the desired index + target_found = False + for key in keys_in_group: + if treat_alpha_digit: + # For alpha+digit mode, check if any alpha+digit combination has the target index + alpha_digit_matches = re.findall(r'\.([a-zA-Z]+)(\d+)(?:\.|$)', key) + for prefix, digit in alpha_digit_matches: + if int(digit) == keep_index: + filtered_keys.append(key) + target_found = True + break + if target_found: + break + else: + # For normal mode, check pure numeric indices + indices = get_numeric_indices(key) + # Check if the first (primary) index matches keep_index + if indices and indices[0] == keep_index: + filtered_keys.append(key) + target_found = True + break + + # If target index not found, fall back to the first available + if not target_found: + filtered_keys.append(keys_in_group[0]) + + filtered_keys = natsorted(filtered_keys) + return filtered_keys + +def get_no_grad_and_has_grad_keys( + model, only_representative: bool = True, + ignore_prefixes: tuple = tuple(), + verbose: int = 1, # for print (not for file save. for save, we log all ) 0,1: only print at last, 2: print at each step + get_representative_moduleNames_at_first :bool = False, + save_path: str = None, # if not None, save detailed log to file +): + # don't use state_dict() (it lacks gradient information) + all_params = dict(model.named_parameters()) + keys = list(all_params.keys()) + + # For file logging, collect all messages + log_messages = [] + + def print_(*msg, verb=1): + if verbose >= verb: + print(*msg) + if save_path is not None: + log_messages.extend(msg) + + if only_representative and get_representative_moduleNames_at_first: + keys = get_representative_moduleNames(keys, ignore_prefixes=ignore_prefixes) + + k_has_grad = [] + k_no_grad = [] # dont require grad or .grad is 0 + + for name in keys: + if name not in all_params: + print_(f"{name} not found in named_parameters (might be buffer)", verb=3) + k_no_grad.append(name) + continue + + param = all_params[name] + if param.requires_grad: + if param.grad is None: + print_(f"{name} has grad but grad is None", verb=3) + k_no_grad.append(name) + elif param.grad.sum() == 0: + print_(f"{name} has grad but grad is 0", verb=3) + k_no_grad.append(name) + else: + print_(f"{name} has grad !=0", verb=4) + k_has_grad.append(name) + else: + k_no_grad.append(name) + if only_representative and not get_representative_moduleNames_at_first: + k_no_grad = get_representative_moduleNames(k_no_grad, ignore_prefixes=ignore_prefixes) + k_has_grad = get_representative_moduleNames(k_has_grad, ignore_prefixes=ignore_prefixes) + + print_("No grad:", verb=2) + for name in k_no_grad: + print_(f" - {name}", verb=2) + print_("Has grad:", verb=2) + if 0: + print_("", verb=2) + else: + for name in k_has_grad: + print_(f" - {name}", verb=2) + print_(f"Total: {len(k_no_grad) + len(k_has_grad)} {len(k_has_grad)=}", verb=1) + + if save_path is not None: + Path(save_path).write_text('\n'.join(log_messages), encoding='utf-8') # !diskW + print(f"> {save_path}") + + return k_has_grad, k_no_grad + + + + + + +if __name__=='__main__': + # Example usage: + all_keys = [ + 'face_ID_model.facenet.input_layer.0.weight', + 'face_ID_model.facenet.input_layer.1.weight', + 'face_ID_model.facenet.input_layer.1.bias', + 'face_ID_model.facenet.input_layer.1.running_mean', + 'face_ID_model.facenet.input_layer.1.running_var', + 'face_ID_model.facenet.input_layer.1.num_batches_tracked', + 'face_ID_model.facenet.input_layer.2.weight', + + 'learnable_vector', + 'model.diffusion_model_refNet.time_embed.0.weight', + 'model.diffusion_model_refNet.time_embed.0.weight.xxx', + 'model.diffusion_model_refNet.time_embed.0.bias', + 'model.diffusion_model_refNet.time_embed.0.xxxx.0', + 'model.diffusion_model_refNet.time_embed.0.xxxx.1', + 'model.diffusion_model_refNet.time_embed.0.xxxx.2', + + 'model.diffusion_model_refNet.time_embed.1.weight', + 'model.diffusion_model_refNet.time_embed.1.bias', + 'model.diffusion_model_refNet.time_embed.0.submodule.param', + 'model.diffusion_model_refNet.time_embed.1.submodule.param', + 'model.diffusion_model_refNet.input_blocks.0.weight', + 'model.diffusion_model_refNet.input_blocks.1.weight', + 'model.diffusion_model_refNet.middle_block.0.weight', + 'model.diffusion_model_refNet.output_blocks.0.bias', + 'model.diffusion_model_refNet.output_blocks.1.bias', + 'model.diffusion_model_refNet.output_blocks.2.bias', + + 'model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_k.weight', + 'model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0.bias', + 'model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0.weight', + 'model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_q.weight', + 'model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_v.weight', + 'model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_k.weight', + 'model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0.bias', + 'model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_q.weight', + 'model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn3.xxxxx', + ] + + + import torch + sd = torch.load('checkpoints/pretrained.ckpt') + all_keys = sd['state_dict'].keys() + filtered = get_representative_moduleNames(all_keys) + print(f"Filtered representative keys (keep_index=0, default):") + for key in sorted(filtered): + print(f" - {key}") + diff --git a/my_py_lib/torch_util.py b/my_py_lib/torch_util.py new file mode 100644 index 0000000000000000000000000000000000000000..83b4a0177a216e8f4914d8ba28f4b0ee3be9b5aa --- /dev/null +++ b/my_py_lib/torch_util.py @@ -0,0 +1,85 @@ +import torch +def count_model_params(model, log=False)->int: + total_params = sum(p.numel() for p in model.parameters()) + if log: + print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") + return total_params +def list_layers(model): + """ + Lists each layer's name, type, and parameter size in a PyTorch model. + """ + layers = [] + for name, module in model.named_modules(): + if isinstance(module, torch.nn.Sequential): + continue # Skip sequential layers + + layer_info = {} + layer_info["name"] = name + layer_info["type"] = str(type(module)) + + params = sum(p.numel() for p in module.parameters(recurse=False) if p.requires_grad) + layer_info["params"] = params + + layers.append(layer_info) + + return layers + +def recursive_to(data: dict, device: torch.device) -> dict: + """Recursively move all tensors in a nested structure to the target device.""" + for key, value in data.items(): + if isinstance(value, torch.Tensor): + data[key] = value.to(device, non_blocking=True) + elif isinstance(value, dict): + data[key] = recursive_to(value, device) + return data + +def cleanup_gpu_memory(): + import gc + if torch.cuda.is_available(): + gc.collect() # Force garbage collection + torch.cuda.empty_cache() + torch.cuda.synchronize() + gc.collect() + # Clear any remaining cached allocations + if hasattr(torch.cuda, 'reset_peak_memory_stats'): + torch.cuda.reset_peak_memory_stats() + print(f"GPU memory cleaned. Allocated: {torch.cuda.memory_allocated()/1024**3:.2f}GB, " + f"Cached: {torch.cuda.memory_reserved()/1024**3:.2f}GB") + +def custom_repr_v3(self): + stats = [] + if self.numel() > 0: + dtype_str = str(self.dtype).replace('torch.', '') + stats.append(dtype_str) + stats.append(f"μ={self.float().mean().item():.2f}") + stats.append(f"{self.min().item():.2f}~{self.max().item():.2f}") + stats.append(f"med={self.float().median().item():.2f}") + if 1 : + uniques = torch.unique(self.flatten()) + if len(uniques) <= 6: + stats.append(f"uniq={uniques.tolist()}") + else: + stats.append(f"uniq=[{uniques[0].item():.2f},...,{uniques[-1].item():.2f}]") + return f'' + +def to_device(obj, device, *args, **kwargs): + """ + Recursively moves tensors in a nested structure to the specified device, + + Args: + device: The target PyTorch device (e.g., 'cuda:0' or 'cpu'). + *args: + **kwargs: Keyword arguments to be passed to the tensor.to() method + (e.g., non_blocking=True). + + Returns: + The object with all tensors moved to the specified device. + """ + if torch.is_tensor(obj): # Pass the device and any additional arguments to the .to() method + return obj.to(device, *args, **kwargs) + elif isinstance(obj, dict): # Recursively call to_device on each value in the dictionary + return {k: to_device(v, device, *args, **kwargs) for k, v in obj.items()} + elif isinstance(obj, list): # Recursively call to_device on each element in the list + return [to_device(elem, device, *args, **kwargs) for elem in obj] + else: # Return the object unchanged if it's not a tensor, dict, or list + return obj \ No newline at end of file diff --git a/pretrained/face_parsing/__init__.py b/pretrained/face_parsing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pretrained/face_parsing/face_parsing_demo.py b/pretrained/face_parsing/face_parsing_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..51378400e40bc084c614ed275e0120d66b20d124 --- /dev/null +++ b/pretrained/face_parsing/face_parsing_demo.py @@ -0,0 +1,341 @@ +### Facial segmentation mask estimation + +import numpy as np +from PIL import Image +import torch +from torch import nn +import cv2 +import os +import torchvision +from torch.nn import functional as F + + +from pretrained.face_parsing.model import BiSeNet, seg_mean, seg_std + +def __celebAHQ_masks_to_faceParser_mask_detailed(celebA_mask): + """Convert the semantic image of CelebAMaskHQ to reduced categories (12-class). + + Args: + mask (PIL image): with shape [H,W] + Return: + aggrigated mask, with same shape [H,W] but the number of segmentation classes is less + """ + # 19 attributes in total, skin-1,nose-2,...cloth-18, background-0 + celelbAHQ_label_list = ['skin', 'nose', 'eye_g', 'l_eye', 'r_eye', + 'l_brow', 'r_brow', 'l_ear', 'r_ear', 'mouth', + 'u_lip', 'l_lip', 'hair', 'hat', 'ear_r', + 'neck_l', 'neck', 'cloth']# 12 attributes with left-right aggrigation + faceParser_label_list_detailed = ['background', 'lip', 'eyebrows', 'eyes', 'hair', + 'nose', 'skin', 'ears', 'belowface', 'mouth', + 'eye_glass', 'ear_rings'] + + converted_mask = np.zeros_like(celebA_mask) + + backgorund = np.equal(celebA_mask, 0) + converted_mask[backgorund] = 0 + + lip = np.logical_or(np.equal(celebA_mask, 11), np.equal(celebA_mask, 12)) + converted_mask[lip] = 1 + + eyebrows = np.logical_or(np.equal(celebA_mask, 6), + np.equal(celebA_mask, 7)) + converted_mask[eyebrows] = 2 + + eyes = np.logical_or(np.equal(celebA_mask, 4), np.equal(celebA_mask, 5)) + converted_mask[eyes] = 3 + + hair = np.equal(celebA_mask, 13) + converted_mask[hair] = 4 + + nose = np.equal(celebA_mask, 2) + converted_mask[nose] = 5 + + skin = np.equal(celebA_mask, 1) + # print('skin', np.sum(skin)) + converted_mask[skin] = 6 + + ears = np.logical_or(np.equal(celebA_mask, 8), np.equal(celebA_mask, 9)) + converted_mask[ears] = 7 + + belowface = np.equal(celebA_mask, 17) + converted_mask[belowface] = 8 + + mouth = np.equal(celebA_mask, 10) + converted_mask[mouth] = 9 + + eye_glass = np.equal(celebA_mask, 3) + converted_mask[eye_glass] = 10 + + ear_rings = np.equal(celebA_mask, 15) + converted_mask[ear_rings] = 11 + + return converted_mask + +def __ffhq_masks_to_faceParser_mask_detailed(mask, label_mode): + """Convert the esitimated semantic image by face-parsing.PyTorch to reduced categories (12-class). + + Args: + mask (PIL image): with shape [H,W] + Return: + aggrigated mask, with same shape [H,W] but the number of segmentation classes is less + """ + + converted_mask = np.zeros_like(mask) + + backgorund = np.equal(mask, 0) + converted_mask[backgorund] = 0 + + lip = np.logical_or(np.equal(mask, 12), np.equal(mask, 13)) + converted_mask[lip] = 1 + + eyebrows = np.logical_or(np.equal(mask, 2), + np.equal(mask, 3)) + converted_mask[eyebrows] = 2 + + eyes = np.logical_or(np.equal(mask, 4), np.equal(mask, 5)) + converted_mask[eyes] = 3 + + hair = np.equal(mask, 17) + converted_mask[hair] = 4 + + nose = np.equal(mask, 10) + converted_mask[nose] = 5 + + skin = np.equal(mask, 1) + converted_mask[skin] = 6 + + ears = np.logical_or(np.equal(mask, 7), np.equal(mask, 8)) + converted_mask[ears] = 7 + + belowface = np.equal(mask, 14) + converted_mask[belowface] = 8 + + mouth = np.equal(mask, 11) + converted_mask[mouth] = 9 + + eye_glass = np.equal(mask, 6) + converted_mask[eye_glass] = 10 + + ear_rings = np.equal(mask, 9) + converted_mask[ear_rings] = 11 + + if label_mode=="RF12_": # for RF12_ (non-RF12 modes count hat/cloth as background via `converted_mask = np.zeros_like(mask)`) + max_label = int(np.max(mask)) + # print(f"__ff {max_label=}") + # cloth -> 20, hat -> 21 + cloth = np.equal(mask, 16) + converted_mask[cloth] = 20 + hat = np.equal(mask, 18) + converted_mask[hat] = 21 + return converted_mask + +class BicubicDownSample(nn.Module): + def bicubic_kernel(self, x, a=-0.50): + """ + This equation is exactly copied from the website below: + https://clouard.users.greyc.fr/Pantheon/experiments/rescaling/index-en.html#bicubic + """ + abs_x = torch.abs(x) + if abs_x <= 1.: + return (a + 2.) * torch.pow(abs_x, 3.) - (a + 3.) * torch.pow(abs_x, 2.) + 1 + elif 1. < abs_x < 2.: + return a * torch.pow(abs_x, 3) - 5. * a * torch.pow(abs_x, 2.) + 8. * a * abs_x - 4. * a + else: + return 0.0 + + def __init__(self, factor=4, cuda=True, padding='reflect'): + super().__init__() + self.factor = factor + size = factor * 4 + k = torch.tensor([self.bicubic_kernel((i - torch.floor(torch.tensor(size / 2)) + 0.5) / factor) + for i in range(size)], dtype=torch.float32) + k = k / torch.sum(k) + # k = torch.einsum('i,j->ij', (k, k)) + k1 = torch.reshape(k, shape=(1, 1, size, 1)) + self.k1 = torch.cat([k1, k1, k1], dim=0) + k2 = torch.reshape(k, shape=(1, 1, 1, size)) + self.k2 = torch.cat([k2, k2, k2], dim=0) + self.cuda = '.cuda' if cuda else '' + self.padding = padding + for param in self.parameters(): + param.requires_grad = False + + def forward(self, x, nhwc=False, clip_round=False, byte_output=False): + # x = torch.from_numpy(x).type('torch.FloatTensor') + filter_height = self.factor * 4 + filter_width = self.factor * 4 + stride = self.factor + + pad_along_height = max(filter_height - stride, 0) + pad_along_width = max(filter_width - stride, 0) + filters1 = self.k1.type('torch{}.FloatTensor'.format(self.cuda)) + filters2 = self.k2.type('torch{}.FloatTensor'.format(self.cuda)) + + # compute actual padding values for each side + pad_top = pad_along_height // 2 + pad_bottom = pad_along_height - pad_top + pad_left = pad_along_width // 2 + pad_right = pad_along_width - pad_left + + # apply mirror padding + if nhwc: + x = torch.transpose(torch.transpose( + x, 2, 3), 1, 2) # NHWC to NCHW + + # downscaling performed by 1-d convolution + x = F.pad(x, (0, 0, pad_top, pad_bottom), self.padding) + x = F.conv2d(input=x, weight=filters1, stride=(stride, 1), groups=3) + if clip_round: + x = torch.clamp(torch.round(x), 0.0, 255.) + + x = F.pad(x, (pad_left, pad_right, 0, 0), self.padding) + x = F.conv2d(input=x, weight=filters2, stride=(1, stride), groups=3) + if clip_round: + x = torch.clamp(torch.round(x), 0.0, 255.) + + if nhwc: + x = torch.transpose(torch.transpose(x, 1, 3), 1, 2) + if byte_output: + return x.type('torch.ByteTensor'.format(self.cuda)) + else: + return x + + +def vis_parsing_maps(image, parsing_anno, stride=1): + """ Visualize the seg map, along with the original RGB image + + args: + img (PIL.Image): [0, 255] PIL.Image + parsing_anno (np.array): seg map, size [512, 512] + return: + vis_im (np.array): visualization image, cv2 format + """ + # Colors for all 20 parts + part_colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], + [255, 0, 85], [255, 0, 170], + [0, 255, 0], [85, 255, 0], [170, 255, 0], + [0, 255, 85], [0, 255, 170], + [0, 0, 255], [85, 0, 255], [170, 0, 255], + [0, 85, 255], [0, 170, 255], + [255, 255, 0], [255, 255, 85], [255, 255, 170], + [255, 0, 255], [255, 85, 255], [255, 170, 255], + [0, 255, 255], [85, 255, 255], [170, 255, 255]] + + im = image.resize((parsing_anno.shape[0], parsing_anno.shape[1]), Image.BILINEAR) + im = np.array(im) + vis_im = im.copy().astype(np.uint8) + vis_parsing_anno = parsing_anno.copy().astype(np.uint8) + vis_parsing_anno = cv2.resize(vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST) + vis_parsing_anno_color = np.zeros((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255 + + num_of_class = np.max(vis_parsing_anno) + + if 0: + len_= num_of_class + 1 + delta_ = 256/len_ + part_colors = [ [round(i*delta_), round(i*delta_), round(i*delta_)] for i in range(len_) ] + print(f"{part_colors[0]=}") + print(f"{part_colors[-1]=}") + if 0: + print(f"{len(part_colors)=}") + print(f'{np.unique(vis_parsing_anno)=}') + for pi in range(1, num_of_class + 1): + index = np.where(vis_parsing_anno == pi) + vis_parsing_anno_color[index[0], index[1], :] = part_colors[pi] + + vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8) + # print(vis_parsing_anno_color.shape, vis_im.shape) + vis_im = cv2.addWeighted(cv2.cvtColor(vis_im, cv2.COLOR_RGB2BGR), 0.4, vis_parsing_anno_color, 0.6, 0) + # vis_im = cv2.addWeighted(cv2.cvtColor(vis_im, cv2.COLOR_RGB2BGR), 0.1, vis_parsing_anno_color, 0.9, 0) + + return vis_im + + +class FaceParser(nn.Module): + + def __init__(self, seg_ckpt, size = 1024, device="cuda"): + super(FaceParser, self).__init__() + self.seg_ckpt = seg_ckpt + self.size = size + self.device = device + + self.load_segmentation_network() + self.load_downsampling() + + def load_downsampling(self): + self.downsample = BicubicDownSample(factor=self.size // 512) + self.downsample_256 = BicubicDownSample(factor=self.size // 256) + + def load_segmentation_network(self): + self.seg = BiSeNet(n_classes=19) + self.seg.to(self.device) + + self.seg.load_state_dict(torch.load(self.seg_ckpt)) + for param in self.seg.parameters(): + param.requires_grad = False + self.seg.eval() + + def preprocess_img(self, img): + img_orig = img + if img_orig.size[0] >= 512: + im = torchvision.transforms.ToTensor()(img_orig)[:3].unsqueeze(0).to(self.device) + im = (self.downsample(im).clamp(0, 1) - seg_mean) / seg_std + else: + im = img_orig.resize((512, 512), Image.BILINEAR) + im = torchvision.transforms.ToTensor()(im)[:3].unsqueeze(0).to(self.device) + im = (im.clamp(0, 1) - seg_mean) / seg_std + return im + + def forward(self, img): + """To esitimate the facial mask for the given image + Args: + img (PIL.Image): [0, 255] PIL.Image + """ + im = self.preprocess_img(img) + down_seg, _, _ = self.seg(im) + seg = torch.argmax(down_seg, dim=1)[0].long() + + # print(np.unique(seg)) + # cv2.imwrite(os.path.join("./tmp", "mask"+os.path.basename(img_path)), seg.astype(np.uint8)) + # vis_parsing_maps(img_path, seg, stride=1, save_im=True, save_path=os.path.join("./tmp", os.path.basename(img_path))) + + return seg + + +# =============================================== +def init_faceParsing_pretrained_model(faceParser_name, ckpt_path, config_path = ""): + if faceParser_name == "default": + parser = FaceParser(seg_ckpt=ckpt_path) + elif faceParser_name == "segnext": + from mmseg.apis import init_segmentor + parser = init_segmentor(config_path, ckpt_path, device='cuda:0') + + return parser + +def faceParsing_demo(model, img, label_mode, model_name = "default"): + """ + args: + model (Object): Loaded pretrained model + img (PIL.Image): [0, 255] PIL.Image + """ + convert_to_seg12 = True + if label_mode in ("RF12","RF12_"): + assert model_name == "default" + else: + raise + with torch.no_grad(): + if model_name == "default": + seg = model(img).cpu().numpy().astype(np.uint8) + if convert_to_seg12: + seg = __ffhq_masks_to_faceParser_mask_detailed(seg,label_mode) + + elif model_name == "segnext": + from mmseg.apis import inference_segmentor + bgr_img = np.array(img)[:,:,::-1] + seg = inference_segmentor(model, bgr_img)[0] + seg = seg.astype(np.uint8) + + if convert_to_seg12: + seg = __celebAHQ_masks_to_faceParser_mask_detailed(seg) + + return seg diff --git a/pretrained/face_parsing/model.py b/pretrained/face_parsing/model.py new file mode 100644 index 0000000000000000000000000000000000000000..9730fa537b3490f5398460b086383ed9890c3af3 --- /dev/null +++ b/pretrained/face_parsing/model.py @@ -0,0 +1,290 @@ +#!/usr/bin/python +# -*- encoding: utf-8 -*- + + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision +import numpy as np + +from pretrained.face_parsing.resnet import Resnet18 +# from modules.bn import InPlaceABNSync as BatchNorm2d + + +seg_mean = torch.from_numpy(np.array([[0.485, 0.456, 0.406]])).float().cuda().reshape(1,3,1,1) +seg_std = torch.from_numpy(np.array([[0.229, 0.224, 0.225]])).float().cuda().reshape(1,3,1,1) +seg_criterion = nn.CrossEntropyLoss() + + +class ConvBNReLU(nn.Module): + def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs): + super(ConvBNReLU, self).__init__() + self.conv = nn.Conv2d(in_chan, + out_chan, + kernel_size = ks, + stride = stride, + padding = padding, + bias = False) + self.bn = nn.BatchNorm2d(out_chan) + self.init_weight() + + def forward(self, x): + x = self.conv(x) + x = F.relu(self.bn(x)) + return x + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + +class BiSeNetOutput(nn.Module): + def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs): + super(BiSeNetOutput, self).__init__() + self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1) + self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False) + self.init_weight() + + def forward(self, x): + x = self.conv(x) + x = self.conv_out(x) + return x + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + + def get_params(self): + wd_params, nowd_params = [], [] + for name, module in self.named_modules(): + if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): + wd_params.append(module.weight) + if not module.bias is None: + nowd_params.append(module.bias) + elif isinstance(module, nn.BatchNorm2d): + nowd_params += list(module.parameters()) + return wd_params, nowd_params + + +class AttentionRefinementModule(nn.Module): + def __init__(self, in_chan, out_chan, *args, **kwargs): + super(AttentionRefinementModule, self).__init__() + self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1) + self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False) + self.bn_atten = nn.BatchNorm2d(out_chan) + self.sigmoid_atten = nn.Sigmoid() + self.init_weight() + + def forward(self, x): + feat = self.conv(x) + atten = F.avg_pool2d(feat, feat.size()[2:]) + atten = self.conv_atten(atten) + atten = self.bn_atten(atten) + atten = self.sigmoid_atten(atten) + out = torch.mul(feat, atten) + return out + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + + +class ContextPath(nn.Module): + def __init__(self, *args, **kwargs): + super(ContextPath, self).__init__() + self.resnet = Resnet18() + self.arm16 = AttentionRefinementModule(256, 128) + self.arm32 = AttentionRefinementModule(512, 128) + self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) + self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) + self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0) + + self.init_weight() + + def forward(self, x): + H0, W0 = x.size()[2:] + feat8, feat16, feat32 = self.resnet(x) + H8, W8 = feat8.size()[2:] + H16, W16 = feat16.size()[2:] + H32, W32 = feat32.size()[2:] + + avg = F.avg_pool2d(feat32, feat32.size()[2:]) + avg = self.conv_avg(avg) + avg_up = F.interpolate(avg, (H32, W32), mode='nearest') + + feat32_arm = self.arm32(feat32) + feat32_sum = feat32_arm + avg_up + feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest') + feat32_up = self.conv_head32(feat32_up) + + feat16_arm = self.arm16(feat16) + feat16_sum = feat16_arm + feat32_up + feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest') + feat16_up = self.conv_head16(feat16_up) + + return feat8, feat16_up, feat32_up # x8, x8, x16 + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + + def get_params(self): + wd_params, nowd_params = [], [] + for name, module in self.named_modules(): + if isinstance(module, (nn.Linear, nn.Conv2d)): + wd_params.append(module.weight) + if not module.bias is None: + nowd_params.append(module.bias) + elif isinstance(module, nn.BatchNorm2d): + nowd_params += list(module.parameters()) + return wd_params, nowd_params + + +### This is not used, since I replace this with the resnet feature with the same size +class SpatialPath(nn.Module): + def __init__(self, *args, **kwargs): + super(SpatialPath, self).__init__() + self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3) + self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1) + self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1) + self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0) + self.init_weight() + + def forward(self, x): + feat = self.conv1(x) + feat = self.conv2(feat) + feat = self.conv3(feat) + feat = self.conv_out(feat) + return feat + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + + def get_params(self): + wd_params, nowd_params = [], [] + for name, module in self.named_modules(): + if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): + wd_params.append(module.weight) + if not module.bias is None: + nowd_params.append(module.bias) + elif isinstance(module, nn.BatchNorm2d): + nowd_params += list(module.parameters()) + return wd_params, nowd_params + + +class FeatureFusionModule(nn.Module): + def __init__(self, in_chan, out_chan, *args, **kwargs): + super(FeatureFusionModule, self).__init__() + self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0) + self.conv1 = nn.Conv2d(out_chan, + out_chan//4, + kernel_size = 1, + stride = 1, + padding = 0, + bias = False) + self.conv2 = nn.Conv2d(out_chan//4, + out_chan, + kernel_size = 1, + stride = 1, + padding = 0, + bias = False) + self.relu = nn.ReLU(inplace=True) + self.sigmoid = nn.Sigmoid() + self.init_weight() + + def forward(self, fsp, fcp): + fcat = torch.cat([fsp, fcp], dim=1) + feat = self.convblk(fcat) + atten = F.avg_pool2d(feat, feat.size()[2:]) + atten = self.conv1(atten) + atten = self.relu(atten) + atten = self.conv2(atten) + atten = self.sigmoid(atten) + feat_atten = torch.mul(feat, atten) + feat_out = feat_atten + feat + return feat_out + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + + def get_params(self): + wd_params, nowd_params = [], [] + for name, module in self.named_modules(): + if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): + wd_params.append(module.weight) + if not module.bias is None: + nowd_params.append(module.bias) + elif isinstance(module, nn.BatchNorm2d): + nowd_params += list(module.parameters()) + return wd_params, nowd_params + + +class BiSeNet(nn.Module): + def __init__(self, n_classes, *args, **kwargs): + super(BiSeNet, self).__init__() + self.cp = ContextPath() + ## here self.sp is deleted + self.ffm = FeatureFusionModule(256, 256) + self.conv_out = BiSeNetOutput(256, 256, n_classes) + self.conv_out16 = BiSeNetOutput(128, 64, n_classes) + self.conv_out32 = BiSeNetOutput(128, 64, n_classes) + self.init_weight() + + def forward(self, x): + # breakpoint() + H, W = x.size()[2:] + feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature + feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature + feat_fuse = self.ffm(feat_sp, feat_cp8) + + feat_out = self.conv_out(feat_fuse) + feat_out16 = self.conv_out16(feat_cp8) + feat_out32 = self.conv_out32(feat_cp16) + + feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True) + feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True) + feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True) + return feat_out, feat_out16, feat_out32 + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + + def get_params(self): + wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], [] + for name, child in self.named_children(): + child_wd_params, child_nowd_params = child.get_params() + if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput): + lr_mul_wd_params += child_wd_params + lr_mul_nowd_params += child_nowd_params + else: + wd_params += child_wd_params + nowd_params += child_nowd_params + return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params + + +if __name__ == "__main__": + net = BiSeNet(16) + net.cuda() + net.eval() + in_ten = torch.randn(16, 3, 640, 480).cuda() + out, out16, out32 = net(in_ten) + print(out.shape) + + net.get_params() diff --git a/pretrained/face_parsing/resnet.py b/pretrained/face_parsing/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..aa2bf95130e9815ba378cb6f73207068b81a04b9 --- /dev/null +++ b/pretrained/face_parsing/resnet.py @@ -0,0 +1,109 @@ +#!/usr/bin/python +# -*- encoding: utf-8 -*- + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.model_zoo as modelzoo + +# from modules.bn import InPlaceABNSync as BatchNorm2d + +resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth' + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class BasicBlock(nn.Module): + def __init__(self, in_chan, out_chan, stride=1): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(in_chan, out_chan, stride) + self.bn1 = nn.BatchNorm2d(out_chan) + self.conv2 = conv3x3(out_chan, out_chan) + self.bn2 = nn.BatchNorm2d(out_chan) + self.relu = nn.ReLU(inplace=True) + self.downsample = None + if in_chan != out_chan or stride != 1: + self.downsample = nn.Sequential( + nn.Conv2d(in_chan, out_chan, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(out_chan), + ) + + def forward(self, x): + residual = self.conv1(x) + residual = F.relu(self.bn1(residual)) + residual = self.conv2(residual) + residual = self.bn2(residual) + + shortcut = x + if self.downsample is not None: + shortcut = self.downsample(x) + + out = shortcut + residual + out = self.relu(out) + return out + + +def create_layer_basic(in_chan, out_chan, bnum, stride=1): + layers = [BasicBlock(in_chan, out_chan, stride=stride)] + for i in range(bnum-1): + layers.append(BasicBlock(out_chan, out_chan, stride=1)) + return nn.Sequential(*layers) + + +class Resnet18(nn.Module): + def __init__(self): + super(Resnet18, self).__init__() + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1) + self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2) + self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2) + self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2) + self.init_weight() + + def forward(self, x): + x = self.conv1(x) + x = F.relu(self.bn1(x)) + x = self.maxpool(x) + + x = self.layer1(x) + feat8 = self.layer2(x) # 1/8 + feat16 = self.layer3(feat8) # 1/16 + feat32 = self.layer4(feat16) # 1/32 + return feat8, feat16, feat32 + + def init_weight(self): + state_dict = modelzoo.load_url(resnet18_url) + self_state_dict = self.state_dict() + for k, v in state_dict.items(): + if 'fc' in k: continue + self_state_dict.update({k: v}) + self.load_state_dict(self_state_dict) + + def get_params(self): + wd_params, nowd_params = [], [] + for name, module in self.named_modules(): + if isinstance(module, (nn.Linear, nn.Conv2d)): + wd_params.append(module.weight) + if not module.bias is None: + nowd_params.append(module.bias) + elif isinstance(module, nn.BatchNorm2d): + nowd_params += list(module.parameters()) + return wd_params, nowd_params + + +if __name__ == "__main__": + net = Resnet18() + x = torch.randn(16, 3, 224, 224) + out = net(x) + print(out[0].size()) + print(out[1].size()) + print(out[2].size()) + net.get_params() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..0a17e15a08a18edf2b225acd9cb4c25fd13be948 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,37 @@ +pytorch_lightning==1.4.2 +cmake +transformers==4.19.2 +albumentations==1.1.0 +bezier==2023.7.28 +diffusers==0.30.3 +dift==0.0.7 +einops==0.4.1 +face_alignment==1.4.1 +ftfy==6.0.3 +glfw==2.7.0 +imageio==2.14.1 +kornia==0.6.0 +matplotlib==3.7.5 +more_itertools==10.5.0 +moviepy==1.0.3 +natsort==8.4.0 +numpy==1.24.3 +omegaconf==2.1.1 +opencv_python +packaging==24.1 +pandas==2.0.3 +Pillow==9.0.1 +proglog==0.1.10 +pudb==2019.2 +PyOpenGL==3.1.7 +PyYAML==6.0.2 +regex==2024.9.11 +Requests==2.32.3 +scipy==1.9.1 +setuptools==59.5.0 +scikit-image==0.20.0 +streamlit==0.73.1 +tqdm==4.66.5 +typing_extensions==4.12.2 +torchmetrics==0.6.0 +mediapipe==0.10.21 diff --git a/requirements_space.txt b/requirements_space.txt new file mode 100644 index 0000000000000000000000000000000000000000..76b76fabd36f2edd608596c3a0b4c4141050b45d --- /dev/null +++ b/requirements_space.txt @@ -0,0 +1,2 @@ +gradio>=4.0 +huggingface_hub diff --git a/setup.sh b/setup.sh new file mode 100644 index 0000000000000000000000000000000000000000..f9b8728ec61c008b4fba017b0dada4db630a61b0 --- /dev/null +++ b/setup.sh @@ -0,0 +1,6 @@ + +# install correct version of torch and torchvision according to your cuda version +pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu117 +pip install -r requirements.txt +pip install -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers +pip install -e git+https://github.com/openai/CLIP.git@main#egg=clip diff --git a/src/Face_models/encoders/__init__.py b/src/Face_models/encoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/Face_models/encoders/helpers.py b/src/Face_models/encoders/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..b6fce4005b9d4bdad75d74a6f4fcad248da447d8 --- /dev/null +++ b/src/Face_models/encoders/helpers.py @@ -0,0 +1,144 @@ +from collections import namedtuple +import torch +from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module, InstanceNorm2d + +""" +ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) +""" + + +class Flatten(Module): + def forward(self, input): + return input.view(input.size(0), -1) + + +def l2_norm(input, axis=1): + norm = torch.norm(input, 2, axis, True) + output = torch.div(input, norm) + return output + + +class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): + """ A named tuple describing a ResNet block. """ + + +def get_block(in_channel, depth, num_units, stride=2): + return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] + + +def get_blocks(num_layers): + if num_layers == 50: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=4), + get_block(in_channel=128, depth=256, num_units=14), + get_block(in_channel=256, depth=512, num_units=3) + ] + elif num_layers == 100: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=13), + get_block(in_channel=128, depth=256, num_units=30), + get_block(in_channel=256, depth=512, num_units=3) + ] + elif num_layers == 152: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=8), + get_block(in_channel=128, depth=256, num_units=36), + get_block(in_channel=256, depth=512, num_units=3) + ] + else: + raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers)) + return blocks + + +class SEModule(Module): + def __init__(self, channels, reduction): + super(SEModule, self).__init__() + self.avg_pool = AdaptiveAvgPool2d(1) + self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False) + self.relu = ReLU(inplace=True) + self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False) + self.sigmoid = Sigmoid() + + def forward(self, x): + module_input = x + x = self.avg_pool(x) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.sigmoid(x) + return module_input * x + + +class bottleneck_IR(Module): + def __init__(self, in_channel, depth, stride): + super(bottleneck_IR, self).__init__() + if in_channel == depth: + self.shortcut_layer = MaxPool2d(1, stride) + else: + self.shortcut_layer = Sequential( + Conv2d(in_channel, depth, (1, 1), stride, bias=False), + BatchNorm2d(depth) + ) + self.res_layer = Sequential( + BatchNorm2d(in_channel), + Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth), + Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth) + ) + + def forward(self, x): + shortcut = self.shortcut_layer(x) + res = self.res_layer(x) + return res + shortcut + + +class bottleneck_IR_SE(Module): + def __init__(self, in_channel, depth, stride): + super(bottleneck_IR_SE, self).__init__() + if in_channel == depth: + self.shortcut_layer = MaxPool2d(1, stride) + else: + self.shortcut_layer = Sequential( + Conv2d(in_channel, depth, (1, 1), stride, bias=False), + BatchNorm2d(depth) + ) + self.res_layer = Sequential( + BatchNorm2d(in_channel), + Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), + PReLU(depth), + Conv2d(depth, depth, (3, 3), stride, 1, bias=False), + BatchNorm2d(depth), + SEModule(depth, 16) + ) + + def forward(self, x): + shortcut = self.shortcut_layer(x) + res = self.res_layer(x) + return res + shortcut + + +class bottleneck_IR_SE_Ours(Module): + def __init__(self, in_channel, depth, stride): + super(bottleneck_IR_SE_Ours, self).__init__() + if in_channel == depth: + self.shortcut_layer = MaxPool2d(1, stride) + else: + self.shortcut_layer = Sequential( + Conv2d(in_channel, depth, (1, 1), stride, bias=False), + InstanceNorm2d(depth) + ) + self.res_layer = Sequential( + InstanceNorm2d(in_channel), + Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), + PReLU(depth), + Conv2d(depth, depth, (3, 3), stride, 1, bias=False), + InstanceNorm2d(depth), + SEModule(depth, 16) + ) + + def forward(self, x): + shortcut = self.shortcut_layer(x) + res = self.res_layer(x) + return res + shortcut diff --git a/src/Face_models/encoders/model_irse.py b/src/Face_models/encoders/model_irse.py new file mode 100644 index 0000000000000000000000000000000000000000..67f9fa48c60176ee844a10f10c0a89c0b859f6f6 --- /dev/null +++ b/src/Face_models/encoders/model_irse.py @@ -0,0 +1,105 @@ +from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module +from src.Face_models.encoders.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm + +""" +Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) +""" + + +class Backbone(Module): + def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True): + super(Backbone, self).__init__() + assert input_size in [112, 224], "input_size should be 112 or 224" + assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152" + assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se" + blocks = get_blocks(num_layers) + if mode == 'ir': + unit_module = bottleneck_IR + elif mode == 'ir_se': + unit_module = bottleneck_IR_SE + self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), + BatchNorm2d(64), + PReLU(64)) + if input_size == 112: + self.output_layer = Sequential(BatchNorm2d(512), + Dropout(drop_ratio), + Flatten(), + Linear(512 * 7 * 7, 512), + BatchNorm1d(512, affine=affine)) + else: + self.output_layer = Sequential(BatchNorm2d(512), + Dropout(drop_ratio), + Flatten(), + Linear(512 * 14 * 14, 512), + BatchNorm1d(512, affine=affine)) + + modules = [] + for block in blocks: + for bottleneck in block: + modules.append(unit_module(bottleneck.in_channel, + bottleneck.depth, + bottleneck.stride)) + self.body = Sequential(*modules) + + def forward(self, x, multi_scale=False): + x = self.input_layer(x) + + if multi_scale: + # multi-scale features + modulelist = list(self.body._modules.values()) + for i, l in enumerate(modulelist): + x = l(x) + if i == 2: + c1 = x.view(x.size(0), -1) + elif i==6: + c2 = x.view(x.size(0), -1) + elif i == 20: + c3 = x.view(x.size(0), -1) + elif i == 23: + c4 = x.view(x.size(0), -1) + else: + # single-scale processing + x = self.body(x) + + x = self.output_layer(x) + + if multi_scale: + return [l2_norm(c1),l2_norm(c2),l2_norm(c3),l2_norm(c4),l2_norm(x)] + else: + return [l2_norm(x)] + + +def IR_50(input_size): + """Constructs a ir-50 model.""" + model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False) + return model + + +def IR_101(input_size): + """Constructs a ir-101 model.""" + model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False) + return model + + +def IR_152(input_size): + """Constructs a ir-152 model.""" + model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False) + return model + + +def IR_SE_50(input_size): + """Constructs a ir_se-50 model.""" + model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False) + return model + + +def IR_SE_101(input_size): + """Constructs a ir_se-101 model.""" + model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False) + return model + + +def IR_SE_152(input_size): + """Constructs a ir_se-152 model.""" + model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False) + return model diff --git a/src/Face_models/encoders/psp_encoders.py b/src/Face_models/encoders/psp_encoders.py new file mode 100644 index 0000000000000000000000000000000000000000..baca715d404fb80f43432e1aa93c4694003a9f30 --- /dev/null +++ b/src/Face_models/encoders/psp_encoders.py @@ -0,0 +1,309 @@ +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn import Linear, Conv2d, BatchNorm2d, PReLU, Sequential, Module, InstanceNorm2d + +from src.models.encoders.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, bottleneck_IR_SE_Ours +from src.models.stylegan2.model import EqualLinear, EqualConv2d +from src.models.encoders.helpers import get_block + +class GradualStyleBlock(Module): + def __init__(self, in_c, out_c, spatial): + super(GradualStyleBlock, self).__init__() + self.out_c = out_c + self.spatial = spatial + num_pools = int(np.log2(spatial)) + modules = [] + modules += [Conv2d(in_c, out_c, kernel_size=3, stride=2, padding=1), + nn.LeakyReLU()] + for i in range(num_pools - 1): + modules += [ + Conv2d(out_c, out_c, kernel_size=3, stride=2, padding=1), + nn.LeakyReLU() + ] + self.convs = nn.Sequential(*modules) + self.linear = EqualLinear(out_c, out_c, lr_mul=1) + + def forward(self, x): + x = self.convs(x) + x = x.view(-1, self.out_c) + x = self.linear(x) + return x + + +class GradualStyleEncoder(Module): + def __init__(self, num_layers, mode='ir', opts=None): + super(GradualStyleEncoder, self).__init__() + assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152' + assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' + blocks = get_blocks(num_layers) + if mode == 'ir': + unit_module = bottleneck_IR + elif mode == 'ir_se': + unit_module = bottleneck_IR_SE + self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), + BatchNorm2d(64), + PReLU(64)) + modules = [] + for block in blocks: + for bottleneck in block: + modules.append(unit_module(bottleneck.in_channel, + bottleneck.depth, + bottleneck.stride)) + self.body = Sequential(*modules) + + self.styles = nn.ModuleList() + self.style_count = opts.n_styles + self.coarse_ind = 3 + self.middle_ind = 7 + for i in range(self.style_count): + if i < self.coarse_ind: + style = GradualStyleBlock(512, 512, 16) + elif i < self.middle_ind: + style = GradualStyleBlock(512, 512, 32) + else: + style = GradualStyleBlock(512, 512, 64) + self.styles.append(style) + self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0) + self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0) + + def _upsample_add(self, x, y): + '''Upsample and add two feature maps. + Args: + x: (Variable) top feature map to be upsampled. + y: (Variable) lateral feature map. + Returns: + (Variable) added feature map. + Note in PyTorch, when input size is odd, the upsampled feature map + with `F.upsample(..., scale_factor=2, mode='nearest')` + maybe not equal to the lateral feature map size. + e.g. + original input size: [N,_,15,15] -> + conv2d feature map size: [N,_,8,8] -> + upsampled feature map size: [N,_,16,16] + So we choose bilinear upsample which supports arbitrary output sizes. + ''' + _, _, H, W = y.size() + return F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) + y + + def forward(self, x): + x = self.input_layer(x) + + latents = [] + modulelist = list(self.body._modules.values()) + for i, l in enumerate(modulelist): + x = l(x) + if i == 6: + c1 = x + elif i == 20: + c2 = x + elif i == 23: + c3 = x + + for j in range(self.coarse_ind): + latents.append(self.styles[j](c3)) + + p2 = self._upsample_add(c3, self.latlayer1(c2)) + for j in range(self.coarse_ind, self.middle_ind): + latents.append(self.styles[j](p2)) + + p1 = self._upsample_add(p2, self.latlayer2(c1)) + for j in range(self.middle_ind, self.style_count): + latents.append(self.styles[j](p1)) + + out = torch.stack(latents, dim=1) + return out + + +class BackboneEncoderUsingLastLayerIntoW(Module): + def __init__(self, num_layers, mode='ir', opts=None): + super(BackboneEncoderUsingLastLayerIntoW, self).__init__() + print('Using BackboneEncoderUsingLastLayerIntoW') + assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152' + assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' + blocks = get_blocks(num_layers) + if mode == 'ir': + unit_module = bottleneck_IR + elif mode == 'ir_se': + unit_module = bottleneck_IR_SE + self.input_layer = Sequential(Conv2d(opts.input_nc, 64, (3, 3), 1, 1, bias=False), + BatchNorm2d(64), + PReLU(64)) + self.output_pool = torch.nn.AdaptiveAvgPool2d((1, 1)) + self.linear = EqualLinear(512, 512, lr_mul=1) + modules = [] + for block in blocks: + for bottleneck in block: + modules.append(unit_module(bottleneck.in_channel, + bottleneck.depth, + bottleneck.stride)) + self.body = Sequential(*modules) + + def forward(self, x): + x = self.input_layer(x) + x = self.body(x) + x = self.output_pool(x) + x = x.view(-1, 512) + x = self.linear(x) + return x + + +class BackboneEncoderUsingLastLayerIntoWPlus(Module): + def __init__(self, num_layers, mode='ir', opts=None): + super(BackboneEncoderUsingLastLayerIntoWPlus, self).__init__() + print('Using BackboneEncoderUsingLastLayerIntoWPlus') + assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152' + assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' + blocks = get_blocks(num_layers) + if mode == 'ir': + unit_module = bottleneck_IR + elif mode == 'ir_se': + unit_module = bottleneck_IR_SE + self.n_styles = opts.n_styles + self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), + BatchNorm2d(64), + PReLU(64)) + self.output_layer_2 = Sequential(BatchNorm2d(512), + torch.nn.AdaptiveAvgPool2d((7, 7)), + Flatten(), + Linear(512 * 7 * 7, 512)) + self.linear = EqualLinear(512, 512 * self.n_styles, lr_mul=1) + modules = [] + for block in blocks: + for bottleneck in block: + modules.append(unit_module(bottleneck.in_channel, + bottleneck.depth, + bottleneck.stride)) + self.body = Sequential(*modules) + + def forward(self, x): + x = self.input_layer(x) + x = self.body(x) + x = self.output_layer_2(x) + x = self.linear(x) + x = x.view(-1, self.n_styles, 512) + return x + +class CustomBackboneEncoderUsingLastLayerIntoWPlus(Module): + def __init__(self, num_layers, mode='ir', opts=None): + super(CustomBackboneEncoderUsingLastLayerIntoWPlus, self).__init__() + # print('Using BackboneEncoderUsingLastLayerIntoWPlus') + assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152' + assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' + blocks = get_blocks(num_layers) + if mode == 'ir': + unit_module = bottleneck_IR + elif mode == 'ir_se': + unit_module = bottleneck_IR_SE + self.n_styles = 11 + self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), + BatchNorm2d(64), + PReLU(64)) + self.output_layer_2 = Sequential(BatchNorm2d(512), + torch.nn.AdaptiveAvgPool2d((7, 7)), + Flatten(), + Linear(512 * 7 * 7, 512)) + self.linear = EqualLinear(512, 512 * self.n_styles, lr_mul=1) + + self.structure_linear = EqualConv2d(256, 512, 1) + + modules = [] + for block in blocks: + for bottleneck in block: + modules.append(unit_module(bottleneck.in_channel, + bottleneck.depth, + bottleneck.stride)) + self.body = Sequential(*modules) + + def forward(self, x): + x = self.input_layer(x) + + latents = [] + modulelist = list(self.body._modules.values()) + for i, l in enumerate(modulelist): + x = l(x) + if i == 20: + structure_feats = x + + x = self.output_layer_2(x) + x = self.linear(x) + x = x.view(-1, self.n_styles, 512) + + structure_feats = self.structure_linear(structure_feats) + return x, structure_feats + + +# ================================================================ +class FSEncoder_PSP(Module): + def __init__(self, mode='ir_se', opts=None): + super(FSEncoder_PSP, self).__init__() + assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' + blocks = [ + get_block(in_channel=64, depth=128, num_units=3), + get_block(in_channel=128, depth=256, num_units=4), + get_block(in_channel=256, depth=512, num_units=14), + get_block(in_channel=512, depth=512, num_units=3) + ] + if mode == 'ir': + unit_module = bottleneck_IR + elif mode == 'ir_se': + unit_module = bottleneck_IR_SE_Ours + self.n_styles = 11 + self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), + InstanceNorm2d(64), + PReLU(64)) + modules = [] + for block in blocks: + for bottleneck in block: + modules.append(unit_module(bottleneck.in_channel, + bottleneck.depth, + bottleneck.stride)) + self.body = Sequential(*modules) + + def get_per_comp_styleCode(self, style_feats, segmap): + segmap = F.interpolate(segmap, size=style_feats.size()[2:], mode='nearest') + + b_size = style_feats.shape[0] + f_size = style_feats.shape[1] + s_size = segmap.shape[1] # number of seg classes + + codes_vector = torch.zeros( + (b_size, s_size, f_size), dtype=style_feats.dtype, device=style_feats.device) + + for i in range(b_size): # for each sample + for j in range(s_size): # for each seg cls + component_mask_area = torch.sum(segmap.bool()[i, j]) + + if component_mask_area > 0: + codes_component_feature = style_feats[i].masked_select( + segmap.bool()[i, j]).reshape(f_size, component_mask_area).mean(1) + codes_vector[i][j] = codes_component_feature + + return codes_vector + + def forward(self, x, segmap): + x = self.input_layer(x) + + modulelist = list(self.body._modules.values()) + for i, l in enumerate(modulelist): + x = l(x) + if i == 6: + s1 = x + elif i==20: + s2 = x + elif i == 23: + s3 = x + + # + # + structure_feats = torch.zeros_like(x) + + # style code + code_vectors1 = self.get_per_comp_styleCode(s1,segmap) + code_vectors2 = self.get_per_comp_styleCode(s2,segmap) + code_vectors3 = self.get_per_comp_styleCode(s3,segmap) + + codes_vector = torch.cat([code_vectors1, code_vectors2, code_vectors3], dim=2) + + return codes_vector, structure_feats \ No newline at end of file diff --git a/util_4dataset.py b/util_4dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..2971ddc9cf27de304f9699355112362db6626e46 --- /dev/null +++ b/util_4dataset.py @@ -0,0 +1,174 @@ +import numpy as np +import random +from torchvision import transforms as T +from torch import Tensor +from PIL import Image +import torchvision, torch, cv2 + +def get_tensor(normalize=True, toTensor=True, + mean = (0.5, 0.5, 0.5), + std = (0.5, 0.5, 0.5), +): + transform_list = [] + if toTensor: + transform_list += [torchvision.transforms.ToTensor()] + + if normalize: + transform_list += [torchvision.transforms.Normalize(mean,std)] + return torchvision.transforms.Compose(transform_list) + +def get_tensor_clip(normalize=True, toTensor=True, + mean = (0.48145466, 0.4578275, 0.40821073), + std = (0.26862954, 0.26130258, 0.27577711), +): + transform_list = [] + if toTensor: + transform_list += [torchvision.transforms.ToTensor()] + + if normalize: + transform_list += [torchvision.transforms.Normalize(mean,std)] + return torchvision.transforms.Compose(transform_list) + +def mask_after_npisin__2__tensor(mask_after_npisin: np.ndarray) -> Tensor: + converted_mask = np.zeros_like(mask_after_npisin) + converted_mask[mask_after_npisin] = 255 + mask_tensor = Image.fromarray(converted_mask).convert('L') + mask_tensor = get_tensor(normalize=False, toTensor=True)(mask_tensor) + mask_tensor = T.Resize([512, 512])(mask_tensor) + return mask_tensor + +# Implement perspective warp for reference images +def apply_perspective_warp(img, mask, deg_x, deg_y, + # border_mode=cv2.BORDER_REPLICATE, interpolation=cv2.INTER_LINEAR, + border_mode=cv2.BORDER_CONSTANT, interpolation=cv2.INTER_CUBIC, + constant_border_value=(0,0,0), + fix_edge_artifacts=False, # no noticeable difference +): + """ + Apply a perspective warp transformation to an image and mask + + Args: + img: numpy array of shape (H, W, C) + mask: numpy array of shape (H, W) + max_deg: maximum rotation degree + border_mode: border handling mode (cv2.BORDER_REPLICATE, cv2.BORDER_CONSTANT, etc.) + interpolation: interpolation method (cv2.INTER_LINEAR, cv2.INTER_CUBIC, etc.) + constant_border_value: border color to use with BORDER_CONSTANT + fix_edge_artifacts: Whether to apply additional processing to fix edge artifacts + + Returns: + transformed_img, transformed_mask + """ + h, w = img.shape[:2] + assert img.shape[:2] == mask.shape[:2], f"img shape {img.shape[:2]} != mask shape {mask.shape[:2]}" + + # Convert degrees to radians + rad_x = np.deg2rad(deg_x) + rad_y = np.deg2rad(deg_y) + + # Calculate perspective transform matrix + d = np.sqrt(h**2 + w**2) + eye_to_center = d / (2 * np.tan(np.pi/8)) # approx distance from eye to image center + + # Define the transformation matrix + transform = np.eye(3) + + # Apply rotation around X axis (vertical) + transform = transform @ np.array([ + [1, 0, 0], + [0, np.cos(rad_x), -np.sin(rad_x)], + [0, np.sin(rad_x), np.cos(rad_x)] + ]) + + # Apply rotation around Y axis (horizontal) + transform = transform @ np.array([ + [np.cos(rad_y), 0, np.sin(rad_y)], + [0, 1, 0], + [-np.sin(rad_y), 0, np.cos(rad_y)] + ]) + + # Project 3D points onto 2D plane + pts_3d = np.array([ + [-w/2, -h/2, 0], + [w/2, -h/2, 0], + [w/2, h/2, 0], + [-w/2, h/2, 0] + ]) + + # Apply transformation + pts_3d_transformed = pts_3d @ transform.T + + # Project to 2D + pts_3d_transformed[:, 0] = pts_3d_transformed[:, 0] * eye_to_center / (eye_to_center + pts_3d_transformed[:, 2]) + w/2 + pts_3d_transformed[:, 1] = pts_3d_transformed[:, 1] * eye_to_center / (eye_to_center + pts_3d_transformed[:, 2]) + h/2 + + src_pts = np.array([ + [0, 0], + [w-1, 0], + [w-1, h-1], + [0, h-1] + ], dtype=np.float32) + + dst_pts = np.array([ + [pts_3d_transformed[0, 0], pts_3d_transformed[0, 1]], + [pts_3d_transformed[1, 0], pts_3d_transformed[1, 1]], + [pts_3d_transformed[2, 0], pts_3d_transformed[2, 1]], + [pts_3d_transformed[3, 0], pts_3d_transformed[3, 1]] + ], dtype=np.float32) + + # Get perspective transform matrix + M = cv2.getPerspectiveTransform(src_pts, dst_pts) + + # Apply perspective transformation with specified border mode and interpolation + transformed_img = cv2.warpPerspective(img, M, (w, h), flags=interpolation, + borderMode=border_mode, + borderValue=constant_border_value) + + # For mask, always use nearest neighbor interpolation + transformed_mask = cv2.warpPerspective(mask, M, (w, h), flags=cv2.INTER_NEAREST, + borderMode=border_mode, + borderValue=0) + + # Additional processing to fix edge artifacts + if fix_edge_artifacts: + # Calculate edge detection mask to find problematic areas + edge_mask = np.ones((h, w), dtype=np.uint8) + warped_edge_mask = cv2.warpPerspective(edge_mask, M, (w, h), flags=cv2.INTER_NEAREST, + borderMode=cv2.BORDER_CONSTANT, borderValue=0) + + # Create transition region mask with larger dilation for better handling of edge artifacts + kernel = np.ones((7, 7), np.uint8) + inner_edge = cv2.erode(warped_edge_mask, kernel) + transition_mask = warped_edge_mask - inner_edge + + # Only focus on vertical edges where artifacts are most common + left_margin = 20 + right_margin = 20 + vertical_edge_mask = np.zeros_like(transition_mask) + vertical_edge_mask[:, :left_margin] = transition_mask[:, :left_margin] + vertical_edge_mask[:, -right_margin:] = transition_mask[:, -right_margin:] + + # Apply stronger smoothing specifically to vertical edges + if len(transformed_img.shape) == 3: + # Create a smooth blend from interior to exterior + for i in range(3): # For each color channel + if np.sum(vertical_edge_mask) > 0: + # Apply a stronger blur to vertical edges + blurred = cv2.GaussianBlur(transformed_img, (9, 9), 0) + vertical_edge_mask_3d = np.stack([vertical_edge_mask] * 3, axis=2) / 255.0 + transformed_img = transformed_img * (1 - vertical_edge_mask_3d) + blurred * vertical_edge_mask_3d + + # Apply general edge blending as well + edge_blurred = cv2.GaussianBlur(transformed_img, (3, 3), 0) + transition_mask_3d = np.stack([transition_mask] * 3, axis=2) / 255.0 + transformed_img = transformed_img * (1 - transition_mask_3d) + edge_blurred * transition_mask_3d + + # Ensure the output image is uint8 + if transformed_img.dtype != np.uint8: + transformed_img = np.clip(transformed_img, 0, 255).astype(np.uint8) + + # Ensure the output mask is uint8 + if transformed_mask.dtype != np.uint8: + transformed_mask = np.clip(transformed_mask, 0, 255).astype(np.uint8) + + return transformed_img, transformed_mask diff --git a/util_and_constant.py b/util_and_constant.py new file mode 100644 index 0000000000000000000000000000000000000000..bac61cf803101af2ca960607eaebbd7fdb667365 --- /dev/null +++ b/util_and_constant.py @@ -0,0 +1,225 @@ +from pathlib import Path; import sys, os; from fnmatch import fnmatch +import global_ + +TASKS = (0,1,2,3,) +TP_enable :bool = 1 +world_size_ = int(os.environ.get("WORLD_SIZE", "1")) +rank_ = int(os.environ.get("RANK", "0")) +local_rank_ = int(os.environ.get("LOCAL_RANK", rank_ )) +assert world_size_ >= 1 and 0 <= rank_ < world_size_ + +USE_filter_mediapipe_fail_swap = 1 +CH14 :bool = False +class REFNET: + ENABLE :bool = 1 + CH9 :bool = 0 + task2layerNum = { # actually used as bool now + 0:9, + 1:9, + 2:9, + 3:9, + } +USE_pts :bool = 1 +READ_mediapipe_result_from_cache = 1 +ADAM_or_SGD :bool = False # 1 => AdamW ; 0 => sgd +N_EPOCHS_TRAIN_REF_AND_MID :int = 1 +# ZeRO-1 optimizer sharding (ZeroRedundancyOptimizer). avoid using FSDP, just ZeroRedundancyOptimizer +ZERO1_ENABLE :bool = 0 + + +NUM_token = 257 + + + +if 1: + SD14_filename = "sd-v1-4.ckpt" + SD14_localpath = Path("checkpoints") / SD14_filename + PRETRAIN_CKPT_PATH = f"checkpoints/pretrained.ckpt" + PRETRAIN_JSON_PATH = f"checkpoints/pretrained.json" + +#------------------------------------------- +assert isinstance(TASKS,tuple) +NUM_pts = 95 +global_.TP_enable = TP_enable +global_.rank_ = rank_ + + + +MERGE_CFG_in_one_batch :bool = 1 + +FOR_upcycle_ckpt_GEN_or_USE :bool = 0 + +DEBUG = 0 +DEBUG_skip_load_ckpt :bool = 0 +DBEUG_skip_most_in_Unet_constructor :bool = 0 +# import os; os.environ['CUDA_LAUNCH_BLOCKING'] = '1' +LOG_debug_level = 0 + + +_gate_total_runs = {} +_gate_total_calls = {} +_gate_k2tu = { # id 2 (max_run,interval,prob) + 'vis Dataset_vFrame perspectiveWarp' : ( 0, 1, None ), + 'vis LatentDiffusion.get_input' : ( 0, 5, None ), + 'vis LatentDiffusion.get_input-before_return True' : ( 0, 5, None ), + 'vis LatentDiffusion.get_input-before_return False' : ( 0, 1, None ), + 'vis LatentDiffusion.conditioning_with_feat' : ( 0, 2, None ), + 'vis LatentDiffusion.p_losses--after-apply_model' : ( 0, 2, None ), + 'statistics test_batch[0]' : ( 0, 2, None ),#-------------infer----------- + "Project config:" : ( 0, 1, None ),# ------------for printC (arg[0] as id)----------- + "Lightning config:" : ( 0, 1, None ), + "logger_cfg=" : ( 0, 1, None ), + "Merged modelckpt-cfg:" : ( 0, 1, None ), + 'bank get' : ( 0, 1, None ), + 'bank set' : ( 0, 1, None ), + 'clear' : ( 0, 1, None ), + 'mean ct:' : ( 0, 1, None ), + "[__iter__]" : ( 1, 1, None ), + "[_create_batches]" : ( 1, 1, None ), + '[set_task_for_MoE]' : ( 1, 1, None ), + 'len_inter' : ( 3, 5, None ), + 'non_paired' : ( 3, 5, None ), + 'ddim rec bg' : ( 4, 5, None ), + '[training step]' : ( 7, 1, None ), + 'LatentDiffusion.configure_optimizers params:' : ( 0, 1, None ), + 'c.shape' : ( 2, 6, None ), + '[conditioning_with_feat return]': ( 0, 6, None ), + 'c for refNet' : ( 0, 6, None ), + 'hair _c.shape:' : ( 0, 1, None ), + 'head _c.shape:' : ( 0, 1, None ), + 'task' : ( 9, 1, None ), + '_t_norm' : ( 9, 1, None ), + 'orig,ID clip,lpips rec lmk:' : (20, 2, None ),#-------------ddim_losses----------- + 'loss_lpips_1 at 0 0 :' : (10, 4, None ), + 'loss_lpips_1 at 0 1 :' : (10, 4, None ), + 'loss_lpips_1 at 0 2 :' : (10, 4, None ), + 'loss_lpips_1 at 1 0 :' : (10, 4, None ), + 'loss_lpips_1 at 1 1 :' : (10, 4, None ), + 'loss_lpips_1 at 1 2 :' : (10, 4, None ), + 'loss_rec_1 at 0 :' : (10, 4, None ), + 'loss_rec_1 at 1 :' : (10, 4, None ), + 'orig, ID clip, lpips rec lmk:' : (10, 4, None ), + 'c_ref True' : ( 3, 5, None ), + 'c_ref False' : ( 1, 1, None ), + 'ffn_gate_input' : ( 3, 3, None ),#-------------MoE----------- + 'vis-ffn_gate_input' : ( 3, 3, None ), + '[warning]: no param to sync' : (10,1, None ),#-------------TP----------- + '[TP] shared sync counts' : (10,1, None ), + '[Conv2d param stats] count, name (sorted desc):': (0,1, None ),#-------------upcycle----------- + 'avg full_name=' : ( 0, 1, None ), +} +def gate_(id_, *args, **kw, ): # gate for some vis or print behaviour, just for vis/debug + # return 0 + if 1 and not ( hasattr(global_,'TP_enable') and global_.TP_enable ): + import torch.distributed as dist + if dist.is_available() and dist.is_initialized(): + if dist.get_rank()!=0: return + + global _gate_total_runs, _gate_total_calls + tu = _gate_k2tu.get(id_, None) + if tu is None: + return 0 + max_run, interval, prob = tu + if max_run==0: + return 0 + if id_ not in _gate_total_runs: # Initialize counters for this ID if not present + _gate_total_runs[id_] = 0 + _gate_total_calls[id_] = 0 + if _gate_total_runs[id_] >= max_run: # Check if we've reached the maximum runs + return False + _gate_total_calls[id_] += 1 + if _gate_total_calls[id_] % interval != 0: + return False + if prob is not None: + import random + if random.random() > prob: + return False + _gate_total_runs[id_] += 1 + return True + +def str_t(): # eg. '0608-17.12.30' + from datetime import datetime + now = datetime.now() + month_day = f"{now.month:02d}{now.day:02d}" + hour_min_second = f"{now.hour:02d}.{now.minute:02d}.{now.second:02d}" + ret = f"{month_day}-{hour_min_second}" + return ret +def str_t_pid(): # eg. '0608-17.12.30-180165' + if hasattr(global_,'TP_enable') and global_.TP_enable: + _suffix = global_.rank_ + else: + _suffix = os.getpid() + return f"{str_t()}-{_suffix}" + +def printC(*args, **kw): # controled print + if gate_(args[0]): + return print(*args, **kw) + +#-------------------- + +os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' # disable tf onednn-related warnings +# from skimage.io import imsave + + +def path_img_2_path_mask( path_img, check_mask_exists = 1 , reuse_if_exists = True, label_mode="RF12_"): + assert label_mode=="RF12_" + assert label_mode in ('RF12',"RF12_",), label_mode + assert 'semantic_mask' not in str(path_img), path_img + path_img = Path(path_img) + if 1: + _suffix = { + # "RF12" :'-semantic_mask', + "RF12_":'-semantic_mask', + }[label_mode] + path_mask = path_img.parent / f"{path_img.stem}{_suffix}.png" + if check_mask_exists or not reuse_if_exists: + if not path_mask.exists() or not reuse_if_exists: + from gen_semantic_mask import gen_semantic_mask + vis_path = None + # vis_path = person_folder.parent.parent / 'vis_semantic_mask' / f"{person_stem}--{path_img.stem}.png" + gen_semantic_mask(path_img, path_mask, label_mode, vis_path, ) + return path_mask + +from my_py_lib.torchModuleName_util import * +if 0: + #-------------------- terminal color (only for exceptions/logging/warnings) + import sys; from IPython.core.ultratb import ColorTB; sys.excepthook = ColorTB() + class _color: # ANSI escape + grey = "\x1b[90m"; green = "\x1b[92m"; yellow = "\x1b[93m" + red = "\x1b[91m"; orange = "\033[38;5;208m"; orange_light = "\033[38;5;214m" + if 1: + import logging + class _CustomFormatter(logging.Formatter): + # format = "%(asctime)s %(filename)s:%(lineno)d %(funcName)s [%(levelname)-8s] %(message)s" + format = "%(asctime)s | %(levelname)-5s | %(message)s" + reset = "\x1b[0m" + FORMATS = { + logging.DEBUG: _color.grey + format + reset, + logging.INFO: _color.green + format + reset, + logging.WARNING: _color.yellow + format + reset, + logging.ERROR: _color.red + format + reset, + logging.CRITICAL: _color.red + format + reset, + } + def format(self, record): + log_fmt = self.FORMATS.get(record.levelno) + formatter = logging.Formatter(log_fmt, datefmt='%H:%M:%S') # <= print only time, not date + return formatter.format(record) + def setup_colored_logging(): + logger = logging.getLogger(); ch = logging.StreamHandler() + if LOG_debug_level: logger.setLevel(logging.DEBUG); ch.setLevel(logging.DEBUG) + else: logger.setLevel(logging.INFO); ch.setLevel(logging.INFO) + ch.setFormatter(_CustomFormatter()); logger.addHandler(ch) + setup_colored_logging() + if 1: + import warnings + def _custom_showwarning(msg, category, filename, lineno, file=None, line=None): + reset = "\x1b[0m"; c_file_line = _color.grey; c_cate = _color.orange; c_msg = _color.yellow + if LOG_debug_level: + formatted_message=f"{c_cate}{category.__name__}{reset}: {c_msg}{msg}{reset} {c_file_line}{filename}:{lineno}{reset}" + else: formatted_message = f"{c_cate}{category.__name__}{reset}: {c_msg}{msg}{reset}" + print(formatted_message) + warnings.showwarning = _custom_showwarning + + if __name__=='__main__': + logging.warning("This is a warning message in yellow"); logging.error("This is an error message in red") + warnings.warn("This is a colored warning message") diff --git a/util_cv2.py b/util_cv2.py new file mode 100644 index 0000000000000000000000000000000000000000..93c42cbaad19bd2e536d77b7b102dc18aaf92ecb --- /dev/null +++ b/util_cv2.py @@ -0,0 +1,21 @@ +from util_and_constant import * +import cv2 +import numpy as np +def auto_interpolation(img:np.ndarray, dst_size:tuple): + if img.shape[0] > dst_size[0] and img.shape[1] > dst_size[1]: + interpolation = cv2.INTER_AREA # value:3 + else: + interpolation = cv2.INTER_LANCZOS4 # value:4 + return interpolation + +_DEBUG_interpolation = 0 # if 1, save before resize to 4debug/cv2_resize_auto_interpolation/{str_t_pid()}.png +def cv2_resize_auto_interpolation(src:np.ndarray, dsize:tuple, interpolation:int=None, **kwargs): + if interpolation is None: + interpolation = auto_interpolation(src, dsize) + ret= cv2.resize(src, dsize, interpolation=interpolation, **kwargs) + if _DEBUG_interpolation and src.shape[0]>1130: + _p = f"4debug/cv2_resize_auto_interpolation/{str_t_pid()}-before.png" + cv2.imwrite(_p, src); print(f"{_p=}") + _p = f"4debug/cv2_resize_auto_interpolation/{str_t_pid()}-after-{interpolation}.png" + cv2.imwrite(_p, ret); print(f"{_p=}") + return ret \ No newline at end of file diff --git a/util_face.py b/util_face.py new file mode 100644 index 0000000000000000000000000000000000000000..537ea227b2e881e5d5acc15e623264982d511a5b --- /dev/null +++ b/util_face.py @@ -0,0 +1,72 @@ +import numpy as np +import cv2 +from PIL import Image +import os +from pathlib import Path +import matplotlib.pyplot as plt +from util_and_constant import * + +def has_glasses(path_img): + mask_path = path_img_2_path_mask(path_img) + mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE) + # if 10 in mask: # slow + if (mask == 10).any(): # vectorized => clearly faster + return True + return False + +def has_hat(path_img): + mask_path = path_img_2_path_mask(path_img, label_mode="RF12_") + mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE) + if (mask == 21).any(): + return True + return False + +def draw_pts70_batch(pts68, gaze, warp_mat256_np, dst_size, im_list=None, return_pt=False): + import torch + import torchvision.transforms as transforms + + left_eye1 = pts68[:, 36] + left_eye2 = pts68[:, 39] + right_eye1 = pts68[:, 42] + right_eye2 = pts68[:, 45] + + right_eye_length = torch.sqrt(torch.sum((right_eye2-right_eye1)**2, dim=1, keepdim=True)) + left_eye_length = torch.sqrt(torch.sum((left_eye2-left_eye1)**2, dim=1, keepdim=True)) + right_eye_center = (right_eye2 + right_eye1) * 0.5 + left_eye_center = (left_eye2 + left_eye1) * 0.5 + + with torch.no_grad(): + left_gaze = gaze[:,:2] * left_eye_length + left_eye_center + right_gaze = gaze[:,2:] * right_eye_length + right_eye_center + pts70 = torch.cat([pts68, left_gaze.view(-1,1,2), right_gaze.view(-1,1,2)],dim=1) + landmarks = pts70.cpu().numpy().round().astype(int) + + colors = plt.get_cmap('rainbow')(np.linspace(0, 1, landmarks.shape[1])) + colors = (255 * colors).astype(int)[:, 0:3].tolist() + + im_pts70_list = [] + if im_list is None: + im_list = [np.zeros((256, 256, 3), dtype=np.uint8) for idx in range(landmarks.shape[0])] + else: + im_list = [np.array(x) for x in im_list] + for idx in range(landmarks.shape[0]): + image = im_list[idx] + + for i in range(landmarks.shape[1]): + x, y = landmarks[idx, i, :] + color = colors[i] + image = cv2.circle(image, (x, y), radius=2, color=(color[2],color[1],color[0]), thickness=-1) + + dst_image = cv2.warpAffine(image, warp_mat256_np[idx], (dst_size, dst_size), flags=(cv2.INTER_LINEAR | cv2.WARP_INVERSE_MAP), borderMode=cv2.BORDER_CONSTANT) + im_pts70_list.append(Image.fromarray(dst_image)) + + if return_pt: + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=0.5, std=0.5) + ]) + tensor_list = [transform(x).view(1,3,dst_size,dst_size) for x in im_pts70_list] + batch = torch.cat(tensor_list,dim=0) + return batch + else: + return im_pts70_list diff --git a/util_vis.py b/util_vis.py new file mode 100644 index 0000000000000000000000000000000000000000..5e334200ecbbf774d7d590a7cc5e4ee39dfe544c --- /dev/null +++ b/util_vis.py @@ -0,0 +1,163 @@ +from pathlib import Path +import cv2 +import numpy as np + +def vis_tensors_A(l_tensor_or_named_tensor, path_grid, vis_batch_size=4, layout='auto'): + """Visualize a list of tensors in a grid layout. + Args: + l_tensor_or_named_tensor: [tensor | (name, tensor), ..]. each tensor: B,(C,)H,W is in [-1,1] range + path_grid: Path object for saving the grid visualization + vis_batch_size: number of samples to visualize + layout: 'BxI' (batch x images) or 'IxB' (images x batch) or 'auto' + """ + import torch + from torchvision.utils import make_grid, save_image + path_grid = Path(path_grid) + path_grid.parent.mkdir(parents=0, exist_ok=True) + # Helper function to unnormalize and prepare images for saving + def prepare_for_vis(tensor, ): + if tensor is None: + return None + shape = tensor.shape + assert shape[1]<=3 + if len(shape)==3 or shape[1]==1: + is_mask = True + else: is_mask = False + if is_mask: + return tensor.repeat(1, 3, 1, 1).cpu() # Expand mask to 3 channels + else: + return (tensor * 0.5 + 0.5).cpu() # Unnormalize from [-1, 1] to [0, 1] + named_tensors = [] + for tensor_or_named_tensor in l_tensor_or_named_tensor: + if isinstance(tensor_or_named_tensor, tuple): + name, tensor = tensor_or_named_tensor + else: + name = "" + tensor = tensor_or_named_tensor + if tensor is not None: + named_tensors.append((name, prepare_for_vis(tensor.detach()[:vis_batch_size], ))) + # Make sure all tensors have the same spatial dimensions + all_shapes = [img.shape[2:] for _, img in named_tensors if img is not None] + if len(set(all_shapes)) > 1: # Pad images to match the largest dimensions + max_h = max(shape[0] for shape in all_shapes) + max_w = max(shape[1] for shape in all_shapes) + for i in range(len(named_tensors)): + name, img = named_tensors[i] + if img is None: + continue + if img.shape[2] == max_h and img.shape[3] == max_w: + continue + pad_h = max_h - img.shape[2] + pad_w = max_w - img.shape[3] + named_tensors[i] = (name, torch.nn.functional.pad(img, (0, pad_w, 0, pad_h), value=0)) + tensors = [] + for _, (name, tensor) in enumerate(named_tensors): + tensor = tensor.detach() + if name: + for b in range(tensor.shape[0]): + # Convert tensor to numpy for OpenCV + img = tensor[b].permute(1, 2, 0).numpy() + img = (img * 255).astype(np.uint8).copy() # Make contiguous copy for OpenCV + # Add text + cv2.putText(img, name, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, + 0.7, (0, 0, 0), 2, cv2.LINE_AA) + cv2.putText(img, name, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, + 0.7, (255, 255, 255), 1, cv2.LINE_AA) + img_tensor = torch.from_numpy(img).permute(2, 0, 1) / 255.0 # Convert back to tensor + tensors.append(img_tensor) + else: + for b in range(tensor.shape[0]): + tensors.append(tensor[b]) + if tensors: # I*B,3,.. + all_images_flat = torch.stack(tensors) # I*B,3,.. + I = len(named_tensors) + B = vis_batch_size + if layout == 'auto': + if B/I > 0.8: + layout = 'IxB' + else: + layout = 'BxI' + if layout == 'BxI': + all_images_nonflat = all_images_flat.reshape(I, B, *all_images_flat.shape[1:]) + all_images_nonflat = all_images_nonflat.permute(1, 0, 2, 3, 4) + all_images_flat = all_images_nonflat.reshape(-1, *all_images_flat.shape[1:]) + nrow = I + else: # 'IxB' + nrow = B + save_image(make_grid(all_images_flat, nrow=nrow), path_grid) + print(f"{path_grid=}") + +def visualize_landmarks(image, landmarks, save_path): + """ + Draw landmarks on an image and save the result. + + Args: + image: Input image as a numpy array (H,W,3) with values in [0,255] + landmarks: Numpy array of shape (136,) or (68,2) containing 68 keypoint coordinates + save_path: Path where the annotated image should be written + """ + # Clone the image and ensure uint8 type + image = image.copy().astype(np.uint8) + + # Ensure the image buffer is contiguous + image = np.ascontiguousarray(image) + + # Reshape landmarks into (68,2) if needed + if landmarks.shape[0] == 136: + landmarks = landmarks.reshape(68, 2) + + # Draw each landmark point + for (x, y) in landmarks: + cv2.circle(image, (int(x), int(y)), 2, (0, 255, 0), -1) + + # Save the annotated image + cv2.imwrite(save_path, cv2.cvtColor(image, cv2.COLOR_RGB2BGR)) + +def visualize_headPose(img_path, yaw, pitch, roll, save_path): + """Visualize pose angles on image using arrows + Args: + img_path: Path to input image + yaw: Yaw angle in degrees + pitch: Pitch angle in degrees + roll: Roll angle in degrees + save_path: Path to save visualization + """ + import matplotlib.pyplot as plt + img = cv2.imread(str(img_path)) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + h, w = img.shape[:2] + center = (w//2, h//2) + + plt.figure(figsize=(10, 10)) + plt.imshow(img) + + # Yaw (left-right) + yaw_rad = np.radians(yaw) + yaw_end = (center[0] + int(100 * np.sin(yaw_rad)), + center[1] - int(100 * np.cos(yaw_rad))) + plt.arrow(center[0], center[1], yaw_end[0]-center[0], yaw_end[1]-center[1], + color='r', width=2, head_width=20, label=f'Yaw: {yaw:.1f}°') + + # Pitch (up-down) + pitch_rad = np.radians(pitch) + pitch_end = (center[0] + int(100 * np.sin(pitch_rad)), + center[1] - int(100 * np.cos(pitch_rad))) + plt.arrow(center[0], center[1], pitch_end[0]-center[0], pitch_end[1]-center[1], + color='g', width=2, head_width=20, label=f'Pitch: {pitch:.1f}°') + + # Roll (tilt) + roll_rad = np.radians(roll) + roll_end = (center[0] + int(100 * np.cos(roll_rad)), + center[1] + int(100 * np.sin(roll_rad))) + plt.arrow(center[0], center[1], roll_end[0]-center[0], roll_end[1]-center[1], + color='b', width=2, head_width=20, label=f'Roll: {roll:.1f}°') + + plt.legend() + plt.axis('off') + + # Save visualization + save_path = Path(save_path) + save_path.parent.mkdir(parents=True, exist_ok=True) + plt.savefig(save_path, bbox_inches='tight', pad_inches=0) + plt.close() + print(f"{save_path=}")