Yash Nagraj commited on
Commit ·
de5e356
1
Parent(s): 38d054a
Add forward function to the midblock
Browse files- models/blocks.py +40 -0
models/blocks.py
CHANGED
|
@@ -157,6 +157,7 @@ class MidBlock(nn.Module):
|
|
| 157 |
self.t_emb_dim = t_emb_dim
|
| 158 |
self.cross_attn = cross_attn
|
| 159 |
self.context_dim = context_dim
|
|
|
|
| 160 |
self.resnet_conv_one = nn.ModuleList([
|
| 161 |
nn.Sequential(
|
| 162 |
nn.GroupNorm(norm_dim, in_channels if i ==
|
|
@@ -218,3 +219,42 @@ class MidBlock(nn.Module):
|
|
| 218 |
for i in range(num_layers + 1)
|
| 219 |
|
| 220 |
])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
self.t_emb_dim = t_emb_dim
|
| 158 |
self.cross_attn = cross_attn
|
| 159 |
self.context_dim = context_dim
|
| 160 |
+
self.num_layers = num_layers
|
| 161 |
self.resnet_conv_one = nn.ModuleList([
|
| 162 |
nn.Sequential(
|
| 163 |
nn.GroupNorm(norm_dim, in_channels if i ==
|
|
|
|
| 219 |
for i in range(num_layers + 1)
|
| 220 |
|
| 221 |
])
|
| 222 |
+
|
| 223 |
+
def forward(self, x, t_emb=None, context=None):
|
| 224 |
+
out = x
|
| 225 |
+
resnet_input = out
|
| 226 |
+
out = self.resnet_conv_one[0](out)
|
| 227 |
+
if self.t_emb_dim is not None:
|
| 228 |
+
out = out + self.time_emb_layers[0](t_emb)[:, :, None, None]
|
| 229 |
+
out = self.resnet_conv_two[0](out)
|
| 230 |
+
out = out + self.residual_input_conv[0](resnet_input)
|
| 231 |
+
|
| 232 |
+
for i in range(self.num_layers):
|
| 233 |
+
batch_size, channels, h, w = out.shape
|
| 234 |
+
in_attn = out.reshape(batch_size, channels, h*w)
|
| 235 |
+
in_attn = self.attention_norms[i](in_attn)
|
| 236 |
+
in_attn = in_attn.transpose(1, 2)
|
| 237 |
+
out_attn, _ = self.attention_heads[i](in_attn, in_attn, in_attn)
|
| 238 |
+
out_attn = out_attn.reshape(batch_size, channels, h, w)
|
| 239 |
+
out = out + out_attn
|
| 240 |
+
|
| 241 |
+
if self.cross_attn:
|
| 242 |
+
assert context is not None, "Context needed when using cross attn"
|
| 243 |
+
batch_size, channels, h, w = out.shape
|
| 244 |
+
in_attn = out.reshape(batch_size, channels, h*w)
|
| 245 |
+
in_attn = self.cross_attn_norms[i](in_attn)
|
| 246 |
+
in_attn = in_attn.transpose(1, 2)
|
| 247 |
+
assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim
|
| 248 |
+
context_proj = self.context_proj[i](context)
|
| 249 |
+
out_attn, _ = self.cross_attn[i](
|
| 250 |
+
in_attn, context_proj, context_proj)
|
| 251 |
+
out_attn = out_attn.transpose(1, 2).reshape(
|
| 252 |
+
batch_size, channels, h, w)
|
| 253 |
+
out = out + out_attn
|
| 254 |
+
|
| 255 |
+
resnet_input = out
|
| 256 |
+
out = self.resnet_conv_one[i+1](out)
|
| 257 |
+
if self.t_emb_dim is not None:
|
| 258 |
+
out = out + self.time_emb_layers[i+1](t_emb)[:, :, None, None]
|
| 259 |
+
out = out + self.resnet_conv_two[i+1](out)
|
| 260 |
+
out = out + self.residual_input_conv[i+1](resnet_input)
|