Spaces:
Sleeping
Sleeping
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from typing import Optional | |
| def default_conv(in_channels, out_channels, kernel_size, bias=True): | |
| return nn.Conv2d( | |
| in_channels, out_channels, kernel_size, | |
| padding=(kernel_size//2), bias=bias) | |
| class MeanShift(nn.Conv2d): | |
| def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1): | |
| super(MeanShift, self).__init__(3, 3, kernel_size=1) | |
| std = torch.Tensor(rgb_std) | |
| self.weight.data = torch.eye(3).view(3, 3, 1, 1) | |
| self.weight.data.div_(std.view(3, 1, 1, 1)) | |
| self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) | |
| self.bias.data.div_(std) | |
| self.requires_grad = False | |
| class BasicBlock(nn.Sequential): | |
| def __init__( | |
| self, in_channels, out_channels, kernel_size, stride=1, bias=False, | |
| bn=True, act=nn.ReLU(True)): | |
| m = [nn.Conv2d( | |
| in_channels, out_channels, kernel_size, | |
| padding=(kernel_size//2), stride=stride, bias=bias) | |
| ] | |
| if bn: m.append(nn.BatchNorm2d(out_channels)) | |
| if act is not None: m.append(act) | |
| super(BasicBlock, self).__init__(*m) | |
| class ResBlock(nn.Module): | |
| def __init__( | |
| self, conv, n_feat, kernel_size, | |
| bias=True, bn=False, act=nn.ReLU(True), res_scale=1): | |
| super(ResBlock, self).__init__() | |
| m = [] | |
| for i in range(2): | |
| m.append(conv(n_feat, n_feat, kernel_size, bias=bias)) | |
| if bn: m.append(nn.BatchNorm2d(n_feat)) | |
| if i == 0: m.append(act) | |
| self.body = nn.Sequential(*m) | |
| self.res_scale = res_scale | |
| def forward(self, x): | |
| res = self.body(x).mul(self.res_scale) | |
| res += x | |
| return res | |
| class Upsampler(nn.Sequential): | |
| def __init__(self, conv, scale, n_feat, bn=False, act=False, bias=True): | |
| m = [] | |
| if (scale & (scale - 1)) == 0: # Is scale = 2^n? | |
| for _ in range(int(math.log(scale, 2))): | |
| m.append(conv(n_feat, 4 * n_feat, 3, bias)) | |
| m.append(nn.PixelShuffle(2)) | |
| if bn: m.append(nn.BatchNorm2d(n_feat)) | |
| if act: m.append(act()) | |
| elif scale == 3: | |
| m.append(conv(n_feat, 9 * n_feat, 3, bias)) | |
| m.append(nn.PixelShuffle(3)) | |
| if bn: m.append(nn.BatchNorm2d(n_feat)) | |
| if act: m.append(act()) | |
| else: | |
| raise NotImplementedError | |
| super(Upsampler, self).__init__(*m) | |
| # add NonLocalBlock2D | |
| # reference: https://github.com/AlexHex7/Non-local_pytorch/blob/master/lib/non_local_simple_version.py | |
| class NonLocalBlock2D(nn.Module): | |
| def __init__(self, in_channels, inter_channels): | |
| super(NonLocalBlock2D, self).__init__() | |
| self.in_channels = in_channels | |
| self.inter_channels = inter_channels | |
| self.g = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0) | |
| self.W = nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0) | |
| nn.init.constant(self.W.weight, 0) | |
| nn.init.constant(self.W.bias, 0) | |
| self.theta = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0) | |
| self.phi = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0) | |
| def forward(self, x): | |
| batch_size = x.size(0) | |
| g_x = self.g(x).view(batch_size, self.inter_channels, -1) | |
| g_x = g_x.permute(0,2,1) | |
| theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) | |
| theta_x = theta_x.permute(0,2,1) | |
| phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) | |
| f = torch.matmul(theta_x, phi_x) | |
| f_div_C = F.softmax(f, dim=1) | |
| y = torch.matmul(f_div_C, g_x) | |
| y = y.permute(0,2,1).contiguous() | |
| y = y.view(batch_size, self.inter_channels, *x.size()[2:]) | |
| W_y = self.W(y) | |
| z = W_y + x | |
| return z | |
| ## define trunk branch | |
| class TrunkBranch(nn.Module): | |
| def __init__( | |
| self, conv, n_feat, kernel_size, | |
| bias=True, bn=False, act=nn.ReLU(True), res_scale=1): | |
| super(TrunkBranch, self).__init__() | |
| modules_body = [] | |
| for i in range(2): | |
| modules_body.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=res_scale)) | |
| self.body = nn.Sequential(*modules_body) | |
| def forward(self, x): | |
| tx = self.body(x) | |
| return tx | |
| ## define mask branch | |
| class MaskBranchDownUp(nn.Module): | |
| def __init__( | |
| self, conv, n_feat, kernel_size, | |
| bias=True, bn=False, act=nn.ReLU(True), res_scale=1): | |
| super(MaskBranchDownUp, self).__init__() | |
| MB_RB1 = [] | |
| MB_RB1.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=res_scale)) | |
| MB_Down = [] | |
| MB_Down.append(nn.Conv2d(n_feat,n_feat, 3, stride=2, padding=1)) | |
| MB_RB2 = [] | |
| for i in range(2): | |
| MB_RB2.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=res_scale)) | |
| MB_Up = [] | |
| MB_Up.append(nn.ConvTranspose2d(n_feat,n_feat, 6, stride=2, padding=2)) | |
| MB_RB3 = [] | |
| MB_RB3.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=res_scale)) | |
| MB_1x1conv = [] | |
| MB_1x1conv.append(nn.Conv2d(n_feat,n_feat, 1, padding=0, bias=True)) | |
| MB_sigmoid = [] | |
| MB_sigmoid.append(nn.Sigmoid()) | |
| self.MB_RB1 = nn.Sequential(*MB_RB1) | |
| self.MB_Down = nn.Sequential(*MB_Down) | |
| self.MB_RB2 = nn.Sequential(*MB_RB2) | |
| self.MB_Up = nn.Sequential(*MB_Up) | |
| self.MB_RB3 = nn.Sequential(*MB_RB3) | |
| self.MB_1x1conv = nn.Sequential(*MB_1x1conv) | |
| self.MB_sigmoid = nn.Sequential(*MB_sigmoid) | |
| def forward(self, x): | |
| x_RB1 = self.MB_RB1(x) | |
| x_Down = self.MB_Down(x_RB1) | |
| x_RB2 = self.MB_RB2(x_Down) | |
| x_Up = self.MB_Up(x_RB2) | |
| x_preRB3 = x_RB1 + x_Up | |
| x_RB3 = self.MB_RB3(x_preRB3) | |
| x_1x1 = self.MB_1x1conv(x_RB3) | |
| mx = self.MB_sigmoid(x_1x1) | |
| return mx | |
| ## define nonlocal mask branch | |
| class NLMaskBranchDownUp(nn.Module): | |
| def __init__( | |
| self, conv, n_feat, kernel_size, | |
| bias=True, bn=False, act=nn.ReLU(True), res_scale=1): | |
| super(NLMaskBranchDownUp, self).__init__() | |
| MB_RB1 = [] | |
| MB_RB1.append(NonLocalBlock2D(n_feat, n_feat // 2)) | |
| MB_RB1.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=res_scale)) | |
| MB_Down = [] | |
| MB_Down.append(nn.Conv2d(n_feat,n_feat, 3, stride=2, padding=1)) | |
| MB_RB2 = [] | |
| for i in range(2): | |
| MB_RB2.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=res_scale)) | |
| MB_Up = [] | |
| MB_Up.append(nn.ConvTranspose2d(n_feat,n_feat, 6, stride=2, padding=2)) | |
| MB_RB3 = [] | |
| MB_RB3.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=res_scale)) | |
| MB_1x1conv = [] | |
| MB_1x1conv.append(nn.Conv2d(n_feat,n_feat, 1, padding=0, bias=True)) | |
| MB_sigmoid = [] | |
| MB_sigmoid.append(nn.Sigmoid()) | |
| self.MB_RB1 = nn.Sequential(*MB_RB1) | |
| self.MB_Down = nn.Sequential(*MB_Down) | |
| self.MB_RB2 = nn.Sequential(*MB_RB2) | |
| self.MB_Up = nn.Sequential(*MB_Up) | |
| self.MB_RB3 = nn.Sequential(*MB_RB3) | |
| self.MB_1x1conv = nn.Sequential(*MB_1x1conv) | |
| self.MB_sigmoid = nn.Sequential(*MB_sigmoid) | |
| def forward(self, x): | |
| x_RB1 = self.MB_RB1(x) | |
| x_Down = self.MB_Down(x_RB1) | |
| x_RB2 = self.MB_RB2(x_Down) | |
| x_Up = self.MB_Up(x_RB2) | |
| x_preRB3 = x_RB1 + x_Up | |
| x_RB3 = self.MB_RB3(x_preRB3) | |
| x_1x1 = self.MB_1x1conv(x_RB3) | |
| mx = self.MB_sigmoid(x_1x1) | |
| return mx | |
| ## define residual attention module | |
| class ResAttModuleDownUpPlus(nn.Module): | |
| def __init__( | |
| self, conv, n_feat, kernel_size, | |
| bias=True, bn=False, act=nn.ReLU(True), res_scale=1): | |
| super(ResAttModuleDownUpPlus, self).__init__() | |
| RA_RB1 = [] | |
| RA_RB1.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=res_scale)) | |
| RA_TB = [] | |
| RA_TB.append(TrunkBranch(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=res_scale)) | |
| RA_MB = [] | |
| RA_MB.append(MaskBranchDownUp(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=res_scale)) | |
| RA_tail = [] | |
| for i in range(2): | |
| RA_tail.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=res_scale)) | |
| self.RA_RB1 = nn.Sequential(*RA_RB1) | |
| self.RA_TB = nn.Sequential(*RA_TB) | |
| self.RA_MB = nn.Sequential(*RA_MB) | |
| self.RA_tail = nn.Sequential(*RA_tail) | |
| def forward(self, input): | |
| RA_RB1_x = self.RA_RB1(input) | |
| tx = self.RA_TB(RA_RB1_x) | |
| mx = self.RA_MB(RA_RB1_x) | |
| txmx = tx * mx | |
| hx = txmx + RA_RB1_x | |
| hx = self.RA_tail(hx) | |
| return hx | |
| ## define nonlocal residual attention module | |
| class NLResAttModuleDownUpPlus(nn.Module): | |
| def __init__( | |
| self, conv, n_feat, kernel_size, | |
| bias=True, bn=False, act=nn.ReLU(True), res_scale=1): | |
| super(NLResAttModuleDownUpPlus, self).__init__() | |
| RA_RB1 = [] | |
| RA_RB1.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=res_scale)) | |
| RA_TB = [] | |
| RA_TB.append(TrunkBranch(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=res_scale)) | |
| RA_MB = [] | |
| RA_MB.append(NLMaskBranchDownUp(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=res_scale)) | |
| RA_tail = [] | |
| for i in range(2): | |
| RA_tail.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=res_scale)) | |
| self.RA_RB1 = nn.Sequential(*RA_RB1) | |
| self.RA_TB = nn.Sequential(*RA_TB) | |
| self.RA_MB = nn.Sequential(*RA_MB) | |
| self.RA_tail = nn.Sequential(*RA_tail) | |
| def forward(self, input): | |
| RA_RB1_x = self.RA_RB1(input) | |
| tx = self.RA_TB(RA_RB1_x) | |
| mx = self.RA_MB(RA_RB1_x) | |
| txmx = tx * mx | |
| hx = txmx + RA_RB1_x | |
| hx = self.RA_tail(hx) | |
| return hx | |
| def make_model(args, parent=False): | |
| return RNAN(args) | |
| ### RNAN | |
| ### residual attention + downscale upscale + denoising | |
| class _ResGroup(nn.Module): | |
| def __init__(self, conv, n_feats, kernel_size, act, res_scale): | |
| super(_ResGroup, self).__init__() | |
| modules_body = [] | |
| modules_body.append(ResAttModuleDownUpPlus(conv, n_feats, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=res_scale)) | |
| modules_body.append(conv(n_feats, n_feats, kernel_size)) | |
| self.body = nn.Sequential(*modules_body) | |
| def forward(self, x): | |
| res = self.body(x) | |
| return res | |
| ### nonlocal residual attention + downscale upscale + denoising | |
| class _NLResGroup(nn.Module): | |
| def __init__(self, conv, n_feats, kernel_size, act, res_scale): | |
| super(_NLResGroup, self).__init__() | |
| modules_body = [] | |
| # changed this to accept scale args | |
| modules_body.append(NLResAttModuleDownUpPlus( | |
| conv, n_feats, kernel_size, bias=True, bn=False, act=nn.ReLU(True), | |
| res_scale=res_scale)) | |
| # if we don't use group residual, donot remove the following conv | |
| modules_body.append(conv(n_feats, n_feats, kernel_size)) | |
| self.body = nn.Sequential(*modules_body) | |
| def forward(self, x): | |
| res = self.body(x) | |
| #res += x | |
| return res | |
| class RNAN(nn.Module): | |
| def __init__(self, scale_factor: Optional[int] = 8, args: Optional[dict] = None, conv=default_conv): | |
| """ | |
| Default parameters provided from the original paper. | |
| https://arxiv.org/pdf/1903.10082 | |
| Parameters | |
| --- | |
| :param n_colors: presumable this is the input channel dim (e.g., C=3 for RGB, etc ) | |
| """ | |
| super(RNAN, self).__init__() | |
| if args != None: | |
| n_resgroup = args.n_resgroups | |
| n_resblock = args.n_resblocks | |
| n_feats = args.n_feats | |
| reduction = args.reduction | |
| scale = args.scale[0] | |
| n_colors = args.n_colors | |
| else: | |
| # input channel dim | |
| n_colors = 1 | |
| n_resgroup = 10 | |
| # set to 2; unused | |
| n_resblock = 2 | |
| n_feats = 64 | |
| reduction = ... | |
| # assuming this is a standard SR factor | |
| scale = scale_factor | |
| assert scale in [2, 4, 8] | |
| kernel_size = 3 | |
| act = nn.ReLU(True) | |
| # define head module | |
| modules_head = [conv(n_colors, n_feats, kernel_size)] | |
| # define body module | |
| # it looks like we hard-coded two NL-blocks | |
| modules_body_nl_low = [ | |
| _NLResGroup( | |
| conv, n_feats, kernel_size, act=act, res_scale=scale)] | |
| # the authors use 8 local res blocks in the paper | |
| # this loop creates N-2 blocks, so we set n_resgroup=10 to create | |
| # 10-2=8 blocks | |
| modules_body = [ | |
| _ResGroup( | |
| conv, n_feats, kernel_size, act=act, res_scale=scale) \ | |
| for _ in range(n_resgroup - 2)] | |
| modules_body_nl_high = [ | |
| _NLResGroup( | |
| conv, n_feats, kernel_size, act=act, res_scale=scale)] | |
| modules_body.append(conv(n_feats, n_feats, kernel_size)) | |
| # define tail module | |
| modules_tail = [ | |
| Upsampler(conv, scale, n_feats, act=False), | |
| conv(n_feats, n_colors, kernel_size)] | |
| self.head = nn.Sequential(*modules_head) | |
| self.body_nl_low = nn.Sequential(*modules_body_nl_low) | |
| self.body = nn.Sequential(*modules_body) | |
| self.body_nl_high = nn.Sequential(*modules_body_nl_high) | |
| self.tail = nn.Sequential(*modules_tail) | |
| def forward(self, x: torch.Tensor): | |
| # [B, H, W] -> [B, 1, H, W] | |
| if len(x.shape) == 3: | |
| x = x.unsqueeze(1) | |
| feats_shallow = self.head(x) | |
| res = self.body_nl_low(feats_shallow) | |
| res = self.body(res) | |
| res = self.body_nl_high(res) | |
| res += feats_shallow | |
| res_main = self.tail(res) | |
| return res_main | |
| def load_state_dict(self, state_dict, strict=False): | |
| own_state = self.state_dict() | |
| for name, param in state_dict.items(): | |
| if name in own_state: | |
| if isinstance(param, nn.Parameter): | |
| param = param.data | |
| try: | |
| own_state[name].copy_(param) | |
| except Exception: | |
| if name.find('tail') >= 0: | |
| print('Replace pre-trained upsampler to new one...') | |
| else: | |
| raise RuntimeError('While copying the parameter named {}, ' | |
| 'whose dimensions in the model are {} and ' | |
| 'whose dimensions in the checkpoint are {}.' | |
| .format(name, own_state[name].size(), param.size())) | |
| elif strict: | |
| if name.find('tail') == -1: | |
| raise KeyError('unexpected key "{}" in state_dict' | |
| .format(name)) | |
| if strict: | |
| missing = set(own_state.keys()) - set(state_dict.keys()) | |
| if len(missing) > 0: | |
| raise KeyError('missing keys in state_dict: "{}"'.format(missing)) | |
| if __name__ == "__main__": | |
| model = RNAN() | |
| x = torch.rand((1, 1, 64, 64)) | |
| breakpoint() | |
| model(x) |