|
|
import torch |
|
|
import comfy.model_management |
|
|
import comfy.utils |
|
|
import folder_paths |
|
|
import os |
|
|
import logging |
|
|
from tqdm import tqdm |
|
|
import numpy as np |
|
|
|
|
|
device = comfy.model_management.get_torch_device() |
|
|
|
|
|
CLAMP_QUANTILE = 0.99 |
|
|
|
|
|
def extract_lora(diff, key, rank, algorithm, lora_type, lowrank_iters=7, adaptive_param=1.0, clamp_quantile=True): |
|
|
""" |
|
|
Extracts LoRA weights from a weight difference tensor using SVD. |
|
|
""" |
|
|
conv2d = (len(diff.shape) == 4) |
|
|
kernel_size = None if not conv2d else diff.size()[2:4] |
|
|
conv2d_3x3 = conv2d and kernel_size != (1, 1) |
|
|
out_dim, in_dim = diff.size()[0:2] |
|
|
|
|
|
if conv2d: |
|
|
if conv2d_3x3: |
|
|
diff = diff.flatten(start_dim=1) |
|
|
else: |
|
|
diff = diff.squeeze() |
|
|
|
|
|
diff_float = diff.float() |
|
|
if algorithm == "svd_lowrank": |
|
|
U, S, V = torch.svd_lowrank(diff_float, q=min(rank, in_dim, out_dim), niter=lowrank_iters) |
|
|
U = U @ torch.diag(S) |
|
|
Vh = V.t() |
|
|
else: |
|
|
|
|
|
U, S, Vh = torch.linalg.svd(diff_float) |
|
|
|
|
|
if "adaptive" in lora_type: |
|
|
if lora_type == "adaptive_ratio": |
|
|
min_s = torch.max(S) * adaptive_param |
|
|
lora_rank = torch.sum(S > min_s).item() |
|
|
elif lora_type == "adaptive_energy": |
|
|
energy = torch.cumsum(S**2, dim=0) |
|
|
total_energy = torch.sum(S**2) |
|
|
threshold = adaptive_param * total_energy |
|
|
lora_rank = torch.sum(energy < threshold).item() + 1 |
|
|
elif lora_type == "adaptive_quantile": |
|
|
s_cum = torch.cumsum(S, dim=0) |
|
|
min_cum_sum = adaptive_param * torch.sum(S) |
|
|
lora_rank = torch.sum(s_cum < min_cum_sum).item() |
|
|
elif lora_type == "adaptive_fro": |
|
|
S_squared = S.pow(2) |
|
|
S_fro_sq = float(torch.sum(S_squared)) |
|
|
sum_S_squared = torch.cumsum(S_squared, dim=0) / S_fro_sq |
|
|
lora_rank = int(torch.searchsorted(sum_S_squared, adaptive_param**2)) + 1 |
|
|
lora_rank = max(1, min(lora_rank, len(S))) |
|
|
else: |
|
|
pass |
|
|
|
|
|
|
|
|
lora_rank = min(lora_rank, rank) |
|
|
|
|
|
|
|
|
if lora_type == "adaptive_fro": |
|
|
S_squared = S.pow(2) |
|
|
s_fro = torch.sqrt(torch.sum(S_squared)) |
|
|
s_red_fro = torch.sqrt(torch.sum(S_squared[:lora_rank])) |
|
|
fro_percent = float(s_red_fro / s_fro) |
|
|
print(f"{key} Extracted LoRA rank: {lora_rank}, Frobenius retained: {fro_percent:.1%}") |
|
|
else: |
|
|
print(f"{key} Extracted LoRA rank: {lora_rank}") |
|
|
else: |
|
|
lora_rank = rank |
|
|
|
|
|
lora_rank = max(1, lora_rank) |
|
|
lora_rank = min(out_dim, in_dim, lora_rank) |
|
|
|
|
|
U = U[:, :lora_rank] |
|
|
S = S[:lora_rank] |
|
|
U = U @ torch.diag(S) |
|
|
Vh = Vh[:lora_rank, :] |
|
|
|
|
|
if clamp_quantile: |
|
|
dist = torch.cat([U.flatten(), Vh.flatten()]) |
|
|
if dist.numel() > 100_000: |
|
|
|
|
|
idx = torch.randperm(dist.numel(), device=dist.device)[:100_000] |
|
|
dist_sample = dist[idx] |
|
|
hi_val = torch.quantile(dist_sample, CLAMP_QUANTILE) |
|
|
else: |
|
|
hi_val = torch.quantile(dist, CLAMP_QUANTILE) |
|
|
low_val = -hi_val |
|
|
|
|
|
U = U.clamp(low_val, hi_val) |
|
|
Vh = Vh.clamp(low_val, hi_val) |
|
|
if conv2d: |
|
|
U = U.reshape(out_dim, lora_rank, 1, 1) |
|
|
Vh = Vh.reshape(lora_rank, in_dim, kernel_size[0], kernel_size[1]) |
|
|
return (U, Vh) |
|
|
|
|
|
|
|
|
def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora_type, algorithm, lowrank_iters, out_dtype, bias_diff=False, adaptive_param=1.0, clamp_quantile=True): |
|
|
comfy.model_management.load_models_gpu([model_diff], force_patch_weights=True) |
|
|
model_diff.model.diffusion_model.cpu() |
|
|
sd = model_diff.model_state_dict(filter_prefix=prefix_model) |
|
|
del model_diff |
|
|
comfy.model_management.soft_empty_cache() |
|
|
for k, v in sd.items(): |
|
|
if isinstance(v, torch.Tensor): |
|
|
sd[k] = v.cpu() |
|
|
|
|
|
|
|
|
total_keys = len([k for k in sd if k.endswith(".weight") or (bias_diff and k.endswith(".bias"))]) |
|
|
|
|
|
|
|
|
progress_bar = tqdm(total=total_keys, desc=f"Extracting LoRA ({prefix_lora.strip('.')})") |
|
|
comfy_pbar = comfy.utils.ProgressBar(total_keys) |
|
|
|
|
|
for k in sd: |
|
|
if k.endswith(".weight"): |
|
|
weight_diff = sd[k] |
|
|
if weight_diff.ndim == 5: |
|
|
logging.info(f"Skipping 5D tensor for key {k}") |
|
|
progress_bar.update(1) |
|
|
comfy_pbar.update(1) |
|
|
continue |
|
|
if lora_type != "full": |
|
|
if weight_diff.ndim < 2: |
|
|
if bias_diff: |
|
|
output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().to(out_dtype).cpu() |
|
|
progress_bar.update(1) |
|
|
comfy_pbar.update(1) |
|
|
continue |
|
|
try: |
|
|
out = extract_lora(weight_diff.to(device), k, rank, algorithm, lora_type, lowrank_iters=lowrank_iters, adaptive_param=adaptive_param, clamp_quantile=clamp_quantile) |
|
|
output_sd["{}{}.lora_up.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[0].contiguous().to(out_dtype).cpu() |
|
|
output_sd["{}{}.lora_down.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[1].contiguous().to(out_dtype).cpu() |
|
|
except Exception as e: |
|
|
logging.warning(f"Could not generate lora weights for key {k}, error {e}") |
|
|
else: |
|
|
output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().to(out_dtype).cpu() |
|
|
|
|
|
progress_bar.update(1) |
|
|
comfy_pbar.update(1) |
|
|
|
|
|
elif bias_diff and k.endswith(".bias"): |
|
|
output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = sd[k].contiguous().to(out_dtype).cpu() |
|
|
progress_bar.update(1) |
|
|
comfy_pbar.update(1) |
|
|
progress_bar.close() |
|
|
return output_sd |
|
|
|
|
|
class LoraExtractKJ: |
|
|
def __init__(self): |
|
|
self.output_dir = folder_paths.get_output_directory() |
|
|
|
|
|
@classmethod |
|
|
def INPUT_TYPES(s): |
|
|
return {"required": |
|
|
{ |
|
|
"finetuned_model": ("MODEL",), |
|
|
"original_model": ("MODEL",), |
|
|
"filename_prefix": ("STRING", {"default": "loras/ComfyUI_extracted_lora"}), |
|
|
"rank": ("INT", {"default": 8, "min": 1, "max": 4096, "step": 1, "tooltip": "The rank to use for standard LoRA, or maximum rank limit for adaptive methods."}), |
|
|
"lora_type": (["standard", "full", "adaptive_ratio", "adaptive_quantile", "adaptive_energy", "adaptive_fro"],), |
|
|
"algorithm": (["svd_linalg", "svd_lowrank"], {"default": "svd_linalg", "tooltip": "SVD algorithm to use, svd_lowrank is faster but less accurate."}), |
|
|
"lowrank_iters": ("INT", {"default": 7, "min": 1, "max": 100, "step": 1, "tooltip": "The number of subspace iterations for lowrank SVD algorithm."}), |
|
|
"output_dtype": (["fp16", "bf16", "fp32"], {"default": "fp16"}), |
|
|
"bias_diff": ("BOOLEAN", {"default": True}), |
|
|
"adaptive_param": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "For ratio mode, this is the ratio of the maximum singular value. For quantile mode, this is the quantile of the singular values. For fro mode, this is the Frobenius norm retention ratio."}), |
|
|
"clamp_quantile": ("BOOLEAN", {"default": True}), |
|
|
}, |
|
|
|
|
|
} |
|
|
RETURN_TYPES = () |
|
|
FUNCTION = "save" |
|
|
OUTPUT_NODE = True |
|
|
|
|
|
CATEGORY = "KJNodes/lora" |
|
|
|
|
|
def save(self, finetuned_model, original_model, filename_prefix, rank, lora_type, algorithm, lowrank_iters, output_dtype, bias_diff, adaptive_param, clamp_quantile): |
|
|
if algorithm == "svd_lowrank" and lora_type != "standard": |
|
|
raise ValueError("svd_lowrank algorithm is only supported for standard LoRA extraction.") |
|
|
|
|
|
dtype = {"fp8_e4m3fn": torch.float8_e4m3fn, "bf16": torch.bfloat16, "fp16": torch.float16, "fp16_fast": torch.float16, "fp32": torch.float32}[output_dtype] |
|
|
m = finetuned_model.clone() |
|
|
kp = original_model.get_key_patches("diffusion_model.") |
|
|
for k in kp: |
|
|
m.add_patches({k: kp[k]}, - 1.0, 1.0) |
|
|
model_diff = m |
|
|
|
|
|
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) |
|
|
|
|
|
output_sd = {} |
|
|
if model_diff is not None: |
|
|
output_sd = calc_lora_model(model_diff, rank, "diffusion_model.", "diffusion_model.", output_sd, lora_type, algorithm, lowrank_iters, dtype, bias_diff=bias_diff, adaptive_param=adaptive_param, clamp_quantile=clamp_quantile) |
|
|
if "adaptive" in lora_type: |
|
|
rank_str = f"{lora_type}_{adaptive_param:.2f}" |
|
|
else: |
|
|
rank_str = rank |
|
|
output_checkpoint = f"{filename}_rank_{rank_str}_{output_dtype}_{counter:05}_.safetensors" |
|
|
output_checkpoint = os.path.join(full_output_folder, output_checkpoint) |
|
|
|
|
|
comfy.utils.save_torch_file(output_sd, output_checkpoint, metadata=None) |
|
|
return {} |
|
|
|
|
|
NODE_CLASS_MAPPINGS = { |
|
|
"LoraExtractKJ": LoraExtractKJ |
|
|
} |
|
|
|
|
|
NODE_DISPLAY_NAME_MAPPINGS = { |
|
|
"LoraExtractKJ": "LoraExtractKJ" |
|
|
} |
|
|
|
|
|
class LoraReduceRank: |
|
|
def __init__(self): |
|
|
self.output_dir = folder_paths.get_output_directory() |
|
|
|
|
|
@classmethod |
|
|
def INPUT_TYPES(s): |
|
|
return {"required": |
|
|
{ |
|
|
"lora_name": (folder_paths.get_filename_list("loras"), {"tooltip": "The name of the LoRA."}), |
|
|
"new_rank": ("INT", {"default": 8, "min": 1, "max": 4096, "step": 1, "tooltip": "The new rank to resize the LoRA. Acts as max rank when using dynamic_method."}), |
|
|
"dynamic_method": (["disabled", "sv_ratio", "sv_cumulative", "sv_fro"], {"default": "disabled", "tooltip": "Method to use for dynamically determining new alphas and dims"}), |
|
|
"dynamic_param": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Method to use for dynamically determining new alphas and dims"}), |
|
|
"output_dtype": (["match_original", "fp16", "bf16", "fp32"], {"default": "match_original", "tooltip": "Data type to save the LoRA as."}), |
|
|
"verbose": ("BOOLEAN", {"default": True}), |
|
|
}, |
|
|
|
|
|
} |
|
|
RETURN_TYPES = () |
|
|
FUNCTION = "save" |
|
|
OUTPUT_NODE = True |
|
|
EXPERIMENTAL = True |
|
|
DESCRIPTION = "Resize a LoRA model by reducing it's rank. Based on kohya's sd-scripts: https://github.com/kohya-ss/sd-scripts/blob/main/networks/resize_lora.py" |
|
|
|
|
|
CATEGORY = "KJNodes/lora" |
|
|
|
|
|
def save(self, lora_name, new_rank, output_dtype, dynamic_method, dynamic_param, verbose): |
|
|
|
|
|
lora_path = folder_paths.get_full_path("loras", lora_name) |
|
|
lora_sd, metadata = comfy.utils.load_torch_file(lora_path, return_metadata=True) |
|
|
|
|
|
if output_dtype == "fp16": |
|
|
save_dtype = torch.float16 |
|
|
elif output_dtype == "bf16": |
|
|
save_dtype = torch.bfloat16 |
|
|
elif output_dtype == "fp32": |
|
|
save_dtype = torch.float32 |
|
|
elif output_dtype == "match_original": |
|
|
first_weight_key = next(k for k in lora_sd if k.endswith(".weight") and isinstance(lora_sd[k], torch.Tensor)) |
|
|
save_dtype = lora_sd[first_weight_key].dtype |
|
|
|
|
|
new_lora_sd = {} |
|
|
for k, v in lora_sd.items(): |
|
|
new_lora_sd[k.replace(".default", "")] = v |
|
|
del lora_sd |
|
|
print("Resizing Lora...") |
|
|
output_sd, old_dim, new_alpha, rank_list = resize_lora_model(new_lora_sd, new_rank, save_dtype, device, dynamic_method, dynamic_param, verbose) |
|
|
|
|
|
|
|
|
if metadata is None: |
|
|
metadata = {} |
|
|
|
|
|
comment = metadata.get("ss_training_comment", "") |
|
|
|
|
|
if dynamic_method == "disabled": |
|
|
metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {new_rank}; {comment}" |
|
|
metadata["ss_network_dim"] = str(new_rank) |
|
|
metadata["ss_network_alpha"] = str(new_alpha) |
|
|
else: |
|
|
metadata["ss_training_comment"] = f"Dynamic resize with {dynamic_method}: {dynamic_param} from {old_dim}; {comment}" |
|
|
metadata["ss_network_dim"] = "Dynamic" |
|
|
metadata["ss_network_alpha"] = "Dynamic" |
|
|
|
|
|
|
|
|
for key in list(output_sd.keys()): |
|
|
value = output_sd[key] |
|
|
if type(value) == torch.Tensor and value.dtype.is_floating_point and value.dtype != save_dtype: |
|
|
output_sd[key] = value.to(save_dtype) |
|
|
|
|
|
output_filename_prefix = "loras/" + lora_name |
|
|
|
|
|
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(output_filename_prefix, self.output_dir) |
|
|
output_dtype_str = f"_{output_dtype}" if output_dtype != "match_original" else "" |
|
|
average_rank = str(int(np.mean(rank_list))) |
|
|
rank_str = new_rank if dynamic_method == "disabled" else f"dynamic_{average_rank}" |
|
|
output_checkpoint = f"{filename.replace('.safetensors', '')}_resized_from_{old_dim}_to_{rank_str}{output_dtype_str}_{counter:05}_.safetensors" |
|
|
output_checkpoint = os.path.join(full_output_folder, output_checkpoint) |
|
|
print(f"Saving resized LoRA to {output_checkpoint}") |
|
|
|
|
|
comfy.utils.save_torch_file(output_sd, output_checkpoint, metadata=metadata) |
|
|
return {} |
|
|
|
|
|
NODE_CLASS_MAPPINGS = { |
|
|
"LoraExtractKJ": LoraExtractKJ |
|
|
} |
|
|
|
|
|
NODE_DISPLAY_NAME_MAPPINGS = { |
|
|
"LoraExtractKJ": "LoraExtractKJ" |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MIN_SV = 1e-6 |
|
|
|
|
|
LORA_DOWN_UP_FORMATS = [ |
|
|
("lora_down", "lora_up"), |
|
|
("lora_A", "lora_B"), |
|
|
("down", "up"), |
|
|
] |
|
|
|
|
|
|
|
|
def index_sv_cumulative(S, target): |
|
|
original_sum = float(torch.sum(S)) |
|
|
cumulative_sums = torch.cumsum(S, dim=0) / original_sum |
|
|
index = int(torch.searchsorted(cumulative_sums, target)) + 1 |
|
|
index = max(1, min(index, len(S) - 1)) |
|
|
|
|
|
return index |
|
|
|
|
|
|
|
|
def index_sv_fro(S, target): |
|
|
S_squared = S.pow(2) |
|
|
S_fro_sq = float(torch.sum(S_squared)) |
|
|
sum_S_squared = torch.cumsum(S_squared, dim=0) / S_fro_sq |
|
|
index = int(torch.searchsorted(sum_S_squared, target**2)) + 1 |
|
|
index = max(1, min(index, len(S) - 1)) |
|
|
|
|
|
return index |
|
|
|
|
|
|
|
|
def index_sv_ratio(S, target): |
|
|
max_sv = S[0] |
|
|
min_sv = max_sv / target |
|
|
index = int(torch.sum(S > min_sv).item()) |
|
|
index = max(1, min(index, len(S) - 1)) |
|
|
|
|
|
return index |
|
|
|
|
|
|
|
|
|
|
|
def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1): |
|
|
out_size, in_size, kernel_size, _ = weight.size() |
|
|
if weight.dtype != torch.float32: |
|
|
weight = weight.to(torch.float32) |
|
|
U, S, Vh = torch.linalg.svd(weight.reshape(out_size, -1).to(device)) |
|
|
|
|
|
param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale) |
|
|
lora_rank = param_dict["new_rank"] |
|
|
|
|
|
U = U[:, :lora_rank] |
|
|
S = S[:lora_rank] |
|
|
U = U @ torch.diag(S) |
|
|
Vh = Vh[:lora_rank, :] |
|
|
|
|
|
param_dict["lora_down"] = Vh.reshape(lora_rank, in_size, kernel_size, kernel_size).cpu() |
|
|
param_dict["lora_up"] = U.reshape(out_size, lora_rank, 1, 1).cpu() |
|
|
del U, S, Vh, weight |
|
|
return param_dict |
|
|
|
|
|
|
|
|
def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1): |
|
|
out_size, in_size = weight.size() |
|
|
|
|
|
if weight.dtype != torch.float32: |
|
|
weight = weight.to(torch.float32) |
|
|
U, S, Vh = torch.linalg.svd(weight.to(device)) |
|
|
|
|
|
param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale) |
|
|
lora_rank = param_dict["new_rank"] |
|
|
|
|
|
U = U[:, :lora_rank] |
|
|
S = S[:lora_rank] |
|
|
U = U @ torch.diag(S) |
|
|
Vh = Vh[:lora_rank, :] |
|
|
|
|
|
param_dict["lora_down"] = Vh.reshape(lora_rank, in_size).cpu() |
|
|
param_dict["lora_up"] = U.reshape(out_size, lora_rank).cpu() |
|
|
del U, S, Vh, weight |
|
|
return param_dict |
|
|
|
|
|
|
|
|
def merge_conv(lora_down, lora_up, device): |
|
|
in_rank, in_size, kernel_size, k_ = lora_down.shape |
|
|
out_size, out_rank, _, _ = lora_up.shape |
|
|
assert in_rank == out_rank and kernel_size == k_, f"rank {in_rank} {out_rank} or kernel {kernel_size} {k_} mismatch" |
|
|
|
|
|
lora_down = lora_down.to(device) |
|
|
lora_up = lora_up.to(device) |
|
|
|
|
|
merged = lora_up.reshape(out_size, -1) @ lora_down.reshape(in_rank, -1) |
|
|
weight = merged.reshape(out_size, in_size, kernel_size, kernel_size) |
|
|
del lora_up, lora_down |
|
|
return weight |
|
|
|
|
|
|
|
|
def merge_linear(lora_down, lora_up, device): |
|
|
in_rank, in_size = lora_down.shape |
|
|
out_size, out_rank = lora_up.shape |
|
|
assert in_rank == out_rank, f"rank {in_rank} {out_rank} mismatch" |
|
|
|
|
|
lora_down = lora_down.to(device) |
|
|
lora_up = lora_up.to(device) |
|
|
|
|
|
weight = lora_up @ lora_down |
|
|
del lora_up, lora_down |
|
|
return weight |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1): |
|
|
param_dict = {} |
|
|
|
|
|
if dynamic_method == "sv_ratio": |
|
|
|
|
|
new_rank = index_sv_ratio(S, dynamic_param) + 1 |
|
|
new_alpha = float(scale * new_rank) |
|
|
|
|
|
elif dynamic_method == "sv_cumulative": |
|
|
|
|
|
new_rank = index_sv_cumulative(S, dynamic_param) + 1 |
|
|
new_alpha = float(scale * new_rank) |
|
|
|
|
|
elif dynamic_method == "sv_fro": |
|
|
|
|
|
new_rank = index_sv_fro(S, dynamic_param) + 1 |
|
|
new_alpha = float(scale * new_rank) |
|
|
else: |
|
|
new_rank = rank |
|
|
new_alpha = float(scale * new_rank) |
|
|
|
|
|
if S[0] <= MIN_SV: |
|
|
new_rank = 1 |
|
|
new_alpha = float(scale * new_rank) |
|
|
elif new_rank > rank: |
|
|
new_rank = rank |
|
|
new_alpha = float(scale * new_rank) |
|
|
|
|
|
|
|
|
s_sum = torch.sum(torch.abs(S)) |
|
|
s_rank = torch.sum(torch.abs(S[:new_rank])) |
|
|
|
|
|
S_squared = S.pow(2) |
|
|
s_fro = torch.sqrt(torch.sum(S_squared)) |
|
|
s_red_fro = torch.sqrt(torch.sum(S_squared[:new_rank])) |
|
|
fro_percent = float(s_red_fro / s_fro) |
|
|
|
|
|
param_dict["new_rank"] = new_rank |
|
|
param_dict["new_alpha"] = new_alpha |
|
|
param_dict["sum_retained"] = (s_rank) / s_sum |
|
|
param_dict["fro_retained"] = fro_percent |
|
|
param_dict["max_ratio"] = S[0] / S[new_rank - 1] |
|
|
|
|
|
return param_dict |
|
|
|
|
|
|
|
|
def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dynamic_param, verbose): |
|
|
max_old_rank = None |
|
|
new_alpha = None |
|
|
verbose_str = "\n" |
|
|
fro_list = [] |
|
|
rank_list = [] |
|
|
|
|
|
if dynamic_method: |
|
|
print(f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}") |
|
|
|
|
|
lora_down_weight = None |
|
|
lora_up_weight = None |
|
|
|
|
|
o_lora_sd = lora_sd.copy() |
|
|
block_down_name = None |
|
|
block_up_name = None |
|
|
|
|
|
total_keys = len([k for k in lora_sd if k.endswith(".weight")]) |
|
|
|
|
|
pbar = comfy.utils.ProgressBar(total_keys) |
|
|
for key, value in tqdm(lora_sd.items(), leave=True, desc="Resizing LoRA weights"): |
|
|
key_parts = key.split(".") |
|
|
block_down_name = None |
|
|
for _format in LORA_DOWN_UP_FORMATS: |
|
|
|
|
|
|
|
|
if len(key_parts) >= 2 and _format[0] == key_parts[-2]: |
|
|
block_down_name = ".".join(key_parts[:-2]) |
|
|
lora_down_name = "." + _format[0] |
|
|
lora_up_name = "." + _format[1] |
|
|
weight_name = "." + key_parts[-1] |
|
|
break |
|
|
if len(key_parts) >= 1 and _format[0] == key_parts[-1]: |
|
|
block_down_name = ".".join(key_parts[:-1]) |
|
|
lora_down_name = "." + _format[0] |
|
|
lora_up_name = "." + _format[1] |
|
|
weight_name = "" |
|
|
break |
|
|
|
|
|
if block_down_name is None: |
|
|
|
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
block_up_name = block_down_name |
|
|
lora_down_weight = value |
|
|
lora_up_weight = lora_sd.get(block_up_name + lora_up_name + weight_name, None) |
|
|
lora_alpha = lora_sd.get(block_down_name + ".alpha", None) |
|
|
|
|
|
weights_loaded = lora_down_weight is not None and lora_up_weight is not None |
|
|
|
|
|
if weights_loaded: |
|
|
|
|
|
conv2d = len(lora_down_weight.size()) == 4 |
|
|
old_rank = lora_down_weight.size()[0] |
|
|
max_old_rank = max(max_old_rank or 0, old_rank) |
|
|
|
|
|
|
|
|
if lora_alpha is None: |
|
|
scale = 1.0 |
|
|
else: |
|
|
scale = lora_alpha / old_rank |
|
|
|
|
|
if conv2d: |
|
|
full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device) |
|
|
param_dict = extract_conv(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale) |
|
|
else: |
|
|
full_weight_matrix = merge_linear(lora_down_weight, lora_up_weight, device) |
|
|
param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale) |
|
|
|
|
|
if verbose: |
|
|
max_ratio = param_dict["max_ratio"] |
|
|
sum_retained = param_dict["sum_retained"] |
|
|
fro_retained = param_dict["fro_retained"] |
|
|
if not np.isnan(fro_retained): |
|
|
fro_list.append(float(fro_retained)) |
|
|
log_str = f"{block_down_name:75} | sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}, new dim: {param_dict['new_rank']}" |
|
|
tqdm.write(log_str) |
|
|
verbose_str += log_str |
|
|
|
|
|
if verbose and dynamic_method: |
|
|
verbose_str += f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}\n" |
|
|
else: |
|
|
verbose_str += "\n" |
|
|
|
|
|
new_alpha = param_dict["new_alpha"] |
|
|
o_lora_sd[block_down_name + lora_down_name + weight_name] = param_dict["lora_down"].to(save_dtype).contiguous() |
|
|
o_lora_sd[block_up_name + lora_up_name + weight_name] = param_dict["lora_up"].to(save_dtype).contiguous() |
|
|
o_lora_sd[block_down_name + ".alpha"] = torch.tensor(param_dict["new_alpha"]).to(save_dtype) |
|
|
|
|
|
block_down_name = None |
|
|
block_up_name = None |
|
|
lora_down_weight = None |
|
|
lora_up_weight = None |
|
|
weights_loaded = False |
|
|
rank_list.append(param_dict["new_rank"]) |
|
|
del param_dict |
|
|
pbar.update(1) |
|
|
|
|
|
if verbose: |
|
|
print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}") |
|
|
return o_lora_sd, max_old_rank, new_alpha, rank_list |
|
|
|