| import copy |
| from typing import Any, Callable, Dict, Iterable, Union |
| import PIL |
| import cv2 |
| import torch |
| import argparse |
| import datetime |
| import logging |
| import inspect |
| import math |
| import os |
| import shutil |
| from typing import Dict, List, Optional, Tuple |
| from pprint import pprint |
| from collections import OrderedDict |
| from dataclasses import dataclass |
| import gc |
| import time |
|
|
| import numpy as np |
| from omegaconf import OmegaConf |
| from omegaconf import SCMode |
| import torch |
| from torch import nn |
| import torch.nn.functional as F |
| import torch.utils.checkpoint |
| from einops import rearrange, repeat |
| import pandas as pd |
| import h5py |
| from diffusers.models.modeling_utils import load_state_dict |
| from diffusers.utils import ( |
| logging, |
| ) |
| from diffusers.utils.import_utils import is_xformers_available |
|
|
| from mmcm.vision.feature_extractor.clip_vision_extractor import ( |
| ImageClipVisionFeatureExtractor, |
| ImageClipVisionFeatureExtractorV2, |
| ) |
| from mmcm.vision.feature_extractor.insight_face_extractor import InsightFaceExtractor |
|
|
| from ip_adapter.resampler import Resampler |
| from ip_adapter.ip_adapter import ImageProjModel |
|
|
| from .unet_loader import update_unet_with_sd |
| from .unet_3d_condition import UNet3DConditionModel |
| from .ip_adapter_loader import ip_adapter_keys_list |
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| |
| unet_keys_list = [ |
| "down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight", |
| "down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight", |
| "down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight", |
| "down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight", |
| "down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight", |
| "down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight", |
| "down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight", |
| "down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight", |
| "down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight", |
| "down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight", |
| "down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight", |
| "down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight", |
| "up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight", |
| "up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight", |
| "up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight", |
| "up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight", |
| "up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight", |
| "up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight", |
| "up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight", |
| "up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight", |
| "up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight", |
| "up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight", |
| "up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight", |
| "up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight", |
| "up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight", |
| "up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight", |
| "up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight", |
| "up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight", |
| "up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight", |
| "up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight", |
| "mid_block.attentions.0.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight", |
| "mid_block.attentions.0.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight", |
| ] |
|
|
|
|
| UNET2IPAadapter_Keys_MAPIING = { |
| k: v for k, v in zip(unet_keys_list, ip_adapter_keys_list) |
| } |
|
|
|
|
| def load_facein_extractor_and_proj_by_name( |
| model_name: str, |
| ip_ckpt: Tuple[str, nn.Module], |
| ip_image_encoder: Tuple[str, nn.Module] = None, |
| cross_attention_dim: int = 768, |
| clip_embeddings_dim: int = 512, |
| clip_extra_context_tokens: int = 1, |
| ip_scale: float = 0.0, |
| dtype: torch.dtype = torch.float16, |
| device: str = "cuda", |
| unet: nn.Module = None, |
| ) -> nn.Module: |
| pass |
|
|
|
|
| def update_unet_facein_cross_attn_param( |
| unet: UNet3DConditionModel, ip_adapter_state_dict: Dict |
| ) -> None: |
| """use independent ip_adapter attn ä¸çš„ to_k, to_v in unet |
| ip_adapter: like ['1.to_k_ip.weight', '1.to_v_ip.weight', '3.to_k_ip.weight']çš„å—å…¸ |
| |
| |
| Args: |
| unet (UNet3DConditionModel): _description_ |
| ip_adapter_state_dict (Dict): _description_ |
| """ |
| pass |
|
|