| import logging
|
| from typing import Optional
|
|
|
| import torch
|
| import comfy.model_management
|
| from .base import WeightAdapterBase, weight_decompose
|
|
|
|
|
| class GLoRAAdapter(WeightAdapterBase):
|
| name = "glora"
|
|
|
| def __init__(self, loaded_keys, weights):
|
| self.loaded_keys = loaded_keys
|
| self.weights = weights
|
|
|
| @classmethod
|
| def load(
|
| cls,
|
| x: str,
|
| lora: dict[str, torch.Tensor],
|
| alpha: float,
|
| dora_scale: torch.Tensor,
|
| loaded_keys: set[str] = None,
|
| ) -> Optional["GLoRAAdapter"]:
|
| if loaded_keys is None:
|
| loaded_keys = set()
|
| a1_name = "{}.a1.weight".format(x)
|
| a2_name = "{}.a2.weight".format(x)
|
| b1_name = "{}.b1.weight".format(x)
|
| b2_name = "{}.b2.weight".format(x)
|
| if a1_name in lora:
|
| weights = (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale)
|
| loaded_keys.add(a1_name)
|
| loaded_keys.add(a2_name)
|
| loaded_keys.add(b1_name)
|
| loaded_keys.add(b2_name)
|
| return cls(loaded_keys, weights)
|
| else:
|
| return None
|
|
|
| def calculate_weight(
|
| self,
|
| weight,
|
| key,
|
| strength,
|
| strength_model,
|
| offset,
|
| function,
|
| intermediate_dtype=torch.float32,
|
| original_weight=None,
|
| ):
|
| v = self.weights
|
| dora_scale = v[5]
|
|
|
| old_glora = False
|
| if v[3].shape[1] == v[2].shape[0] == v[0].shape[0] == v[1].shape[1]:
|
| rank = v[0].shape[0]
|
| old_glora = True
|
|
|
| if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]:
|
| if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]:
|
| pass
|
| else:
|
| old_glora = False
|
| rank = v[1].shape[0]
|
|
|
| a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
|
| a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
|
| b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
|
| b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
|
|
|
| if v[4] is not None:
|
| alpha = v[4] / rank
|
| else:
|
| alpha = 1.0
|
|
|
| try:
|
| if old_glora:
|
| lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape)
|
| else:
|
| if weight.dim() > 2:
|
| lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
|
| else:
|
| lora_diff = torch.mm(torch.mm(weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
|
| lora_diff += torch.mm(b1, b2).reshape(weight.shape)
|
|
|
| if dora_scale is not None:
|
| weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
| else:
|
| weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
| except Exception as e:
|
| logging.error("ERROR {} {} {}".format(self.name, key, e))
|
| return weight
|
|
|