|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from timm.layers import DropPath, trunc_normal_ |
|
|
|
|
|
|
|
|
class Mlp(nn.Module): |
|
|
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): |
|
|
super().__init__() |
|
|
out_features = out_features or in_features |
|
|
hidden_features = hidden_features or in_features |
|
|
self.fc1 = nn.Linear(in_features, hidden_features) |
|
|
self.act = act_layer() |
|
|
self.fc2 = nn.Linear(hidden_features, out_features) |
|
|
self.drop = nn.Dropout(drop) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.fc1(x) |
|
|
x = self.act(x) |
|
|
x = self.drop(x) |
|
|
x = self.fc2(x) |
|
|
x = self.drop(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class CMlp(nn.Module): |
|
|
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): |
|
|
super().__init__() |
|
|
out_features = out_features or in_features |
|
|
hidden_features = hidden_features or in_features |
|
|
self.fc1 = nn.Conv2d(in_features, hidden_features, 1) |
|
|
self.act = act_layer() |
|
|
self.fc2 = nn.Conv2d(hidden_features, out_features, 1) |
|
|
self.drop = nn.Dropout(drop) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.fc1(x) |
|
|
x = self.act(x) |
|
|
x = self.drop(x) |
|
|
x = self.fc2(x) |
|
|
x = self.drop(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class Attention(nn.Module): |
|
|
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): |
|
|
super().__init__() |
|
|
self.num_heads = num_heads |
|
|
head_dim = dim // num_heads |
|
|
|
|
|
self.scale = qk_scale or head_dim ** -0.5 |
|
|
|
|
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
|
|
self.attn_drop = nn.Dropout(attn_drop) |
|
|
self.proj = nn.Linear(dim, dim) |
|
|
self.proj_drop = nn.Dropout(proj_drop) |
|
|
|
|
|
def forward(self, x): |
|
|
B, N, C = x.shape |
|
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) |
|
|
q, k, v = qkv[0], qkv[1], qkv[2] |
|
|
|
|
|
attn = (q @ k.transpose(-2, -1)) * self.scale |
|
|
attn = attn.softmax(dim=-1) |
|
|
attn = self.attn_drop(attn) |
|
|
|
|
|
x = (attn @ v).transpose(1, 2).reshape(B, N, C) |
|
|
x = self.proj(x) |
|
|
x = self.proj_drop(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class CBlock(nn.Module): |
|
|
def __init__(self, dim, num_heads=4, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., |
|
|
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): |
|
|
super().__init__() |
|
|
self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim) |
|
|
self.norm1 = nn.BatchNorm2d(dim) |
|
|
self.conv1 = nn.Conv2d(dim, dim, 1) |
|
|
self.conv2 = nn.Conv2d(dim, dim, 1) |
|
|
self.attn = nn.Conv2d(dim, dim, 5, padding=2, groups=dim) |
|
|
|
|
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
|
|
self.norm2 = nn.BatchNorm2d(dim) |
|
|
mlp_hidden_dim = int(dim * mlp_ratio) |
|
|
self.mlp = CMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) |
|
|
|
|
|
self.apply(self._init_weights) |
|
|
|
|
|
def _init_weights(self, m): |
|
|
if isinstance(m, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)): |
|
|
nn.init.constant_(m.bias, 0) |
|
|
nn.init.constant_(m.weight, 1.0) |
|
|
elif isinstance(m, nn.Conv2d): |
|
|
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels |
|
|
fan_out //= m.groups |
|
|
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) |
|
|
if m.bias is not None: |
|
|
m.bias.data.zero_() |
|
|
|
|
|
@torch.jit.ignore |
|
|
def no_weight_decay(self): |
|
|
return {} |
|
|
|
|
|
def forward(self, x): |
|
|
x = x + self.pos_embed(x) |
|
|
x = x + self.drop_path(self.conv2(self.attn(self.conv1(self.norm1(x))))) |
|
|
x = x + self.drop_path(self.mlp(self.norm2(x))) |
|
|
return x |
|
|
|
|
|
|
|
|
class SABlock(nn.Module): |
|
|
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., |
|
|
drop_path=0., init_value=1e-6, act_layer=nn.GELU, norm_layer=nn.LayerNorm): |
|
|
super().__init__() |
|
|
self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim) |
|
|
self.norm1 = norm_layer(dim) |
|
|
self.attn = Attention( |
|
|
dim, |
|
|
num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, |
|
|
attn_drop=attn_drop, proj_drop=drop) |
|
|
|
|
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
|
|
self.norm2 = norm_layer(dim) |
|
|
mlp_hidden_dim = int(dim * mlp_ratio) |
|
|
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) |
|
|
|
|
|
self.gamma_1 = nn.Parameter(init_value * torch.ones((dim)),requires_grad=True) |
|
|
self.gamma_2 = nn.Parameter(init_value * torch.ones((dim)),requires_grad=True) |
|
|
|
|
|
self.apply(self._init_weights) |
|
|
|
|
|
def _init_weights(self, m): |
|
|
if isinstance(m, nn.Linear): |
|
|
trunc_normal_(m.weight, std=.02) |
|
|
if isinstance(m, nn.Linear) and m.bias is not None: |
|
|
nn.init.constant_(m.bias, 0) |
|
|
elif isinstance(m, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)): |
|
|
nn.init.constant_(m.bias, 0) |
|
|
nn.init.constant_(m.weight, 1.0) |
|
|
|
|
|
@torch.jit.ignore |
|
|
def no_weight_decay(self): |
|
|
return {'gamma_1', 'gamma_2'} |
|
|
|
|
|
def forward(self, x): |
|
|
x = x + self.pos_embed(x) |
|
|
B, N, H, W = x.shape |
|
|
x = x.flatten(2).transpose(1, 2) |
|
|
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x))) |
|
|
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) |
|
|
x = x.transpose(1, 2).reshape(B, N, H, W) |
|
|
return x |
|
|
|