YashNagraj75 commited on
Commit
e1d97a8
·
1 Parent(s): 426ee66

Add UpBlock

Browse files
Files changed (1) hide show
  1. model_blocks/blocks.py +302 -5
model_blocks/blocks.py CHANGED
@@ -1,6 +1,10 @@
 
 
1
  import torch
2
  import torch.nn as nn
3
 
 
 
4
 
5
  def get_time_embedding(time_steps, temb_dim):
6
  r"""
@@ -35,7 +39,7 @@ class DownBlock(nn.Module):
35
  1) Resnet Block :- [Norm-> Silu -> Conv] x num_layers
36
  2) Self Attention :- [Norm -> SA]
37
  3) Cross Attention :- [Norm -> CA]
38
- b) DownSample : DownSample the dimnension
39
  """
40
 
41
  def __init__(
@@ -170,15 +174,29 @@ class DownBlock(nn.Module):
170
  out = x
171
  for i in range(self.num_layers):
172
  # Input x to Resnet Block of the Encoder of the Unet
 
173
  resnet_input = out
174
  out = self.resnet_one[i](out)
175
- if t_emb is not None:
 
 
 
 
 
 
176
  out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
177
  out = self.resnet_two[i](out)
 
 
 
178
  out = out + self.resnet_in[i](resnet_input)
 
 
 
179
 
180
  if self.attn:
181
  # Now Passing through the Self Attention blocks
 
182
  batch_size, channels, h, w = out.shape
183
  in_attn = out.reshape(batch_size, channels, h * w)
184
  in_attn = self.attention_norms[i](in_attn)
@@ -186,11 +204,17 @@ class DownBlock(nn.Module):
186
  out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
187
  out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
188
  out = out + out_attn
 
 
 
189
 
190
  if self.cross_attn:
191
  assert context is not None, (
192
  "context cannot be None if cross attention layers are used"
193
  )
 
 
 
194
  batch_size, channels, h, w = out.shape
195
  in_attn = out.reshape(batch_size, channels, h * w)
196
  in_attn = self.cross_attn_norms[i](in_attn)
@@ -199,19 +223,40 @@ class DownBlock(nn.Module):
199
  context.shape[0] == x.shape[0]
200
  and context.shape[-1] == self.context_dim
201
  )
 
 
 
202
  context_proj = self.context_proj[i](context)
203
  out_attn, _ = self.cross_attentions[i](
204
  in_attn, context_proj, context_proj
205
  )
206
  out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
207
  out = out + out_attn
 
 
 
208
 
209
  # DownSample to x2 smaller dimension
210
  out = self.down_sample_conv(out)
 
211
  return out
212
 
213
 
214
  class MidBlock(nn.Module):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  def __init__(
216
  self,
217
  num_heads,
@@ -253,7 +298,7 @@ class MidBlock(nn.Module):
253
  padding=1,
254
  ),
255
  )
256
- for i in range(self.num_layers)
257
  ]
258
  )
259
 
@@ -261,7 +306,7 @@ class MidBlock(nn.Module):
261
  self.t_emb_layers = nn.ModuleList(
262
  [
263
  nn.Sequential(nn.SiLU(), nn.Linear(self.t_emb_dim, self.output_dim))
264
- for _ in range(self.num_layers)
265
  ]
266
  )
267
 
@@ -281,7 +326,7 @@ class MidBlock(nn.Module):
281
  padding=1,
282
  ),
283
  )
284
- for _ in range(self.num_layers)
285
  ]
286
  )
287
 
@@ -323,3 +368,255 @@ class MidBlock(nn.Module):
323
  for _ in range(self.num_layers)
324
  ]
325
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
  import torch
4
  import torch.nn as nn
5
 
6
+ logger = logging.getLogger(__name__)
7
+
8
 
9
  def get_time_embedding(time_steps, temb_dim):
10
  r"""
 
39
  1) Resnet Block :- [Norm-> Silu -> Conv] x num_layers
40
  2) Self Attention :- [Norm -> SA]
41
  3) Cross Attention :- [Norm -> CA]
42
+ b) MidSample : DownSample the dimnension
43
  """
44
 
45
  def __init__(
 
174
  out = x
175
  for i in range(self.num_layers):
176
  # Input x to Resnet Block of the Encoder of the Unet
177
+ logger.debug(f"Input to Resnet Block in Down Block Layer {i} : {out.shape}")
178
  resnet_input = out
179
  out = self.resnet_one[i](out)
180
+ logger.debug(
181
+ f"Output of Resnet Sub Block 1 of Down Block Layer {i} : {out.shape}"
182
+ )
183
+ if self.t_emb_dim is not None:
184
+ logger.debug(
185
+ f"Adding t_emb of shape {self.t_emb_dim} to output of shape: {out.shape} of Down Block Layer {i}"
186
+ )
187
  out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
188
  out = self.resnet_two[i](out)
189
+ logger.debug(
190
+ f"Output of Resnet Sub Block 2 of Down Block Layer: {i} with output_shape:{out.shape}"
191
+ )
192
  out = out + self.resnet_in[i](resnet_input)
193
+ logger.debug(
194
+ f"Residual connection of the input to out : {out.shape} in Down Block Layer {i}"
195
+ )
196
 
197
  if self.attn:
198
  # Now Passing through the Self Attention blocks
199
+ logger.debug(f"Going into the attention Block in Down Block Layer {i}")
200
  batch_size, channels, h, w = out.shape
201
  in_attn = out.reshape(batch_size, channels, h * w)
202
  in_attn = self.attention_norms[i](in_attn)
 
204
  out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
205
  out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
206
  out = out + out_attn
207
+ logger.debug(
208
+ f"Out of the Self Attention Block with out : {out.shape} in Down Block Layer {i}"
209
+ )
210
 
211
  if self.cross_attn:
212
  assert context is not None, (
213
  "context cannot be None if cross attention layers are used"
214
  )
215
+ logger.debug(
216
+ f"Going into the Cross Attention Block in Down Block Layer {i}"
217
+ )
218
  batch_size, channels, h, w = out.shape
219
  in_attn = out.reshape(batch_size, channels, h * w)
220
  in_attn = self.cross_attn_norms[i](in_attn)
 
223
  context.shape[0] == x.shape[0]
224
  and context.shape[-1] == self.context_dim
225
  )
226
+ logger.debug(
227
+ f"Calculating context projection for Cross Attn in Down Block Layer : {i}"
228
+ )
229
  context_proj = self.context_proj[i](context)
230
  out_attn, _ = self.cross_attentions[i](
231
  in_attn, context_proj, context_proj
232
  )
233
  out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
234
  out = out + out_attn
235
+ logger.debug(
236
+ f"Out of the Cross Attention Block with out : {out.shape} in Down Block Layer {i}"
237
+ )
238
 
239
  # DownSample to x2 smaller dimension
240
  out = self.down_sample_conv(out)
241
+ logger.debug(f"Down Sampling out to : {out.shape} in Down Block Layer {i} ")
242
  return out
243
 
244
 
245
  class MidBlock(nn.Module):
246
+ r"""
247
+
248
+ MidBlock for Diffusion model:
249
+ Time embedding -> [Silu -> FC]
250
+
251
+ 1) Resnet Block :- [Norm-> Silu -> Conv] x num_layers
252
+ 2) Self Attention :- [Norm -> SA]
253
+ 3) Cross Attention :- [Norm -> CA]
254
+ Time embedding -> [Silu -> FC]
255
+
256
+ 4) Resnet Block :- [Norm-> Silu -> Conv] x num_layers
257
+
258
+ """
259
+
260
  def __init__(
261
  self,
262
  num_heads,
 
298
  padding=1,
299
  ),
300
  )
301
+ for i in range(self.num_layers + 1)
302
  ]
303
  )
304
 
 
306
  self.t_emb_layers = nn.ModuleList(
307
  [
308
  nn.Sequential(nn.SiLU(), nn.Linear(self.t_emb_dim, self.output_dim))
309
+ for _ in range(self.num_layers + 1)
310
  ]
311
  )
312
 
 
326
  padding=1,
327
  ),
328
  )
329
+ for _ in range(self.num_layers + 1)
330
  ]
331
  )
332
 
 
368
  for _ in range(self.num_layers)
369
  ]
370
  )
371
+
372
+ self.resnet_in = nn.ModuleList(
373
+ [
374
+ nn.Conv2d(
375
+ self.input_dim if i == 0 else self.output_dim,
376
+ self.output_dim,
377
+ kernel_size=1,
378
+ )
379
+ for i in range(self.num_layers + 1)
380
+ ]
381
+ )
382
+
383
+ def forward(self, x, t_emb=None, context=None):
384
+ out = x
385
+
386
+ # Input Resnet Block
387
+ logger.debug("Input to First Resnet Block in Mid Block")
388
+ resnet_input = out
389
+ out = self.resnet_one[0](out)
390
+ logger.debug(f"Output of Resnet Sub Block 1 of Mid Block Layer: {out.shape}")
391
+ if self.t_emb_dim is not None:
392
+ out = out + self.t_emb_layers[0](t_emb)[:, :, None, None]
393
+ logger.debug(
394
+ f"Adding t_emb of shape {self.t_emb_dim} to output of shape: {out.shape}"
395
+ )
396
+ out = self.resnet_two[0](out)
397
+ logger.debug(f"Output of Resnet Sub Block 2 with output_shape:{out.shape}")
398
+ out = out + self.resnet_in[0](resnet_input)
399
+ logger.debug(
400
+ f"Residual connection of the input to out : {out.shape} in Mid Block"
401
+ )
402
+
403
+ for i in range(self.num_layers):
404
+ logger.debug(f"Going into the attention Block in Mid Block Layer {i}")
405
+ batch_size, channels, h, w = out.shape
406
+ in_attn = out.reshape(batch_size, channels, h * w)
407
+ in_attn = self.attention_norms[i](in_attn)
408
+ in_attn = in_attn.transpose(1, 2)
409
+ out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
410
+ out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
411
+ out = out + out_attn
412
+ logger.debug(
413
+ f"Out of the Self Attention Block with out : {out.shape} in Mid Block Layer {i}"
414
+ )
415
+
416
+ if self.cross_attn:
417
+ assert context is not None, (
418
+ "context cannot be None if cross attention layers are used"
419
+ )
420
+ logger.debug(
421
+ f"Going into the Cross Attention Block in Mid Block Layer {i}"
422
+ )
423
+ batch_size, channels, h, w = out.shape
424
+ in_attn = out.reshape(batch_size, channels, h * w)
425
+ in_attn = self.cross_attn_norms[i](in_attn)
426
+ in_attn = in_attn.transpose(1, 2)
427
+ assert (
428
+ context.shape[0] == x.shape[0]
429
+ and context.shape[-1] == self.context_dim
430
+ )
431
+ logger.debug(
432
+ f"Calculating context projection for Cross Attn in Mid Block Layer : {i}"
433
+ )
434
+ context_proj = self.context_proj[i](context)
435
+ out_attn, _ = self.cross_attentions[i](
436
+ in_attn, context_proj, context_proj
437
+ )
438
+ out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
439
+ out = out + out_attn
440
+ logger.debug(
441
+ f"Out of the Cross Attention Block with out : {out.shape} in Mid Block Layer {i}"
442
+ )
443
+ logger.debug(
444
+ f"Last Resnet Block input : {out.shape} of Mid Block Layer {i}"
445
+ )
446
+ resnet_input = out
447
+ out = self.resnet_one[0](out)
448
+ logger.debug(
449
+ f"Output of Resnet Sub Block 1 of Mid Block Layer {i} of shape : {out.shape}"
450
+ )
451
+ if self.t_emb_dim is not None:
452
+ out = out + self.t_emb_layers[0](t_emb)[:, :, None, None]
453
+ logger.debug(
454
+ f"Adding t_emb of shape {self.t_emb_dim} to output of shape: {out.shape} of Mid Block Layer {i}"
455
+ )
456
+ out = self.resnet_two[0](out)
457
+ logger.debug(
458
+ f"Output of Resnet Sub Block 2 with output_shape:{out.shape} of Mid Block Layer {i}"
459
+ )
460
+ out = out + self.resnet_in[0](resnet_input)
461
+ logger.debug(
462
+ f"Residual connection of the input to out : {out.shape} in Mid Block Layer {i}"
463
+ )
464
+
465
+ return out
466
+
467
+
468
+ class UpBlockUnet(nn.Module):
469
+ r"""
470
+ Up conv block with attention.
471
+ Sequence of following blocks
472
+ 1. Upsample
473
+ 1. Concatenate Down block output
474
+ 2. Resnet block with time embedding
475
+ 3. Attention Block
476
+ """
477
+
478
+ def __init__(
479
+ self,
480
+ in_channels,
481
+ out_channels,
482
+ t_emb_dim,
483
+ up_sample,
484
+ num_heads,
485
+ num_layers,
486
+ norm_channels,
487
+ cross_attn=False,
488
+ context_dim=None,
489
+ ):
490
+ super().__init__()
491
+ self.num_layers = num_layers
492
+ self.up_sample = up_sample
493
+ self.t_emb_dim = t_emb_dim
494
+ self.cross_attn = cross_attn
495
+ self.context_dim = context_dim
496
+ self.resnet_conv_first = nn.ModuleList(
497
+ [
498
+ nn.Sequential(
499
+ nn.GroupNorm(
500
+ norm_channels, in_channels if i == 0 else out_channels
501
+ ),
502
+ nn.SiLU(),
503
+ nn.Conv2d(
504
+ in_channels if i == 0 else out_channels,
505
+ out_channels,
506
+ kernel_size=3,
507
+ stride=1,
508
+ padding=1,
509
+ ),
510
+ )
511
+ for i in range(num_layers)
512
+ ]
513
+ )
514
+
515
+ if self.t_emb_dim is not None:
516
+ self.t_emb_layers = nn.ModuleList(
517
+ [
518
+ nn.Sequential(nn.SiLU(), nn.Linear(t_emb_dim, out_channels))
519
+ for _ in range(num_layers)
520
+ ]
521
+ )
522
+
523
+ self.resnet_conv_second = nn.ModuleList(
524
+ [
525
+ nn.Sequential(
526
+ nn.GroupNorm(norm_channels, out_channels),
527
+ nn.SiLU(),
528
+ nn.Conv2d(
529
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
530
+ ),
531
+ )
532
+ for _ in range(num_layers)
533
+ ]
534
+ )
535
+
536
+ self.attention_norms = nn.ModuleList(
537
+ [nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)]
538
+ )
539
+
540
+ self.attentions = nn.ModuleList(
541
+ [
542
+ nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
543
+ for _ in range(num_layers)
544
+ ]
545
+ )
546
+
547
+ if self.cross_attn:
548
+ assert context_dim is not None, (
549
+ "Context Dimension must be passed for cross attention"
550
+ )
551
+ self.cross_attention_norms = nn.ModuleList(
552
+ [nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)]
553
+ )
554
+ self.cross_attentions = nn.ModuleList(
555
+ [
556
+ nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
557
+ for _ in range(num_layers)
558
+ ]
559
+ )
560
+ self.context_proj = nn.ModuleList(
561
+ [nn.Linear(context_dim, out_channels) for _ in range(num_layers)]
562
+ )
563
+ self.residual_input_conv = nn.ModuleList(
564
+ [
565
+ nn.Conv2d(
566
+ in_channels if i == 0 else out_channels, out_channels, kernel_size=1
567
+ )
568
+ for i in range(num_layers)
569
+ ]
570
+ )
571
+ self.up_sample_conv = (
572
+ nn.ConvTranspose2d(in_channels // 2, in_channels // 2, 4, 2, 1)
573
+ if self.up_sample
574
+ else nn.Identity()
575
+ )
576
+
577
+ def forward(self, x, out_down=None, t_emb=None, context=None):
578
+ x = self.up_sample_conv(x)
579
+ if out_down is not None:
580
+ x = torch.cat([x, out_down], dim=1)
581
+
582
+ out = x
583
+ for i in range(self.num_layers):
584
+ # Resnet
585
+ resnet_input = out
586
+ out = self.resnet_conv_first[i](out)
587
+ if self.t_emb_dim is not None:
588
+ out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
589
+ out = self.resnet_conv_second[i](out)
590
+ out = out + self.residual_input_conv[i](resnet_input)
591
+ # Self Attention
592
+ batch_size, channels, h, w = out.shape
593
+ in_attn = out.reshape(batch_size, channels, h * w)
594
+ in_attn = self.attention_norms[i](in_attn)
595
+ in_attn = in_attn.transpose(1, 2)
596
+ out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
597
+ out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
598
+ out = out + out_attn
599
+ # Cross Attention
600
+ if self.cross_attn:
601
+ assert context is not None, (
602
+ "context cannot be None if cross attention layers are used"
603
+ )
604
+ batch_size, channels, h, w = out.shape
605
+ in_attn = out.reshape(batch_size, channels, h * w)
606
+ in_attn = self.cross_attention_norms[i](in_attn)
607
+ in_attn = in_attn.transpose(1, 2)
608
+ assert len(context.shape) == 3, (
609
+ "Context shape does not match B,_,CONTEXT_DIM"
610
+ )
611
+ assert (
612
+ context.shape[0] == x.shape[0]
613
+ and context.shape[-1] == self.context_dim
614
+ ), "Context shape does not match B,_,CONTEXT_DIM"
615
+ context_proj = self.context_proj[i](context)
616
+ out_attn, _ = self.cross_attentions[i](
617
+ in_attn, context_proj, context_proj
618
+ )
619
+ out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
620
+ out = out + out_attn
621
+
622
+ return out