| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import math | |
| class SpatialAttention(nn.Module): | |
| def __init__(self, in_c): | |
| super().__init__() | |
| self.norm = nn.GroupNorm(num_groups=32, num_channels=in_c, eps=1e-6, affine=True) | |
| self.Q = nn.Conv2d(in_c, in_c, kernel_size=1, stride=1, padding=0) | |
| self.K = nn.Conv2d(in_c, in_c, kernel_size=1, stride=1, padding=0) | |
| self.V = nn.Conv2d(in_c, in_c, kernel_size=1, stride=1, padding=0) | |
| self.proj = nn.Conv2d(in_c, in_c, kernel_size=1, stride=1, padding=0) | |
| def forward(self, x): | |
| b, c, h, w = x.shape | |
| R = self.norm(x) | |
| q, v, k = self.Q(R), self.V(R), self.K(R) | |
| q, v, k = q.reshape(b, c, h*w), v.reshape(b, c, h*w), k.reshape(b, c, h*w) | |
| q, v, k = q.permute(0, 2, 1), v, k | |
| R = torch.bmm(q, k) * (1.0 / math.sqrt(c)) | |
| R = F.softmax(R, dim=2) | |
| R = torch.bmm(v, R) | |
| R = R.reshape(b, c, h, w) | |
| return self.proj(R) + x | |
| class ResBlock(nn.Module): | |
| def __init__(self, in_c, out_c): | |
| super().__init__() | |
| self.reshape = False | |
| if in_c != out_c: | |
| self.reshape = True | |
| self.conv_reshape = nn.Conv2d(in_c, out_c, kernel_size=3, stride=1, padding=1) | |
| self.norm1 = nn.GroupNorm(num_groups=32, num_channels=out_c, eps=1e-6, affine=True) | |
| self.conv1 = nn.Conv2d(out_c, out_c, kernel_size=3, stride=1, padding=1) | |
| self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_c, eps=1e-6, affine=True) | |
| self.conv2 = nn.Conv2d(out_c, out_c, kernel_size=3, stride=1, padding=1) | |
| def forward(self, x): | |
| if self.reshape: | |
| x = self.conv_reshape(x) | |
| res = x | |
| x = self.norm1(x) | |
| x = x * torch.sigmoid(x) | |
| x = self.conv1(x) | |
| x = self.norm2(x) | |
| x = x * torch.sigmoid(x) | |
| x = self.conv2(x) | |
| x = x + res | |
| return x | |
| class Model(nn.Module): | |
| def __init__(self, T=1000, filters=[32, 64, 96, 128], depth=2, t_dim=512, LDM=False): | |
| super().__init__() | |
| self.t_dim = t_dim | |
| self.T = T | |
| self.conv_in = nn.Conv2d(4 + self.t_dim if LDM else 3 + self.t_dim, filters[0], kernel_size=1) | |
| self.down = nn.ModuleList([]) | |
| for i in range(1,len(filters)): | |
| block = nn.Module() | |
| block.Blocks = nn.ModuleList([ResBlock(filters[i-1], filters[i])]) | |
| for _ in range(1, depth): | |
| block.Blocks.append(ResBlock(filters[i], filters[i])) | |
| block.DownSample = nn.Conv2d(filters[i], filters[i], kernel_size=3, stride=2, padding=1) | |
| self.down.append(block) | |
| self.mid = nn.Sequential(ResBlock(filters[-1], filters[-1]), | |
| SpatialAttention(filters[-1]), | |
| ResBlock(filters[-1], filters[-1])) | |
| self.up = nn.ModuleList([]) | |
| filters = filters[::-1] | |
| for i in range(1,len(filters)): | |
| block = nn.Module() | |
| block.Blocks = nn.ModuleList([ResBlock(filters[i-1]*2, filters[i])]) | |
| for _ in range(1, depth): | |
| block.Blocks.append(ResBlock(filters[i], filters[i])) | |
| block.UpSample = nn.Upsample(scale_factor=2, mode="bilinear") | |
| self.up.append(block) | |
| self.conv_out = nn.Conv2d(filters[-1], 4 if LDM else 3, kernel_size=3, padding=1) | |
| def get_sinusoidal_emb(self, t): | |
| """ Recieves B 1 shaped t tensor with scalar timesteps, returns B D embeddings """ | |
| freqs = torch.exp(-math.log(self.T) * torch.arange(start=0, end=self.t_dim // 2, dtype=torch.float32) / (self.t_dim // 2)).to(device=t.device) | |
| args = t[:, None].float() * freqs[None] | |
| return torch.cat([torch.cos(args), torch.sin(args)], dim=-1) | |
| def forward(self, x, t): | |
| t_emb = self.get_sinusoidal_emb(t) | |
| B, C, H, W = x.shape | |
| t_emb = t_emb.unsqueeze(-1).unsqueeze(-1).expand(B, self.t_dim, H, W) | |
| x = torch.cat((x,t_emb), 1) | |
| x = self.conv_in(x) | |
| cache = [] | |
| for block in self.down: | |
| for resblock in block.Blocks: | |
| x = resblock(x) | |
| cache.append(x.clone()) | |
| x = block.DownSample(x) | |
| x = self.mid(x) | |
| for block in self.up: | |
| x = block.UpSample(x) | |
| x = torch.cat((x, cache.pop()), 1) | |
| for resblock in block.Blocks: | |
| x = resblock(x) | |
| return (self.conv_out(x)) |