File size: 4,174 Bytes
aecc64d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class SpatialAttention(nn.Module):
def __init__(self, in_c):
super().__init__()
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_c, eps=1e-6, affine=True)
self.Q = nn.Conv2d(in_c, in_c, kernel_size=1, stride=1, padding=0)
self.K = nn.Conv2d(in_c, in_c, kernel_size=1, stride=1, padding=0)
self.V = nn.Conv2d(in_c, in_c, kernel_size=1, stride=1, padding=0)
self.proj = nn.Conv2d(in_c, in_c, kernel_size=1, stride=1, padding=0)
def forward(self, x):
b, c, h, w = x.shape
R = self.norm(x)
q, v, k = self.Q(R), self.V(R), self.K(R)
q, v, k = q.reshape(b, c, h*w), v.reshape(b, c, h*w), k.reshape(b, c, h*w)
q, v, k = q.permute(0, 2, 1), v, k
R = torch.bmm(q, k) * (1.0 / math.sqrt(c))
R = F.softmax(R, dim=2)
R = torch.bmm(v, R)
R = R.reshape(b, c, h, w)
return self.proj(R) + x
class ResBlock(nn.Module):
def __init__(self, in_c, out_c):
super().__init__()
self.reshape = False
if in_c != out_c:
self.reshape = True
self.conv_reshape = nn.Conv2d(in_c, out_c, kernel_size=3, stride=1, padding=1)
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=out_c, eps=1e-6, affine=True)
self.conv1 = nn.Conv2d(out_c, out_c, kernel_size=3, stride=1, padding=1)
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_c, eps=1e-6, affine=True)
self.conv2 = nn.Conv2d(out_c, out_c, kernel_size=3, stride=1, padding=1)
def forward(self, x):
if self.reshape:
x = self.conv_reshape(x)
res = x
x = self.norm1(x)
x = x * torch.sigmoid(x)
x = self.conv1(x)
x = self.norm2(x)
x = x * torch.sigmoid(x)
x = self.conv2(x)
x = x + res
return x
class Model(nn.Module):
def __init__(self, T=1000, filters=[32, 64, 96, 128], depth=2, t_dim=512, LDM=False):
super().__init__()
self.t_dim = t_dim
self.T = T
self.conv_in = nn.Conv2d(4 + self.t_dim if LDM else 3 + self.t_dim, filters[0], kernel_size=1)
self.down = nn.ModuleList([])
for i in range(1,len(filters)):
block = nn.Module()
block.Blocks = nn.ModuleList([ResBlock(filters[i-1], filters[i])])
for _ in range(1, depth):
block.Blocks.append(ResBlock(filters[i], filters[i]))
block.DownSample = nn.Conv2d(filters[i], filters[i], kernel_size=3, stride=2, padding=1)
self.down.append(block)
self.mid = nn.Sequential(ResBlock(filters[-1], filters[-1]),
SpatialAttention(filters[-1]),
ResBlock(filters[-1], filters[-1]))
self.up = nn.ModuleList([])
filters = filters[::-1]
for i in range(1,len(filters)):
block = nn.Module()
block.Blocks = nn.ModuleList([ResBlock(filters[i-1]*2, filters[i])])
for _ in range(1, depth):
block.Blocks.append(ResBlock(filters[i], filters[i]))
block.UpSample = nn.Upsample(scale_factor=2, mode="bilinear")
self.up.append(block)
self.conv_out = nn.Conv2d(filters[-1], 4 if LDM else 3, kernel_size=3, padding=1)
def get_sinusoidal_emb(self, t):
""" Recieves B 1 shaped t tensor with scalar timesteps, returns B D embeddings """
freqs = torch.exp(-math.log(self.T) * torch.arange(start=0, end=self.t_dim // 2, dtype=torch.float32) / (self.t_dim // 2)).to(device=t.device)
args = t[:, None].float() * freqs[None]
return torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
def forward(self, x, t):
t_emb = self.get_sinusoidal_emb(t)
B, C, H, W = x.shape
t_emb = t_emb.unsqueeze(-1).unsqueeze(-1).expand(B, self.t_dim, H, W)
x = torch.cat((x,t_emb), 1)
x = self.conv_in(x)
cache = []
for block in self.down:
for resblock in block.Blocks:
x = resblock(x)
cache.append(x.clone())
x = block.DownSample(x)
x = self.mid(x)
for block in self.up:
x = block.UpSample(x)
x = torch.cat((x, cache.pop()), 1)
for resblock in block.Blocks:
x = resblock(x)
return (self.conv_out(x)) |