Yash Nagraj commited on
Commit ·
6801d5a
1
Parent(s): de5e356
Add all the layers in the UpBlock
Browse files- models/blocks.py +71 -0
models/blocks.py
CHANGED
|
@@ -152,6 +152,7 @@ class MidBlock(nn.Module):
|
|
| 152 |
"""
|
| 153 |
|
| 154 |
def __init__(self, in_channels, out_channels, t_emb_dim, num_heads, num_layers, norm_dim, cross_attn=None, context_dim=None):
|
|
|
|
| 155 |
self.in_channels = in_channels
|
| 156 |
self.out_channels = out_channels
|
| 157 |
self.t_emb_dim = t_emb_dim
|
|
@@ -258,3 +259,73 @@ class MidBlock(nn.Module):
|
|
| 258 |
out = out + self.time_emb_layers[i+1](t_emb)[:, :, None, None]
|
| 259 |
out = out + self.resnet_conv_two[i+1](out)
|
| 260 |
out = out + self.residual_input_conv[i+1](resnet_input)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
"""
|
| 153 |
|
| 154 |
def __init__(self, in_channels, out_channels, t_emb_dim, num_heads, num_layers, norm_dim, cross_attn=None, context_dim=None):
|
| 155 |
+
super().__init__()
|
| 156 |
self.in_channels = in_channels
|
| 157 |
self.out_channels = out_channels
|
| 158 |
self.t_emb_dim = t_emb_dim
|
|
|
|
| 259 |
out = out + self.time_emb_layers[i+1](t_emb)[:, :, None, None]
|
| 260 |
out = out + self.resnet_conv_two[i+1](out)
|
| 261 |
out = out + self.residual_input_conv[i+1](resnet_input)
|
| 262 |
+
|
| 263 |
+
return out
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
class UpBlock(nn.Module):
|
| 267 |
+
"""
|
| 268 |
+
Up Block that upsamples the image, flows like this:
|
| 269 |
+
1) UpSample
|
| 270 |
+
2) Concat down block output
|
| 271 |
+
3) Resnet block with time embedding
|
| 272 |
+
4) Attention Block
|
| 273 |
+
"""
|
| 274 |
+
|
| 275 |
+
def __init__(self, in_channels, out_channels, t_emb_dim, up_sample, num_layers, attn, norm_channels, num_heads):
|
| 276 |
+
super().__init__()
|
| 277 |
+
self.num_layers = num_layers
|
| 278 |
+
self.attn = attn
|
| 279 |
+
self.norm_channels = norm_channels
|
| 280 |
+
self.resnet_conv_one = nn.ModuleList([
|
| 281 |
+
nn.Sequential(
|
| 282 |
+
nn.GroupNorm(norm_channels, in_channels if i ==
|
| 283 |
+
0 else out_channels),
|
| 284 |
+
nn.SiLU(),
|
| 285 |
+
nn.Conv2d(in_channels if i == 0 else out_channels,
|
| 286 |
+
out_channels, 3, 1, 1)
|
| 287 |
+
|
| 288 |
+
) for i in range(num_layers)
|
| 289 |
+
|
| 290 |
+
])
|
| 291 |
+
|
| 292 |
+
if t_emb_dim is not None:
|
| 293 |
+
self.time_emb_layers = nn.ModuleList(
|
| 294 |
+
[
|
| 295 |
+
nn.Sequential(
|
| 296 |
+
nn.SiLU(),
|
| 297 |
+
nn.Linear(t_emb_dim, out_channels)
|
| 298 |
+
) for _ in range(num_layers)
|
| 299 |
+
|
| 300 |
+
]
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
self.resnet_conv_two = nn.ModuleList([
|
| 304 |
+
nn.Sequential(
|
| 305 |
+
nn.GroupNorm(norm_channels, out_channels),
|
| 306 |
+
nn.SiLU(),
|
| 307 |
+
nn.Conv2d(out_channels, out_channels, 3, 1, 1)
|
| 308 |
+
|
| 309 |
+
) for _ in range(num_layers)
|
| 310 |
+
|
| 311 |
+
])
|
| 312 |
+
|
| 313 |
+
if self.attn:
|
| 314 |
+
self.attention_norms = nn.ModuleList([
|
| 315 |
+
nn.GroupNorm(norm_channels, out_channels)
|
| 316 |
+
for _ in range(num_layers)
|
| 317 |
+
])
|
| 318 |
+
|
| 319 |
+
self.attention_heads = nn.ModuleList(
|
| 320 |
+
[nn.MultiheadAttention(
|
| 321 |
+
out_channels, num_heads, batch_first=True) for _ in range(num_layers)]
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
self.resnet_input_conv = nn.ModuleList([
|
| 325 |
+
nn.Conv2d(in_channels if i == 0 else out_channels,
|
| 326 |
+
out_channels, 3, 1, 1)
|
| 327 |
+
for i in range(num_layers)
|
| 328 |
+
])
|
| 329 |
+
|
| 330 |
+
self.upsample = nn.ConvTranspose2d(
|
| 331 |
+
in_channels, in_channels, 4, 2, 1) if self.upsample else nn.Identity()
|