from functools import partial import math from typing import List, Optional import torch from torch import Tensor from torch import nn from torch.nn import functional as F # Settings for GroupNorm and Attention GN_GROUP_SIZE = 32 GN_EPS = 1e-5 ATTN_HEAD_DIM = 8 # Convs Conv1x1 = partial(nn.Conv2d, kernel_size=1, stride=1, padding=0) Conv3x3 = partial(nn.Conv2d, kernel_size=3, stride=1, padding=1) # GroupNorm and conditional GroupNorm class GroupNorm(nn.Module): def __init__(self, in_channels: int) -> None: super().__init__() num_groups = max(1, in_channels // GN_GROUP_SIZE) self.norm = nn.GroupNorm(num_groups, in_channels, eps=GN_EPS) def forward(self, x: Tensor) -> Tensor: return self.norm(x) class AdaGroupNorm(nn.Module): def __init__(self, in_channels: int, cond_channels: int) -> None: super().__init__() self.in_channels = in_channels self.num_groups = max(1, in_channels // GN_GROUP_SIZE) self.linear = nn.Linear(cond_channels, in_channels * 2) def forward(self, x: Tensor, cond: Tensor) -> Tensor: assert x.size(1) == self.in_channels x = F.group_norm(x, self.num_groups, eps=GN_EPS) scale, shift = self.linear(cond)[:, :, None, None].chunk(2, dim=1) return x * (1 + scale) + shift # Self Attention class SelfAttention2d(nn.Module): def __init__(self, in_channels: int, head_dim: int = ATTN_HEAD_DIM) -> None: super().__init__() self.n_head = max(1, in_channels // head_dim) assert in_channels % self.n_head == 0 self.norm = GroupNorm(in_channels) self.qkv_proj = Conv1x1(in_channels, in_channels * 3) self.out_proj = Conv1x1(in_channels, in_channels) nn.init.zeros_(self.out_proj.weight) nn.init.zeros_(self.out_proj.bias) def forward(self, x: Tensor) -> Tensor: n, c, h, w = x.shape x = self.norm(x) qkv = self.qkv_proj(x) qkv = qkv.view(n, self.n_head * 3, c // self.n_head, h * w).transpose(2, 3).contiguous() q, k, v = [x for x in qkv.chunk(3, dim=1)] att = (q @ k.transpose(-2, -1)) / math.sqrt(k.size(-1)) att = F.softmax(att, dim=-1) y = att @ v y = y.transpose(2, 3).reshape(n, c, h, w) return x + self.out_proj(y) # Embedding of the noise level class FourierFeatures(nn.Module): def __init__(self, cond_channels: int) -> None: super().__init__() assert cond_channels % 2 == 0 self.register_buffer("weight", torch.randn(1, cond_channels // 2)) def forward(self, input: Tensor) -> Tensor: assert input.ndim == 1 f = 2 * math.pi * input.unsqueeze(1) @ self.weight return torch.cat([f.cos(), f.sin()], dim=-1) # [Down|Up]sampling class Downsample(nn.Module): def __init__(self, in_channels: int) -> None: super().__init__() self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=1) nn.init.orthogonal_(self.conv.weight) def forward(self, x: Tensor) -> Tensor: return self.conv(x) class Upsample(nn.Module): def __init__(self, in_channels: int) -> None: super().__init__() self.conv = Conv3x3(in_channels, in_channels) def forward(self, x: Tensor) -> Tensor: x = F.interpolate(x, scale_factor=2.0, mode="nearest") return self.conv(x) # Small Residual block class SmallResBlock(nn.Module): def __init__(self, in_channels: int, out_channels: int) -> None: super().__init__() self.f = nn.Sequential(GroupNorm(in_channels), nn.SiLU(inplace=True), Conv3x3(in_channels, out_channels)) self.skip_projection = nn.Identity() if in_channels == out_channels else Conv1x1(in_channels, out_channels) def forward(self, x: Tensor) -> Tensor: return self.skip_projection(x) + self.f(x) # Residual block (conditioning with AdaGroupNorm, no [down|up]sampling, optional self-attention) class ResBlock(nn.Module): def __init__(self, in_channels: int, out_channels: int, cond_channels: int, attn: bool) -> None: super().__init__() should_proj = in_channels != out_channels self.proj = Conv1x1(in_channels, out_channels) if should_proj else nn.Identity() self.norm1 = AdaGroupNorm(in_channels, cond_channels) self.conv1 = Conv3x3(in_channels, out_channels) self.norm2 = AdaGroupNorm(out_channels, cond_channels) self.conv2 = Conv3x3(out_channels, out_channels) self.attn = SelfAttention2d(out_channels) if attn else nn.Identity() nn.init.zeros_(self.conv2.weight) def forward(self, x: Tensor, cond: Tensor) -> Tensor: r = self.proj(x) x = self.conv1(F.silu(self.norm1(x, cond))) x = self.conv2(F.silu(self.norm2(x, cond))) x = x + r x = self.attn(x) return x # Sequence of residual blocks (in_channels -> mid_channels -> ... -> mid_channels -> out_channels) class ResBlocks(nn.Module): def __init__( self, list_in_channels: List[int], list_out_channels: List[int], cond_channels: int, attn: bool, ) -> None: super().__init__() assert len(list_in_channels) == len(list_out_channels) self.in_channels = list_in_channels[0] self.resblocks = nn.ModuleList( [ ResBlock(in_ch, out_ch, cond_channels, attn) for (in_ch, out_ch) in zip(list_in_channels, list_out_channels) ] ) def forward(self, x: Tensor, cond: Tensor, to_cat: Optional[List[Tensor]] = None) -> Tensor: outputs = [] for i, resblock in enumerate(self.resblocks): x = x if to_cat is None else torch.cat((x, to_cat[i]), dim=1) x = resblock(x, cond) outputs.append(x) return x, outputs # UNet class UNet(nn.Module): def __init__(self, cond_channels: int, depths: List[int], channels: List[int], attn_depths: List[int]) -> None: super().__init__() assert len(depths) == len(channels) == len(attn_depths) self._num_down = len(channels) - 1 d_blocks, u_blocks = [], [] for i, n in enumerate(depths): c1 = channels[max(0, i - 1)] c2 = channels[i] d_blocks.append( ResBlocks( list_in_channels=[c1] + [c2] * (n - 1), list_out_channels=[c2] * n, cond_channels=cond_channels, attn=attn_depths[i], ) ) u_blocks.append( ResBlocks( list_in_channels=[2 * c2] * n + [c1 + c2], list_out_channels=[c2] * n + [c1], cond_channels=cond_channels, attn=attn_depths[i], ) ) self.d_blocks = nn.ModuleList(d_blocks) self.u_blocks = nn.ModuleList(reversed(u_blocks)) self.mid_blocks = ResBlocks( list_in_channels=[channels[-1]] * 2, list_out_channels=[channels[-1]] * 2, cond_channels=cond_channels, attn=True, ) downsamples = [nn.Identity()] + [Downsample(c) for c in channels[:-1]] upsamples = [nn.Identity()] + [Upsample(c) for c in reversed(channels[:-1])] self.downsamples = nn.ModuleList(downsamples) self.upsamples = nn.ModuleList(upsamples) def forward(self, x: Tensor, cond: Tensor) -> Tensor: *_, h, w = x.size() n = self._num_down padding_h = math.ceil(h / 2 ** n) * 2 ** n - h padding_w = math.ceil(w / 2 ** n) * 2 ** n - w x = F.pad(x, (0, padding_w, 0, padding_h)) d_outputs = [] for block, down in zip(self.d_blocks, self.downsamples): x_down = down(x) x, block_outputs = block(x_down, cond) d_outputs.append((x_down, *block_outputs)) x, _ = self.mid_blocks(x, cond) u_outputs = [] for block, up, skip in zip(self.u_blocks, self.upsamples, reversed(d_outputs)): x_up = up(x) x, block_outputs = block(x_up, cond, skip[::-1]) u_outputs.append((x_up, *block_outputs)) x = x[..., :h, :w] return x, d_outputs, u_outputs