Spaces:
Runtime error
Runtime error
| import math | |
| import cv2 | |
| import torch | |
| import torch.nn as nn | |
| import torchvision | |
| from torch import nn | |
| import torch | |
| class eca_block(nn.Module): | |
| def __init__(self, channel, b=1, gamma=2): | |
| super(eca_block, self).__init__() | |
| kernel_size = int(abs((math.log(channel, 2) + b) / gamma)) | |
| kernel_size = kernel_size if kernel_size % 2 else kernel_size + 1 | |
| self.avg_pool = nn.AdaptiveAvgPool2d(1) | |
| self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False) | |
| self.sigmoid = nn.Sigmoid() | |
| def forward(self, x): | |
| y = self.avg_pool(x) | |
| y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1) | |
| y = self.sigmoid(y) | |
| return x * y.expand_as(x) | |
| class DilatedConvNet(nn.Module): | |
| def __init__(self, in_channels, out_channels, dilation, padding, kernel_size): | |
| super(DilatedConvNet, self).__init__() | |
| self.dilated_conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding, dilation=dilation) | |
| self.relu = nn.ReLU(inplace=False) | |
| def forward(self, x): | |
| x = self.dilated_conv(x) | |
| x = self.relu(x) | |
| return x | |
| class LAM(nn.Module): | |
| def __init__(self, ch=16): | |
| super().__init__() | |
| self.eca = eca_block(ch) | |
| self.conv1 = nn.Conv2d(6, 3, 3, padding=1) | |
| def forward(self, x): | |
| x = self.eca(x) | |
| x = self.conv1(x) | |
| return x | |
| class RFEM(nn.Module): | |
| def __init__( | |
| self, | |
| ch_blocks=64, | |
| ch_mask=16, | |
| ): | |
| super().__init__() | |
| self.encoder = nn.Sequential(nn.Conv2d(3, 16, 3, padding=1), | |
| nn.LeakyReLU(True), | |
| nn.Conv2d(16, ch_blocks, 3, padding=1), | |
| nn.LeakyReLU(True)) | |
| self.dconv1 = DilatedConvNet(ch_blocks, | |
| ch_blocks // 4, | |
| kernel_size=3, | |
| padding=1, dilation=1) | |
| self.dconv2 = DilatedConvNet(ch_blocks, | |
| ch_blocks // 4, | |
| kernel_size=3, | |
| padding=2, dilation=2) | |
| self.dconv3 = DilatedConvNet(ch_blocks, | |
| ch_blocks // 4, | |
| kernel_size=3, | |
| padding=3, dilation=3) | |
| self.dconv4 = nn.Conv2d(ch_blocks, | |
| ch_blocks // 4, | |
| kernel_size=7, | |
| padding=3) | |
| self.decoder = nn.Sequential(nn.Conv2d(ch_blocks, 16, 3, padding=1), | |
| nn.LeakyReLU(True), | |
| nn.Conv2d(16, 3, 3, padding=1), | |
| nn.LeakyReLU(True), | |
| ) | |
| self.lam = LAM(ch_mask) | |
| def forward(self, x): | |
| x1 = self.encoder(x) | |
| x1_1 = self.dconv1(x1) | |
| x1_2 = self.dconv2(x1) | |
| x1_3 = self.dconv3(x1) | |
| x1_4 = self.dconv4(x1) | |
| x1 = torch.cat([x1_1, x1_2, x1_3, x1_4], dim=1) | |
| x1 = self.decoder(x1) | |
| out = x + x1 | |
| out = torch.relu(out) | |
| mask = self.lam(torch.cat([x, out], dim=1)) | |
| return out, mask | |
| class ATEM(nn.Module): | |
| def __init__(self, in_ch=3, inter_ch=32, out_ch=3, kernel_size=3): | |
| super().__init__() | |
| self.encoder = nn.Sequential( | |
| nn.Conv2d(in_ch, inter_ch, kernel_size, padding=kernel_size // 2), | |
| nn.LeakyReLU(True), | |
| ) | |
| self.shift_conv = nn.Sequential( | |
| nn.Conv2d(in_ch, inter_ch, kernel_size, padding=kernel_size // 2)) | |
| self.scale_conv = nn.Sequential( | |
| nn.Conv2d(in_ch, inter_ch, kernel_size, padding=kernel_size // 2)) | |
| self.decoder = nn.Sequential( | |
| nn.Conv2d(inter_ch, out_ch, kernel_size, padding=kernel_size // 2)) | |
| def forward(self, x, tag): | |
| x = self.encoder(x) | |
| scale = self.scale_conv(tag) | |
| shift = self.shift_conv(tag) | |
| x = x +(x * scale + shift) | |
| x = self.decoder(x) | |
| return x | |
| class Trans_high(nn.Module): | |
| def __init__(self, in_ch=3, inter_ch=16, out_ch=3, kernel_size=3): | |
| super().__init__() | |
| self.atem = ATEM(in_ch, inter_ch, out_ch, kernel_size) | |
| def forward(self, x, tag): | |
| x = x + self.atem(x, tag) | |
| return x | |
| class Up_tag(nn.Module): | |
| def __init__(self, kernel_size=1, ch=3): | |
| super().__init__() | |
| self.up = nn.Sequential( | |
| nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), | |
| nn.Conv2d(ch, | |
| ch, | |
| kernel_size, | |
| stride=1, | |
| padding=kernel_size // 2, | |
| bias=False)) | |
| def forward(self, x): | |
| x = self.up(x) | |
| return x | |
| class Lap_Pyramid_Conv(nn.Module): | |
| def __init__(self, num_high=3, kernel_size=5, channels=3): | |
| super().__init__() | |
| self.num_high = num_high | |
| self.kernel = self.gauss_kernel(kernel_size, channels) | |
| def gauss_kernel(self, kernel_size, channels): | |
| kernel = cv2.getGaussianKernel(kernel_size, 0).dot( | |
| cv2.getGaussianKernel(kernel_size, 0).T) | |
| kernel = torch.FloatTensor(kernel).unsqueeze(0).repeat( | |
| channels, 1, 1, 1) | |
| kernel = torch.nn.Parameter(data=kernel, requires_grad=False) | |
| return kernel | |
| def conv_gauss(self, x, kernel): | |
| n_channels, _, kw, kh = kernel.shape | |
| x = torch.nn.functional.pad(x, (kw // 2, kh // 2, kw // 2, kh // 2), | |
| mode='reflect') | |
| x = torch.nn.functional.conv2d(x, kernel, groups=n_channels) | |
| return x | |
| def downsample(self, x): | |
| return x[:, :, ::2, ::2] | |
| def pyramid_down(self, x): | |
| return self.downsample(self.conv_gauss(x, self.kernel)) | |
| def upsample(self, x): | |
| up = torch.zeros((x.size(0), x.size(1), x.size(2) * 2, x.size(3) * 2), | |
| device=x.device) | |
| up[:, :, ::2, ::2] = x * 4 | |
| return self.conv_gauss(up, self.kernel) | |
| def pyramid_decom(self, img): | |
| self.kernel = self.kernel.to(img.device) | |
| current = img | |
| pyr = [] | |
| for _ in range(self.num_high): | |
| down = self.pyramid_down(current) | |
| up = self.upsample(down) | |
| diff = current - up | |
| pyr.append(diff) | |
| current = down | |
| pyr.append(current) | |
| return pyr | |
| def pyramid_recons(self, pyr): | |
| image = pyr[0] | |
| for level in pyr[1:]: | |
| up = self.upsample(image) | |
| image = up + level | |
| return image | |
| class FAENet(nn.Module): | |
| def __init__(self, | |
| num_high=1, | |
| ch_blocks=32, | |
| up_ksize=1, | |
| high_ch=32, | |
| high_ksize=3, | |
| ch_mask=32, | |
| gauss_kernel=7): | |
| super().__init__() | |
| self.num_high = num_high | |
| self.lap_pyramid = Lap_Pyramid_Conv(num_high, gauss_kernel) | |
| self.rfem = RFEM(ch_blocks, ch_mask) | |
| for i in range(0, self.num_high): | |
| self.__setattr__('up_tag_layer_{}'.format(i), | |
| Up_tag(up_ksize, ch=3)) | |
| self.__setattr__('trans_high_layer_{}'.format(i), | |
| Trans_high(3, high_ch, 3, high_ksize)) | |
| def forward(self, x): | |
| pyrs = self.lap_pyramid.pyramid_decom(img=x) | |
| trans_pyrs = [] | |
| trans_pyr, tag = self.rfem(pyrs[-1]) | |
| trans_pyrs.append(trans_pyr) | |
| commom_tag = [] | |
| for i in range(self.num_high): | |
| tag = self.__getattr__('up_tag_layer_{}'.format(i))(tag) | |
| commom_tag.append(tag) | |
| for i in range(self.num_high): | |
| trans_pyr = self.__getattr__('trans_high_layer_{}'.format(i))( | |
| pyrs[-2 - i], commom_tag[i]) | |
| trans_pyrs.append(trans_pyr) | |
| out = self.lap_pyramid.pyramid_recons(trans_pyrs) | |
| return out | |
| faenet = FAENet() | |
| params = faenet.parameters() | |
| num_params = sum(p.numel() for p in params) | |
| print("FAENet parameters: {:.2f}K ".format(num_params/ 1024) + "{:.2f} MB".format(num_params/ (1024 * 1024))) | |