| import logging |
| from typing import Optional |
|
|
| import torch |
| import comfy.model_management |
| from .base import ( |
| WeightAdapterBase, |
| WeightAdapterTrainBase, |
| weight_decompose, |
| pad_tensor_to_shape, |
| tucker_weight_from_conv, |
| ) |
|
|
|
|
| class LoraDiff(WeightAdapterTrainBase): |
| def __init__(self, weights): |
| super().__init__() |
| mat1, mat2, alpha, mid, dora_scale, reshape = weights |
| out_dim, rank = mat1.shape[0], mat1.shape[1] |
| rank, in_dim = mat2.shape[0], mat2.shape[1] |
| if mid is not None: |
| convdim = mid.ndim - 2 |
| layer = ( |
| torch.nn.Conv1d, |
| torch.nn.Conv2d, |
| torch.nn.Conv3d |
| )[convdim] |
| else: |
| layer = torch.nn.Linear |
| self.lora_up = layer(rank, out_dim, bias=False) |
| self.lora_down = layer(in_dim, rank, bias=False) |
| self.lora_up.weight.data.copy_(mat1) |
| self.lora_down.weight.data.copy_(mat2) |
| if mid is not None: |
| self.lora_mid = layer(mid, rank, bias=False) |
| self.lora_mid.weight.data.copy_(mid) |
| else: |
| self.lora_mid = None |
| self.rank = rank |
| self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False) |
|
|
| def __call__(self, w): |
| org_dtype = w.dtype |
| if self.lora_mid is None: |
| diff = self.lora_up.weight @ self.lora_down.weight |
| else: |
| diff = tucker_weight_from_conv( |
| self.lora_up.weight, self.lora_down.weight, self.lora_mid.weight |
| ) |
| scale = self.alpha / self.rank |
| weight = w + scale * diff.reshape(w.shape) |
| return weight.to(org_dtype) |
|
|
| def passive_memory_usage(self): |
| return sum(param.numel() * param.element_size() for param in self.parameters()) |
|
|
|
|
| class LoRAAdapter(WeightAdapterBase): |
| name = "lora" |
|
|
| def __init__(self, loaded_keys, weights): |
| self.loaded_keys = loaded_keys |
| self.weights = weights |
|
|
| @classmethod |
| def create_train(cls, weight, rank=1, alpha=1.0): |
| out_dim = weight.shape[0] |
| in_dim = weight.shape[1:].numel() |
| mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype) |
| mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype) |
| torch.nn.init.kaiming_uniform_(mat1, a=5**0.5) |
| torch.nn.init.constant_(mat2, 0.0) |
| return LoraDiff( |
| (mat1, mat2, alpha, None, None, None) |
| ) |
|
|
| def to_train(self): |
| return LoraDiff(self.weights) |
|
|
| @classmethod |
| def load( |
| cls, |
| x: str, |
| lora: dict[str, torch.Tensor], |
| alpha: float, |
| dora_scale: torch.Tensor, |
| loaded_keys: set[str] = None, |
| ) -> Optional["LoRAAdapter"]: |
| if loaded_keys is None: |
| loaded_keys = set() |
|
|
| reshape_name = "{}.reshape_weight".format(x) |
| regular_lora = "{}.lora_up.weight".format(x) |
| diffusers_lora = "{}_lora.up.weight".format(x) |
| diffusers2_lora = "{}.lora_B.weight".format(x) |
| diffusers3_lora = "{}.lora.up.weight".format(x) |
| mochi_lora = "{}.lora_B".format(x) |
| transformers_lora = "{}.lora_linear_layer.up.weight".format(x) |
| A_name = None |
|
|
| if regular_lora in lora.keys(): |
| A_name = regular_lora |
| B_name = "{}.lora_down.weight".format(x) |
| mid_name = "{}.lora_mid.weight".format(x) |
| elif diffusers_lora in lora.keys(): |
| A_name = diffusers_lora |
| B_name = "{}_lora.down.weight".format(x) |
| mid_name = None |
| elif diffusers2_lora in lora.keys(): |
| A_name = diffusers2_lora |
| B_name = "{}.lora_A.weight".format(x) |
| mid_name = None |
| elif diffusers3_lora in lora.keys(): |
| A_name = diffusers3_lora |
| B_name = "{}.lora.down.weight".format(x) |
| mid_name = None |
| elif mochi_lora in lora.keys(): |
| A_name = mochi_lora |
| B_name = "{}.lora_A".format(x) |
| mid_name = None |
| elif transformers_lora in lora.keys(): |
| A_name = transformers_lora |
| B_name = "{}.lora_linear_layer.down.weight".format(x) |
| mid_name = None |
|
|
| if A_name is not None: |
| mid = None |
| if mid_name is not None and mid_name in lora.keys(): |
| mid = lora[mid_name] |
| loaded_keys.add(mid_name) |
| reshape = None |
| if reshape_name in lora.keys(): |
| try: |
| reshape = lora[reshape_name].tolist() |
| loaded_keys.add(reshape_name) |
| except: |
| pass |
| weights = (lora[A_name], lora[B_name], alpha, mid, dora_scale, reshape) |
| loaded_keys.add(A_name) |
| loaded_keys.add(B_name) |
| return cls(loaded_keys, weights) |
| else: |
| return None |
|
|
| def calculate_weight( |
| self, |
| weight, |
| key, |
| strength, |
| strength_model, |
| offset, |
| function, |
| intermediate_dtype=torch.float32, |
| original_weight=None, |
| ): |
| v = self.weights |
| mat1 = comfy.model_management.cast_to_device( |
| v[0], weight.device, intermediate_dtype |
| ) |
| mat2 = comfy.model_management.cast_to_device( |
| v[1], weight.device, intermediate_dtype |
| ) |
| dora_scale = v[4] |
| reshape = v[5] |
|
|
| if reshape is not None: |
| weight = pad_tensor_to_shape(weight, reshape) |
|
|
| if v[2] is not None: |
| alpha = v[2] / mat2.shape[0] |
| else: |
| alpha = 1.0 |
|
|
| if v[3] is not None: |
| |
| mat3 = comfy.model_management.cast_to_device( |
| v[3], weight.device, intermediate_dtype |
| ) |
| final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]] |
| mat2 = ( |
| torch.mm( |
| mat2.transpose(0, 1).flatten(start_dim=1), |
| mat3.transpose(0, 1).flatten(start_dim=1), |
| ) |
| .reshape(final_shape) |
| .transpose(0, 1) |
| ) |
| try: |
| lora_diff = torch.mm( |
| mat1.flatten(start_dim=1), mat2.flatten(start_dim=1) |
| ).reshape(weight.shape) |
| if dora_scale is not None: |
| weight = weight_decompose( |
| dora_scale, |
| weight, |
| lora_diff, |
| alpha, |
| strength, |
| intermediate_dtype, |
| function, |
| ) |
| else: |
| weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) |
| except Exception as e: |
| logging.error("ERROR {} {} {}".format(self.name, key, e)) |
| return weight |
|
|