import torch import torch.nn as nn import torch.nn.functional as F class _conv(nn.Conv2d): def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias): super(_conv, self).__init__(in_channels = in_channels, out_channels = out_channels, kernel_size = kernel_size, stride = stride, padding = (kernel_size) // 2, bias = True) self.weight.data = torch.normal(torch.zeros((out_channels, in_channels, kernel_size, kernel_size)), 0.02) self.bias.data = torch.zeros((out_channels)) for p in self.parameters(): p.requires_grad = True class conv(nn.Module): def __init__(self, in_channel, out_channel, kernel_size, BN = False, act = None, stride = 1, bias = True): super(conv, self).__init__() m = [] m.append(_conv(in_channels = in_channel, out_channels = out_channel, kernel_size = kernel_size, stride = stride, padding = (kernel_size) // 2, bias = True)) if BN: m.append(nn.BatchNorm2d(num_features = out_channel)) if act is not None: m.append(act) self.body = nn.Sequential(*m) def forward(self, x): out = self.body(x) return out class ResBlock(nn.Module): def __init__(self, channels, kernel_size, act = nn.ReLU(inplace = True), bias = True): super(ResBlock, self).__init__() m = [] m.append(conv(channels, channels, kernel_size, BN = True, act = act)) m.append(conv(channels, channels, kernel_size, BN = True, act = None)) self.body = nn.Sequential(*m) def forward(self, x): res = self.body(x) res += x return res class BasicBlock(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, num_res_block, act = nn.ReLU(inplace = True)): super(BasicBlock, self).__init__() m = [] self.conv = conv(in_channels, out_channels, kernel_size, BN = False, act = act) for i in range(num_res_block): m.append(ResBlock(out_channels, kernel_size, act)) m.append(conv(out_channels, out_channels, kernel_size, BN = True, act = None)) self.body = nn.Sequential(*m) def forward(self, x): res = self.conv(x) out = self.body(res) out += res return out class Upsampler(nn.Module): def __init__(self, channel, kernel_size, scale, act = nn.ReLU(inplace = True)): super(Upsampler, self).__init__() m = [] m.append(conv(channel, channel * scale * scale, kernel_size)) m.append(nn.PixelShuffle(scale)) if act is not None: m.append(act) self.body = nn.Sequential(*m) def forward(self, x): out = self.body(x) return out class discrim_block(nn.Module): def __init__(self, in_feats, out_feats, kernel_size, act = nn.LeakyReLU(inplace = True)): super(discrim_block, self).__init__() m = [] m.append(conv(in_feats, out_feats, kernel_size, BN = True, act = act)) m.append(conv(out_feats, out_feats, kernel_size, BN = True, act = act, stride = 2)) self.body = nn.Sequential(*m) def forward(self, x): out = self.body(x) return out class Generator(nn.Module): def __init__(self, img_feat = 3, n_feats = 64, kernel_size = 3, num_block = 16, act = nn.PReLU(), scale=4): super(Generator, self).__init__() self.conv01 = conv(in_channel = img_feat, out_channel = n_feats, kernel_size = 9, BN = False, act = act) resblocks = [ResBlock(channels = n_feats, kernel_size = 3, act = act) for _ in range(num_block)] self.body = nn.Sequential(*resblocks) self.conv02 = conv(in_channel = n_feats, out_channel = n_feats, kernel_size = 3, BN = True, act = None) if(scale == 4): upsample_blocks = [Upsampler(channel = n_feats, kernel_size = 3, scale = 2, act = act) for _ in range(2)] else: upsample_blocks = [Upsampler(channel = n_feats, kernel_size = 3, scale = scale, act = act)] self.tail = nn.Sequential(*upsample_blocks) self.last_conv = conv(in_channel = n_feats, out_channel = img_feat, kernel_size = 3, BN = False, act = nn.Tanh()) def forward(self, x): x = self.conv01(x) _skip_connection = x x = self.body(x) x = self.conv02(x) feat = x + _skip_connection x = self.tail(feat) x = self.last_conv(x) return x, feat class Discriminator(nn.Module): def __init__(self, img_feat = 3, n_feats = 64, kernel_size = 3, act = nn.LeakyReLU(inplace = True), num_of_block = 3, patch_size = 96): super(Discriminator, self).__init__() self.act = act self.conv01 = conv(in_channel = img_feat, out_channel = n_feats, kernel_size = 3, BN = False, act = self.act) self.conv02 = conv(in_channel = n_feats, out_channel = n_feats, kernel_size = 3, BN = False, act = self.act, stride = 2) body = [discrim_block(in_feats = n_feats * (2 ** i), out_feats = n_feats * (2 ** (i + 1)), kernel_size = 3, act = self.act) for i in range(num_of_block)] self.body = nn.Sequential(*body) self.linear_size = ((patch_size // (2 ** (num_of_block + 1))) ** 2) * (n_feats * (2 ** num_of_block)) tail = [] tail.append(nn.Linear(self.linear_size, 1024)) tail.append(self.act) tail.append(nn.Linear(1024, 1)) tail.append(nn.Sigmoid()) self.tail = nn.Sequential(*tail) def forward(self, x): x = self.conv01(x) x = self.conv02(x) x = self.body(x) x = x.view(-1, self.linear_size) x = self.tail(x) return x