| """ |
| This file is part of ComfyUI. |
| Copyright (C) 2024 Comfy |
| |
| This program is free software: you can redistribute it and/or modify |
| it under the terms of the GNU General Public License as published by |
| the Free Software Foundation, either version 3 of the License, or |
| (at your option) any later version. |
| |
| This program is distributed in the hope that it will be useful, |
| but WITHOUT ANY WARRANTY; without even the implied warranty of |
| MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
| GNU General Public License for more details. |
| |
| You should have received a copy of the GNU General Public License |
| along with this program. If not, see <https://www.gnu.org/licenses/>. |
| """ |
|
|
| from __future__ import annotations |
| import comfy.utils |
| import comfy.model_management |
| import comfy.model_base |
| import logging |
| import torch |
|
|
| LORA_CLIP_MAP = { |
| "mlp.fc1": "mlp_fc1", |
| "mlp.fc2": "mlp_fc2", |
| "self_attn.k_proj": "self_attn_k_proj", |
| "self_attn.q_proj": "self_attn_q_proj", |
| "self_attn.v_proj": "self_attn_v_proj", |
| "self_attn.out_proj": "self_attn_out_proj", |
| } |
|
|
|
|
| def load_lora(lora, to_load, log_missing=True): |
| patch_dict = {} |
| loaded_keys = set() |
| for x in to_load: |
| alpha_name = "{}.alpha".format(x) |
| alpha = None |
| if alpha_name in lora.keys(): |
| alpha = lora[alpha_name].item() |
| loaded_keys.add(alpha_name) |
|
|
| dora_scale_name = "{}.dora_scale".format(x) |
| dora_scale = None |
| if dora_scale_name in lora.keys(): |
| dora_scale = lora[dora_scale_name] |
| loaded_keys.add(dora_scale_name) |
|
|
| reshape_name = "{}.reshape_weight".format(x) |
| reshape = None |
| if reshape_name in lora.keys(): |
| try: |
| reshape = lora[reshape_name].tolist() |
| loaded_keys.add(reshape_name) |
| except: |
| pass |
|
|
| regular_lora = "{}.lora_up.weight".format(x) |
| diffusers_lora = "{}_lora.up.weight".format(x) |
| diffusers2_lora = "{}.lora_B.weight".format(x) |
| diffusers3_lora = "{}.lora.up.weight".format(x) |
| mochi_lora = "{}.lora_B".format(x) |
| transformers_lora = "{}.lora_linear_layer.up.weight".format(x) |
| A_name = None |
|
|
| if regular_lora in lora.keys(): |
| A_name = regular_lora |
| B_name = "{}.lora_down.weight".format(x) |
| mid_name = "{}.lora_mid.weight".format(x) |
| elif diffusers_lora in lora.keys(): |
| A_name = diffusers_lora |
| B_name = "{}_lora.down.weight".format(x) |
| mid_name = None |
| elif diffusers2_lora in lora.keys(): |
| A_name = diffusers2_lora |
| B_name = "{}.lora_A.weight".format(x) |
| mid_name = None |
| elif diffusers3_lora in lora.keys(): |
| A_name = diffusers3_lora |
| B_name = "{}.lora.down.weight".format(x) |
| mid_name = None |
| elif mochi_lora in lora.keys(): |
| A_name = mochi_lora |
| B_name = "{}.lora_A".format(x) |
| mid_name = None |
| elif transformers_lora in lora.keys(): |
| A_name = transformers_lora |
| B_name ="{}.lora_linear_layer.down.weight".format(x) |
| mid_name = None |
|
|
| if A_name is not None: |
| mid = None |
| if mid_name is not None and mid_name in lora.keys(): |
| mid = lora[mid_name] |
| loaded_keys.add(mid_name) |
| patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale, reshape)) |
| loaded_keys.add(A_name) |
| loaded_keys.add(B_name) |
|
|
|
|
| |
| hada_w1_a_name = "{}.hada_w1_a".format(x) |
| hada_w1_b_name = "{}.hada_w1_b".format(x) |
| hada_w2_a_name = "{}.hada_w2_a".format(x) |
| hada_w2_b_name = "{}.hada_w2_b".format(x) |
| hada_t1_name = "{}.hada_t1".format(x) |
| hada_t2_name = "{}.hada_t2".format(x) |
| if hada_w1_a_name in lora.keys(): |
| hada_t1 = None |
| hada_t2 = None |
| if hada_t1_name in lora.keys(): |
| hada_t1 = lora[hada_t1_name] |
| hada_t2 = lora[hada_t2_name] |
| loaded_keys.add(hada_t1_name) |
| loaded_keys.add(hada_t2_name) |
|
|
| patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale)) |
| loaded_keys.add(hada_w1_a_name) |
| loaded_keys.add(hada_w1_b_name) |
| loaded_keys.add(hada_w2_a_name) |
| loaded_keys.add(hada_w2_b_name) |
|
|
|
|
| |
| lokr_w1_name = "{}.lokr_w1".format(x) |
| lokr_w2_name = "{}.lokr_w2".format(x) |
| lokr_w1_a_name = "{}.lokr_w1_a".format(x) |
| lokr_w1_b_name = "{}.lokr_w1_b".format(x) |
| lokr_t2_name = "{}.lokr_t2".format(x) |
| lokr_w2_a_name = "{}.lokr_w2_a".format(x) |
| lokr_w2_b_name = "{}.lokr_w2_b".format(x) |
|
|
| lokr_w1 = None |
| if lokr_w1_name in lora.keys(): |
| lokr_w1 = lora[lokr_w1_name] |
| loaded_keys.add(lokr_w1_name) |
|
|
| lokr_w2 = None |
| if lokr_w2_name in lora.keys(): |
| lokr_w2 = lora[lokr_w2_name] |
| loaded_keys.add(lokr_w2_name) |
|
|
| lokr_w1_a = None |
| if lokr_w1_a_name in lora.keys(): |
| lokr_w1_a = lora[lokr_w1_a_name] |
| loaded_keys.add(lokr_w1_a_name) |
|
|
| lokr_w1_b = None |
| if lokr_w1_b_name in lora.keys(): |
| lokr_w1_b = lora[lokr_w1_b_name] |
| loaded_keys.add(lokr_w1_b_name) |
|
|
| lokr_w2_a = None |
| if lokr_w2_a_name in lora.keys(): |
| lokr_w2_a = lora[lokr_w2_a_name] |
| loaded_keys.add(lokr_w2_a_name) |
|
|
| lokr_w2_b = None |
| if lokr_w2_b_name in lora.keys(): |
| lokr_w2_b = lora[lokr_w2_b_name] |
| loaded_keys.add(lokr_w2_b_name) |
|
|
| lokr_t2 = None |
| if lokr_t2_name in lora.keys(): |
| lokr_t2 = lora[lokr_t2_name] |
| loaded_keys.add(lokr_t2_name) |
|
|
| if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None): |
| patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale)) |
|
|
| |
| a1_name = "{}.a1.weight".format(x) |
| a2_name = "{}.a2.weight".format(x) |
| b1_name = "{}.b1.weight".format(x) |
| b2_name = "{}.b2.weight".format(x) |
| if a1_name in lora: |
| patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale)) |
| loaded_keys.add(a1_name) |
| loaded_keys.add(a2_name) |
| loaded_keys.add(b1_name) |
| loaded_keys.add(b2_name) |
|
|
| w_norm_name = "{}.w_norm".format(x) |
| b_norm_name = "{}.b_norm".format(x) |
| w_norm = lora.get(w_norm_name, None) |
| b_norm = lora.get(b_norm_name, None) |
|
|
| if w_norm is not None: |
| loaded_keys.add(w_norm_name) |
| patch_dict[to_load[x]] = ("diff", (w_norm,)) |
| if b_norm is not None: |
| loaded_keys.add(b_norm_name) |
| patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (b_norm,)) |
|
|
| diff_name = "{}.diff".format(x) |
| diff_weight = lora.get(diff_name, None) |
| if diff_weight is not None: |
| patch_dict[to_load[x]] = ("diff", (diff_weight,)) |
| loaded_keys.add(diff_name) |
|
|
| diff_bias_name = "{}.diff_b".format(x) |
| diff_bias = lora.get(diff_bias_name, None) |
| if diff_bias is not None: |
| patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,)) |
| loaded_keys.add(diff_bias_name) |
|
|
| set_weight_name = "{}.set_weight".format(x) |
| set_weight = lora.get(set_weight_name, None) |
| if set_weight is not None: |
| patch_dict[to_load[x]] = ("set", (set_weight,)) |
| loaded_keys.add(set_weight_name) |
|
|
| if log_missing: |
| for x in lora.keys(): |
| if x not in loaded_keys: |
| logging.warning("lora key not loaded: {}".format(x)) |
|
|
| return patch_dict |
|
|
| def model_lora_keys_clip(model, key_map={}): |
| sdk = model.state_dict().keys() |
| for k in sdk: |
| if k.endswith(".weight"): |
| key_map["text_encoders.{}".format(k[:-len(".weight")])] = k |
|
|
| text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}" |
| clip_l_present = False |
| clip_g_present = False |
| for b in range(32): |
| for c in LORA_CLIP_MAP: |
| k = "clip_h.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) |
| if k in sdk: |
| lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c]) |
| key_map[lora_key] = k |
| lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) |
| key_map[lora_key] = k |
| lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) |
| key_map[lora_key] = k |
|
|
| k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) |
| if k in sdk: |
| lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c]) |
| key_map[lora_key] = k |
| lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) |
| key_map[lora_key] = k |
| clip_l_present = True |
| lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) |
| key_map[lora_key] = k |
|
|
| k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) |
| if k in sdk: |
| clip_g_present = True |
| if clip_l_present: |
| lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) |
| key_map[lora_key] = k |
| lora_key = "text_encoder_2.text_model.encoder.layers.{}.{}".format(b, c) |
| key_map[lora_key] = k |
| else: |
| lora_key = "lora_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) |
| key_map[lora_key] = k |
| lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) |
| key_map[lora_key] = k |
| lora_key = "lora_prior_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) |
| key_map[lora_key] = k |
|
|
| for k in sdk: |
| if k.endswith(".weight"): |
| if k.startswith("t5xxl.transformer."): |
| l_key = k[len("t5xxl.transformer."):-len(".weight")] |
| t5_index = 1 |
| if clip_g_present: |
| t5_index += 1 |
| if clip_l_present: |
| t5_index += 1 |
| if t5_index == 2: |
| key_map["lora_te{}_{}".format(t5_index, l_key.replace(".", "_"))] = k |
| t5_index += 1 |
|
|
| key_map["lora_te{}_{}".format(t5_index, l_key.replace(".", "_"))] = k |
| elif k.startswith("hydit_clip.transformer.bert."): |
| l_key = k[len("hydit_clip.transformer.bert."):-len(".weight")] |
| lora_key = "lora_te1_{}".format(l_key.replace(".", "_")) |
| key_map[lora_key] = k |
|
|
|
|
| k = "clip_g.transformer.text_projection.weight" |
| if k in sdk: |
| key_map["lora_prior_te_text_projection"] = k |
| |
| key_map["lora_te2_text_projection"] = k |
|
|
| k = "clip_l.transformer.text_projection.weight" |
| if k in sdk: |
| key_map["lora_te1_text_projection"] = k |
|
|
| return key_map |
|
|
| def model_lora_keys_unet(model, key_map={}): |
| sd = model.state_dict() |
| sdk = sd.keys() |
|
|
| for k in sdk: |
| if k.startswith("diffusion_model."): |
| if k.endswith(".weight"): |
| key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_") |
| key_map["lora_unet_{}".format(key_lora)] = k |
| key_map["{}".format(k[:-len(".weight")])] = k |
| else: |
| key_map["{}".format(k)] = k |
|
|
| diffusers_keys = comfy.utils.unet_to_diffusers(model.model_config.unet_config) |
| for k in diffusers_keys: |
| if k.endswith(".weight"): |
| unet_key = "diffusion_model.{}".format(diffusers_keys[k]) |
| key_lora = k[:-len(".weight")].replace(".", "_") |
| key_map["lora_unet_{}".format(key_lora)] = unet_key |
| key_map["lycoris_{}".format(key_lora)] = unet_key |
|
|
| diffusers_lora_prefix = ["", "unet."] |
| for p in diffusers_lora_prefix: |
| diffusers_lora_key = "{}{}".format(p, k[:-len(".weight")].replace(".to_", ".processor.to_")) |
| if diffusers_lora_key.endswith(".to_out.0"): |
| diffusers_lora_key = diffusers_lora_key[:-2] |
| key_map[diffusers_lora_key] = unet_key |
|
|
| if isinstance(model, comfy.model_base.StableCascade_C): |
| for k in sdk: |
| if k.startswith("diffusion_model."): |
| if k.endswith(".weight"): |
| key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_") |
| key_map["lora_prior_unet_{}".format(key_lora)] = k |
|
|
| if isinstance(model, comfy.model_base.SD3): |
| diffusers_keys = comfy.utils.mmdit_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.") |
| for k in diffusers_keys: |
| if k.endswith(".weight"): |
| to = diffusers_keys[k] |
| key_lora = "transformer.{}".format(k[:-len(".weight")]) |
| key_map[key_lora] = to |
|
|
| key_lora = "base_model.model.{}".format(k[:-len(".weight")]) |
| key_map[key_lora] = to |
|
|
| key_lora = "lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_")) |
| key_map[key_lora] = to |
|
|
| key_lora = "lycoris_{}".format(k[:-len(".weight")].replace(".", "_")) |
| key_map[key_lora] = to |
|
|
| if isinstance(model, comfy.model_base.AuraFlow): |
| diffusers_keys = comfy.utils.auraflow_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.") |
| for k in diffusers_keys: |
| if k.endswith(".weight"): |
| to = diffusers_keys[k] |
| key_lora = "transformer.{}".format(k[:-len(".weight")]) |
| key_map[key_lora] = to |
|
|
| if isinstance(model, comfy.model_base.PixArt): |
| diffusers_keys = comfy.utils.pixart_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.") |
| for k in diffusers_keys: |
| if k.endswith(".weight"): |
| to = diffusers_keys[k] |
| key_lora = "transformer.{}".format(k[:-len(".weight")]) |
| key_map[key_lora] = to |
|
|
| key_lora = "base_model.model.{}".format(k[:-len(".weight")]) |
| key_map[key_lora] = to |
|
|
| key_lora = "unet.base_model.model.{}".format(k[:-len(".weight")]) |
| key_map[key_lora] = to |
|
|
| if isinstance(model, comfy.model_base.HunyuanDiT): |
| for k in sdk: |
| if k.startswith("diffusion_model.") and k.endswith(".weight"): |
| key_lora = k[len("diffusion_model."):-len(".weight")] |
| key_map["base_model.model.{}".format(key_lora)] = k |
|
|
| if isinstance(model, comfy.model_base.Flux): |
| diffusers_keys = comfy.utils.flux_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.") |
| for k in diffusers_keys: |
| if k.endswith(".weight"): |
| to = diffusers_keys[k] |
| key_map["transformer.{}".format(k[:-len(".weight")])] = to |
| key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to |
| key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to |
|
|
| if isinstance(model, comfy.model_base.GenmoMochi): |
| for k in sdk: |
| if k.startswith("diffusion_model.") and k.endswith(".weight"): |
| key_lora = k[len("diffusion_model."):-len(".weight")] |
| key_map["{}".format(key_lora)] = k |
|
|
| if isinstance(model, comfy.model_base.HunyuanVideo): |
| for k in sdk: |
| if k.startswith("diffusion_model.") and k.endswith(".weight"): |
| |
| key_lora = k |
| key_lora = key_lora.replace("_mod.lin.", "_mod.linear.").replace("_attn.qkv.", "_attn_qkv.").replace("_attn.proj.", "_attn_proj.") |
| key_lora = key_lora.replace("mlp.0.", "mlp.fc1.").replace("mlp.2.", "mlp.fc2.") |
| key_lora = key_lora.replace(".modulation.lin.", ".modulation.linear.") |
| key_lora = key_lora[len("diffusion_model."):-len(".weight")] |
| key_map["transformer.{}".format(key_lora)] = k |
| key_map["diffusion_model.{}".format(key_lora)] = k |
|
|
| return key_map |
|
|
|
|
| def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function): |
| dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype) |
| lora_diff *= alpha |
| weight_calc = weight + function(lora_diff).type(weight.dtype) |
| weight_norm = ( |
| weight_calc.transpose(0, 1) |
| .reshape(weight_calc.shape[1], -1) |
| .norm(dim=1, keepdim=True) |
| .reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1)) |
| .transpose(0, 1) |
| ) |
|
|
| weight_calc *= (dora_scale / weight_norm).type(weight.dtype) |
| if strength != 1.0: |
| weight_calc -= weight |
| weight += strength * (weight_calc) |
| else: |
| weight[:] = weight_calc |
| return weight |
|
|
| def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Tensor: |
| """ |
| Pad a tensor to a new shape with zeros. |
| |
| Args: |
| tensor (torch.Tensor): The original tensor to be padded. |
| new_shape (List[int]): The desired shape of the padded tensor. |
| |
| Returns: |
| torch.Tensor: A new tensor padded with zeros to the specified shape. |
| |
| Note: |
| If the new shape is smaller than the original tensor in any dimension, |
| the original tensor will be truncated in that dimension. |
| """ |
| if any([new_shape[i] < tensor.shape[i] for i in range(len(new_shape))]): |
| raise ValueError("The new shape must be larger than the original tensor in all dimensions") |
|
|
| if len(new_shape) != len(tensor.shape): |
| raise ValueError("The new shape must have the same number of dimensions as the original tensor") |
|
|
| |
| padded_tensor = torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) |
|
|
| |
| orig_slices = tuple(slice(0, dim) for dim in tensor.shape) |
| new_slices = tuple(slice(0, dim) for dim in tensor.shape) |
|
|
| |
| padded_tensor[new_slices] = tensor[orig_slices] |
|
|
| return padded_tensor |
|
|
| def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, original_weights=None): |
| for p in patches: |
| strength = p[0] |
| v = p[1] |
| strength_model = p[2] |
| offset = p[3] |
| function = p[4] |
| if function is None: |
| function = lambda a: a |
|
|
| old_weight = None |
| if offset is not None: |
| old_weight = weight |
| weight = weight.narrow(offset[0], offset[1], offset[2]) |
|
|
| if strength_model != 1.0: |
| weight *= strength_model |
|
|
| if isinstance(v, list): |
| v = (calculate_weight(v[1:], v[0][1](comfy.model_management.cast_to_device(v[0][0], weight.device, intermediate_dtype, copy=True), inplace=True), key, intermediate_dtype=intermediate_dtype), ) |
|
|
| if len(v) == 1: |
| patch_type = "diff" |
| elif len(v) == 2: |
| patch_type = v[0] |
| v = v[1] |
|
|
| if patch_type == "diff": |
| diff: torch.Tensor = v[0] |
| |
| do_pad_weight = len(v) > 1 and v[1]['pad_weight'] |
| if do_pad_weight and diff.shape != weight.shape: |
| logging.info("Pad weight {} from {} to shape: {}".format(key, weight.shape, diff.shape)) |
| weight = pad_tensor_to_shape(weight, diff.shape) |
|
|
| if strength != 0.0: |
| if diff.shape != weight.shape: |
| logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, diff.shape, weight.shape)) |
| else: |
| weight += function(strength * comfy.model_management.cast_to_device(diff, weight.device, weight.dtype)) |
| elif patch_type == "set": |
| weight.copy_(v[0]) |
| elif patch_type == "model_as_lora": |
| target_weight: torch.Tensor = v[0] |
| diff_weight = comfy.model_management.cast_to_device(target_weight, weight.device, intermediate_dtype) - \ |
| comfy.model_management.cast_to_device(original_weights[key][0][0], weight.device, intermediate_dtype) |
| weight += function(strength * comfy.model_management.cast_to_device(diff_weight, weight.device, weight.dtype)) |
| elif patch_type == "lora": |
| mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype) |
| mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype) |
| dora_scale = v[4] |
| reshape = v[5] |
|
|
| if reshape is not None: |
| weight = pad_tensor_to_shape(weight, reshape) |
|
|
| if v[2] is not None: |
| alpha = v[2] / mat2.shape[0] |
| else: |
| alpha = 1.0 |
|
|
| if v[3] is not None: |
| |
| mat3 = comfy.model_management.cast_to_device(v[3], weight.device, intermediate_dtype) |
| final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]] |
| mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1) |
| try: |
| lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape) |
| if dora_scale is not None: |
| weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) |
| else: |
| weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) |
| except Exception as e: |
| logging.error("ERROR {} {} {}".format(patch_type, key, e)) |
| elif patch_type == "lokr": |
| w1 = v[0] |
| w2 = v[1] |
| w1_a = v[3] |
| w1_b = v[4] |
| w2_a = v[5] |
| w2_b = v[6] |
| t2 = v[7] |
| dora_scale = v[8] |
| dim = None |
|
|
| if w1 is None: |
| dim = w1_b.shape[0] |
| w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype), |
| comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype)) |
| else: |
| w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype) |
|
|
| if w2 is None: |
| dim = w2_b.shape[0] |
| if t2 is None: |
| w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype), |
| comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype)) |
| else: |
| w2 = torch.einsum('i j k l, j r, i p -> p r k l', |
| comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype), |
| comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype), |
| comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype)) |
| else: |
| w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype) |
|
|
| if len(w2.shape) == 4: |
| w1 = w1.unsqueeze(2).unsqueeze(2) |
| if v[2] is not None and dim is not None: |
| alpha = v[2] / dim |
| else: |
| alpha = 1.0 |
|
|
| try: |
| lora_diff = torch.kron(w1, w2).reshape(weight.shape) |
| if dora_scale is not None: |
| weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) |
| else: |
| weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) |
| except Exception as e: |
| logging.error("ERROR {} {} {}".format(patch_type, key, e)) |
| elif patch_type == "loha": |
| w1a = v[0] |
| w1b = v[1] |
| if v[2] is not None: |
| alpha = v[2] / w1b.shape[0] |
| else: |
| alpha = 1.0 |
|
|
| w2a = v[3] |
| w2b = v[4] |
| dora_scale = v[7] |
| if v[5] is not None: |
| t1 = v[5] |
| t2 = v[6] |
| m1 = torch.einsum('i j k l, j r, i p -> p r k l', |
| comfy.model_management.cast_to_device(t1, weight.device, intermediate_dtype), |
| comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype), |
| comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype)) |
|
|
| m2 = torch.einsum('i j k l, j r, i p -> p r k l', |
| comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype), |
| comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype), |
| comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype)) |
| else: |
| m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype), |
| comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype)) |
| m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype), |
| comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype)) |
|
|
| try: |
| lora_diff = (m1 * m2).reshape(weight.shape) |
| if dora_scale is not None: |
| weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) |
| else: |
| weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) |
| except Exception as e: |
| logging.error("ERROR {} {} {}".format(patch_type, key, e)) |
| elif patch_type == "glora": |
| dora_scale = v[5] |
|
|
| old_glora = False |
| if v[3].shape[1] == v[2].shape[0] == v[0].shape[0] == v[1].shape[1]: |
| rank = v[0].shape[0] |
| old_glora = True |
|
|
| if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]: |
| if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]: |
| pass |
| else: |
| old_glora = False |
| rank = v[1].shape[0] |
|
|
| a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype) |
| a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype) |
| b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype) |
| b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype) |
|
|
| if v[4] is not None: |
| alpha = v[4] / rank |
| else: |
| alpha = 1.0 |
|
|
| try: |
| if old_glora: |
| lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) |
| else: |
| if weight.dim() > 2: |
| lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape) |
| else: |
| lora_diff = torch.mm(torch.mm(weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape) |
| lora_diff += torch.mm(b1, b2).reshape(weight.shape) |
|
|
| if dora_scale is not None: |
| weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) |
| else: |
| weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) |
| except Exception as e: |
| logging.error("ERROR {} {} {}".format(patch_type, key, e)) |
| else: |
| logging.warning("patch type not recognized {} {}".format(patch_type, key)) |
|
|
| if old_weight is not None: |
| weight = old_weight |
|
|
| return weight |
|
|