Spaces:
Runtime error
Runtime error
| import functools | |
| import torch | |
| import torch.nn as nn | |
| from torch.nn import init | |
| import torch.optim as optim | |
| import torch.nn.functional as F | |
| import src.models.big.layers as layers | |
| from src.models.parameter import labels_dim, parameter | |
| from src.models.neuralnetwork import NeuralNetwork | |
| # Architectures for G | |
| # Attention is passed in in the format '32_64' to mean applying an attention | |
| # block at both resolution 32x32 and 64x64. Just '64' will apply at 64x64. | |
| def G_arch(ch=64, attention='64', ksize='333333', dilation='111111'): | |
| arch = {} | |
| arch[512] = {'in_channels' : [ch * item for item in [16, 16, 8, 8, 4, 2, 1]], | |
| 'out_channels' : [ch * item for item in [16, 8, 8, 4, 2, 1, 1]], | |
| 'upsample' : [True] * 7, | |
| 'resolution' : [8, 16, 32, 64, 128, 256, 512], | |
| 'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')]) | |
| for i in range(3,10)}} | |
| arch[256] = {'in_channels' : [ch * item for item in [16, 16, 8, 8, 4, 2]], | |
| 'out_channels' : [ch * item for item in [16, 8, 8, 4, 2, 1]], | |
| 'upsample' : [True] * 6, | |
| 'resolution' : [8, 16, 32, 64, 128, 256], | |
| 'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')]) | |
| for i in range(3,9)}} | |
| arch[128] = {'in_channels' : [ch * item for item in [16, 16, 8, 4, 2]], | |
| 'out_channels' : [ch * item for item in [16, 8, 4, 2, 1]], | |
| 'upsample' : [True] * 5, | |
| 'resolution' : [8, 16, 32, 64, 128], | |
| 'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')]) | |
| for i in range(3,8)}} | |
| arch[64] = {'in_channels' : [ch * item for item in [16, 16, 8, 4]], | |
| 'out_channels' : [ch * item for item in [16, 8, 4, 2]], | |
| 'upsample' : [True] * 4, | |
| 'resolution' : [8, 16, 32, 64], | |
| 'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')]) | |
| for i in range(3,7)}} | |
| arch[32] = {'in_channels' : [ch * item for item in [4, 4, 4]], | |
| 'out_channels' : [ch * item for item in [4, 4, 4]], | |
| 'upsample' : [True] * 3, | |
| 'resolution' : [8, 16, 32], | |
| 'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')]) | |
| for i in range(3,6)}} | |
| return arch | |
| class Generator(NeuralNetwork): | |
| def __init__(self, G_ch=64, dim_z=128, bottom_width=4, resolution=64, labels_dim=labels_dim, | |
| G_kernel_size=3, G_attn='64', n_classes=1, | |
| num_G_SVs=1, num_G_SV_itrs=1, | |
| G_shared=True, shared_dim=0, hier=False, | |
| cross_replica=False, mybn=False, | |
| G_activation=nn.ReLU(inplace=False), | |
| G_lr=5e-5, G_B1=0.0, G_B2=0.999, adam_eps=1e-8, | |
| BN_eps=1e-5, SN_eps=1e-12, G_mixed_precision=False, G_fp16=False, | |
| G_init='ortho', skip_init=False, no_optim=False, | |
| G_param='SN', norm_style='bn', | |
| **kwargs): | |
| super(Generator, self).__init__() | |
| # Channel width mulitplier | |
| self.ch = G_ch | |
| # Dimensionality of the latent space | |
| self.dim_z = dim_z | |
| # The initial spatial dimensions | |
| self.bottom_width = bottom_width | |
| # Resolution of the output | |
| self.resolution = resolution | |
| # Kernel size? | |
| self.kernel_size = G_kernel_size | |
| # Attention? | |
| self.attention = G_attn | |
| # number of classes, for use in categorical conditional generation | |
| self.n_classes = n_classes | |
| # Use shared embeddings? | |
| self.G_shared = G_shared | |
| # Dimensionality of the shared embedding? Unused if not using G_shared | |
| self.shared_dim = shared_dim if shared_dim > 0 else dim_z | |
| # Hierarchical latent space? | |
| self.hier = hier | |
| # Cross replica batchnorm? | |
| self.cross_replica = cross_replica | |
| # Use my batchnorm? | |
| self.mybn = mybn | |
| # nonlinearity for residual blocks | |
| self.activation = G_activation | |
| # Initialization style | |
| self.init = G_init | |
| # Parameterization style | |
| self.G_param = G_param | |
| # Normalization style | |
| self.norm_style = norm_style | |
| # Epsilon for BatchNorm? | |
| self.BN_eps = BN_eps | |
| # Epsilon for Spectral Norm? | |
| self.SN_eps = SN_eps | |
| # fp16? | |
| self.fp16 = G_fp16 | |
| # Architecture dict | |
| self.arch = G_arch(self.ch, self.attention)[resolution] | |
| # If using hierarchical latents, adjust z | |
| if self.hier: | |
| # Number of places z slots into | |
| self.num_slots = len(self.arch['in_channels']) + 1 | |
| self.z_chunk_size = (self.dim_z // self.num_slots) | |
| # Recalculate latent dimensionality for even splitting into chunks | |
| self.dim_z = self.z_chunk_size * self.num_slots | |
| else: | |
| self.num_slots = 1 | |
| self.z_chunk_size = 0 | |
| # Which convs, batchnorms, and linear layers to use | |
| if self.G_param == 'SN': | |
| self.which_conv = functools.partial(layers.SNConv2d, | |
| kernel_size=3, padding=1, | |
| num_svs=num_G_SVs, num_itrs=num_G_SV_itrs, | |
| eps=self.SN_eps) | |
| self.which_linear = functools.partial(layers.SNLinear, | |
| num_svs=num_G_SVs, num_itrs=num_G_SV_itrs, | |
| eps=self.SN_eps) | |
| else: | |
| self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1) | |
| self.which_linear = nn.Linear | |
| # We use a non-spectral-normed embedding here regardless; | |
| # For some reason applying SN to G's embedding seems to randomly cripple G | |
| self.which_embedding = nn.Embedding | |
| bn_linear = (functools.partial(self.which_linear, bias=False) if self.G_shared | |
| else self.which_embedding) | |
| self.which_bn = functools.partial(layers.ccbn, | |
| which_linear=bn_linear, | |
| cross_replica=self.cross_replica, | |
| mybn=self.mybn, | |
| input_size=(self.shared_dim + self.z_chunk_size if self.G_shared | |
| else self.n_classes), | |
| norm_style=self.norm_style, | |
| eps=self.BN_eps) | |
| # Prepare model | |
| # prepare label input | |
| self.transform_label_layer = torch.nn.Linear(labels_dim, 128) | |
| # If not using shared embeddings, self.shared is just a passthrough | |
| self.shared = (self.which_embedding(n_classes, self.shared_dim) if G_shared | |
| else layers.identity()) | |
| # First linear layer | |
| self.linear = self.which_linear(self.dim_z // self.num_slots, | |
| self.arch['in_channels'][0] * (self.bottom_width **2)) | |
| # self.blocks is a doubly-nested list of modules, the outer loop intended | |
| # to be over blocks at a given resolution (resblocks and/or self-attention) | |
| # while the inner loop is over a given block | |
| self.blocks = [] | |
| for index in range(len(self.arch['out_channels'])): | |
| self.blocks += [[layers.GBlock(in_channels=self.arch['in_channels'][index], | |
| out_channels=self.arch['out_channels'][index], | |
| which_conv=self.which_conv, | |
| which_bn=self.which_bn, | |
| activation=self.activation, | |
| upsample=(functools.partial(F.interpolate, scale_factor=2) | |
| if self.arch['upsample'][index] else None))]] | |
| # If attention on this block, attach it to the end | |
| if self.arch['attention'][self.arch['resolution'][index]]: | |
| print('Adding attention layer in G at resolution %d' % self.arch['resolution'][index]) | |
| self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index], self.which_conv)] | |
| # Turn self.blocks into a ModuleList so that it's all properly registered. | |
| self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks]) | |
| # output layer: batchnorm-relu-conv. | |
| # Consider using a non-spectral conv here | |
| self.output_layer = nn.Sequential(layers.bn(self.arch['out_channels'][-1], | |
| cross_replica=self.cross_replica, | |
| mybn=self.mybn), | |
| self.activation, | |
| self.which_conv(self.arch['out_channels'][-1], 3)) | |
| # Initialize weights. Optionally skip init for testing. | |
| if not skip_init: | |
| self.init_weights() | |
| # Set up optimizer | |
| # If this is an EMA copy, no need for an optim, so just return now | |
| if no_optim: | |
| return | |
| self.lr, self.B1, self.B2, self.adam_eps = G_lr, G_B1, G_B2, adam_eps | |
| if G_mixed_precision: | |
| print('Using fp16 adam in G...') | |
| import utils | |
| self.optim = utils.Adam16(params=self.parameters(), lr=self.lr, | |
| betas=(self.B1, self.B2), weight_decay=0, | |
| eps=self.adam_eps) | |
| else: | |
| self.optim = optim.Adam(params=self.parameters(), lr=self.lr, | |
| betas=(self.B1, self.B2), weight_decay=0, | |
| eps=self.adam_eps) | |
| # LR scheduling, left here for forward compatibility | |
| # self.lr_sched = {'itr' : 0}# if self.progressive else {} | |
| # self.j = 0 | |
| self.set_optimizer(parameter.optimizer, lr=parameter.learning_rate, betas=parameter.betas) | |
| # Initialize | |
| def init_weights(self): | |
| self.param_count = 0 | |
| for module in self.modules(): | |
| if (isinstance(module, nn.Conv2d) | |
| or isinstance(module, nn.Linear) | |
| or isinstance(module, nn.Embedding)): | |
| if self.init == 'ortho': | |
| init.orthogonal_(module.weight) | |
| elif self.init == 'N02': | |
| init.normal_(module.weight, 0, 0.02) | |
| elif self.init in ['glorot', 'xavier']: | |
| init.xavier_uniform_(module.weight) | |
| else: | |
| print('Init style not recognized...') | |
| self.param_count += sum([p.data.nelement() for p in module.parameters()]) | |
| print('Param count for G''s initialized parameters: %d' % self.param_count) | |
| def transform_labels(self, labels): | |
| """ prepore labels for input to generator """ | |
| return self.transform_label_layer(labels) | |
| # Note on this forward function: we pass in a y vector which has | |
| # already been passed through G.shared to enable easy class-wise | |
| # interpolation later. If we passed in the one-hot and then ran it through | |
| # G.shared in this forward function, it would be harder to handle. | |
| def forward(self, z, y): | |
| # If hierarchical, concatenate zs and ys | |
| y = self.transform_labels(y) | |
| if self.hier: | |
| zs = torch.split(z, self.z_chunk_size, 1) | |
| z = zs[0] | |
| ys = [torch.cat([y, item], 1) for item in zs[1:]] | |
| else: | |
| ys = [y] * len(self.blocks) | |
| # First linear layer | |
| h = self.linear(z) | |
| # Reshape | |
| h = h.view(h.size(0), -1, self.bottom_width, self.bottom_width) | |
| # Loop over blocks | |
| for index, blocklist in enumerate(self.blocks): | |
| # Second inner loop in case block has multiple layers | |
| for block in blocklist: | |
| h = block(h, ys[index]) | |
| # Apply batchnorm-relu-conv-tanh at output | |
| return torch.sigmoid(self.output_layer(h)) | |
| # return torch.tanh(self.output_layer(h)) | |
| # Discriminator architecture, same paradigm as G's above | |
| def D_arch(ch=64, attention='64',ksize='333333', dilation='111111'): | |
| arch = {} | |
| arch[256] = {'in_channels' : [3] + [ch*item for item in [1, 2, 4, 8, 8, 16]], | |
| 'out_channels' : [item * ch for item in [1, 2, 4, 8, 8, 16, 16]], | |
| 'downsample' : [True] * 6 + [False], | |
| 'resolution' : [128, 64, 32, 16, 8, 4, 4 ], | |
| 'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')] | |
| for i in range(2,8)}} | |
| arch[128] = {'in_channels' : [3] + [ch*item for item in [1, 2, 4, 8, 16]], | |
| 'out_channels' : [item * ch for item in [1, 2, 4, 8, 16, 16]], | |
| 'downsample' : [True] * 5 + [False], | |
| 'resolution' : [64, 32, 16, 8, 4, 4], | |
| 'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')] | |
| for i in range(2,8)}} | |
| arch[64] = {'in_channels' : [3] + [ch*item for item in [1, 2, 4, 8]], | |
| 'out_channels' : [item * ch for item in [1, 2, 4, 8, 16]], | |
| 'downsample' : [True] * 4 + [False], | |
| 'resolution' : [32, 16, 8, 4, 4], | |
| 'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')] | |
| for i in range(2,7)}} | |
| arch[32] = {'in_channels' : [3] + [item * ch for item in [4, 4, 4]], | |
| 'out_channels' : [item * ch for item in [4, 4, 4, 4]], | |
| 'downsample' : [True, True, False, False], | |
| 'resolution' : [16, 16, 16, 16], | |
| 'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')] | |
| for i in range(2,6)}} | |
| return arch | |
| class Discriminator(NeuralNetwork): | |
| def __init__(self, D_ch=64, D_wide=True, resolution=64, labels_dim=labels_dim, | |
| D_kernel_size=3, D_attn='64', n_classes=1, | |
| num_D_SVs=1, num_D_SV_itrs=1, D_activation=nn.ReLU(inplace=False), | |
| D_lr=2e-4, D_B1=0.0, D_B2=0.999, adam_eps=1e-8, | |
| SN_eps=1e-12, output_dim=1, D_mixed_precision=False, D_fp16=False, | |
| D_init='ortho', skip_init=False, D_param='SN', **kwargs): | |
| super(Discriminator, self).__init__() | |
| # Width multiplier | |
| self.ch = D_ch | |
| # Use Wide D as in BigGAN and SA-GAN or skinny D as in SN-GAN? | |
| self.D_wide = D_wide | |
| # Resolution | |
| self.resolution = resolution | |
| # Kernel size | |
| self.kernel_size = D_kernel_size | |
| # Attention? | |
| self.attention = D_attn | |
| # Number of classes | |
| self.n_classes = n_classes | |
| # Activation | |
| self.activation = D_activation | |
| # Initialization style | |
| self.init = D_init | |
| # Parameterization style | |
| self.D_param = D_param | |
| # Epsilon for Spectral Norm? | |
| self.SN_eps = SN_eps | |
| # Fp16? | |
| self.fp16 = D_fp16 | |
| # Architecture | |
| self.arch = D_arch(self.ch, self.attention)[resolution] | |
| # Which convs, batchnorms, and linear layers to use | |
| # No option to turn off SN in D right now | |
| if self.D_param == 'SN': | |
| self.which_conv = functools.partial(layers.SNConv2d, | |
| kernel_size=3, padding=1, | |
| num_svs=num_D_SVs, num_itrs=num_D_SV_itrs, | |
| eps=self.SN_eps) | |
| self.which_linear = functools.partial(layers.SNLinear, | |
| num_svs=num_D_SVs, num_itrs=num_D_SV_itrs, | |
| eps=self.SN_eps) | |
| self.which_embedding = functools.partial(layers.SNEmbedding, | |
| num_svs=num_D_SVs, num_itrs=num_D_SV_itrs, | |
| eps=self.SN_eps) | |
| # Prepare model | |
| # prepare label input | |
| self.transform_label_layer = torch.nn.Linear(labels_dim, 1024) | |
| # self.blocks is a doubly-nested list of modules, the outer loop intended | |
| # to be over blocks at a given resolution (resblocks and/or self-attention) | |
| self.blocks = [] | |
| for index in range(len(self.arch['out_channels'])): | |
| self.blocks += [[layers.DBlock(in_channels=self.arch['in_channels'][index], | |
| out_channels=self.arch['out_channels'][index], | |
| which_conv=self.which_conv, | |
| wide=self.D_wide, | |
| activation=self.activation, | |
| preactivation=(index > 0), | |
| downsample=(nn.AvgPool2d(2) if self.arch['downsample'][index] else None))]] | |
| # If attention on this block, attach it to the end | |
| if self.arch['attention'][self.arch['resolution'][index]]: | |
| print('Adding attention layer in D at resolution %d' % self.arch['resolution'][index]) | |
| self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index], | |
| self.which_conv)] | |
| # Turn self.blocks into a ModuleList so that it's all properly registered. | |
| self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks]) | |
| # Linear output layer. The output dimension is typically 1, but may be | |
| # larger if we're e.g. turning this into a VAE with an inference output | |
| self.linear = self.which_linear(self.arch['out_channels'][-1], output_dim) | |
| # Embedding for projection discrimination | |
| self.embed = self.which_embedding(self.n_classes, self.arch['out_channels'][-1]) | |
| # Initialize weights | |
| if not skip_init: | |
| self.init_weights() | |
| # Set up optimizer | |
| self.lr, self.B1, self.B2, self.adam_eps = D_lr, D_B1, D_B2, adam_eps | |
| if D_mixed_precision: | |
| print('Using fp16 adam in D...') | |
| import utils | |
| self.optim = utils.Adam16(params=self.parameters(), lr=self.lr, | |
| betas=(self.B1, self.B2), weight_decay=0, eps=self.adam_eps) | |
| else: | |
| self.optim = optim.Adam(params=self.parameters(), lr=self.lr, | |
| betas=(self.B1, self.B2), weight_decay=0, eps=self.adam_eps) | |
| # LR scheduling, left here for forward compatibility | |
| # self.lr_sched = {'itr' : 0}# if self.progressive else {} | |
| # self.j = 0 | |
| self.set_optimizer(parameter.optimizer, lr=parameter.learning_rate*3, betas=parameter.betas) | |
| # Initialize | |
| def init_weights(self): | |
| self.param_count = 0 | |
| for module in self.modules(): | |
| if (isinstance(module, nn.Conv2d) | |
| or isinstance(module, nn.Linear) | |
| or isinstance(module, nn.Embedding)): | |
| if self.init == 'ortho': | |
| init.orthogonal_(module.weight) | |
| elif self.init == 'N02': | |
| init.normal_(module.weight, 0, 0.02) | |
| elif self.init in ['glorot', 'xavier']: | |
| init.xavier_uniform_(module.weight) | |
| else: | |
| print('Init style not recognized...') | |
| self.param_count += sum([p.data.nelement() for p in module.parameters()]) | |
| print('Param count for D''s initialized parameters: %d' % self.param_count) | |
| def transform_labels(self, labels): | |
| """ prepore labels for input to discriminator """ | |
| return self.transform_label_layer(labels) | |
| def forward(self, x, y=None): | |
| # Stick x into h for cleaner for loops without flow control | |
| h = x | |
| # Loop over blocks | |
| for index, blocklist in enumerate(self.blocks): | |
| for block in blocklist: | |
| h = block(h) | |
| # Apply global sum pooling as in SN-GAN | |
| h = torch.sum(self.activation(h), [2, 3]) | |
| # Get initial class-unconditional output | |
| out = self.linear(h) | |
| # Get projection of final featureset onto class vectors and add to evidence | |
| y = self.transform_labels(y) | |
| out = out + torch.sum(y * h, 1, keepdim=True) | |
| # out = out + torch.sum(self.embed(y) * h, 1, keepdim=True) ## use y = torch.tensor(0) | |
| return out | |
| # Parallelized G_D to minimize cross-gpu communication | |
| # Without this, Generator outputs would get all-gathered and then rebroadcast. | |
| class G_D(nn.Module): | |
| def __init__(self, G, D): | |
| super(G_D, self).__init__() | |
| self.G = G | |
| self.D = D | |
| def forward(self, z, gy, x=None, dy=None, train_G=False, return_G_z=False, | |
| split_D=False): | |
| # If training G, enable grad tape | |
| with torch.set_grad_enabled(train_G): | |
| # Get Generator output given noise | |
| G_z = self.G(z, self.G.shared(gy)) | |
| # Cast as necessary | |
| if self.G.fp16 and not self.D.fp16: | |
| G_z = G_z.float() | |
| if self.D.fp16 and not self.G.fp16: | |
| G_z = G_z.half() | |
| # Split_D means to run D once with real data and once with fake, | |
| # rather than concatenating along the batch dimension. | |
| if split_D: | |
| D_fake = self.D(G_z, gy) | |
| if x is not None: | |
| D_real = self.D(x, dy) | |
| return D_fake, D_real | |
| else: | |
| if return_G_z: | |
| return D_fake, G_z | |
| else: | |
| return D_fake | |
| # If real data is provided, concatenate it with the Generator's output | |
| # along the batch dimension for improved efficiency. | |
| else: | |
| D_input = torch.cat([G_z, x], 0) if x is not None else G_z | |
| D_class = torch.cat([gy, dy], 0) if dy is not None else gy | |
| # Get Discriminator output | |
| D_out = self.D(D_input, D_class) | |
| if x is not None: | |
| return torch.split(D_out, [G_z.shape[0], x.shape[0]]) # D_fake, D_real | |
| else: | |
| if return_G_z: | |
| return D_out, G_z | |
| else: | |
| return D_out | |