InPeerReview's picture
Upload 161 files
226675b verified
import torch
import torch.nn as nn
import torch.nn.functional as F
class ChannelAttention(nn.Module):
def __init__(self, in_channels):
super(ChannelAttention, self).__init__()
self.in_channels = in_channels
self.linear_1 = nn.Linear(self.in_channels, self.in_channels // 4)
self.linear_2 = nn.Linear(self.in_channels // 4, self.in_channels)
def forward(self, input_):
n_b, n_c, h, w = input_.size()
feats = F.adaptive_avg_pool2d(input_, (1, 1)).view((n_b, n_c))
feats = F.relu(self.linear_1(feats))
feats = torch.sigmoid(self.linear_2(feats))
feats = feats.view((n_b, n_c, 1, 1))
feats = feats.expand_as(input_).clone()
outfeats = torch.mul(feats, input_)
return outfeats