File size: 2,507 Bytes
46b9840 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 | import torch
import torch.nn as nn
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class Res_Net(nn.Module):
def __init__(self,input_cha):
super(Res_Net,self).__init__()
self.conv1 = nn.Conv2d(input_cha,input_cha,3,padding=1)
self.conv2 = nn.Conv2d(input_cha,input_cha,5,padding=2)
self.conv3 = nn.Conv2d(input_cha,input_cha,7,padding=3)
self.cbamBlock = CBAMBlock(input_cha)
self.bn1 = nn.BatchNorm2d(input_cha)
self.relu1 = nn.ReLU()
self.relu2 = nn.LeakyReLU()
def forward(self,x):
init_x = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu2(out)
out = self.conv1(out)
out = self.bn1(out)
out += init_x
out = self.relu2(out)
return out
class CBAMBlock(nn.Module):
def __init__(self, channel, reduction=16):
super(CBAMBlock, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.channel_excitation = nn.Sequential(nn.Linear(channel,int(channel//reduction),bias=False),
nn.ReLU(inplace=True),
nn.Linear(int(channel//reduction),channel,bias=False),
)
self.sigmoid = nn.Sigmoid()
self.spatial_excitation = nn.Sequential(nn.Conv2d(2, 1, kernel_size=7,
stride=1, padding=3, bias=False),
)
def forward(self, x):
bahs, chs, _, _ = x.size() #16 16 24 42
# Returns a new tensor with the same data as the self tensor but of a different size.
chn_avg = self.avg_pool(x).view(bahs, chs)
chn_avg = self.channel_excitation(chn_avg).view(bahs, chs, 1, 1)
chn_max = self.max_pool(x).view(bahs, chs)
chn_max = self.channel_excitation(chn_max).view(bahs, chs, 1, 1)
chn_add=chn_avg+chn_max
chn_add=self.sigmoid(chn_add)
chn_cbam = torch.mul(x, chn_add)
avg_out = torch.mean(chn_cbam, dim=1, keepdim=True)
max_out, _ = torch.max(chn_cbam, dim=1, keepdim=True)
cat = torch.cat([avg_out, max_out], dim=1)
spa_add = self.spatial_excitation(cat)
spa_add = self.sigmoid(spa_add)
spa_cbam = torch.mul(chn_cbam, spa_add)
return spa_cbam |