|
|
import torch |
|
|
import torch.nn as nn |
|
|
from accelerate import init_empty_weights |
|
|
from comfy.ops import cast_bias_weight |
|
|
|
|
|
|
|
|
def _replace_linear(model, compute_dtype, state_dict, prefix="", patches=None, scale_weights=None): |
|
|
|
|
|
has_children = list(model.children()) |
|
|
if not has_children: |
|
|
return |
|
|
for name, module in model.named_children(): |
|
|
module_prefix = prefix + name + "." |
|
|
_replace_linear(module, compute_dtype, state_dict, module_prefix, patches, scale_weights) |
|
|
|
|
|
if isinstance(module, nn.Linear) and "loras" not in module_prefix: |
|
|
in_features = state_dict[module_prefix + "weight"].shape[1] |
|
|
out_features = state_dict[module_prefix + "weight"].shape[0] |
|
|
if scale_weights is not None: |
|
|
scale_key = f"{module_prefix}scale_weight" |
|
|
|
|
|
with init_empty_weights(): |
|
|
model._modules[name] = CustomLinear( |
|
|
in_features, |
|
|
out_features, |
|
|
module.bias is not None, |
|
|
compute_dtype=compute_dtype, |
|
|
scale_weight=scale_weights.get(scale_key) if scale_weights else None |
|
|
) |
|
|
|
|
|
model._modules[name].source_cls = type(module) |
|
|
|
|
|
model._modules[name].requires_grad_(False) |
|
|
|
|
|
return model |
|
|
|
|
|
def set_lora_params(module, patches, module_prefix=""): |
|
|
|
|
|
for name, child in module.named_children(): |
|
|
child_prefix = (f"{module_prefix}{name}.") |
|
|
set_lora_params(child, patches, child_prefix) |
|
|
if isinstance(module, CustomLinear): |
|
|
key = f"diffusion_model.{module_prefix}weight" |
|
|
patch = patches.get(key, []) |
|
|
|
|
|
if len(patch) != 0: |
|
|
lora_diffs = [] |
|
|
for p in patch: |
|
|
lora_obj = p[1] |
|
|
if "head" in key: |
|
|
continue |
|
|
elif hasattr(lora_obj, "weights"): |
|
|
lora_diffs.append(lora_obj.weights) |
|
|
elif isinstance(lora_obj, tuple) and lora_obj[0] == "diff": |
|
|
lora_diffs.append(lora_obj[1]) |
|
|
else: |
|
|
continue |
|
|
lora_strengths = [p[0] for p in patch] |
|
|
module.lora = (lora_diffs, lora_strengths) |
|
|
module.step = 0 |
|
|
|
|
|
|
|
|
class CustomLinear(nn.Linear): |
|
|
def __init__( |
|
|
self, |
|
|
in_features, |
|
|
out_features, |
|
|
bias=False, |
|
|
compute_dtype=None, |
|
|
device=None, |
|
|
scale_weight=None |
|
|
) -> None: |
|
|
super().__init__(in_features, out_features, bias, device) |
|
|
self.compute_dtype = compute_dtype |
|
|
self.lora = None |
|
|
self.step = 0 |
|
|
self.scale_weight = scale_weight |
|
|
self.bias_function = [] |
|
|
self.weight_function = [] |
|
|
|
|
|
def forward(self, input): |
|
|
weight, bias = cast_bias_weight(self, input) |
|
|
|
|
|
if self.scale_weight is not None: |
|
|
if weight.numel() < input.numel(): |
|
|
weight = weight * self.scale_weight |
|
|
else: |
|
|
input = input * self.scale_weight |
|
|
|
|
|
if self.lora is not None: |
|
|
weight = self.apply_lora(weight).to(self.compute_dtype) |
|
|
|
|
|
return torch.nn.functional.linear(input, weight, bias) |
|
|
|
|
|
@torch.compiler.disable() |
|
|
def apply_lora(self, weight): |
|
|
for lora_diff, lora_strength in zip(self.lora[0], self.lora[1]): |
|
|
if isinstance(lora_strength, list): |
|
|
lora_strength = lora_strength[self.step] |
|
|
if lora_strength == 0.0: |
|
|
continue |
|
|
elif lora_strength == 0.0: |
|
|
continue |
|
|
patch_diff = torch.mm( |
|
|
lora_diff[0].flatten(start_dim=1).to(weight.device), |
|
|
lora_diff[1].flatten(start_dim=1).to(weight.device) |
|
|
).reshape(weight.shape) |
|
|
alpha = lora_diff[2] / lora_diff[1].shape[0] if lora_diff[2] is not None else 1.0 |
|
|
scale = lora_strength * alpha |
|
|
weight = weight.add(patch_diff, alpha=scale) |
|
|
return weight |
|
|
|
|
|
def remove_lora_from_module(module): |
|
|
for name, submodule in module.named_modules(): |
|
|
submodule.lora = None |