| from typing import * | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.linalg as linalg | |
| from tqdm import tqdm | |
| def make_sparse(t: torch.Tensor, sparsity=0.95): | |
| abs_t = torch.abs(t) | |
| np_array = abs_t.detach().cpu().numpy() | |
| quan = float(np.quantile(np_array, sparsity)) | |
| sparse_t = t.masked_fill(abs_t < quan, 0) | |
| return sparse_t | |
| def extract_conv( | |
| weight: Union[torch.Tensor, nn.Parameter], | |
| mode = 'fixed', | |
| mode_param = 0, | |
| device = 'cpu', | |
| ) -> Tuple[nn.Parameter, nn.Parameter]: | |
| weight = weight.to(device) | |
| out_ch, in_ch, kernel_size, _ = weight.shape | |
| U, S, Vh = linalg.svd(weight.reshape(out_ch, -1)) | |
| if mode=='fixed': | |
| lora_rank = mode_param | |
| elif mode=='threshold': | |
| assert mode_param>=0 | |
| lora_rank = torch.sum(S>mode_param) | |
| elif mode=='ratio': | |
| assert 1>=mode_param>=0 | |
| min_s = torch.max(S)*mode_param | |
| lora_rank = torch.sum(S>min_s) | |
| elif mode=='quantile' or mode=='percentile': | |
| assert 1>=mode_param>=0 | |
| s_cum = torch.cumsum(S, dim=0) | |
| min_cum_sum = mode_param * torch.sum(S) | |
| lora_rank = torch.sum(s_cum<min_cum_sum) | |
| else: | |
| raise NotImplementedError('Extract mode should be "fixed", "threshold", "ratio" or "quantile"') | |
| lora_rank = max(1, lora_rank) | |
| lora_rank = min(out_ch, in_ch, lora_rank) | |
| U = U[:, :lora_rank] | |
| S = S[:lora_rank] | |
| U = U @ torch.diag(S) | |
| Vh = Vh[:lora_rank, :] | |
| diff = (weight - (U @ Vh).reshape(out_ch, in_ch, kernel_size, kernel_size)).detach() | |
| extract_weight_A = Vh.reshape(lora_rank, in_ch, kernel_size, kernel_size).detach() | |
| extract_weight_B = U.reshape(out_ch, lora_rank, 1, 1).detach() | |
| del U, S, Vh, weight | |
| return extract_weight_A, extract_weight_B, diff | |
| def merge_conv( | |
| weight_a: Union[torch.Tensor, nn.Parameter], | |
| weight_b: Union[torch.Tensor, nn.Parameter], | |
| device = 'cpu' | |
| ): | |
| rank, in_ch, kernel_size, k_ = weight_a.shape | |
| out_ch, rank_, _, _ = weight_b.shape | |
| assert rank == rank_ and kernel_size == k_ | |
| wa = weight_a.to(device) | |
| wb = weight_b.to(device) | |
| if device == 'cpu': | |
| wa = wa.float() | |
| wb = wb.float() | |
| merged = wb.reshape(out_ch, -1) @ wa.reshape(rank, -1) | |
| weight = merged.reshape(out_ch, in_ch, kernel_size, kernel_size) | |
| del wb, wa | |
| return weight | |
| def extract_linear( | |
| weight: Union[torch.Tensor, nn.Parameter], | |
| mode = 'fixed', | |
| mode_param = 0, | |
| device = 'cpu', | |
| ) -> Tuple[nn.Parameter, nn.Parameter]: | |
| weight = weight.to(device) | |
| out_ch, in_ch = weight.shape | |
| U, S, Vh = linalg.svd(weight) | |
| if mode=='fixed': | |
| lora_rank = mode_param | |
| elif mode=='threshold': | |
| assert mode_param>=0 | |
| lora_rank = torch.sum(S>mode_param) | |
| elif mode=='ratio': | |
| assert 1>=mode_param>=0 | |
| min_s = torch.max(S)*mode_param | |
| lora_rank = torch.sum(S>min_s) | |
| elif mode=='quantile' or mode=='percentile': | |
| assert 1>=mode_param>=0 | |
| s_cum = torch.cumsum(S, dim=0) | |
| min_cum_sum = mode_param * torch.sum(S) | |
| lora_rank = torch.sum(s_cum<min_cum_sum) | |
| else: | |
| raise NotImplementedError('Extract mode should be "fixed", "threshold", "ratio" or "quantile"') | |
| lora_rank = max(1, lora_rank) | |
| lora_rank = min(out_ch, in_ch, lora_rank) | |
| U = U[:, :lora_rank] | |
| S = S[:lora_rank] | |
| U = U @ torch.diag(S) | |
| Vh = Vh[:lora_rank, :] | |
| diff = (weight - U @ Vh).detach() | |
| extract_weight_A = Vh.reshape(lora_rank, in_ch).detach() | |
| extract_weight_B = U.reshape(out_ch, lora_rank).detach() | |
| del U, S, Vh, weight | |
| return extract_weight_A, extract_weight_B, diff | |
| def merge_linear( | |
| weight_a: Union[torch.Tensor, nn.Parameter], | |
| weight_b: Union[torch.Tensor, nn.Parameter], | |
| device = 'cpu' | |
| ): | |
| rank, in_ch = weight_a.shape | |
| out_ch, rank_ = weight_b.shape | |
| assert rank == rank_ | |
| wa = weight_a.to(device) | |
| wb = weight_b.to(device) | |
| if device == 'cpu': | |
| wa = wa.float() | |
| wb = wb.float() | |
| weight = wb @ wa | |
| del wb, wa | |
| return weight | |
| def extract_diff( | |
| base_model, | |
| db_model, | |
| mode = 'fixed', | |
| linear_mode_param = 0, | |
| conv_mode_param = 0, | |
| extract_device = 'cpu', | |
| use_bias = False, | |
| sparsity = 0.98, | |
| small_conv = True | |
| ): | |
| UNET_TARGET_REPLACE_MODULE = [ | |
| "Transformer2DModel", | |
| "Attention", | |
| "ResnetBlock2D", | |
| "Downsample2D", | |
| "Upsample2D" | |
| ] | |
| TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] | |
| LORA_PREFIX_UNET = 'lora_unet' | |
| LORA_PREFIX_TEXT_ENCODER = 'lora_te' | |
| def make_state_dict( | |
| prefix, | |
| root_module: torch.nn.Module, | |
| target_module: torch.nn.Module, | |
| target_replace_modules | |
| ): | |
| loras = {} | |
| temp = {} | |
| for name, module in root_module.named_modules(): | |
| if module.__class__.__name__ in target_replace_modules: | |
| temp[name] = {} | |
| for child_name, child_module in module.named_modules(): | |
| if child_module.__class__.__name__ not in {'Linear', 'Conv2d'}: | |
| continue | |
| temp[name][child_name] = child_module.weight | |
| for name, module in tqdm(list(target_module.named_modules())): | |
| if name in temp: | |
| weights = temp[name] | |
| for child_name, child_module in module.named_modules(): | |
| lora_name = prefix + '.' + name + '.' + child_name | |
| lora_name = lora_name.replace('.', '_') | |
| layer = child_module.__class__.__name__ | |
| if layer == 'Linear': | |
| extract_a, extract_b, diff = extract_linear( | |
| (child_module.weight - weights[child_name]), | |
| mode, | |
| linear_mode_param, | |
| device = extract_device, | |
| ) | |
| elif layer == 'Conv2d': | |
| is_linear = (child_module.weight.shape[2] == 1 | |
| and child_module.weight.shape[3] == 1) | |
| extract_a, extract_b, diff = extract_conv( | |
| (child_module.weight - weights[child_name]), | |
| mode, | |
| linear_mode_param if is_linear else conv_mode_param, | |
| device = extract_device, | |
| ) | |
| if small_conv and not is_linear: | |
| dim = extract_a.size(0) | |
| extract_c, extract_a, _ = extract_conv( | |
| extract_a.transpose(0, 1), | |
| 'fixed', dim, | |
| extract_device | |
| ) | |
| extract_a = extract_a.transpose(0, 1) | |
| extract_c = extract_c.transpose(0, 1) | |
| loras[f'{lora_name}.lora_mid.weight'] = extract_c.detach().cpu().contiguous().half() | |
| diff = child_module.weight - torch.einsum( | |
| 'i j k l, j r, p i -> p r k l', | |
| extract_c, extract_a.flatten(1, -1), extract_b.flatten(1, -1) | |
| ).detach().cpu().contiguous() | |
| del extract_c | |
| else: | |
| continue | |
| loras[f'{lora_name}.lora_down.weight'] = extract_a.detach().cpu().contiguous().half() | |
| loras[f'{lora_name}.lora_up.weight'] = extract_b.detach().cpu().contiguous().half() | |
| loras[f'{lora_name}.alpha'] = torch.Tensor([extract_a.shape[0]]).half() | |
| if use_bias: | |
| diff = diff.detach().cpu().reshape(extract_b.size(0), -1) | |
| sparse_diff = make_sparse(diff, sparsity).to_sparse().coalesce() | |
| indices = sparse_diff.indices().to(torch.int16) | |
| values = sparse_diff.values().half() | |
| loras[f'{lora_name}.bias_indices'] = indices | |
| loras[f'{lora_name}.bias_values'] = values | |
| loras[f'{lora_name}.bias_size'] = torch.tensor(diff.shape).to(torch.int16) | |
| del extract_a, extract_b, diff | |
| return loras | |
| text_encoder_loras = make_state_dict( | |
| LORA_PREFIX_TEXT_ENCODER, | |
| base_model[0], db_model[0], | |
| TEXT_ENCODER_TARGET_REPLACE_MODULE | |
| ) | |
| unet_loras = make_state_dict( | |
| LORA_PREFIX_UNET, | |
| base_model[2], db_model[2], | |
| UNET_TARGET_REPLACE_MODULE | |
| ) | |
| print(len(text_encoder_loras), len(unet_loras)) | |
| return text_encoder_loras|unet_loras | |
| def merge_locon( | |
| base_model, | |
| locon_state_dict: Dict[str, torch.TensorType], | |
| scale: float = 1.0, | |
| device = 'cpu' | |
| ): | |
| UNET_TARGET_REPLACE_MODULE = [ | |
| "Transformer2DModel", | |
| "Attention", | |
| "ResnetBlock2D", | |
| "Downsample2D", | |
| "Upsample2D" | |
| ] | |
| TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] | |
| LORA_PREFIX_UNET = 'lora_unet' | |
| LORA_PREFIX_TEXT_ENCODER = 'lora_te' | |
| def merge( | |
| prefix, | |
| root_module: torch.nn.Module, | |
| target_replace_modules | |
| ): | |
| temp = {} | |
| for name, module in tqdm(list(root_module.named_modules())): | |
| if module.__class__.__name__ in target_replace_modules: | |
| temp[name] = {} | |
| for child_name, child_module in module.named_modules(): | |
| layer = child_module.__class__.__name__ | |
| if layer not in {'Linear', 'Conv2d'}: | |
| continue | |
| lora_name = prefix + '.' + name + '.' + child_name | |
| lora_name = lora_name.replace('.', '_') | |
| down = locon_state_dict[f'{lora_name}.lora_down.weight'].float() | |
| up = locon_state_dict[f'{lora_name}.lora_up.weight'].float() | |
| alpha = locon_state_dict[f'{lora_name}.alpha'].float() | |
| rank = down.shape[0] | |
| if layer == 'Conv2d': | |
| delta = merge_conv(down, up, device) | |
| child_module.weight.requires_grad_(False) | |
| child_module.weight += (alpha.to(device)/rank * scale * delta).cpu() | |
| del delta | |
| elif layer == 'Linear': | |
| delta = merge_linear(down, up, device) | |
| child_module.weight.requires_grad_(False) | |
| child_module.weight += (alpha.to(device)/rank * scale * delta).cpu() | |
| del delta | |
| merge( | |
| LORA_PREFIX_TEXT_ENCODER, | |
| base_model[0], | |
| TEXT_ENCODER_TARGET_REPLACE_MODULE | |
| ) | |
| merge( | |
| LORA_PREFIX_UNET, | |
| base_model[2], | |
| UNET_TARGET_REPLACE_MODULE | |
| ) | |
| def merge_loha( | |
| base_model, | |
| loha_state_dict: Dict[str, torch.TensorType], | |
| scale: float = 1.0, | |
| device = 'cpu' | |
| ): | |
| UNET_TARGET_REPLACE_MODULE = [ | |
| "Transformer2DModel", | |
| "Attention", | |
| "ResnetBlock2D", | |
| "Downsample2D", | |
| "Upsample2D" | |
| ] | |
| TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] | |
| LORA_PREFIX_UNET = 'lora_unet' | |
| LORA_PREFIX_TEXT_ENCODER = 'lora_te' | |
| def merge( | |
| prefix, | |
| root_module: torch.nn.Module, | |
| target_replace_modules | |
| ): | |
| temp = {} | |
| for name, module in tqdm(list(root_module.named_modules())): | |
| if module.__class__.__name__ in target_replace_modules: | |
| temp[name] = {} | |
| for child_name, child_module in module.named_modules(): | |
| layer = child_module.__class__.__name__ | |
| if layer not in {'Linear', 'Conv2d'}: | |
| continue | |
| lora_name = prefix + '.' + name + '.' + child_name | |
| lora_name = lora_name.replace('.', '_') | |
| w1a = loha_state_dict[f'{lora_name}.hada_w1_a'].float().to(device) | |
| w1b = loha_state_dict[f'{lora_name}.hada_w1_b'].float().to(device) | |
| w2a = loha_state_dict[f'{lora_name}.hada_w2_a'].float().to(device) | |
| w2b = loha_state_dict[f'{lora_name}.hada_w2_b'].float().to(device) | |
| alpha = loha_state_dict[f'{lora_name}.alpha'].float().to(device) | |
| dim = w1b.shape[0] | |
| delta = (w1a @ w1b) * (w2a @ w2b) | |
| delta = delta.reshape(child_module.weight.shape) | |
| if layer == 'Conv2d': | |
| child_module.weight.requires_grad_(False) | |
| child_module.weight += (alpha.to(device)/dim * scale * delta).cpu() | |
| elif layer == 'Linear': | |
| child_module.weight.requires_grad_(False) | |
| child_module.weight += (alpha.to(device)/dim * scale * delta).cpu() | |
| del delta | |
| merge( | |
| LORA_PREFIX_TEXT_ENCODER, | |
| base_model[0], | |
| TEXT_ENCODER_TARGET_REPLACE_MODULE | |
| ) | |
| merge( | |
| LORA_PREFIX_UNET, | |
| base_model[2], | |
| UNET_TARGET_REPLACE_MODULE | |
| ) |