|
|
import torch.nn as nn |
|
|
import torch |
|
|
|
|
|
class ChannelAttention(nn.Module): |
|
|
def __init__(self, in_channels, ratio = 16): |
|
|
super(ChannelAttention, self).__init__() |
|
|
self.avg_pool = nn.AdaptiveAvgPool2d(1) |
|
|
self.max_pool = nn.AdaptiveMaxPool2d(1) |
|
|
self.fc1 = nn.Conv2d(in_channels,in_channels//ratio,1,bias=False) |
|
|
self.relu1 = nn.ReLU() |
|
|
self.fc2 = nn.Conv2d(in_channels//ratio, in_channels,1,bias=False) |
|
|
self.sigmod = nn.Sigmoid() |
|
|
def forward(self,x): |
|
|
avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x)))) |
|
|
max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x)))) |
|
|
out = avg_out + max_out |
|
|
return self.sigmod(out) |
|
|
|
|
|
class ECAM(nn.Module): |
|
|
|
|
|
def __init__(self, out_ch=2): |
|
|
super(ECAM, self).__init__() |
|
|
torch.nn.Module.dump_patches = True |
|
|
n1 = 32 |
|
|
filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16] |
|
|
|
|
|
self.ca = ChannelAttention(filters[0] * 4, ratio=16) |
|
|
self.ca1 = ChannelAttention(filters[0], ratio=16 // 4) |
|
|
|
|
|
self.conv_final = nn.Conv2d(filters[0] * 4, out_ch, kernel_size=1) |
|
|
|
|
|
def forward(self, x): |
|
|
out = torch.cat([x[0], x[1], x[2], x[3]], 1) |
|
|
|
|
|
intra = torch.sum(torch.stack((x[0], x[1], x[2], x[3])), dim=0) |
|
|
ca1 = self.ca1(intra) |
|
|
out = self.ca(out) * (out + ca1.repeat(1, 4, 1, 1)) |
|
|
out = self.conv_final(out) |
|
|
|
|
|
return out |
|
|
|