import torch from torch import nn import Deraining.model.common as common import torch.nn.functional as F from Deraining.moco.builder import MoCo def make_model(args): return BlindSR(args) class DA_conv(nn.Module): def __init__(self, channels_in, channels_out, kernel_size, reduction): super(DA_conv, self).__init__() self.channels_out = channels_out self.channels_in = channels_in self.kernel_size = kernel_size self.kernel = nn.Sequential( nn.Linear(64, 64, bias=False), nn.LeakyReLU(0.1, True), nn.Linear(64, 64 * self.kernel_size * self.kernel_size, bias=False) ) self.conv = common.default_conv(channels_in, channels_out, 1) self.ca = CA_layer(channels_in, channels_out, reduction) self.relu = nn.LeakyReLU(0.1, True) def forward(self, x): ''' :param x[0]: feature map: B * C * H * W :param x[1]: degradation representation: B * C ''' b, c, h, w = x[0].size() # branch 1 kernel = self.kernel(x[1]).view(-1, 1, self.kernel_size, self.kernel_size) out = self.relu(F.conv2d(x[0].view(1, -1, h, w), kernel, groups=b*c, padding=(self.kernel_size-1)//2)) out = self.conv(out.view(b, -1, h, w)) # branch 2 out = out + self.ca(x) return out class CA_layer(nn.Module): def __init__(self, channels_in, channels_out, reduction): super(CA_layer, self).__init__() self.conv_du = nn.Sequential( nn.Conv2d(channels_in, channels_in//reduction, 1, 1, 0, bias=False), nn.LeakyReLU(0.1, True), nn.Conv2d(channels_in // reduction, channels_out, 1, 1, 0, bias=False), nn.Sigmoid() ) def forward(self, x): ''' :param x[0]: feature map: B * C * H * W :param x[1]: degradation representation: B * C ''' att = self.conv_du(x[1][:, :, None, None]) return x[0] * att class DAB(nn.Module): def __init__(self, conv, n_feat, kernel_size, reduction): super(DAB, self).__init__() self.da_conv1 = DA_conv(n_feat, n_feat, kernel_size, reduction) self.da_conv2 = DA_conv(n_feat, n_feat, kernel_size, reduction) self.conv1 = conv(n_feat, n_feat, kernel_size) self.conv2 = conv(n_feat, n_feat, kernel_size) self.relu = nn.LeakyReLU(0.1, True) def forward(self, x): ''' :param x[0]: feature map: B * C * H * W :param x[1]: degradation representation: B * C ''' out = self.relu(self.da_conv1(x)) out = self.relu(self.conv1(out)) out = self.relu(self.da_conv2([out, x[1]])) out = self.conv2(out) + x[0] return out class DAG(nn.Module): def __init__(self, conv, n_feat, kernel_size, reduction, n_blocks): super(DAG, self).__init__() self.n_blocks = n_blocks modules_body = [ DAB(conv, n_feat, kernel_size, reduction) \ for _ in range(n_blocks) ] modules_body.append(conv(n_feat, n_feat, kernel_size)) self.body = nn.Sequential(*modules_body) def forward(self, x): ''' :param x[0]: feature map: B * C * H * W :param x[1]: degradation representation: B * C ''' res = x[0] for i in range(self.n_blocks): res = self.body[i]([res, x[1]]) res = self.body[-1](res) res = res + x[0] return res class DASR(nn.Module): def __init__(self, conv=common.default_conv): super(DASR, self).__init__() self.n_groups = 5 n_blocks = 5 n_feats = 64 kernel_size = 3 reduction = 8 scale = 1 # RGB mean for DIV2K rgb_mean = (0.4488, 0.4371, 0.4040) rgb_std = (1.0, 1.0, 1.0) self.sub_mean = common.MeanShift(255.0, rgb_mean, rgb_std) self.add_mean = common.MeanShift(255.0, rgb_mean, rgb_std, 1) # head module modules_head = [conv(3, n_feats, kernel_size)] self.head = nn.Sequential(*modules_head) # compress self.compress = nn.Sequential( nn.Linear(256, 64, bias=False), nn.LeakyReLU(0.1, True) ) # body modules_body = [ DAG(common.default_conv, n_feats, kernel_size, reduction, n_blocks) \ for _ in range(self.n_groups) ] modules_body.append(conv(n_feats, n_feats, kernel_size)) self.body = nn.Sequential(*modules_body) # tail modules_tail = [common.Upsampler(conv, scale, n_feats, act=False), conv(n_feats, 3, kernel_size)] self.tail = nn.Sequential(*modules_tail) def forward(self, x, k_v): k_v = self.compress(k_v) # sub mean x = self.sub_mean(x) # head 3-64 x = self.head(x) # body res = x # 0-5 for i in range(self.n_groups): res = self.body[i]([res, k_v]) res = self.body[-1](res) res = res + x # tail x = self.tail(res) # add mean x = self.add_mean(x) return x class Encoder(nn.Module): def __init__(self): super(Encoder, self).__init__() self.E = nn.Sequential( nn.Conv2d(3, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.LeakyReLU(0.1, True), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.LeakyReLU(0.1, True), nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(128), nn.LeakyReLU(0.1, True), nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.LeakyReLU(0.1, True), nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(256), nn.LeakyReLU(0.1, True), nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.LeakyReLU(0.1, True), nn.AdaptiveAvgPool2d(1), ) self.mlp = nn.Sequential( nn.Linear(256, 256), nn.LeakyReLU(0.1, True), nn.Linear(256, 256), ) def forward(self, x): # print(x.shape) fea = self.E(x).squeeze(-1).squeeze(-1) out = self.mlp(fea) return fea, out class BlindSR(nn.Module): def __init__(self): super(BlindSR, self).__init__() self.E = MoCo(base_encoder=Encoder) def forward(self, x, y): if self.training: x_query = x x_key = y fea, logits, labels = self.E(x_query, x_key) return fea, logits, labels else: # degradation-aware represenetion learning fea = self.E(x, x) return fea class Super(nn.Module): def __init__(self): super(Super, self).__init__() self.G = DASR() def forward(self, x, fea): # if self.training: sr = self.G(x, fea) return sr