Yash Nagraj commited on
Commit
de5e356
·
1 Parent(s): 38d054a

Add forward function to the midblock

Browse files
Files changed (1) hide show
  1. models/blocks.py +40 -0
models/blocks.py CHANGED
@@ -157,6 +157,7 @@ class MidBlock(nn.Module):
157
  self.t_emb_dim = t_emb_dim
158
  self.cross_attn = cross_attn
159
  self.context_dim = context_dim
 
160
  self.resnet_conv_one = nn.ModuleList([
161
  nn.Sequential(
162
  nn.GroupNorm(norm_dim, in_channels if i ==
@@ -218,3 +219,42 @@ class MidBlock(nn.Module):
218
  for i in range(num_layers + 1)
219
 
220
  ])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  self.t_emb_dim = t_emb_dim
158
  self.cross_attn = cross_attn
159
  self.context_dim = context_dim
160
+ self.num_layers = num_layers
161
  self.resnet_conv_one = nn.ModuleList([
162
  nn.Sequential(
163
  nn.GroupNorm(norm_dim, in_channels if i ==
 
219
  for i in range(num_layers + 1)
220
 
221
  ])
222
+
223
+ def forward(self, x, t_emb=None, context=None):
224
+ out = x
225
+ resnet_input = out
226
+ out = self.resnet_conv_one[0](out)
227
+ if self.t_emb_dim is not None:
228
+ out = out + self.time_emb_layers[0](t_emb)[:, :, None, None]
229
+ out = self.resnet_conv_two[0](out)
230
+ out = out + self.residual_input_conv[0](resnet_input)
231
+
232
+ for i in range(self.num_layers):
233
+ batch_size, channels, h, w = out.shape
234
+ in_attn = out.reshape(batch_size, channels, h*w)
235
+ in_attn = self.attention_norms[i](in_attn)
236
+ in_attn = in_attn.transpose(1, 2)
237
+ out_attn, _ = self.attention_heads[i](in_attn, in_attn, in_attn)
238
+ out_attn = out_attn.reshape(batch_size, channels, h, w)
239
+ out = out + out_attn
240
+
241
+ if self.cross_attn:
242
+ assert context is not None, "Context needed when using cross attn"
243
+ batch_size, channels, h, w = out.shape
244
+ in_attn = out.reshape(batch_size, channels, h*w)
245
+ in_attn = self.cross_attn_norms[i](in_attn)
246
+ in_attn = in_attn.transpose(1, 2)
247
+ assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim
248
+ context_proj = self.context_proj[i](context)
249
+ out_attn, _ = self.cross_attn[i](
250
+ in_attn, context_proj, context_proj)
251
+ out_attn = out_attn.transpose(1, 2).reshape(
252
+ batch_size, channels, h, w)
253
+ out = out + out_attn
254
+
255
+ resnet_input = out
256
+ out = self.resnet_conv_one[i+1](out)
257
+ if self.t_emb_dim is not None:
258
+ out = out + self.time_emb_layers[i+1](t_emb)[:, :, None, None]
259
+ out = out + self.resnet_conv_two[i+1](out)
260
+ out = out + self.residual_input_conv[i+1](resnet_input)