Spaces:
Runtime error
Runtime error
| """ | |
| Copyright (c) Meta Platforms, Inc. and affiliates. | |
| All rights reserved. | |
| This source code is licensed under the license found in the | |
| LICENSE file in the root directory of this source tree. | |
| """ | |
| import logging | |
| from turtle import forward | |
| import visualize.ca_body.nn.layers as la | |
| from visualize.ca_body.nn.layers import weight_norm_wrapper | |
| import numpy as np | |
| import torch as th | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| logger = logging.getLogger(__name__) | |
| # pyre-ignore | |
| def weights_initializer(lrelu_slope=0.2): | |
| # pyre-ignore | |
| def init_fn(m): | |
| if isinstance( | |
| m, | |
| ( | |
| nn.Conv2d, | |
| nn.Conv1d, | |
| nn.ConvTranspose2d, | |
| nn.Linear, | |
| ), | |
| ): | |
| gain = nn.init.calculate_gain("leaky_relu", lrelu_slope) | |
| nn.init.kaiming_uniform_(m.weight.data, a=gain) | |
| if hasattr(m, "bias") and m.bias is not None: | |
| nn.init.zeros_(m.bias.data) | |
| else: | |
| logger.debug(f"skipping initialization for {m}") | |
| return init_fn | |
| # pyre-ignore | |
| def WeightNorm(x, dim=0): | |
| return nn.utils.weight_norm(x, dim=dim) | |
| # pyre-ignore | |
| def np_warp_bias(uv_size): | |
| xgrid, ygrid = np.meshgrid(np.linspace(-1.0, 1.0, uv_size), np.linspace(-1.0, 1.0, uv_size)) | |
| grid = np.concatenate((xgrid[None, :, :], ygrid[None, :, :]), axis=0)[None, ...].astype( | |
| np.float32 | |
| ) | |
| return grid | |
| class Conv2dBias(nn.Conv2d): | |
| __annotations__ = {"bias": th.Tensor} | |
| def __init__( | |
| self, | |
| in_channels, | |
| out_channels, | |
| kernel_size, | |
| size, | |
| stride=1, | |
| padding=1, | |
| bias=True, | |
| *args, | |
| **kwargs, | |
| ): | |
| super().__init__( | |
| in_channels, | |
| out_channels, | |
| bias=False, | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| padding=padding, | |
| *args, | |
| **kwargs, | |
| ) | |
| if not bias: | |
| logger.warning("ignoring bias=False") | |
| self.bias = nn.Parameter(th.zeros(out_channels, size, size)) | |
| def forward(self, x): | |
| bias = self.bias.clone() | |
| return ( | |
| # pyre-ignore | |
| th.conv2d( | |
| x, | |
| self.weight, | |
| bias=None, | |
| stride=self.stride, | |
| # pyre-ignore | |
| padding=self.padding, | |
| dilation=self.dilation, | |
| groups=self.groups, | |
| ) | |
| + bias[np.newaxis] | |
| ) | |
| class Conv1dBias(nn.Conv1d): | |
| def __init__( | |
| self, | |
| in_channels, | |
| out_channels, | |
| kernel_size, | |
| size, | |
| stride=1, | |
| padding=0, | |
| bias=True, | |
| *args, | |
| **kwargs, | |
| ): | |
| super().__init__( | |
| in_channels, | |
| out_channels, | |
| bias=False, | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| padding=padding, | |
| *args, | |
| **kwargs, | |
| ) | |
| if not bias: | |
| logger.warning("ignoring bias=False") | |
| self.bias = nn.Parameter(th.zeros(out_channels, size)) | |
| def forward(self, x): | |
| return ( | |
| # pyre-ignore | |
| th.conv1d( | |
| x, | |
| self.weight, | |
| bias=None, | |
| stride=self.stride, | |
| # pyre-ignore | |
| padding=self.padding, | |
| dilation=self.dilation, | |
| groups=self.groups, | |
| ) | |
| + self.bias | |
| ) | |
| class UpConvBlock(nn.Module): | |
| # pyre-ignore | |
| def __init__(self, in_channels, out_channels, size, lrelu_slope=0.2): | |
| super().__init__() | |
| # Intergration: it was not exist in github, but assume upsample is same as other class | |
| self.upsample = nn.UpsamplingBilinear2d(size) | |
| self.conv_resize = la.Conv2dWN( | |
| in_channels=in_channels, out_channels=out_channels, kernel_size=1 | |
| ) | |
| self.conv1 = la.Conv2dWNUB( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=3, | |
| height=size, | |
| width=size, | |
| padding=1, | |
| ) | |
| self.lrelu1 = nn.LeakyReLU(lrelu_slope) | |
| # self.conv2 = nn.utils.weight_norm( | |
| # Conv2dBias(in_channels, out_channels, kernel_size=3, size=size), dim=None, | |
| # ) | |
| # self.lrelu2 = nn.LeakyReLU(lrelu_slope) | |
| # pyre-ignore | |
| def forward(self, x): | |
| x_up = self.upsample(x) | |
| x_skip = self.conv_resize(x_up) | |
| x = self.conv1(x_up) | |
| x = self.lrelu1(x) | |
| return x + x_skip | |
| class ConvBlock1d(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels, | |
| out_channels, | |
| size, | |
| lrelu_slope=0.2, | |
| kernel_size=3, | |
| padding=1, | |
| wnorm_dim=0, | |
| ): | |
| super().__init__() | |
| self.conv_resize = WeightNorm( | |
| nn.Conv1d(in_channels, out_channels, kernel_size=1), dim=wnorm_dim | |
| ) | |
| self.conv1 = WeightNorm( | |
| Conv1dBias( | |
| in_channels, | |
| in_channels, | |
| kernel_size=kernel_size, | |
| padding=padding, | |
| size=size, | |
| ), | |
| dim=wnorm_dim, | |
| ) | |
| self.lrelu1 = nn.LeakyReLU(lrelu_slope) | |
| self.conv2 = WeightNorm( | |
| Conv1dBias( | |
| in_channels, | |
| out_channels, | |
| kernel_size=kernel_size, | |
| padding=padding, | |
| size=size, | |
| ), | |
| dim=wnorm_dim, | |
| ) | |
| self.lrelu2 = nn.LeakyReLU(lrelu_slope) | |
| def forward(self, x): | |
| x_skip = self.conv_resize(x) | |
| x = self.conv1(x) | |
| x = self.lrelu1(x) | |
| x = self.conv2(x) | |
| x = self.lrelu2(x) | |
| return x + x_skip | |
| class ConvBlock(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels, | |
| out_channels, | |
| size, | |
| lrelu_slope=0.2, | |
| kernel_size=3, | |
| padding=1, | |
| wnorm_dim=0, | |
| ): | |
| super().__init__() | |
| Conv2dWNUB = weight_norm_wrapper(la.Conv2dUB, "Conv2dWNUB", g_dim=wnorm_dim, v_dim=None) | |
| Conv2dWN = weight_norm_wrapper(th.nn.Conv2d, "Conv2dWN", g_dim=wnorm_dim, v_dim=None) | |
| # TODO: do we really need this? | |
| self.conv_resize = Conv2dWN(in_channels, out_channels, kernel_size=1) | |
| self.conv1 = Conv2dWNUB( | |
| in_channels, | |
| in_channels, | |
| kernel_size=kernel_size, | |
| padding=padding, | |
| height=size, | |
| width=size, | |
| ) | |
| self.lrelu1 = nn.LeakyReLU(lrelu_slope) | |
| self.conv2 = Conv2dWNUB( | |
| in_channels, | |
| out_channels, | |
| kernel_size=kernel_size, | |
| padding=padding, | |
| height=size, | |
| width=size, | |
| ) | |
| self.lrelu2 = nn.LeakyReLU(lrelu_slope) | |
| def forward(self, x): | |
| x_skip = self.conv_resize(x) | |
| x = self.conv1(x) | |
| x = self.lrelu1(x) | |
| x = self.conv2(x) | |
| x = self.lrelu2(x) | |
| return x + x_skip | |
| class ConvBlockNoSkip(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels, | |
| out_channels, | |
| size, | |
| lrelu_slope=0.2, | |
| kernel_size=3, | |
| padding=1, | |
| wnorm_dim=0, | |
| ): | |
| super().__init__() | |
| self.conv1 = WeightNorm( | |
| Conv2dBias( | |
| in_channels, | |
| in_channels, | |
| kernel_size=kernel_size, | |
| padding=padding, | |
| size=size, | |
| ), | |
| dim=wnorm_dim, | |
| ) | |
| self.lrelu1 = nn.LeakyReLU(lrelu_slope) | |
| self.conv2 = WeightNorm( | |
| Conv2dBias( | |
| in_channels, | |
| out_channels, | |
| kernel_size=kernel_size, | |
| padding=padding, | |
| size=size, | |
| ), | |
| dim=wnorm_dim, | |
| ) | |
| self.lrelu2 = nn.LeakyReLU(lrelu_slope) | |
| def forward(self, x): | |
| x = self.conv1(x) | |
| x = self.lrelu1(x) | |
| x = self.conv2(x) | |
| x = self.lrelu2(x) | |
| return x | |
| class ConvDownBlock(nn.Module): | |
| def __init__(self, in_channels, out_channels, size, lrelu_slope=0.2, groups=1, wnorm_dim=0): | |
| """Constructor. | |
| Args: | |
| in_channels: int, # of input channels | |
| out_channels: int, # of input channels | |
| size: the *input* size | |
| """ | |
| super().__init__() | |
| Conv2dWNUB = weight_norm_wrapper(la.Conv2dUB, "Conv2dWNUB", g_dim=wnorm_dim, v_dim=None) | |
| Conv2dWN = weight_norm_wrapper(th.nn.Conv2d, "Conv2dWN", g_dim=wnorm_dim, v_dim=None) | |
| self.conv_resize = Conv2dWN( | |
| in_channels, out_channels, kernel_size=1, stride=2, groups=groups | |
| ) | |
| self.conv1 = Conv2dWNUB( | |
| in_channels, | |
| in_channels, | |
| kernel_size=3, | |
| height=size, | |
| width=size, | |
| groups=groups, | |
| padding=1, | |
| ) | |
| self.lrelu1 = nn.LeakyReLU(lrelu_slope) | |
| self.conv2 = Conv2dWNUB( | |
| in_channels, | |
| out_channels, | |
| kernel_size=3, | |
| stride=2, | |
| height=size // 2, | |
| width=size // 2, | |
| groups=groups, | |
| padding=1, | |
| ) | |
| self.lrelu2 = nn.LeakyReLU(lrelu_slope) | |
| def forward(self, x): | |
| x_skip = self.conv_resize(x) | |
| x = self.conv1(x) | |
| x = self.lrelu1(x) | |
| x = self.conv2(x) | |
| x = self.lrelu2(x) | |
| return x + x_skip | |
| class UpConvBlockDeep(nn.Module): | |
| def __init__(self, in_channels, out_channels, size, lrelu_slope=0.2, wnorm_dim=0, groups=1): | |
| super().__init__() | |
| self.upsample = nn.UpsamplingBilinear2d(size) | |
| Conv2dWNUB = weight_norm_wrapper(la.Conv2dUB, "Conv2dWNUB", g_dim=wnorm_dim, v_dim=None) | |
| Conv2dWN = weight_norm_wrapper(th.nn.Conv2d, "Conv2dWN", g_dim=wnorm_dim, v_dim=None) | |
| # NOTE: the old one normalizes only across one dimension | |
| self.conv_resize = Conv2dWN( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=1, | |
| groups=groups, | |
| ) | |
| self.conv1 = Conv2dWNUB( | |
| in_channels, | |
| in_channels, | |
| kernel_size=3, | |
| height=size, | |
| width=size, | |
| padding=1, | |
| groups=groups, | |
| ) | |
| self.lrelu1 = nn.LeakyReLU(lrelu_slope) | |
| self.conv2 = Conv2dWNUB( | |
| in_channels, | |
| out_channels, | |
| kernel_size=3, | |
| height=size, | |
| width=size, | |
| padding=1, | |
| groups=groups, | |
| ) | |
| self.lrelu2 = nn.LeakyReLU(lrelu_slope) | |
| def forward(self, x): | |
| x_up = self.upsample(x) | |
| x_skip = self.conv_resize(x_up) | |
| x = x_up | |
| x = self.conv1(x) | |
| x = self.lrelu1(x) | |
| x = self.conv2(x) | |
| x = self.lrelu2(x) | |
| return x + x_skip | |
| class ConvBlockPositional(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels, | |
| out_channels, | |
| pos_map, | |
| lrelu_slope=0.2, | |
| kernel_size=3, | |
| padding=1, | |
| wnorm_dim=0, | |
| ): | |
| """Block with positional encoding. | |
| Args: | |
| in_channels: # of input channels (not counting the positional encoding) | |
| out_channels: # of output channels | |
| pos_map: tensor [P, size, size] | |
| """ | |
| super().__init__() | |
| assert len(pos_map.shape) == 3 and pos_map.shape[1] == pos_map.shape[2] | |
| self.register_buffer("pos_map", pos_map) | |
| self.conv_resize = WeightNorm(nn.Conv2d(in_channels, out_channels, 1), dim=wnorm_dim) | |
| self.conv1 = WeightNorm( | |
| nn.Conv2d( | |
| in_channels + pos_map.shape[0], | |
| in_channels, | |
| kernel_size=3, | |
| padding=padding, | |
| ), | |
| dim=wnorm_dim, | |
| ) | |
| self.lrelu1 = nn.LeakyReLU(lrelu_slope) | |
| self.conv2 = WeightNorm( | |
| nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=padding), | |
| dim=wnorm_dim, | |
| ) | |
| self.lrelu2 = nn.LeakyReLU(lrelu_slope) | |
| def forward(self, x): | |
| B = x.shape[0] | |
| x_skip = self.conv_resize(x) | |
| pos = self.pos_map[np.newaxis].expand(B, -1, -1, -1) | |
| x = th.cat([x, pos], dim=1) | |
| x = self.conv1(x) | |
| x = self.lrelu1(x) | |
| x = self.conv2(x) | |
| x = self.lrelu2(x) | |
| return x + x_skip | |
| class UpConvBlockPositional(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels, | |
| out_channels, | |
| pos_map, | |
| lrelu_slope=0.2, | |
| wnorm_dim=0, | |
| ): | |
| """Block with positional encoding. | |
| Args: | |
| in_channels: # of input channels (not counting the positional encoding) | |
| out_channels: # of output channels | |
| pos_map: tensor [P, size, size] | |
| """ | |
| super().__init__() | |
| assert len(pos_map.shape) == 3 and pos_map.shape[1] == pos_map.shape[2] | |
| self.register_buffer("pos_map", pos_map) | |
| size = pos_map.shape[1] | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| self.upsample = nn.UpsamplingBilinear2d(size) | |
| if in_channels != out_channels: | |
| self.conv_resize = WeightNorm(nn.Conv2d(in_channels, out_channels, 1), dim=wnorm_dim) | |
| self.conv1 = WeightNorm( | |
| nn.Conv2d( | |
| in_channels + pos_map.shape[0], | |
| in_channels, | |
| kernel_size=3, | |
| padding=1, | |
| ), | |
| dim=wnorm_dim, | |
| ) | |
| self.lrelu1 = nn.LeakyReLU(lrelu_slope) | |
| self.conv2 = WeightNorm( | |
| nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), | |
| dim=wnorm_dim, | |
| ) | |
| self.lrelu2 = nn.LeakyReLU(lrelu_slope) | |
| def forward(self, x): | |
| B = x.shape[0] | |
| x_up = self.upsample(x) | |
| x_skip = x_up | |
| if self.in_channels != self.out_channels: | |
| x_skip = self.conv_resize(x_up) | |
| pos = self.pos_map[np.newaxis].expand(B, -1, -1, -1) | |
| x = th.cat([x_up, pos], dim=1) | |
| x = self.conv1(x) | |
| x = self.lrelu1(x) | |
| x = self.conv2(x) | |
| x = self.lrelu2(x) | |
| return x + x_skip | |
| class UpConvBlockDeepNoBias(nn.Module): | |
| def __init__(self, in_channels, out_channels, size, lrelu_slope=0.2, wnorm_dim=0, groups=1): | |
| super().__init__() | |
| self.upsample = nn.UpsamplingBilinear2d(size) | |
| # NOTE: the old one normalizes only across one dimension | |
| self.conv_resize = WeightNorm( | |
| nn.Conv2d(in_channels, out_channels, 1, groups=groups), dim=wnorm_dim | |
| ) | |
| self.conv1 = WeightNorm( | |
| nn.Conv2d(in_channels, in_channels, padding=1, kernel_size=3, groups=groups), | |
| dim=wnorm_dim, | |
| ) | |
| self.lrelu1 = nn.LeakyReLU(lrelu_slope) | |
| self.conv2 = WeightNorm( | |
| nn.Conv2d(in_channels, out_channels, padding=1, kernel_size=3, groups=groups), | |
| dim=wnorm_dim, | |
| ) | |
| self.lrelu2 = nn.LeakyReLU(lrelu_slope) | |
| def forward(self, x): | |
| x_up = self.upsample(x) | |
| x_skip = self.conv_resize(x_up) | |
| x = x_up | |
| x = self.conv1(x) | |
| x = self.lrelu1(x) | |
| x = self.conv2(x) | |
| x = self.lrelu2(x) | |
| return x + x_skip | |
| class UpConvBlockXDeep(nn.Module): | |
| def __init__(self, in_channels, out_channels, size, lrelu_slope=0.2, wnorm_dim=0): | |
| super().__init__() | |
| self.upsample = nn.UpsamplingBilinear2d(size) | |
| # TODO: see if this is necce | |
| self.conv_resize = WeightNorm(nn.Conv2d(in_channels, out_channels, 1), dim=wnorm_dim) | |
| self.conv1 = WeightNorm( | |
| Conv2dBias(in_channels, in_channels // 2, kernel_size=3, size=size), | |
| dim=wnorm_dim, | |
| ) | |
| self.lrelu1 = nn.LeakyReLU(lrelu_slope) | |
| self.conv2 = WeightNorm( | |
| Conv2dBias(in_channels // 2, in_channels // 2, kernel_size=3, size=size), | |
| dim=wnorm_dim, | |
| ) | |
| self.lrelu2 = nn.LeakyReLU(lrelu_slope) | |
| self.conv2 = WeightNorm( | |
| Conv2dBias(in_channels // 2, in_channels // 2, kernel_size=3, size=size), | |
| dim=wnorm_dim, | |
| ) | |
| self.lrelu2 = nn.LeakyReLU(lrelu_slope) | |
| self.conv3 = WeightNorm( | |
| Conv2dBias(in_channels // 2, out_channels, kernel_size=3, size=size), | |
| dim=wnorm_dim, | |
| ) | |
| self.lrelu3 = nn.LeakyReLU(lrelu_slope) | |
| def forward(self, x): | |
| x_up = self.upsample(x) | |
| x_skip = self.conv_resize(x_up) | |
| x = x_up | |
| x = self.conv1(x) | |
| x = self.lrelu1(x) | |
| x = self.conv2(x) | |
| x = self.lrelu2(x) | |
| x = self.conv3(x) | |
| x = self.lrelu3(x) | |
| return x + x_skip | |
| class UpConvCondBlock(nn.Module): | |
| def __init__(self, in_channels, out_channels, size, cond_channels, lrelu_slope=0.2): | |
| super().__init__() | |
| self.upsample = nn.UpsamplingBilinear2d(size) | |
| self.conv_resize = nn.utils.weight_norm(nn.Conv2d(in_channels, out_channels, 1), dim=None) | |
| self.conv1 = WeightNorm( | |
| Conv2dBias(in_channels + cond_channels, in_channels, kernel_size=3, size=size), | |
| ) | |
| self.lrelu1 = nn.LeakyReLU(lrelu_slope) | |
| self.conv2 = WeightNorm( | |
| Conv2dBias(in_channels, out_channels, kernel_size=3, size=size), | |
| ) | |
| self.lrelu2 = nn.LeakyReLU(lrelu_slope) | |
| def forward(self, x, cond): | |
| x_up = self.upsample(x) | |
| x_skip = self.conv_resize(x_up) | |
| x = x_up | |
| x = th.cat([x, cond], dim=1) | |
| x = self.conv1(x) | |
| x = self.lrelu1(x) | |
| x = self.conv2(x) | |
| x = self.lrelu2(x) | |
| return x + x_skip | |
| class UpConvBlockPS(nn.Module): | |
| # pyre-ignore | |
| def __init__(self, n_in, n_out, size, kernel_size=3, padding=1): | |
| super().__init__() | |
| self.conv1 = la.Conv2dWNUB( | |
| n_in, | |
| n_out * 4, | |
| size, | |
| size, | |
| kernel_size=kernel_size, | |
| padding=padding, | |
| ) | |
| self.lrelu = nn.LeakyReLU(0.2, inplace=True) | |
| self.ps = nn.PixelShuffle(2) | |
| def forward(self, x): | |
| x = self.conv(x) | |
| x = self.lrelu(x) | |
| return self.ps(x) | |
| # pyre-ignore | |
| def apply_crop( | |
| image, | |
| ymin, | |
| ymax, | |
| xmin, | |
| xmax, | |
| ): | |
| """Crops a region from an image.""" | |
| # NOTE: here we are expecting one of [H, W] [H, W, C] [B, H, W, C] | |
| if len(image.shape) == 2: | |
| return image[ymin:ymax, xmin:xmax] | |
| elif len(image.shape) == 3: | |
| return image[ymin:ymax, xmin:xmax, :] | |
| elif len(image.shape) == 4: | |
| return image[:, ymin:ymax, xmin:xmax, :] | |
| else: | |
| raise ValueError("provide a batch of images or a single image") | |
| def tile1d(x, size): | |
| """Tile a given set of features into a convolutional map. | |
| Args: | |
| x: float tensor of shape [N, F] | |
| size: int or a tuple | |
| Returns: | |
| a feature map [N, F, ∑size[0], size[1]] | |
| """ | |
| # size = size if isinstance(size, tuple) else (size, size) | |
| return x[:, :, np.newaxis].expand(-1, -1, size) | |
| def tile2d(x, size: int): | |
| """Tile a given set of features into a convolutional map. | |
| Args: | |
| x: float tensor of shape [N, F] | |
| size: int or a tuple | |
| Returns: | |
| a feature map [N, F, size[0], size[1]] | |
| """ | |
| # size = size if isinstance(size, tuple) else (size, size) | |
| # NOTE: expecting only int here (!!!) | |
| return x[:, :, np.newaxis, np.newaxis].expand(-1, -1, size, size) | |
| def sample_negative_idxs(size, *args, **kwargs): | |
| idxs = th.randperm(size, *args, **kwargs) | |
| if th.all(idxs == th.arange(size, dtype=idxs.dtype, device=idxs.device)): | |
| return th.flip(idxs, (0,)) | |
| return idxs | |
| def icnr_init(x, scale=2, init=nn.init.kaiming_normal_): | |
| ni, nf, h, w = x.shape | |
| ni2 = int(ni / (scale**2)) | |
| k = init(x.new_zeros([ni2, nf, h, w])).transpose(0, 1) | |
| k = k.contiguous().view(ni2, nf, -1) | |
| k = k.repeat(1, 1, scale**2) | |
| return k.contiguous().view([nf, ni, h, w]).transpose(0, 1) | |
| class PixelShuffleWN(nn.Module): | |
| """PixelShuffle with the right initialization. | |
| NOTE: make sure to create this one | |
| """ | |
| def __init__(self, n_in, n_out, upscale_factor=2): | |
| super().__init__() | |
| self.upscale_factor = upscale_factor | |
| self.n_in = n_in | |
| self.n_out = n_out | |
| self.conv = la.Conv2dWN(n_in, n_out * (upscale_factor**2), kernel_size=1, padding=0) | |
| # NOTE: the bias is 2K? | |
| self.ps = nn.PixelShuffle(upscale_factor) | |
| self._init_icnr() | |
| def _init_icnr(self): | |
| self.conv.weight_v.data.copy_(icnr_init(self.conv.weight_v.data)) | |
| self.conv.weight_g.data.copy_( | |
| ((self.conv.weight_v.data**2).sum(dim=[1, 2, 3]) ** 0.5)[:, None, None, None] | |
| ) | |
| def forward(self, x): | |
| x = self.conv(x) | |
| return self.ps(x) | |
| class UpscaleNet(nn.Module): | |
| def __init__(self, in_channels, out_channels=3, n_ftrs=16, size=1024, upscale_factor=2): | |
| super().__init__() | |
| self.conv_block = nn.Sequential( | |
| la.Conv2dWNUB(in_channels, n_ftrs, size, size, kernel_size=3, padding=1), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| la.Conv2dWNUB(n_ftrs, n_ftrs, size, size, kernel_size=3, padding=1), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| ) | |
| self.out_block = la.Conv2dWNUB( | |
| n_ftrs, | |
| out_channels * upscale_factor**2, | |
| size, | |
| size, | |
| kernel_size=1, | |
| padding=0, | |
| ) | |
| self.pixel_shuffle = nn.PixelShuffle(upscale_factor=upscale_factor) | |
| self.apply(lambda x: la.glorot(x, 0.2)) | |
| self.out_block.apply(weights_initializer(1.0)) | |
| def forward(self, x): | |
| x = self.conv_block(x) | |
| x = self.out_block(x) | |
| return self.pixel_shuffle(x) | |