Spaces:
Runtime error
Runtime error
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.autograd import Variable | |
| def default_conv(in_channels, out_channels, kernel_size, bias=True): | |
| return nn.Conv2d( | |
| in_channels, out_channels, kernel_size, | |
| padding=(kernel_size//2), bias=bias) | |
| class MeanShift(nn.Conv2d): | |
| def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1): | |
| super(MeanShift, self).__init__(3, 3, kernel_size=1) | |
| std = torch.Tensor(rgb_std) | |
| self.weight.data = torch.eye(3).view(3, 3, 1, 1) | |
| self.weight.data.div_(std.view(3, 1, 1, 1)) | |
| self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) | |
| self.bias.data.div_(std) | |
| self.requires_grad = False | |
| class BasicBlock(nn.Sequential): | |
| def __init__( | |
| self, conv, in_channels, out_channels, kernel_size, stride=1, bias=True, | |
| bn=False, act=nn.ReLU(True)): | |
| m = [conv(in_channels, out_channels, kernel_size, bias=bias)] | |
| if bn: | |
| m.append(nn.BatchNorm2d(out_channels)) | |
| if act is not None: | |
| m.append(act) | |
| super(BasicBlock, self).__init__(*m) | |
| class ResBlock(nn.Module): | |
| def __init__( | |
| self, conv, n_feat, kernel_size, | |
| bias=True, bn=False, act=nn.ReLU(True), res_scale=1): | |
| super(ResBlock, self).__init__() | |
| m = [] | |
| for i in range(2): | |
| m.append(conv(n_feat, n_feat, kernel_size, bias=bias)) | |
| if bn: m.append(nn.BatchNorm2d(n_feat)) | |
| if i == 0: m.append(act) | |
| self.body = nn.Sequential(*m) | |
| self.res_scale = res_scale | |
| def forward(self, x): | |
| res = self.body(x).mul(self.res_scale) | |
| res += x | |
| return res | |
| class Upsampler(nn.Sequential): | |
| def __init__(self, conv, scale, n_feat, bn=False, act=False, bias=True): | |
| m = [] | |
| if (scale & (scale - 1)) == 0: # Is scale = 2^n? | |
| for _ in range(int(math.log(scale, 2))): | |
| m.append(conv(n_feat, 4 * n_feat, 3, bias)) | |
| m.append(nn.PixelShuffle(2)) | |
| if bn: m.append(nn.BatchNorm2d(n_feat)) | |
| if act: m.append(act()) | |
| elif scale == 3: | |
| m.append(conv(n_feat, 9 * n_feat, 3, bias)) | |
| m.append(nn.PixelShuffle(3)) | |
| if bn: m.append(nn.BatchNorm2d(n_feat)) | |
| if act: m.append(act()) | |
| else: | |
| raise NotImplementedError | |
| super(Upsampler, self).__init__(*m) | |
| class DownBlock(nn.Module): | |
| def __init__(self, scale): | |
| super().__init__() | |
| self.scale = scale | |
| def forward(self, x): | |
| n, c, h, w = x.size() | |
| x = x.view(n, c, h//self.scale, self.scale, w//self.scale, self.scale) | |
| x = x.permute(0, 3, 5, 1, 2, 4).contiguous() | |
| x = x.view(n, c * (self.scale**2), h//self.scale, w//self.scale) | |
| return x | |
| # NONLocalBlock2D | |
| # ref: https://github.com/AlexHex7/Non-local_pytorch/blob/master/Non-Local_pytorch_0.4.1_to_1.1.0/lib/non_local_dot_product.py | |
| # ref: https://github.com/yulunzhang/RNAN/blob/master/SR/code/model/common.py | |
| class NonLocalBlock2D(nn.Module): | |
| def __init__(self, in_channels, inter_channels): | |
| super(NonLocalBlock2D, self).__init__() | |
| self.in_channels = in_channels | |
| self.inter_channels = inter_channels | |
| self.g = nn.Conv2d(in_channels=in_channels, out_channels=inter_channels, | |
| kernel_size=1, stride=1, padding=0) | |
| self.W = nn.Conv2d(in_channels=inter_channels, out_channels=in_channels, | |
| kernel_size=1, stride=1, padding=0) | |
| nn.init.constant_(self.W.weight, 0) | |
| nn.init.constant_(self.W.bias, 0) | |
| self.theta = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, | |
| kernel_size=1, stride=1, padding=0) | |
| self.phi = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, | |
| kernel_size=1, stride=1, padding=0) | |
| def forward(self, x): | |
| batch_size = x.size(0) | |
| g_x = self.g(x).view(batch_size, self.inter_channels, -1) | |
| g_x = g_x.permute(0, 2, 1) | |
| theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) | |
| theta_x = theta_x.permute(0, 2, 1) | |
| phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) | |
| f = torch.matmul(theta_x, phi_x) | |
| # use dot production | |
| # N = f.size(-1) | |
| # f_div_C = f / N | |
| # use embedding gaussian | |
| f_div_C = F.softmax(f, dim=-1) | |
| y = torch.matmul(f_div_C, g_x) | |
| y = y.permute(0, 2, 1).contiguous() | |
| y = y.view(batch_size, self.inter_channels, *x.size()[2:]) | |
| W_y = self.W(y) | |
| z = W_y + x | |
| return z |