| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | """Layers for defining NCSN++. |
| | """ |
| | from . import layers |
| | from . import up_or_down_sampling |
| | import torch.nn as nn |
| | import torch |
| | import torch.nn.functional as F |
| | import numpy as np |
| |
|
| | conv1x1 = layers.ddpm_conv1x1 |
| | conv3x3 = layers.ddpm_conv3x3 |
| | NIN = layers.NIN |
| | default_init = layers.default_init |
| |
|
| |
|
| | class GaussianFourierProjection(nn.Module): |
| | """Gaussian Fourier embeddings for noise levels.""" |
| |
|
| | def __init__(self, embedding_size=256, scale=1.0): |
| | super().__init__() |
| | self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) |
| |
|
| | def forward(self, x): |
| | x_proj = x[:, None] * self.W[None, :] * 2 * np.pi |
| | return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) |
| |
|
| |
|
| | class Combine(nn.Module): |
| | """Combine information from skip connections.""" |
| |
|
| | def __init__(self, dim1, dim2, method='cat'): |
| | super().__init__() |
| | self.Conv_0 = conv1x1(dim1, dim2) |
| | self.method = method |
| |
|
| | def forward(self, x, y): |
| | h = self.Conv_0(x) |
| | if self.method == 'cat': |
| | return torch.cat([h, y], dim=1) |
| | elif self.method == 'sum': |
| | return h + y |
| | else: |
| | raise ValueError(f'Method {self.method} not recognized.') |
| |
|
| |
|
| | class AttnBlockpp(nn.Module): |
| | """Channel-wise self-attention block. Modified from DDPM.""" |
| |
|
| | def __init__(self, channels, skip_rescale=False, init_scale=0.): |
| | super().__init__() |
| | self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels, |
| | eps=1e-6) |
| | self.NIN_0 = NIN(channels, channels) |
| | self.NIN_1 = NIN(channels, channels) |
| | self.NIN_2 = NIN(channels, channels) |
| | self.NIN_3 = NIN(channels, channels, init_scale=init_scale) |
| | self.skip_rescale = skip_rescale |
| |
|
| | def forward(self, x): |
| | B, C, H, W = x.shape |
| | h = self.GroupNorm_0(x) |
| | q = self.NIN_0(h) |
| | k = self.NIN_1(h) |
| | v = self.NIN_2(h) |
| |
|
| | w = torch.einsum('bchw,bcij->bhwij', q, k) * (int(C) ** (-0.5)) |
| | w = torch.reshape(w, (B, H, W, H * W)) |
| | w = F.softmax(w, dim=-1) |
| | w = torch.reshape(w, (B, H, W, H, W)) |
| | h = torch.einsum('bhwij,bcij->bchw', w, v) |
| | h = self.NIN_3(h) |
| | if not self.skip_rescale: |
| | return x + h |
| | else: |
| | return (x + h) / np.sqrt(2.) |
| |
|
| |
|
| | class Upsample(nn.Module): |
| | def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, |
| | fir_kernel=(1, 3, 3, 1)): |
| | super().__init__() |
| | out_ch = out_ch if out_ch else in_ch |
| | if not fir: |
| | if with_conv: |
| | self.Conv_0 = conv3x3(in_ch, out_ch) |
| | else: |
| | if with_conv: |
| | self.Conv2d_0 = up_or_down_sampling.Conv2d(in_ch, out_ch, |
| | kernel=3, up=True, |
| | resample_kernel=fir_kernel, |
| | use_bias=True, |
| | kernel_init=default_init()) |
| | self.fir = fir |
| | self.with_conv = with_conv |
| | self.fir_kernel = fir_kernel |
| | self.out_ch = out_ch |
| |
|
| | def forward(self, x): |
| | B, C, H, W = x.shape |
| | if not self.fir: |
| | h = F.interpolate(x, (H * 2, W * 2), 'nearest') |
| | if self.with_conv: |
| | h = self.Conv_0(h) |
| | else: |
| | if not self.with_conv: |
| | h = up_or_down_sampling.upsample_2d(x, self.fir_kernel, factor=2) |
| | else: |
| | h = self.Conv2d_0(x) |
| |
|
| | return h |
| |
|
| |
|
| | class Downsample(nn.Module): |
| | def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, |
| | fir_kernel=(1, 3, 3, 1)): |
| | super().__init__() |
| | out_ch = out_ch if out_ch else in_ch |
| | if not fir: |
| | if with_conv: |
| | self.Conv_0 = conv3x3(in_ch, out_ch, stride=2, padding=0) |
| | else: |
| | if with_conv: |
| | self.Conv2d_0 = up_or_down_sampling.Conv2d(in_ch, out_ch, |
| | kernel=3, down=True, |
| | resample_kernel=fir_kernel, |
| | use_bias=True, |
| | kernel_init=default_init()) |
| | self.fir = fir |
| | self.fir_kernel = fir_kernel |
| | self.with_conv = with_conv |
| | self.out_ch = out_ch |
| |
|
| | def forward(self, x): |
| | B, C, H, W = x.shape |
| | if not self.fir: |
| | if self.with_conv: |
| | x = F.pad(x, (0, 1, 0, 1)) |
| | x = self.Conv_0(x) |
| | else: |
| | x = F.avg_pool2d(x, 2, stride=2) |
| | else: |
| | if not self.with_conv: |
| | x = up_or_down_sampling.downsample_2d(x, self.fir_kernel, factor=2) |
| | else: |
| | x = self.Conv2d_0(x) |
| |
|
| | return x |
| |
|
| |
|
| | class ResnetBlockDDPMpp(nn.Module): |
| | """ResBlock adapted from DDPM.""" |
| |
|
| | def __init__(self, act, in_ch, out_ch=None, temb_dim=None, conv_shortcut=False, |
| | dropout=0.1, skip_rescale=False, init_scale=0.): |
| | super().__init__() |
| | out_ch = out_ch if out_ch else in_ch |
| | self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6) |
| | self.Conv_0 = conv3x3(in_ch, out_ch) |
| | if temb_dim is not None: |
| | self.Dense_0 = nn.Linear(temb_dim, out_ch) |
| | self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape) |
| | nn.init.zeros_(self.Dense_0.bias) |
| | self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6) |
| | self.Dropout_0 = nn.Dropout(dropout) |
| | self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale) |
| | if in_ch != out_ch: |
| | if conv_shortcut: |
| | self.Conv_2 = conv3x3(in_ch, out_ch) |
| | else: |
| | self.NIN_0 = NIN(in_ch, out_ch) |
| |
|
| | self.skip_rescale = skip_rescale |
| | self.act = act |
| | self.out_ch = out_ch |
| | self.conv_shortcut = conv_shortcut |
| |
|
| | def forward(self, x, temb=None): |
| | h = self.act(self.GroupNorm_0(x)) |
| | h = self.Conv_0(h) |
| | if temb is not None: |
| | h += self.Dense_0(self.act(temb))[:, :, None, None] |
| | h = self.act(self.GroupNorm_1(h)) |
| | h = self.Dropout_0(h) |
| | h = self.Conv_1(h) |
| | if x.shape[1] != self.out_ch: |
| | if self.conv_shortcut: |
| | x = self.Conv_2(x) |
| | else: |
| | x = self.NIN_0(x) |
| | if not self.skip_rescale: |
| | return x + h |
| | else: |
| | return (x + h) / np.sqrt(2.) |
| |
|
| |
|
| | class ResnetBlockBigGANpp(nn.Module): |
| | def __init__(self, act, in_ch, out_ch=None, temb_dim=None, up=False, down=False, |
| | dropout=0.1, fir=False, fir_kernel=(1, 3, 3, 1), |
| | skip_rescale=True, init_scale=0.): |
| | super().__init__() |
| |
|
| | out_ch = out_ch if out_ch else in_ch |
| | self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6) |
| | self.up = up |
| | self.down = down |
| | self.fir = fir |
| | self.fir_kernel = fir_kernel |
| |
|
| | self.Conv_0 = conv3x3(in_ch, out_ch) |
| | if temb_dim is not None: |
| | self.Dense_0 = nn.Linear(temb_dim, out_ch) |
| | self.Dense_0.weight.data = default_init()(self.Dense_0.weight.shape) |
| | nn.init.zeros_(self.Dense_0.bias) |
| |
|
| | self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6) |
| | self.Dropout_0 = nn.Dropout(dropout) |
| | self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale) |
| | if in_ch != out_ch or up or down: |
| | self.Conv_2 = conv1x1(in_ch, out_ch) |
| |
|
| | self.skip_rescale = skip_rescale |
| | self.act = act |
| | self.in_ch = in_ch |
| | self.out_ch = out_ch |
| |
|
| | def forward(self, x, temb=None): |
| | h = self.act(self.GroupNorm_0(x)) |
| |
|
| | if self.up: |
| | if self.fir: |
| | h = up_or_down_sampling.upsample_2d(h, self.fir_kernel, factor=2) |
| | x = up_or_down_sampling.upsample_2d(x, self.fir_kernel, factor=2) |
| | else: |
| | h = up_or_down_sampling.naive_upsample_2d(h, factor=2) |
| | x = up_or_down_sampling.naive_upsample_2d(x, factor=2) |
| | elif self.down: |
| | if self.fir: |
| | h = up_or_down_sampling.downsample_2d(h, self.fir_kernel, factor=2) |
| | x = up_or_down_sampling.downsample_2d(x, self.fir_kernel, factor=2) |
| | else: |
| | h = up_or_down_sampling.naive_downsample_2d(h, factor=2) |
| | x = up_or_down_sampling.naive_downsample_2d(x, factor=2) |
| |
|
| | h = self.Conv_0(h) |
| | |
| | if temb is not None: |
| | h += self.Dense_0(self.act(temb))[:, :, None, None] |
| | h = self.act(self.GroupNorm_1(h)) |
| | h = self.Dropout_0(h) |
| | h = self.Conv_1(h) |
| |
|
| | if self.in_ch != self.out_ch or self.up or self.down: |
| | x = self.Conv_2(x) |
| |
|
| | if not self.skip_rescale: |
| | return x + h |
| | else: |
| | return (x + h) / np.sqrt(2.) |
| |
|