Spaces:
Running
on
T4
Running
on
T4
| """ | |
| EDSR common.py | |
| Since a lot of models are developed on top of EDSR, here we include some common functions from EDSR. | |
| In this repository, the common functions is used by edsr_esa.py and ipt.py | |
| """ | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| def default_conv(in_channels, out_channels, kernel_size, bias=True): | |
| return nn.Conv2d( | |
| in_channels, out_channels, kernel_size, padding=(kernel_size // 2), bias=bias | |
| ) | |
| class MeanShift(nn.Conv2d): | |
| def __init__( | |
| self, | |
| rgb_range, | |
| rgb_mean=(0.4488, 0.4371, 0.4040), | |
| rgb_std=(1.0, 1.0, 1.0), | |
| sign=-1, | |
| ): | |
| super(MeanShift, self).__init__(3, 3, kernel_size=1) | |
| std = torch.Tensor(rgb_std) | |
| self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) | |
| self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std | |
| for p in self.parameters(): | |
| p.requires_grad = False | |
| class BasicBlock(nn.Sequential): | |
| def __init__( | |
| self, | |
| conv, | |
| in_channels, | |
| out_channels, | |
| kernel_size, | |
| stride=1, | |
| bias=False, | |
| bn=True, | |
| act=nn.ReLU(True), | |
| ): | |
| m = [conv(in_channels, out_channels, kernel_size, bias=bias)] | |
| if bn: | |
| m.append(nn.BatchNorm2d(out_channels)) | |
| if act is not None: | |
| m.append(act) | |
| super(BasicBlock, self).__init__(*m) | |
| class ESA(nn.Module): | |
| def __init__(self, esa_channels, n_feats): | |
| super(ESA, self).__init__() | |
| f = esa_channels | |
| self.conv1 = nn.Conv2d(n_feats, f, kernel_size=1) | |
| self.conv_f = nn.Conv2d(f, f, kernel_size=1) | |
| # self.conv_max = conv(f, f, kernel_size=3, padding=1) | |
| self.conv2 = nn.Conv2d(f, f, kernel_size=3, stride=2, padding=0) | |
| self.conv3 = nn.Conv2d(f, f, kernel_size=3, padding=1) | |
| # self.conv3_ = conv(f, f, kernel_size=3, padding=1) | |
| self.conv4 = nn.Conv2d(f, n_feats, kernel_size=1) | |
| self.sigmoid = nn.Sigmoid() | |
| # self.relu = nn.ReLU(inplace=True) | |
| def forward(self, x): | |
| c1_ = self.conv1(x) | |
| c1 = self.conv2(c1_) | |
| v_max = F.max_pool2d(c1, kernel_size=7, stride=3) | |
| c3 = self.conv3(v_max) | |
| # v_range = self.relu(self.conv_max(v_max)) | |
| # c3 = self.relu(self.conv3(v_range)) | |
| # c3 = self.conv3_(c3) | |
| c3 = F.interpolate( | |
| c3, (x.size(2), x.size(3)), mode="bilinear", align_corners=False | |
| ) | |
| cf = self.conv_f(c1_) | |
| c4 = self.conv4(c3 + cf) | |
| m = self.sigmoid(c4) | |
| return x * m | |
| # class ESA(nn.Module): | |
| # def __init__(self, esa_channels, n_feats, conv=nn.Conv2d): | |
| # super(ESA, self).__init__() | |
| # f = n_feats // 4 | |
| # self.conv1 = conv(n_feats, f, kernel_size=1) | |
| # self.conv_f = conv(f, f, kernel_size=1) | |
| # self.conv_max = conv(f, f, kernel_size=3, padding=1) | |
| # self.conv2 = conv(f, f, kernel_size=3, stride=2, padding=0) | |
| # self.conv3 = conv(f, f, kernel_size=3, padding=1) | |
| # self.conv3_ = conv(f, f, kernel_size=3, padding=1) | |
| # self.conv4 = conv(f, n_feats, kernel_size=1) | |
| # self.sigmoid = nn.Sigmoid() | |
| # self.relu = nn.ReLU(inplace=True) | |
| # | |
| # def forward(self, x): | |
| # c1_ = (self.conv1(x)) | |
| # c1 = self.conv2(c1_) | |
| # v_max = F.max_pool2d(c1, kernel_size=7, stride=3) | |
| # v_range = self.relu(self.conv_max(v_max)) | |
| # c3 = self.relu(self.conv3(v_range)) | |
| # c3 = self.conv3_(c3) | |
| # c3 = F.interpolate(c3, (x.size(2), x.size(3)), mode='bilinear', align_corners=False) | |
| # cf = self.conv_f(c1_) | |
| # c4 = self.conv4(c3 + cf) | |
| # m = self.sigmoid(c4) | |
| # | |
| # return x * m | |
| class ResBlock(nn.Module): | |
| def __init__( | |
| self, | |
| conv, | |
| n_feats, | |
| kernel_size, | |
| bias=True, | |
| bn=False, | |
| act=nn.ReLU(True), | |
| res_scale=1, | |
| esa_block=True, | |
| depth_wise_kernel=7, | |
| ): | |
| super(ResBlock, self).__init__() | |
| m = [] | |
| for i in range(2): | |
| m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) | |
| if bn: | |
| m.append(nn.BatchNorm2d(n_feats)) | |
| if i == 0: | |
| m.append(act) | |
| self.body = nn.Sequential(*m) | |
| self.esa_block = esa_block | |
| if self.esa_block: | |
| esa_channels = 16 | |
| self.c5 = nn.Conv2d( | |
| n_feats, | |
| n_feats, | |
| depth_wise_kernel, | |
| padding=depth_wise_kernel // 2, | |
| groups=n_feats, | |
| bias=True, | |
| ) | |
| self.esa = ESA(esa_channels, n_feats) | |
| self.res_scale = res_scale | |
| def forward(self, x): | |
| res = self.body(x).mul(self.res_scale) | |
| res += x | |
| if self.esa_block: | |
| res = self.esa(self.c5(res)) | |
| return res | |
| class Upsampler(nn.Sequential): | |
| def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True): | |
| m = [] | |
| if (scale & (scale - 1)) == 0: # Is scale = 2^n? | |
| for _ in range(int(math.log(scale, 2))): | |
| m.append(conv(n_feats, 4 * n_feats, 3, bias)) | |
| m.append(nn.PixelShuffle(2)) | |
| if bn: | |
| m.append(nn.BatchNorm2d(n_feats)) | |
| if act == "relu": | |
| m.append(nn.ReLU(True)) | |
| elif act == "prelu": | |
| m.append(nn.PReLU(n_feats)) | |
| elif scale == 3: | |
| m.append(conv(n_feats, 9 * n_feats, 3, bias)) | |
| m.append(nn.PixelShuffle(3)) | |
| if bn: | |
| m.append(nn.BatchNorm2d(n_feats)) | |
| if act == "relu": | |
| m.append(nn.ReLU(True)) | |
| elif act == "prelu": | |
| m.append(nn.PReLU(n_feats)) | |
| else: | |
| raise NotImplementedError | |
| super(Upsampler, self).__init__(*m) | |
| class LiteUpsampler(nn.Sequential): | |
| def __init__(self, conv, scale, n_feats, n_out=3, bn=False, act=False, bias=True): | |
| m = [] | |
| m.append(conv(n_feats, n_out * (scale**2), 3, bias)) | |
| m.append(nn.PixelShuffle(scale)) | |
| # if (scale & (scale - 1)) == 0: # Is scale = 2^n? | |
| # for _ in range(int(math.log(scale, 2))): | |
| # m.append(conv(n_feats, 4 * n_out, 3, bias)) | |
| # m.append(nn.PixelShuffle(2)) | |
| # if bn: | |
| # m.append(nn.BatchNorm2d(n_out)) | |
| # if act == 'relu': | |
| # m.append(nn.ReLU(True)) | |
| # elif act == 'prelu': | |
| # m.append(nn.PReLU(n_out)) | |
| # elif scale == 3: | |
| # m.append(conv(n_feats, 9 * n_out, 3, bias)) | |
| # m.append(nn.PixelShuffle(3)) | |
| # if bn: | |
| # m.append(nn.BatchNorm2d(n_out)) | |
| # if act == 'relu': | |
| # m.append(nn.ReLU(True)) | |
| # elif act == 'prelu': | |
| # m.append(nn.PReLU(n_out)) | |
| # else: | |
| # raise NotImplementedError | |
| super(LiteUpsampler, self).__init__(*m) | |