| import functools
|
| import torch.nn as nn
|
|
|
|
|
| class ActNorm(nn.Module):
|
| def __init__(self, num_features, logdet=False, affine=True,
|
| allow_reverse_init=False):
|
| assert affine
|
| super().__init__()
|
| self.logdet = logdet
|
| self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
|
| self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
|
| self.allow_reverse_init = allow_reverse_init
|
|
|
| self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
|
|
|
| def initialize(self, input):
|
| with torch.no_grad():
|
| flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
|
| mean = (
|
| flatten.mean(1)
|
| .unsqueeze(1)
|
| .unsqueeze(2)
|
| .unsqueeze(3)
|
| .permute(1, 0, 2, 3)
|
| )
|
| std = (
|
| flatten.std(1)
|
| .unsqueeze(1)
|
| .unsqueeze(2)
|
| .unsqueeze(3)
|
| .permute(1, 0, 2, 3)
|
| )
|
|
|
| self.loc.data.copy_(-mean)
|
| self.scale.data.copy_(1 / (std + 1e-6))
|
|
|
| def forward(self, input, reverse=False):
|
| if reverse:
|
| return self.reverse(input)
|
| if len(input.shape) == 2:
|
| input = input[:, :, None, None]
|
| squeeze = True
|
| else:
|
| squeeze = False
|
|
|
| _, _, height, width = input.shape
|
|
|
| if self.training and self.initialized.item() == 0:
|
| self.initialize(input)
|
| self.initialized.fill_(1)
|
|
|
| h = self.scale * (input + self.loc)
|
|
|
| if squeeze:
|
| h = h.squeeze(-1).squeeze(-1)
|
|
|
| if self.logdet:
|
| log_abs = torch.log(torch.abs(self.scale))
|
| logdet = height * width * torch.sum(log_abs)
|
| logdet = logdet * torch.ones(input.shape[0]).to(input)
|
| return h, logdet
|
|
|
| return h
|
|
|
| def reverse(self, output):
|
| if self.training and self.initialized.item() == 0:
|
| if not self.allow_reverse_init:
|
| raise RuntimeError(
|
| "Initializing ActNorm in reverse direction is "
|
| "disabled by default. Use allow_reverse_init=True to enable."
|
| )
|
| else:
|
| self.initialize(output)
|
| self.initialized.fill_(1)
|
|
|
| if len(output.shape) == 2:
|
| output = output[:, :, None, None]
|
| squeeze = True
|
| else:
|
| squeeze = False
|
|
|
| h = output / self.scale - self.loc
|
|
|
| if squeeze:
|
| h = h.squeeze(-1).squeeze(-1)
|
| return h
|
|
|
| def weights_init(m):
|
| classname = m.__class__.__name__
|
| if classname.find('Conv') != -1:
|
| nn.init.normal_(m.weight.data, 0.0, 0.02)
|
| elif classname.find('BatchNorm') != -1:
|
| nn.init.normal_(m.weight.data, 1.0, 0.02)
|
| nn.init.constant_(m.bias.data, 0)
|
|
|
|
|
| class NLayerDiscriminator(nn.Module):
|
| """Defines a PatchGAN discriminator as in Pix2Pix
|
| --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
|
| """
|
| def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
|
| """Construct a PatchGAN discriminator
|
| Parameters:
|
| input_nc (int) -- the number of channels in input images
|
| ndf (int) -- the number of filters in the last conv layer
|
| n_layers (int) -- the number of conv layers in the discriminator
|
| norm_layer -- normalization layer
|
| """
|
| super(NLayerDiscriminator, self).__init__()
|
| if not use_actnorm:
|
| norm_layer = nn.BatchNorm2d
|
| else:
|
| norm_layer = ActNorm
|
| if type(norm_layer) == functools.partial:
|
| use_bias = norm_layer.func != nn.BatchNorm2d
|
| else:
|
| use_bias = norm_layer != nn.BatchNorm2d
|
|
|
| kw = 4
|
| padw = 1
|
| sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
|
| nf_mult = 1
|
| nf_mult_prev = 1
|
| for n in range(1, n_layers):
|
| nf_mult_prev = nf_mult
|
| nf_mult = min(2 ** n, 8)
|
| sequence += [
|
| nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
|
| norm_layer(ndf * nf_mult),
|
| nn.LeakyReLU(0.2, True)
|
| ]
|
|
|
| nf_mult_prev = nf_mult
|
| nf_mult = min(2 ** n_layers, 8)
|
| sequence += [
|
| nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
|
| norm_layer(ndf * nf_mult),
|
| nn.LeakyReLU(0.2, True)
|
| ]
|
|
|
| sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
|
| self.main = nn.Sequential(*sequence)
|
|
|
| def forward(self, input):
|
| """Standard forward."""
|
| return self.main(input)
|
|
|
| class NLayerDiscriminator1dFeats(NLayerDiscriminator):
|
| """Defines a PatchGAN discriminator as in Pix2Pix
|
| --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
|
| """
|
| def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
|
| """Construct a PatchGAN discriminator
|
| Parameters:
|
| input_nc (int) -- the number of channels in input feats
|
| ndf (int) -- the number of filters in the last conv layer
|
| n_layers (int) -- the number of conv layers in the discriminator
|
| norm_layer -- normalization layer
|
| """
|
| super().__init__(input_nc=input_nc, ndf=64, n_layers=n_layers, use_actnorm=use_actnorm)
|
|
|
| if not use_actnorm:
|
| norm_layer = nn.BatchNorm1d
|
| else:
|
| norm_layer = ActNorm
|
| if type(norm_layer) == functools.partial:
|
| use_bias = norm_layer.func != nn.BatchNorm1d
|
| else:
|
| use_bias = norm_layer != nn.BatchNorm1d
|
|
|
| kw = 4
|
| padw = 1
|
| sequence = [nn.Conv1d(input_nc, input_nc//2, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
|
| nf_mult = input_nc//2
|
| nf_mult_prev = 1
|
| for n in range(1, n_layers):
|
| nf_mult_prev = nf_mult
|
| nf_mult = max(nf_mult_prev // (2 ** n), 8)
|
| sequence += [
|
| nn.Conv1d(nf_mult_prev, nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
|
| norm_layer(nf_mult),
|
| nn.LeakyReLU(0.2, True)
|
| ]
|
|
|
| nf_mult_prev = nf_mult
|
| nf_mult = max(nf_mult_prev // (2 ** n), 8)
|
| sequence += [
|
| nn.Conv1d(nf_mult_prev, nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
|
| norm_layer(nf_mult),
|
| nn.LeakyReLU(0.2, True)
|
| ]
|
| nf_mult_prev = nf_mult
|
| nf_mult = max(nf_mult_prev // (2 ** n), 8)
|
| sequence += [
|
| nn.Conv1d(nf_mult_prev, nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
|
| norm_layer(nf_mult),
|
| nn.LeakyReLU(0.2, True)
|
| ]
|
|
|
| sequence += [nn.Conv1d(nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
|
| self.main = nn.Sequential(*sequence)
|
|
|
|
|
| class NLayerDiscriminator1dSpecs(NLayerDiscriminator):
|
| """Defines a PatchGAN discriminator as in Pix2Pix
|
| --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
|
| """
|
| def __init__(self, input_nc=80, ndf=64, n_layers=3, use_actnorm=False):
|
| """Construct a PatchGAN discriminator
|
| Parameters:
|
| input_nc (int) -- the number of channels in input specs
|
| ndf (int) -- the number of filters in the last conv layer
|
| n_layers (int) -- the number of conv layers in the discriminator
|
| norm_layer -- normalization layer
|
| """
|
| super().__init__(input_nc=input_nc, ndf=64, n_layers=n_layers, use_actnorm=use_actnorm)
|
|
|
| if not use_actnorm:
|
| norm_layer = nn.BatchNorm1d
|
| else:
|
| norm_layer = ActNorm
|
| if type(norm_layer) == functools.partial:
|
| use_bias = norm_layer.func != nn.BatchNorm1d
|
| else:
|
| use_bias = norm_layer != nn.BatchNorm1d
|
|
|
| kw = 4
|
| padw = 1
|
| sequence = [nn.Conv1d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
|
| nf_mult = 1
|
| nf_mult_prev = 1
|
| for n in range(1, n_layers):
|
| nf_mult_prev = nf_mult
|
| nf_mult = min(2 ** n, 8)
|
| sequence += [
|
| nn.Conv1d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
|
| norm_layer(ndf * nf_mult),
|
| nn.LeakyReLU(0.2, True)
|
| ]
|
|
|
| nf_mult_prev = nf_mult
|
| nf_mult = min(2 ** n_layers, 8)
|
| sequence += [
|
| nn.Conv1d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
|
| norm_layer(ndf * nf_mult),
|
| nn.LeakyReLU(0.2, True)
|
| ]
|
|
|
| sequence += [nn.Conv1d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
|
| self.main = nn.Sequential(*sequence)
|
|
|
| def forward(self, input):
|
| """Standard forward."""
|
|
|
| input = input.squeeze(1)
|
| input = self.main(input)
|
| return input
|
|
|
|
|
| if __name__ == '__main__':
|
| import torch
|
|
|
|
|
| disc_in_channels = 2048
|
| disc_num_layers = 2
|
| use_actnorm = False
|
| disc_ndf = 64
|
| discriminator = NLayerDiscriminator1dFeats(input_nc=disc_in_channels, n_layers=disc_num_layers,
|
| use_actnorm=use_actnorm, ndf=disc_ndf).apply(weights_init)
|
| inputs = torch.rand((6, 2048, 212))
|
| outputs = discriminator(inputs)
|
| print(outputs.shape)
|
|
|
|
|
| disc_in_channels = 1
|
| disc_num_layers = 3
|
| use_actnorm = False
|
| disc_ndf = 64
|
| discriminator = NLayerDiscriminator(input_nc=disc_in_channels, n_layers=disc_num_layers,
|
| use_actnorm=use_actnorm, ndf=disc_ndf).apply(weights_init)
|
| inputs = torch.rand((6, 1, 80, 848))
|
| outputs = discriminator(inputs)
|
| print(outputs.shape)
|
|
|
|
|
| disc_in_channels = 3
|
| disc_num_layers = 3
|
| use_actnorm = False
|
| disc_ndf = 64
|
| discriminator = NLayerDiscriminator(input_nc=disc_in_channels, n_layers=disc_num_layers,
|
| use_actnorm=use_actnorm, ndf=disc_ndf).apply(weights_init)
|
| inputs = torch.rand((6, 3, 256, 256))
|
| outputs = discriminator(inputs)
|
| print(outputs.shape)
|
|
|