import math import torch import torch.nn as nn import torch.nn.functional as F def sub_mean(x): mean = x.mean(2, keepdim=True).mean(3, keepdim=True) x -= mean return x, mean def InOutPaddings(x): w, h = x.size(3), x.size(2) padding_width, padding_height = 0, 0 if w != ((w >> 7) << 7): padding_width = (((w >> 7) + 1) << 7) - w if h != ((h >> 7) << 7): padding_height = (((h >> 7) + 1) << 7) - h paddingInput = nn.ReflectionPad2d(padding=[padding_width // 2, padding_width - padding_width // 2, padding_height // 2, padding_height - padding_height // 2]) paddingOutput = nn.ReflectionPad2d(padding=[0 - padding_width // 2, padding_width // 2 - padding_width, 0 - padding_height // 2, padding_height // 2 - padding_height]) return paddingInput, paddingOutput class ConvNorm(nn.Module): def __init__(self, in_feat, out_feat, kernel_size, stride=1, norm=False): super(ConvNorm, self).__init__() reflection_padding = kernel_size // 2 self.reflection_pad = nn.ReflectionPad2d(reflection_padding) self.conv = nn.Conv2d(in_feat, out_feat, stride=stride, kernel_size=kernel_size, bias=True) self.norm = norm if norm == 'IN': self.norm = nn.InstanceNorm2d(out_feat, track_running_stats=True) elif norm == 'BN': self.norm = nn.BatchNorm2d(out_feat) def forward(self, x): out = self.reflection_pad(x) out = self.conv(out) if self.norm: out = self.norm(out) return out class UpConvNorm(nn.Module): def __init__(self, in_channels, out_channels, mode='transpose', norm=False): super(UpConvNorm, self).__init__() if mode == 'transpose': self.upconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1) elif mode == 'shuffle': self.upconv = nn.Sequential( ConvNorm(in_channels, 4*out_channels, kernel_size=3, stride=1, norm=norm), PixelShuffle(2)) else: # out_channels is always going to be the same as in_channels self.upconv = nn.Sequential( nn.Upsample(mode='bilinear', scale_factor=2, align_corners=False), ConvNorm(in_channels, out_channels, kernel_size=1, stride=1, norm=norm)) def forward(self, x): out = self.upconv(x) return out class meanShift(nn.Module): def __init__(self, rgbRange, rgbMean, sign, nChannel=3): super(meanShift, self).__init__() if nChannel == 1: l = rgbMean[0] * rgbRange * float(sign) self.shifter = nn.Conv2d(1, 1, kernel_size=1, stride=1, padding=0) self.shifter.weight.data = torch.eye(1).view(1, 1, 1, 1) self.shifter.bias.data = torch.Tensor([l]) elif nChannel == 3: r = rgbMean[0] * rgbRange * float(sign) g = rgbMean[1] * rgbRange * float(sign) b = rgbMean[2] * rgbRange * float(sign) self.shifter = nn.Conv2d(3, 3, kernel_size=1, stride=1, padding=0) self.shifter.weight.data = torch.eye(3).view(3, 3, 1, 1) self.shifter.bias.data = torch.Tensor([r, g, b]) else: r = rgbMean[0] * rgbRange * float(sign) g = rgbMean[1] * rgbRange * float(sign) b = rgbMean[2] * rgbRange * float(sign) self.shifter = nn.Conv2d(6, 6, kernel_size=1, stride=1, padding=0) self.shifter.weight.data = torch.eye(6).view(6, 6, 1, 1) self.shifter.bias.data = torch.Tensor([r, g, b, r, g, b]) # Freeze the meanShift layer for params in self.shifter.parameters(): params.requires_grad = False def forward(self, x): x = self.shifter(x) return x """ CONV - (BN) - RELU - CONV - (BN) """ class ResBlock(nn.Module): def __init__(self, in_feat, out_feat, kernel_size=3, reduction=False, bias=True, # 'reduction' is just for placeholder norm=False, act=nn.ReLU(True), downscale=False): super(ResBlock, self).__init__() self.body = nn.Sequential( ConvNorm(in_feat, out_feat, kernel_size=kernel_size, stride=2 if downscale else 1), act, ConvNorm(out_feat, out_feat, kernel_size=kernel_size, stride=1) ) self.downscale = None if downscale: self.downscale = nn.Conv2d(in_feat, out_feat, kernel_size=1, stride=2) def forward(self, x): res = x out = self.body(x) if self.downscale is not None: res = self.downscale(res) out += res return out ## Channel Attention (CA) Layer class CALayer(nn.Module): def __init__(self, channel, reduction=16): super(CALayer, self).__init__() # global average pooling: feature --> point self.avg_pool = nn.AdaptiveAvgPool2d(1) # feature channel downscale and upscale --> channel weight self.conv_du = nn.Sequential( nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), nn.ReLU(inplace=True), nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), nn.Sigmoid() ) def forward(self, x): y = self.avg_pool(x) y = self.conv_du(y) return x * y, y ## Residual Channel Attention Block (RCAB) class RCAB(nn.Module): def __init__(self, in_feat, out_feat, kernel_size, reduction, bias=True, norm=False, act=nn.ReLU(True), downscale=False, return_ca=False): super(RCAB, self).__init__() self.body = nn.Sequential( ConvNorm(in_feat, out_feat, kernel_size, stride=2 if downscale else 1, norm=norm), act, ConvNorm(out_feat, out_feat, kernel_size, stride=1, norm=norm), CALayer(out_feat, reduction) ) self.downscale = downscale if downscale: self.downConv = nn.Conv2d(in_feat, out_feat, kernel_size=3, stride=2, padding=1) self.return_ca = return_ca def forward(self, x): res = x out, ca = self.body(x) if self.downscale: res = self.downConv(res) out += res if self.return_ca: return out, ca else: return out ## Residual Group (RG) class ResidualGroup(nn.Module): def __init__(self, Block, n_resblocks, n_feat, kernel_size, reduction, act, norm=False): super(ResidualGroup, self).__init__() modules_body = [Block(n_feat, n_feat, kernel_size, reduction, bias=True, norm=norm, act=act) for _ in range(n_resblocks)] modules_body.append(ConvNorm(n_feat, n_feat, kernel_size, stride=1, norm=norm)) self.body = nn.Sequential(*modules_body) def forward(self, x): res = self.body(x) res += x return res def pixel_shuffle(input, scale_factor): batch_size, channels, in_height, in_width = input.size() out_channels = int(int(channels / scale_factor) / scale_factor) out_height = int(in_height * scale_factor) out_width = int(in_width * scale_factor) if scale_factor >= 1: input_view = input.contiguous().view(batch_size, out_channels, scale_factor, scale_factor, in_height, in_width) shuffle_out = input_view.permute(0, 1, 4, 2, 5, 3).contiguous() else: block_size = int(1 / scale_factor) input_view = input.contiguous().view(batch_size, channels, out_height, block_size, out_width, block_size) shuffle_out = input_view.permute(0, 1, 3, 5, 2, 4).contiguous() return shuffle_out.view(batch_size, out_channels, out_height, out_width) class PixelShuffle(nn.Module): def __init__(self, scale_factor): super(PixelShuffle, self).__init__() self.scale_factor = scale_factor def forward(self, x): return pixel_shuffle(x, self.scale_factor) def extra_repr(self): return 'scale_factor={}'.format(self.scale_factor) def conv(in_channels, out_channels, kernel_size, stride=1, bias=True, groups=1): return nn.Conv2d( in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size//2, stride=1, bias=bias, groups=groups) def conv1x1(in_channels, out_channels, stride=1, bias=True, groups=1): return nn.Conv2d( in_channels, out_channels, kernel_size=1, stride=stride, bias=bias, groups=groups) def conv3x3(in_channels, out_channels, stride=1, padding=1, bias=True, groups=1): return nn.Conv2d( in_channels, out_channels, kernel_size=3, stride=stride, padding=padding, bias=bias, groups=groups) def conv5x5(in_channels, out_channels, stride=1, padding=2, bias=True, groups=1): return nn.Conv2d( in_channels, out_channels, kernel_size=5, stride=stride, padding=padding, bias=bias, groups=groups) def conv7x7(in_channels, out_channels, stride=1, padding=3, bias=True, groups=1): return nn.Conv2d( in_channels, out_channels, kernel_size=7, stride=stride, padding=padding, bias=bias, groups=groups) def upconv2x2(in_channels, out_channels, mode='shuffle'): if mode == 'transpose': return nn.ConvTranspose2d( in_channels, out_channels, kernel_size=4, stride=2, padding=1) elif mode == 'shuffle': return nn.Sequential( conv3x3(in_channels, 4*out_channels), PixelShuffle(2)) else: # out_channels is always going to be the same as in_channels return nn.Sequential( nn.Upsample(mode='bilinear', scale_factor=2, align_corners=False), conv1x1(in_channels, out_channels)) class Interpolation(nn.Module): def __init__(self, n_resgroups, n_resblocks, n_feats, reduction=16, act=nn.LeakyReLU(0.2, True), norm=False): super(Interpolation, self).__init__() # define modules: head, body, tail self.headConv = conv3x3(n_feats * 2, n_feats) modules_body = [ ResidualGroup( RCAB, n_resblocks=n_resblocks, n_feat=n_feats, kernel_size=3, reduction=reduction, act=act, norm=norm) for _ in range(n_resgroups)] self.body = nn.Sequential(*modules_body) self.tailConv = conv3x3(n_feats, n_feats) def forward(self, x0, x1): # Build input tensor x = torch.cat([x0, x1], dim=1) x = self.headConv(x) res = self.body(x) res += x out = self.tailConv(res) return out class Interpolation_res(nn.Module): def __init__(self, n_resgroups, n_resblocks, n_feats, act=nn.LeakyReLU(0.2, True), norm=False): super(Interpolation_res, self).__init__() # define modules: head, body, tail (reduces concatenated inputs to n_feat) self.headConv = conv3x3(n_feats * 2, n_feats) modules_body = [ResidualGroup(ResBlock, n_resblocks=n_resblocks, n_feat=n_feats, kernel_size=3, reduction=0, act=act, norm=norm) for _ in range(n_resgroups)] self.body = nn.Sequential(*modules_body) self.tailConv = conv3x3(n_feats, n_feats) def forward(self, x0, x1): # Build input tensor x = torch.cat([x0, x1], dim=1) x = self.headConv(x) res = x for m in self.body: res = m(res) res += x x = self.tailConv(res) return x