| import os |
| import pickle |
| from copy import deepcopy |
| from typing import Optional |
|
|
| import torch |
| from diffusers.models.activations import GEGLU, GELU |
| from cross_attn_hook import CrossAttentionExtractionHook |
| from ffn_hooker import FeedForwardHooker |
| from norm_attn_hook import NormHooker |
|
|
|
|
| |
| class SkipConnection(torch.nn.Module): |
| def __init__(self): |
| super(SkipConnection, self).__init__() |
|
|
| def forward(*args, **kwargs): |
| return args[1] |
|
|
|
|
| def calculate_mask_sparsity(hooker, threshold: Optional[float] = None): |
| total_num_lambs = 0 |
| num_activate_lambs = 0 |
| binary = getattr( |
| hooker, "binary", None |
| ) |
| for lamb in hooker.lambs: |
| total_num_lambs += lamb.size(0) |
| if binary: |
| assert threshold is None, "threshold should be None for binary mask" |
| num_activate_lambs += lamb.sum().item() |
| else: |
| assert ( |
| threshold is not None |
| ), "threshold must be provided for non-binary mask" |
| num_activate_lambs += (lamb >= threshold).sum().item() |
| return total_num_lambs, num_activate_lambs, num_activate_lambs / total_num_lambs |
|
|
|
|
| def create_pipeline( |
| pipe, |
| model_id, |
| device, |
| torch_dtype, |
| save_pt=None, |
| lambda_threshold: float = 1, |
| binary=True, |
| epsilon=0.0, |
| masking="binary", |
| attn_name="attn", |
| return_hooker=False, |
| scope=None, |
| ratio=None, |
| ): |
| """ |
| create the pipeline and optionally load the saved mask |
| """ |
| pipe.to(device) |
| pipe.vae.requires_grad_(False) |
| if hasattr(pipe, "unet"): |
| pipe.unet.requires_grad_(False) |
| else: |
| pipe.transformer.requires_grad_(False) |
| if save_pt: |
| |
| if "ff.pt" in save_pt or "attn.pt" in save_pt: |
| save_pts = get_save_pts(save_pt) |
|
|
| cross_attn_hooker = CrossAttentionExtractionHook( |
| pipe, |
| model_name=model_id, |
| regex=".*", |
| dtype=torch_dtype, |
| head_num_filter=1, |
| masking=masking, |
| dst=save_pts["attn"], |
| epsilon=epsilon, |
| attn_name=attn_name, |
| binary=binary, |
| ) |
| cross_attn_hooker.add_hooks(init_value=1) |
|
|
| ff_hooker = FeedForwardHooker( |
| pipe, |
| regex=".*", |
| dtype=torch_dtype, |
| masking=masking, |
| dst=save_pts["ff"], |
| epsilon=epsilon, |
| binary=binary, |
| ) |
| ff_hooker.add_hooks(init_value=1) |
|
|
| if os.path.exists(save_pts["norm"]): |
| norm_hooker = NormHooker( |
| pipe, |
| regex=".*", |
| dtype=torch_dtype, |
| masking=masking, |
| dst=save_pts["norm"], |
| epsilon=epsilon, |
| binary=binary, |
| ) |
| norm_hooker.add_hooks(init_value=1) |
| else: |
| norm_hooker = None |
|
|
| _ = pipe("abc", num_inference_steps=1) |
| cross_attn_hooker.load(device=device, threshold=lambda_threshold) |
| ff_hooker.load(device=device, threshold=lambda_threshold) |
| if norm_hooker: |
| norm_hooker.load(device=device, threshold=lambda_threshold) |
| if scope == "local" or scope == "global": |
| if isinstance(ratio, float): |
| attn_hooker_ratio = ratio |
| ff_hooker_ratio = ratio |
| else: |
| attn_hooker_ratio, ff_hooker_ratio = ratio[0], ratio[1] |
|
|
| if norm_hooker: |
| if len(ratio) < 3: |
| raise ValueError("Need to provide ratio for norm layer") |
| norm_hooker_ratio = ratio[2] |
|
|
| cross_attn_hooker.binarize(scope, attn_hooker_ratio) |
| ff_hooker.binarize(scope, ff_hooker_ratio) |
| if norm_hooker: |
| norm_hooker.binarize(scope, norm_hooker_ratio) |
| hookers = [cross_attn_hooker, ff_hooker] |
| if norm_hooker: |
| hookers.append(norm_hooker) |
|
|
| if return_hooker: |
| return pipe, hookers |
| else: |
| return pipe |
|
|
|
|
| def linear_layer_pruning(module, lamb): |
| heads_to_keep = torch.nonzero(lamb).squeeze() |
| if len(heads_to_keep.shape) == 0: |
| |
| heads_to_keep = heads_to_keep.unsqueeze(0) |
|
|
| modules_to_remove = [module.to_k, module.to_q, module.to_v] |
| new_heads = int(lamb.sum().item()) |
|
|
| if new_heads == 0: |
| return SkipConnection() |
|
|
| for module_to_remove in modules_to_remove: |
| |
| inner_dim = module_to_remove.out_features // module.heads |
| |
| rows_to_keep = torch.zeros( |
| module_to_remove.out_features, |
| dtype=torch.bool, |
| device=module_to_remove.weight.device, |
| ) |
|
|
| for idx in heads_to_keep: |
| rows_to_keep[idx * inner_dim : (idx + 1) * inner_dim] = True |
|
|
| |
| module_to_remove.weight.data = module_to_remove.weight.data[rows_to_keep, :] |
| if module_to_remove.bias is not None: |
| module_to_remove.bias.data = module_to_remove.bias.data[rows_to_keep] |
| module_to_remove.out_features = int(sum(rows_to_keep).item()) |
|
|
| |
| |
| if getattr(module, "to_out", None) is not None: |
| module.to_out[0].weight.data = module.to_out[0].weight.data[:, rows_to_keep] |
| module.to_out[0].in_features = int(sum(rows_to_keep).item()) |
|
|
| |
| module.inner_dim = module.inner_dim // module.heads * new_heads |
| try: |
| module.query_dim = module.query_dim // module.heads * new_heads |
| module.inner_kv_dim = module.inner_kv_dim // module.heads * new_heads |
| except: |
| pass |
| module.cross_attention_dim = module.cross_attention_dim // module.heads * new_heads |
| module.heads = new_heads |
| return module |
|
|
|
|
| def ffn_linear_layer_pruning(module, lamb): |
| lambda_to_keep = torch.nonzero(lamb).squeeze() |
| if len(lambda_to_keep) == 0: |
| return SkipConnection() |
|
|
| num_lambda = len(lambda_to_keep) |
|
|
| if isinstance(module.net[0], GELU): |
| |
| module.net[0].proj.weight.data = module.net[0].proj.weight.data[ |
| lambda_to_keep, : |
| ] |
| module.net[0].proj.out_features = num_lambda |
| if module.net[0].proj.bias is not None: |
| module.net[0].proj.bias.data = module.net[0].proj.bias.data[lambda_to_keep] |
|
|
| update_act = GELU(module.net[0].proj.in_features, num_lambda) |
| update_act.proj = module.net[0].proj |
| module.net[0] = update_act |
| elif isinstance(module.net[0], GEGLU): |
| output_feature = module.net[0].proj.out_features |
| module.net[0].proj.weight.data = torch.cat( |
| [ |
| module.net[0].proj.weight.data[: output_feature // 2, :][ |
| lambda_to_keep, : |
| ], |
| module.net[0].proj.weight.data[output_feature // 2 :][ |
| lambda_to_keep, : |
| ], |
| ], |
| dim=0, |
| ) |
| module.net[0].proj.out_features = num_lambda * 2 |
| if module.net[0].proj.bias is not None: |
| module.net[0].proj.bias.data = torch.cat( |
| [ |
| module.net[0].proj.bias.data[: output_feature // 2][lambda_to_keep], |
| module.net[0].proj.bias.data[output_feature // 2 :][lambda_to_keep], |
| ] |
| ) |
|
|
| update_act = GEGLU(module.net[0].proj.in_features, num_lambda * 2) |
| update_act.proj = module.net[0].proj |
| module.net[0] = update_act |
|
|
| |
| module.net[2].weight.data = module.net[2].weight.data[:, lambda_to_keep] |
| module.net[2].in_features = num_lambda |
|
|
| return module |
|
|
|
|
| |
| class SparsityLinear(torch.nn.Module): |
| def __init__(self, in_features, out_features, lambda_to_keep, num_lambda): |
| super(SparsityLinear, self).__init__() |
| self.linear = torch.nn.Linear(in_features, num_lambda) |
| self.out_features = out_features |
| self.lambda_to_keep = lambda_to_keep |
|
|
| def forward(self, x): |
| x = self.linear(x) |
| output = torch.zeros( |
| x.size(0), self.out_features, device=x.device, dtype=x.dtype |
| ) |
| output[:, self.lambda_to_keep] = x |
| return output |
|
|
|
|
| def norm_layer_pruning(module, lamb): |
| """ |
| Pruning the layer normalization layer for FLUX model |
| """ |
| lambda_to_keep = torch.nonzero(lamb).squeeze() |
| if len(lambda_to_keep) == 0: |
| return SkipConnection() |
|
|
| num_lambda = len(lambda_to_keep) |
|
|
| |
| in_features = module.linear.in_features |
| out_features = module.linear.out_features |
|
|
| linear = SparsityLinear(in_features, out_features, lambda_to_keep, num_lambda) |
| linear.linear.weight.data = module.linear.weight.data[lambda_to_keep] |
| linear.linear.bias.data = module.linear.bias.data[lambda_to_keep] |
| module.linear = linear |
| return module |
|
|
|
|
| def get_save_pts(save_pt): |
| if "ff.pt" in save_pt: |
| ff_save_pt = deepcopy(save_pt) |
| attn_save_pt = save_pt.split(os.sep) |
| attn_save_pt[-1] = attn_save_pt[-1].replace("ff", "attn") |
| attn_save_pt_output = os.sep.join(attn_save_pt) |
| attn_save_pt[-1] = attn_save_pt[-1].replace("attn", "norm") |
| norm_save_pt = os.sep.join(attn_save_pt) |
|
|
| return { |
| "ff": ff_save_pt, |
| "attn": attn_save_pt_output, |
| "norm": norm_save_pt, |
| } |
| else: |
| attn_save_pt = deepcopy(save_pt) |
| ff_save_pt = save_pt.split(os.sep) |
| ff_save_pt[-1] = ff_save_pt[-1].replace("attn", "ff") |
| ff_save_pt_output = os.sep.join(ff_save_pt) |
| ff_save_pt[-1] = ff_save_pt[-1].replace("ff", "norm") |
| norm_save_pt = os.sep.join(attn_save_pt) |
|
|
| return { |
| "ff": ff_save_pt_output, |
| "attn": attn_save_pt, |
| "norm": norm_save_pt, |
| } |
|
|
|
|
| def save_img(pipe, g_cpu, steps, prompt, save_path): |
| image = pipe(prompt, generator=g_cpu, num_inference_steps=steps) |
| image["images"][0].save(save_path) |