Yash Nagraj commited on
Commit
70a401a
·
1 Parent(s): 6801d5a

Add UpBlock Unet to concat downblock's output

Browse files
Files changed (1) hide show
  1. models/blocks.py +122 -53
models/blocks.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import torch
2
  import torch.nn as nn
3
 
@@ -263,69 +264,137 @@ class MidBlock(nn.Module):
263
  return out
264
 
265
 
266
- class UpBlock(nn.Module):
267
- """
268
- Up Block that upsamples the image, flows like this:
269
- 1) UpSample
270
- 2) Concat down block output
271
- 3) Resnet block with time embedding
272
- 4) Attention Block
 
273
  """
274
 
275
- def __init__(self, in_channels, out_channels, t_emb_dim, up_sample, num_layers, attn, norm_channels, num_heads):
 
276
  super().__init__()
277
  self.num_layers = num_layers
278
- self.attn = attn
279
- self.norm_channels = norm_channels
280
- self.resnet_conv_one = nn.ModuleList([
281
- nn.Sequential(
282
- nn.GroupNorm(norm_channels, in_channels if i ==
283
- 0 else out_channels),
284
- nn.SiLU(),
285
- nn.Conv2d(in_channels if i == 0 else out_channels,
286
- out_channels, 3, 1, 1)
287
-
288
- ) for i in range(num_layers)
289
-
290
- ])
291
-
292
- if t_emb_dim is not None:
293
- self.time_emb_layers = nn.ModuleList(
294
- [
295
- nn.Sequential(
296
- nn.SiLU(),
297
- nn.Linear(t_emb_dim, out_channels)
298
- ) for _ in range(num_layers)
299
-
300
- ]
301
- )
302
-
303
- self.resnet_conv_two = nn.ModuleList([
304
- nn.Sequential(
305
- nn.GroupNorm(norm_channels, out_channels),
306
- nn.SiLU(),
307
- nn.Conv2d(out_channels, out_channels, 3, 1, 1)
308
 
309
- ) for _ in range(num_layers)
 
 
 
 
 
 
 
310
 
311
- ])
 
 
 
 
 
 
 
 
 
 
312
 
313
- if self.attn:
314
- self.attention_norms = nn.ModuleList([
315
  nn.GroupNorm(norm_channels, out_channels)
316
  for _ in range(num_layers)
317
- ])
 
318
 
319
- self.attention_heads = nn.ModuleList(
320
- [nn.MultiheadAttention(
321
- out_channels, num_heads, batch_first=True) for _ in range(num_layers)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
322
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
 
324
- self.resnet_input_conv = nn.ModuleList([
325
- nn.Conv2d(in_channels if i == 0 else out_channels,
326
- out_channels, 3, 1, 1)
327
- for i in range(num_layers)
328
- ])
329
 
330
- self.upsample = nn.ConvTranspose2d(
331
- in_channels, in_channels, 4, 2, 1) if self.upsample else nn.Identity()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from re import A
2
  import torch
3
  import torch.nn as nn
4
 
 
264
  return out
265
 
266
 
267
+ class UpBlockUnet(nn.Module):
268
+ r"""
269
+ Up conv block with attention.
270
+ Sequence of following blocks
271
+ 1. Upsample
272
+ 1. Concatenate Down block output
273
+ 2. Resnet block with time embedding
274
+ 3. Attention Block
275
  """
276
 
277
+ def __init__(self, in_channels, out_channels, t_emb_dim, up_sample,
278
+ num_heads, num_layers, norm_channels, cross_attn=False, context_dim=None):
279
  super().__init__()
280
  self.num_layers = num_layers
281
+ self.up_sample = up_sample
282
+ self.t_emb_dim = t_emb_dim
283
+ self.cross_attn = cross_attn
284
+ self.context_dim = context_dim
285
+ self.resnet_conv_first = nn.ModuleList(
286
+ [
287
+ nn.Sequential(
288
+ nn.GroupNorm(norm_channels, in_channels if i ==
289
+ 0 else out_channels),
290
+ nn.SiLU(),
291
+ nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,
292
+ padding=1),
293
+ )
294
+ for i in range(num_layers)
295
+ ]
296
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
297
 
298
+ if self.t_emb_dim is not None:
299
+ self.t_emb_layers = nn.ModuleList([
300
+ nn.Sequential(
301
+ nn.SiLU(),
302
+ nn.Linear(t_emb_dim, out_channels)
303
+ )
304
+ for _ in range(num_layers)
305
+ ])
306
 
307
+ self.resnet_conv_second = nn.ModuleList(
308
+ [
309
+ nn.Sequential(
310
+ nn.GroupNorm(norm_channels, out_channels),
311
+ nn.SiLU(),
312
+ nn.Conv2d(out_channels, out_channels,
313
+ kernel_size=3, stride=1, padding=1),
314
+ )
315
+ for _ in range(num_layers)
316
+ ]
317
+ )
318
 
319
+ self.attention_norms = nn.ModuleList(
320
+ [
321
  nn.GroupNorm(norm_channels, out_channels)
322
  for _ in range(num_layers)
323
+ ]
324
+ )
325
 
326
+ self.attentions = nn.ModuleList(
327
+ [
328
+ nn.MultiheadAttention(
329
+ out_channels, num_heads, batch_first=True)
330
+ for _ in range(num_layers)
331
+ ]
332
+ )
333
+
334
+ if self.cross_attn:
335
+ assert context_dim is not None, "Context Dimension must be passed for cross attention"
336
+ self.cross_attention_norms = nn.ModuleList(
337
+ [nn.GroupNorm(norm_channels, out_channels)
338
+ for _ in range(num_layers)]
339
+ )
340
+ self.cross_attentions = nn.ModuleList(
341
+ [nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
342
+ for _ in range(num_layers)]
343
  )
344
+ self.context_proj = nn.ModuleList(
345
+ [nn.Linear(context_dim, out_channels)
346
+ for _ in range(num_layers)]
347
+ )
348
+ self.residual_input_conv = nn.ModuleList(
349
+ [
350
+ nn.Conv2d(in_channels if i == 0 else out_channels,
351
+ out_channels, kernel_size=1)
352
+ for i in range(num_layers)
353
+ ]
354
+ )
355
+ self.up_sample_conv = nn.ConvTranspose2d(in_channels // 2, in_channels // 2,
356
+ 4, 2, 1) \
357
+ if self.up_sample else nn.Identity()
358
 
359
+ def forward(self, x, out_down=None, t_emb=None, context=None):
360
+ x = self.up_sample_conv(x)
361
+ if out_down is not None:
362
+ x = torch.cat([x, out_down], dim=1)
 
363
 
364
+ out = x
365
+ for i in range(self.num_layers):
366
+ # Resnet
367
+ resnet_input = out
368
+ out = self.resnet_conv_first[i](out)
369
+ if self.t_emb_dim is not None:
370
+ out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
371
+ out = self.resnet_conv_second[i](out)
372
+ out = out + self.residual_input_conv[i](resnet_input)
373
+ # Self Attention
374
+ batch_size, channels, h, w = out.shape
375
+ in_attn = out.reshape(batch_size, channels, h * w)
376
+ in_attn = self.attention_norms[i](in_attn)
377
+ in_attn = in_attn.transpose(1, 2)
378
+ out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
379
+ out_attn = out_attn.transpose(1, 2).reshape(
380
+ batch_size, channels, h, w)
381
+ out = out + out_attn
382
+ # Cross Attention
383
+ if self.cross_attn:
384
+ assert context is not None, "context cannot be None if cross attention layers are used"
385
+ batch_size, channels, h, w = out.shape
386
+ in_attn = out.reshape(batch_size, channels, h * w)
387
+ in_attn = self.cross_attention_norms[i](in_attn)
388
+ in_attn = in_attn.transpose(1, 2)
389
+ assert len(context.shape) == 3, \
390
+ "Context shape does not match B,_,CONTEXT_DIM"
391
+ assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim, \
392
+ "Context shape does not match B,_,CONTEXT_DIM"
393
+ context_proj = self.context_proj[i](context)
394
+ out_attn, _ = self.cross_attentions[i](
395
+ in_attn, context_proj, context_proj)
396
+ out_attn = out_attn.transpose(1, 2).reshape(
397
+ batch_size, channels, h, w)
398
+ out = out + out_attn
399
+
400
+ return out