| import torch.nn as nn |
| import torch |
|
|
| class ChannelAttention(nn.Module): |
| def __init__(self, in_channels, ratio = 16): |
| super(ChannelAttention, self).__init__() |
| self.avg_pool = nn.AdaptiveAvgPool2d(1) |
| self.max_pool = nn.AdaptiveMaxPool2d(1) |
| self.fc1 = nn.Conv2d(in_channels,in_channels//ratio,1,bias=False) |
| self.relu1 = nn.ReLU() |
| self.fc2 = nn.Conv2d(in_channels//ratio, in_channels,1,bias=False) |
| self.sigmod = nn.Sigmoid() |
| def forward(self,x): |
| avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x)))) |
| max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x)))) |
| out = avg_out + max_out |
| return self.sigmod(out) |
|
|
| class ECAM(nn.Module): |
| |
| def __init__(self, out_ch=2): |
| super(ECAM, self).__init__() |
| torch.nn.Module.dump_patches = True |
| n1 = 32 |
| filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16] |
|
|
| self.ca = ChannelAttention(filters[0] * 4, ratio=16) |
| self.ca1 = ChannelAttention(filters[0], ratio=16 // 4) |
|
|
| self.conv_final = nn.Conv2d(filters[0] * 4, out_ch, kernel_size=1) |
|
|
| def forward(self, x): |
| out = torch.cat([x[0], x[1], x[2], x[3]], 1) |
|
|
| intra = torch.sum(torch.stack((x[0], x[1], x[2], x[3])), dim=0) |
| ca1 = self.ca1(intra) |
| out = self.ca(out) * (out + ca1.repeat(1, 4, 1, 1)) |
| out = self.conv_final(out) |
|
|
| return out |
|
|