Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| # -*- coding:utf-8 -*- | |
| ############################################################# | |
| # File: esa.py | |
| # Created Date: Tuesday April 28th 2022 | |
| # Author: Chen Xuanhong | |
| # Email: chenxuanhongzju@outlook.com | |
| # Last Modified: Thursday, 20th April 2023 9:28:06 am | |
| # Modified By: Chen Xuanhong | |
| # Copyright (c) 2020 Shanghai Jiao Tong University | |
| ############################################################# | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from .layernorm import LayerNorm2d | |
| def moment(x, dim=(2, 3), k=2): | |
| assert len(x.size()) == 4 | |
| mean = torch.mean(x, dim=dim).unsqueeze(-1).unsqueeze(-1) | |
| mk = (1 / (x.size(2) * x.size(3))) * torch.sum(torch.pow(x - mean, k), dim=dim) | |
| return mk | |
| class ESA(nn.Module): | |
| """ | |
| Modification of Enhanced Spatial Attention (ESA), which is proposed by | |
| `Residual Feature Aggregation Network for Image Super-Resolution` | |
| Note: `conv_max` and `conv3_` are NOT used here, so the corresponding codes | |
| are deleted. | |
| """ | |
| def __init__(self, esa_channels, n_feats, conv=nn.Conv2d): | |
| super(ESA, self).__init__() | |
| f = esa_channels | |
| self.conv1 = conv(n_feats, f, kernel_size=1) | |
| self.conv_f = conv(f, f, kernel_size=1) | |
| self.conv2 = conv(f, f, kernel_size=3, stride=2, padding=0) | |
| self.conv3 = conv(f, f, kernel_size=3, padding=1) | |
| self.conv4 = conv(f, n_feats, kernel_size=1) | |
| self.sigmoid = nn.Sigmoid() | |
| self.relu = nn.ReLU(inplace=True) | |
| def forward(self, x): | |
| c1_ = self.conv1(x) | |
| c1 = self.conv2(c1_) | |
| v_max = F.max_pool2d(c1, kernel_size=7, stride=3) | |
| c3 = self.conv3(v_max) | |
| c3 = F.interpolate( | |
| c3, (x.size(2), x.size(3)), mode="bilinear", align_corners=False | |
| ) | |
| cf = self.conv_f(c1_) | |
| c4 = self.conv4(c3 + cf) | |
| m = self.sigmoid(c4) | |
| return x * m | |
| class LK_ESA(nn.Module): | |
| def __init__( | |
| self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True | |
| ): | |
| super(LK_ESA, self).__init__() | |
| f = esa_channels | |
| self.conv1 = conv(n_feats, f, kernel_size=1) | |
| self.conv_f = conv(f, f, kernel_size=1) | |
| kernel_size = 17 | |
| kernel_expand = kernel_expand | |
| padding = kernel_size // 2 | |
| self.vec_conv = nn.Conv2d( | |
| in_channels=f * kernel_expand, | |
| out_channels=f * kernel_expand, | |
| kernel_size=(1, kernel_size), | |
| padding=(0, padding), | |
| groups=2, | |
| bias=bias, | |
| ) | |
| self.vec_conv3x1 = nn.Conv2d( | |
| in_channels=f * kernel_expand, | |
| out_channels=f * kernel_expand, | |
| kernel_size=(1, 3), | |
| padding=(0, 1), | |
| groups=2, | |
| bias=bias, | |
| ) | |
| self.hor_conv = nn.Conv2d( | |
| in_channels=f * kernel_expand, | |
| out_channels=f * kernel_expand, | |
| kernel_size=(kernel_size, 1), | |
| padding=(padding, 0), | |
| groups=2, | |
| bias=bias, | |
| ) | |
| self.hor_conv1x3 = nn.Conv2d( | |
| in_channels=f * kernel_expand, | |
| out_channels=f * kernel_expand, | |
| kernel_size=(3, 1), | |
| padding=(1, 0), | |
| groups=2, | |
| bias=bias, | |
| ) | |
| self.conv4 = conv(f, n_feats, kernel_size=1) | |
| self.sigmoid = nn.Sigmoid() | |
| self.relu = nn.ReLU(inplace=True) | |
| def forward(self, x): | |
| c1_ = self.conv1(x) | |
| res = self.vec_conv(c1_) + self.vec_conv3x1(c1_) | |
| res = self.hor_conv(res) + self.hor_conv1x3(res) | |
| cf = self.conv_f(c1_) | |
| c4 = self.conv4(res + cf) | |
| m = self.sigmoid(c4) | |
| return x * m | |
| class LK_ESA_LN(nn.Module): | |
| def __init__( | |
| self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True | |
| ): | |
| super(LK_ESA_LN, self).__init__() | |
| f = esa_channels | |
| self.conv1 = conv(n_feats, f, kernel_size=1) | |
| self.conv_f = conv(f, f, kernel_size=1) | |
| kernel_size = 17 | |
| kernel_expand = kernel_expand | |
| padding = kernel_size // 2 | |
| self.norm = LayerNorm2d(n_feats) | |
| self.vec_conv = nn.Conv2d( | |
| in_channels=f * kernel_expand, | |
| out_channels=f * kernel_expand, | |
| kernel_size=(1, kernel_size), | |
| padding=(0, padding), | |
| groups=2, | |
| bias=bias, | |
| ) | |
| self.vec_conv3x1 = nn.Conv2d( | |
| in_channels=f * kernel_expand, | |
| out_channels=f * kernel_expand, | |
| kernel_size=(1, 3), | |
| padding=(0, 1), | |
| groups=2, | |
| bias=bias, | |
| ) | |
| self.hor_conv = nn.Conv2d( | |
| in_channels=f * kernel_expand, | |
| out_channels=f * kernel_expand, | |
| kernel_size=(kernel_size, 1), | |
| padding=(padding, 0), | |
| groups=2, | |
| bias=bias, | |
| ) | |
| self.hor_conv1x3 = nn.Conv2d( | |
| in_channels=f * kernel_expand, | |
| out_channels=f * kernel_expand, | |
| kernel_size=(3, 1), | |
| padding=(1, 0), | |
| groups=2, | |
| bias=bias, | |
| ) | |
| self.conv4 = conv(f, n_feats, kernel_size=1) | |
| self.sigmoid = nn.Sigmoid() | |
| self.relu = nn.ReLU(inplace=True) | |
| def forward(self, x): | |
| c1_ = self.norm(x) | |
| c1_ = self.conv1(c1_) | |
| res = self.vec_conv(c1_) + self.vec_conv3x1(c1_) | |
| res = self.hor_conv(res) + self.hor_conv1x3(res) | |
| cf = self.conv_f(c1_) | |
| c4 = self.conv4(res + cf) | |
| m = self.sigmoid(c4) | |
| return x * m | |
| class AdaGuidedFilter(nn.Module): | |
| def __init__( | |
| self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True | |
| ): | |
| super(AdaGuidedFilter, self).__init__() | |
| self.gap = nn.AdaptiveAvgPool2d(1) | |
| self.fc = nn.Conv2d( | |
| in_channels=n_feats, | |
| out_channels=1, | |
| kernel_size=1, | |
| padding=0, | |
| stride=1, | |
| groups=1, | |
| bias=True, | |
| ) | |
| self.r = 5 | |
| def box_filter(self, x, r): | |
| channel = x.shape[1] | |
| kernel_size = 2 * r + 1 | |
| weight = 1.0 / (kernel_size**2) | |
| box_kernel = weight * torch.ones( | |
| (channel, 1, kernel_size, kernel_size), dtype=torch.float32, device=x.device | |
| ) | |
| output = F.conv2d(x, weight=box_kernel, stride=1, padding=r, groups=channel) | |
| return output | |
| def forward(self, x): | |
| _, _, H, W = x.shape | |
| N = self.box_filter( | |
| torch.ones((1, 1, H, W), dtype=x.dtype, device=x.device), self.r | |
| ) | |
| # epsilon = self.fc(self.gap(x)) | |
| # epsilon = torch.pow(epsilon, 2) | |
| epsilon = 1e-2 | |
| mean_x = self.box_filter(x, self.r) / N | |
| var_x = self.box_filter(x * x, self.r) / N - mean_x * mean_x | |
| A = var_x / (var_x + epsilon) | |
| b = (1 - A) * mean_x | |
| m = A * x + b | |
| # mean_A = self.box_filter(A, self.r) / N | |
| # mean_b = self.box_filter(b, self.r) / N | |
| # m = mean_A * x + mean_b | |
| return x * m | |
| class AdaConvGuidedFilter(nn.Module): | |
| def __init__( | |
| self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True | |
| ): | |
| super(AdaConvGuidedFilter, self).__init__() | |
| f = esa_channels | |
| self.conv_f = conv(f, f, kernel_size=1) | |
| kernel_size = 17 | |
| kernel_expand = kernel_expand | |
| padding = kernel_size // 2 | |
| self.vec_conv = nn.Conv2d( | |
| in_channels=f, | |
| out_channels=f, | |
| kernel_size=(1, kernel_size), | |
| padding=(0, padding), | |
| groups=f, | |
| bias=bias, | |
| ) | |
| self.hor_conv = nn.Conv2d( | |
| in_channels=f, | |
| out_channels=f, | |
| kernel_size=(kernel_size, 1), | |
| padding=(padding, 0), | |
| groups=f, | |
| bias=bias, | |
| ) | |
| self.gap = nn.AdaptiveAvgPool2d(1) | |
| self.fc = nn.Conv2d( | |
| in_channels=f, | |
| out_channels=f, | |
| kernel_size=1, | |
| padding=0, | |
| stride=1, | |
| groups=1, | |
| bias=True, | |
| ) | |
| def forward(self, x): | |
| y = self.vec_conv(x) | |
| y = self.hor_conv(y) | |
| sigma = torch.pow(y, 2) | |
| epsilon = self.fc(self.gap(y)) | |
| weight = sigma / (sigma + epsilon) | |
| m = weight * x + (1 - weight) | |
| return x * m | |