import math import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional def default_conv(in_channels, out_channels, kernel_size, bias=True): return nn.Conv2d( in_channels, out_channels, kernel_size, padding=(kernel_size//2), bias=bias) class MeanShift(nn.Conv2d): def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1): super(MeanShift, self).__init__(3, 3, kernel_size=1) std = torch.Tensor(rgb_std) self.weight.data = torch.eye(3).view(3, 3, 1, 1) self.weight.data.div_(std.view(3, 1, 1, 1)) self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) self.bias.data.div_(std) self.requires_grad = False class BasicBlock(nn.Sequential): def __init__( self, in_channels, out_channels, kernel_size, stride=1, bias=False, bn=True, act=nn.ReLU(True)): m = [nn.Conv2d( in_channels, out_channels, kernel_size, padding=(kernel_size//2), stride=stride, bias=bias) ] if bn: m.append(nn.BatchNorm2d(out_channels)) if act is not None: m.append(act) super(BasicBlock, self).__init__(*m) class ResBlock(nn.Module): def __init__( self, conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1): super(ResBlock, self).__init__() m = [] for i in range(2): m.append(conv(n_feat, n_feat, kernel_size, bias=bias)) if bn: m.append(nn.BatchNorm2d(n_feat)) if i == 0: m.append(act) self.body = nn.Sequential(*m) self.res_scale = res_scale def forward(self, x): res = self.body(x).mul(self.res_scale) res += x return res class Upsampler(nn.Sequential): def __init__(self, conv, scale, n_feat, bn=False, act=False, bias=True): m = [] if (scale & (scale - 1)) == 0: # Is scale = 2^n? for _ in range(int(math.log(scale, 2))): m.append(conv(n_feat, 4 * n_feat, 3, bias)) m.append(nn.PixelShuffle(2)) if bn: m.append(nn.BatchNorm2d(n_feat)) if act: m.append(act()) elif scale == 3: m.append(conv(n_feat, 9 * n_feat, 3, bias)) m.append(nn.PixelShuffle(3)) if bn: m.append(nn.BatchNorm2d(n_feat)) if act: m.append(act()) else: raise NotImplementedError super(Upsampler, self).__init__(*m) # add NonLocalBlock2D # reference: https://github.com/AlexHex7/Non-local_pytorch/blob/master/lib/non_local_simple_version.py class NonLocalBlock2D(nn.Module): def __init__(self, in_channels, inter_channels): super(NonLocalBlock2D, self).__init__() self.in_channels = in_channels self.inter_channels = inter_channels self.g = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0) self.W = nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0) nn.init.constant(self.W.weight, 0) nn.init.constant(self.W.bias, 0) self.theta = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0) self.phi = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0) def forward(self, x): batch_size = x.size(0) g_x = self.g(x).view(batch_size, self.inter_channels, -1) g_x = g_x.permute(0,2,1) theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) theta_x = theta_x.permute(0,2,1) phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) f = torch.matmul(theta_x, phi_x) f_div_C = F.softmax(f, dim=1) y = torch.matmul(f_div_C, g_x) y = y.permute(0,2,1).contiguous() y = y.view(batch_size, self.inter_channels, *x.size()[2:]) W_y = self.W(y) z = W_y + x return z ## define trunk branch class TrunkBranch(nn.Module): def __init__( self, conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1): super(TrunkBranch, self).__init__() modules_body = [] for i in range(2): modules_body.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=res_scale)) self.body = nn.Sequential(*modules_body) def forward(self, x): tx = self.body(x) return tx ## define mask branch class MaskBranchDownUp(nn.Module): def __init__( self, conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1): super(MaskBranchDownUp, self).__init__() MB_RB1 = [] MB_RB1.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=res_scale)) MB_Down = [] MB_Down.append(nn.Conv2d(n_feat,n_feat, 3, stride=2, padding=1)) MB_RB2 = [] for i in range(2): MB_RB2.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=res_scale)) MB_Up = [] MB_Up.append(nn.ConvTranspose2d(n_feat,n_feat, 6, stride=2, padding=2)) MB_RB3 = [] MB_RB3.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=res_scale)) MB_1x1conv = [] MB_1x1conv.append(nn.Conv2d(n_feat,n_feat, 1, padding=0, bias=True)) MB_sigmoid = [] MB_sigmoid.append(nn.Sigmoid()) self.MB_RB1 = nn.Sequential(*MB_RB1) self.MB_Down = nn.Sequential(*MB_Down) self.MB_RB2 = nn.Sequential(*MB_RB2) self.MB_Up = nn.Sequential(*MB_Up) self.MB_RB3 = nn.Sequential(*MB_RB3) self.MB_1x1conv = nn.Sequential(*MB_1x1conv) self.MB_sigmoid = nn.Sequential(*MB_sigmoid) def forward(self, x): x_RB1 = self.MB_RB1(x) x_Down = self.MB_Down(x_RB1) x_RB2 = self.MB_RB2(x_Down) x_Up = self.MB_Up(x_RB2) x_preRB3 = x_RB1 + x_Up x_RB3 = self.MB_RB3(x_preRB3) x_1x1 = self.MB_1x1conv(x_RB3) mx = self.MB_sigmoid(x_1x1) return mx ## define nonlocal mask branch class NLMaskBranchDownUp(nn.Module): def __init__( self, conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1): super(NLMaskBranchDownUp, self).__init__() MB_RB1 = [] MB_RB1.append(NonLocalBlock2D(n_feat, n_feat // 2)) MB_RB1.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=res_scale)) MB_Down = [] MB_Down.append(nn.Conv2d(n_feat,n_feat, 3, stride=2, padding=1)) MB_RB2 = [] for i in range(2): MB_RB2.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=res_scale)) MB_Up = [] MB_Up.append(nn.ConvTranspose2d(n_feat,n_feat, 6, stride=2, padding=2)) MB_RB3 = [] MB_RB3.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=res_scale)) MB_1x1conv = [] MB_1x1conv.append(nn.Conv2d(n_feat,n_feat, 1, padding=0, bias=True)) MB_sigmoid = [] MB_sigmoid.append(nn.Sigmoid()) self.MB_RB1 = nn.Sequential(*MB_RB1) self.MB_Down = nn.Sequential(*MB_Down) self.MB_RB2 = nn.Sequential(*MB_RB2) self.MB_Up = nn.Sequential(*MB_Up) self.MB_RB3 = nn.Sequential(*MB_RB3) self.MB_1x1conv = nn.Sequential(*MB_1x1conv) self.MB_sigmoid = nn.Sequential(*MB_sigmoid) def forward(self, x): x_RB1 = self.MB_RB1(x) x_Down = self.MB_Down(x_RB1) x_RB2 = self.MB_RB2(x_Down) x_Up = self.MB_Up(x_RB2) x_preRB3 = x_RB1 + x_Up x_RB3 = self.MB_RB3(x_preRB3) x_1x1 = self.MB_1x1conv(x_RB3) mx = self.MB_sigmoid(x_1x1) return mx ## define residual attention module class ResAttModuleDownUpPlus(nn.Module): def __init__( self, conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1): super(ResAttModuleDownUpPlus, self).__init__() RA_RB1 = [] RA_RB1.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=res_scale)) RA_TB = [] RA_TB.append(TrunkBranch(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=res_scale)) RA_MB = [] RA_MB.append(MaskBranchDownUp(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=res_scale)) RA_tail = [] for i in range(2): RA_tail.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=res_scale)) self.RA_RB1 = nn.Sequential(*RA_RB1) self.RA_TB = nn.Sequential(*RA_TB) self.RA_MB = nn.Sequential(*RA_MB) self.RA_tail = nn.Sequential(*RA_tail) def forward(self, input): RA_RB1_x = self.RA_RB1(input) tx = self.RA_TB(RA_RB1_x) mx = self.RA_MB(RA_RB1_x) txmx = tx * mx hx = txmx + RA_RB1_x hx = self.RA_tail(hx) return hx ## define nonlocal residual attention module class NLResAttModuleDownUpPlus(nn.Module): def __init__( self, conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1): super(NLResAttModuleDownUpPlus, self).__init__() RA_RB1 = [] RA_RB1.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=res_scale)) RA_TB = [] RA_TB.append(TrunkBranch(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=res_scale)) RA_MB = [] RA_MB.append(NLMaskBranchDownUp(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=res_scale)) RA_tail = [] for i in range(2): RA_tail.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=res_scale)) self.RA_RB1 = nn.Sequential(*RA_RB1) self.RA_TB = nn.Sequential(*RA_TB) self.RA_MB = nn.Sequential(*RA_MB) self.RA_tail = nn.Sequential(*RA_tail) def forward(self, input): RA_RB1_x = self.RA_RB1(input) tx = self.RA_TB(RA_RB1_x) mx = self.RA_MB(RA_RB1_x) txmx = tx * mx hx = txmx + RA_RB1_x hx = self.RA_tail(hx) return hx def make_model(args, parent=False): return RNAN(args) ### RNAN ### residual attention + downscale upscale + denoising class _ResGroup(nn.Module): def __init__(self, conv, n_feats, kernel_size, act, res_scale): super(_ResGroup, self).__init__() modules_body = [] modules_body.append(ResAttModuleDownUpPlus(conv, n_feats, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=res_scale)) modules_body.append(conv(n_feats, n_feats, kernel_size)) self.body = nn.Sequential(*modules_body) def forward(self, x): res = self.body(x) return res ### nonlocal residual attention + downscale upscale + denoising class _NLResGroup(nn.Module): def __init__(self, conv, n_feats, kernel_size, act, res_scale): super(_NLResGroup, self).__init__() modules_body = [] # changed this to accept scale args modules_body.append(NLResAttModuleDownUpPlus( conv, n_feats, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=res_scale)) # if we don't use group residual, donot remove the following conv modules_body.append(conv(n_feats, n_feats, kernel_size)) self.body = nn.Sequential(*modules_body) def forward(self, x): res = self.body(x) #res += x return res class RNAN(nn.Module): def __init__(self, scale_factor: Optional[int] = 8, args: Optional[dict] = None, conv=default_conv): """ Default parameters provided from the original paper. https://arxiv.org/pdf/1903.10082 Parameters --- :param n_colors: presumable this is the input channel dim (e.g., C=3 for RGB, etc ) """ super(RNAN, self).__init__() if args != None: n_resgroup = args.n_resgroups n_resblock = args.n_resblocks n_feats = args.n_feats reduction = args.reduction scale = args.scale[0] n_colors = args.n_colors else: # input channel dim n_colors = 1 n_resgroup = 10 # set to 2; unused n_resblock = 2 n_feats = 64 reduction = ... # assuming this is a standard SR factor scale = scale_factor assert scale in [2, 4, 8] kernel_size = 3 act = nn.ReLU(True) # define head module modules_head = [conv(n_colors, n_feats, kernel_size)] # define body module # it looks like we hard-coded two NL-blocks modules_body_nl_low = [ _NLResGroup( conv, n_feats, kernel_size, act=act, res_scale=scale)] # the authors use 8 local res blocks in the paper # this loop creates N-2 blocks, so we set n_resgroup=10 to create # 10-2=8 blocks modules_body = [ _ResGroup( conv, n_feats, kernel_size, act=act, res_scale=scale) \ for _ in range(n_resgroup - 2)] modules_body_nl_high = [ _NLResGroup( conv, n_feats, kernel_size, act=act, res_scale=scale)] modules_body.append(conv(n_feats, n_feats, kernel_size)) # define tail module modules_tail = [ Upsampler(conv, scale, n_feats, act=False), conv(n_feats, n_colors, kernel_size)] self.head = nn.Sequential(*modules_head) self.body_nl_low = nn.Sequential(*modules_body_nl_low) self.body = nn.Sequential(*modules_body) self.body_nl_high = nn.Sequential(*modules_body_nl_high) self.tail = nn.Sequential(*modules_tail) def forward(self, x: torch.Tensor): # [B, H, W] -> [B, 1, H, W] if len(x.shape) == 3: x = x.unsqueeze(1) feats_shallow = self.head(x) res = self.body_nl_low(feats_shallow) res = self.body(res) res = self.body_nl_high(res) res += feats_shallow res_main = self.tail(res) return res_main def load_state_dict(self, state_dict, strict=False): own_state = self.state_dict() for name, param in state_dict.items(): if name in own_state: if isinstance(param, nn.Parameter): param = param.data try: own_state[name].copy_(param) except Exception: if name.find('tail') >= 0: print('Replace pre-trained upsampler to new one...') else: raise RuntimeError('While copying the parameter named {}, ' 'whose dimensions in the model are {} and ' 'whose dimensions in the checkpoint are {}.' .format(name, own_state[name].size(), param.size())) elif strict: if name.find('tail') == -1: raise KeyError('unexpected key "{}" in state_dict' .format(name)) if strict: missing = set(own_state.keys()) - set(state_dict.keys()) if len(missing) > 0: raise KeyError('missing keys in state_dict: "{}"'.format(missing)) if __name__ == "__main__": model = RNAN() x = torch.rand((1, 1, 64, 64)) breakpoint() model(x)