RingMo-SAM / models /block.py
AI-Cyber's picture
Upload 123 files
8d7921b
from __future__ import print_function
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
class MergeAndConv(nn.Module):
def __init__(self, ic, oc, inner=32):
super().__init__()
self.conv1 = nn.Conv2d(ic, inner, kernel_size=3, stride=1, padding=1)
self.bn = nn.BatchNorm2d(inner)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(inner, oc, kernel_size=3, stride=1, padding=1)
def forward(self, x):
x = self.conv2(self.bn(self.relu(self.conv1(x))))
x = torch.sigmoid(x)
return x
class SideClassifer(nn.Module):
def __init__(self, ic, n_class=1, M=2, kernel_size=1):
super().__init__()
sides = []
for i in range(M):
sides.append(nn.Conv2d(ic, n_class, kernel_size=kernel_size))
self.sides = nn.ModuleList(sides)
def forward(self, x):
return [fn(x) for fn in self.sides]
class UpsampleSKConv(nn.Module):
"""docstring for UpsampleSKConvPlus"""
def __init__(self, ic, oc, reduce=4):
super(UpsampleSKConv, self).__init__()
self.relu = nn.ReLU(inplace=True)
self.prev = nn.Conv2d(ic, ic // reduce, kernel_size=3, stride=1, padding=1)
self.bn = nn.BatchNorm2d(ic // reduce)
self.next = nn.Conv2d(ic // reduce, oc, kernel_size=1, stride=1)
self.bn2 = nn.BatchNorm2d(oc)
self.sk = SKSPP(ic // reduce, ic // reduce, M=4)
def forward(self, x):
x = F.interpolate(x, scale_factor=2)
x = self.bn(self.relu(self.prev(x)))
x = self.sk(x)
x = self.bn2(self.relu(self.next(x)))
return x
class SKSPP(nn.Module):
def __init__(self, features, WH, M=2, G=1, r=16, stride=1, L=32):
""" Constructor
Args:
features: input channel dimensionality.
WH: input spatial dimensionality, used for GAP kernel size.
M: the number of branchs.
G: num of convolution groups.
r: the radio for compute d, the length of z.
stride: stride, default 1.
L: the minimum dim of the vector z in paper, default 32.
"""
super(SKSPP, self).__init__()
d = max(int(features / r), L)
self.M = M # original
self.features = features
self.convs = nn.ModuleList([])
# 1,3,5,7 padding:[0,1,2,3]
for i in range(1, M):
self.convs.append(nn.Sequential(
nn.Conv2d(features, features, kernel_size=1 + i * 2, dilation=1 + i * 2, stride=stride,
padding=((1 + i * 2) * (i * 2) + 1) // 2, groups=G),
nn.BatchNorm2d(features),
nn.ReLU(inplace=False)
))
# self.gap = nn.AvgPool2d(int(WH/stride))
self.fc = nn.Linear(features, d)
self.fcs = nn.ModuleList([])
for i in range(M):
self.fcs.append(
nn.Linear(d, features)
)
self.softmax = nn.Softmax(dim=1)
def forward(self, x):
feas = torch.unsqueeze(x, dim=1)
# F->conv1x1->conv3x3->conv5x5->conv7x7
for i, conv in enumerate(self.convs):
x = conv(x)
# if i == 0:
# feas = fea
# else:
feas = torch.cat([feas, torch.unsqueeze(x, dim=1)], dim=1)
fea_U = torch.sum(feas, dim=1)
fea_s = fea_U.mean(-1).mean(-1)
fea_z = self.fc(fea_s)
for i, fc in enumerate(self.fcs):
vector = fc(fea_z).unsqueeze_(dim=1)
if i == 0:
attention_vectors = vector
else:
attention_vectors = torch.cat([attention_vectors, vector], dim=1)
attention_vectors = self.softmax(attention_vectors)
attention_vectors = attention_vectors.unsqueeze(-1).unsqueeze(-1)
fea_v = (feas * attention_vectors).sum(dim=1)
return fea_v