Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| ########################################################################## | |
| def conv(in_channels, out_channels, kernel_size, bias=False, stride=1): | |
| layer = nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size // 2), bias=bias, stride=stride) | |
| return layer | |
| def conv3x3(in_chn, out_chn, bias=True): | |
| layer = nn.Conv2d(in_chn, out_chn, kernel_size=3, stride=1, padding=1, bias=bias) | |
| return layer | |
| def conv_down(in_chn, out_chn, bias=False): | |
| layer = nn.Conv2d(in_chn, out_chn, kernel_size=4, stride=2, padding=1, bias=bias) | |
| return layer | |
| ########################################################################## | |
| ## Supervised Attention Module (RAM) | |
| class SAM(nn.Module): | |
| def __init__(self, n_feat, kernel_size, bias): | |
| super(SAM, self).__init__() | |
| self.conv1 = conv(n_feat, n_feat, kernel_size, bias=bias) | |
| self.conv2 = conv(n_feat, 3, kernel_size, bias=bias) | |
| self.conv3 = conv(3, n_feat, kernel_size, bias=bias) | |
| def forward(self, x, x_img): | |
| x1 = self.conv1(x) | |
| img = self.conv2(x) + x_img | |
| x2 = torch.sigmoid(self.conv3(img)) | |
| x1 = x1 * x2 | |
| x1 = x1 + x | |
| return x1, img | |
| ########################################################################## | |
| ## Spatial Attention | |
| class SALayer(nn.Module): | |
| def __init__(self, kernel_size=7): | |
| super(SALayer, self).__init__() | |
| self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False) | |
| self.sigmoid = nn.Sigmoid() | |
| def forward(self, x): | |
| avg_out = torch.mean(x, dim=1, keepdim=True) | |
| max_out, _ = torch.max(x, dim=1, keepdim=True) | |
| y = torch.cat([avg_out, max_out], dim=1) | |
| y = self.conv1(y) | |
| y = self.sigmoid(y) | |
| return x * y | |
| # Spatial Attention Block (SAB) | |
| class SAB(nn.Module): | |
| def __init__(self, n_feat, kernel_size, reduction, bias, act): | |
| super(SAB, self).__init__() | |
| modules_body = [conv(n_feat, n_feat, kernel_size, bias=bias), act, conv(n_feat, n_feat, kernel_size, bias=bias)] | |
| self.body = nn.Sequential(*modules_body) | |
| self.SA = SALayer(kernel_size=7) | |
| def forward(self, x): | |
| res = self.body(x) | |
| res = self.SA(res) | |
| res += x | |
| return res | |
| ########################################################################## | |
| ## Pixel Attention | |
| class PALayer(nn.Module): | |
| def __init__(self, channel, reduction=16, bias=False): | |
| super(PALayer, self).__init__() | |
| self.pa = nn.Sequential( | |
| nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=bias), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=bias), # channel <-> 1 | |
| nn.Sigmoid() | |
| ) | |
| def forward(self, x): | |
| y = self.pa(x) | |
| return x * y | |
| ## Pixel Attention Block (PAB) | |
| class PAB(nn.Module): | |
| def __init__(self, n_feat, kernel_size, reduction, bias, act): | |
| super(PAB, self).__init__() | |
| modules_body = [conv(n_feat, n_feat, kernel_size, bias=bias), act, conv(n_feat, n_feat, kernel_size, bias=bias)] | |
| self.PA = PALayer(n_feat, reduction, bias=bias) | |
| self.body = nn.Sequential(*modules_body) | |
| def forward(self, x): | |
| res = self.body(x) | |
| res = self.PA(res) | |
| res += x | |
| return res | |
| ########################################################################## | |
| ## Channel Attention Layer | |
| class CALayer(nn.Module): | |
| def __init__(self, channel, reduction=16, bias=False): | |
| super(CALayer, self).__init__() | |
| # global average pooling: feature --> point | |
| self.avg_pool = nn.AdaptiveAvgPool2d(1) | |
| # feature channel downscale and upscale --> channel weight | |
| self.conv_du = nn.Sequential( | |
| nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=bias), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=bias), | |
| nn.Sigmoid() | |
| ) | |
| def forward(self, x): | |
| y = self.avg_pool(x) | |
| y = self.conv_du(y) | |
| return x * y | |
| ## Channel Attention Block (CAB) | |
| class CAB(nn.Module): | |
| def __init__(self, n_feat, kernel_size, reduction, bias, act): | |
| super(CAB, self).__init__() | |
| modules_body = [conv(n_feat, n_feat, kernel_size, bias=bias), act, conv(n_feat, n_feat, kernel_size, bias=bias)] | |
| self.CA = CALayer(n_feat, reduction, bias=bias) | |
| self.body = nn.Sequential(*modules_body) | |
| def forward(self, x): | |
| res = self.body(x) | |
| res = self.CA(res) | |
| res += x | |
| return res | |
| if __name__ == "__main__": | |
| import time | |
| from thop import profile | |
| # layer = CAB(64, 3, 4, False, nn.PReLU()) | |
| layer = PAB(64, 3, 4, False, nn.PReLU()) | |
| # layer = SAB(64, 3, 4, False, nn.PReLU()) | |
| for idx, m in enumerate(layer.modules()): | |
| print(idx, "-", m) | |
| s = time.time() | |
| rgb = torch.ones(1, 64, 256, 256, dtype=torch.float, requires_grad=False) | |
| out = layer(rgb) | |
| flops, params = profile(layer, inputs=(rgb,)) | |
| print('parameters:', params) | |
| print('flops', flops) | |
| print('time: {:.4f}ms'.format((time.time()-s)*10)) |