| | |
| | |
| | |
| | |
| |
|
| | import copy |
| | import math |
| | import re |
| | from typing import NamedTuple |
| | import torch |
| |
|
| |
|
| | class LoRAInfo(NamedTuple): |
| | lora_name: str |
| | module_name: str |
| | module: torch.nn.Module |
| | multiplier: float |
| | dim: int |
| | alpha: float |
| |
|
| |
|
| | class LoRAModule(torch.nn.Module): |
| | """ |
| | replaces forward method of the original Linear, instead of replacing the original Linear module. |
| | """ |
| |
|
| | def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1): |
| | """if alpha == 0 or None, alpha is rank (no scaling).""" |
| | super().__init__() |
| | self.lora_name = lora_name |
| | self.lora_dim = lora_dim |
| |
|
| | if org_module.__class__.__name__ == "Conv2d": |
| | in_dim = org_module.in_channels |
| | out_dim = org_module.out_channels |
| |
|
| | |
| | |
| | |
| |
|
| | kernel_size = org_module.kernel_size |
| | stride = org_module.stride |
| | padding = org_module.padding |
| | self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) |
| | self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) |
| | else: |
| | in_dim = org_module.in_features |
| | out_dim = org_module.out_features |
| | self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) |
| | self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) |
| |
|
| | if type(alpha) == torch.Tensor: |
| | alpha = alpha.detach().float().numpy() |
| | alpha = self.lora_dim if alpha is None or alpha == 0 else alpha |
| | self.scale = alpha / self.lora_dim |
| | self.register_buffer("alpha", torch.tensor(alpha)) |
| |
|
| | |
| | torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) |
| | torch.nn.init.zeros_(self.lora_up.weight) |
| |
|
| | self.multiplier = multiplier |
| | self.org_forward = org_module.forward |
| | self.org_module = org_module |
| | self.mask_dic = None |
| | self.mask = None |
| | self.mask_area = -1 |
| |
|
| | def apply_to(self): |
| | self.org_forward = self.org_module.forward |
| | self.org_module.forward = self.forward |
| | del self.org_module |
| |
|
| | def set_mask_dic(self, mask_dic): |
| | |
| |
|
| | |
| | if "attn2_to_k" in self.lora_name or "attn2_to_v" in self.lora_name or "emb_layers" in self.lora_name: |
| | |
| | self.mask_dic = None |
| | else: |
| | self.mask_dic = mask_dic |
| |
|
| | self.mask = None |
| |
|
| | def forward(self, x): |
| | """ |
| | may be cascaded. |
| | """ |
| | if self.mask_dic is None: |
| | return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale |
| |
|
| | |
| |
|
| | |
| | lx = self.lora_up(self.lora_down(x)) |
| |
|
| | if len(lx.size()) == 4: |
| | area = lx.size()[2] * lx.size()[3] |
| | else: |
| | area = lx.size()[1] |
| |
|
| | if self.mask is None or self.mask_area != area: |
| | |
| | |
| | mask = self.mask_dic[area] |
| | if len(lx.size()) == 3: |
| | mask = torch.reshape(mask, (1, -1, 1)) |
| | self.mask = mask |
| | self.mask_area = area |
| |
|
| | return self.org_forward(x) + lx * self.multiplier * self.scale * self.mask |
| |
|
| |
|
| | def create_network_and_apply_compvis(du_state_dict, multiplier_tenc, multiplier_unet, text_encoder, unet, **kwargs): |
| | |
| | for module in unet.modules(): |
| | if module.__class__.__name__ == "Linear": |
| | param: torch.nn.Parameter = module.weight |
| | |
| | dtype = param.dtype |
| | break |
| |
|
| | |
| | modules_dim = {} |
| | modules_alpha = {} |
| | for key, value in du_state_dict.items(): |
| | if "." not in key: |
| | continue |
| |
|
| | lora_name = key.split(".")[0] |
| | if "alpha" in key: |
| | modules_alpha[lora_name] = float(value.detach().to(torch.float).cpu().numpy()) |
| | elif "lora_down" in key: |
| | dim = value.size()[0] |
| | modules_dim[lora_name] = dim |
| |
|
| | |
| | for key in modules_dim.keys(): |
| | if key not in modules_alpha: |
| | modules_alpha[key] = modules_dim[key] |
| |
|
| | print( |
| | f"dimension: {set(modules_dim.values())}, alpha: {set(modules_alpha.values())}, multiplier_unet: {multiplier_unet}, multiplier_tenc: {multiplier_tenc}" |
| | ) |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | network = LoRANetworkCompvis(text_encoder, unet, multiplier_tenc, multiplier_unet, modules_dim, modules_alpha) |
| | state_dict = network.apply_lora_modules(du_state_dict) |
| | network.to(dtype) |
| | info = network.load_state_dict(state_dict, strict=False) |
| |
|
| | |
| | if len(info.missing_keys) > 4: |
| | missing_keys = [] |
| | alpha_count = 0 |
| | for key in info.missing_keys: |
| | if "alpha" not in key: |
| | missing_keys.append(key) |
| | else: |
| | if alpha_count == 0: |
| | missing_keys.append(key) |
| | alpha_count += 1 |
| | if alpha_count > 1: |
| | missing_keys.append( |
| | f"... and {alpha_count-1} alphas. The model doesn't have alpha, use dim (rannk) as alpha. You can ignore this message." |
| | ) |
| |
|
| | info = torch.nn.modules.module._IncompatibleKeys(missing_keys, info.unexpected_keys) |
| |
|
| | return network, info |
| |
|
| |
|
| | class LoRANetworkCompvis(torch.nn.Module): |
| | |
| | |
| | UNET_TARGET_REPLACE_MODULE = ["SpatialTransformer", "ResBlock", "Downsample", "Upsample"] |
| | TEXT_ENCODER_TARGET_REPLACE_MODULE = ["ResidualAttentionBlock", "CLIPAttention", "CLIPMLP"] |
| |
|
| | LORA_PREFIX_UNET = "lora_unet" |
| | LORA_PREFIX_TEXT_ENCODER = "lora_te" |
| |
|
| | @classmethod |
| | def convert_diffusers_name_to_compvis(cls, v2, du_name): |
| | """ |
| | convert diffusers's LoRA name to CompVis |
| | """ |
| | cv_name = None |
| | if "lora_unet_" in du_name: |
| | m = re.search(r"_down_blocks_(\d+)_attentions_(\d+)_(.+)", du_name) |
| | if m: |
| | du_block_index = int(m.group(1)) |
| | du_attn_index = int(m.group(2)) |
| | du_suffix = m.group(3) |
| |
|
| | cv_index = 1 + du_block_index * 3 + du_attn_index |
| | cv_name = f"lora_unet_input_blocks_{cv_index}_1_{du_suffix}" |
| | return cv_name |
| |
|
| | m = re.search(r"_mid_block_attentions_(\d+)_(.+)", du_name) |
| | if m: |
| | du_suffix = m.group(2) |
| | cv_name = f"lora_unet_middle_block_1_{du_suffix}" |
| | return cv_name |
| |
|
| | m = re.search(r"_up_blocks_(\d+)_attentions_(\d+)_(.+)", du_name) |
| | if m: |
| | du_block_index = int(m.group(1)) |
| | du_attn_index = int(m.group(2)) |
| | du_suffix = m.group(3) |
| |
|
| | cv_index = du_block_index * 3 + du_attn_index |
| | cv_name = f"lora_unet_output_blocks_{cv_index}_1_{du_suffix}" |
| | return cv_name |
| |
|
| | m = re.search(r"_down_blocks_(\d+)_resnets_(\d+)_(.+)", du_name) |
| | if m: |
| | du_block_index = int(m.group(1)) |
| | du_res_index = int(m.group(2)) |
| | du_suffix = m.group(3) |
| | cv_suffix = { |
| | "conv1": "in_layers_2", |
| | "conv2": "out_layers_3", |
| | "time_emb_proj": "emb_layers_1", |
| | "conv_shortcut": "skip_connection", |
| | }[du_suffix] |
| |
|
| | cv_index = 1 + du_block_index * 3 + du_res_index |
| | cv_name = f"lora_unet_input_blocks_{cv_index}_0_{cv_suffix}" |
| | return cv_name |
| |
|
| | m = re.search(r"_down_blocks_(\d+)_downsamplers_0_conv", du_name) |
| | if m: |
| | block_index = int(m.group(1)) |
| | cv_index = 3 + block_index * 3 |
| | cv_name = f"lora_unet_input_blocks_{cv_index}_0_op" |
| | return cv_name |
| |
|
| | m = re.search(r"_mid_block_resnets_(\d+)_(.+)", du_name) |
| | if m: |
| | index = int(m.group(1)) |
| | du_suffix = m.group(2) |
| | cv_suffix = { |
| | "conv1": "in_layers_2", |
| | "conv2": "out_layers_3", |
| | "time_emb_proj": "emb_layers_1", |
| | "conv_shortcut": "skip_connection", |
| | }[du_suffix] |
| | cv_name = f"lora_unet_middle_block_{index*2}_{cv_suffix}" |
| | return cv_name |
| |
|
| | m = re.search(r"_up_blocks_(\d+)_resnets_(\d+)_(.+)", du_name) |
| | if m: |
| | du_block_index = int(m.group(1)) |
| | du_res_index = int(m.group(2)) |
| | du_suffix = m.group(3) |
| | cv_suffix = { |
| | "conv1": "in_layers_2", |
| | "conv2": "out_layers_3", |
| | "time_emb_proj": "emb_layers_1", |
| | "conv_shortcut": "skip_connection", |
| | }[du_suffix] |
| |
|
| | cv_index = du_block_index * 3 + du_res_index |
| | cv_name = f"lora_unet_output_blocks_{cv_index}_0_{cv_suffix}" |
| | return cv_name |
| |
|
| | m = re.search(r"_up_blocks_(\d+)_upsamplers_0_conv", du_name) |
| | if m: |
| | block_index = int(m.group(1)) |
| | cv_index = block_index * 3 + 2 |
| | cv_name = f"lora_unet_output_blocks_{cv_index}_{bool(block_index)+1}_conv" |
| | return cv_name |
| |
|
| | elif "lora_te_" in du_name: |
| | m = re.search(r"_model_encoder_layers_(\d+)_(.+)", du_name) |
| | if m: |
| | du_block_index = int(m.group(1)) |
| | du_suffix = m.group(2) |
| |
|
| | cv_index = du_block_index |
| | if v2: |
| | if "mlp_fc1" in du_suffix: |
| | cv_name = ( |
| | f"lora_te_wrapped_model_transformer_resblocks_{cv_index}_{du_suffix.replace('mlp_fc1', 'mlp_c_fc')}" |
| | ) |
| | elif "mlp_fc2" in du_suffix: |
| | cv_name = ( |
| | f"lora_te_wrapped_model_transformer_resblocks_{cv_index}_{du_suffix.replace('mlp_fc2', 'mlp_c_proj')}" |
| | ) |
| | elif "self_attn": |
| | |
| | cv_name = f"lora_te_wrapped_model_transformer_resblocks_{cv_index}_{du_suffix.replace('self_attn', 'attn')}" |
| | else: |
| | cv_name = f"lora_te_wrapped_transformer_text_model_encoder_layers_{cv_index}_{du_suffix}" |
| |
|
| | assert cv_name is not None, f"conversion failed: {du_name}. the model may not be trained by `sd-scripts`." |
| | return cv_name |
| |
|
| | @classmethod |
| | def convert_state_dict_name_to_compvis(cls, v2, state_dict): |
| | """ |
| | convert keys in state dict to load it by load_state_dict |
| | """ |
| | new_sd = {} |
| | for key, value in state_dict.items(): |
| | tokens = key.split(".") |
| | compvis_name = LoRANetworkCompvis.convert_diffusers_name_to_compvis(v2, tokens[0]) |
| | new_key = compvis_name + "." + ".".join(tokens[1:]) |
| |
|
| | new_sd[new_key] = value |
| |
|
| | return new_sd |
| |
|
| | def __init__(self, text_encoder, unet, multiplier_tenc=1.0, multiplier_unet=1.0, modules_dim=None, modules_alpha=None) -> None: |
| | super().__init__() |
| | self.multiplier_unet = multiplier_unet |
| | self.multiplier_tenc = multiplier_tenc |
| | self.latest_mask_info = None |
| |
|
| | |
| | self.v2 = False |
| | for _, module in text_encoder.named_modules(): |
| | for _, child_module in module.named_modules(): |
| | if child_module.__class__.__name__ == "MultiheadAttention": |
| | self.v2 = True |
| | break |
| | if self.v2: |
| | break |
| |
|
| | |
| | comp_vis_loras_dim_alpha = {} |
| | for du_lora_name in modules_dim.keys(): |
| | dim = modules_dim[du_lora_name] |
| | alpha = modules_alpha[du_lora_name] |
| | comp_vis_lora_name = LoRANetworkCompvis.convert_diffusers_name_to_compvis(self.v2, du_lora_name) |
| | comp_vis_loras_dim_alpha[comp_vis_lora_name] = (dim, alpha) |
| |
|
| | |
| | def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules, multiplier): |
| | loras = [] |
| | replaced_modules = [] |
| | for name, module in root_module.named_modules(): |
| | if module.__class__.__name__ in target_replace_modules: |
| | for child_name, child_module in module.named_modules(): |
| | |
| | if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d": |
| | lora_name = prefix + "." + name + "." + child_name |
| | lora_name = lora_name.replace(".", "_") |
| | if "_resblocks_23_" in lora_name: |
| | break |
| | if lora_name not in comp_vis_loras_dim_alpha: |
| | continue |
| |
|
| | dim, alpha = comp_vis_loras_dim_alpha[lora_name] |
| | lora = LoRAModule(lora_name, child_module, multiplier, dim, alpha) |
| | loras.append(lora) |
| |
|
| | replaced_modules.append(child_module) |
| | elif child_module.__class__.__name__ == "MultiheadAttention": |
| | |
| | for suffix in ["q_proj", "k_proj", "v_proj", "out_proj"]: |
| | module_name = prefix + "." + name + "." + child_name |
| | module_name = module_name.replace(".", "_") |
| | if "_resblocks_23_" in module_name: |
| | break |
| |
|
| | lora_name = module_name + "_" + suffix |
| | if lora_name not in comp_vis_loras_dim_alpha: |
| | continue |
| | dim, alpha = comp_vis_loras_dim_alpha[lora_name] |
| | lora_info = LoRAInfo(lora_name, module_name, child_module, multiplier, dim, alpha) |
| | loras.append(lora_info) |
| |
|
| | replaced_modules.append(child_module) |
| | return loras, replaced_modules |
| |
|
| | self.text_encoder_loras, te_rep_modules = create_modules( |
| | LoRANetworkCompvis.LORA_PREFIX_TEXT_ENCODER, |
| | text_encoder, |
| | LoRANetworkCompvis.TEXT_ENCODER_TARGET_REPLACE_MODULE, |
| | self.multiplier_tenc, |
| | ) |
| | print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") |
| |
|
| | self.unet_loras, unet_rep_modules = create_modules( |
| | LoRANetworkCompvis.LORA_PREFIX_UNET, unet, LoRANetworkCompvis.UNET_TARGET_REPLACE_MODULE, self.multiplier_unet |
| | ) |
| | print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") |
| |
|
| | |
| | backed_up = False |
| | for rep_module in te_rep_modules + unet_rep_modules: |
| | if ( |
| | rep_module.__class__.__name__ == "MultiheadAttention" |
| | ): |
| | if not hasattr(rep_module, "_lora_org_weights"): |
| | |
| | rep_module._lora_org_weights = copy.deepcopy(rep_module.state_dict()) |
| | backed_up = True |
| | elif not hasattr(rep_module, "_lora_org_forward"): |
| | rep_module._lora_org_forward = rep_module.forward |
| | backed_up = True |
| | if backed_up: |
| | print("original forward/weights is backed up.") |
| |
|
| | |
| | names = set() |
| | for lora in self.text_encoder_loras + self.unet_loras: |
| | assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" |
| | names.add(lora.lora_name) |
| |
|
| | def restore(self, text_encoder, unet): |
| | |
| | restored = False |
| | modules = [] |
| | modules.extend(text_encoder.modules()) |
| | modules.extend(unet.modules()) |
| | for module in modules: |
| | if hasattr(module, "_lora_org_forward"): |
| | module.forward = module._lora_org_forward |
| | del module._lora_org_forward |
| | restored = True |
| | if hasattr( |
| | module, "_lora_org_weights" |
| | ): |
| | module.load_state_dict(module._lora_org_weights) |
| | del module._lora_org_weights |
| | restored = True |
| |
|
| | if restored: |
| | print("original forward/weights is restored.") |
| |
|
| | def apply_lora_modules(self, du_state_dict): |
| | |
| | state_dict = LoRANetworkCompvis.convert_state_dict_name_to_compvis(self.v2, du_state_dict) |
| |
|
| | |
| | weights_has_text_encoder = weights_has_unet = False |
| | for key in state_dict.keys(): |
| | if key.startswith(LoRANetworkCompvis.LORA_PREFIX_TEXT_ENCODER): |
| | weights_has_text_encoder = True |
| | elif key.startswith(LoRANetworkCompvis.LORA_PREFIX_UNET): |
| | weights_has_unet = True |
| | if weights_has_text_encoder and weights_has_unet: |
| | break |
| |
|
| | apply_text_encoder = weights_has_text_encoder |
| | apply_unet = weights_has_unet |
| |
|
| | if apply_text_encoder: |
| | print("enable LoRA for text encoder") |
| | else: |
| | self.text_encoder_loras = [] |
| |
|
| | if apply_unet: |
| | print("enable LoRA for U-Net") |
| | else: |
| | self.unet_loras = [] |
| |
|
| | |
| | mha_loras = {} |
| | for lora in self.text_encoder_loras + self.unet_loras: |
| | if type(lora) == LoRAModule: |
| | lora.apply_to() |
| | self.add_module(lora.lora_name, lora) |
| | else: |
| | |
| | lora_info: LoRAInfo = lora |
| | if lora_info.module_name not in mha_loras: |
| | mha_loras[lora_info.module_name] = {} |
| |
|
| | lora_dic = mha_loras[lora_info.module_name] |
| | lora_dic[lora_info.lora_name] = lora_info |
| | if len(lora_dic) == 4: |
| | |
| | module = lora_info.module |
| | module_name = lora_info.module_name |
| | w_q_dw = state_dict.get(module_name + "_q_proj.lora_down.weight") |
| | if w_q_dw is not None: |
| | w_q_up = state_dict[module_name + "_q_proj.lora_up.weight"] |
| | w_k_dw = state_dict[module_name + "_k_proj.lora_down.weight"] |
| | w_k_up = state_dict[module_name + "_k_proj.lora_up.weight"] |
| | w_v_dw = state_dict[module_name + "_v_proj.lora_down.weight"] |
| | w_v_up = state_dict[module_name + "_v_proj.lora_up.weight"] |
| | w_out_dw = state_dict[module_name + "_out_proj.lora_down.weight"] |
| | w_out_up = state_dict[module_name + "_out_proj.lora_up.weight"] |
| | q_lora_info = lora_dic[module_name + "_q_proj"] |
| | k_lora_info = lora_dic[module_name + "_k_proj"] |
| | v_lora_info = lora_dic[module_name + "_v_proj"] |
| | out_lora_info = lora_dic[module_name + "_out_proj"] |
| |
|
| | sd = module.state_dict() |
| | qkv_weight = sd["in_proj_weight"] |
| | out_weight = sd["out_proj.weight"] |
| | dev = qkv_weight.device |
| |
|
| | def merge_weights(l_info, weight, up_weight, down_weight): |
| | |
| | scale = l_info.alpha / l_info.dim |
| | dtype = weight.dtype |
| | weight = ( |
| | weight.float() |
| | + l_info.multiplier |
| | * (up_weight.to(dev, dtype=torch.float) @ down_weight.to(dev, dtype=torch.float)) |
| | * scale |
| | ) |
| | weight = weight.to(dtype) |
| | return weight |
| |
|
| | q_weight, k_weight, v_weight = torch.chunk(qkv_weight, 3) |
| | if q_weight.size()[1] == w_q_up.size()[0]: |
| | q_weight = merge_weights(q_lora_info, q_weight, w_q_up, w_q_dw) |
| | k_weight = merge_weights(k_lora_info, k_weight, w_k_up, w_k_dw) |
| | v_weight = merge_weights(v_lora_info, v_weight, w_v_up, w_v_dw) |
| | qkv_weight = torch.cat([q_weight, k_weight, v_weight]) |
| |
|
| | out_weight = merge_weights(out_lora_info, out_weight, w_out_up, w_out_dw) |
| |
|
| | sd["in_proj_weight"] = qkv_weight.to(dev) |
| | sd["out_proj.weight"] = out_weight.to(dev) |
| |
|
| | lora_info.module.load_state_dict(sd) |
| | else: |
| | |
| | print(f"shape of weight is different: {module_name}. SD version may be different") |
| |
|
| | for t in ["q", "k", "v", "out"]: |
| | del state_dict[f"{module_name}_{t}_proj.lora_down.weight"] |
| | del state_dict[f"{module_name}_{t}_proj.lora_up.weight"] |
| | alpha_key = f"{module_name}_{t}_proj.alpha" |
| | if alpha_key in state_dict: |
| | del state_dict[alpha_key] |
| | else: |
| | |
| | pass |
| |
|
| | |
| | state_dict = self.convert_state_dict_shape_to_compvis(state_dict) |
| |
|
| | return state_dict |
| |
|
| | def convert_state_dict_shape_to_compvis(self, state_dict): |
| | |
| | current_sd = self.state_dict() |
| | wrapped = False |
| | count = 0 |
| | for key in list(state_dict.keys()): |
| | if key not in current_sd: |
| | continue |
| | if "wrapped" in key: |
| | wrapped = True |
| |
|
| | value: torch.Tensor = state_dict[key] |
| | if value.size() != current_sd[key].size(): |
| | |
| | count += 1 |
| | if len(value.size()) == 4: |
| | value = value.squeeze(3).squeeze(2) |
| | else: |
| | value = value.unsqueeze(2).unsqueeze(3) |
| | state_dict[key] = value |
| | if tuple(value.size()) != tuple(current_sd[key].size()): |
| | print( |
| | f"weight's shape is different: {key} expected {current_sd[key].size()} found {value.size()}. SD version may be different" |
| | ) |
| | del state_dict[key] |
| | print(f"shapes for {count} weights are converted.") |
| |
|
| | |
| | if not wrapped: |
| | print("remove 'wrapped' from keys") |
| | for key in list(state_dict.keys()): |
| | if "_wrapped_" in key: |
| | new_key = key.replace("_wrapped_", "_") |
| | state_dict[new_key] = state_dict[key] |
| | del state_dict[key] |
| |
|
| | return state_dict |
| |
|
| | def set_mask(self, mask, height=None, width=None, hr_height=None, hr_width=None): |
| | if mask is None: |
| | |
| | |
| | self.latest_mask_info = None |
| | for lora in self.unet_loras: |
| | lora.set_mask_dic(None) |
| | return |
| |
|
| | |
| | if ( |
| | self.latest_mask_info is not None |
| | and torch.equal(mask, self.latest_mask_info[0]) |
| | and (height, width, hr_height, hr_width) == self.latest_mask_info[1:] |
| | ): |
| | |
| | return |
| |
|
| | self.latest_mask_info = (mask, height, width, hr_height, hr_width) |
| |
|
| | org_dtype = mask.dtype |
| | if mask.dtype == torch.bfloat16: |
| | mask = mask.to(torch.float) |
| |
|
| | mask_dic = {} |
| | mask = mask.unsqueeze(0).unsqueeze(1) |
| |
|
| | def resize_add(mh, mw): |
| | |
| | m = torch.nn.functional.interpolate(mask, (mh, mw), mode="bilinear") |
| | m = m.to(org_dtype) |
| | mask_dic[mh * mw] = m |
| |
|
| | for h, w in [(height, width), (hr_height, hr_width)]: |
| | h = h // 8 |
| | w = w // 8 |
| | for i in range(4): |
| | resize_add(h, w) |
| | if h % 2 == 1 or w % 2 == 1: |
| | resize_add(h + h % 2, w + w % 2) |
| | h = (h + 1) // 2 |
| | w = (w + 1) // 2 |
| |
|
| | for lora in self.unet_loras: |
| | lora.set_mask_dic(mask_dic) |
| | return |
| |
|