| import logging |
| from typing import Optional |
|
|
| import torch |
| import comfy.model_management |
| from .base import WeightAdapterBase, WeightAdapterTrainBase, weight_decompose, factorization |
|
|
|
|
| class OFTDiff(WeightAdapterTrainBase): |
| def __init__(self, weights): |
| super().__init__() |
| |
| blocks, rescale, alpha, _ = weights |
|
|
| |
| self.oft_blocks = torch.nn.Parameter(blocks) |
| if rescale is not None: |
| self.rescale = torch.nn.Parameter(rescale) |
| self.rescaled = True |
| else: |
| self.rescaled = False |
| self.block_num, self.block_size, _ = blocks.shape |
| self.constraint = float(alpha) |
| self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False) |
|
|
| def __call__(self, w): |
| org_dtype = w.dtype |
| I = torch.eye(self.block_size, device=self.oft_blocks.device) |
|
|
| |
| |
| q = self.oft_blocks - self.oft_blocks.transpose(1, 2) |
| normed_q = q |
| if self.constraint: |
| q_norm = torch.norm(q) + 1e-8 |
| if q_norm > self.constraint: |
| normed_q = q * self.constraint / q_norm |
| |
| r = (I + normed_q) @ (I - normed_q).float().inverse() |
|
|
| |
| _, *shape = w.shape |
| org_weight = w.to(dtype=r.dtype) |
| org_weight = org_weight.unflatten(0, (self.block_num, self.block_size)) |
| |
| weight = torch.einsum( |
| "k n m, k n ... -> k m ...", |
| r, |
| org_weight, |
| ).flatten(0, 1) |
| if self.rescaled: |
| weight = self.rescale * weight |
| return weight.to(org_dtype) |
|
|
| def passive_memory_usage(self): |
| """Calculates memory usage of the trainable parameters.""" |
| return sum(param.numel() * param.element_size() for param in self.parameters()) |
|
|
|
|
| class OFTAdapter(WeightAdapterBase): |
| name = "oft" |
|
|
| def __init__(self, loaded_keys, weights): |
| self.loaded_keys = loaded_keys |
| self.weights = weights |
|
|
| @classmethod |
| def create_train(cls, weight, rank=1, alpha=1.0): |
| out_dim = weight.shape[0] |
| block_size, block_num = factorization(out_dim, rank) |
| block = torch.zeros(block_num, block_size, block_size, device=weight.device, dtype=weight.dtype) |
| return OFTDiff( |
| (block, None, alpha, None) |
| ) |
|
|
| def to_train(self): |
| return OFTDiff(self.weights) |
|
|
| @classmethod |
| def load( |
| cls, |
| x: str, |
| lora: dict[str, torch.Tensor], |
| alpha: float, |
| dora_scale: torch.Tensor, |
| loaded_keys: set[str] = None, |
| ) -> Optional["OFTAdapter"]: |
| if loaded_keys is None: |
| loaded_keys = set() |
| blocks_name = "{}.oft_blocks".format(x) |
| rescale_name = "{}.rescale".format(x) |
|
|
| blocks = None |
| if blocks_name in lora.keys(): |
| blocks = lora[blocks_name] |
| if blocks.ndim == 3: |
| loaded_keys.add(blocks_name) |
| else: |
| blocks = None |
| if blocks is None: |
| return None |
|
|
| rescale = None |
| if rescale_name in lora.keys(): |
| rescale = lora[rescale_name] |
| loaded_keys.add(rescale_name) |
|
|
| weights = (blocks, rescale, alpha, dora_scale) |
| return cls(loaded_keys, weights) |
|
|
| def calculate_weight( |
| self, |
| weight, |
| key, |
| strength, |
| strength_model, |
| offset, |
| function, |
| intermediate_dtype=torch.float32, |
| original_weight=None, |
| ): |
| v = self.weights |
| blocks = v[0] |
| rescale = v[1] |
| alpha = v[2] |
| if alpha is None: |
| alpha = 0 |
| dora_scale = v[3] |
|
|
| blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype) |
| if rescale is not None: |
| rescale = comfy.model_management.cast_to_device(rescale, weight.device, intermediate_dtype) |
|
|
| block_num, block_size, *_ = blocks.shape |
|
|
| try: |
| |
| I = torch.eye(block_size, device=blocks.device, dtype=blocks.dtype) |
| |
| q = blocks - blocks.transpose(1, 2) |
| normed_q = q |
| if alpha > 0: |
| q_norm = torch.norm(q) + 1e-8 |
| if q_norm > alpha: |
| normed_q = q * alpha / q_norm |
| |
| r = (I + normed_q) @ (I - normed_q).float().inverse() |
| r = r.to(weight) |
| _, *shape = weight.shape |
| lora_diff = torch.einsum( |
| "k n m, k n ... -> k m ...", |
| (r * strength) - strength * I, |
| weight.view(block_num, block_size, *shape), |
| ).view(-1, *shape) |
| if dora_scale is not None: |
| weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) |
| else: |
| weight += function((strength * lora_diff).type(weight.dtype)) |
| except Exception as e: |
| logging.error("ERROR {} {} {}".format(self.name, key, e)) |
| return weight |
|
|