Spaces:
Running
Running
| from copy import deepcopy | |
| from typing import Optional, Union | |
| import torch | |
| from torch import nn | |
| from einops import rearrange, repeat | |
| from einops.layers.torch import Rearrange | |
| from raw_vit import ViT, Attention, FeedForward | |
| from utils.dl.common.model import get_model_size, set_module | |
| class KTakesAll(nn.Module): | |
| # k means sparsity (the larger k is, the smaller model is) | |
| def __init__(self, k): | |
| super(KTakesAll, self).__init__() | |
| self.k = k | |
| def forward(self, g: torch.Tensor): | |
| k = int(g.size(1) * self.k) | |
| i = (-g).topk(k, 1)[1] | |
| t = g.scatter(1, i, 0) | |
| return t | |
| class Abs(nn.Module): | |
| def __init__(self): | |
| super(Abs, self).__init__() | |
| def forward(self, x): | |
| return x.abs() | |
| class SqueezeLast(nn.Module): | |
| def __init__(self): | |
| super(SqueezeLast, self).__init__() | |
| def forward(self, x): | |
| return x.squeeze(-1) | |
| class Linear_WrappedWithFBS(nn.Module): | |
| def __init__(self, linear: nn.Linear, r, k): | |
| super(Linear_WrappedWithFBS, self).__init__() | |
| self.linear = linear | |
| # for conv: (B, C_in, H, W) -> (B, C_in) -> (B, C_out) | |
| # for mlp in ViT: (B, #patches, D: dim of patches embedding) -> (B, D) -> (B, C_out) | |
| self.fbs = nn.Sequential( | |
| Rearrange('b n d -> b d n'), | |
| Abs(), | |
| nn.AdaptiveAvgPool1d(1), | |
| SqueezeLast(), | |
| nn.Linear(linear.in_features, linear.out_features // r), | |
| nn.ReLU(), | |
| nn.Linear(linear.out_features // r, linear.out_features), | |
| nn.ReLU(), | |
| KTakesAll(k) | |
| ) | |
| self.k = k | |
| self.cached_channel_attention = None # (batch_size, dim) | |
| self.use_cached_channel_attention = False | |
| def forward(self, x): | |
| if self.use_cached_channel_attention and self.cached_channel_attention is not None: | |
| channel_attention = self.cached_channel_attention | |
| else: | |
| channel_attention = self.fbs(x) | |
| self.cached_channel_attention = channel_attention | |
| raw_res = self.linear(x) | |
| return channel_attention.unsqueeze(1) * raw_res | |
| class ToQKV_WrappedWithFBS(nn.Module): | |
| """ | |
| This regards to_q/to_k/to_v as a whole (in fact it consists of multiple heads) and prunes it. | |
| It seems different channels of different heads are pruned according to the input. | |
| This is different from "removing some head" or "removing the same channels in each head". | |
| """ | |
| def __init__(self, to_qkv: nn.Linear, r, k): | |
| super(ToQKV_WrappedWithFBS, self).__init__() | |
| self.to_qkv = to_qkv | |
| self.fbses = nn.ModuleList([nn.Sequential( | |
| Rearrange('b n d -> b d n'), | |
| Abs(), | |
| nn.AdaptiveAvgPool1d(1), | |
| SqueezeLast(), | |
| nn.Linear(to_qkv.in_features, to_qkv.out_features // 3 // r), | |
| nn.ReLU(), | |
| nn.Linear(to_qkv.out_features // 3 // r, to_qkv.out_features // 3), | |
| nn.ReLU(), | |
| KTakesAll(k) | |
| ) for _ in range(3)]) | |
| self.k = k | |
| self.cached_channel_attention = None | |
| self.use_cached_channel_attention = False | |
| def forward(self, x): | |
| if self.use_cached_channel_attention and self.cached_channel_attention is not None: | |
| # print('use cache') | |
| channel_attention = self.cached_channel_attention | |
| else: | |
| # print('dynamic') | |
| channel_attention = torch.cat([fbs(x) for fbs in self.fbses], dim=1) | |
| self.cached_channel_attention = channel_attention | |
| raw_res = self.to_qkv(x) | |
| return channel_attention.unsqueeze(1) * raw_res | |
| def boost_raw_vit_by_fbs(raw_vit: ViT, r, k): | |
| raw_vit = deepcopy(raw_vit) | |
| raw_vit_model_size = get_model_size(raw_vit, True) | |
| # set_module(raw_vit.to_patch_embedding, '2', Linear_WrappedWithFBS(raw_vit.to_patch_embedding[2], r, k)) | |
| for attn, ff in raw_vit.transformer.layers: | |
| attn = attn.fn | |
| ff = ff.fn | |
| set_module(attn, 'to_qkv', ToQKV_WrappedWithFBS(attn.to_qkv, r, k)) | |
| set_module(ff.net, '0', Linear_WrappedWithFBS(ff.net[0], r, k)) | |
| boosted_vit_model_size = get_model_size(raw_vit, True) | |
| print(f'boost_raw_vit_by_fbs() | model size from {raw_vit_model_size:.3f}MB to {boosted_vit_model_size:.3f}MB ' | |
| f'(↑ {((boosted_vit_model_size - raw_vit_model_size) / raw_vit_model_size * 100):.2f}%)') | |
| return raw_vit | |
| def set_boosted_vit_sparsity(boosted_vit: ViT, sparsity: float): | |
| for attn, ff in boosted_vit.transformer.layers: | |
| attn = attn.fn | |
| ff = ff.fn | |
| q_features = attn.to_qkv.to_qkv.out_features // 3 | |
| if (q_features - int(q_features * sparsity)) % attn.heads != 0: | |
| # tune sparsity to ensure #unpruned channel % num_heads == 0 | |
| # so that the pruning seems to reduce the dim_head of each head | |
| tuned_sparsity = 1. - int((q_features - int(q_features * sparsity)) / attn.heads) * attn.heads / q_features | |
| print(f'set_boosted_vit_sparsity() | tune sparsity from {sparsity} to {tuned_sparsity}') | |
| sparsity = tuned_sparsity | |
| attn.to_qkv.k = sparsity | |
| for fbs in attn.to_qkv.fbses: | |
| fbs[-1].k = sparsity | |
| ff.net[0].k = sparsity | |
| ff.net[0].fbs[-1].k = sparsity | |
| def set_boosted_vit_inference_via_cached_channel_attentions(boosted_vit: ViT): | |
| for attn, ff in boosted_vit.transformer.layers: | |
| attn = attn.fn | |
| ff = ff.fn | |
| assert attn.to_qkv.cached_channel_attention is not None | |
| assert ff.net[0].cached_channel_attention is not None | |
| attn.to_qkv.use_cached_channel_attention = True | |
| ff.net[0].use_cached_channel_attention = True | |
| def set_boosted_vit_dynamic_inference(boosted_vit: ViT): | |
| for attn, ff in boosted_vit.transformer.layers: | |
| attn = attn.fn | |
| ff = ff.fn | |
| attn.to_qkv.use_cached_channel_attention = False | |
| ff.net[0].use_cached_channel_attention = False | |
| class StaticFBS(nn.Module): | |
| def __init__(self, static_channel_attention): | |
| super(StaticFBS, self).__init__() | |
| assert static_channel_attention.dim() == 2 and static_channel_attention.size(0) == 1 | |
| self.static_channel_attention = nn.Parameter(static_channel_attention, requires_grad=False) # (1, dim) | |
| def forward(self, x): | |
| return x * self.static_channel_attention.unsqueeze(1) | |
| def extract_surrogate_vit_via_cached_channel_attn(boosted_vit: ViT): | |
| boosted_vit = deepcopy(boosted_vit) | |
| raw_vit_model_size = get_model_size(boosted_vit, True) | |
| def get_unpruned_indexes_from_channel_attn(channel_attn: torch.Tensor, k): | |
| assert channel_attn.size(0) == 1, 'use A representative sample to generate channel attentions' | |
| res = channel_attn[0].nonzero(as_tuple=True)[0] # should be one-dim | |
| return res | |
| for attn, ff in boosted_vit.transformer.layers: | |
| attn = attn.fn | |
| ff_w_norm = ff | |
| ff = ff_w_norm.fn | |
| # prune to_qkv | |
| to_qkv = attn.to_qkv | |
| to_q_unpruned_indexes = get_unpruned_indexes_from_channel_attn( | |
| to_qkv.cached_channel_attention[:, 0: to_qkv.cached_channel_attention.size(1) // 3], | |
| to_qkv.k | |
| ) | |
| to_q_unpruned_indexes_w_offset = to_q_unpruned_indexes | |
| to_k_unpruned_indexes = get_unpruned_indexes_from_channel_attn( | |
| to_qkv.cached_channel_attention[:, to_qkv.cached_channel_attention.size(1) // 3: to_qkv.cached_channel_attention.size(1) // 3 * 2], | |
| to_qkv.k | |
| ) | |
| to_k_unpruned_indexes_w_offset = to_k_unpruned_indexes + to_qkv.cached_channel_attention.size(1) // 3 | |
| to_v_unpruned_indexes = get_unpruned_indexes_from_channel_attn( | |
| to_qkv.cached_channel_attention[:, to_qkv.cached_channel_attention.size(1) // 3 * 2: ], | |
| to_qkv.k | |
| ) | |
| to_v_unpruned_indexes_w_offset = to_v_unpruned_indexes + to_qkv.cached_channel_attention.size(1) // 3 * 2 | |
| assert to_q_unpruned_indexes.size(0) == to_k_unpruned_indexes.size(0) == to_v_unpruned_indexes.size(0) | |
| to_qkv_unpruned_indexes = torch.cat([to_q_unpruned_indexes_w_offset, to_k_unpruned_indexes_w_offset, to_v_unpruned_indexes_w_offset]) | |
| new_to_qkv = nn.Linear(to_qkv.to_qkv.in_features, to_qkv_unpruned_indexes.size(0), to_qkv.to_qkv.bias is not None) | |
| new_to_qkv.weight.data.copy_(to_qkv.to_qkv.weight.data[to_qkv_unpruned_indexes]) | |
| if to_qkv.to_qkv.bias is not None: | |
| new_to_qkv.bias.data.copy_(to_qkv.to_qkv.bias.data[to_qkv_unpruned_indexes]) | |
| set_module(attn, 'to_qkv', nn.Sequential(new_to_qkv, StaticFBS(to_qkv.cached_channel_attention[:, to_qkv_unpruned_indexes]))) | |
| # prune to_out | |
| to_out = attn.to_out[0] | |
| new_to_out = nn.Linear(to_v_unpruned_indexes.size(0), to_out.out_features, to_out.bias is not None) | |
| new_to_out.weight.data.copy_(to_out.weight.data[:, to_v_unpruned_indexes]) | |
| if to_out.bias is not None: | |
| new_to_out.bias.data.copy_(to_out.bias.data) | |
| set_module(attn, 'to_out', new_to_out) | |
| ff_0 = ff.net[0] | |
| ff_0_unpruned_indexes = get_unpruned_indexes_from_channel_attn(ff_0.cached_channel_attention, ff_0.k) | |
| new_ff_0 = nn.Linear(ff_0.linear.in_features, ff_0_unpruned_indexes.size(0), ff_0.linear.bias is not None) | |
| new_ff_0.weight.data.copy_(ff_0.linear.weight.data[ff_0_unpruned_indexes]) | |
| if ff_0.linear.bias is not None: | |
| new_ff_0.bias.data.copy_(ff_0.linear.bias.data[ff_0_unpruned_indexes]) | |
| set_module(ff.net, '0', nn.Sequential(new_ff_0, StaticFBS(ff_0.cached_channel_attention[:, ff_0_unpruned_indexes]))) | |
| ff_1 = ff.net[3] | |
| new_ff_1 = nn.Linear(ff_0_unpruned_indexes.size(0), ff_1.out_features, ff_1.bias is not None) | |
| new_ff_1.weight.data.copy_(ff_1.weight.data[:, ff_0_unpruned_indexes]) | |
| if ff_1.bias is not None: | |
| new_ff_1.bias.data.copy_(ff_1.bias.data) | |
| set_module(ff.net, '3', new_ff_1) | |
| pruned_vit_model_size = get_model_size(boosted_vit, True) | |
| print(f'extract_surrogate_vit_via_cached_channel_attn() | model size from {raw_vit_model_size:.3f}MB to {pruned_vit_model_size:.3f}MB ' | |
| f'({(pruned_vit_model_size / raw_vit_model_size * 100):.2f}%)') | |
| return boosted_vit | |
| if __name__ == '__main__': | |
| from utils.dl.common.env import set_random_seed | |
| set_random_seed(1) | |
| def verify(vit, sparsity=0.8): | |
| vit.eval() | |
| with torch.no_grad(): | |
| r = torch.rand((1, 3, 224, 224)) | |
| print(vit(r).size()) | |
| # print(vit) | |
| boosted_vit = boost_raw_vit_by_fbs(vit, r=32, k=sparsity) | |
| set_boosted_vit_sparsity(boosted_vit, sparsity) | |
| # print(boosted_vit) | |
| with torch.no_grad(): | |
| r = torch.rand((1, 3, 224, 224)) | |
| print(boosted_vit(r).size()) | |
| # set_boosted_vit_inference_via_cached_channel_attentions(boosted_vit) | |
| r = torch.rand((1, 3, 224, 224)) | |
| boosted_vit.eval() | |
| with torch.no_grad(): | |
| o1 = boosted_vit(r) | |
| pruned_vit = extract_surrogate_vit_via_cached_channel_attn(boosted_vit) | |
| pruned_vit.eval() | |
| with torch.no_grad(): | |
| o2 = pruned_vit(r) | |
| print('output diff (should be tiny): ', ((o1 - o2) ** 2).sum()) | |
| # print(pruned_vit) | |
| # print(pruned_vit) | |
| # vit_b_16 = ViT( | |
| # image_size = 224, | |
| # patch_size = 16, | |
| # num_classes = 1000, | |
| # dim = 768, # encoder layer/attention input/output size (Hidden Size D in the paper) | |
| # depth = 12, | |
| # heads = 12, # (Heads in the paper) | |
| # dim_head = 64, # attention hidden size (seems be default, never change this) | |
| # mlp_dim = 3072, # mlp layer hidden size (MLP size in the paper) | |
| # dropout = 0., | |
| # emb_dropout = 0. | |
| # ) | |
| # verify(vit_b_16) | |
| vit_l_16 = ViT( | |
| image_size = 224, | |
| patch_size = 16, | |
| num_classes = 1000, | |
| dim = 1024, # encoder layer/attention input/output size (Hidden Size D in the paper) | |
| depth = 24, | |
| heads = 16, # (Heads in the paper) | |
| dim_head = 64, # attention hidden size (seems be default, never change this) | |
| mlp_dim = 4096, # mlp layer hidden size (MLP size in the paper) | |
| dropout = 0., | |
| emb_dropout = 0. | |
| ) | |
| verify(vit_l_16, 0.98) | |
| # vit_h_16 = ViT( | |
| # image_size = 224, | |
| # patch_size = 16, | |
| # num_classes = 1000, | |
| # dim = 1280, # encoder layer/attention input/output size (Hidden Size D in the paper) | |
| # depth = 32, | |
| # heads = 16, # (Heads in the paper) | |
| # dim_head = 64, # attention hidden size (seems be default, never change this) | |
| # mlp_dim = 5120, # mlp layer hidden size (MLP size in the paper) | |
| # dropout = 0., | |
| # emb_dropout = 0. | |
| # ) | |
| # verify(vit_h_16) |