YashNagraj75 commited on
Commit
3b3d382
·
1 Parent(s): 763c0f9

Almost finished with Down Block

Browse files
Files changed (1) hide show
  1. model_blocks/blocks.py +185 -0
model_blocks/blocks.py CHANGED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ def get_time_embedding(time_steps, temb_dim):
6
+ r"""
7
+ Convert time steps tensor into an embedding using the
8
+ sinusoidal time embedding formula
9
+ :param time_steps: 1D tensor of length batch size
10
+ :param temb_dim: Dimension of the embedding
11
+ :return: BxD embedding representation of B time steps
12
+ """
13
+ assert temb_dim % 2 == 0, "time embedding dimension must be divisible by 2"
14
+
15
+ # factor = 10000^(2i/d_model)
16
+ factor = 10000 ** (
17
+ torch.arange(
18
+ start=0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device
19
+ )
20
+ / (temb_dim // 2)
21
+ )
22
+
23
+ # pos / factor
24
+ # timesteps B -> B, 1 -> B, temb_dim
25
+ t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor
26
+ t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1)
27
+ return t_emb
28
+
29
+
30
+ class DownBlock(nn.Module):
31
+ r"""
32
+ DownBlock for Diffusion model:
33
+ a) Block Time embedding -> [Silu -> FC]
34
+
35
+ 1) Resnet Block :- [Norm-> Silu -> Conv] x num_layers
36
+ 2) Self Attention :- [Norm -> SA]
37
+ 3) Cross Attention :- [Norm -> CA]
38
+ b) DownSample : DownSample the dimnension
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ num_heads,
44
+ num_layers,
45
+ cross_attn,
46
+ input_dim,
47
+ output_dim,
48
+ t_emb_dim,
49
+ cond_dim,
50
+ norm_channels,
51
+ self_attn,
52
+ down_sample,
53
+ ) -> None:
54
+ super().__init__()
55
+ self.num_heads = num_heads
56
+ self.num_layers = num_layers
57
+ self.cross_attn = cross_attn
58
+ self.input_dim = input_dim
59
+ self.output_dim = output_dim
60
+ self.cond_dim = cond_dim
61
+ self.norm_channels = norm_channels
62
+ self.t_emb_dim = t_emb_dim
63
+ self.attn = self_attn
64
+ self.down_sample = down_sample
65
+
66
+ self.resnet_in = nn.ModuleList(
67
+ [
68
+ nn.Conv2d(
69
+ self.input_dim if i == 0 else self.output_dim,
70
+ self.output_dim,
71
+ kernel_size=1,
72
+ )
73
+ for i in range(self.num_layers)
74
+ ]
75
+ )
76
+ self.resnet_one = nn.ModuleList(
77
+ [
78
+ nn.Sequential(
79
+ nn.GroupNorm(
80
+ self.norm_channels,
81
+ self.input_dim if i == 0 else self.output_dim,
82
+ ),
83
+ nn.SiLU(),
84
+ nn.Conv2d(
85
+ self.input_dim if i == 0 else self.output_dim,
86
+ self.output_dim,
87
+ kernel_size=3,
88
+ stride=1,
89
+ padding=1,
90
+ ),
91
+ )
92
+ for i in range(self.num_layers)
93
+ ]
94
+ )
95
+
96
+ if self.t_emb_dim is not None:
97
+ self.t_emb_layers = nn.ModuleList(
98
+ [
99
+ nn.Sequential(nn.SiLU(), nn.Linear(self.t_emb_dim, self.output_dim))
100
+ for _ in range(self.num_layers)
101
+ ]
102
+ )
103
+
104
+ self.resnet_two = nn.ModuleList(
105
+ [
106
+ nn.Sequential(
107
+ nn.GroupNorm(
108
+ self.norm_channels,
109
+ self.output_dim,
110
+ ),
111
+ nn.SiLU(),
112
+ nn.Conv2d(
113
+ self.output_dim,
114
+ self.output_dim,
115
+ kernel_size=3,
116
+ stride=1,
117
+ padding=1,
118
+ ),
119
+ )
120
+ for _ in range(self.num_layers)
121
+ ]
122
+ )
123
+
124
+ if self.attn:
125
+ self.attention_norms = nn.ModuleList(
126
+ [
127
+ nn.GroupNorm(self.norm_channels, self.output_dim)
128
+ for _ in range(num_layers)
129
+ ]
130
+ )
131
+ self.attentions = nn.ModuleList(
132
+ [
133
+ nn.MultiheadAttention(
134
+ self.output_dim, self.num_heads, batch_first=True
135
+ )
136
+ for _ in range(self.num_layers)
137
+ ]
138
+ )
139
+
140
+ if self.cross_attn:
141
+ self.cross_attn_norms = nn.ModuleList(
142
+ [
143
+ nn.GroupNorm(self.norm_channels, self.output_dim)
144
+ for _ in range(self.num_layers)
145
+ ]
146
+ )
147
+ self.cross_attentions = nn.ModuleList(
148
+ [
149
+ nn.MultiheadAttention(
150
+ self.output_dim, self.num_heads, batch_first=True
151
+ )
152
+ for _ in range(self.num_layers)
153
+ ]
154
+ )
155
+
156
+ self.context_proj = nn.ModuleList(
157
+ [
158
+ nn.Linear(self.cond_dim, self.output_dim)
159
+ for _ in range(self.num_layers)
160
+ ]
161
+ )
162
+
163
+ self.down_sample_conv = (
164
+ nn.Conv2d(self.output_dim, self.output_dim, 4, 2, 1)
165
+ if self.down_sample
166
+ else nn.Identity()
167
+ )
168
+
169
+ def forward(self, x, t_emb=None, context=None):
170
+ out = x
171
+ for i in range(self.num_layers):
172
+ # Input x to Resnet Block of the Encoder of the Unet
173
+ resnet_input = out
174
+ out = self.resnet_one[i](out)
175
+ if t_emb is not None:
176
+ out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
177
+ out = self.resnet_two[i](out)
178
+ out = out + self.resnet_in[i](resnet_input)
179
+
180
+ if self.attn:
181
+ # Now Passing through the Self Attention blocks
182
+ batch_size, channels, h, w = out.shape
183
+ in_attn = out.reshape(batch_size, channels, h * w)
184
+ in_attn = self.attention_norms[i](in_attn)
185
+ in_attn = in_attn.transpose(1, 2)