Spaces:
Runtime error
Runtime error
| # modified from: https://github.com/thstkdgus35/EDSR-PyTorch | |
| import math | |
| from argparse import Namespace | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from models import register | |
| def default_conv(in_channels, out_channels, kernel_size, bias=True): | |
| return nn.Conv2d( | |
| in_channels, out_channels, kernel_size, | |
| padding=(kernel_size//2), bias=bias) | |
| class MeanShift(nn.Conv2d): | |
| def __init__( | |
| self, rgb_range, | |
| rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1): | |
| super(MeanShift, self).__init__(3, 3, kernel_size=1) | |
| std = torch.Tensor(rgb_std) | |
| self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) | |
| self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std | |
| for p in self.parameters(): | |
| p.requires_grad = False | |
| class ResBlock(nn.Module): | |
| def __init__( | |
| self, conv, n_feats, kernel_size, | |
| bias=True, bn=False, act=nn.ReLU(True), res_scale=1): | |
| super(ResBlock, self).__init__() | |
| m = [] | |
| for i in range(2): | |
| m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) | |
| if bn: | |
| m.append(nn.BatchNorm2d(n_feats)) | |
| if i == 0: | |
| m.append(act) | |
| self.body = nn.Sequential(*m) | |
| self.res_scale = res_scale | |
| def forward(self, x): | |
| res = self.body(x).mul(self.res_scale) | |
| res += x | |
| return res | |
| class Upsampler(nn.Sequential): | |
| def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True): | |
| m = [] | |
| if (scale & (scale - 1)) == 0: # Is scale = 2^n? | |
| for _ in range(int(math.log(scale, 2))): | |
| m.append(conv(n_feats, 4 * n_feats, 3, bias)) | |
| m.append(nn.PixelShuffle(2)) | |
| if bn: | |
| m.append(nn.BatchNorm2d(n_feats)) | |
| if act == 'relu': | |
| m.append(nn.ReLU(True)) | |
| elif act == 'prelu': | |
| m.append(nn.PReLU(n_feats)) | |
| elif scale == 3: | |
| m.append(conv(n_feats, 9 * n_feats, 3, bias)) | |
| m.append(nn.PixelShuffle(3)) | |
| if bn: | |
| m.append(nn.BatchNorm2d(n_feats)) | |
| if act == 'relu': | |
| m.append(nn.ReLU(True)) | |
| elif act == 'prelu': | |
| m.append(nn.PReLU(n_feats)) | |
| else: | |
| raise NotImplementedError | |
| super(Upsampler, self).__init__(*m) | |
| url = { | |
| 'r16f64x2': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x2-1bc95232.pt', | |
| 'r16f64x3': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x3-abf2a44e.pt', | |
| 'r16f64x4': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x4-6b446fab.pt', | |
| 'r32f256x2': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x2-0edfb8a3.pt', | |
| 'r32f256x3': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x3-ea3ef2c6.pt', | |
| 'r32f256x4': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x4-4f62e9ef.pt' | |
| } | |
| class EDSR(nn.Module): | |
| def __init__(self, args, conv=default_conv): | |
| super(EDSR, self).__init__() | |
| self.args = args | |
| n_resblocks = args.n_resblocks | |
| n_feats = args.n_feats | |
| kernel_size = 3 | |
| scale = args.scale[0] | |
| act = nn.ReLU(True) | |
| url_name = 'r{}f{}x{}'.format(n_resblocks, n_feats, scale) | |
| if url_name in url: | |
| self.url = url[url_name] | |
| else: | |
| self.url = None | |
| self.sub_mean = MeanShift(args.rgb_range) | |
| self.add_mean = MeanShift(args.rgb_range, sign=1) | |
| # define head module | |
| m_head = [conv(args.n_colors, n_feats, kernel_size)] | |
| # define body module | |
| m_body = [ | |
| ResBlock( | |
| conv, n_feats, kernel_size, act=act, res_scale=args.res_scale | |
| ) for _ in range(n_resblocks) | |
| ] | |
| m_body.append(conv(n_feats, n_feats, kernel_size)) | |
| self.head = nn.Sequential(*m_head) | |
| self.body = nn.Sequential(*m_body) | |
| if args.no_upsampling: | |
| self.out_dim = n_feats | |
| else: | |
| self.out_dim = args.n_colors | |
| # define tail module | |
| m_tail = [ | |
| Upsampler(conv, scale, n_feats, act=False), | |
| conv(n_feats, args.n_colors, kernel_size) | |
| ] | |
| self.tail = nn.Sequential(*m_tail) | |
| self.load_state_dict('pretrained/'+self.url.split('/')[-1]) | |
| def forward(self, x): | |
| #x = self.sub_mean(x) | |
| x = self.head(x) | |
| res = self.body(x) | |
| res += x | |
| if self.args.no_upsampling: | |
| x = res | |
| else: | |
| x = self.tail(res) | |
| #x = self.add_mean(x) | |
| return x | |
| def load_state_dict(self, state_dict, strict=True): | |
| state_dict = torch.load(state_dict, map_location='cpu') | |
| own_state = self.state_dict() | |
| print('loading pretrain model') | |
| for name, param in state_dict.items(): | |
| if name in own_state: | |
| if isinstance(param, nn.Parameter): | |
| param = param.data | |
| try: | |
| own_state[name].copy_(param) | |
| except Exception: | |
| if name.find('tail') == -1: | |
| raise RuntimeError('While copying the parameter named {}, ' | |
| 'whose dimensions in the model are {} and ' | |
| 'whose dimensions in the checkpoint are {}.' | |
| .format(name, own_state[name].size(), param.size())) | |
| elif strict: | |
| if name.find('tail') == -1: | |
| raise KeyError('unexpected key "{}" in state_dict' | |
| .format(name)) | |
| def make_edsr_baseline(n_resblocks=16, n_feats=64, res_scale=1, | |
| scale=2, no_upsampling=False, rgb_range=1): | |
| args = Namespace() | |
| args.n_resblocks = n_resblocks | |
| args.n_feats = n_feats | |
| args.res_scale = res_scale | |
| args.scale = [scale] | |
| args.no_upsampling = no_upsampling | |
| args.rgb_range = rgb_range | |
| args.n_colors = 3 | |
| return EDSR(args) | |
| def make_edsr(n_resblocks=32, n_feats=256, res_scale=0.1, | |
| scale=2, no_upsampling=False, rgb_range=1): | |
| args = Namespace() | |
| args.n_resblocks = n_resblocks | |
| args.n_feats = n_feats | |
| args.res_scale = res_scale | |
| args.scale = [scale] | |
| args.no_upsampling = no_upsampling | |
| args.rgb_range = rgb_range | |
| args.n_colors = 3 | |
| return EDSR(args) | |