YashNagraj75 commited on
Commit
426ee66
·
1 Parent(s): 3b3d382

Add MidBlock

Browse files
Files changed (1) hide show
  1. model_blocks/blocks.py +140 -0
model_blocks/blocks.py CHANGED
@@ -183,3 +183,143 @@ class DownBlock(nn.Module):
183
  in_attn = out.reshape(batch_size, channels, h * w)
184
  in_attn = self.attention_norms[i](in_attn)
185
  in_attn = in_attn.transpose(1, 2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  in_attn = out.reshape(batch_size, channels, h * w)
184
  in_attn = self.attention_norms[i](in_attn)
185
  in_attn = in_attn.transpose(1, 2)
186
+ out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
187
+ out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
188
+ out = out + out_attn
189
+
190
+ if self.cross_attn:
191
+ assert context is not None, (
192
+ "context cannot be None if cross attention layers are used"
193
+ )
194
+ batch_size, channels, h, w = out.shape
195
+ in_attn = out.reshape(batch_size, channels, h * w)
196
+ in_attn = self.cross_attn_norms[i](in_attn)
197
+ in_attn = in_attn.transpose(1, 2)
198
+ assert (
199
+ context.shape[0] == x.shape[0]
200
+ and context.shape[-1] == self.context_dim
201
+ )
202
+ context_proj = self.context_proj[i](context)
203
+ out_attn, _ = self.cross_attentions[i](
204
+ in_attn, context_proj, context_proj
205
+ )
206
+ out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
207
+ out = out + out_attn
208
+
209
+ # DownSample to x2 smaller dimension
210
+ out = self.down_sample_conv(out)
211
+ return out
212
+
213
+
214
+ class MidBlock(nn.Module):
215
+ def __init__(
216
+ self,
217
+ num_heads,
218
+ num_layers,
219
+ cross_attn,
220
+ input_dim,
221
+ output_dim,
222
+ t_emb_dim,
223
+ cond_dim,
224
+ norm_channels,
225
+ self_attn,
226
+ down_sample,
227
+ ) -> None:
228
+ super().__init__()
229
+ self.num_heads = num_heads
230
+ self.num_layers = num_layers
231
+ self.cross_attn = cross_attn
232
+ self.input_dim = input_dim
233
+ self.output_dim = output_dim
234
+ self.cond_dim = cond_dim
235
+ self.norm_channels = norm_channels
236
+ self.t_emb_dim = t_emb_dim
237
+ self.attn = self_attn
238
+ self.down_sample = down_sample
239
+
240
+ self.resnet_one = nn.ModuleList(
241
+ [
242
+ nn.Sequential(
243
+ nn.GroupNorm(
244
+ self.norm_channels,
245
+ self.input_dim if i == 0 else self.output_dim,
246
+ ),
247
+ nn.SiLU(),
248
+ nn.Conv2d(
249
+ self.input_dim if i == 0 else self.output_dim,
250
+ self.output_dim,
251
+ kernel_size=3,
252
+ stride=1,
253
+ padding=1,
254
+ ),
255
+ )
256
+ for i in range(self.num_layers)
257
+ ]
258
+ )
259
+
260
+ if self.t_emb_dim is not None:
261
+ self.t_emb_layers = nn.ModuleList(
262
+ [
263
+ nn.Sequential(nn.SiLU(), nn.Linear(self.t_emb_dim, self.output_dim))
264
+ for _ in range(self.num_layers)
265
+ ]
266
+ )
267
+
268
+ self.resnet_two = nn.ModuleList(
269
+ [
270
+ nn.Sequential(
271
+ nn.GroupNorm(
272
+ self.norm_channels,
273
+ self.output_dim,
274
+ ),
275
+ nn.SiLU(),
276
+ nn.Conv2d(
277
+ self.output_dim,
278
+ self.output_dim,
279
+ kernel_size=3,
280
+ stride=1,
281
+ padding=1,
282
+ ),
283
+ )
284
+ for _ in range(self.num_layers)
285
+ ]
286
+ )
287
+
288
+ if self.attn:
289
+ self.attention_norms = nn.ModuleList(
290
+ [
291
+ nn.GroupNorm(self.norm_channels, self.output_dim)
292
+ for _ in range(num_layers)
293
+ ]
294
+ )
295
+ self.attentions = nn.ModuleList(
296
+ [
297
+ nn.MultiheadAttention(
298
+ self.output_dim, self.num_heads, batch_first=True
299
+ )
300
+ for _ in range(self.num_layers)
301
+ ]
302
+ )
303
+
304
+ if self.cross_attn:
305
+ self.cross_attn_norms = nn.ModuleList(
306
+ [
307
+ nn.GroupNorm(self.norm_channels, self.output_dim)
308
+ for _ in range(self.num_layers)
309
+ ]
310
+ )
311
+ self.cross_attentions = nn.ModuleList(
312
+ [
313
+ nn.MultiheadAttention(
314
+ self.output_dim, self.num_heads, batch_first=True
315
+ )
316
+ for _ in range(self.num_layers)
317
+ ]
318
+ )
319
+
320
+ self.context_proj = nn.ModuleList(
321
+ [
322
+ nn.Linear(self.cond_dim, self.output_dim)
323
+ for _ in range(self.num_layers)
324
+ ]
325
+ )