|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
|
|
|
class RD_block(nn.Module): |
|
|
def __init__(self, channels, growth_channels, residual_beta): |
|
|
super(RD_block, self).__init__() |
|
|
self.residual_beta = residual_beta |
|
|
self.conv1 = nn.Conv2d(channels + growth_channels * 0, growth_channels, kernel_size=3, |
|
|
stride=1, padding=1) |
|
|
self.conv2 = nn.Conv2d(channels + growth_channels * 1, growth_channels, kernel_size=3, |
|
|
stride=1, padding=1) |
|
|
self.conv3 = nn.Conv2d(channels + growth_channels * 2, growth_channels, kernel_size=3, |
|
|
stride=1, padding=1) |
|
|
self.conv4 = nn.Conv2d(channels + growth_channels * 3, growth_channels, kernel_size=3, |
|
|
stride=1, padding=1) |
|
|
self.conv5 = nn.Conv2d(channels + growth_channels * 4, channels, kernel_size=3, |
|
|
stride=1, padding=1) |
|
|
|
|
|
self.activation = nn.LeakyReLU(0.2, inplace=True) |
|
|
self.identity = nn.Identity() |
|
|
|
|
|
def forward(self, x): |
|
|
temp = x |
|
|
out1 = self.activation(self.conv1(x)) |
|
|
out2 = self.activation(self.conv2(torch.cat([x, out1], 1))) |
|
|
out3 = self.activation(self.conv3(torch.cat([x, out1, out2, ], 1))) |
|
|
out4 = self.activation(self.conv4(torch.cat([x, out1, out2, out3, ], 1))) |
|
|
out5 = self.identity(self.conv5(torch.cat([x, out1, out2, out3, out4], 1))) |
|
|
out6 = torch.mul(out5, self.residual_beta) |
|
|
out = torch.add(out6, temp) |
|
|
return out |
|
|
|
|
|
|
|
|
class RRD_block(nn.Module): |
|
|
def __init__(self, channels, growth_channels, residual_beta): |
|
|
self.residual_beta = residual_beta |
|
|
super(RRD_block, self).__init__() |
|
|
self.block1 = RD_block(channels, growth_channels, residual_beta) |
|
|
self.block2 = RD_block(channels, growth_channels, residual_beta) |
|
|
self.block3 = RD_block(channels, growth_channels, residual_beta) |
|
|
|
|
|
def forward(self, x): |
|
|
out1 = self.block1(x) |
|
|
out2 = self.block2(out1) |
|
|
out3 = self.block3(out2) |
|
|
out4 = torch.mul(out3, self.residual_beta) |
|
|
out = torch.add(out4, x) |
|
|
return out |
|
|
|
|
|
|
|
|
class UpsampleBlock(nn.Module): |
|
|
def __init__(self, in_c, upscale_factor): |
|
|
super(UpsampleBlock, self).__init__() |
|
|
self.upsample = nn.Upsample(scale_factor=upscale_factor, mode="nearest") |
|
|
self.conv = nn.Conv2d(in_c, in_c, 3, 1, 1, bias=True) |
|
|
self.act = nn.LeakyReLU(0.2, inplace=True) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.act(self.conv(self.upsample(x))) |
|
|
|
|
|
|
|
|
class DRRRDBNet(nn.Module): |
|
|
def __init__(self, in_channels, out_channels, channels, growth_channels, upscale_factor, residual_beta): |
|
|
super(DRRRDBNet, self).__init__() |
|
|
self.conv1 = nn.Conv2d(in_channels, channels, kernel_size=3, |
|
|
stride=1, padding=1) |
|
|
self.res_block = nn.Sequential(*[RRD_block(channels, growth_channels, residual_beta) for _ in range(6)]) |
|
|
self.res_block2 = nn.Sequential(*[RRD_block(channels, growth_channels, residual_beta) for _ in range(6)]) |
|
|
self.res_block3 = nn.Sequential(*[RRD_block(channels, growth_channels, residual_beta) for _ in range(6)]) |
|
|
self.res_block4 = nn.Sequential(*[RRD_block(channels, growth_channels, residual_beta) for _ in range(5)]) |
|
|
self.dropout = nn.Dropout(0.1) |
|
|
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, |
|
|
stride=1, padding=1) |
|
|
self.upsample = nn.Sequential( |
|
|
UpsampleBlock(channels, upscale_factor), UpsampleBlock(channels, upscale_factor), |
|
|
) |
|
|
|
|
|
self.conv3 = nn.Sequential( |
|
|
nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1)), |
|
|
nn.LeakyReLU(0.2, True), |
|
|
) |
|
|
|
|
|
self.conv4 = nn.Conv2d(channels, out_channels, (3, 3), (1, 1), (1, 1)) |
|
|
|
|
|
def forward(self, x): |
|
|
out1 = self.conv1(x) |
|
|
t_out1 = self.res_block(out1) |
|
|
t_out2 = self.dropout(t_out1) |
|
|
t_out3 = self.res_block2(t_out2) |
|
|
t_out4 = self.dropout(t_out3) |
|
|
t_out5 = self.res_block3(t_out4) |
|
|
t_out6 = self.dropout(t_out5) |
|
|
out2 = self.conv2(self.res_block4(t_out6)) |
|
|
out3 = torch.add(out2, out1) |
|
|
out4 = self.upsample(out3) |
|
|
out5 = self.conv3(out4) |
|
|
out = self.conv4(out5) |
|
|
|
|
|
out = torch.clamp_(out, 0.0, 1.0) |
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Discriminator(nn.Module): |
|
|
def __init__(self) -> None: |
|
|
super(Discriminator, self).__init__() |
|
|
self.features = nn.Sequential( |
|
|
|
|
|
nn.Conv2d(3, 64, (3, 3), (1, 1), (1, 1), bias=True), |
|
|
nn.LeakyReLU(0.2, True), |
|
|
|
|
|
nn.Conv2d(64, 64, (4, 4), (2, 2), (1, 1), bias=False), |
|
|
nn.BatchNorm2d(64), |
|
|
nn.LeakyReLU(0.2, True), |
|
|
nn.Conv2d(64, 128, (3, 3), (1, 1), (1, 1), bias=False), |
|
|
nn.BatchNorm2d(128), |
|
|
nn.LeakyReLU(0.2, True), |
|
|
|
|
|
nn.Conv2d(128, 128, (4, 4), (2, 2), (1, 1), bias=False), |
|
|
nn.BatchNorm2d(128), |
|
|
nn.LeakyReLU(0.2, True), |
|
|
nn.Conv2d(128, 256, (3, 3), (1, 1), (1, 1), bias=False), |
|
|
nn.BatchNorm2d(256), |
|
|
nn.LeakyReLU(0.2, True), |
|
|
|
|
|
nn.Conv2d(256, 256, (4, 4), (2, 2), (1, 1), bias=False), |
|
|
nn.BatchNorm2d(256), |
|
|
nn.LeakyReLU(0.2, True), |
|
|
nn.Conv2d(256, 512, (3, 3), (1, 1), (1, 1), bias=False), |
|
|
nn.BatchNorm2d(512), |
|
|
nn.LeakyReLU(0.2, True), |
|
|
|
|
|
nn.Conv2d(512, 512, (4, 4), (2, 2), (1, 1), bias=False), |
|
|
nn.BatchNorm2d(512), |
|
|
nn.LeakyReLU(0.2, True), |
|
|
nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), bias=False), |
|
|
nn.BatchNorm2d(512), |
|
|
nn.LeakyReLU(0.2, True), |
|
|
|
|
|
nn.Conv2d(512, 512, (4, 4), (2, 2), (1, 1), bias=False), |
|
|
nn.BatchNorm2d(512), |
|
|
nn.LeakyReLU(0.2, True), |
|
|
nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), bias=False), |
|
|
nn.BatchNorm2d(512), |
|
|
nn.LeakyReLU(0.2, True), |
|
|
|
|
|
nn.Conv2d(512, 512, (4, 4), (2, 2), (1, 1), bias=False), |
|
|
nn.BatchNorm2d(512), |
|
|
nn.LeakyReLU(0.2, True) |
|
|
) |
|
|
self.classifier = nn.Sequential( |
|
|
nn.Linear(512 * 4 * 4, 100), |
|
|
nn.LeakyReLU(0.2, True), |
|
|
nn.Linear(100, 1) |
|
|
) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
out = self.features(x) |
|
|
out = torch.flatten(out, 1) |
|
|
out = self.classifier(out) |
|
|
|
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
def weights_init(m): |
|
|
if isinstance(m, nn.Conv2d): |
|
|
nn.init.kaiming_normal_(m.weight) |
|
|
m.weight.data *= 0.1 |
|
|
if m.bias is not None: |
|
|
nn.init.constant_(m.bias, 0) |
|
|
|