| """Layers used for up-sampling or down-sampling images.
|
|
|
| Many functions are ported from https://github.com/NVlabs/stylegan2.
|
| """
|
|
|
| import torch.nn as nn
|
| import torch
|
| import torch.nn.functional as F
|
| import numpy as np
|
| from .op import upfirdn2d
|
|
|
|
|
|
|
| def get_weight(module,
|
| shape,
|
| weight_var='weight',
|
| kernel_init=None):
|
| """Get/create weight tensor for a convolution or fully-connected layer."""
|
|
|
| return module.param(weight_var, kernel_init, shape)
|
|
|
|
|
| class Conv2d(nn.Module):
|
| """Conv2d layer with optimal upsampling and downsampling (StyleGAN2)."""
|
|
|
| def __init__(self, in_ch, out_ch, kernel, up=False, down=False,
|
| resample_kernel=(1, 3, 3, 1),
|
| use_bias=True,
|
| kernel_init=None):
|
| super().__init__()
|
| assert not (up and down)
|
| assert kernel >= 1 and kernel % 2 == 1
|
| self.weight = nn.Parameter(torch.zeros(out_ch, in_ch, kernel, kernel))
|
| if kernel_init is not None:
|
| self.weight.data = kernel_init(self.weight.data.shape)
|
| if use_bias:
|
| self.bias = nn.Parameter(torch.zeros(out_ch))
|
|
|
| self.up = up
|
| self.down = down
|
| self.resample_kernel = resample_kernel
|
| self.kernel = kernel
|
| self.use_bias = use_bias
|
|
|
| def forward(self, x):
|
| if self.up:
|
| x = upsample_conv_2d(x, self.weight, k=self.resample_kernel)
|
| elif self.down:
|
| x = conv_downsample_2d(x, self.weight, k=self.resample_kernel)
|
| else:
|
| x = F.conv2d(x, self.weight, stride=1, padding=self.kernel // 2)
|
|
|
| if self.use_bias:
|
| x = x + self.bias.reshape(1, -1, 1, 1)
|
|
|
| return x
|
|
|
|
|
| def naive_upsample_2d(x, factor=2):
|
| _N, C, H, W = x.shape
|
| x = torch.reshape(x, (-1, C, H, 1, W, 1))
|
| x = x.repeat(1, 1, 1, factor, 1, factor)
|
| return torch.reshape(x, (-1, C, H * factor, W * factor))
|
|
|
|
|
| def naive_downsample_2d(x, factor=2):
|
| _N, C, H, W = x.shape
|
| x = torch.reshape(x, (-1, C, H // factor, factor, W // factor, factor))
|
| return torch.mean(x, dim=(3, 5))
|
|
|
|
|
| def upsample_conv_2d(x, w, k=None, factor=2, gain=1):
|
| """Fused `upsample_2d()` followed by `tf.nn.conv2d()`.
|
|
|
| Padding is performed only once at the beginning, not between the
|
| operations.
|
| The fused op is considerably more efficient than performing the same
|
| calculation
|
| using standard TensorFlow ops. It supports gradients of arbitrary order.
|
| Args:
|
| x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
|
| C]`.
|
| w: Weight tensor of the shape `[filterH, filterW, inChannels,
|
| outChannels]`. Grouped convolution can be performed by `inChannels =
|
| x.shape[0] // numGroups`.
|
| k: FIR filter of the shape `[firH, firW]` or `[firN]`
|
| (separable). The default is `[1] * factor`, which corresponds to
|
| nearest-neighbor upsampling.
|
| factor: Integer upsampling factor (default: 2).
|
| gain: Scaling factor for signal magnitude (default: 1.0).
|
|
|
| Returns:
|
| Tensor of the shape `[N, C, H * factor, W * factor]` or
|
| `[N, H * factor, W * factor, C]`, and same datatype as `x`.
|
| """
|
|
|
| assert isinstance(factor, int) and factor >= 1
|
|
|
|
|
| assert len(w.shape) == 4
|
| convH = w.shape[2]
|
| convW = w.shape[3]
|
| inC = w.shape[1]
|
| outC = w.shape[0]
|
|
|
| assert convW == convH
|
|
|
|
|
| if k is None:
|
| k = [1] * factor
|
| k = _setup_kernel(k) * (gain * (factor ** 2))
|
| p = (k.shape[0] - factor) - (convW - 1)
|
|
|
| stride = (factor, factor)
|
|
|
|
|
| stride = [1, 1, factor, factor]
|
| output_shape = ((_shape(x, 2) - 1) * factor + convH, (_shape(x, 3) - 1) * factor + convW)
|
| output_padding = (output_shape[0] - (_shape(x, 2) - 1) * stride[0] - convH,
|
| output_shape[1] - (_shape(x, 3) - 1) * stride[1] - convW)
|
| assert output_padding[0] >= 0 and output_padding[1] >= 0
|
| num_groups = _shape(x, 1) // inC
|
|
|
|
|
| w = torch.reshape(w, (num_groups, -1, inC, convH, convW))
|
| w = w[..., ::-1, ::-1].permute(0, 2, 1, 3, 4)
|
| w = torch.reshape(w, (num_groups * inC, -1, convH, convW))
|
|
|
| x = F.conv_transpose2d(x, w, stride=stride, output_padding=output_padding, padding=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| return upfirdn2d(x, torch.tensor(k, device=x.device),
|
| pad=((p + 1) // 2 + factor - 1, p // 2 + 1))
|
|
|
|
|
| def conv_downsample_2d(x, w, k=None, factor=2, gain=1):
|
| """Fused `tf.nn.conv2d()` followed by `downsample_2d()`.
|
|
|
| Padding is performed only once at the beginning, not between the operations.
|
| The fused op is considerably more efficient than performing the same
|
| calculation
|
| using standard TensorFlow ops. It supports gradients of arbitrary order.
|
| Args:
|
| x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
|
| C]`.
|
| w: Weight tensor of the shape `[filterH, filterW, inChannels,
|
| outChannels]`. Grouped convolution can be performed by `inChannels =
|
| x.shape[0] // numGroups`.
|
| k: FIR filter of the shape `[firH, firW]` or `[firN]`
|
| (separable). The default is `[1] * factor`, which corresponds to
|
| average pooling.
|
| factor: Integer downsampling factor (default: 2).
|
| gain: Scaling factor for signal magnitude (default: 1.0).
|
|
|
| Returns:
|
| Tensor of the shape `[N, C, H // factor, W // factor]` or
|
| `[N, H // factor, W // factor, C]`, and same datatype as `x`.
|
| """
|
|
|
| assert isinstance(factor, int) and factor >= 1
|
| _outC, _inC, convH, convW = w.shape
|
| assert convW == convH
|
| if k is None:
|
| k = [1] * factor
|
| k = _setup_kernel(k) * gain
|
| p = (k.shape[0] - factor) + (convW - 1)
|
| s = [factor, factor]
|
| x = upfirdn2d(x, torch.tensor(k, device=x.device),
|
| pad=((p + 1) // 2, p // 2))
|
| return F.conv2d(x, w, stride=s, padding=0)
|
|
|
|
|
| def _setup_kernel(k):
|
| k = np.asarray(k, dtype=np.float32)
|
| if k.ndim == 1:
|
| k = np.outer(k, k)
|
| k /= np.sum(k)
|
| assert k.ndim == 2
|
| assert k.shape[0] == k.shape[1]
|
| return k
|
|
|
|
|
| def _shape(x, dim):
|
| return x.shape[dim]
|
|
|
|
|
| def upsample_2d(x, k=None, factor=2, gain=1):
|
| r"""Upsample a batch of 2D images with the given filter.
|
|
|
| Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`
|
| and upsamples each image with the given filter. The filter is normalized so
|
| that
|
| if the input pixels are constant, they will be scaled by the specified
|
| `gain`.
|
| Pixels outside the image are assumed to be zero, and the filter is padded
|
| with
|
| zeros so that its shape is a multiple of the upsampling factor.
|
| Args:
|
| x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
|
| C]`.
|
| k: FIR filter of the shape `[firH, firW]` or `[firN]`
|
| (separable). The default is `[1] * factor`, which corresponds to
|
| nearest-neighbor upsampling.
|
| factor: Integer upsampling factor (default: 2).
|
| gain: Scaling factor for signal magnitude (default: 1.0).
|
|
|
| Returns:
|
| Tensor of the shape `[N, C, H * factor, W * factor]`
|
| """
|
| assert isinstance(factor, int) and factor >= 1
|
| if k is None:
|
| k = [1] * factor
|
| k = _setup_kernel(k) * (gain * (factor ** 2))
|
| p = k.shape[0] - factor
|
| return upfirdn2d(x, torch.tensor(k, device=x.device),
|
| up=factor, pad=((p + 1) // 2 + factor - 1, p // 2))
|
|
|
|
|
| def downsample_2d(x, k=None, factor=2, gain=1):
|
| r"""Downsample a batch of 2D images with the given filter.
|
|
|
| Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`
|
| and downsamples each image with the given filter. The filter is normalized
|
| so that
|
| if the input pixels are constant, they will be scaled by the specified
|
| `gain`.
|
| Pixels outside the image are assumed to be zero, and the filter is padded
|
| with
|
| zeros so that its shape is a multiple of the downsampling factor.
|
| Args:
|
| x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
|
| C]`.
|
| k: FIR filter of the shape `[firH, firW]` or `[firN]`
|
| (separable). The default is `[1] * factor`, which corresponds to
|
| average pooling.
|
| factor: Integer downsampling factor (default: 2).
|
| gain: Scaling factor for signal magnitude (default: 1.0).
|
|
|
| Returns:
|
| Tensor of the shape `[N, C, H // factor, W // factor]`
|
| """
|
|
|
| assert isinstance(factor, int) and factor >= 1
|
| if k is None:
|
| k = [1] * factor
|
| k = _setup_kernel(k) * gain
|
| p = k.shape[0] - factor
|
| return upfirdn2d(x, torch.tensor(k, device=x.device),
|
| down=factor, pad=((p + 1) // 2, p // 2))
|
|
|