| |
| |
|
|
| import torch |
| import torch.nn as nn |
| from torch import Tensor |
|
|
|
|
| def swish(x: Tensor) -> Tensor: |
| return x * torch.sigmoid(x) |
|
|
|
|
| class Downsample(nn.Module): |
| def __init__(self, in_channels: int): |
| super().__init__() |
| self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| x = nn.functional.pad(x, (0, 1, 0, 1), mode="constant", value=0) |
| return self.conv(x) |
|
|
|
|
| class Upsample(nn.Module): |
| def __init__(self, in_channels: int): |
| super().__init__() |
| self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") |
| return self.conv(x) |
|
|
|
|
| class ResnetBlock(nn.Module): |
| def __init__(self, in_channels: int, out_channels: int, cond_dim: int = None): |
| super().__init__() |
| self.in_channels = in_channels |
| self.out_channels = out_channels if out_channels is not None else in_channels |
| self.cond_dim = cond_dim |
|
|
| self.norm1 = nn.GroupNorm(32, in_channels, eps=1e-6, affine=True) |
| self.conv1 = nn.Conv2d(in_channels, self.out_channels, 3, stride=1, padding=1) |
|
|
| if self.cond_dim is not None: |
| self.emb_proj = nn.Linear(cond_dim, self.out_channels * 2) |
| nn.init.zeros_(self.emb_proj.bias) |
| self.emb_proj.weight.data.zero_() |
| self.emb_proj.bias.data[: self.out_channels] = 1.0 |
|
|
| self.norm2 = nn.GroupNorm(32, self.out_channels, eps=1e-6, affine=True) |
| self.conv2 = nn.Conv2d(self.out_channels, self.out_channels, 3, stride=1, padding=1) |
| self.nin_shortcut = ( |
| nn.Conv2d(in_channels, self.out_channels, 1, stride=1, padding=0) |
| if in_channels != self.out_channels |
| else nn.Identity() |
| ) |
|
|
| def forward(self, x: Tensor, emb: Tensor = None) -> Tensor: |
| h = self.norm1(x) |
| h = swish(h) |
| h = self.conv1(h) |
|
|
| if self.cond_dim is not None and emb is not None: |
| style = self.emb_proj(emb).unsqueeze(-1).unsqueeze(-1) |
| scale, shift = style.chunk(2, dim=1) |
| h = self.norm2(h) |
| h = h * scale + shift |
| else: |
| h = self.norm2(h) |
|
|
| h = swish(h) |
| h = self.conv2(h) |
| return h + self.nin_shortcut(x) |
|
|
|
|
| class AttnBlock(nn.Module): |
| def __init__(self, in_channels: int): |
| super().__init__() |
| self.norm = nn.GroupNorm(32, in_channels, eps=1e-6, affine=True) |
| self.q = nn.Conv2d(in_channels, in_channels, 1) |
| self.k = nn.Conv2d(in_channels, in_channels, 1) |
| self.v = nn.Conv2d(in_channels, in_channels, 1) |
| self.proj_out = nn.Conv2d(in_channels, in_channels, 1) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| h_ = self.norm(x) |
| q, k, v = self.q(h_), self.k(h_), self.v(h_) |
| b, c, h, w = q.shape |
| q = q.flatten(2).transpose(1, 2).unsqueeze(1) |
| k = k.flatten(2).transpose(1, 2).unsqueeze(1) |
| v = v.flatten(2).transpose(1, 2).unsqueeze(1) |
| h_ = torch.nn.functional.scaled_dot_product_attention(q, k, v) |
| h_ = h_.squeeze(1).transpose(1, 2).view(b, c, h, w) |
| return x + self.proj_out(h_) |
|
|