import torch import torch.nn as nn class DAF(nn.Module): ''' 直接相加 DirectAddFuse ''' def __init__(self): super(DAF, self).__init__() def forward(self, x, residual): return x + residual class iAFF(nn.Module): ''' 多特征融合 iAFF ''' def __init__(self, channels=64, r=4): super(iAFF, self).__init__() inter_channels = int(channels // r) # 本地注意力 self.local_att = nn.Sequential( nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), nn.BatchNorm2d(inter_channels), nn.ReLU(inplace=True), nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), nn.BatchNorm2d(channels), ) # 全局注意力 self.global_att = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), nn.BatchNorm2d(inter_channels), nn.ReLU(inplace=True), nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), nn.BatchNorm2d(channels), ) # 第二次本地注意力 self.local_att2 = nn.Sequential( nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), nn.BatchNorm2d(inter_channels), nn.ReLU(inplace=True), nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), nn.BatchNorm2d(channels), ) # 第二次全局注意力 self.global_att2 = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), nn.BatchNorm2d(inter_channels), nn.ReLU(inplace=True), nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), nn.BatchNorm2d(channels), ) self.sigmoid = nn.Sigmoid() def forward(self, x, residual): xa = x + residual xl = self.local_att(xa) xg = self.global_att(xa) xlg = xl + xg wei = self.sigmoid(xlg) xi = x * wei + residual * (1 - wei) xl2 = self.local_att2(xi) xg2 = self.global_att(xi) xlg2 = xl2 + xg2 wei2 = self.sigmoid(xlg2) xo = x * wei2 + residual * (1 - wei2) return xo class AFF(nn.Module): ''' 多特征融合 AFF ''' def __init__(self, channels=64, r=4): super(AFF, self).__init__() inter_channels = int(channels // r) self.local_att = nn.Sequential( nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), nn.BatchNorm2d(inter_channels), nn.ReLU(inplace=True), nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), nn.BatchNorm2d(channels), ) self.global_att = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), nn.BatchNorm2d(inter_channels), nn.ReLU(inplace=True), nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), nn.BatchNorm2d(channels), ) self.sigmoid = nn.Sigmoid() def forward(self, x, residual): xa = x + residual xl = self.local_att(xa) xg = self.global_att(xa) xlg = xl + xg wei = self.sigmoid(xlg) xo = 2 * x * wei + 2 * residual * (1 - wei) return xo class MS_CAM(nn.Module): ''' 单特征 进行通道加权,作用类似SE模块 ''' def __init__(self, channels=64, r=4): super(MS_CAM, self).__init__() inter_channels = int(channels // r) self.local_att = nn.Sequential( nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), nn.BatchNorm2d(inter_channels), nn.ReLU(inplace=True), nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), nn.BatchNorm2d(channels), ) self.global_att = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), nn.BatchNorm2d(inter_channels), nn.ReLU(inplace=True), nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), nn.BatchNorm2d(channels), ) self.sigmoid = nn.Sigmoid() def forward(self, x): xl = self.local_att(x) xg = self.global_att(x) xlg = xl + xg wei = self.sigmoid(xlg) return x * wei if __name__ == '__main__': import os device = torch.device("cpu") x = torch.ones(1, 2, 2, 2).to(device) print(x) a = x[0] print(a) b = torch.ones(2, 2, 2) c = torch.stack((a, b)) print(x.shape) # x, residual= torch.ones(1, 2, 2, 2).to(device), torch.ones(1,64, 32, 32).to(device) # x = torch.cat(x, dim=1) # print(x.shape) # channels=x.shape[1] # print(channels) # model=AFF(channels=channels) # model=model.to(device).train() # output = model(x, residual) # print(output.shape)