Spaces:
Runtime error
Runtime error
| import gc | |
| import os | |
| from typing import Any, Callable, List, Literal, Union, Dict, Tuple | |
| import logging | |
| from safetensors.torch import load_file | |
| from safetensors import safe_open | |
| import torch | |
| from torch import nn | |
| from diffusers.models.controlnet import ControlNetModel | |
| from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel | |
| from diffusers.pipelines.pipeline_utils import DiffusionPipeline | |
| from .convert_from_ckpt import ( | |
| convert_ldm_unet_checkpoint, | |
| convert_ldm_vae_checkpoint, | |
| convert_ldm_clip_checkpoint, | |
| ) | |
| from .convert_lora_safetensor_to_diffusers import convert_motion_lora_ckpt_to_diffusers | |
| logger = logging.getLogger(__name__) | |
| def update_pipeline_model_parameters( | |
| pipeline: DiffusionPipeline, | |
| model_path: str = None, | |
| lora_dict: Dict[str, Dict] = None, | |
| text_model_path: str = None, | |
| device="cuda", | |
| need_unload: bool = False, | |
| ): | |
| if model_path is not None: | |
| pipeline = update_pipeline_basemodel( | |
| pipeline, model_path, text_sd_model_path=text_model_path, device=device | |
| ) | |
| if lora_dict is not None: | |
| pipeline, unload_dict = update_pipeline_lora_models( | |
| pipeline, | |
| lora_dict, | |
| device=device, | |
| need_unload=need_unload, | |
| ) | |
| if need_unload: | |
| return pipeline, unload_dict | |
| return pipeline | |
| def update_pipeline_basemodel( | |
| pipeline: DiffusionPipeline, | |
| model_path: str, | |
| text_sd_model_path: str, | |
| device: str = "cuda", | |
| ): | |
| """使用model_path更新pipeline中的基础参数 | |
| Args: | |
| pipeline (DiffusionPipeline): _description_ | |
| model_path (str): _description_ | |
| text_sd_model_path (str): _description_ | |
| device (str, optional): _description_. Defaults to "cuda". | |
| Returns: | |
| _type_: _description_ | |
| """ | |
| # load base | |
| if model_path.endswith(".ckpt"): | |
| state_dict = torch.load(model_path, map_location=device) | |
| pipeline.unet.load_state_dict(state_dict) | |
| print("update sd_model", model_path) | |
| elif model_path.endswith(".safetensors"): | |
| base_state_dict = {} | |
| with safe_open(model_path, framework="pt", device=device) as f: | |
| for key in f.keys(): | |
| base_state_dict[key] = f.get_tensor(key) | |
| is_lora = all("lora" in k for k in base_state_dict.keys()) | |
| assert is_lora == False, "Base model cannot be LoRA: {}".format(model_path) | |
| # vae | |
| converted_vae_checkpoint = convert_ldm_vae_checkpoint( | |
| base_state_dict, pipeline.vae.config | |
| ) | |
| pipeline.vae.load_state_dict(converted_vae_checkpoint) | |
| # unet | |
| converted_unet_checkpoint = convert_ldm_unet_checkpoint( | |
| base_state_dict, pipeline.unet.config | |
| ) | |
| pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False) | |
| # text_model | |
| pipeline.text_encoder = convert_ldm_clip_checkpoint( | |
| base_state_dict, text_sd_model_path | |
| ) | |
| print("update sd_model", model_path) | |
| pipeline.to(device) | |
| return pipeline | |
| # ref https://git.woa.com/innovative_tech/GenerationGroup/VirtualIdol/VidolImageDraw/blob/master/cfg.yaml | |
| LORA_BLOCK_WEIGHT_MAP = { | |
| "FACE": [1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0], | |
| "DEFACE": [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1], | |
| "ALL": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], | |
| "MIDD": [1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0], | |
| "OUTALL": [1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1], | |
| } | |
| # ref https://git.woa.com/innovative_tech/GenerationGroup/VirtualIdol/VidolImageDraw/blob/master/pipeline/draw_pipe.py | |
| def update_pipeline_lora_model( | |
| pipeline: DiffusionPipeline, | |
| lora: Union[str, Dict], | |
| alpha: float = 0.75, | |
| device: str = "cuda", | |
| lora_prefix_unet: str = "lora_unet", | |
| lora_prefix_text_encoder: str = "lora_te", | |
| lora_unet_layers=[ | |
| "lora_unet_down_blocks_0_attentions_0", | |
| "lora_unet_down_blocks_0_attentions_1", | |
| "lora_unet_down_blocks_1_attentions_0", | |
| "lora_unet_down_blocks_1_attentions_1", | |
| "lora_unet_down_blocks_2_attentions_0", | |
| "lora_unet_down_blocks_2_attentions_1", | |
| "lora_unet_mid_block_attentions_0", | |
| "lora_unet_up_blocks_1_attentions_0", | |
| "lora_unet_up_blocks_1_attentions_1", | |
| "lora_unet_up_blocks_1_attentions_2", | |
| "lora_unet_up_blocks_2_attentions_0", | |
| "lora_unet_up_blocks_2_attentions_1", | |
| "lora_unet_up_blocks_2_attentions_2", | |
| "lora_unet_up_blocks_3_attentions_0", | |
| "lora_unet_up_blocks_3_attentions_1", | |
| "lora_unet_up_blocks_3_attentions_2", | |
| ], | |
| lora_block_weight_str: Literal["FACE", "ALL"] = "ALL", | |
| need_unload: bool = False, | |
| ): | |
| """使用 lora 更新pipeline中的unet相关参数 | |
| Args: | |
| pipeline (DiffusionPipeline): _description_ | |
| lora (Union[str, Dict]): _description_ | |
| alpha (float, optional): _description_. Defaults to 0.75. | |
| device (str, optional): _description_. Defaults to "cuda". | |
| lora_prefix_unet (str, optional): _description_. Defaults to "lora_unet". | |
| lora_prefix_text_encoder (str, optional): _description_. Defaults to "lora_te". | |
| lora_unet_layers (list, optional): _description_. Defaults to [ "lora_unet_down_blocks_0_attentions_0", "lora_unet_down_blocks_0_attentions_1", "lora_unet_down_blocks_1_attentions_0", "lora_unet_down_blocks_1_attentions_1", "lora_unet_down_blocks_2_attentions_0", "lora_unet_down_blocks_2_attentions_1", "lora_unet_mid_block_attentions_0", "lora_unet_up_blocks_1_attentions_0", "lora_unet_up_blocks_1_attentions_1", "lora_unet_up_blocks_1_attentions_2", "lora_unet_up_blocks_2_attentions_0", "lora_unet_up_blocks_2_attentions_1", "lora_unet_up_blocks_2_attentions_2", "lora_unet_up_blocks_3_attentions_0", "lora_unet_up_blocks_3_attentions_1", "lora_unet_up_blocks_3_attentions_2", ]. | |
| lora_block_weight_str (Literal["FACE", "ALL"], optional): _description_. Defaults to "ALL". | |
| need_unload (bool, optional): _description_. Defaults to False. | |
| Returns: | |
| _type_: _description_ | |
| """ | |
| # ref https://git.woa.com/innovative_tech/GenerationGroup/VirtualIdol/VidolImageDraw/blob/master/pipeline/tool.py#L20 | |
| if lora_block_weight_str is not None: | |
| lora_block_weight = LORA_BLOCK_WEIGHT_MAP[lora_block_weight_str.upper()] | |
| if lora_block_weight: | |
| assert len(lora_block_weight) == 17 | |
| # load lora weight | |
| if isinstance(lora, str): | |
| state_dict = load_file(lora, device=device) | |
| else: | |
| for k in lora: | |
| lora[k] = lora[k].to(device) | |
| state_dict = lora # state_dict = {} | |
| visited = set() | |
| unload_dict = [] | |
| # directly update weight in diffusers model | |
| for key in state_dict: | |
| # it is suggested to print out the key, it usually will be something like below | |
| # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight" | |
| # as we have set the alpha beforehand, so just skip | |
| if ".alpha" in key or key in visited: | |
| continue | |
| if "text" in key: | |
| layer_infos = ( | |
| key.split(".")[0].split(lora_prefix_text_encoder + "_")[-1].split("_") | |
| ) | |
| curr_layer = pipeline.text_encoder | |
| else: | |
| layer_infos = key.split(".")[0].split(lora_prefix_unet + "_")[-1].split("_") | |
| curr_layer = pipeline.unet | |
| # find the target layer | |
| temp_name = layer_infos.pop(0) | |
| while len(layer_infos) > -1: | |
| try: | |
| curr_layer = curr_layer.__getattr__(temp_name) | |
| if len(layer_infos) > 0: | |
| temp_name = layer_infos.pop(0) | |
| elif len(layer_infos) == 0: | |
| break | |
| except Exception: | |
| if len(temp_name) > 0: | |
| temp_name += "_" + layer_infos.pop(0) | |
| else: | |
| temp_name = layer_infos.pop(0) | |
| pair_keys = [] | |
| if "lora_down" in key: | |
| pair_keys.append(key.replace("lora_down", "lora_up")) | |
| pair_keys.append(key) | |
| alpha_key = key.replace("lora_down.weight", "alpha") | |
| else: | |
| pair_keys.append(key) | |
| pair_keys.append(key.replace("lora_up", "lora_down")) | |
| alpha_key = key.replace("lora_up.weight", "alpha") | |
| # update weight | |
| if len(state_dict[pair_keys[0]].shape) == 4: | |
| weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32) | |
| weight_down = ( | |
| state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32) | |
| ) | |
| if alpha_key in state_dict: | |
| weight_scale = state_dict[alpha_key].item() / weight_up.shape[1] | |
| else: | |
| weight_scale = 1.0 | |
| # adding_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3) | |
| if len(weight_up.shape) == len(weight_down.shape): | |
| adding_weight = ( | |
| alpha | |
| * weight_scale | |
| * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3) | |
| ) | |
| else: | |
| adding_weight = ( | |
| alpha | |
| * weight_scale | |
| * torch.einsum("a b, b c h w -> a c h w", weight_up, weight_down) | |
| ) | |
| else: | |
| weight_up = state_dict[pair_keys[0]].to(torch.float32) | |
| weight_down = state_dict[pair_keys[1]].to(torch.float32) | |
| if alpha_key in state_dict: | |
| weight_scale = state_dict[alpha_key].item() / weight_up.shape[1] | |
| else: | |
| weight_scale = 1.0 | |
| adding_weight = alpha * weight_scale * torch.mm(weight_up, weight_down) | |
| adding_weight = adding_weight.to(torch.float16) | |
| if lora_block_weight: | |
| if "text" in key: | |
| adding_weight *= lora_block_weight[0] | |
| else: | |
| for idx, layer in enumerate(lora_unet_layers): | |
| if layer in key: | |
| adding_weight *= lora_block_weight[idx + 1] | |
| break | |
| curr_layer_unload_data = {"layer": curr_layer, "added_weight": adding_weight} | |
| curr_layer.weight.data += adding_weight | |
| unload_dict.append(curr_layer_unload_data) | |
| # update visited list | |
| for item in pair_keys: | |
| visited.add(item) | |
| if need_unload: | |
| return pipeline, unload_dict | |
| else: | |
| return pipeline | |
| # ref https://git.woa.com/innovative_tech/GenerationGroup/VirtualIdol/VidolImageDraw/blob/master/pipeline/draw_pipe.py | |
| def update_pipeline_lora_model_old( | |
| pipeline: DiffusionPipeline, | |
| lora: Union[str, Dict], | |
| alpha: float = 0.75, | |
| device: str = "cuda", | |
| lora_prefix_unet: str = "lora_unet", | |
| lora_prefix_text_encoder: str = "lora_te", | |
| lora_unet_layers=[ | |
| "lora_unet_down_blocks_0_attentions_0", | |
| "lora_unet_down_blocks_0_attentions_1", | |
| "lora_unet_down_blocks_1_attentions_0", | |
| "lora_unet_down_blocks_1_attentions_1", | |
| "lora_unet_down_blocks_2_attentions_0", | |
| "lora_unet_down_blocks_2_attentions_1", | |
| "lora_unet_mid_block_attentions_0", | |
| "lora_unet_up_blocks_1_attentions_0", | |
| "lora_unet_up_blocks_1_attentions_1", | |
| "lora_unet_up_blocks_1_attentions_2", | |
| "lora_unet_up_blocks_2_attentions_0", | |
| "lora_unet_up_blocks_2_attentions_1", | |
| "lora_unet_up_blocks_2_attentions_2", | |
| "lora_unet_up_blocks_3_attentions_0", | |
| "lora_unet_up_blocks_3_attentions_1", | |
| "lora_unet_up_blocks_3_attentions_2", | |
| ], | |
| lora_block_weight_str: Literal["FACE", "ALL"] = "ALL", | |
| need_unload: bool = False, | |
| ): | |
| """使用 lora 更新pipeline中的unet相关参数 | |
| Args: | |
| pipeline (DiffusionPipeline): _description_ | |
| lora (Union[str, Dict]): _description_ | |
| alpha (float, optional): _description_. Defaults to 0.75. | |
| device (str, optional): _description_. Defaults to "cuda". | |
| lora_prefix_unet (str, optional): _description_. Defaults to "lora_unet". | |
| lora_prefix_text_encoder (str, optional): _description_. Defaults to "lora_te". | |
| lora_unet_layers (list, optional): _description_. Defaults to [ "lora_unet_down_blocks_0_attentions_0", "lora_unet_down_blocks_0_attentions_1", "lora_unet_down_blocks_1_attentions_0", "lora_unet_down_blocks_1_attentions_1", "lora_unet_down_blocks_2_attentions_0", "lora_unet_down_blocks_2_attentions_1", "lora_unet_mid_block_attentions_0", "lora_unet_up_blocks_1_attentions_0", "lora_unet_up_blocks_1_attentions_1", "lora_unet_up_blocks_1_attentions_2", "lora_unet_up_blocks_2_attentions_0", "lora_unet_up_blocks_2_attentions_1", "lora_unet_up_blocks_2_attentions_2", "lora_unet_up_blocks_3_attentions_0", "lora_unet_up_blocks_3_attentions_1", "lora_unet_up_blocks_3_attentions_2", ]. | |
| lora_block_weight_str (Literal["FACE", "ALL"], optional): _description_. Defaults to "ALL". | |
| need_unload (bool, optional): _description_. Defaults to False. | |
| Returns: | |
| _type_: _description_ | |
| """ | |
| # ref https://git.woa.com/innovative_tech/GenerationGroup/VirtualIdol/VidolImageDraw/blob/master/pipeline/tool.py#L20 | |
| if lora_block_weight_str is not None: | |
| lora_block_weight = LORA_BLOCK_WEIGHT_MAP[lora_block_weight_str.upper()] | |
| if lora_block_weight: | |
| assert len(lora_block_weight) == 17 | |
| # load lora weight | |
| if isinstance(lora, str): | |
| state_dict = load_file(lora, device=device) | |
| else: | |
| for k in lora: | |
| lora[k] = lora[k].to(device) | |
| state_dict = lora # state_dict = {} | |
| visited = set() | |
| unload_dict = [] | |
| # directly update weight in diffusers model | |
| for key in state_dict: | |
| # it is suggested to print out the key, it usually will be something like below | |
| # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight" | |
| # as we have set the alpha beforehand, so just skip | |
| if ".alpha" in key or key in visited: | |
| continue | |
| if "text" in key: | |
| layer_infos = ( | |
| key.split(".")[0].split(lora_prefix_text_encoder + "_")[-1].split("_") | |
| ) | |
| curr_layer = pipeline.text_encoder | |
| else: | |
| layer_infos = key.split(".")[0].split(lora_prefix_unet + "_")[-1].split("_") | |
| curr_layer = pipeline.unet | |
| # find the target layer | |
| temp_name = layer_infos.pop(0) | |
| while len(layer_infos) > -1: | |
| try: | |
| curr_layer = curr_layer.__getattr__(temp_name) | |
| if len(layer_infos) > 0: | |
| temp_name = layer_infos.pop(0) | |
| elif len(layer_infos) == 0: | |
| break | |
| except Exception: | |
| if len(temp_name) > 0: | |
| temp_name += "_" + layer_infos.pop(0) | |
| else: | |
| temp_name = layer_infos.pop(0) | |
| pair_keys = [] | |
| if "lora_down" in key: | |
| pair_keys.append(key.replace("lora_down", "lora_up")) | |
| pair_keys.append(key) | |
| else: | |
| pair_keys.append(key) | |
| pair_keys.append(key.replace("lora_up", "lora_down")) | |
| # update weight | |
| if len(state_dict[pair_keys[0]].shape) == 4: | |
| weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32) | |
| weight_down = ( | |
| state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32) | |
| ) | |
| adding_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze( | |
| 2 | |
| ).unsqueeze(3) | |
| else: | |
| weight_up = state_dict[pair_keys[0]].to(torch.float32) | |
| weight_down = state_dict[pair_keys[1]].to(torch.float32) | |
| adding_weight = alpha * torch.mm(weight_up, weight_down) | |
| if lora_block_weight: | |
| if "text" in key: | |
| adding_weight *= lora_block_weight[0] | |
| else: | |
| for idx, layer in enumerate(lora_unet_layers): | |
| if layer in key: | |
| adding_weight *= lora_block_weight[idx + 1] | |
| break | |
| curr_layer_unload_data = {"layer": curr_layer, "added_weight": adding_weight} | |
| curr_layer.weight.data += adding_weight | |
| unload_dict.append(curr_layer_unload_data) | |
| # update visited list | |
| for item in pair_keys: | |
| visited.add(item) | |
| if need_unload: | |
| return pipeline, unload_dict | |
| else: | |
| return pipeline | |
| def update_pipeline_lora_models( | |
| pipeline: DiffusionPipeline, | |
| lora_dict: Dict[str, Dict], | |
| device: str = "cuda", | |
| need_unload: bool = True, | |
| lora_prefix_unet: str = "lora_unet", | |
| lora_prefix_text_encoder: str = "lora_te", | |
| lora_unet_layers=[ | |
| "lora_unet_down_blocks_0_attentions_0", | |
| "lora_unet_down_blocks_0_attentions_1", | |
| "lora_unet_down_blocks_1_attentions_0", | |
| "lora_unet_down_blocks_1_attentions_1", | |
| "lora_unet_down_blocks_2_attentions_0", | |
| "lora_unet_down_blocks_2_attentions_1", | |
| "lora_unet_mid_block_attentions_0", | |
| "lora_unet_up_blocks_1_attentions_0", | |
| "lora_unet_up_blocks_1_attentions_1", | |
| "lora_unet_up_blocks_1_attentions_2", | |
| "lora_unet_up_blocks_2_attentions_0", | |
| "lora_unet_up_blocks_2_attentions_1", | |
| "lora_unet_up_blocks_2_attentions_2", | |
| "lora_unet_up_blocks_3_attentions_0", | |
| "lora_unet_up_blocks_3_attentions_1", | |
| "lora_unet_up_blocks_3_attentions_2", | |
| ], | |
| ): | |
| """使用 lora 更新pipeline中的unet相关参数 | |
| Args: | |
| pipeline (DiffusionPipeline): _description_ | |
| lora_dict (Dict[str, Dict]): _description_ | |
| device (str, optional): _description_. Defaults to "cuda". | |
| lora_prefix_unet (str, optional): _description_. Defaults to "lora_unet". | |
| lora_prefix_text_encoder (str, optional): _description_. Defaults to "lora_te". | |
| lora_unet_layers (list, optional): _description_. Defaults to [ "lora_unet_down_blocks_0_attentions_0", "lora_unet_down_blocks_0_attentions_1", "lora_unet_down_blocks_1_attentions_0", "lora_unet_down_blocks_1_attentions_1", "lora_unet_down_blocks_2_attentions_0", "lora_unet_down_blocks_2_attentions_1", "lora_unet_mid_block_attentions_0", "lora_unet_up_blocks_1_attentions_0", "lora_unet_up_blocks_1_attentions_1", "lora_unet_up_blocks_1_attentions_2", "lora_unet_up_blocks_2_attentions_0", "lora_unet_up_blocks_2_attentions_1", "lora_unet_up_blocks_2_attentions_2", "lora_unet_up_blocks_3_attentions_0", "lora_unet_up_blocks_3_attentions_1", "lora_unet_up_blocks_3_attentions_2", ]. | |
| Returns: | |
| _type_: _description_ | |
| """ | |
| unload_dicts = [] | |
| for lora, value in lora_dict.items(): | |
| lora_name = os.path.basename(lora).replace(".safetensors", "") | |
| strength_offset = value.get("strength_offset", 0.0) | |
| alpha = value.get("strength", 1.0) | |
| alpha += strength_offset | |
| lora_weight_str = value.get("lora_block_weight", "ALL") | |
| lora = load_file(lora) | |
| pipeline, unload_dict = update_pipeline_lora_model( | |
| pipeline, | |
| lora=lora, | |
| device=device, | |
| alpha=alpha, | |
| lora_prefix_unet=lora_prefix_unet, | |
| lora_prefix_text_encoder=lora_prefix_text_encoder, | |
| lora_unet_layers=lora_unet_layers, | |
| lora_block_weight_str=lora_weight_str, | |
| need_unload=True, | |
| ) | |
| print( | |
| "Update LoRA {} with alpha {} and weight {}".format( | |
| lora_name, alpha, lora_weight_str | |
| ) | |
| ) | |
| unload_dicts += unload_dict | |
| return pipeline, unload_dicts | |
| def unload_lora(unload_dict: List[Dict[str, nn.Module]]): | |
| for layer_data in unload_dict: | |
| layer = layer_data["layer"] | |
| added_weight = layer_data["added_weight"] | |
| layer.weight.data -= added_weight | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| def load_motion_lora_weights( | |
| animation_pipeline, | |
| motion_module_lora_configs=[], | |
| ): | |
| for motion_module_lora_config in motion_module_lora_configs: | |
| path, alpha = ( | |
| motion_module_lora_config["path"], | |
| motion_module_lora_config["alpha"], | |
| ) | |
| print(f"load motion LoRA from {path}") | |
| motion_lora_state_dict = torch.load(path, map_location="cpu") | |
| motion_lora_state_dict = ( | |
| motion_lora_state_dict["state_dict"] | |
| if "state_dict" in motion_lora_state_dict | |
| else motion_lora_state_dict | |
| ) | |
| animation_pipeline = convert_motion_lora_ckpt_to_diffusers( | |
| animation_pipeline, motion_lora_state_dict, alpha | |
| ) | |
| return animation_pipeline | |