File size: 4,974 Bytes
9f88559 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, utils
from torchvision.transforms import ToTensor
from torchvision.transforms import transforms
class BasicBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=[1, 1], padding=1) -> None:
super(BasicBlock, self).__init__()
# 残差部分
self.layer = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride[0], padding=padding, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True), # 原地替换 节省内存开销
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride[1], padding=padding, bias=False),
nn.BatchNorm2d(out_channels)
)
# shortcut 部分
# 由于存在维度不一致的情况 所以分情况
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(
# 卷积核为1 进行升降维
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride[0], bias=False),
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
# print('shape of x: {}'.format(x.shape))
out = self.layer(x)
# print('shape of out: {}'.format(out.shape))
# print('After shortcut shape of x: {}'.format(self.shortcut(x).shape))
out += self.shortcut(x)
out = F.relu(out)
return out
# 采用bn的网络中,卷积层的输出并不加偏置
class ResNet18(nn.Module):
def __init__(self, BasicBlock, num_classes=10) -> None:
super(ResNet18, self).__init__()
self.in_channels = 64
# 第一层作为单独的 因为没有残差快
self.conv1 = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
nn.BatchNorm2d(64),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)
# conv2_x
self.conv2 = self._make_layer(BasicBlock, 64, [[1, 1], [1, 1]])
# self.conv2_2 = self._make_layer(BasicBlock,64,[1,1])
# conv3_x
self.conv3 = self._make_layer(BasicBlock, 128, [[2, 1], [1, 1]])
# self.conv3_2 = self._make_layer(BasicBlock,128,[1,1])
# conv4_x
self.conv4 = self._make_layer(BasicBlock, 256, [[2, 1], [1, 1]])
# self.conv4_2 = self._make_layer(BasicBlock,256,[1,1])
# conv5_x
self.conv5 = self._make_layer(BasicBlock, 512, [[2, 1], [1, 1]])
# self.conv5_2 = self._make_layer(BasicBlock,512,[1,1])
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512, num_classes)
self.senet64 = SEAttention(64)
self.senet128 = SEAttention(128)
self.senet256 = SEAttention(256)
self.senet512 = SEAttention(512)
# 这个函数主要是用来,重复同一个残差块
def _make_layer(self, block, out_channels, strides):
layers = []
for stride in strides:
layers.append(block(self.in_channels, out_channels, stride))
self.in_channels = out_channels
return nn.Sequential(*layers)
def forward(self, x):
out = self.conv1(x)
out = self.conv2(out)
print("out1.shape:", out.shape)
out = self.senet64(out)
out = self.conv3(out)
out = self.senet128(out)
print("out2.shape:", out.shape)
out = self.conv4(out)
out = self.senet256(out)
print("out3.shape:", out.shape)
out = self.conv5(out)
print("out4.shape:", out.shape)
out = self.senet512(out)
out = self.avgpool(out)
out = out.reshape(x.shape[0], -1)
out = self.fc(out)
return out
class SEAttention(nn.Module):
def __init__(self, in_channels, reduction_ratio=16):
super(SEAttention, self).__init__()
# 定义全局均值池化层
self.avg_pool = nn.AdaptiveAvgPool2d(1)
# 定义全连接层
self.fc = nn.Sequential(
nn.Linear(in_channels, in_channels // reduction_ratio),
nn.ReLU(inplace=True),
nn.Linear(in_channels // reduction_ratio, in_channels),
nn.Sigmoid()
)
def forward(self, x):
# 计算全局平均值,并通过全连接层得到每个通道的重要度
module_input = x # 用于残差计算
x = self.avg_pool(x)
x = torch.flatten(x, start_dim=1)
x = self.fc(x)
x = x.view(-1, x.size(1), 1, 1)
# 通过重要度对每个通道的特征图进行加权
x = module_input * x.expand_as(module_input)
return x
x = torch.rand(4,3,512,512)
print("x.shape:", x.shape)
model = ResNet18(BasicBlock)
y = model(x)
print("y.shape:", y.shape)
|