| | import torch |
| | import torch.nn as nn |
| |
|
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
|
| | class Res_Net(nn.Module): |
| | def __init__(self,input_cha): |
| | super(Res_Net,self).__init__() |
| | self.conv1 = nn.Conv2d(input_cha,input_cha,3,padding=1) |
| | self.conv2 = nn.Conv2d(input_cha,input_cha,5,padding=2) |
| | self.conv3 = nn.Conv2d(input_cha,input_cha,7,padding=3) |
| |
|
| | self.cbamBlock = CBAMBlock(input_cha) |
| |
|
| | self.bn1 = nn.BatchNorm2d(input_cha) |
| | self.relu1 = nn.ReLU() |
| | self.relu2 = nn.LeakyReLU() |
| |
|
| | def forward(self,x): |
| | init_x = x |
| |
|
| | out = self.conv1(x) |
| | out = self.bn1(out) |
| | out = self.relu2(out) |
| |
|
| | out = self.conv1(out) |
| | out = self.bn1(out) |
| | out += init_x |
| | out = self.relu2(out) |
| |
|
| | return out |
| |
|
| | class CBAMBlock(nn.Module): |
| | def __init__(self, channel, reduction=16): |
| | super(CBAMBlock, self).__init__() |
| | self.avg_pool = nn.AdaptiveAvgPool2d(1) |
| | self.max_pool = nn.AdaptiveMaxPool2d(1) |
| |
|
| | self.channel_excitation = nn.Sequential(nn.Linear(channel,int(channel//reduction),bias=False), |
| | nn.ReLU(inplace=True), |
| | nn.Linear(int(channel//reduction),channel,bias=False), |
| | ) |
| | self.sigmoid = nn.Sigmoid() |
| |
|
| | self.spatial_excitation = nn.Sequential(nn.Conv2d(2, 1, kernel_size=7, |
| | stride=1, padding=3, bias=False), |
| | ) |
| |
|
| |
|
| | def forward(self, x): |
| | bahs, chs, _, _ = x.size() |
| |
|
| | |
| | chn_avg = self.avg_pool(x).view(bahs, chs) |
| | chn_avg = self.channel_excitation(chn_avg).view(bahs, chs, 1, 1) |
| | chn_max = self.max_pool(x).view(bahs, chs) |
| | chn_max = self.channel_excitation(chn_max).view(bahs, chs, 1, 1) |
| | chn_add=chn_avg+chn_max |
| | chn_add=self.sigmoid(chn_add) |
| |
|
| | chn_cbam = torch.mul(x, chn_add) |
| |
|
| | avg_out = torch.mean(chn_cbam, dim=1, keepdim=True) |
| | max_out, _ = torch.max(chn_cbam, dim=1, keepdim=True) |
| | cat = torch.cat([avg_out, max_out], dim=1) |
| |
|
| | spa_add = self.spatial_excitation(cat) |
| | spa_add = self.sigmoid(spa_add) |
| | spa_cbam = torch.mul(chn_cbam, spa_add) |
| |
|
| | return spa_cbam |