| from functools import partial |
| import torch |
| import torch.nn as nn |
| import torchinfo |
| from timm.models.layers import to_2tuple, trunc_normal_, DropPath |
| from timm.models.registry import register_model |
| from timm.models.vision_transformer import _cfg |
| from einops.layers.torch import Rearrange |
| import torch.nn.functional as F |
| from timm.models.vision_transformer import PatchEmbed, Block |
|
|
| from spikingjelly.clock_driven import layer |
| import copy |
| from torchvision import transforms |
| import matplotlib.pyplot as plt |
|
|
| import models.encoder as encoder |
| from .util.pos_embed import get_2d_sincos_pos_embed |
|
|
| import torch |
|
|
| |
| T=4 |
|
|
|
|
| class multispike(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, input, lens=T): |
| ctx.save_for_backward(input) |
| ctx.lens = lens |
| return torch.floor(torch.clamp(input, 0, lens) + 0.5) |
|
|
| @staticmethod |
| def backward(ctx, grad_output): |
| input, = ctx.saved_tensors |
| grad_input = grad_output.clone() |
| temp1 = 0 < input |
| temp2 = input < ctx.lens |
| return grad_input * temp1.float() * temp2.float(), None |
|
|
|
|
| class Multispike(nn.Module): |
| def __init__(self, spike=multispike,norm=T): |
| super().__init__() |
| self.lens = norm |
| self.spike = spike |
| self.norm=norm |
|
|
| def forward(self, inputs): |
| return self.spike.apply(inputs)/self.norm |
|
|
|
|
|
|
|
|
| def MS_conv_unit(in_channels, out_channels,kernel_size=1,padding=0,groups=1): |
| return nn.Sequential( |
| layer.SeqToANNContainer( |
| encoder.SparseConv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding, groups=groups,bias=True), |
| encoder.SparseBatchNorm2d(out_channels) |
| ) |
| ) |
| class MS_ConvBlock(nn.Module): |
| def __init__(self, dim, |
| mlp_ratio=4.0): |
| super().__init__() |
|
|
| self.neuron1 = Multispike() |
| self.conv1 = MS_conv_unit(dim, dim * mlp_ratio, 3, 1) |
|
|
| self.neuron2 = Multispike() |
| self.conv2 = MS_conv_unit(dim*mlp_ratio, dim, 3, 1) |
|
|
|
|
| def forward(self, x, mask=None): |
| short_cut = x |
| x = self.neuron1(x) |
| x = self.conv1(x) |
| x = self.neuron2(x) |
| x = self.conv2(x) |
| x = x +short_cut |
| return x |
|
|
| class MS_MLP(nn.Module): |
| def __init__( |
| self, in_features, hidden_features=None, out_features=None, drop=0.0, layer=0 |
| ): |
| super().__init__() |
| out_features = out_features or in_features |
| hidden_features = hidden_features or in_features |
| self.fc1_conv = nn.Conv1d(in_features, hidden_features, kernel_size=1, stride=1) |
| self.fc1_bn = nn.BatchNorm1d(hidden_features) |
| self.fc1_lif = Multispike() |
|
|
|
|
| self.fc2_conv = nn.Conv1d( |
| hidden_features, out_features, kernel_size=1, stride=1 |
| ) |
| self.fc2_bn = nn.BatchNorm1d(out_features) |
| self.fc2_lif = Multispike() |
|
|
| self.c_hidden = hidden_features |
| self.c_output = out_features |
|
|
| def forward(self, x): |
| T, B, C, N= x.shape |
|
|
| x = self.fc1_lif(x) |
| x = self.fc1_conv(x.flatten(0, 1)) |
| x = self.fc1_bn(x).reshape(T, B, self.c_hidden, N).contiguous() |
|
|
| x = self.fc2_lif(x) |
| x = self.fc2_conv(x.flatten(0, 1)) |
| x = self.fc2_bn(x).reshape(T, B, C, N).contiguous() |
|
|
| return x |
|
|
| class RepConv(nn.Module): |
| def __init__( |
| self, |
| in_channel, |
| out_channel, |
| bias=False, |
| ): |
| super().__init__() |
| |
| self.conv1 = nn.Sequential(nn.Conv1d(in_channel, int(in_channel*1.5), kernel_size=1, stride=1,bias=False), nn.BatchNorm1d(int(in_channel*1.5))) |
| self.conv2 = nn.Sequential(nn.Conv1d(int(in_channel*1.5), out_channel, kernel_size=1, stride=1,bias=False), nn.BatchNorm1d(out_channel)) |
| def forward(self, x): |
| return self.conv2(self.conv1(x)) |
| class RepConv2(nn.Module): |
| def __init__( |
| self, |
| in_channel, |
| out_channel, |
| bias=False, |
| ): |
| super().__init__() |
| |
| self.conv1 = nn.Sequential(nn.Conv1d(in_channel, int(in_channel*1.5), kernel_size=1, stride=1,bias=False), nn.BatchNorm1d(int(in_channel*1.5))) |
| self.conv2 = nn.Sequential(nn.Conv1d(int(in_channel*1.5), out_channel, kernel_size=1, stride=1,bias=False), nn.BatchNorm1d(out_channel)) |
| def forward(self, x): |
| return self.conv2(self.conv1(x)) |
|
|
| class MS_Attention_Conv_qkv_id(nn.Module): |
| def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): |
| super().__init__() |
| assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." |
| self.dim = dim |
| self.num_heads = num_heads |
| self.scale = 0.125 |
| self.sr_ratio=sr_ratio |
|
|
| self.head_lif = Multispike() |
|
|
| |
| self.q_conv = nn.Sequential(RepConv(dim,dim), nn.BatchNorm1d(dim)) |
| self.k_conv = nn.Sequential(RepConv(dim,dim), nn.BatchNorm1d(dim)) |
| self.v_conv = nn.Sequential(RepConv(dim,dim*sr_ratio), nn.BatchNorm1d(dim*sr_ratio)) |
|
|
| |
| |
|
|
| self.q_lif = Multispike() |
|
|
| self.k_lif = Multispike() |
|
|
| self.v_lif = Multispike() |
|
|
| self.attn_lif = Multispike() |
|
|
| self.proj_conv = nn.Sequential(RepConv(sr_ratio*dim,dim), nn.BatchNorm1d(dim)) |
|
|
| def forward(self, x): |
| T, B, C, N = x.shape |
|
|
| x = self.head_lif(x) |
|
|
| x_for_qkv = x.flatten(0, 1) |
| q_conv_out = self.q_conv(x_for_qkv).reshape(T, B, C, N) |
|
|
| q_conv_out = self.q_lif(q_conv_out) |
|
|
| q = q_conv_out.transpose(-1, -2).reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, |
| 4) |
|
|
| k_conv_out = self.k_conv(x_for_qkv).reshape(T, B, C, N) |
|
|
| k_conv_out = self.k_lif(k_conv_out) |
|
|
| k = k_conv_out.transpose(-1, -2).reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, |
| 4) |
|
|
| v_conv_out = self.v_conv(x_for_qkv).reshape(T, B, self.sr_ratio*C, N) |
|
|
| v_conv_out = self.v_lif(v_conv_out) |
|
|
| v = v_conv_out.transpose(-1, -2).reshape(T, B, N, self.num_heads, self.sr_ratio*C // self.num_heads).permute(0, 1, 3, 2, |
| 4) |
|
|
| x = k.transpose(-2, -1) @ v |
| x = (q @ x) * self.scale |
| x = x.transpose(3, 4).reshape(T, B, self.sr_ratio*C, N) |
| x = self.attn_lif(x) |
|
|
| x = self.proj_conv(x.flatten(0, 1)).reshape(T, B, C, N) |
| return x |
|
|
|
|
|
|
|
|
| class MS_DownSampling(nn.Module): |
| def __init__( |
| self, |
| in_channels=2, |
| embed_dims=256, |
| kernel_size=3, |
| stride=2, |
| padding=1, |
| first_layer=True, |
| |
| ): |
| super().__init__() |
|
|
| self.encode_conv = encoder.SparseConv2d( |
| in_channels, |
| embed_dims, |
| kernel_size=kernel_size, |
| stride=stride, |
| padding=padding, |
| ) |
|
|
| self.encode_bn = encoder.SparseBatchNorm2d(embed_dims) |
| self.first_layer = first_layer |
| if not first_layer: |
| self.encode_spike = Multispike() |
|
|
| def forward(self, x): |
| T, B, _, _, _ = x.shape |
| if hasattr(self, "encode_spike"): |
| x = self.encode_spike(x) |
| x = self.encode_conv(x.flatten(0, 1)) |
| _, _, H, W = x.shape |
| x = self.encode_bn(x).reshape(T, B, -1, H, W) |
|
|
| return x |
|
|
| class MS_Block(nn.Module): |
| def __init__( |
| self, |
| dim, |
| choice, |
| num_heads, |
| mlp_ratio=4.0, |
| qkv_bias=False, |
| qk_scale=None, |
| drop=0.0, |
| attn_drop=0.0, |
| drop_path=0.0, |
| norm_layer=nn.LayerNorm, |
| sr_ratio=1,init_values=1e-6,finetune=False, |
| ): |
| super().__init__() |
| self.model=choice |
| if self.model=="base": |
| self.rep_conv=RepConv2(dim,dim) |
| self.lif = Multispike() |
| self.attn = MS_Attention_Conv_qkv_id( |
| dim, |
| num_heads=num_heads, |
| qkv_bias=qkv_bias, |
| qk_scale=qk_scale, |
| attn_drop=attn_drop, |
| proj_drop=drop, |
| sr_ratio=sr_ratio, |
| ) |
| self.finetune = finetune |
| self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() |
| mlp_hidden_dim = int(dim * mlp_ratio) |
| self.mlp = MS_MLP(in_features=dim, hidden_features=mlp_hidden_dim, drop=drop) |
|
|
| if self.finetune: |
| self.layer_scale1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) |
| self.layer_scale2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) |
|
|
| def forward(self, x): |
| |
| if self.model=="base": |
| x= x + self.rep_conv(self.lif(x).flatten(0, 1)).reshape(T, B, C, N) |
| |
| if self.finetune: |
| x = x + self.drop_path(self.attn(x) * self.layer_scale1.unsqueeze(0).unsqueeze(0).unsqueeze(-1)) |
| x = x + self.drop_path(self.mlp(x) * self.layer_scale2.unsqueeze(0).unsqueeze(0).unsqueeze(-1)) |
| else: |
| x = x + self.attn(x) |
| x = x + self.mlp(x) |
| return x |
|
|
| class Spikmae(nn.Module): |
| def __init__(self, T=1,choice=None, |
| img_size_h=224, |
| img_size_w=224, |
| patch_size=16, |
| embed_dim=[128, 256, 512], |
| num_heads=8, |
| mlp_ratios=4, |
| in_channels=3, |
| qk_scale=None, |
| drop_rate=0.0, |
| attn_drop_rate=0.0, |
| drop_path_rate=0.0, |
| num_classes=1000, |
| qkv_bias=False, |
| norm_layer=partial(nn.LayerNorm, eps=1e-6), |
| depths=8, |
| sr_ratios=1, |
| decoder_embed_dim=768, |
| decoder_depth=4, |
| decoder_num_heads=16, |
| mlp_ratio=4., |
| norm_pix_loss=False, nb_classes=1000): |
| super().__init__() |
|
|
| self.num_classes = num_classes |
| self.depths = depths |
| self.T = 1 |
|
|
| dpr = [ |
| x.item() for x in torch.linspace(0, drop_path_rate, depths) |
| ] |
|
|
| self.downsample1_1 = MS_DownSampling( |
| in_channels=in_channels, |
| embed_dims=embed_dim[0] // 2, |
| kernel_size=7, |
| stride=2, |
| padding=3, |
| first_layer=True, |
| ) |
|
|
| self.ConvBlock1_1 = nn.ModuleList( |
| [MS_ConvBlock(dim=embed_dim[0] // 2, mlp_ratio=mlp_ratios)] |
| ) |
|
|
| self.downsample1_2 = MS_DownSampling( |
| in_channels=embed_dim[0] // 2, |
| embed_dims=embed_dim[0], |
| kernel_size=3, |
| stride=2, |
| padding=1, |
| first_layer=False, |
|
|
| ) |
|
|
| self.ConvBlock1_2 = nn.ModuleList( |
| [MS_ConvBlock(dim=embed_dim[0], mlp_ratio=mlp_ratios)] |
| ) |
|
|
| self.downsample2 = MS_DownSampling( |
| in_channels=embed_dim[0], |
| embed_dims=embed_dim[1], |
| kernel_size=3, |
| stride=2, |
| padding=1, |
| first_layer=False, |
|
|
| ) |
|
|
| self.ConvBlock2_1 = nn.ModuleList( |
| [MS_ConvBlock(dim=embed_dim[1], mlp_ratio=mlp_ratios)] |
| ) |
|
|
| self.ConvBlock2_2 = nn.ModuleList( |
| [MS_ConvBlock(dim=embed_dim[1], mlp_ratio=mlp_ratios)] |
| ) |
|
|
| self.downsample3 = MS_DownSampling( |
| in_channels=embed_dim[1], |
| embed_dims=embed_dim[2], |
| kernel_size=3, |
| stride=2, |
| padding=1, |
| first_layer=False, |
|
|
| ) |
|
|
| self.block3 = nn.ModuleList( |
| [ |
| MS_Block( |
| dim=embed_dim[2], |
| choice=choice, |
| num_heads=num_heads, |
| mlp_ratio=mlp_ratios, |
| qkv_bias=qkv_bias, |
| qk_scale=qk_scale, |
| drop=drop_rate, |
| attn_drop=attn_drop_rate, |
| drop_path=dpr[j], |
| norm_layer=norm_layer, |
| sr_ratio=sr_ratios, |
| finetune=False, |
| ) |
| for j in range(depths) |
| ] |
| ) |
|
|
| self.norm = nn.BatchNorm1d(embed_dim[-1]) |
| self.downsample_raito =16 |
|
|
| num_patches = 196 |
|
|
| self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim[-1],num_patches), requires_grad=False) |
|
|
| |
| self.decoder_embed = nn.Linear(embed_dim[-1], decoder_embed_dim,bias=True) |
| self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) |
| |
| self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches, decoder_embed_dim), requires_grad=False) |
| self.decoder_blocks = nn.ModuleList([ |
| Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=False, norm_layer=norm_layer) |
| for i in range(decoder_depth)]) |
| self.decoder_norm = norm_layer(decoder_embed_dim) |
| self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size ** 2 * in_channels,bias=True) |
| self.initialize_weights() |
|
|
| def initialize_weights(self): |
| num_patches=196 |
| pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[1], int(num_patches ** .5), |
| cls_token=False) |
|
|
| self.pos_embed.data.copy_(torch.from_numpy(pos_embed.transpose(1,0)).float().unsqueeze(0)) |
|
|
| decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], |
| int(num_patches** .5), cls_token=False) |
| self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) |
|
|
| torch.nn.init.normal_(self.mask_token, std=.02) |
| self.apply(self._init_weights) |
|
|
| def _init_weights(self, m): |
| if isinstance(m, nn.Linear): |
| trunc_normal_(m.weight, std=0.02) |
| if isinstance(m, nn.Linear) and m.bias is not None: |
| nn.init.constant_(m.bias, 0) |
| elif isinstance(m, nn.LayerNorm): |
| nn.init.constant_(m.bias, 0) |
| nn.init.constant_(m.weight, 1.0) |
| def random_masking(self, x, mask_ratio): |
| """ |
| Perform per-sample random masking by per-sample shuffling. |
| Per-sample shuffling is done by argsort random noise. |
| x: [N, L, D], sequence |
| """ |
| num_patches=196 |
| T, N, _, _, _ = x.shape |
| L = num_patches |
| len_keep = int(L * (1 - mask_ratio)) |
|
|
| noise = torch.rand(N, L, device=x.device) |
|
|
| |
| ids_shuffle = torch.argsort(noise, dim=1) |
| ids_restore = torch.argsort(ids_shuffle, dim=1) |
|
|
| |
| ids_keep = ids_shuffle[:, :len_keep] |
|
|
| |
| mask = torch.ones([N, L], device=x.device) |
| mask[:, :len_keep] = 0 |
| |
| mask = torch.gather(mask, dim=1, index=ids_restore) |
|
|
| |
| active = torch.ones([N, L], device=x.device) |
| active[:, len_keep:] = 0 |
| active = torch.gather(active, dim=1, index=ids_restore) |
|
|
| return ids_keep, active, ids_restore |
|
|
| def forward_encoder(self, x , mask_ratio=1.0): |
| x = (x.unsqueeze(0)).repeat(self.T, 1, 1, 1, 1) |
| |
| ids_keep, active, ids_restore = self.random_masking(x , mask_ratio) |
| B,N=active.shape |
| active_b1ff=active.reshape(B,1,14,14) |
|
|
|
|
| encoder._cur_active = active_b1ff |
| active_hw = active_b1ff.repeat_interleave(self.downsample_raito, 2).repeat_interleave(self.downsample_raito, 3) |
| active_hw = active_hw.unsqueeze(0) |
| masked_bchw = x * active_hw |
| x = masked_bchw |
| x = self.downsample1_1(x) |
| for blk in self.ConvBlock1_1: |
| x = blk(x) |
| x = self.downsample1_2(x) |
| for blk in self.ConvBlock1_2: |
| x = blk(x) |
|
|
| x = self.downsample2(x) |
| for blk in self.ConvBlock2_1: |
| x = blk(x) |
| for blk in self.ConvBlock2_2: |
| x = blk(x) |
|
|
| x = self.downsample3(x) |
| x = x.flatten(3) |
| for blk in self.block3: |
| x = blk(x) |
|
|
| x = x.mean(0) |
| x = self.norm(x).transpose(-1, -2).contiguous() |
| return x, active,ids_restore,active_hw |
|
|
| def forward_decoder(self, x, ids_restore): |
| |
| B, N, C = x.shape |
| x = self.decoder_embed(x) |
| |
| |
| mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] - x.shape[1], 1) |
| x_ = torch.cat([x[:, :, :], mask_tokens], dim=1) |
| x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) |
| x = x_ |
| |
| |
| x = x + self.decoder_pos_embed |
| |
| for blk in self.decoder_blocks: |
| x = blk(x) |
| x = self.decoder_norm(x) |
| x = self.decoder_pred(x) |
|
|
| return x |
|
|
| def patchify(self, imgs): |
| """ |
| imgs: (N, 3, H, W) |
| x: (N, L, patch_size**2 *3) |
| """ |
| p = 16 |
| assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 |
|
|
| h = w = imgs.shape[2] // p |
| x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) |
| x = torch.einsum('nchpwq->nhwpqc', x) |
| x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * 3)) |
| return x |
|
|
| def unpatchify(self, x): |
| """ |
| x: (N, L, patch_size**2 *3) |
| imgs: (N, 3, H, W) |
| """ |
| p = 16 |
| h = w = int(x.shape[1] ** .5) |
| assert h * w == x.shape[1] |
|
|
| x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) |
| x = torch.einsum('nhwpqc->nchpwq', x) |
| imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) |
| return imgs |
| def forward_loss(self, imgs, pred, mask): |
| """ |
| imgs: [N, 3, H, W] |
| pred: [N, L, p*p*3] |
| mask: [N, L], 0 is keep, 1 is remove, |
| """ |
|
|
| inp, rec = self.patchify(imgs), pred |
| mean = inp.mean(dim=-1, keepdim=True) |
| var = (inp.var(dim=-1, keepdim=True) + 1e-6) ** .5 |
| inp = (inp - mean) / var |
| l2_loss = ((rec - inp) ** 2).mean(dim=2, keepdim=False) |
| non_active = mask.logical_not().int().view(mask.shape[0], -1) |
| recon_loss = l2_loss.mul_(non_active).sum() / (non_active.sum() + 1e-8) |
| return recon_loss,mean,var |
|
|
| def forward(self, imgs, mask_ratio=0.5,vis=False): |
|
|
| latent, active, ids_restore,active_hw = self.forward_encoder(imgs, mask_ratio) |
| rec = self.forward_decoder(latent, ids_restore) |
| recon_loss,mean,var = self.forward_loss(imgs, rec, active) |
| if vis: |
| masked_bchw = imgs * active_hw.flatten(0,1) |
| rec_bchw = self.unpatchify(rec * var + mean) |
| rec_or_inp = torch.where(active_hw.flatten(0,1).bool(), imgs, rec_bchw) |
| return imgs, masked_bchw, rec_or_inp |
| else: |
| return recon_loss |
|
|
|
|
| def spikmae_12_512(**kwargs): |
| model = Spikmae( |
| T=1, |
| choice="base", |
| img_size_h=224, |
| img_size_w=224, |
| patch_size=16, |
| embed_dim=[128,256,512], |
| num_heads=8, |
| mlp_ratios=4, |
| in_channels=3, |
| num_classes=1000, |
| qkv_bias=False, |
| norm_layer=partial(nn.LayerNorm, eps=1e-6), |
| depths=12, |
| sr_ratios=1, decoder_embed_dim=256, decoder_depth=4, decoder_num_heads=4, |
| **kwargs) |
| return model |
| def spikmae_12_768(**kwargs): |
| model = Spikmae( |
| T=1, |
| choice="large", |
| img_size_h=224, |
| img_size_w=224, |
| patch_size=16, |
| embed_dim=[192,384,768], |
| num_heads=8, |
| mlp_ratios=4, |
| in_channels=3, |
| num_classes=1000, |
| qkv_bias=False, |
| norm_layer=partial(nn.LayerNorm, eps=1e-6), |
| depths=12, |
| sr_ratios=1, decoder_embed_dim=256, decoder_depth=4, decoder_num_heads=4, |
| **kwargs) |
| return model |
|
|
|
|
|
|
|
|
| if __name__ == "__main__": |
| model = spikmae_12_768() |
| x=torch.randn(1,3,224,224) |
| loss = model(x,mask_ratio=0.50) |
| print('loss',loss) |
| torchinfo.summary(model, (1, 3, 224, 224)) |
| print(f"number of params: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") |
|
|