Yash Nagraj commited on
Commit
38d054a
·
1 Parent(s): c2857b5

Add residual convolution to the mid block

Browse files
Files changed (1) hide show
  1. models/blocks.py +18 -0
models/blocks.py CHANGED
@@ -200,3 +200,21 @@ class MidBlock(nn.Module):
200
  [nn.GroupNorm(norm_dim, out_channels)
201
  for _ in range(num_layers)]
202
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  [nn.GroupNorm(norm_dim, out_channels)
201
  for _ in range(num_layers)]
202
  )
203
+
204
+ self.cross_attn = nn.ModuleList(
205
+ [nn.MultiheadAttention(
206
+ out_channels, num_heads=num_heads, batch_first=True) for _ in range(num_layers)]
207
+ )
208
+
209
+ self.context_proj = nn.ModuleList([
210
+ nn.Conv2d(in_channels if i == 0 else out_channels,
211
+ out_channels, kernel_size=1)
212
+ for i in range(num_layers + 1)
213
+ ])
214
+
215
+ self.residual_input_conv = nn.ModuleList([
216
+ nn.Conv2d(in_channels if i == 0 else out_channels,
217
+ out_channels, kernel_size=1)
218
+ for i in range(num_layers + 1)
219
+
220
+ ])