ControlNet / model_blocks /blocks.py
YashNagraj75's picture
Almost finished with Down Block
3b3d382
raw
history blame
5.92 kB
import torch
import torch.nn as nn
def get_time_embedding(time_steps, temb_dim):
r"""
Convert time steps tensor into an embedding using the
sinusoidal time embedding formula
:param time_steps: 1D tensor of length batch size
:param temb_dim: Dimension of the embedding
:return: BxD embedding representation of B time steps
"""
assert temb_dim % 2 == 0, "time embedding dimension must be divisible by 2"
# factor = 10000^(2i/d_model)
factor = 10000 ** (
torch.arange(
start=0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device
)
/ (temb_dim // 2)
)
# pos / factor
# timesteps B -> B, 1 -> B, temb_dim
t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor
t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1)
return t_emb
class DownBlock(nn.Module):
r"""
DownBlock for Diffusion model:
a) Block Time embedding -> [Silu -> FC]
1) Resnet Block :- [Norm-> Silu -> Conv] x num_layers
2) Self Attention :- [Norm -> SA]
3) Cross Attention :- [Norm -> CA]
b) DownSample : DownSample the dimnension
"""
def __init__(
self,
num_heads,
num_layers,
cross_attn,
input_dim,
output_dim,
t_emb_dim,
cond_dim,
norm_channels,
self_attn,
down_sample,
) -> None:
super().__init__()
self.num_heads = num_heads
self.num_layers = num_layers
self.cross_attn = cross_attn
self.input_dim = input_dim
self.output_dim = output_dim
self.cond_dim = cond_dim
self.norm_channels = norm_channels
self.t_emb_dim = t_emb_dim
self.attn = self_attn
self.down_sample = down_sample
self.resnet_in = nn.ModuleList(
[
nn.Conv2d(
self.input_dim if i == 0 else self.output_dim,
self.output_dim,
kernel_size=1,
)
for i in range(self.num_layers)
]
)
self.resnet_one = nn.ModuleList(
[
nn.Sequential(
nn.GroupNorm(
self.norm_channels,
self.input_dim if i == 0 else self.output_dim,
),
nn.SiLU(),
nn.Conv2d(
self.input_dim if i == 0 else self.output_dim,
self.output_dim,
kernel_size=3,
stride=1,
padding=1,
),
)
for i in range(self.num_layers)
]
)
if self.t_emb_dim is not None:
self.t_emb_layers = nn.ModuleList(
[
nn.Sequential(nn.SiLU(), nn.Linear(self.t_emb_dim, self.output_dim))
for _ in range(self.num_layers)
]
)
self.resnet_two = nn.ModuleList(
[
nn.Sequential(
nn.GroupNorm(
self.norm_channels,
self.output_dim,
),
nn.SiLU(),
nn.Conv2d(
self.output_dim,
self.output_dim,
kernel_size=3,
stride=1,
padding=1,
),
)
for _ in range(self.num_layers)
]
)
if self.attn:
self.attention_norms = nn.ModuleList(
[
nn.GroupNorm(self.norm_channels, self.output_dim)
for _ in range(num_layers)
]
)
self.attentions = nn.ModuleList(
[
nn.MultiheadAttention(
self.output_dim, self.num_heads, batch_first=True
)
for _ in range(self.num_layers)
]
)
if self.cross_attn:
self.cross_attn_norms = nn.ModuleList(
[
nn.GroupNorm(self.norm_channels, self.output_dim)
for _ in range(self.num_layers)
]
)
self.cross_attentions = nn.ModuleList(
[
nn.MultiheadAttention(
self.output_dim, self.num_heads, batch_first=True
)
for _ in range(self.num_layers)
]
)
self.context_proj = nn.ModuleList(
[
nn.Linear(self.cond_dim, self.output_dim)
for _ in range(self.num_layers)
]
)
self.down_sample_conv = (
nn.Conv2d(self.output_dim, self.output_dim, 4, 2, 1)
if self.down_sample
else nn.Identity()
)
def forward(self, x, t_emb=None, context=None):
out = x
for i in range(self.num_layers):
# Input x to Resnet Block of the Encoder of the Unet
resnet_input = out
out = self.resnet_one[i](out)
if t_emb is not None:
out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
out = self.resnet_two[i](out)
out = out + self.resnet_in[i](resnet_input)
if self.attn:
# Now Passing through the Self Attention blocks
batch_size, channels, h, w = out.shape
in_attn = out.reshape(batch_size, channels, h * w)
in_attn = self.attention_norms[i](in_attn)
in_attn = in_attn.transpose(1, 2)