Spaces:
Runtime error
Runtime error
| # Residual Dense Network for Image Super-Resolution | |
| # https://arxiv.org/abs/1802.08797 | |
| # modified from: https://github.com/thstkdgus35/EDSR-PyTorch | |
| from argparse import Namespace | |
| import torch | |
| import torch.nn as nn | |
| from models import register | |
| class RDB_Conv(nn.Module): | |
| def __init__(self, inChannels, growRate, kSize=3): | |
| super(RDB_Conv, self).__init__() | |
| Cin = inChannels | |
| G = growRate | |
| self.conv = nn.Sequential(*[ | |
| nn.Conv2d(Cin, G, kSize, padding=(kSize-1)//2, stride=1), | |
| nn.ReLU() | |
| ]) | |
| def forward(self, x): | |
| out = self.conv(x) | |
| return torch.cat((x, out), 1) | |
| class RDB(nn.Module): | |
| def __init__(self, growRate0, growRate, nConvLayers, kSize=3): | |
| super(RDB, self).__init__() | |
| G0 = growRate0 | |
| G = growRate | |
| C = nConvLayers | |
| convs = [] | |
| for c in range(C): | |
| convs.append(RDB_Conv(G0 + c*G, G)) | |
| self.convs = nn.Sequential(*convs) | |
| # Local Feature Fusion | |
| self.LFF = nn.Conv2d(G0 + C*G, G0, 1, padding=0, stride=1) | |
| def forward(self, x): | |
| return self.LFF(self.convs(x)) + x | |
| class RDN(nn.Module): | |
| def __init__(self, args): | |
| super(RDN, self).__init__() | |
| self.args = args | |
| r = args.scale[0] | |
| G0 = args.G0 | |
| kSize = args.RDNkSize | |
| # number of RDB blocks, conv layers, out channels | |
| self.D, C, G = { | |
| 'A': (20, 6, 32), | |
| 'B': (16, 8, 64), | |
| }[args.RDNconfig] | |
| # Shallow feature extraction net | |
| self.SFENet1 = nn.Conv2d(args.n_colors, G0, kSize, padding=(kSize-1)//2, stride=1) | |
| self.SFENet2 = nn.Conv2d(G0, G0, kSize, padding=(kSize-1)//2, stride=1) | |
| # Redidual dense blocks and dense feature fusion | |
| self.RDBs = nn.ModuleList() | |
| for i in range(self.D): | |
| self.RDBs.append( | |
| RDB(growRate0 = G0, growRate = G, nConvLayers = C) | |
| ) | |
| # Global Feature Fusion | |
| self.GFF = nn.Sequential(*[ | |
| nn.Conv2d(self.D * G0, G0, 1, padding=0, stride=1), | |
| nn.Conv2d(G0, G0, kSize, padding=(kSize-1)//2, stride=1) | |
| ]) | |
| if args.no_upsampling: | |
| self.out_dim = G0 | |
| else: | |
| self.out_dim = args.n_colors | |
| # Up-sampling net | |
| if r == 2 or r == 3: | |
| self.UPNet = nn.Sequential(*[ | |
| nn.Conv2d(G0, G * r * r, kSize, padding=(kSize-1)//2, stride=1), | |
| nn.PixelShuffle(r), | |
| nn.Conv2d(G, args.n_colors, kSize, padding=(kSize-1)//2, stride=1) | |
| ]) | |
| elif r == 4: | |
| self.UPNet = nn.Sequential(*[ | |
| nn.Conv2d(G0, G * 4, kSize, padding=(kSize-1)//2, stride=1), | |
| nn.PixelShuffle(2), | |
| nn.Conv2d(G, G * 4, kSize, padding=(kSize-1)//2, stride=1), | |
| nn.PixelShuffle(2), | |
| nn.Conv2d(G, args.n_colors, kSize, padding=(kSize-1)//2, stride=1) | |
| ]) | |
| else: | |
| raise ValueError("scale must be 2 or 3 or 4.") | |
| def forward(self, x): | |
| f__1 = self.SFENet1(x) | |
| x = self.SFENet2(f__1) | |
| RDBs_out = [] | |
| for i in range(self.D): | |
| x = self.RDBs[i](x) | |
| RDBs_out.append(x) | |
| x = self.GFF(torch.cat(RDBs_out,1)) | |
| x += f__1 | |
| if self.args.no_upsampling: | |
| return x | |
| else: | |
| return self.UPNet(x) | |
| def make_rdn(G0=64, RDNkSize=3, RDNconfig='B', | |
| scale=2, no_upsampling=False): | |
| args = Namespace() | |
| args.G0 = G0 | |
| args.RDNkSize = RDNkSize | |
| args.RDNconfig = RDNconfig | |
| args.scale = [scale] | |
| args.no_upsampling = no_upsampling | |
| args.n_colors = 3 | |
| return RDN(args) | |