InPeerReview's picture
Upload 161 files
226675b verified
import torch.nn as nn
import torch
class ChannelAttention(nn.Module):
def __init__(self, in_channels, ratio = 16):
super(ChannelAttention, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.fc1 = nn.Conv2d(in_channels,in_channels//ratio,1,bias=False)
self.relu1 = nn.ReLU()
self.fc2 = nn.Conv2d(in_channels//ratio, in_channels,1,bias=False)
self.sigmod = nn.Sigmoid()
def forward(self,x):
avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
out = avg_out + max_out
return self.sigmod(out)
class ECAM(nn.Module):
# SNUNet-CD with ECAM
def __init__(self, out_ch=2):
super(ECAM, self).__init__()
torch.nn.Module.dump_patches = True
n1 = 32 # the initial number of channels of feature map
filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]
self.ca = ChannelAttention(filters[0] * 4, ratio=16)
self.ca1 = ChannelAttention(filters[0], ratio=16 // 4)
self.conv_final = nn.Conv2d(filters[0] * 4, out_ch, kernel_size=1)
def forward(self, x):
out = torch.cat([x[0], x[1], x[2], x[3]], 1)
intra = torch.sum(torch.stack((x[0], x[1], x[2], x[3])), dim=0)
ca1 = self.ca1(intra)
out = self.ca(out) * (out + ca1.repeat(1, 4, 1, 1))
out = self.conv_final(out)
return out