Yash Nagraj commited on
Commit
6801d5a
·
1 Parent(s): de5e356

Add all the layers in the UpBlock

Browse files
Files changed (1) hide show
  1. models/blocks.py +71 -0
models/blocks.py CHANGED
@@ -152,6 +152,7 @@ class MidBlock(nn.Module):
152
  """
153
 
154
  def __init__(self, in_channels, out_channels, t_emb_dim, num_heads, num_layers, norm_dim, cross_attn=None, context_dim=None):
 
155
  self.in_channels = in_channels
156
  self.out_channels = out_channels
157
  self.t_emb_dim = t_emb_dim
@@ -258,3 +259,73 @@ class MidBlock(nn.Module):
258
  out = out + self.time_emb_layers[i+1](t_emb)[:, :, None, None]
259
  out = out + self.resnet_conv_two[i+1](out)
260
  out = out + self.residual_input_conv[i+1](resnet_input)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  """
153
 
154
  def __init__(self, in_channels, out_channels, t_emb_dim, num_heads, num_layers, norm_dim, cross_attn=None, context_dim=None):
155
+ super().__init__()
156
  self.in_channels = in_channels
157
  self.out_channels = out_channels
158
  self.t_emb_dim = t_emb_dim
 
259
  out = out + self.time_emb_layers[i+1](t_emb)[:, :, None, None]
260
  out = out + self.resnet_conv_two[i+1](out)
261
  out = out + self.residual_input_conv[i+1](resnet_input)
262
+
263
+ return out
264
+
265
+
266
+ class UpBlock(nn.Module):
267
+ """
268
+ Up Block that upsamples the image, flows like this:
269
+ 1) UpSample
270
+ 2) Concat down block output
271
+ 3) Resnet block with time embedding
272
+ 4) Attention Block
273
+ """
274
+
275
+ def __init__(self, in_channels, out_channels, t_emb_dim, up_sample, num_layers, attn, norm_channels, num_heads):
276
+ super().__init__()
277
+ self.num_layers = num_layers
278
+ self.attn = attn
279
+ self.norm_channels = norm_channels
280
+ self.resnet_conv_one = nn.ModuleList([
281
+ nn.Sequential(
282
+ nn.GroupNorm(norm_channels, in_channels if i ==
283
+ 0 else out_channels),
284
+ nn.SiLU(),
285
+ nn.Conv2d(in_channels if i == 0 else out_channels,
286
+ out_channels, 3, 1, 1)
287
+
288
+ ) for i in range(num_layers)
289
+
290
+ ])
291
+
292
+ if t_emb_dim is not None:
293
+ self.time_emb_layers = nn.ModuleList(
294
+ [
295
+ nn.Sequential(
296
+ nn.SiLU(),
297
+ nn.Linear(t_emb_dim, out_channels)
298
+ ) for _ in range(num_layers)
299
+
300
+ ]
301
+ )
302
+
303
+ self.resnet_conv_two = nn.ModuleList([
304
+ nn.Sequential(
305
+ nn.GroupNorm(norm_channels, out_channels),
306
+ nn.SiLU(),
307
+ nn.Conv2d(out_channels, out_channels, 3, 1, 1)
308
+
309
+ ) for _ in range(num_layers)
310
+
311
+ ])
312
+
313
+ if self.attn:
314
+ self.attention_norms = nn.ModuleList([
315
+ nn.GroupNorm(norm_channels, out_channels)
316
+ for _ in range(num_layers)
317
+ ])
318
+
319
+ self.attention_heads = nn.ModuleList(
320
+ [nn.MultiheadAttention(
321
+ out_channels, num_heads, batch_first=True) for _ in range(num_layers)]
322
+ )
323
+
324
+ self.resnet_input_conv = nn.ModuleList([
325
+ nn.Conv2d(in_channels if i == 0 else out_channels,
326
+ out_channels, 3, 1, 1)
327
+ for i in range(num_layers)
328
+ ])
329
+
330
+ self.upsample = nn.ConvTranspose2d(
331
+ in_channels, in_channels, 4, 2, 1) if self.upsample else nn.Identity()