Spaces:
Runtime error
Runtime error
| import copy | |
| from typing import Any, Callable, Dict, Iterable, Union | |
| import PIL | |
| import cv2 | |
| import torch | |
| import argparse | |
| import datetime | |
| import logging | |
| import inspect | |
| import math | |
| import os | |
| import shutil | |
| from typing import Dict, List, Optional, Tuple | |
| from pprint import pprint | |
| from collections import OrderedDict | |
| from dataclasses import dataclass | |
| import gc | |
| import time | |
| import numpy as np | |
| from omegaconf import OmegaConf | |
| from omegaconf import SCMode | |
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| import torch.utils.checkpoint | |
| from einops import rearrange, repeat | |
| import pandas as pd | |
| import h5py | |
| from diffusers.models.modeling_utils import load_state_dict | |
| from diffusers.utils import ( | |
| logging, | |
| ) | |
| from diffusers.utils.import_utils import is_xformers_available | |
| from mmcm.vision.feature_extractor import clip_vision_extractor | |
| from mmcm.vision.feature_extractor.clip_vision_extractor import ( | |
| ImageClipVisionFeatureExtractor, | |
| ImageClipVisionFeatureExtractorV2, | |
| VerstailSDLastHiddenState2ImageEmb, | |
| ) | |
| from ip_adapter.resampler import Resampler | |
| from ip_adapter.ip_adapter import ImageProjModel | |
| from .unet_loader import update_unet_with_sd | |
| from .unet_3d_condition import UNet3DConditionModel | |
| logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
| def load_vision_clip_encoder_by_name( | |
| ip_image_encoder: Tuple[str, nn.Module] = None, | |
| dtype: torch.dtype = torch.float16, | |
| device: str = "cuda", | |
| vision_clip_extractor_class_name: str = None, | |
| ) -> nn.Module: | |
| if vision_clip_extractor_class_name is not None: | |
| vision_clip_extractor = getattr( | |
| clip_vision_extractor, vision_clip_extractor_class_name | |
| )( | |
| pretrained_model_name_or_path=ip_image_encoder, | |
| dtype=dtype, | |
| device=device, | |
| ) | |
| else: | |
| vision_clip_extractor = None | |
| return vision_clip_extractor | |
| def load_ip_adapter_image_proj_by_name( | |
| model_name: str, | |
| ip_ckpt: Tuple[str, nn.Module] = None, | |
| cross_attention_dim: int = 768, | |
| clip_embeddings_dim: int = 1024, | |
| clip_extra_context_tokens: int = 4, | |
| ip_scale: float = 0.0, | |
| dtype: torch.dtype = torch.float16, | |
| device: str = "cuda", | |
| unet: nn.Module = None, | |
| vision_clip_extractor_class_name: str = None, | |
| ip_image_encoder: Tuple[str, nn.Module] = None, | |
| ) -> nn.Module: | |
| if model_name in [ | |
| "IPAdapter", | |
| "musev_referencenet", | |
| "musev_referencenet_pose", | |
| ]: | |
| ip_adapter_image_proj = ImageProjModel( | |
| cross_attention_dim=cross_attention_dim, | |
| clip_embeddings_dim=clip_embeddings_dim, | |
| clip_extra_context_tokens=clip_extra_context_tokens, | |
| ) | |
| elif model_name == "IPAdapterPlus": | |
| vision_clip_extractor = ImageClipVisionFeatureExtractorV2( | |
| pretrained_model_name_or_path=ip_image_encoder, | |
| dtype=dtype, | |
| device=device, | |
| ) | |
| ip_adapter_image_proj = Resampler( | |
| dim=cross_attention_dim, | |
| depth=4, | |
| dim_head=64, | |
| heads=12, | |
| num_queries=clip_extra_context_tokens, | |
| embedding_dim=vision_clip_extractor.image_encoder.config.hidden_size, | |
| output_dim=cross_attention_dim, | |
| ff_mult=4, | |
| ) | |
| elif model_name in [ | |
| "VerstailSDLastHiddenState2ImageEmb", | |
| "OriginLastHiddenState2ImageEmbd", | |
| "OriginLastHiddenState2Poolout", | |
| ]: | |
| ip_adapter_image_proj = getattr( | |
| clip_vision_extractor, model_name | |
| ).from_pretrained(ip_image_encoder) | |
| else: | |
| raise ValueError( | |
| f"unsupport model_name={model_name}, only support IPAdapter, IPAdapterPlus, VerstailSDLastHiddenState2ImageEmb" | |
| ) | |
| if ip_ckpt is not None: | |
| ip_adapter_state_dict = torch.load( | |
| ip_ckpt, | |
| map_location="cpu", | |
| ) | |
| ip_adapter_image_proj.load_state_dict(ip_adapter_state_dict["image_proj"]) | |
| if ( | |
| unet is not None | |
| and unet.ip_adapter_cross_attn | |
| and "ip_adapter" in ip_adapter_state_dict | |
| ): | |
| update_unet_ip_adapter_cross_attn_param( | |
| unet, ip_adapter_state_dict["ip_adapter"] | |
| ) | |
| logger.info( | |
| f"update unet.spatial_cross_attn_ip_adapter parameter with {ip_ckpt}" | |
| ) | |
| return ip_adapter_image_proj | |
| def load_ip_adapter_vision_clip_encoder_by_name( | |
| model_name: str, | |
| ip_ckpt: Tuple[str, nn.Module], | |
| ip_image_encoder: Tuple[str, nn.Module] = None, | |
| cross_attention_dim: int = 768, | |
| clip_embeddings_dim: int = 1024, | |
| clip_extra_context_tokens: int = 4, | |
| ip_scale: float = 0.0, | |
| dtype: torch.dtype = torch.float16, | |
| device: str = "cuda", | |
| unet: nn.Module = None, | |
| vision_clip_extractor_class_name: str = None, | |
| ) -> nn.Module: | |
| if vision_clip_extractor_class_name is not None: | |
| vision_clip_extractor = getattr( | |
| clip_vision_extractor, vision_clip_extractor_class_name | |
| )( | |
| pretrained_model_name_or_path=ip_image_encoder, | |
| dtype=dtype, | |
| device=device, | |
| ) | |
| else: | |
| vision_clip_extractor = None | |
| if model_name in [ | |
| "IPAdapter", | |
| "musev_referencenet", | |
| ]: | |
| if ip_image_encoder is not None: | |
| if vision_clip_extractor_class_name is None: | |
| vision_clip_extractor = ImageClipVisionFeatureExtractor( | |
| pretrained_model_name_or_path=ip_image_encoder, | |
| dtype=dtype, | |
| device=device, | |
| ) | |
| else: | |
| vision_clip_extractor = None | |
| ip_adapter_image_proj = ImageProjModel( | |
| cross_attention_dim=cross_attention_dim, | |
| clip_embeddings_dim=clip_embeddings_dim, | |
| clip_extra_context_tokens=clip_extra_context_tokens, | |
| ) | |
| elif model_name == "IPAdapterPlus": | |
| if ip_image_encoder is not None: | |
| if vision_clip_extractor_class_name is None: | |
| vision_clip_extractor = ImageClipVisionFeatureExtractorV2( | |
| pretrained_model_name_or_path=ip_image_encoder, | |
| dtype=dtype, | |
| device=device, | |
| ) | |
| else: | |
| vision_clip_extractor = None | |
| ip_adapter_image_proj = Resampler( | |
| dim=cross_attention_dim, | |
| depth=4, | |
| dim_head=64, | |
| heads=12, | |
| num_queries=clip_extra_context_tokens, | |
| embedding_dim=vision_clip_extractor.image_encoder.config.hidden_size, | |
| output_dim=cross_attention_dim, | |
| ff_mult=4, | |
| ).to(dtype=torch.float16) | |
| else: | |
| raise ValueError( | |
| f"unsupport model_name={model_name}, only support IPAdapter, IPAdapterPlus" | |
| ) | |
| ip_adapter_state_dict = torch.load( | |
| ip_ckpt, | |
| map_location="cpu", | |
| ) | |
| ip_adapter_image_proj.load_state_dict(ip_adapter_state_dict["image_proj"]) | |
| if ( | |
| unet is not None | |
| and unet.ip_adapter_cross_attn | |
| and "ip_adapter" in ip_adapter_state_dict | |
| ): | |
| update_unet_ip_adapter_cross_attn_param( | |
| unet, ip_adapter_state_dict["ip_adapter"] | |
| ) | |
| logger.info( | |
| f"update unet.spatial_cross_attn_ip_adapter parameter with {ip_ckpt}" | |
| ) | |
| return ( | |
| vision_clip_extractor, | |
| ip_adapter_image_proj, | |
| ) | |
| # refer https://github.com/tencent-ailab/IP-Adapter/issues/168#issuecomment-1846771651 | |
| unet_keys_list = [ | |
| "down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor.to_k_ip.weight", | |
| "down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor.to_v_ip.weight", | |
| "down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor.to_k_ip.weight", | |
| "down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor.to_v_ip.weight", | |
| "down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.to_k_ip.weight", | |
| "down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.to_v_ip.weight", | |
| "down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.to_k_ip.weight", | |
| "down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.to_v_ip.weight", | |
| "down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.to_k_ip.weight", | |
| "down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.to_v_ip.weight", | |
| "down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.to_k_ip.weight", | |
| "down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.to_v_ip.weight", | |
| "up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.to_k_ip.weight", | |
| "up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.to_v_ip.weight", | |
| "up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.to_k_ip.weight", | |
| "up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.to_v_ip.weight", | |
| "up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor.to_k_ip.weight", | |
| "up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor.to_v_ip.weight", | |
| "up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.to_k_ip.weight", | |
| "up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.to_v_ip.weight", | |
| "up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.to_k_ip.weight", | |
| "up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.to_v_ip.weight", | |
| "up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor.to_k_ip.weight", | |
| "up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor.to_v_ip.weight", | |
| "up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor.to_k_ip.weight", | |
| "up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor.to_v_ip.weight", | |
| "up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor.to_k_ip.weight", | |
| "up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor.to_v_ip.weight", | |
| "up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor.to_k_ip.weight", | |
| "up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor.to_v_ip.weight", | |
| "mid_block.attentions.0.transformer_blocks.0.attn2.processor.to_k_ip.weight", | |
| "mid_block.attentions.0.transformer_blocks.0.attn2.processor.to_v_ip.weight", | |
| ] | |
| ip_adapter_keys_list = [ | |
| "1.to_k_ip.weight", | |
| "1.to_v_ip.weight", | |
| "3.to_k_ip.weight", | |
| "3.to_v_ip.weight", | |
| "5.to_k_ip.weight", | |
| "5.to_v_ip.weight", | |
| "7.to_k_ip.weight", | |
| "7.to_v_ip.weight", | |
| "9.to_k_ip.weight", | |
| "9.to_v_ip.weight", | |
| "11.to_k_ip.weight", | |
| "11.to_v_ip.weight", | |
| "13.to_k_ip.weight", | |
| "13.to_v_ip.weight", | |
| "15.to_k_ip.weight", | |
| "15.to_v_ip.weight", | |
| "17.to_k_ip.weight", | |
| "17.to_v_ip.weight", | |
| "19.to_k_ip.weight", | |
| "19.to_v_ip.weight", | |
| "21.to_k_ip.weight", | |
| "21.to_v_ip.weight", | |
| "23.to_k_ip.weight", | |
| "23.to_v_ip.weight", | |
| "25.to_k_ip.weight", | |
| "25.to_v_ip.weight", | |
| "27.to_k_ip.weight", | |
| "27.to_v_ip.weight", | |
| "29.to_k_ip.weight", | |
| "29.to_v_ip.weight", | |
| "31.to_k_ip.weight", | |
| "31.to_v_ip.weight", | |
| ] | |
| UNET2IPAadapter_Keys_MAPIING = { | |
| k: v for k, v in zip(unet_keys_list, ip_adapter_keys_list) | |
| } | |
| def update_unet_ip_adapter_cross_attn_param( | |
| unet: UNet3DConditionModel, ip_adapter_state_dict: Dict | |
| ) -> None: | |
| """use independent ip_adapter attn 中的 to_k, to_v in unet | |
| ip_adapter: dict whose keys are ['1.to_k_ip.weight', '1.to_v_ip.weight', '3.to_k_ip.weight'] | |
| Args: | |
| unet (UNet3DConditionModel): _description_ | |
| ip_adapter_state_dict (Dict): _description_ | |
| """ | |
| unet_spatial_cross_atnns = unet.spatial_cross_attns[0] | |
| unet_spatial_cross_atnns_dct = {k: v for k, v in unet_spatial_cross_atnns} | |
| for i, (unet_key_more, ip_adapter_key) in enumerate( | |
| UNET2IPAadapter_Keys_MAPIING.items() | |
| ): | |
| ip_adapter_value = ip_adapter_state_dict[ip_adapter_key] | |
| unet_key_more_spit = unet_key_more.split(".") | |
| unet_key = ".".join(unet_key_more_spit[:-3]) | |
| suffix = ".".join(unet_key_more_spit[-3:]) | |
| logger.debug( | |
| f"{i}: unet_key_more = {unet_key_more}, {unet_key}=unet_key, suffix={suffix}", | |
| ) | |
| if "to_k" in suffix: | |
| with torch.no_grad(): | |
| unet_spatial_cross_atnns_dct[unet_key].to_k_ip.weight.copy_( | |
| ip_adapter_value.data | |
| ) | |
| else: | |
| with torch.no_grad(): | |
| unet_spatial_cross_atnns_dct[unet_key].to_v_ip.weight.copy_( | |
| ip_adapter_value.data | |
| ) | |