Spaces:
Sleeping
Sleeping
| from torch import nn | |
| import torch | |
| from torch.nn import functional as F | |
| from typing import Optional | |
| import math | |
| class WSLinear(nn.Module): | |
| ''' | |
| Weighted scale linear for equalized learning rate. | |
| Args: | |
| in_features (int): The number of input features. | |
| out_features (int): The number of output features. | |
| ''' | |
| def __init__(self, in_features: int, out_features: int) -> None: | |
| super(WSLinear, self).__init__() | |
| self.in_features = in_features | |
| self.out_features = out_features | |
| self.linear = nn.Linear(self.in_features, self.out_features) | |
| self.scale = (2 / self.in_features) ** 0.5 | |
| self.bias = self.linear.bias | |
| self.linear.bias = None | |
| self._init_weights() | |
| def _init_weights(self) -> None: | |
| nn.init.normal_(self.linear.weight) | |
| nn.init.zeros_(self.bias) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.linear(x * self.scale) + self.bias | |
| class WSConv2d(nn.Module): | |
| """ | |
| Weight-scaled Conv2d layer for equalized learning rate. | |
| Args: | |
| in_channels (int): Number of input channels. | |
| out_channels (int): Number of output channels. | |
| kernel_size (int, optional): Size of the convolving kernel. Default: 3. | |
| stride (int, optional): Stride of the convolution. Default: 1. | |
| padding (int, optional): Padding added to all sides of the input. Default: 1. | |
| gain (float, optional): Gain factor for weight initialization. Default: 2. | |
| """ | |
| def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, gain=2): | |
| super().__init__() | |
| self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding) | |
| self.scale = (gain / (in_channels * kernel_size ** 2)) ** 0.5 | |
| self.bias = self.conv.bias | |
| self.conv.bias = None # Remove bias to apply it after scaling | |
| # Initialize weights | |
| nn.init.normal_(self.conv.weight) | |
| nn.init.zeros_(self.bias) | |
| def forward(self, x): | |
| return self.conv(x * self.scale) + self.bias.view(1, self.bias.shape[0], 1, 1) | |
| class Mapping(nn.Module): | |
| ''' | |
| Mapping network. | |
| Args: | |
| features (int): Number of features in the input and output. | |
| num_layers (int): Number of layers in the feed forward network. | |
| num_styles (int): Number of styles to generate. | |
| ''' | |
| def __init__( | |
| self, | |
| features: int, | |
| num_styles: int, | |
| num_layers: int = 8, | |
| ) -> None: | |
| super(Mapping, self).__init__() | |
| self.features = features | |
| self.num_layers = num_layers | |
| self.num_styles = num_styles | |
| layers = [] | |
| for _ in range(self.num_layers): | |
| layers.append(WSLinear(self.features, self.features)) | |
| layers.append(nn.LeakyReLU(0.2)) | |
| self.fc = nn.Sequential(*layers) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| ''' | |
| Args: | |
| x (torch.Tensor): Input tensor of shape (b, l). | |
| Returns: | |
| torch.Tensor: Output tensor with the same shape as input. | |
| ''' | |
| x = self.fc(x) # (b, l) | |
| return x | |
| class AdaIN(nn.Module): | |
| ''' | |
| Adaptive Instance Normalization (AdaIN) | |
| AdaIN(x_i, y) = y_s,i * (x_i - mean(x_i)) / std(x_i) + y_b,i | |
| Args: | |
| eps (float, optional): Small value to avoid division by zero. Default value is 0.00001. | |
| ''' | |
| def __init__(self, eps: float= 1e-5) -> None: | |
| super(AdaIN, self).__init__() | |
| self.eps = eps | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| scale: torch.Tensor, | |
| shift: torch.Tensor | |
| ) -> torch.Tensor: | |
| ''' | |
| Args: | |
| x (torch.Tensor): Input tensor of shape (b, c, h, w). | |
| scale (torch.Tensor): Scale tensor of shape (b, c). | |
| shift (torch.Tensor): Shift tensor of shape (b, c). | |
| Returns: | |
| torch.Tensor: Output tensor of shape (b, c, h, w). | |
| ''' | |
| b, c, *_ = x.shape | |
| mean = x.mean(dim=(2, 3), keepdim=True) # (b, c, 1, 1) | |
| std = x.std(dim=(2, 3), keepdim=True) # (b, c, 1, 1) | |
| x_norm = (x - mean) / (std ** 2 + self.eps) ** .5 | |
| scale = scale.view(b, c, 1, 1) # (b, c, 1, 1) | |
| shift = scale.view(b, c, 1, 1) # (b, c, 1, 1) | |
| outputs = scale * x_norm + shift # (b, c, h, w) | |
| return outputs | |
| class SynthesisLayer(nn.Module): | |
| ''' | |
| Synthesis network layer which consist of: | |
| - Conv2d. | |
| - AdaIN. | |
| - Affine transformation. | |
| - Noise injection. | |
| Args: | |
| in_channels (int): The number of input channels. | |
| out_channels (int): The number of output channels. | |
| latent_features (int): The number of latent features. | |
| use_conv (bool, optional): Whether to use convolution or not. Default value is True. | |
| ''' | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| latent_features: int, | |
| use_conv: bool = True | |
| ) -> None: | |
| super(SynthesisLayer, self).__init__() | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| self.latent_features = latent_features | |
| self.use_conv = use_conv | |
| self.conv = nn.Sequential( | |
| WSConv2d(self.in_channels, self.out_channels, kernel_size=3, padding=1), | |
| nn.LeakyReLU(0.2) | |
| ) if self.use_conv else nn.Identity() | |
| self.norm = AdaIN() | |
| self.scale_transform = WSLinear(self.latent_features, self.out_channels) | |
| self.shift_transform = WSLinear(self.latent_features, self.out_channels) | |
| self.noise_factor = nn.Parameter(torch.zeros(1, self.out_channels, 1, 1)) | |
| self._init_weights() | |
| def _init_weights(self) -> None: | |
| for m in self.modules(): | |
| if isinstance(m, (nn.Conv2d, nn.Linear)): | |
| nn.init.normal_(m.weight) | |
| if m.bias is not None: | |
| nn.init.zeros_(m.bias) | |
| nn.init.ones_(self.scale_transform.bias) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| w: torch.Tensor, | |
| noise: Optional[torch.Tensor] = None | |
| ) -> torch.Tensor: | |
| ''' | |
| Args: | |
| x (torch.Tensor): Input tensor of shape (b, c, h, w). | |
| w (torch.Tensor): Latent space vector of shape (b, l). | |
| noise (torch.Tensor, optional): Noise tensor of shape (b, 1, h, w). Default value is None. | |
| Returns: | |
| torch.Tensor: Output tensor of shape (b, c, h, w). | |
| ''' | |
| b, _, h, w_ = x.shape | |
| x = self.conv(x) # (b, o_c, h, w) | |
| if noise is None: | |
| noise = torch.randn(b, 1, h, w_, device=x.device) # (b, 1, h, w) | |
| x += self.noise_factor * noise # (b, o_c, h, w) | |
| y_s = self.scale_transform(w) # (b, o_c) | |
| y_b = self.shift_transform(w) # (b, o_c) | |
| x = self.norm(x, y_s, y_b) # (b, i_c, h, w) | |
| return x | |
| class SynthesisBlock(nn.Module): | |
| ''' | |
| Synthesis network block which consist of: | |
| - Optional upsampling. | |
| - 2 Synthesis Layers. | |
| Args: | |
| in_channels (int): The number of input channels. | |
| out_channels (int): The number of output channels. | |
| latent_features (int): The number of latent features. | |
| use_conv (bool, optional): Whether to use convolution or not. Default value is True. | |
| upsample (bool, optional): Whether to use upsampling or not. Default value is True. | |
| ''' | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| latent_features: int, | |
| *, | |
| use_conv: bool = True, | |
| upsample: bool = True | |
| ) -> None: | |
| super(SynthesisBlock, self).__init__() | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| self.latent_features = latent_features | |
| self.use_conv = use_conv | |
| self.upsample = upsample | |
| self.upsample = nn.Upsample(scale_factor=2, mode='bilinear') if self.upsample else nn.Identity() | |
| self.layers = nn.ModuleList([ | |
| SynthesisLayer(self.in_channels, self.in_channels, self.latent_features, use_conv=self.use_conv), | |
| SynthesisLayer(self.in_channels, self.out_channels, self.latent_features) | |
| ]) | |
| def forward(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor: | |
| ''' | |
| Args: | |
| x (torch.Tensor): Input tensor of shape (b, c, h, w). | |
| w (torch.Tensor): Latent vector of shape (b, l). | |
| Returns: | |
| torch.Tensor: Output tensor of shape (b, c, h, w) if not upsample else (b, c, 2h, 2w). | |
| ''' | |
| x = self.upsample(x) # (b, c, h, w) if not upsample else (b, c, 2h, 2w) | |
| for layer in self.layers: | |
| x = layer(x, w) # (b, c, h, w) if not upsample else (b, c, 2h, 2w) | |
| return x | |
| class Synthesis(nn.Module): | |
| ''' | |
| Synthesis network which consist of: | |
| - Constant tensor. | |
| - Synthesis blocks. | |
| - ToRGB convolutions. | |
| Args: | |
| resolution (int): The resolution of the image. | |
| const_channels (int): The number of channels in the constant tensor. Default value is 512. | |
| ''' | |
| def __init__(self, resolution: int, const_channels: int = 512) -> None: | |
| super(Synthesis, self).__init__() | |
| self.const_channels = const_channels | |
| self.resolution = resolution | |
| self.resolution_levels = int(math.log2(resolution) - 1) | |
| self.constant = nn.Parameter(torch.ones(1, self.const_channels, 4, 4)) # (c, 4, 4) | |
| in_channels = self.const_channels | |
| blocks = [ SynthesisBlock(in_channels, in_channels, self.const_channels, use_conv=False, upsample=False) ] | |
| to_rgb = [ WSConv2d(in_channels, 3, kernel_size=1, padding=0) ] | |
| for _ in range(self.resolution_levels - 1): | |
| blocks.append(SynthesisBlock(in_channels, in_channels // 2, self.const_channels)) | |
| to_rgb.append(WSConv2d(in_channels // 2, 3, kernel_size=1, padding=0)) | |
| in_channels //= 2 | |
| self.blocks = nn.ModuleList(blocks) | |
| self.to_rgb = nn.ModuleList(to_rgb) | |
| def forward(self, w: torch.Tensor, alpha: float, steps: int) -> torch.Tensor: | |
| ''' | |
| Args: | |
| w (torch.Tensor): Latent space vector of shape (b, l). | |
| alpha (float): Fade in alpha value. | |
| steps (int): The number of steps starting from 0. | |
| Returns: | |
| torch.Tensor: Output tensor of shape (b, 3, h, w). | |
| ''' | |
| b = w.size(0) | |
| x = self.constant.expand(b, -1, -1, -1).clone() # (b, c, h, w) | |
| if steps == 0: | |
| x = self.blocks[0](x, w) # (b, c, h, w) | |
| x = self.to_rgb[0](x) # (b, c, h, w) | |
| return x | |
| for i in range(steps): | |
| x = self.blocks[i](x, w) # (b, c, h/2, w/2) | |
| old_rgb = self.to_rgb[steps - 1](x) # (b, 3, h/2, w/2) | |
| x = self.blocks[steps](x, w) # (b, 3, h, w) | |
| new_rgb = self.to_rgb[steps](x) # (b, 3, h, w) | |
| old_rgb = F.interpolate(old_rgb, scale_factor=2, mode='bilinear', align_corners=False) # (b, 3, h, w) | |
| x = (1 - alpha) * old_rgb + alpha * new_rgb # (b, 3, h, w) | |
| return x | |
| class StyleGAN(nn.Module): | |
| ''' | |
| StyleGAN implementation. | |
| Args: | |
| num_features (int): The number of features in the latent space vector. | |
| resolution (int): The resolution of the image. | |
| num_blocks (int, optional): The number of blocks in the synthesis network. Default value is 10. | |
| ''' | |
| def __init__(self, num_features: int, resolution: int, num_blocks: int = 10): | |
| super(StyleGAN, self).__init__() | |
| self.num_features = num_features | |
| self.resolution = resolution | |
| self.num_blocks = num_blocks | |
| self.mapping = Mapping(self.num_features, self.num_blocks) | |
| self.synthesis = Synthesis(self.resolution, self.num_features) | |
| def forward(self, x: torch.Tensor, alpha: float, steps: int) -> torch.Tensor: | |
| ''' | |
| Args: | |
| x (torch.Tensor): Random input tensor of shape (b, l). | |
| alpha (float): Fade in alpha value. | |
| steps (int): The number of steps starting from 0. | |
| Returns: | |
| torch.Tensor: Output tensor of shape (b, c, h, w). | |
| ''' | |
| w = self.mapping(x) # (b, l) | |
| outputs = self.synthesis(w, alpha, steps) # (b, c, h, w) | |
| return outputs |