Yash Nagraj commited on
Commit ·
38d054a
1
Parent(s): c2857b5
Add residual convolution to the mid block
Browse files- models/blocks.py +18 -0
models/blocks.py
CHANGED
|
@@ -200,3 +200,21 @@ class MidBlock(nn.Module):
|
|
| 200 |
[nn.GroupNorm(norm_dim, out_channels)
|
| 201 |
for _ in range(num_layers)]
|
| 202 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
[nn.GroupNorm(norm_dim, out_channels)
|
| 201 |
for _ in range(num_layers)]
|
| 202 |
)
|
| 203 |
+
|
| 204 |
+
self.cross_attn = nn.ModuleList(
|
| 205 |
+
[nn.MultiheadAttention(
|
| 206 |
+
out_channels, num_heads=num_heads, batch_first=True) for _ in range(num_layers)]
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
self.context_proj = nn.ModuleList([
|
| 210 |
+
nn.Conv2d(in_channels if i == 0 else out_channels,
|
| 211 |
+
out_channels, kernel_size=1)
|
| 212 |
+
for i in range(num_layers + 1)
|
| 213 |
+
])
|
| 214 |
+
|
| 215 |
+
self.residual_input_conv = nn.ModuleList([
|
| 216 |
+
nn.Conv2d(in_channels if i == 0 else out_channels,
|
| 217 |
+
out_channels, kernel_size=1)
|
| 218 |
+
for i in range(num_layers + 1)
|
| 219 |
+
|
| 220 |
+
])
|