File size: 6,210 Bytes
ca2a3d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import comfy.model_management
import safetensors.torch
import torch, os, comfy, json
# ATTRIBUTION: This code is a mix of code from kohya-ss, comfy, and Swarm. It would be annoying to disentangle but it's all FOSS and relatively short so it's fine.
CLAMP_QUANTILE = 0.99
def extract_lora(diff, rank):
conv2d = (len(diff.shape) == 4)
kernel_size = None if not conv2d else diff.size()[2:4]
conv2d_3x3 = conv2d and kernel_size != (1, 1)
out_dim, in_dim = diff.size()[0:2]
rank = min(rank, in_dim, out_dim)
if conv2d:
if conv2d_3x3:
diff = diff.flatten(start_dim=1)
else:
diff = diff.squeeze()
U, S, Vh = torch.linalg.svd(diff.float())
U = U[:, :rank]
S = S[:rank]
U = U @ torch.diag(S)
Vh = Vh[:rank, :]
dist = torch.cat([U.flatten(), Vh.flatten()])
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
low_val = -hi_val
U = U.clamp(low_val, hi_val)
Vh = Vh.clamp(low_val, hi_val)
if conv2d:
U = U.reshape(out_dim, rank, 1, 1)
Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
return (U, Vh)
def do_lora_handle(base_data, other_data, rank, prefix, require, do_bias, callback):
out_data = {}
device = comfy.model_management.get_torch_device()
for key in base_data.keys():
callback()
if key not in other_data:
continue
base_tensor = base_data[key].float()
other_tensor = other_data[key].float()
if key.startswith("clip_g"):
key = "1." + key[len("clip_g."):]
elif key.startswith("clip_l"):
key = "0." + key[len("clip_l."):]
if require:
if not key.startswith(require):
print(f"Ignore unmatched key {key} (doesn't match {require})")
continue
key = key[len(require):]
if base_tensor.shape != other_tensor.shape:
continue
diff = other_tensor.to(device) - base_tensor.to(device)
other_tensor = other_tensor.cpu()
base_tensor = base_tensor.cpu()
max_diff = float(diff.abs().max())
if max_diff < 1e-5:
print(f"discard unaltered key {key} ({max_diff})")
continue
if key.endswith(".weight"):
fixed_key = key[:-len(".weight")].replace('.', '_')
name = f"lora_{prefix}_{fixed_key}"
if len(base_tensor.shape) >= 2:
print(f"extract key {name} ({max_diff})")
out = extract_lora(diff, rank)
out_data[f"{name}.lora_up.weight"] = out[0].contiguous().half().cpu()
out_data[f"{name}.lora_down.weight"] = out[1].contiguous().half().cpu()
else:
print(f"ignore valid raw pass-through key {name} ({max_diff})")
#out_data[name] = other_tensor.contiguous().half().cpu()
elif key.endswith(".bias") and do_bias:
fixed_key = key[:-len(".bias")].replace('.', '_')
name = f"lora_{prefix}_{fixed_key}"
print(f"extract bias key {name} ({max_diff})")
out_data[f"{name}.diff_b"] = diff.contiguous().half().cpu()
return out_data
class SwarmExtractLora:
def __init__(self):
self.loaded_lora = None
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"base_model": ("MODEL", ),
"base_model_clip": ("CLIP", ),
"other_model": ("MODEL", ),
"other_model_clip": ("CLIP", ),
"rank": ("INT", {"default": 16, "min": 1, "max": 320}),
"save_rawpath": ("STRING", {"multiline": False}),
"save_filename": ("STRING", {"multiline": False}),
"save_clip": ("BOOLEAN", {"default": True}),
"metadata": ("STRING", {"multiline": True}),
}
}
CATEGORY = "SwarmUI/models"
RETURN_TYPES = ()
FUNCTION = "extract_lora"
OUTPUT_NODE = True
DESCRIPTION = "Internal node, do not use directly - extracts a LoRA from the difference between two models. This is used by SwarmUI Utilities tab."
def extract_lora(self, base_model, base_model_clip, other_model, other_model_clip, rank, save_rawpath, save_filename, save_clip, metadata):
base_data = base_model.model_state_dict()
other_data = other_model.model_state_dict()
key_count = len(base_data.keys())
if save_clip:
key_count += len(base_model_clip.get_sd().keys())
pbar = comfy.utils.ProgressBar(key_count)
class Helper:
steps = 0
def callback(self):
self.steps += 1
pbar.update_absolute(self.steps, key_count, None)
helper = Helper()
out_data = do_lora_handle(base_data, other_data, rank, "unet", "diffusion_model.", True, lambda: helper.callback())
if save_clip:
# TODO: CLIP keys get wonky, this probably doesn't work? Model-arch-dependent.
out_clip = do_lora_handle(base_model_clip.get_sd(), other_model_clip.get_sd(), rank, "te_text_model_encoder_layers", "0.transformer.text_model.encoder.layers.", False, lambda: helper.callback())
out_clip = do_lora_handle(base_model_clip.get_sd(), other_model_clip.get_sd(), rank, "te2_text_model_encoder_layers", "1.transformer.text_model.encoder.layers.", False, lambda: helper.callback())
out_data.update(out_clip)
# Can't easily autodetect all the correct modelspec info, but at least supply some basics
out_metadata = {
"modelspec.title": f"(Extracted LoRA) {save_filename}",
"modelspec.description": f"LoRA extracted in SwarmUI"
}
if metadata:
out_metadata.update(json.loads(metadata))
path = f"{save_rawpath}{save_filename}.safetensors"
print(f"saving to path {path}")
safetensors.torch.save_file(out_data, path, metadata=out_metadata)
return ()
NODE_CLASS_MAPPINGS = {
"SwarmExtractLora": SwarmExtractLora,
}
|