|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
from timm.layers import DropPath, trunc_normal_ |
|
|
|
|
|
|
|
|
class DWConv(nn.Module): |
|
|
def __init__(self, dim=768): |
|
|
super(DWConv, self).__init__() |
|
|
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.dwconv(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class MixMlp(nn.Module): |
|
|
def __init__(self, |
|
|
in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): |
|
|
super().__init__() |
|
|
out_features = out_features or in_features |
|
|
hidden_features = hidden_features or in_features |
|
|
self.fc1 = nn.Conv2d(in_features, hidden_features, 1) |
|
|
self.dwconv = DWConv(hidden_features) |
|
|
self.act = act_layer() |
|
|
self.fc2 = nn.Conv2d(hidden_features, out_features, 1) |
|
|
self.drop = nn.Dropout(drop) |
|
|
self.apply(self._init_weights) |
|
|
|
|
|
def _init_weights(self, m): |
|
|
if isinstance(m, nn.Linear): |
|
|
trunc_normal_(m.weight, std=.02) |
|
|
if isinstance(m, nn.Linear) and m.bias is not None: |
|
|
nn.init.constant_(m.bias, 0) |
|
|
elif isinstance(m, nn.LayerNorm): |
|
|
nn.init.constant_(m.bias, 0) |
|
|
nn.init.constant_(m.weight, 1.0) |
|
|
elif isinstance(m, nn.Conv2d): |
|
|
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels |
|
|
fan_out //= m.groups |
|
|
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) |
|
|
if m.bias is not None: |
|
|
m.bias.data.zero_() |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.fc1(x) |
|
|
x = self.dwconv(x) |
|
|
x = self.act(x) |
|
|
x = self.drop(x) |
|
|
x = self.fc2(x) |
|
|
x = self.drop(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class LKA(nn.Module): |
|
|
def __init__(self, dim): |
|
|
super().__init__() |
|
|
self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim) |
|
|
self.conv_spatial = nn.Conv2d( |
|
|
dim, dim, 7, stride=1, padding=9, groups=dim, dilation=3) |
|
|
self.conv1 = nn.Conv2d(dim, dim, 1) |
|
|
|
|
|
|
|
|
def forward(self, x): |
|
|
u = x.clone() |
|
|
attn = self.conv0(x) |
|
|
attn = self.conv_spatial(attn) |
|
|
attn = self.conv1(attn) |
|
|
|
|
|
return u * attn |
|
|
|
|
|
|
|
|
class Attention(nn.Module): |
|
|
def __init__(self, d_model, attn_shortcut=True): |
|
|
super().__init__() |
|
|
|
|
|
self.proj_1 = nn.Conv2d(d_model, d_model, 1) |
|
|
self.activation = nn.GELU() |
|
|
self.spatial_gating_unit = LKA(d_model) |
|
|
self.proj_2 = nn.Conv2d(d_model, d_model, 1) |
|
|
self.attn_shortcut = attn_shortcut |
|
|
|
|
|
def forward(self, x): |
|
|
if self.attn_shortcut: |
|
|
shortcut = x.clone() |
|
|
x = self.proj_1(x) |
|
|
x = self.activation(x) |
|
|
x = self.spatial_gating_unit(x) |
|
|
x = self.proj_2(x) |
|
|
if self.attn_shortcut: |
|
|
x = x + shortcut |
|
|
return x |
|
|
|
|
|
|
|
|
class VANBlock(nn.Module): |
|
|
def __init__(self, dim, mlp_ratio=4., drop=0.,drop_path=0., init_value=1e-2, act_layer=nn.GELU, attn_shortcut=True): |
|
|
super().__init__() |
|
|
self.norm1 = nn.BatchNorm2d(dim) |
|
|
self.attn = Attention(dim, attn_shortcut=attn_shortcut) |
|
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
|
|
|
|
|
self.norm2 = nn.BatchNorm2d(dim) |
|
|
mlp_hidden_dim = int(dim * mlp_ratio) |
|
|
self.mlp = MixMlp( |
|
|
in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) |
|
|
|
|
|
self.layer_scale_1 = nn.Parameter(init_value * torch.ones((dim)), requires_grad=True) |
|
|
self.layer_scale_2 = nn.Parameter(init_value * torch.ones((dim)), requires_grad=True) |
|
|
|
|
|
def forward(self, x): |
|
|
x = x + self.drop_path( |
|
|
self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * self.attn(self.norm1(x))) |
|
|
x = x + self.drop_path( |
|
|
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.mlp(self.norm2(x))) |
|
|
return x |
|
|
|