from typing import Optional, Dict from functools import partial import math import torch import torch.nn as nn def get_activation(activation: str = "lrelu"): actv_layers = { "relu": nn.ReLU, "lrelu": partial(nn.LeakyReLU, 0.2), } assert activation in actv_layers, f"activation [{activation}] not implemented" return actv_layers[activation] def get_normalization(normalization: str = "batch_norm"): norm_layers = { "instance_norm": nn.InstanceNorm2d, "batch_norm": nn.BatchNorm2d, "group_norm": partial(nn.GroupNorm, num_groups=8), "layer_norm": partial(nn.GroupNorm, num_groups=1), } assert normalization in norm_layers, f"normalization [{normalization}] not implemented" return norm_layers[normalization] class ConvLayer(nn.Sequential): def __init__( self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1, padding: Optional[int] = 1, padding_mode: str = "zeros", groups: int = 1, bias: bool = True, transposed: bool = False, normalization: Optional[str] = None, activation: Optional[str] = "lrelu", pre_activate: bool = False, ): if transposed: conv = partial(nn.ConvTranspose2d, output_padding=stride-1) padding_mode = "zeros" else: conv = nn.Conv2d layers = [ conv( in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, padding_mode=padding_mode, groups=groups, bias=bias, ) ] norm_actv = [] if normalization is not None: norm_actv.append( get_normalization(normalization)( num_channels=in_channels if pre_activate else out_channels ) ) if activation is not None: norm_actv.append( get_activation(activation)(inplace=True) ) if pre_activate: layers = norm_actv + layers else: layers = layers + norm_actv super().__init__( *layers ) class SubspaceLayer(nn.Module): def __init__( self, dim: int, n_basis: int, ): super().__init__() self.U = nn.Parameter(torch.empty(n_basis, dim)) nn.init.orthogonal_(self.U) self.L = nn.Parameter(torch.FloatTensor([3 * i for i in range(n_basis, 0, -1)])) self.mu = nn.Parameter(torch.zeros(dim)) def forward(self, z): return (self.L * z) @ self.U + self.mu class EigenBlock(nn.Module): def __init__( self, width: int, height: int, in_channels: int, out_channels: int, n_basis: int, ): super().__init__() self.projection = SubspaceLayer(dim=width*height*in_channels, n_basis=n_basis) self.subspace_conv1 = ConvLayer( in_channels, in_channels, kernel_size=1, stride=1, padding=0, transposed=True, activation=None, normalization=None, ) self.subspace_conv2 = ConvLayer( in_channels, out_channels, kernel_size=3, stride=2, padding=1, transposed=True, activation=None, normalization=None, ) self.feature_conv1 = ConvLayer( in_channels, out_channels, kernel_size=3, stride=2, transposed=True, pre_activate=True, ) self.feature_conv2 = ConvLayer( out_channels, out_channels, kernel_size=3, stride=1, transposed=True, pre_activate=True, ) def forward(self, z, h): phi = self.projection(z).view(h.shape) h = self.feature_conv1(h + self.subspace_conv1(phi)) h = self.feature_conv2(h + self.subspace_conv2(phi)) return h class ConditionalGenerator(nn.Module): """Conditional generator It generates images from one hot label + noise sampled from N(0, 1) with explorable z injection space Based on EigenGAN """ def __init__(self, size: int, y_size: int, z_size: int, out_channels: int = 3, n_basis: int = 6, noise_dim: int = 512, base_channels: int = 16, max_channels: int = 512, y_type: str = 'one_hot'): if y_type not in ['one_hot', 'multi_label', 'mixed', 'real']: raise ValueError('Unsupported `y_type`') super(ConditionalGenerator, self).__init__() assert (size & (size - 1) == 0) and size != 0, "img size should be a power of 2" self.y_type = y_type self.y_size = y_size self.eps_size = z_size self.noise_dim = noise_dim self.n_basis = n_basis self.n_blocks = int(math.log(size, 2)) - 2 def get_channels(i_block): return min(max_channels, base_channels * (2 ** (self.n_blocks - i_block))) self.y_fc = nn.Linear(self.y_size, self.y_size) self.concat_fc = nn.Linear(self.y_size + self.eps_size, self.noise_dim) self.fc = nn.Linear(self.noise_dim, 4 * 4 * get_channels(0)) self.blocks = nn.ModuleList() for i in range(self.n_blocks): self.blocks.append( EigenBlock( width=4 * (2 ** i), height=4 * (2 ** i), in_channels=get_channels(i), out_channels=get_channels(i + 1), n_basis=self.n_basis, ) ) self.out = nn.Sequential( ConvLayer(base_channels, out_channels, kernel_size=7, stride=1, padding=3, pre_activate=True), nn.Tanh(), ) def forward(self, y: torch.Tensor, eps: Optional[torch.Tensor] = None, zs: Optional[torch.Tensor] = None, return_eps: bool = False): bs = y.size(0) if eps is None: eps = self.sample_eps(bs) if zs is None: zs = self.sample_zs(bs) y_out = self.y_fc(y) concat = torch.cat((y_out, eps), dim=1) concat = self.concat_fc(concat) out = self.fc(concat).view(len(eps), -1, 4, 4) for block, z in zip(self.blocks, zs.permute(1, 0, 2)): out = block(z, out) out = self.out(out) if return_eps: return out, concat return out def sample_zs(self, batch: int, truncation: float = 1.): device = self.get_device() zs = torch.randn(batch, self.n_blocks, self.n_basis, device=device) if truncation < 1.: zs = torch.zeros_like(zs) * (1 - truncation) + zs * truncation return zs def sample_eps(self, batch: int, truncation: float = 1.): device = self.get_device() eps = torch.randn(batch, self.eps_size, device=device) if truncation < 1.: eps = torch.zeros_like(eps) * (1 - truncation) + eps * truncation return eps def get_device(self): return self.fc.weight.device def orthogonal_regularizer(self): reg = [] for layer in self.modules(): if isinstance(layer, SubspaceLayer): UUT = layer.U @ layer.U.t() reg.append( ((UUT - torch.eye(UUT.shape[0], device=UUT.device)) ** 2).mean() ) return sum(reg) / len(reg)