YashNagraj75 commited on
Commit
76a0a2e
·
1 Parent(s): ec73463

Add Unet and other blocks

Browse files
Files changed (1) hide show
  1. model_blocks/unet_base.py +504 -1
model_blocks/unet_base.py CHANGED
@@ -1,4 +1,5 @@
1
  import logging
 
2
 
3
  import torch
4
  import torch.nn as nn
@@ -31,4 +32,506 @@ def get_time_embedding(time_steps, temb_dim):
31
  return t_emb
32
 
33
 
34
- class
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import logging
2
+ from os import wait
3
 
4
  import torch
5
  import torch.nn as nn
 
32
  return t_emb
33
 
34
 
35
+ class DownBlock(nn.Module):
36
+ r"""
37
+ DownBlock for Diffusion model:
38
+ a) Block Time embedding -> [Silu -> FC]
39
+
40
+ 1) Resnet Block :- [Norm-> Silu -> Conv] x num_layers
41
+ 2) Self Attention :- [Norm -> SA]
42
+ b) DownSample : DownSample the dimnension
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ input_dim,
48
+ output_dim,
49
+ t_emb_dim,
50
+ down_sample=True,
51
+ num_heads=4,
52
+ num_layers=1,
53
+ ) -> None:
54
+ super().__init__()
55
+ self.input_dim = input_dim
56
+ self.output_dim = output_dim
57
+ self.down_sample = down_sample
58
+ self.num_heads = num_heads
59
+ self.num_layers = num_layers
60
+ self.t_emb_dim = t_emb_dim
61
+
62
+ self.resnet_one = nn.ModuleList(
63
+ [
64
+ nn.Sequential(
65
+ nn.GroupNorm(8, self.input_dim if i == 0 else self.output_dim),
66
+ nn.SiLU(),
67
+ nn.Conv2d(
68
+ self.input_dim if i == 0 else self.output_dim,
69
+ self.output_dim,
70
+ kernel_size=3,
71
+ stride=1,
72
+ padding=1,
73
+ ),
74
+ )
75
+ for i in range(self.num_layers)
76
+ ]
77
+ )
78
+
79
+ self.t_emb_layers = nn.ModuleList(
80
+ [
81
+ nn.Sequential(nn.SiLU(), nn.Linear(self.t_emb_dim, self.output_dim))
82
+ for _ in range(self.num_layers)
83
+ ]
84
+ )
85
+
86
+ self.resnet_two = nn.ModuleList(
87
+ [
88
+ nn.Sequential(
89
+ nn.GroupNorm(8, self.output_dim),
90
+ nn.SiLU(),
91
+ nn.Conv2d(
92
+ self.output_dim,
93
+ self.output_dim,
94
+ kernel_size=3,
95
+ stride=1,
96
+ padding=1,
97
+ ),
98
+ )
99
+ for _ in range(self.num_layers)
100
+ ]
101
+ )
102
+
103
+ self.attention_norms = nn.ModuleList(
104
+ [nn.GroupNorm(8, self.output_dim) for _ in range(self.num_layers)]
105
+ )
106
+
107
+ self.attentions = nn.ModuleList(
108
+ [nn.MultiheadAttention(self.output_dim, self.num_heads, batch_first=True)]
109
+ )
110
+ self.resnet_in = nn.ModuleList(
111
+ [
112
+ nn.Conv2d(
113
+ self.input_dim if i == 0 else self.output_dim,
114
+ self.output_dim,
115
+ kernel_size=1,
116
+ )
117
+ for i in range(self.num_layers)
118
+ ]
119
+ )
120
+ self.down_sample_conv = (
121
+ nn.Conv2d(self.output_dim, self.output_dim, 4, 2, 1)
122
+ if self.down_sample
123
+ else nn.Identity()
124
+ )
125
+
126
+ def forward(
127
+ self,
128
+ x,
129
+ t_emb,
130
+ ):
131
+ out = x
132
+ logger.debug(f"Input of shape: {out.shape} to Down Block ")
133
+
134
+ for i in range(self.num_layers):
135
+ resnet_input = out
136
+ logger.debug(f"Input to Resnet Block : {resnet_input.shape} ")
137
+ out = self.resnet_one[i](out)
138
+ out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
139
+ logger.debug(
140
+ f"Concatenated time embeddings to Resnet Sub Block 1: {out.shape} of Down Block Layer {i}"
141
+ )
142
+ out = self.resnet_two[i](out)
143
+ out = out + self.resnet_in[i](resnet_input)
144
+ logger.debug(
145
+ f"Adding Residual connection : {out.shape} to Down Block Layer {i}"
146
+ )
147
+
148
+ batch_size, channels, h, w = out.shape
149
+ in_attn = out.reshape(batch_size, channels, h * w)
150
+ in_attn = self.attention_norms[i](in_attn)
151
+ logger.debug(f"Attention Norm: {in_attn.shape} in Down Block Layer : {i}")
152
+ in_attn = in_attn.transpose(1, 2)
153
+ logger.debug(
154
+ f"Passing Norm : {in_attn.shape} to Attention Layer in Down Block Layer : {i}"
155
+ )
156
+ out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
157
+ out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
158
+ out = out + out_attn
159
+ logger.debug(
160
+ f"Added Attention score to output: {out.shape} in Down Block Layer {i}"
161
+ )
162
+
163
+ out = self.down_sample_conv(out)
164
+ logger.debug(f"Down sampled to : {out.shape}")
165
+ return out
166
+
167
+
168
+ class MidBlock(nn.Module):
169
+ r"""
170
+ MidBlock for Diffusion model:
171
+ Time embedding -> [Silu -> FC]
172
+
173
+ 1) Resnet Block :- [Norm-> Silu -> Conv] x num_layers
174
+ 2) Self Attention :- [Norm -> SA]
175
+ Time embedding -> [Silu -> FC]
176
+
177
+ 3) Resnet Block :- [Norm-> Silu -> Conv] x num_layers
178
+ """
179
+
180
+ def __init__(
181
+ self,
182
+ input_dim,
183
+ output_dim,
184
+ t_emb_dim,
185
+ num_heads=4,
186
+ num_layers=1,
187
+ ) -> None:
188
+ super().__init__()
189
+ self.input_dim = input_dim
190
+ self.output_dim = output_dim
191
+ self.num_heads = num_heads
192
+ self.num_layers = num_layers
193
+ self.t_emb_dim = t_emb_dim
194
+
195
+ self.resnet_one = nn.ModuleList(
196
+ [
197
+ nn.Sequential(
198
+ nn.GroupNorm(8, self.input_dim if i == 0 else self.output_dim),
199
+ nn.SiLU(),
200
+ nn.Conv2d(
201
+ self.input_dim if i == 0 else self.output_dim,
202
+ self.output_dim,
203
+ kernel_size=3,
204
+ stride=1,
205
+ padding=1,
206
+ ),
207
+ )
208
+ for i in range(self.num_layers + 1)
209
+ ]
210
+ )
211
+
212
+ self.t_emb_layers = nn.ModuleList(
213
+ [
214
+ nn.Sequential(nn.SiLU(), nn.Linear(self.t_emb_dim, self.output_dim))
215
+ for _ in range(self.num_layers + 1)
216
+ ]
217
+ )
218
+ self.resnet_two = nn.ModuleList(
219
+ [
220
+ nn.Sequential(
221
+ nn.GroupNorm(8, self.output_dim),
222
+ nn.SiLU(),
223
+ nn.Conv2d(
224
+ self.output_dim,
225
+ self.output_dim,
226
+ kernel_size=3,
227
+ stride=1,
228
+ padding=1,
229
+ ),
230
+ )
231
+ for _ in range(self.num_layers + 1)
232
+ ]
233
+ )
234
+
235
+ self.attention_norms = nn.ModuleList(
236
+ [nn.GroupNorm(8, self.output_dim) for _ in range(self.num_layers)]
237
+ )
238
+
239
+ self.attentions = nn.ModuleList(
240
+ [
241
+ nn.MultiheadAttention(self.output_dim, self.num_heads, batch_first=True)
242
+ for _ in range(self.num_layers)
243
+ ]
244
+ )
245
+ self.resnet_in = nn.ModuleList(
246
+ [
247
+ nn.Conv2d(
248
+ self.input_dim if i == 0 else self.output_dim,
249
+ self.output_dim,
250
+ kernel_size=1,
251
+ )
252
+ for i in range(self.num_layers + 1)
253
+ ]
254
+ )
255
+
256
+ def forward(self, x, t_emb):
257
+ out = x
258
+ logger.debug(f"Input of shape: {out.shape} to Mid Block ")
259
+
260
+ # First Resnet Block
261
+ resnet_input = out
262
+ logger.debug(
263
+ f"Input to Resnet Block : {resnet_input.shape} in Mid Block Layer 0"
264
+ )
265
+ out = self.resnet_one[0](out)
266
+ out = out + self.t_emb_layers[0](t_emb)[:, :, None, None]
267
+ logger.debug(
268
+ f"Concatenated time embeddings to Resnet Sub Block 1: {out.shape} of Mid Block Layer 0"
269
+ )
270
+ out = self.resnet_two[0](out)
271
+ out = out + self.resnet_in[0](resnet_input)
272
+ logger.debug(f"Adding Residual connection : {out.shape} to Mid Block Layer 0")
273
+
274
+ for i in range(self.num_layers):
275
+ # Attention Block
276
+ batch_size, channels, h, w = out.shape
277
+ in_attn = out.reshape(batch_size, channels, h * w)
278
+ in_attn = self.attention_norms[i](in_attn)
279
+ logger.debug(f"Attention Norm: {in_attn.shape} in Mid Block Layer : {i} ")
280
+ in_attn = in_attn.transpose(1, 2)
281
+ out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
282
+ out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
283
+ out = out + out_attn
284
+ logger.debug(
285
+ f"Added Attention score to output: {out.shape} in Mid Block Layer {i}"
286
+ )
287
+
288
+ # Resnet Block
289
+ resnet_input = out
290
+ logger.debug(
291
+ f"Input to Resnet Block : {resnet_input.shape} in Mid Block Layer {i}"
292
+ )
293
+ out = self.resnet_one[i + 1](out)
294
+ out = out + self.t_emb_layers[i + 1](t_emb)[:, :, None, None]
295
+ logger.debug(
296
+ f"Concatenated time embeddings to Resnet Sub Block 1: {out.shape} of Mid Block Layer {i}"
297
+ )
298
+ out = self.resnet_two[i + 1](out)
299
+ out = out + self.resnet_in[i + 1](resnet_input)
300
+ logger.debug(
301
+ f"Adding Residual connection : {out.shape} to Mid Block Layer {i}"
302
+ )
303
+
304
+ return out
305
+
306
+
307
+ class UpBlock(nn.Module):
308
+ r"""
309
+ UpBlock for Diffusion model:
310
+ 1. Upsample
311
+ 1. Concatenate Down block output
312
+ 2. Resnet block with time embedding
313
+ 3. Attention Block
314
+ """
315
+
316
+ def __init__(
317
+ self,
318
+ input_dim,
319
+ output_dim,
320
+ t_emb_dim,
321
+ up_sample=True,
322
+ num_heads=4,
323
+ num_layers=1,
324
+ ) -> None:
325
+ super().__init__()
326
+ self.input_dim = input_dim
327
+ self.output_dim = output_dim
328
+ self.up_sample = up_sample
329
+ self.num_heads = num_heads
330
+ self.num_layers = num_layers
331
+ self.t_emb_dim = t_emb_dim
332
+
333
+ self.resnet_one = nn.ModuleList(
334
+ [
335
+ nn.Sequential(
336
+ nn.GroupNorm(8, self.input_dim if i == 0 else self.output_dim),
337
+ nn.SiLU(),
338
+ nn.Conv2d(
339
+ self.input_dim if i == 0 else self.output_dim,
340
+ self.output_dim,
341
+ kernel_size=3,
342
+ stride=1,
343
+ padding=1,
344
+ ),
345
+ )
346
+ for i in range(self.num_layers)
347
+ ]
348
+ )
349
+ self.t_emb_layers = nn.ModuleList(
350
+ [
351
+ nn.Sequential(nn.SiLU(), nn.Linear(self.t_emb_dim, self.output_dim))
352
+ for _ in range(self.num_layers)
353
+ ]
354
+ )
355
+ self.resnet_two = nn.ModuleList(
356
+ [
357
+ nn.Sequential(
358
+ nn.GroupNorm(8, self.output_dim),
359
+ nn.SiLU(),
360
+ nn.Conv2d(
361
+ self.output_dim,
362
+ self.output_dim,
363
+ kernel_size=3,
364
+ stride=1,
365
+ padding=1,
366
+ ),
367
+ )
368
+ for _ in range(self.num_layers)
369
+ ]
370
+ )
371
+ self.attention_norms = nn.ModuleList(
372
+ [nn.GroupNorm(8, self.output_dim) for _ in range(self.num_layers)]
373
+ )
374
+ self.attentions = nn.ModuleList(
375
+ [
376
+ nn.MultiheadAttention(self.output_dim, self.num_heads, batch_first=True)
377
+ for _ in range(self.num_layers)
378
+ ]
379
+ )
380
+ self.resnet_in = nn.ModuleList(
381
+ [
382
+ nn.Conv2d(
383
+ self.input_dim if i == 0 else self.output_dim,
384
+ self.output_dim,
385
+ kernel_size=1,
386
+ )
387
+ for i in range(self.num_layers)
388
+ ]
389
+ )
390
+
391
+ self.up_sample_conv = (
392
+ nn.ConvTranspose2d(self.input_dim // 2, self.output_dim // 2, 4, 2, 1)
393
+ if self.up_sample
394
+ else nn.Identity()
395
+ )
396
+
397
+ def forward(self, x, out_down, t_emb):
398
+ logger.debug(f"Input of shape: {x.shape} to Up Block ")
399
+ out = x
400
+ out = self.up_sample_conv(out)
401
+ logger.debug(f"Up sampled to : {out.shape}")
402
+
403
+ # Concatenate Down Block output
404
+ out = torch.cat([out, out_down], dim=1)
405
+ logger.debug(f"Concatenated Down Block output: {out.shape}")
406
+
407
+ for i in range(self.num_layers):
408
+ resnet_input = out
409
+ logger.debug(
410
+ f"Input to Resnet Block : {resnet_input.shape} in Up Block Layer {i}"
411
+ )
412
+ out = self.resnet_one[i](out)
413
+ out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
414
+ logger.debug(
415
+ f"Concatenated time embeddings to Resnet Sub Block 1: {out.shape} of Up Block Layer {i}"
416
+ )
417
+ out = self.resnet_two[i](out)
418
+ out = out + self.resnet_in[i](resnet_input)
419
+ logger.debug(
420
+ f"Adding Residual connection : {out.shape} to Up Block Layer {i}"
421
+ )
422
+
423
+ # Attention Block
424
+ batch_size, channels, h, w = out.shape
425
+ in_attn = out.reshape(batch_size, channels, h * w)
426
+ in_attn = self.attention_norms[i](in_attn)
427
+ logger.debug(f"Attention Norm: {in_attn.shape} in Up Block Layer : {i}")
428
+ in_attn = in_attn.transpose(1, 2)
429
+ logger.debug(
430
+ f"Passing Norm : {in_attn.shape} to Attention Layer in Up Block Layer : {i}"
431
+ )
432
+ out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
433
+ out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
434
+ out = out + out_attn
435
+ logger.debug(
436
+ f"Added Attention score to output: {out.shape} in Up Block Layer {i}"
437
+ )
438
+
439
+ return out
440
+
441
+
442
+ class UNet(nn.Module):
443
+ r"""
444
+ Unet Backbone consisting:
445
+ Down Blocks, Mid Blocks, UpBlocks
446
+ """
447
+
448
+ def __init__(self, model_config, use_up=True):
449
+ super().__init__()
450
+ im_channels = model_config["im_channels"]
451
+ self.down_channels = model_config["down_channels"]
452
+ self.mid_channels = model_config["mid_channels"]
453
+ self.t_emb_dim = model_config["t_emb_dim"]
454
+ self.down_sample = model_config["down_sample"]
455
+ self.num_down_layers = model_config["num_down_layers"]
456
+ self.num_mid_layers = model_config["num_mid_layers"]
457
+ self.num_up_layers = model_config["num_up_layers"]
458
+
459
+ assert self.mid_channels[0] == self.down_channels[-1]
460
+ assert self.mid_channels[-1] == self.down_channels[-2]
461
+ assert len(self.down_sample) == len(self.down_channels) - 1
462
+
463
+ self.t_proj = nn.Sequential(
464
+ nn.Linear(self.t_emb_dim, self.t_emb_dim),
465
+ nn.SiLU(),
466
+ nn.Linear(self.t_emb_dim, self.t_emb_dim),
467
+ )
468
+
469
+ self.up_sample = list(reversed(self.down_sample))
470
+ self.conv_in = nn.Conv2d(
471
+ im_channels, self.down_channels[0], kernel_size=3, padding=1
472
+ )
473
+ self.downs = nn.ModuleList([])
474
+ for i in range(len(self.down_channels) - 1):
475
+ self.downs.append(
476
+ DownBlock(
477
+ self.down_channels[i],
478
+ self.down_channels[i + 1],
479
+ self.t_emb_dim,
480
+ down_sample=self.down_sample[i],
481
+ num_layers=self.num_down_layers,
482
+ )
483
+ )
484
+
485
+ self.mids = nn.ModuleList([])
486
+ for i in range(len(self.mid_channels) - 1):
487
+ self.mids.append(
488
+ MidBlock(
489
+ self.mid_channels[i],
490
+ self.mid_channels[i + 1],
491
+ self.t_emb_dim,
492
+ num_layers=self.num_mid_layers,
493
+ )
494
+ )
495
+
496
+ if use_up:
497
+ self.ups = nn.ModuleList([])
498
+ for i in reversed(range(len(self.down_channels) - 1)):
499
+ self.ups.append(
500
+ UpBlock(
501
+ self.down_channels[i] * 2,
502
+ self.down_channels[i - 1] if i != 0 else 16,
503
+ self.t_emb_dim,
504
+ up_sample=self.down_sample[i],
505
+ num_layers=self.num_up_layers,
506
+ )
507
+ )
508
+
509
+ self.norm_out = nn.GroupNorm(8, 16)
510
+ self.conv_out = nn.Conv2d(16, im_channels, kernel_size=3, padding=1)
511
+
512
+ def forward(self, x, t):
513
+ t_emb = get_time_embedding(torch.as_tensor(t).long(), self.t_emb_dim)
514
+ t_emb = self.t_proj(t_emb)
515
+ logger.debug(f"Time embedding shape: {t_emb.shape} to UNet")
516
+
517
+ out = self.conv_in(x)
518
+ logger.debug(f"Ouput for conv : {out.shape} to UNet")
519
+ down_outs = []
520
+
521
+ for idx, down in enumerate(self.downs):
522
+ down_outs.append(out)
523
+ out = down(out, t_emb)
524
+ logger.debug(f"Output of Down Block {idx} : {out.shape} in UNet")
525
+
526
+ for idx, mid in enumerate(self.mids):
527
+ out = mid(out, t_emb)
528
+ logger.debug(f"Output of Mid Block {idx} : {out.shape} in UNet")
529
+
530
+ for idx, up in enumerate(self.ups):
531
+ out = up(out, down_outs.pop(), t_emb)
532
+ logger.debug(f"Output of Up Block {idx} : {out.shape} in UNet")
533
+
534
+ out = self.norm_out(out)
535
+ out = self.conv_out(out)
536
+ logger.debug(f"Output of UNet : {out.shape}")
537
+ return out