File size: 4,174 Bytes
aecc64d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class SpatialAttention(nn.Module):
  def __init__(self, in_c):
    super().__init__()
    self.norm = nn.GroupNorm(num_groups=32, num_channels=in_c, eps=1e-6, affine=True)
    self.Q = nn.Conv2d(in_c, in_c, kernel_size=1, stride=1, padding=0)
    self.K = nn.Conv2d(in_c, in_c, kernel_size=1, stride=1, padding=0)
    self.V = nn.Conv2d(in_c, in_c, kernel_size=1, stride=1, padding=0)
    self.proj = nn.Conv2d(in_c, in_c, kernel_size=1, stride=1, padding=0)

  def forward(self, x):
    b, c, h, w = x.shape
    R = self.norm(x)
    q, v, k = self.Q(R), self.V(R), self.K(R)
    q, v, k = q.reshape(b, c, h*w), v.reshape(b, c, h*w), k.reshape(b, c, h*w)
    q, v, k = q.permute(0, 2, 1), v, k
    R = torch.bmm(q, k) * (1.0 / math.sqrt(c))
    R = F.softmax(R, dim=2)
    R = torch.bmm(v, R)
    R = R.reshape(b, c, h, w)
    return self.proj(R) + x

class ResBlock(nn.Module):
  def __init__(self, in_c, out_c):
    super().__init__()
    self.reshape = False
    if in_c != out_c:
      self.reshape = True
      self.conv_reshape = nn.Conv2d(in_c, out_c, kernel_size=3, stride=1, padding=1)
    self.norm1 = nn.GroupNorm(num_groups=32, num_channels=out_c, eps=1e-6, affine=True)
    self.conv1 = nn.Conv2d(out_c, out_c, kernel_size=3, stride=1, padding=1)
    self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_c, eps=1e-6, affine=True)
    self.conv2 = nn.Conv2d(out_c, out_c, kernel_size=3, stride=1, padding=1)

  def forward(self, x):
    if self.reshape:
      x = self.conv_reshape(x)
    res = x
    x = self.norm1(x)
    x = x * torch.sigmoid(x)
    x = self.conv1(x)
    x = self.norm2(x)
    x = x * torch.sigmoid(x)
    x = self.conv2(x)
    x = x + res
    return x

class Model(nn.Module):
  def __init__(self, T=1000, filters=[32, 64, 96, 128], depth=2, t_dim=512, LDM=False):
    super().__init__()
    self.t_dim = t_dim
    self.T = T
    self.conv_in = nn.Conv2d(4 + self.t_dim if LDM else 3 + self.t_dim, filters[0], kernel_size=1)
    self.down = nn.ModuleList([])
    for i in range(1,len(filters)):
      block = nn.Module()
      block.Blocks = nn.ModuleList([ResBlock(filters[i-1], filters[i])])
      for _ in range(1, depth):
        block.Blocks.append(ResBlock(filters[i], filters[i]))
      block.DownSample = nn.Conv2d(filters[i], filters[i], kernel_size=3, stride=2, padding=1)
      self.down.append(block)

    self.mid = nn.Sequential(ResBlock(filters[-1], filters[-1]),
                             SpatialAttention(filters[-1]),
                             ResBlock(filters[-1], filters[-1]))

    self.up = nn.ModuleList([])
    filters = filters[::-1]
    for i in range(1,len(filters)):
      block = nn.Module()
      block.Blocks = nn.ModuleList([ResBlock(filters[i-1]*2, filters[i])])
      for _ in range(1, depth):
        block.Blocks.append(ResBlock(filters[i], filters[i]))
      block.UpSample = nn.Upsample(scale_factor=2, mode="bilinear")
      self.up.append(block)
    self.conv_out = nn.Conv2d(filters[-1], 4 if LDM else 3, kernel_size=3, padding=1)

  def get_sinusoidal_emb(self, t):
    """ Recieves B 1 shaped t tensor with scalar timesteps, returns B D embeddings """
    freqs = torch.exp(-math.log(self.T) * torch.arange(start=0, end=self.t_dim // 2, dtype=torch.float32) / (self.t_dim // 2)).to(device=t.device)
    args = t[:, None].float() * freqs[None]
    return torch.cat([torch.cos(args), torch.sin(args)], dim=-1)

  def forward(self, x, t):
    t_emb = self.get_sinusoidal_emb(t)
    B, C, H, W = x.shape

    t_emb = t_emb.unsqueeze(-1).unsqueeze(-1).expand(B, self.t_dim, H, W)
    x = torch.cat((x,t_emb), 1)
    x = self.conv_in(x)

    cache = []
    for block in self.down:
      for resblock in block.Blocks:
        x = resblock(x)
      cache.append(x.clone())
      x = block.DownSample(x)
    x = self.mid(x)
    for block in self.up:
      x = block.UpSample(x)
      x = torch.cat((x, cache.pop()), 1)
      for resblock in block.Blocks:
        x = resblock(x)

    return (self.conv_out(x))