Yash Nagraj commited on
Commit ·
70a401a
1
Parent(s): 6801d5a
Add UpBlock Unet to concat downblock's output
Browse files- models/blocks.py +122 -53
models/blocks.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
|
|
@@ -263,69 +264,137 @@ class MidBlock(nn.Module):
|
|
| 263 |
return out
|
| 264 |
|
| 265 |
|
| 266 |
-
class
|
| 267 |
-
"""
|
| 268 |
-
Up
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
|
|
|
| 273 |
"""
|
| 274 |
|
| 275 |
-
def __init__(self, in_channels, out_channels, t_emb_dim, up_sample,
|
|
|
|
| 276 |
super().__init__()
|
| 277 |
self.num_layers = num_layers
|
| 278 |
-
self.
|
| 279 |
-
self.
|
| 280 |
-
self.
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
nn.
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
[
|
| 295 |
-
nn.Sequential(
|
| 296 |
-
nn.SiLU(),
|
| 297 |
-
nn.Linear(t_emb_dim, out_channels)
|
| 298 |
-
) for _ in range(num_layers)
|
| 299 |
-
|
| 300 |
-
]
|
| 301 |
-
)
|
| 302 |
-
|
| 303 |
-
self.resnet_conv_two = nn.ModuleList([
|
| 304 |
-
nn.Sequential(
|
| 305 |
-
nn.GroupNorm(norm_channels, out_channels),
|
| 306 |
-
nn.SiLU(),
|
| 307 |
-
nn.Conv2d(out_channels, out_channels, 3, 1, 1)
|
| 308 |
|
| 309 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 310 |
|
| 311 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 312 |
|
| 313 |
-
|
| 314 |
-
|
| 315 |
nn.GroupNorm(norm_channels, out_channels)
|
| 316 |
for _ in range(num_layers)
|
| 317 |
-
]
|
|
|
|
| 318 |
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 322 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 323 |
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
])
|
| 329 |
|
| 330 |
-
|
| 331 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from re import A
|
| 2 |
import torch
|
| 3 |
import torch.nn as nn
|
| 4 |
|
|
|
|
| 264 |
return out
|
| 265 |
|
| 266 |
|
| 267 |
+
class UpBlockUnet(nn.Module):
|
| 268 |
+
r"""
|
| 269 |
+
Up conv block with attention.
|
| 270 |
+
Sequence of following blocks
|
| 271 |
+
1. Upsample
|
| 272 |
+
1. Concatenate Down block output
|
| 273 |
+
2. Resnet block with time embedding
|
| 274 |
+
3. Attention Block
|
| 275 |
"""
|
| 276 |
|
| 277 |
+
def __init__(self, in_channels, out_channels, t_emb_dim, up_sample,
|
| 278 |
+
num_heads, num_layers, norm_channels, cross_attn=False, context_dim=None):
|
| 279 |
super().__init__()
|
| 280 |
self.num_layers = num_layers
|
| 281 |
+
self.up_sample = up_sample
|
| 282 |
+
self.t_emb_dim = t_emb_dim
|
| 283 |
+
self.cross_attn = cross_attn
|
| 284 |
+
self.context_dim = context_dim
|
| 285 |
+
self.resnet_conv_first = nn.ModuleList(
|
| 286 |
+
[
|
| 287 |
+
nn.Sequential(
|
| 288 |
+
nn.GroupNorm(norm_channels, in_channels if i ==
|
| 289 |
+
0 else out_channels),
|
| 290 |
+
nn.SiLU(),
|
| 291 |
+
nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,
|
| 292 |
+
padding=1),
|
| 293 |
+
)
|
| 294 |
+
for i in range(num_layers)
|
| 295 |
+
]
|
| 296 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
|
| 298 |
+
if self.t_emb_dim is not None:
|
| 299 |
+
self.t_emb_layers = nn.ModuleList([
|
| 300 |
+
nn.Sequential(
|
| 301 |
+
nn.SiLU(),
|
| 302 |
+
nn.Linear(t_emb_dim, out_channels)
|
| 303 |
+
)
|
| 304 |
+
for _ in range(num_layers)
|
| 305 |
+
])
|
| 306 |
|
| 307 |
+
self.resnet_conv_second = nn.ModuleList(
|
| 308 |
+
[
|
| 309 |
+
nn.Sequential(
|
| 310 |
+
nn.GroupNorm(norm_channels, out_channels),
|
| 311 |
+
nn.SiLU(),
|
| 312 |
+
nn.Conv2d(out_channels, out_channels,
|
| 313 |
+
kernel_size=3, stride=1, padding=1),
|
| 314 |
+
)
|
| 315 |
+
for _ in range(num_layers)
|
| 316 |
+
]
|
| 317 |
+
)
|
| 318 |
|
| 319 |
+
self.attention_norms = nn.ModuleList(
|
| 320 |
+
[
|
| 321 |
nn.GroupNorm(norm_channels, out_channels)
|
| 322 |
for _ in range(num_layers)
|
| 323 |
+
]
|
| 324 |
+
)
|
| 325 |
|
| 326 |
+
self.attentions = nn.ModuleList(
|
| 327 |
+
[
|
| 328 |
+
nn.MultiheadAttention(
|
| 329 |
+
out_channels, num_heads, batch_first=True)
|
| 330 |
+
for _ in range(num_layers)
|
| 331 |
+
]
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
if self.cross_attn:
|
| 335 |
+
assert context_dim is not None, "Context Dimension must be passed for cross attention"
|
| 336 |
+
self.cross_attention_norms = nn.ModuleList(
|
| 337 |
+
[nn.GroupNorm(norm_channels, out_channels)
|
| 338 |
+
for _ in range(num_layers)]
|
| 339 |
+
)
|
| 340 |
+
self.cross_attentions = nn.ModuleList(
|
| 341 |
+
[nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
|
| 342 |
+
for _ in range(num_layers)]
|
| 343 |
)
|
| 344 |
+
self.context_proj = nn.ModuleList(
|
| 345 |
+
[nn.Linear(context_dim, out_channels)
|
| 346 |
+
for _ in range(num_layers)]
|
| 347 |
+
)
|
| 348 |
+
self.residual_input_conv = nn.ModuleList(
|
| 349 |
+
[
|
| 350 |
+
nn.Conv2d(in_channels if i == 0 else out_channels,
|
| 351 |
+
out_channels, kernel_size=1)
|
| 352 |
+
for i in range(num_layers)
|
| 353 |
+
]
|
| 354 |
+
)
|
| 355 |
+
self.up_sample_conv = nn.ConvTranspose2d(in_channels // 2, in_channels // 2,
|
| 356 |
+
4, 2, 1) \
|
| 357 |
+
if self.up_sample else nn.Identity()
|
| 358 |
|
| 359 |
+
def forward(self, x, out_down=None, t_emb=None, context=None):
|
| 360 |
+
x = self.up_sample_conv(x)
|
| 361 |
+
if out_down is not None:
|
| 362 |
+
x = torch.cat([x, out_down], dim=1)
|
|
|
|
| 363 |
|
| 364 |
+
out = x
|
| 365 |
+
for i in range(self.num_layers):
|
| 366 |
+
# Resnet
|
| 367 |
+
resnet_input = out
|
| 368 |
+
out = self.resnet_conv_first[i](out)
|
| 369 |
+
if self.t_emb_dim is not None:
|
| 370 |
+
out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
|
| 371 |
+
out = self.resnet_conv_second[i](out)
|
| 372 |
+
out = out + self.residual_input_conv[i](resnet_input)
|
| 373 |
+
# Self Attention
|
| 374 |
+
batch_size, channels, h, w = out.shape
|
| 375 |
+
in_attn = out.reshape(batch_size, channels, h * w)
|
| 376 |
+
in_attn = self.attention_norms[i](in_attn)
|
| 377 |
+
in_attn = in_attn.transpose(1, 2)
|
| 378 |
+
out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
|
| 379 |
+
out_attn = out_attn.transpose(1, 2).reshape(
|
| 380 |
+
batch_size, channels, h, w)
|
| 381 |
+
out = out + out_attn
|
| 382 |
+
# Cross Attention
|
| 383 |
+
if self.cross_attn:
|
| 384 |
+
assert context is not None, "context cannot be None if cross attention layers are used"
|
| 385 |
+
batch_size, channels, h, w = out.shape
|
| 386 |
+
in_attn = out.reshape(batch_size, channels, h * w)
|
| 387 |
+
in_attn = self.cross_attention_norms[i](in_attn)
|
| 388 |
+
in_attn = in_attn.transpose(1, 2)
|
| 389 |
+
assert len(context.shape) == 3, \
|
| 390 |
+
"Context shape does not match B,_,CONTEXT_DIM"
|
| 391 |
+
assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim, \
|
| 392 |
+
"Context shape does not match B,_,CONTEXT_DIM"
|
| 393 |
+
context_proj = self.context_proj[i](context)
|
| 394 |
+
out_attn, _ = self.cross_attentions[i](
|
| 395 |
+
in_attn, context_proj, context_proj)
|
| 396 |
+
out_attn = out_attn.transpose(1, 2).reshape(
|
| 397 |
+
batch_size, channels, h, w)
|
| 398 |
+
out = out + out_attn
|
| 399 |
+
|
| 400 |
+
return out
|