# refer to the code from VAN, Thanks! # https://github.com/Visual-Attention-Network/VAN-Classification import math import torch import torch.nn as nn from timm.layers import DropPath, trunc_normal_ class DWConv(nn.Module): def __init__(self, dim=768): super(DWConv, self).__init__() self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) def forward(self, x): x = self.dwconv(x) return x class MixMlp(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Conv2d(in_features, hidden_features, 1) # 1x1 self.dwconv = DWConv(hidden_features) # CFF: Convlutional feed-forward network self.act = act_layer() # GELU self.fc2 = nn.Conv2d(hidden_features, out_features, 1) # 1x1 self.drop = nn.Dropout(drop) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) elif isinstance(m, nn.Conv2d): fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels fan_out //= m.groups m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) if m.bias is not None: m.bias.data.zero_() def forward(self, x): x = self.fc1(x) x = self.dwconv(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class LKA(nn.Module): def __init__(self, dim): super().__init__() self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim) self.conv_spatial = nn.Conv2d( dim, dim, 7, stride=1, padding=9, groups=dim, dilation=3) self.conv1 = nn.Conv2d(dim, dim, 1) def forward(self, x): u = x.clone() attn = self.conv0(x) attn = self.conv_spatial(attn) attn = self.conv1(attn) return u * attn class Attention(nn.Module): def __init__(self, d_model, attn_shortcut=True): super().__init__() self.proj_1 = nn.Conv2d(d_model, d_model, 1) self.activation = nn.GELU() self.spatial_gating_unit = LKA(d_model) self.proj_2 = nn.Conv2d(d_model, d_model, 1) self.attn_shortcut = attn_shortcut def forward(self, x): if self.attn_shortcut: shortcut = x.clone() x = self.proj_1(x) x = self.activation(x) x = self.spatial_gating_unit(x) x = self.proj_2(x) if self.attn_shortcut: x = x + shortcut return x class VANBlock(nn.Module): def __init__(self, dim, mlp_ratio=4., drop=0.,drop_path=0., init_value=1e-2, act_layer=nn.GELU, attn_shortcut=True): super().__init__() self.norm1 = nn.BatchNorm2d(dim) self.attn = Attention(dim, attn_shortcut=attn_shortcut) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = nn.BatchNorm2d(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = MixMlp( in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) self.layer_scale_1 = nn.Parameter(init_value * torch.ones((dim)), requires_grad=True) self.layer_scale_2 = nn.Parameter(init_value * torch.ones((dim)), requires_grad=True) def forward(self, x): x = x + self.drop_path( self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * self.attn(self.norm1(x))) x = x + self.drop_path( self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.mlp(self.norm2(x))) return x