Spaces:
Runtime error
Runtime error
File size: 7,426 Bytes
e4338d2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 | import os
from torch.hub import download_url_to_file
import torch.nn as nn
import torch
from diffusers import UNet2DConditionModel
from diffusers.configuration_utils import FrozenDict
def patch_transvae_sd(model, state_dict):
return {'model.' + k: v for k, v in state_dict.items()}
def module_dtype(self):
return next(self.parameters()).dtype
def module_device(self):
return next(self.parameters()).device
def conv_add_channels(new_c: int, conv: nn.Conv2d, prepend=False):
new_conv = nn.Conv2d(new_c + conv.in_channels, conv.out_channels, conv.kernel_size, conv.stride, conv.padding, conv.dilation, conv.groups, conv.bias is not None)
sd = conv.state_dict()
ks = conv.kernel_size[0]
if prepend:
sd['weight'] = torch.cat([torch.zeros((conv.out_channels, new_c, ks, ks)), sd['weight']], dim=1)
else:
sd['weight'] = torch.cat([sd['weight'], torch.zeros((conv.out_channels, new_c, ks, ks))], dim=1)
new_conv.load_state_dict(sd, strict=True)
new_conv.to(device=module_device(conv), dtype=module_dtype(conv))
return new_conv
def update_net_config(net: UNet2DConditionModel, key: str, value):
new_config = dict(net.config)
new_config[key] = value
net._internal_dict = FrozenDict(new_config)
def patch_unet_convin(unet: UNet2DConditionModel, target_in_channels, prepend=False):
'''
add new channels to unet.conv_in, weights init to zeros
'''
new_added_conv_channels = target_in_channels - unet.config.in_channels
if new_added_conv_channels < 1:
return
new_conv = conv_add_channels(new_added_conv_channels, unet.conv_in, prepend=prepend)
del unet.conv_in
unet.conv_in = new_conv
update_net_config(unet, "in_channels", new_conv.in_channels)
def download_model(url, local_path):
if os.path.exists(local_path):
return local_path
temp_path = local_path + '.tmp'
download_url_to_file(url=url, dst=temp_path)
os.rename(temp_path, local_path)
return local_path
def load_frozen_patcher(filename, state_dict, strength):
patch_dict = {}
for k, w in state_dict.items():
model_key, patch_type, weight_index = k.split('::')
if model_key not in patch_dict:
patch_dict[model_key] = {}
if patch_type not in patch_dict[model_key]:
patch_dict[model_key][patch_type] = [None] * 16
patch_dict[model_key][patch_type][int(weight_index)] = w
patch_flat = {}
for model_key, v in patch_dict.items():
for patch_type, weight_list in v.items():
patch_flat[model_key] = (patch_type, weight_list)
add_patches(filename=filename, patches=patch_flat, strength_patch=float(strength), strength_model=1.0)
return
def add_patches(self, *, filename, patches, strength_patch=1.0, strength_model=1.0, online_mode=False):
lora_identifier = (filename, strength_patch, strength_model, online_mode)
this_patches = {}
p = set()
model_keys = set(k for k, _ in self.model.named_parameters())
for k in patches:
offset = None
function = None
if isinstance(k, str):
key = k
else:
offset = k[1]
key = k[0]
if len(k) > 2:
function = k[2]
if key in model_keys:
p.add(k)
current_patches = this_patches.get(key, [])
current_patches.append([strength_patch, patches[k], strength_model, offset, function])
this_patches[key] = current_patches
self.lora_patches[lora_identifier] = this_patches
return p
# class LoraLoader:
# def __init__(self, model):
# self.model = model
# self.backup = {}
# self.online_backup = []
# self.loaded_hash = str([])
# @torch.inference_mode()
# def refresh(self, lora_patches, offload_device=torch.device('cpu'), force_refresh=False):
# hashes = str(list(lora_patches.keys()))
# if hashes == self.loaded_hash and not force_refresh:
# return
# # Merge Patches
# all_patches = {}
# for (_, _, _, online_mode), patches in lora_patches.items():
# for key, current_patches in patches.items():
# all_patches[(key, online_mode)] = all_patches.get((key, online_mode), []) + current_patches
# # Initialize
# memory_management.signal_empty_cache = True
# parameter_devices = get_parameter_devices(self.model)
# # Restore
# for m in set(self.online_backup):
# del m.forge_online_loras
# self.online_backup = []
# for k, w in self.backup.items():
# if not isinstance(w, torch.nn.Parameter):
# # In very few cases
# w = torch.nn.Parameter(w, requires_grad=False)
# utils.set_attr_raw(self.model, k, w)
# self.backup = {}
# set_parameter_devices(self.model, parameter_devices=parameter_devices)
# # Patch
# for (key, online_mode), current_patches in all_patches.items():
# try:
# parent_layer, child_key, weight = utils.get_attr_with_parent(self.model, key)
# assert isinstance(weight, torch.nn.Parameter)
# except:
# raise ValueError(f"Wrong LoRA Key: {key}")
# if online_mode:
# if not hasattr(parent_layer, 'forge_online_loras'):
# parent_layer.forge_online_loras = {}
# parent_layer.forge_online_loras[child_key] = current_patches
# self.online_backup.append(parent_layer)
# continue
# if key not in self.backup:
# self.backup[key] = weight.to(device=offload_device)
# bnb_layer = None
# if hasattr(weight, 'bnb_quantized') and operations.bnb_avaliable:
# bnb_layer = parent_layer
# from backend.operations_bnb import functional_dequantize_4bit
# weight = functional_dequantize_4bit(weight)
# gguf_cls = getattr(weight, 'gguf_cls', None)
# gguf_parameter = None
# if gguf_cls is not None:
# gguf_parameter = weight
# from backend.operations_gguf import dequantize_tensor
# weight = dequantize_tensor(weight)
# try:
# weight = merge_lora_to_weight(current_patches, weight, key, computation_dtype=torch.float32)
# except:
# print('Patching LoRA weights out of memory. Retrying by offloading models.')
# set_parameter_devices(self.model, parameter_devices={k: offload_device for k in parameter_devices.keys()})
# memory_management.soft_empty_cache()
# weight = merge_lora_to_weight(current_patches, weight, key, computation_dtype=torch.float32)
# if bnb_layer is not None:
# bnb_layer.reload_weight(weight)
# continue
# if gguf_cls is not None:
# gguf_cls.quantize_pytorch(weight, gguf_parameter)
# continue
# utils.set_attr_raw(self.model, key, torch.nn.Parameter(weight, requires_grad=False))
# # End
# set_parameter_devices(self.model, parameter_devices=parameter_devices)
# self.loaded_hash = hashes
# return
|