Spaces:
Runtime error
Runtime error
| import os | |
| from torch.hub import download_url_to_file | |
| import torch.nn as nn | |
| import torch | |
| from diffusers import UNet2DConditionModel | |
| from diffusers.configuration_utils import FrozenDict | |
| def patch_transvae_sd(model, state_dict): | |
| return {'model.' + k: v for k, v in state_dict.items()} | |
| def module_dtype(self): | |
| return next(self.parameters()).dtype | |
| def module_device(self): | |
| return next(self.parameters()).device | |
| def conv_add_channels(new_c: int, conv: nn.Conv2d, prepend=False): | |
| new_conv = nn.Conv2d(new_c + conv.in_channels, conv.out_channels, conv.kernel_size, conv.stride, conv.padding, conv.dilation, conv.groups, conv.bias is not None) | |
| sd = conv.state_dict() | |
| ks = conv.kernel_size[0] | |
| if prepend: | |
| sd['weight'] = torch.cat([torch.zeros((conv.out_channels, new_c, ks, ks)), sd['weight']], dim=1) | |
| else: | |
| sd['weight'] = torch.cat([sd['weight'], torch.zeros((conv.out_channels, new_c, ks, ks))], dim=1) | |
| new_conv.load_state_dict(sd, strict=True) | |
| new_conv.to(device=module_device(conv), dtype=module_dtype(conv)) | |
| return new_conv | |
| def update_net_config(net: UNet2DConditionModel, key: str, value): | |
| new_config = dict(net.config) | |
| new_config[key] = value | |
| net._internal_dict = FrozenDict(new_config) | |
| def patch_unet_convin(unet: UNet2DConditionModel, target_in_channels, prepend=False): | |
| ''' | |
| add new channels to unet.conv_in, weights init to zeros | |
| ''' | |
| new_added_conv_channels = target_in_channels - unet.config.in_channels | |
| if new_added_conv_channels < 1: | |
| return | |
| new_conv = conv_add_channels(new_added_conv_channels, unet.conv_in, prepend=prepend) | |
| del unet.conv_in | |
| unet.conv_in = new_conv | |
| update_net_config(unet, "in_channels", new_conv.in_channels) | |
| def download_model(url, local_path): | |
| if os.path.exists(local_path): | |
| return local_path | |
| temp_path = local_path + '.tmp' | |
| download_url_to_file(url=url, dst=temp_path) | |
| os.rename(temp_path, local_path) | |
| return local_path | |
| def load_frozen_patcher(filename, state_dict, strength): | |
| patch_dict = {} | |
| for k, w in state_dict.items(): | |
| model_key, patch_type, weight_index = k.split('::') | |
| if model_key not in patch_dict: | |
| patch_dict[model_key] = {} | |
| if patch_type not in patch_dict[model_key]: | |
| patch_dict[model_key][patch_type] = [None] * 16 | |
| patch_dict[model_key][patch_type][int(weight_index)] = w | |
| patch_flat = {} | |
| for model_key, v in patch_dict.items(): | |
| for patch_type, weight_list in v.items(): | |
| patch_flat[model_key] = (patch_type, weight_list) | |
| add_patches(filename=filename, patches=patch_flat, strength_patch=float(strength), strength_model=1.0) | |
| return | |
| def add_patches(self, *, filename, patches, strength_patch=1.0, strength_model=1.0, online_mode=False): | |
| lora_identifier = (filename, strength_patch, strength_model, online_mode) | |
| this_patches = {} | |
| p = set() | |
| model_keys = set(k for k, _ in self.model.named_parameters()) | |
| for k in patches: | |
| offset = None | |
| function = None | |
| if isinstance(k, str): | |
| key = k | |
| else: | |
| offset = k[1] | |
| key = k[0] | |
| if len(k) > 2: | |
| function = k[2] | |
| if key in model_keys: | |
| p.add(k) | |
| current_patches = this_patches.get(key, []) | |
| current_patches.append([strength_patch, patches[k], strength_model, offset, function]) | |
| this_patches[key] = current_patches | |
| self.lora_patches[lora_identifier] = this_patches | |
| return p | |
| # class LoraLoader: | |
| # def __init__(self, model): | |
| # self.model = model | |
| # self.backup = {} | |
| # self.online_backup = [] | |
| # self.loaded_hash = str([]) | |
| # @torch.inference_mode() | |
| # def refresh(self, lora_patches, offload_device=torch.device('cpu'), force_refresh=False): | |
| # hashes = str(list(lora_patches.keys())) | |
| # if hashes == self.loaded_hash and not force_refresh: | |
| # return | |
| # # Merge Patches | |
| # all_patches = {} | |
| # for (_, _, _, online_mode), patches in lora_patches.items(): | |
| # for key, current_patches in patches.items(): | |
| # all_patches[(key, online_mode)] = all_patches.get((key, online_mode), []) + current_patches | |
| # # Initialize | |
| # memory_management.signal_empty_cache = True | |
| # parameter_devices = get_parameter_devices(self.model) | |
| # # Restore | |
| # for m in set(self.online_backup): | |
| # del m.forge_online_loras | |
| # self.online_backup = [] | |
| # for k, w in self.backup.items(): | |
| # if not isinstance(w, torch.nn.Parameter): | |
| # # In very few cases | |
| # w = torch.nn.Parameter(w, requires_grad=False) | |
| # utils.set_attr_raw(self.model, k, w) | |
| # self.backup = {} | |
| # set_parameter_devices(self.model, parameter_devices=parameter_devices) | |
| # # Patch | |
| # for (key, online_mode), current_patches in all_patches.items(): | |
| # try: | |
| # parent_layer, child_key, weight = utils.get_attr_with_parent(self.model, key) | |
| # assert isinstance(weight, torch.nn.Parameter) | |
| # except: | |
| # raise ValueError(f"Wrong LoRA Key: {key}") | |
| # if online_mode: | |
| # if not hasattr(parent_layer, 'forge_online_loras'): | |
| # parent_layer.forge_online_loras = {} | |
| # parent_layer.forge_online_loras[child_key] = current_patches | |
| # self.online_backup.append(parent_layer) | |
| # continue | |
| # if key not in self.backup: | |
| # self.backup[key] = weight.to(device=offload_device) | |
| # bnb_layer = None | |
| # if hasattr(weight, 'bnb_quantized') and operations.bnb_avaliable: | |
| # bnb_layer = parent_layer | |
| # from backend.operations_bnb import functional_dequantize_4bit | |
| # weight = functional_dequantize_4bit(weight) | |
| # gguf_cls = getattr(weight, 'gguf_cls', None) | |
| # gguf_parameter = None | |
| # if gguf_cls is not None: | |
| # gguf_parameter = weight | |
| # from backend.operations_gguf import dequantize_tensor | |
| # weight = dequantize_tensor(weight) | |
| # try: | |
| # weight = merge_lora_to_weight(current_patches, weight, key, computation_dtype=torch.float32) | |
| # except: | |
| # print('Patching LoRA weights out of memory. Retrying by offloading models.') | |
| # set_parameter_devices(self.model, parameter_devices={k: offload_device for k in parameter_devices.keys()}) | |
| # memory_management.soft_empty_cache() | |
| # weight = merge_lora_to_weight(current_patches, weight, key, computation_dtype=torch.float32) | |
| # if bnb_layer is not None: | |
| # bnb_layer.reload_weight(weight) | |
| # continue | |
| # if gguf_cls is not None: | |
| # gguf_cls.quantize_pytorch(weight, gguf_parameter) | |
| # continue | |
| # utils.set_attr_raw(self.model, key, torch.nn.Parameter(weight, requires_grad=False)) | |
| # # End | |
| # set_parameter_devices(self.model, parameter_devices=parameter_devices) | |
| # self.loaded_hash = hashes | |
| # return | |