| import torch |
| from torch import nn |
| import Deraining.model.common as common |
| import torch.nn.functional as F |
| from Deraining.moco.builder import MoCo |
|
|
|
|
| def make_model(args): |
| return BlindSR(args) |
|
|
|
|
| class DA_conv(nn.Module): |
| def __init__(self, channels_in, channels_out, kernel_size, reduction): |
| super(DA_conv, self).__init__() |
| self.channels_out = channels_out |
| self.channels_in = channels_in |
| self.kernel_size = kernel_size |
|
|
| self.kernel = nn.Sequential( |
| nn.Linear(64, 64, bias=False), |
| nn.LeakyReLU(0.1, True), |
| nn.Linear(64, 64 * self.kernel_size * self.kernel_size, bias=False) |
| ) |
| self.conv = common.default_conv(channels_in, channels_out, 1) |
| self.ca = CA_layer(channels_in, channels_out, reduction) |
|
|
| self.relu = nn.LeakyReLU(0.1, True) |
|
|
| def forward(self, x): |
| ''' |
| :param x[0]: feature map: B * C * H * W |
| :param x[1]: degradation representation: B * C |
| ''' |
| b, c, h, w = x[0].size() |
|
|
| |
| kernel = self.kernel(x[1]).view(-1, 1, self.kernel_size, self.kernel_size) |
| out = self.relu(F.conv2d(x[0].view(1, -1, h, w), kernel, groups=b*c, padding=(self.kernel_size-1)//2)) |
| out = self.conv(out.view(b, -1, h, w)) |
|
|
| |
| out = out + self.ca(x) |
|
|
| return out |
|
|
|
|
| class CA_layer(nn.Module): |
| def __init__(self, channels_in, channels_out, reduction): |
| super(CA_layer, self).__init__() |
| self.conv_du = nn.Sequential( |
| nn.Conv2d(channels_in, channels_in//reduction, 1, 1, 0, bias=False), |
| nn.LeakyReLU(0.1, True), |
| nn.Conv2d(channels_in // reduction, channels_out, 1, 1, 0, bias=False), |
| nn.Sigmoid() |
| ) |
|
|
| def forward(self, x): |
| ''' |
| :param x[0]: feature map: B * C * H * W |
| :param x[1]: degradation representation: B * C |
| ''' |
| att = self.conv_du(x[1][:, :, None, None]) |
|
|
| return x[0] * att |
|
|
|
|
| class DAB(nn.Module): |
| def __init__(self, conv, n_feat, kernel_size, reduction): |
| super(DAB, self).__init__() |
|
|
| self.da_conv1 = DA_conv(n_feat, n_feat, kernel_size, reduction) |
| self.da_conv2 = DA_conv(n_feat, n_feat, kernel_size, reduction) |
| self.conv1 = conv(n_feat, n_feat, kernel_size) |
| self.conv2 = conv(n_feat, n_feat, kernel_size) |
|
|
| self.relu = nn.LeakyReLU(0.1, True) |
|
|
| def forward(self, x): |
| ''' |
| :param x[0]: feature map: B * C * H * W |
| :param x[1]: degradation representation: B * C |
| ''' |
|
|
| out = self.relu(self.da_conv1(x)) |
| out = self.relu(self.conv1(out)) |
| out = self.relu(self.da_conv2([out, x[1]])) |
| out = self.conv2(out) + x[0] |
|
|
| return out |
|
|
|
|
| class DAG(nn.Module): |
| def __init__(self, conv, n_feat, kernel_size, reduction, n_blocks): |
| super(DAG, self).__init__() |
| self.n_blocks = n_blocks |
| modules_body = [ |
| DAB(conv, n_feat, kernel_size, reduction) \ |
| for _ in range(n_blocks) |
| ] |
| modules_body.append(conv(n_feat, n_feat, kernel_size)) |
|
|
| self.body = nn.Sequential(*modules_body) |
|
|
| def forward(self, x): |
| ''' |
| :param x[0]: feature map: B * C * H * W |
| :param x[1]: degradation representation: B * C |
| ''' |
| res = x[0] |
| for i in range(self.n_blocks): |
| res = self.body[i]([res, x[1]]) |
| res = self.body[-1](res) |
| res = res + x[0] |
|
|
| return res |
|
|
|
|
| class DASR(nn.Module): |
| def __init__(self, conv=common.default_conv): |
| super(DASR, self).__init__() |
|
|
| self.n_groups = 5 |
| n_blocks = 5 |
| n_feats = 64 |
| kernel_size = 3 |
| reduction = 8 |
| scale = 1 |
|
|
| |
| rgb_mean = (0.4488, 0.4371, 0.4040) |
| rgb_std = (1.0, 1.0, 1.0) |
| self.sub_mean = common.MeanShift(255.0, rgb_mean, rgb_std) |
| self.add_mean = common.MeanShift(255.0, rgb_mean, rgb_std, 1) |
|
|
| |
| modules_head = [conv(3, n_feats, kernel_size)] |
| self.head = nn.Sequential(*modules_head) |
|
|
| |
| self.compress = nn.Sequential( |
| nn.Linear(256, 64, bias=False), |
| nn.LeakyReLU(0.1, True) |
| ) |
|
|
| |
| modules_body = [ |
| DAG(common.default_conv, n_feats, kernel_size, reduction, n_blocks) \ |
| for _ in range(self.n_groups) |
| ] |
| modules_body.append(conv(n_feats, n_feats, kernel_size)) |
| self.body = nn.Sequential(*modules_body) |
|
|
| |
| modules_tail = [common.Upsampler(conv, scale, n_feats, act=False), |
| conv(n_feats, 3, kernel_size)] |
| self.tail = nn.Sequential(*modules_tail) |
|
|
| def forward(self, x, k_v): |
| k_v = self.compress(k_v) |
|
|
| |
| x = self.sub_mean(x) |
|
|
| |
| x = self.head(x) |
|
|
| |
| res = x |
| |
| for i in range(self.n_groups): |
| res = self.body[i]([res, k_v]) |
| res = self.body[-1](res) |
| res = res + x |
|
|
| |
| x = self.tail(res) |
|
|
| |
| x = self.add_mean(x) |
|
|
| return x |
|
|
|
|
| class Encoder(nn.Module): |
| def __init__(self): |
| super(Encoder, self).__init__() |
|
|
| self.E = nn.Sequential( |
| nn.Conv2d(3, 64, kernel_size=3, padding=1), |
| nn.BatchNorm2d(64), |
| nn.LeakyReLU(0.1, True), |
| nn.Conv2d(64, 64, kernel_size=3, padding=1), |
| nn.BatchNorm2d(64), |
| nn.LeakyReLU(0.1, True), |
| nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), |
| nn.BatchNorm2d(128), |
| nn.LeakyReLU(0.1, True), |
| nn.Conv2d(128, 128, kernel_size=3, padding=1), |
| nn.BatchNorm2d(128), |
| nn.LeakyReLU(0.1, True), |
| nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1), |
| nn.BatchNorm2d(256), |
| nn.LeakyReLU(0.1, True), |
| nn.Conv2d(256, 256, kernel_size=3, padding=1), |
| nn.BatchNorm2d(256), |
| nn.LeakyReLU(0.1, True), |
| nn.AdaptiveAvgPool2d(1), |
| ) |
| self.mlp = nn.Sequential( |
| nn.Linear(256, 256), |
| nn.LeakyReLU(0.1, True), |
| nn.Linear(256, 256), |
| ) |
|
|
| def forward(self, x): |
| |
| fea = self.E(x).squeeze(-1).squeeze(-1) |
|
|
| out = self.mlp(fea) |
|
|
| return fea, out |
|
|
|
|
| class BlindSR(nn.Module): |
| def __init__(self): |
| super(BlindSR, self).__init__() |
| self.E = MoCo(base_encoder=Encoder) |
|
|
| def forward(self, x, y): |
| if self.training: |
| x_query = x |
| x_key = y |
| fea, logits, labels = self.E(x_query, x_key) |
| return fea, logits, labels |
| else: |
| |
| fea = self.E(x, x) |
|
|
| return fea |
| class Super(nn.Module): |
| def __init__(self): |
| super(Super, self).__init__() |
| self.G = DASR() |
| def forward(self, x, fea): |
| |
| sr = self.G(x, fea) |
| return sr |