ControlNet / model_blocks /unet_base.py
YashNagraj75's picture
Add Unet and other blocks
76a0a2e
import logging
from os import wait
import torch
import torch.nn as nn
logger = logging.getLogger(__name__)
def get_time_embedding(time_steps, temb_dim):
r"""
Convert time steps tensor into an embedding using the
sinusoidal time embedding formula
:param time_steps: 1D tensor of length batch size
:param temb_dim: Dimension of the embedding
:return: BxD embedding representation of B time steps
"""
assert temb_dim % 2 == 0, "time embedding dimension must be divisible by 2"
# factor = 10000^(2i/d_model)
factor = 10000 ** (
torch.arange(
start=0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device
)
/ (temb_dim // 2)
)
# pos / factor
# timesteps B -> B, 1 -> B, temb_dim
t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor
t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1)
return t_emb
class DownBlock(nn.Module):
r"""
DownBlock for Diffusion model:
a) Block Time embedding -> [Silu -> FC]
1) Resnet Block :- [Norm-> Silu -> Conv] x num_layers
2) Self Attention :- [Norm -> SA]
b) DownSample : DownSample the dimnension
"""
def __init__(
self,
input_dim,
output_dim,
t_emb_dim,
down_sample=True,
num_heads=4,
num_layers=1,
) -> None:
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.down_sample = down_sample
self.num_heads = num_heads
self.num_layers = num_layers
self.t_emb_dim = t_emb_dim
self.resnet_one = nn.ModuleList(
[
nn.Sequential(
nn.GroupNorm(8, self.input_dim if i == 0 else self.output_dim),
nn.SiLU(),
nn.Conv2d(
self.input_dim if i == 0 else self.output_dim,
self.output_dim,
kernel_size=3,
stride=1,
padding=1,
),
)
for i in range(self.num_layers)
]
)
self.t_emb_layers = nn.ModuleList(
[
nn.Sequential(nn.SiLU(), nn.Linear(self.t_emb_dim, self.output_dim))
for _ in range(self.num_layers)
]
)
self.resnet_two = nn.ModuleList(
[
nn.Sequential(
nn.GroupNorm(8, self.output_dim),
nn.SiLU(),
nn.Conv2d(
self.output_dim,
self.output_dim,
kernel_size=3,
stride=1,
padding=1,
),
)
for _ in range(self.num_layers)
]
)
self.attention_norms = nn.ModuleList(
[nn.GroupNorm(8, self.output_dim) for _ in range(self.num_layers)]
)
self.attentions = nn.ModuleList(
[nn.MultiheadAttention(self.output_dim, self.num_heads, batch_first=True)]
)
self.resnet_in = nn.ModuleList(
[
nn.Conv2d(
self.input_dim if i == 0 else self.output_dim,
self.output_dim,
kernel_size=1,
)
for i in range(self.num_layers)
]
)
self.down_sample_conv = (
nn.Conv2d(self.output_dim, self.output_dim, 4, 2, 1)
if self.down_sample
else nn.Identity()
)
def forward(
self,
x,
t_emb,
):
out = x
logger.debug(f"Input of shape: {out.shape} to Down Block ")
for i in range(self.num_layers):
resnet_input = out
logger.debug(f"Input to Resnet Block : {resnet_input.shape} ")
out = self.resnet_one[i](out)
out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
logger.debug(
f"Concatenated time embeddings to Resnet Sub Block 1: {out.shape} of Down Block Layer {i}"
)
out = self.resnet_two[i](out)
out = out + self.resnet_in[i](resnet_input)
logger.debug(
f"Adding Residual connection : {out.shape} to Down Block Layer {i}"
)
batch_size, channels, h, w = out.shape
in_attn = out.reshape(batch_size, channels, h * w)
in_attn = self.attention_norms[i](in_attn)
logger.debug(f"Attention Norm: {in_attn.shape} in Down Block Layer : {i}")
in_attn = in_attn.transpose(1, 2)
logger.debug(
f"Passing Norm : {in_attn.shape} to Attention Layer in Down Block Layer : {i}"
)
out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
out = out + out_attn
logger.debug(
f"Added Attention score to output: {out.shape} in Down Block Layer {i}"
)
out = self.down_sample_conv(out)
logger.debug(f"Down sampled to : {out.shape}")
return out
class MidBlock(nn.Module):
r"""
MidBlock for Diffusion model:
Time embedding -> [Silu -> FC]
1) Resnet Block :- [Norm-> Silu -> Conv] x num_layers
2) Self Attention :- [Norm -> SA]
Time embedding -> [Silu -> FC]
3) Resnet Block :- [Norm-> Silu -> Conv] x num_layers
"""
def __init__(
self,
input_dim,
output_dim,
t_emb_dim,
num_heads=4,
num_layers=1,
) -> None:
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.num_heads = num_heads
self.num_layers = num_layers
self.t_emb_dim = t_emb_dim
self.resnet_one = nn.ModuleList(
[
nn.Sequential(
nn.GroupNorm(8, self.input_dim if i == 0 else self.output_dim),
nn.SiLU(),
nn.Conv2d(
self.input_dim if i == 0 else self.output_dim,
self.output_dim,
kernel_size=3,
stride=1,
padding=1,
),
)
for i in range(self.num_layers + 1)
]
)
self.t_emb_layers = nn.ModuleList(
[
nn.Sequential(nn.SiLU(), nn.Linear(self.t_emb_dim, self.output_dim))
for _ in range(self.num_layers + 1)
]
)
self.resnet_two = nn.ModuleList(
[
nn.Sequential(
nn.GroupNorm(8, self.output_dim),
nn.SiLU(),
nn.Conv2d(
self.output_dim,
self.output_dim,
kernel_size=3,
stride=1,
padding=1,
),
)
for _ in range(self.num_layers + 1)
]
)
self.attention_norms = nn.ModuleList(
[nn.GroupNorm(8, self.output_dim) for _ in range(self.num_layers)]
)
self.attentions = nn.ModuleList(
[
nn.MultiheadAttention(self.output_dim, self.num_heads, batch_first=True)
for _ in range(self.num_layers)
]
)
self.resnet_in = nn.ModuleList(
[
nn.Conv2d(
self.input_dim if i == 0 else self.output_dim,
self.output_dim,
kernel_size=1,
)
for i in range(self.num_layers + 1)
]
)
def forward(self, x, t_emb):
out = x
logger.debug(f"Input of shape: {out.shape} to Mid Block ")
# First Resnet Block
resnet_input = out
logger.debug(
f"Input to Resnet Block : {resnet_input.shape} in Mid Block Layer 0"
)
out = self.resnet_one[0](out)
out = out + self.t_emb_layers[0](t_emb)[:, :, None, None]
logger.debug(
f"Concatenated time embeddings to Resnet Sub Block 1: {out.shape} of Mid Block Layer 0"
)
out = self.resnet_two[0](out)
out = out + self.resnet_in[0](resnet_input)
logger.debug(f"Adding Residual connection : {out.shape} to Mid Block Layer 0")
for i in range(self.num_layers):
# Attention Block
batch_size, channels, h, w = out.shape
in_attn = out.reshape(batch_size, channels, h * w)
in_attn = self.attention_norms[i](in_attn)
logger.debug(f"Attention Norm: {in_attn.shape} in Mid Block Layer : {i} ")
in_attn = in_attn.transpose(1, 2)
out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
out = out + out_attn
logger.debug(
f"Added Attention score to output: {out.shape} in Mid Block Layer {i}"
)
# Resnet Block
resnet_input = out
logger.debug(
f"Input to Resnet Block : {resnet_input.shape} in Mid Block Layer {i}"
)
out = self.resnet_one[i + 1](out)
out = out + self.t_emb_layers[i + 1](t_emb)[:, :, None, None]
logger.debug(
f"Concatenated time embeddings to Resnet Sub Block 1: {out.shape} of Mid Block Layer {i}"
)
out = self.resnet_two[i + 1](out)
out = out + self.resnet_in[i + 1](resnet_input)
logger.debug(
f"Adding Residual connection : {out.shape} to Mid Block Layer {i}"
)
return out
class UpBlock(nn.Module):
r"""
UpBlock for Diffusion model:
1. Upsample
1. Concatenate Down block output
2. Resnet block with time embedding
3. Attention Block
"""
def __init__(
self,
input_dim,
output_dim,
t_emb_dim,
up_sample=True,
num_heads=4,
num_layers=1,
) -> None:
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.up_sample = up_sample
self.num_heads = num_heads
self.num_layers = num_layers
self.t_emb_dim = t_emb_dim
self.resnet_one = nn.ModuleList(
[
nn.Sequential(
nn.GroupNorm(8, self.input_dim if i == 0 else self.output_dim),
nn.SiLU(),
nn.Conv2d(
self.input_dim if i == 0 else self.output_dim,
self.output_dim,
kernel_size=3,
stride=1,
padding=1,
),
)
for i in range(self.num_layers)
]
)
self.t_emb_layers = nn.ModuleList(
[
nn.Sequential(nn.SiLU(), nn.Linear(self.t_emb_dim, self.output_dim))
for _ in range(self.num_layers)
]
)
self.resnet_two = nn.ModuleList(
[
nn.Sequential(
nn.GroupNorm(8, self.output_dim),
nn.SiLU(),
nn.Conv2d(
self.output_dim,
self.output_dim,
kernel_size=3,
stride=1,
padding=1,
),
)
for _ in range(self.num_layers)
]
)
self.attention_norms = nn.ModuleList(
[nn.GroupNorm(8, self.output_dim) for _ in range(self.num_layers)]
)
self.attentions = nn.ModuleList(
[
nn.MultiheadAttention(self.output_dim, self.num_heads, batch_first=True)
for _ in range(self.num_layers)
]
)
self.resnet_in = nn.ModuleList(
[
nn.Conv2d(
self.input_dim if i == 0 else self.output_dim,
self.output_dim,
kernel_size=1,
)
for i in range(self.num_layers)
]
)
self.up_sample_conv = (
nn.ConvTranspose2d(self.input_dim // 2, self.output_dim // 2, 4, 2, 1)
if self.up_sample
else nn.Identity()
)
def forward(self, x, out_down, t_emb):
logger.debug(f"Input of shape: {x.shape} to Up Block ")
out = x
out = self.up_sample_conv(out)
logger.debug(f"Up sampled to : {out.shape}")
# Concatenate Down Block output
out = torch.cat([out, out_down], dim=1)
logger.debug(f"Concatenated Down Block output: {out.shape}")
for i in range(self.num_layers):
resnet_input = out
logger.debug(
f"Input to Resnet Block : {resnet_input.shape} in Up Block Layer {i}"
)
out = self.resnet_one[i](out)
out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
logger.debug(
f"Concatenated time embeddings to Resnet Sub Block 1: {out.shape} of Up Block Layer {i}"
)
out = self.resnet_two[i](out)
out = out + self.resnet_in[i](resnet_input)
logger.debug(
f"Adding Residual connection : {out.shape} to Up Block Layer {i}"
)
# Attention Block
batch_size, channels, h, w = out.shape
in_attn = out.reshape(batch_size, channels, h * w)
in_attn = self.attention_norms[i](in_attn)
logger.debug(f"Attention Norm: {in_attn.shape} in Up Block Layer : {i}")
in_attn = in_attn.transpose(1, 2)
logger.debug(
f"Passing Norm : {in_attn.shape} to Attention Layer in Up Block Layer : {i}"
)
out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
out = out + out_attn
logger.debug(
f"Added Attention score to output: {out.shape} in Up Block Layer {i}"
)
return out
class UNet(nn.Module):
r"""
Unet Backbone consisting:
Down Blocks, Mid Blocks, UpBlocks
"""
def __init__(self, model_config, use_up=True):
super().__init__()
im_channels = model_config["im_channels"]
self.down_channels = model_config["down_channels"]
self.mid_channels = model_config["mid_channels"]
self.t_emb_dim = model_config["t_emb_dim"]
self.down_sample = model_config["down_sample"]
self.num_down_layers = model_config["num_down_layers"]
self.num_mid_layers = model_config["num_mid_layers"]
self.num_up_layers = model_config["num_up_layers"]
assert self.mid_channels[0] == self.down_channels[-1]
assert self.mid_channels[-1] == self.down_channels[-2]
assert len(self.down_sample) == len(self.down_channels) - 1
self.t_proj = nn.Sequential(
nn.Linear(self.t_emb_dim, self.t_emb_dim),
nn.SiLU(),
nn.Linear(self.t_emb_dim, self.t_emb_dim),
)
self.up_sample = list(reversed(self.down_sample))
self.conv_in = nn.Conv2d(
im_channels, self.down_channels[0], kernel_size=3, padding=1
)
self.downs = nn.ModuleList([])
for i in range(len(self.down_channels) - 1):
self.downs.append(
DownBlock(
self.down_channels[i],
self.down_channels[i + 1],
self.t_emb_dim,
down_sample=self.down_sample[i],
num_layers=self.num_down_layers,
)
)
self.mids = nn.ModuleList([])
for i in range(len(self.mid_channels) - 1):
self.mids.append(
MidBlock(
self.mid_channels[i],
self.mid_channels[i + 1],
self.t_emb_dim,
num_layers=self.num_mid_layers,
)
)
if use_up:
self.ups = nn.ModuleList([])
for i in reversed(range(len(self.down_channels) - 1)):
self.ups.append(
UpBlock(
self.down_channels[i] * 2,
self.down_channels[i - 1] if i != 0 else 16,
self.t_emb_dim,
up_sample=self.down_sample[i],
num_layers=self.num_up_layers,
)
)
self.norm_out = nn.GroupNorm(8, 16)
self.conv_out = nn.Conv2d(16, im_channels, kernel_size=3, padding=1)
def forward(self, x, t):
t_emb = get_time_embedding(torch.as_tensor(t).long(), self.t_emb_dim)
t_emb = self.t_proj(t_emb)
logger.debug(f"Time embedding shape: {t_emb.shape} to UNet")
out = self.conv_in(x)
logger.debug(f"Ouput for conv : {out.shape} to UNet")
down_outs = []
for idx, down in enumerate(self.downs):
down_outs.append(out)
out = down(out, t_emb)
logger.debug(f"Output of Down Block {idx} : {out.shape} in UNet")
for idx, mid in enumerate(self.mids):
out = mid(out, t_emb)
logger.debug(f"Output of Mid Block {idx} : {out.shape} in UNet")
for idx, up in enumerate(self.ups):
out = up(out, down_outs.pop(), t_emb)
logger.debug(f"Output of Up Block {idx} : {out.shape} in UNet")
out = self.norm_out(out)
out = self.conv_out(out)
logger.debug(f"Output of UNet : {out.shape}")
return out