Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| from torch.nn import functional as F | |
| from torch.nn.utils import spectral_norm | |
| import models.basicblock as B | |
| import functools | |
| import numpy as np | |
| """ | |
| # -------------------------------------------- | |
| # Discriminator_PatchGAN | |
| # Discriminator_UNet | |
| # -------------------------------------------- | |
| """ | |
| # -------------------------------------------- | |
| # PatchGAN discriminator | |
| # If n_layers = 3, then the receptive field is 70x70 | |
| # -------------------------------------------- | |
| class Discriminator_PatchGAN(nn.Module): | |
| def __init__(self, input_nc=3, ndf=64, n_layers=3, norm_type='spectral'): | |
| '''PatchGAN discriminator, receptive field = 70x70 if n_layers = 3 | |
| Args: | |
| input_nc: number of input channels | |
| ndf: base channel number | |
| n_layers: number of conv layer with stride 2 | |
| norm_type: 'batch', 'instance', 'spectral', 'batchspectral', instancespectral' | |
| Returns: | |
| tensor: score | |
| ''' | |
| super(Discriminator_PatchGAN, self).__init__() | |
| self.n_layers = n_layers | |
| norm_layer = self.get_norm_layer(norm_type=norm_type) | |
| kw = 4 | |
| padw = int(np.ceil((kw - 1.0) / 2)) | |
| sequence = [[self.use_spectral_norm(nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), norm_type), nn.LeakyReLU(0.2, True)]] | |
| nf = ndf | |
| for n in range(1, n_layers): | |
| nf_prev = nf | |
| nf = min(nf * 2, 512) | |
| sequence += [[self.use_spectral_norm(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), norm_type), | |
| norm_layer(nf), | |
| nn.LeakyReLU(0.2, True)]] | |
| nf_prev = nf | |
| nf = min(nf * 2, 512) | |
| sequence += [[self.use_spectral_norm(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), norm_type), | |
| norm_layer(nf), | |
| nn.LeakyReLU(0.2, True)]] | |
| sequence += [[self.use_spectral_norm(nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw), norm_type)]] | |
| self.model = nn.Sequential() | |
| for n in range(len(sequence)): | |
| self.model.add_module('child' + str(n), nn.Sequential(*sequence[n])) | |
| self.model.apply(self.weights_init) | |
| def use_spectral_norm(self, module, norm_type='spectral'): | |
| if 'spectral' in norm_type: | |
| return spectral_norm(module) | |
| return module | |
| def get_norm_layer(self, norm_type='instance'): | |
| if 'batch' in norm_type: | |
| norm_layer = functools.partial(nn.BatchNorm2d, affine=True) | |
| elif 'instance' in norm_type: | |
| norm_layer = functools.partial(nn.InstanceNorm2d, affine=False) | |
| else: | |
| norm_layer = functools.partial(nn.Identity) | |
| return norm_layer | |
| def weights_init(self, m): | |
| classname = m.__class__.__name__ | |
| if classname.find('Conv') != -1: | |
| m.weight.data.normal_(0.0, 0.02) | |
| elif classname.find('BatchNorm2d') != -1: | |
| m.weight.data.normal_(1.0, 0.02) | |
| m.bias.data.fill_(0) | |
| def forward(self, x): | |
| return self.model(x) | |
| class Discriminator_UNet(nn.Module): | |
| """Defines a U-Net discriminator with spectral normalization (SN)""" | |
| def __init__(self, input_nc=3, ndf=64): | |
| super(Discriminator_UNet, self).__init__() | |
| norm = spectral_norm | |
| self.conv0 = nn.Conv2d(input_nc, ndf, kernel_size=3, stride=1, padding=1) | |
| self.conv1 = norm(nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False)) | |
| self.conv2 = norm(nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False)) | |
| self.conv3 = norm(nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False)) | |
| # upsample | |
| self.conv4 = norm(nn.Conv2d(ndf * 8, ndf * 4, 3, 1, 1, bias=False)) | |
| self.conv5 = norm(nn.Conv2d(ndf * 4, ndf * 2, 3, 1, 1, bias=False)) | |
| self.conv6 = norm(nn.Conv2d(ndf * 2, ndf, 3, 1, 1, bias=False)) | |
| # extra | |
| self.conv7 = norm(nn.Conv2d(ndf, ndf, 3, 1, 1, bias=False)) | |
| self.conv8 = norm(nn.Conv2d(ndf, ndf, 3, 1, 1, bias=False)) | |
| self.conv9 = nn.Conv2d(ndf, 1, 3, 1, 1) | |
| print('using the UNet discriminator') | |
| def forward(self, x): | |
| x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True) | |
| x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True) | |
| x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True) | |
| x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True) | |
| # upsample | |
| x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=False) | |
| x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True) | |
| x4 = x4 + x2 | |
| x4 = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False) | |
| x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True) | |
| x5 = x5 + x1 | |
| x5 = F.interpolate(x5, scale_factor=2, mode='bilinear', align_corners=False) | |
| x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True) | |
| x6 = x6 + x0 | |
| # extra | |
| out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True) | |
| out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True) | |
| out = self.conv9(out) | |
| return out | |
| # -------------------------------------------- | |
| # VGG style Discriminator with 96x96 input | |
| # -------------------------------------------- | |
| class Discriminator_VGG_96(nn.Module): | |
| def __init__(self, in_nc=3, base_nc=64, ac_type='BL'): | |
| super(Discriminator_VGG_96, self).__init__() | |
| # features | |
| # hxw, c | |
| # 96, 64 | |
| conv0 = B.conv(in_nc, base_nc, kernel_size=3, mode='C') | |
| conv1 = B.conv(base_nc, base_nc, kernel_size=4, stride=2, mode='C'+ac_type) | |
| # 48, 64 | |
| conv2 = B.conv(base_nc, base_nc*2, kernel_size=3, stride=1, mode='C'+ac_type) | |
| conv3 = B.conv(base_nc*2, base_nc*2, kernel_size=4, stride=2, mode='C'+ac_type) | |
| # 24, 128 | |
| conv4 = B.conv(base_nc*2, base_nc*4, kernel_size=3, stride=1, mode='C'+ac_type) | |
| conv5 = B.conv(base_nc*4, base_nc*4, kernel_size=4, stride=2, mode='C'+ac_type) | |
| # 12, 256 | |
| conv6 = B.conv(base_nc*4, base_nc*8, kernel_size=3, stride=1, mode='C'+ac_type) | |
| conv7 = B.conv(base_nc*8, base_nc*8, kernel_size=4, stride=2, mode='C'+ac_type) | |
| # 6, 512 | |
| conv8 = B.conv(base_nc*8, base_nc*8, kernel_size=3, stride=1, mode='C'+ac_type) | |
| conv9 = B.conv(base_nc*8, base_nc*8, kernel_size=4, stride=2, mode='C'+ac_type) | |
| # 3, 512 | |
| self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, | |
| conv5, conv6, conv7, conv8, conv9) | |
| # classifier | |
| self.classifier = nn.Sequential( | |
| nn.Linear(512 * 3 * 3, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1)) | |
| def forward(self, x): | |
| x = self.features(x) | |
| x = x.view(x.size(0), -1) | |
| x = self.classifier(x) | |
| return x | |
| # -------------------------------------------- | |
| # VGG style Discriminator with 128x128 input | |
| # -------------------------------------------- | |
| class Discriminator_VGG_128(nn.Module): | |
| def __init__(self, in_nc=3, base_nc=64, ac_type='BL'): | |
| super(Discriminator_VGG_128, self).__init__() | |
| # features | |
| # hxw, c | |
| # 128, 64 | |
| conv0 = B.conv(in_nc, base_nc, kernel_size=3, mode='C') | |
| conv1 = B.conv(base_nc, base_nc, kernel_size=4, stride=2, mode='C'+ac_type) | |
| # 64, 64 | |
| conv2 = B.conv(base_nc, base_nc*2, kernel_size=3, stride=1, mode='C'+ac_type) | |
| conv3 = B.conv(base_nc*2, base_nc*2, kernel_size=4, stride=2, mode='C'+ac_type) | |
| # 32, 128 | |
| conv4 = B.conv(base_nc*2, base_nc*4, kernel_size=3, stride=1, mode='C'+ac_type) | |
| conv5 = B.conv(base_nc*4, base_nc*4, kernel_size=4, stride=2, mode='C'+ac_type) | |
| # 16, 256 | |
| conv6 = B.conv(base_nc*4, base_nc*8, kernel_size=3, stride=1, mode='C'+ac_type) | |
| conv7 = B.conv(base_nc*8, base_nc*8, kernel_size=4, stride=2, mode='C'+ac_type) | |
| # 8, 512 | |
| conv8 = B.conv(base_nc*8, base_nc*8, kernel_size=3, stride=1, mode='C'+ac_type) | |
| conv9 = B.conv(base_nc*8, base_nc*8, kernel_size=4, stride=2, mode='C'+ac_type) | |
| # 4, 512 | |
| self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, | |
| conv5, conv6, conv7, conv8, conv9) | |
| # classifier | |
| self.classifier = nn.Sequential(nn.Linear(512 * 4 * 4, 100), | |
| nn.LeakyReLU(0.2, True), | |
| nn.Linear(100, 1)) | |
| def forward(self, x): | |
| x = self.features(x) | |
| x = x.view(x.size(0), -1) | |
| x = self.classifier(x) | |
| return x | |
| # -------------------------------------------- | |
| # VGG style Discriminator with 192x192 input | |
| # -------------------------------------------- | |
| class Discriminator_VGG_192(nn.Module): | |
| def __init__(self, in_nc=3, base_nc=64, ac_type='BL'): | |
| super(Discriminator_VGG_192, self).__init__() | |
| # features | |
| # hxw, c | |
| # 192, 64 | |
| conv0 = B.conv(in_nc, base_nc, kernel_size=3, mode='C') | |
| conv1 = B.conv(base_nc, base_nc, kernel_size=4, stride=2, mode='C'+ac_type) | |
| # 96, 64 | |
| conv2 = B.conv(base_nc, base_nc*2, kernel_size=3, stride=1, mode='C'+ac_type) | |
| conv3 = B.conv(base_nc*2, base_nc*2, kernel_size=4, stride=2, mode='C'+ac_type) | |
| # 48, 128 | |
| conv4 = B.conv(base_nc*2, base_nc*4, kernel_size=3, stride=1, mode='C'+ac_type) | |
| conv5 = B.conv(base_nc*4, base_nc*4, kernel_size=4, stride=2, mode='C'+ac_type) | |
| # 24, 256 | |
| conv6 = B.conv(base_nc*4, base_nc*8, kernel_size=3, stride=1, mode='C'+ac_type) | |
| conv7 = B.conv(base_nc*8, base_nc*8, kernel_size=4, stride=2, mode='C'+ac_type) | |
| # 12, 512 | |
| conv8 = B.conv(base_nc*8, base_nc*8, kernel_size=3, stride=1, mode='C'+ac_type) | |
| conv9 = B.conv(base_nc*8, base_nc*8, kernel_size=4, stride=2, mode='C'+ac_type) | |
| # 6, 512 | |
| conv10 = B.conv(base_nc*8, base_nc*8, kernel_size=3, stride=1, mode='C'+ac_type) | |
| conv11 = B.conv(base_nc*8, base_nc*8, kernel_size=4, stride=2, mode='C'+ac_type) | |
| # 3, 512 | |
| self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, | |
| conv6, conv7, conv8, conv9, conv10, conv11) | |
| # classifier | |
| self.classifier = nn.Sequential(nn.Linear(512 * 3 * 3, 100), | |
| nn.LeakyReLU(0.2, True), | |
| nn.Linear(100, 1)) | |
| def forward(self, x): | |
| x = self.features(x) | |
| x = x.view(x.size(0), -1) | |
| x = self.classifier(x) | |
| return x | |
| # -------------------------------------------- | |
| # SN-VGG style Discriminator with 128x128 input | |
| # -------------------------------------------- | |
| class Discriminator_VGG_128_SN(nn.Module): | |
| def __init__(self): | |
| super(Discriminator_VGG_128_SN, self).__init__() | |
| # features | |
| # hxw, c | |
| # 128, 64 | |
| self.lrelu = nn.LeakyReLU(0.2, True) | |
| self.conv0 = spectral_norm(nn.Conv2d(3, 64, 3, 1, 1)) | |
| self.conv1 = spectral_norm(nn.Conv2d(64, 64, 4, 2, 1)) | |
| # 64, 64 | |
| self.conv2 = spectral_norm(nn.Conv2d(64, 128, 3, 1, 1)) | |
| self.conv3 = spectral_norm(nn.Conv2d(128, 128, 4, 2, 1)) | |
| # 32, 128 | |
| self.conv4 = spectral_norm(nn.Conv2d(128, 256, 3, 1, 1)) | |
| self.conv5 = spectral_norm(nn.Conv2d(256, 256, 4, 2, 1)) | |
| # 16, 256 | |
| self.conv6 = spectral_norm(nn.Conv2d(256, 512, 3, 1, 1)) | |
| self.conv7 = spectral_norm(nn.Conv2d(512, 512, 4, 2, 1)) | |
| # 8, 512 | |
| self.conv8 = spectral_norm(nn.Conv2d(512, 512, 3, 1, 1)) | |
| self.conv9 = spectral_norm(nn.Conv2d(512, 512, 4, 2, 1)) | |
| # 4, 512 | |
| # classifier | |
| self.linear0 = spectral_norm(nn.Linear(512 * 4 * 4, 100)) | |
| self.linear1 = spectral_norm(nn.Linear(100, 1)) | |
| def forward(self, x): | |
| x = self.lrelu(self.conv0(x)) | |
| x = self.lrelu(self.conv1(x)) | |
| x = self.lrelu(self.conv2(x)) | |
| x = self.lrelu(self.conv3(x)) | |
| x = self.lrelu(self.conv4(x)) | |
| x = self.lrelu(self.conv5(x)) | |
| x = self.lrelu(self.conv6(x)) | |
| x = self.lrelu(self.conv7(x)) | |
| x = self.lrelu(self.conv8(x)) | |
| x = self.lrelu(self.conv9(x)) | |
| x = x.view(x.size(0), -1) | |
| x = self.lrelu(self.linear0(x)) | |
| x = self.linear1(x) | |
| return x | |
| if __name__ == '__main__': | |
| x = torch.rand(1, 3, 96, 96) | |
| net = Discriminator_VGG_96() | |
| net.eval() | |
| with torch.no_grad(): | |
| y = net(x) | |
| print(y.size()) | |
| x = torch.rand(1, 3, 128, 128) | |
| net = Discriminator_VGG_128() | |
| net.eval() | |
| with torch.no_grad(): | |
| y = net(x) | |
| print(y.size()) | |
| x = torch.rand(1, 3, 192, 192) | |
| net = Discriminator_VGG_192() | |
| net.eval() | |
| with torch.no_grad(): | |
| y = net(x) | |
| print(y.size()) | |
| x = torch.rand(1, 3, 128, 128) | |
| net = Discriminator_VGG_128_SN() | |
| net.eval() | |
| with torch.no_grad(): | |
| y = net(x) | |
| print(y.size()) | |
| # run models/network_discriminator.py | |