InPeerReview's picture
Upload 161 files
226675b verified
# Code is borrowed from https://github.com/fudan-zvg/SeaFormer
import math
import torch
from torch import nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmcv.cnn import build_norm_layer
from timm.models.registry import register_model
def _make_divisible(v, divisor, min_value=None):
"""
This function is taken from the original tf repo.
It ensures that all layers have a channel number that is divisible by 8
It can be seen here:
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
:param v:
:param divisor:
:param min_value:
:return:
"""
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v
def drop_path(x, drop_prob: float = 0., training: bool = False):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
def get_shape(tensor):
shape = tensor.shape
if torch.onnx.is_in_onnx_export():
shape = [i.cpu().numpy() for i in shape]
return shape
class Conv2d_BN(nn.Sequential):
def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1,
groups=1, bn_weight_init=1, bias=False,
norm_cfg=dict(type='BN', requires_grad=True)):
super().__init__()
self.inp_channel = a
self.out_channel = b
self.ks = ks
self.pad = pad
self.stride = stride
self.dilation = dilation
self.groups = groups
# self.bias = bias
self.add_module('c', nn.Conv2d(
a, b, ks, stride, pad, dilation, groups, bias=bias))
bn = build_norm_layer(norm_cfg, b)[1]
nn.init.constant_(bn.weight, bn_weight_init)
nn.init.constant_(bn.bias, 0)
self.add_module('bn', bn)
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, drop=0.,
norm_cfg=dict(type='BN', requires_grad=True)):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = Conv2d_BN(in_features, hidden_features, norm_cfg=norm_cfg)
self.dwconv = nn.Conv2d(hidden_features, hidden_features, 3, 1, 1, bias=True, groups=hidden_features)
self.act = act_layer()
self.fc2 = Conv2d_BN(hidden_features, out_features, norm_cfg=norm_cfg)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.dwconv(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class InvertedResidual(nn.Module):
def __init__(
self,
inp: int,
oup: int,
ks: int,
stride: int,
expand_ratio: int,
activations=None,
norm_cfg=dict(type='BN', requires_grad=True)
) -> None:
super(InvertedResidual, self).__init__()
self.stride = stride
self.expand_ratio = expand_ratio
assert stride in [1, 2]
if activations is None:
activations = nn.ReLU
hidden_dim = int(round(inp * expand_ratio))
self.use_res_connect = self.stride == 1 and inp == oup
layers = []
if expand_ratio != 1:
# pw
layers.append(Conv2d_BN(inp, hidden_dim, ks=1, norm_cfg=norm_cfg))
layers.append(activations())
layers.extend([
# dw
Conv2d_BN(hidden_dim, hidden_dim, ks=ks, stride=stride, pad=ks // 2, groups=hidden_dim, norm_cfg=norm_cfg),
activations(),
# pw-linear
Conv2d_BN(hidden_dim, oup, ks=1, norm_cfg=norm_cfg)
])
self.conv = nn.Sequential(*layers)
self.out_channels = oup
self._is_cn = stride > 1
def forward(self, x):
if self.use_res_connect:
return x + self.conv(x)
else:
return self.conv(x)
class StackedMV2Block(nn.Module):
def __init__(
self,
cfgs,
stem,
inp_channel=16,
activation=nn.ReLU,
norm_cfg=dict(type='BN', requires_grad=True),
width_mult=1.):
super().__init__()
self.stem = stem
if stem:
self.stem_block = nn.Sequential(
Conv2d_BN(3, inp_channel, 3, 2, 1, norm_cfg=norm_cfg),
activation()
)
self.cfgs = cfgs
self.layers = []
for i, (k, t, c, s) in enumerate(cfgs):
output_channel = _make_divisible(c * width_mult, 8)
exp_size = t * inp_channel
exp_size = _make_divisible(exp_size * width_mult, 8)
layer_name = 'layer{}'.format(i + 1)
layer = InvertedResidual(inp_channel, output_channel, ks=k, stride=s, expand_ratio=t, norm_cfg=norm_cfg,
activations=activation)
self.add_module(layer_name, layer)
inp_channel = output_channel
self.layers.append(layer_name)
def forward(self, x):
if self.stem:
x = self.stem_block(x)
for i, layer_name in enumerate(self.layers):
layer = getattr(self, layer_name)
x = layer(x)
return x
class SqueezeAxialPositionalEmbedding(nn.Module):
def __init__(self, dim, shape):
super().__init__()
self.pos_embed = nn.Parameter(torch.randn([1, dim, shape]))
def forward(self, x):
B, C, N = x.shape
x = x + F.interpolate(self.pos_embed, size=(N), mode='linear', align_corners=False)
return x
class Sea_Attention(torch.nn.Module):
def __init__(self, dim, key_dim, num_heads,
attn_ratio=2,
activation=None,
norm_cfg=dict(type='BN', requires_grad=True), ):
super().__init__()
self.num_heads = num_heads
self.scale = key_dim ** -0.5
self.key_dim = key_dim
self.nh_kd = nh_kd = key_dim * num_heads # num_head key_dim
self.d = int(attn_ratio * key_dim)
self.dh = int(attn_ratio * key_dim) * num_heads
self.attn_ratio = attn_ratio
self.to_q = Conv2d_BN(dim, nh_kd, 1, norm_cfg=norm_cfg)
self.to_k = Conv2d_BN(dim, nh_kd, 1, norm_cfg=norm_cfg)
self.to_v = Conv2d_BN(dim, self.dh, 1, norm_cfg=norm_cfg)
self.proj = torch.nn.Sequential(activation(), Conv2d_BN(
self.dh, dim, bn_weight_init=0, norm_cfg=norm_cfg))
self.proj_encode_row = torch.nn.Sequential(activation(), Conv2d_BN(
self.dh, self.dh, bn_weight_init=0, norm_cfg=norm_cfg))
self.pos_emb_rowq = SqueezeAxialPositionalEmbedding(nh_kd, 16)
self.pos_emb_rowk = SqueezeAxialPositionalEmbedding(nh_kd, 16)
self.proj_encode_column = torch.nn.Sequential(activation(), Conv2d_BN(
self.dh, self.dh, bn_weight_init=0, norm_cfg=norm_cfg))
self.pos_emb_columnq = SqueezeAxialPositionalEmbedding(nh_kd, 16)
self.pos_emb_columnk = SqueezeAxialPositionalEmbedding(nh_kd, 16)
self.dwconv = Conv2d_BN(self.dh + 2 * self.nh_kd, 2 * self.nh_kd + self.dh, ks=3, stride=1, pad=1, dilation=1,
groups=2 * self.nh_kd + self.dh, norm_cfg=norm_cfg)
self.act = activation()
self.pwconv = Conv2d_BN(2 * self.nh_kd + self.dh, dim, ks=1, norm_cfg=norm_cfg)
self.sigmoid = h_sigmoid()
def forward(self, x):
B, C, H, W = x.shape
q = self.to_q(x)
k = self.to_k(x)
v = self.to_v(x)
# detail enhance
qkv = torch.cat([q, k, v], dim=1)
qkv = self.act(self.dwconv(qkv))
qkv = self.pwconv(qkv)
# squeeze axial attention
## squeeze row
qrow = self.pos_emb_rowq(q.mean(-1)).reshape(B, self.num_heads, -1, H).permute(0, 1, 3, 2)
krow = self.pos_emb_rowk(k.mean(-1)).reshape(B, self.num_heads, -1, H)
vrow = v.mean(-1).reshape(B, self.num_heads, -1, H).permute(0, 1, 3, 2)
attn_row = torch.matmul(qrow, krow) * self.scale
attn_row = attn_row.softmax(dim=-1)
xx_row = torch.matmul(attn_row, vrow) # B nH H C
xx_row = self.proj_encode_row(xx_row.permute(0, 1, 3, 2).reshape(B, self.dh, H, 1))
## squeeze column
qcolumn = self.pos_emb_columnq(q.mean(-2)).reshape(B, self.num_heads, -1, W).permute(0, 1, 3, 2)
kcolumn = self.pos_emb_columnk(k.mean(-2)).reshape(B, self.num_heads, -1, W)
vcolumn = v.mean(-2).reshape(B, self.num_heads, -1, W).permute(0, 1, 3, 2)
attn_column = torch.matmul(qcolumn, kcolumn) * self.scale
attn_column = attn_column.softmax(dim=-1)
xx_column = torch.matmul(attn_column, vcolumn) # B nH W C
xx_column = self.proj_encode_column(xx_column.permute(0, 1, 3, 2).reshape(B, self.dh, 1, W))
xx = xx_row.add(xx_column)
xx = v.add(xx)
xx = self.proj(xx)
xx = self.sigmoid(xx) * qkv
return xx
class Block(nn.Module):
def __init__(self, dim, key_dim, num_heads, mlp_ratio=4., attn_ratio=2., drop=0.,
drop_path=0., act_layer=nn.ReLU, norm_cfg=dict(type='BN2d', requires_grad=True)):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.mlp_ratio = mlp_ratio
self.attn = Sea_Attention(dim, key_dim=key_dim, num_heads=num_heads, attn_ratio=attn_ratio,
activation=act_layer, norm_cfg=norm_cfg)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, norm_cfg=norm_cfg)
def forward(self, x1):
x1 = x1 + self.drop_path(self.attn(x1))
x1 = x1 + self.drop_path(self.mlp(x1))
return x1
class BasicLayer(nn.Module):
def __init__(self, block_num, embedding_dim, key_dim, num_heads,
mlp_ratio=4., attn_ratio=2., drop=0., attn_drop=0., drop_path=0.,
norm_cfg=dict(type='BN2d', requires_grad=True),
act_layer=None):
super().__init__()
self.block_num = block_num
self.transformer_blocks = nn.ModuleList()
for i in range(self.block_num):
self.transformer_blocks.append(Block(
embedding_dim, key_dim=key_dim, num_heads=num_heads,
mlp_ratio=mlp_ratio, attn_ratio=attn_ratio,
drop=drop, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_cfg=norm_cfg,
act_layer=act_layer))
def forward(self, x):
# token * N
for i in range(self.block_num):
x = self.transformer_blocks[i](x)
return x
class h_sigmoid(nn.Module):
def __init__(self, inplace=True):
super(h_sigmoid, self).__init__()
self.relu = nn.ReLU6(inplace=inplace)
def forward(self, x):
return self.relu(x + 3) / 6
class SeaFormer(nn.Module):
def __init__(self, cfgs,
channels,
emb_dims,
key_dims,
depths=[2,2],
num_heads=4,
attn_ratios=2,
mlp_ratios=[2, 4],
drop_path_rate=0.,
norm_cfg=dict(type='BN', requires_grad=True),
act_layer=nn.ReLU6,
init_cfg=None,
num_classes=1000):
super().__init__()
self.num_classes = num_classes
self.channels = channels
self.depths = depths
self.cfgs = cfgs
self.norm_cfg = norm_cfg
self.init_cfg = init_cfg
if self.init_cfg is not None:
self.pretrained = self.init_cfg['checkpoint']
for i in range(len(cfgs)):
smb = StackedMV2Block(cfgs=cfgs[i], stem=True if i == 0 else False, inp_channel=channels[i], norm_cfg=norm_cfg)
setattr(self, f"smb{i + 1}", smb)
for i in range(len(depths)):
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depths[i])] # stochastic depth decay rule
trans = BasicLayer(
block_num=depths[i],
embedding_dim=emb_dims[i],
key_dim=key_dims[i],
num_heads=num_heads,
mlp_ratio=mlp_ratios[i],
attn_ratio=attn_ratios,
drop=0, attn_drop=0,
drop_path=dpr,
norm_cfg=norm_cfg,
act_layer=act_layer)
setattr(self, f"trans{i + 1}", trans)
self.linear = nn.Linear(channels[-1], 1000)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
n //= m.groups
m.weight.data.normal_(0, math.sqrt(2. / n))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0, 0.01)
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x):
num_smb_stage = len(self.cfgs)
num_trans_stage = len(self.depths)
res = []
for i in range(num_smb_stage):
smb = getattr(self, f"smb{i + 1}")
x = smb(x)
if num_trans_stage + i >= num_smb_stage:
trans = getattr(self, f"trans{i + num_trans_stage - num_smb_stage + 1}")
x = trans(x)
res.append(x)
out = self.avgpool(x).view(-1, x.shape[1])
out = self.linear(out)
return res
@register_model
def SeaFormer_T(pretrained=False, **kwargs):
model_cfgs = dict(
cfg1=[
# k, t, c, s
[3, 1, 16, 1],
[3, 4, 16, 2],
[3, 3, 16, 1]],
cfg2=[
[5, 3, 32, 2],
[5, 3, 32, 1]],
cfg3=[
[3, 3, 64, 2],
[3, 3, 64, 1]],
cfg4=[
[5, 3, 128, 2]],
cfg5=[
[3, 6, 160, 2]],
channels=[16, 16, 32, 64, 128, 160],
num_heads=4,
depths=[2, 2],
emb_dims=[128, 160],
key_dims=[16, 24],
drop_path_rate=0.1,
attn_ratios=2,
mlp_ratios=[2, 4])
return SeaFormer(
cfgs=[model_cfgs['cfg1'], model_cfgs['cfg2'], model_cfgs['cfg3'], model_cfgs['cfg4'], model_cfgs['cfg5']],
channels=model_cfgs['channels'],
emb_dims=model_cfgs['emb_dims'],
key_dims=model_cfgs['key_dims'],
depths=model_cfgs['depths'],
attn_ratios=model_cfgs['attn_ratios'],
mlp_ratios=model_cfgs['mlp_ratios'],
num_heads=model_cfgs['num_heads'],
drop_path_rate=model_cfgs['drop_path_rate'])
@register_model
def SeaFormer_S(pretrained=False, **kwargs):
model_cfgs = dict(
cfg1=[
# k, t, c, s
[3, 1, 16, 1],
[3, 4, 24, 2],
[3, 3, 24, 1]],
cfg2=[
[5, 3, 48, 2],
[5, 3, 48, 1]],
cfg3=[
[3, 3, 96, 2],
[3, 3, 96, 1]],
cfg4=[
[5, 4, 160, 2]],
cfg5=[
[3, 6, 192, 2]],
channels=[16, 24, 48, 96, 160, 192],
num_heads=6,
depths=[3, 3],
key_dims=[16, 24],
emb_dims=[160, 192],
drop_path_rate=0.1,
attn_ratios=2,
mlp_ratios=[2, 4])
return SeaFormer(
cfgs=[model_cfgs['cfg1'], model_cfgs['cfg2'], model_cfgs['cfg3'], model_cfgs['cfg4'], model_cfgs['cfg5']],
channels=model_cfgs['channels'],
emb_dims=model_cfgs['emb_dims'],
key_dims=model_cfgs['key_dims'],
depths=model_cfgs['depths'],
attn_ratios=model_cfgs['attn_ratios'],
mlp_ratios=model_cfgs['mlp_ratios'],
num_heads=model_cfgs['num_heads'],
drop_path_rate=model_cfgs['drop_path_rate'])
@register_model
def SeaFormer_B(pretrained=False, **kwargs):
model_cfgs = dict(
cfg1=[
# k, t, c, s
[3, 1, 16, 1],
[3, 4, 32, 2],
[3, 3, 32, 1]],
cfg2=[
[5, 3, 64, 2],
[5, 3, 64, 1]],
cfg3=[
[3, 3, 128, 2],
[3, 3, 128, 1]],
cfg4=[
[5, 4, 192, 2]],
cfg5=[
[3, 6, 256, 2]],
channels=[16, 32, 64, 128, 192, 256],
num_heads=8,
depths=[4, 4],
key_dims=[16, 24],
emb_dims=[192, 256],
drop_path_rate=0.1,
attn_ratios=2,
mlp_ratios=[2, 4])
return SeaFormer(
cfgs=[model_cfgs['cfg1'], model_cfgs['cfg2'], model_cfgs['cfg3'], model_cfgs['cfg4'], model_cfgs['cfg5']],
channels=model_cfgs['channels'],
emb_dims=model_cfgs['emb_dims'],
key_dims=model_cfgs['key_dims'],
depths=model_cfgs['depths'],
attn_ratios=model_cfgs['attn_ratios'],
mlp_ratios=model_cfgs['mlp_ratios'],
num_heads=model_cfgs['num_heads'],
drop_path_rate=model_cfgs['drop_path_rate'])
# download link of the pretrained backbone weight
# https://drive.google.com/drive/folders/1BrZU0339JAFpKsQf4kdS0EpeeFgrBvBJ?usp=drive_link
@register_model
def SeaFormer_L(pretrained=False, weights='rscd/models/backbones/review_pretrain/SeaFormer_L_cls_79.9.pth.tar', **kwargs):
model_cfgs = dict(
cfg1=[
# k, t, c, s
[3, 3, 32, 1],
[3, 4, 64, 2],
[3, 4, 64, 1]],
cfg2=[
[5, 4, 128, 2],
[5, 4, 128, 1]],
cfg3=[
[3, 4, 192, 2],
[3, 4, 192, 1]],
cfg4=[
[5, 4, 256, 2]],
cfg5=[
[3, 6, 320, 2]],
channels=[32, 64, 128, 192, 256, 320],
num_heads=8,
depths=[3, 3, 3],
key_dims=[16, 20, 24],
emb_dims=[192, 256, 320],
drop_path_rate=0.1,
attn_ratios=2,
mlp_ratios=[2, 4, 6])
model = SeaFormer(
cfgs=[model_cfgs['cfg1'], model_cfgs['cfg2'], model_cfgs['cfg3'], model_cfgs['cfg4'], model_cfgs['cfg5']],
channels=model_cfgs['channels'],
emb_dims=model_cfgs['emb_dims'],
key_dims=model_cfgs['key_dims'],
depths=model_cfgs['depths'],
attn_ratios=model_cfgs['attn_ratios'],
mlp_ratios=model_cfgs['mlp_ratios'],
num_heads=model_cfgs['num_heads'],
drop_path_rate=model_cfgs['drop_path_rate'])
if pretrained:
model_weitht = torch.load(weights)
model.load_state_dict(model_weitht['state_dict'])
return model
if __name__ == '__main__':
model = SeaFormer_L(pretrained=True)
# ck = torch.load('model.pth.tar', map_location='cpu')
# model.load_state_dict(ck['state_dict_ema'])
input = torch.rand((1, 3, 512, 512))
print(model)
from fvcore.nn import FlopCountAnalysis, flop_count_table
model.eval()
flops = FlopCountAnalysis(model, input)
print(flop_count_table(flops))
res = model(input)
for i in res:
print(i.shape)