Staty's picture
Upload 50 files
2b21abc verified
# refer to the code from VAN, Thanks!
# https://github.com/Visual-Attention-Network/VAN-Classification
import math
import torch
import torch.nn as nn
from timm.layers import DropPath, trunc_normal_
class DWConv(nn.Module):
def __init__(self, dim=768):
super(DWConv, self).__init__()
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
def forward(self, x):
x = self.dwconv(x)
return x
class MixMlp(nn.Module):
def __init__(self,
in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Conv2d(in_features, hidden_features, 1) # 1x1
self.dwconv = DWConv(hidden_features) # CFF: Convlutional feed-forward network
self.act = act_layer() # GELU
self.fc2 = nn.Conv2d(hidden_features, out_features, 1) # 1x1
self.drop = nn.Dropout(drop)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x):
x = self.fc1(x)
x = self.dwconv(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class LKA(nn.Module):
def __init__(self, dim):
super().__init__()
self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
self.conv_spatial = nn.Conv2d(
dim, dim, 7, stride=1, padding=9, groups=dim, dilation=3)
self.conv1 = nn.Conv2d(dim, dim, 1)
def forward(self, x):
u = x.clone()
attn = self.conv0(x)
attn = self.conv_spatial(attn)
attn = self.conv1(attn)
return u * attn
class Attention(nn.Module):
def __init__(self, d_model, attn_shortcut=True):
super().__init__()
self.proj_1 = nn.Conv2d(d_model, d_model, 1)
self.activation = nn.GELU()
self.spatial_gating_unit = LKA(d_model)
self.proj_2 = nn.Conv2d(d_model, d_model, 1)
self.attn_shortcut = attn_shortcut
def forward(self, x):
if self.attn_shortcut:
shortcut = x.clone()
x = self.proj_1(x)
x = self.activation(x)
x = self.spatial_gating_unit(x)
x = self.proj_2(x)
if self.attn_shortcut:
x = x + shortcut
return x
class VANBlock(nn.Module):
def __init__(self, dim, mlp_ratio=4., drop=0.,drop_path=0., init_value=1e-2, act_layer=nn.GELU, attn_shortcut=True):
super().__init__()
self.norm1 = nn.BatchNorm2d(dim)
self.attn = Attention(dim, attn_shortcut=attn_shortcut)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = nn.BatchNorm2d(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = MixMlp(
in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.layer_scale_1 = nn.Parameter(init_value * torch.ones((dim)), requires_grad=True)
self.layer_scale_2 = nn.Parameter(init_value * torch.ones((dim)), requires_grad=True)
def forward(self, x):
x = x + self.drop_path(
self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * self.attn(self.norm1(x)))
x = x + self.drop_path(
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.mlp(self.norm2(x)))
return x