ControlNet / model_blocks /blocks.py
YashNagraj75's picture
Add UpBlock
e1d97a8
import logging
import torch
import torch.nn as nn
logger = logging.getLogger(__name__)
def get_time_embedding(time_steps, temb_dim):
r"""
Convert time steps tensor into an embedding using the
sinusoidal time embedding formula
:param time_steps: 1D tensor of length batch size
:param temb_dim: Dimension of the embedding
:return: BxD embedding representation of B time steps
"""
assert temb_dim % 2 == 0, "time embedding dimension must be divisible by 2"
# factor = 10000^(2i/d_model)
factor = 10000 ** (
torch.arange(
start=0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device
)
/ (temb_dim // 2)
)
# pos / factor
# timesteps B -> B, 1 -> B, temb_dim
t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor
t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1)
return t_emb
class DownBlock(nn.Module):
r"""
DownBlock for Diffusion model:
a) Block Time embedding -> [Silu -> FC]
1) Resnet Block :- [Norm-> Silu -> Conv] x num_layers
2) Self Attention :- [Norm -> SA]
3) Cross Attention :- [Norm -> CA]
b) MidSample : DownSample the dimnension
"""
def __init__(
self,
num_heads,
num_layers,
cross_attn,
input_dim,
output_dim,
t_emb_dim,
cond_dim,
norm_channels,
self_attn,
down_sample,
) -> None:
super().__init__()
self.num_heads = num_heads
self.num_layers = num_layers
self.cross_attn = cross_attn
self.input_dim = input_dim
self.output_dim = output_dim
self.cond_dim = cond_dim
self.norm_channels = norm_channels
self.t_emb_dim = t_emb_dim
self.attn = self_attn
self.down_sample = down_sample
self.resnet_in = nn.ModuleList(
[
nn.Conv2d(
self.input_dim if i == 0 else self.output_dim,
self.output_dim,
kernel_size=1,
)
for i in range(self.num_layers)
]
)
self.resnet_one = nn.ModuleList(
[
nn.Sequential(
nn.GroupNorm(
self.norm_channels,
self.input_dim if i == 0 else self.output_dim,
),
nn.SiLU(),
nn.Conv2d(
self.input_dim if i == 0 else self.output_dim,
self.output_dim,
kernel_size=3,
stride=1,
padding=1,
),
)
for i in range(self.num_layers)
]
)
if self.t_emb_dim is not None:
self.t_emb_layers = nn.ModuleList(
[
nn.Sequential(nn.SiLU(), nn.Linear(self.t_emb_dim, self.output_dim))
for _ in range(self.num_layers)
]
)
self.resnet_two = nn.ModuleList(
[
nn.Sequential(
nn.GroupNorm(
self.norm_channels,
self.output_dim,
),
nn.SiLU(),
nn.Conv2d(
self.output_dim,
self.output_dim,
kernel_size=3,
stride=1,
padding=1,
),
)
for _ in range(self.num_layers)
]
)
if self.attn:
self.attention_norms = nn.ModuleList(
[
nn.GroupNorm(self.norm_channels, self.output_dim)
for _ in range(num_layers)
]
)
self.attentions = nn.ModuleList(
[
nn.MultiheadAttention(
self.output_dim, self.num_heads, batch_first=True
)
for _ in range(self.num_layers)
]
)
if self.cross_attn:
self.cross_attn_norms = nn.ModuleList(
[
nn.GroupNorm(self.norm_channels, self.output_dim)
for _ in range(self.num_layers)
]
)
self.cross_attentions = nn.ModuleList(
[
nn.MultiheadAttention(
self.output_dim, self.num_heads, batch_first=True
)
for _ in range(self.num_layers)
]
)
self.context_proj = nn.ModuleList(
[
nn.Linear(self.cond_dim, self.output_dim)
for _ in range(self.num_layers)
]
)
self.down_sample_conv = (
nn.Conv2d(self.output_dim, self.output_dim, 4, 2, 1)
if self.down_sample
else nn.Identity()
)
def forward(self, x, t_emb=None, context=None):
out = x
for i in range(self.num_layers):
# Input x to Resnet Block of the Encoder of the Unet
logger.debug(f"Input to Resnet Block in Down Block Layer {i} : {out.shape}")
resnet_input = out
out = self.resnet_one[i](out)
logger.debug(
f"Output of Resnet Sub Block 1 of Down Block Layer {i} : {out.shape}"
)
if self.t_emb_dim is not None:
logger.debug(
f"Adding t_emb of shape {self.t_emb_dim} to output of shape: {out.shape} of Down Block Layer {i}"
)
out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
out = self.resnet_two[i](out)
logger.debug(
f"Output of Resnet Sub Block 2 of Down Block Layer: {i} with output_shape:{out.shape}"
)
out = out + self.resnet_in[i](resnet_input)
logger.debug(
f"Residual connection of the input to out : {out.shape} in Down Block Layer {i}"
)
if self.attn:
# Now Passing through the Self Attention blocks
logger.debug(f"Going into the attention Block in Down Block Layer {i}")
batch_size, channels, h, w = out.shape
in_attn = out.reshape(batch_size, channels, h * w)
in_attn = self.attention_norms[i](in_attn)
in_attn = in_attn.transpose(1, 2)
out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
out = out + out_attn
logger.debug(
f"Out of the Self Attention Block with out : {out.shape} in Down Block Layer {i}"
)
if self.cross_attn:
assert context is not None, (
"context cannot be None if cross attention layers are used"
)
logger.debug(
f"Going into the Cross Attention Block in Down Block Layer {i}"
)
batch_size, channels, h, w = out.shape
in_attn = out.reshape(batch_size, channels, h * w)
in_attn = self.cross_attn_norms[i](in_attn)
in_attn = in_attn.transpose(1, 2)
assert (
context.shape[0] == x.shape[0]
and context.shape[-1] == self.context_dim
)
logger.debug(
f"Calculating context projection for Cross Attn in Down Block Layer : {i}"
)
context_proj = self.context_proj[i](context)
out_attn, _ = self.cross_attentions[i](
in_attn, context_proj, context_proj
)
out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
out = out + out_attn
logger.debug(
f"Out of the Cross Attention Block with out : {out.shape} in Down Block Layer {i}"
)
# DownSample to x2 smaller dimension
out = self.down_sample_conv(out)
logger.debug(f"Down Sampling out to : {out.shape} in Down Block Layer {i} ")
return out
class MidBlock(nn.Module):
r"""
MidBlock for Diffusion model:
Time embedding -> [Silu -> FC]
1) Resnet Block :- [Norm-> Silu -> Conv] x num_layers
2) Self Attention :- [Norm -> SA]
3) Cross Attention :- [Norm -> CA]
Time embedding -> [Silu -> FC]
4) Resnet Block :- [Norm-> Silu -> Conv] x num_layers
"""
def __init__(
self,
num_heads,
num_layers,
cross_attn,
input_dim,
output_dim,
t_emb_dim,
cond_dim,
norm_channels,
self_attn,
down_sample,
) -> None:
super().__init__()
self.num_heads = num_heads
self.num_layers = num_layers
self.cross_attn = cross_attn
self.input_dim = input_dim
self.output_dim = output_dim
self.cond_dim = cond_dim
self.norm_channels = norm_channels
self.t_emb_dim = t_emb_dim
self.attn = self_attn
self.down_sample = down_sample
self.resnet_one = nn.ModuleList(
[
nn.Sequential(
nn.GroupNorm(
self.norm_channels,
self.input_dim if i == 0 else self.output_dim,
),
nn.SiLU(),
nn.Conv2d(
self.input_dim if i == 0 else self.output_dim,
self.output_dim,
kernel_size=3,
stride=1,
padding=1,
),
)
for i in range(self.num_layers + 1)
]
)
if self.t_emb_dim is not None:
self.t_emb_layers = nn.ModuleList(
[
nn.Sequential(nn.SiLU(), nn.Linear(self.t_emb_dim, self.output_dim))
for _ in range(self.num_layers + 1)
]
)
self.resnet_two = nn.ModuleList(
[
nn.Sequential(
nn.GroupNorm(
self.norm_channels,
self.output_dim,
),
nn.SiLU(),
nn.Conv2d(
self.output_dim,
self.output_dim,
kernel_size=3,
stride=1,
padding=1,
),
)
for _ in range(self.num_layers + 1)
]
)
if self.attn:
self.attention_norms = nn.ModuleList(
[
nn.GroupNorm(self.norm_channels, self.output_dim)
for _ in range(num_layers)
]
)
self.attentions = nn.ModuleList(
[
nn.MultiheadAttention(
self.output_dim, self.num_heads, batch_first=True
)
for _ in range(self.num_layers)
]
)
if self.cross_attn:
self.cross_attn_norms = nn.ModuleList(
[
nn.GroupNorm(self.norm_channels, self.output_dim)
for _ in range(self.num_layers)
]
)
self.cross_attentions = nn.ModuleList(
[
nn.MultiheadAttention(
self.output_dim, self.num_heads, batch_first=True
)
for _ in range(self.num_layers)
]
)
self.context_proj = nn.ModuleList(
[
nn.Linear(self.cond_dim, self.output_dim)
for _ in range(self.num_layers)
]
)
self.resnet_in = nn.ModuleList(
[
nn.Conv2d(
self.input_dim if i == 0 else self.output_dim,
self.output_dim,
kernel_size=1,
)
for i in range(self.num_layers + 1)
]
)
def forward(self, x, t_emb=None, context=None):
out = x
# Input Resnet Block
logger.debug("Input to First Resnet Block in Mid Block")
resnet_input = out
out = self.resnet_one[0](out)
logger.debug(f"Output of Resnet Sub Block 1 of Mid Block Layer: {out.shape}")
if self.t_emb_dim is not None:
out = out + self.t_emb_layers[0](t_emb)[:, :, None, None]
logger.debug(
f"Adding t_emb of shape {self.t_emb_dim} to output of shape: {out.shape}"
)
out = self.resnet_two[0](out)
logger.debug(f"Output of Resnet Sub Block 2 with output_shape:{out.shape}")
out = out + self.resnet_in[0](resnet_input)
logger.debug(
f"Residual connection of the input to out : {out.shape} in Mid Block"
)
for i in range(self.num_layers):
logger.debug(f"Going into the attention Block in Mid Block Layer {i}")
batch_size, channels, h, w = out.shape
in_attn = out.reshape(batch_size, channels, h * w)
in_attn = self.attention_norms[i](in_attn)
in_attn = in_attn.transpose(1, 2)
out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
out = out + out_attn
logger.debug(
f"Out of the Self Attention Block with out : {out.shape} in Mid Block Layer {i}"
)
if self.cross_attn:
assert context is not None, (
"context cannot be None if cross attention layers are used"
)
logger.debug(
f"Going into the Cross Attention Block in Mid Block Layer {i}"
)
batch_size, channels, h, w = out.shape
in_attn = out.reshape(batch_size, channels, h * w)
in_attn = self.cross_attn_norms[i](in_attn)
in_attn = in_attn.transpose(1, 2)
assert (
context.shape[0] == x.shape[0]
and context.shape[-1] == self.context_dim
)
logger.debug(
f"Calculating context projection for Cross Attn in Mid Block Layer : {i}"
)
context_proj = self.context_proj[i](context)
out_attn, _ = self.cross_attentions[i](
in_attn, context_proj, context_proj
)
out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
out = out + out_attn
logger.debug(
f"Out of the Cross Attention Block with out : {out.shape} in Mid Block Layer {i}"
)
logger.debug(
f"Last Resnet Block input : {out.shape} of Mid Block Layer {i}"
)
resnet_input = out
out = self.resnet_one[0](out)
logger.debug(
f"Output of Resnet Sub Block 1 of Mid Block Layer {i} of shape : {out.shape}"
)
if self.t_emb_dim is not None:
out = out + self.t_emb_layers[0](t_emb)[:, :, None, None]
logger.debug(
f"Adding t_emb of shape {self.t_emb_dim} to output of shape: {out.shape} of Mid Block Layer {i}"
)
out = self.resnet_two[0](out)
logger.debug(
f"Output of Resnet Sub Block 2 with output_shape:{out.shape} of Mid Block Layer {i}"
)
out = out + self.resnet_in[0](resnet_input)
logger.debug(
f"Residual connection of the input to out : {out.shape} in Mid Block Layer {i}"
)
return out
class UpBlockUnet(nn.Module):
r"""
Up conv block with attention.
Sequence of following blocks
1. Upsample
1. Concatenate Down block output
2. Resnet block with time embedding
3. Attention Block
"""
def __init__(
self,
in_channels,
out_channels,
t_emb_dim,
up_sample,
num_heads,
num_layers,
norm_channels,
cross_attn=False,
context_dim=None,
):
super().__init__()
self.num_layers = num_layers
self.up_sample = up_sample
self.t_emb_dim = t_emb_dim
self.cross_attn = cross_attn
self.context_dim = context_dim
self.resnet_conv_first = nn.ModuleList(
[
nn.Sequential(
nn.GroupNorm(
norm_channels, in_channels if i == 0 else out_channels
),
nn.SiLU(),
nn.Conv2d(
in_channels if i == 0 else out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
),
)
for i in range(num_layers)
]
)
if self.t_emb_dim is not None:
self.t_emb_layers = nn.ModuleList(
[
nn.Sequential(nn.SiLU(), nn.Linear(t_emb_dim, out_channels))
for _ in range(num_layers)
]
)
self.resnet_conv_second = nn.ModuleList(
[
nn.Sequential(
nn.GroupNorm(norm_channels, out_channels),
nn.SiLU(),
nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=1, padding=1
),
)
for _ in range(num_layers)
]
)
self.attention_norms = nn.ModuleList(
[nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)]
)
self.attentions = nn.ModuleList(
[
nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
for _ in range(num_layers)
]
)
if self.cross_attn:
assert context_dim is not None, (
"Context Dimension must be passed for cross attention"
)
self.cross_attention_norms = nn.ModuleList(
[nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)]
)
self.cross_attentions = nn.ModuleList(
[
nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
for _ in range(num_layers)
]
)
self.context_proj = nn.ModuleList(
[nn.Linear(context_dim, out_channels) for _ in range(num_layers)]
)
self.residual_input_conv = nn.ModuleList(
[
nn.Conv2d(
in_channels if i == 0 else out_channels, out_channels, kernel_size=1
)
for i in range(num_layers)
]
)
self.up_sample_conv = (
nn.ConvTranspose2d(in_channels // 2, in_channels // 2, 4, 2, 1)
if self.up_sample
else nn.Identity()
)
def forward(self, x, out_down=None, t_emb=None, context=None):
x = self.up_sample_conv(x)
if out_down is not None:
x = torch.cat([x, out_down], dim=1)
out = x
for i in range(self.num_layers):
# Resnet
resnet_input = out
out = self.resnet_conv_first[i](out)
if self.t_emb_dim is not None:
out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
out = self.resnet_conv_second[i](out)
out = out + self.residual_input_conv[i](resnet_input)
# Self Attention
batch_size, channels, h, w = out.shape
in_attn = out.reshape(batch_size, channels, h * w)
in_attn = self.attention_norms[i](in_attn)
in_attn = in_attn.transpose(1, 2)
out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
out = out + out_attn
# Cross Attention
if self.cross_attn:
assert context is not None, (
"context cannot be None if cross attention layers are used"
)
batch_size, channels, h, w = out.shape
in_attn = out.reshape(batch_size, channels, h * w)
in_attn = self.cross_attention_norms[i](in_attn)
in_attn = in_attn.transpose(1, 2)
assert len(context.shape) == 3, (
"Context shape does not match B,_,CONTEXT_DIM"
)
assert (
context.shape[0] == x.shape[0]
and context.shape[-1] == self.context_dim
), "Context shape does not match B,_,CONTEXT_DIM"
context_proj = self.context_proj[i](context)
out_attn, _ = self.cross_attentions[i](
in_attn, context_proj, context_proj
)
out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
out = out + out_attn
return out