ControlNet / model_blocks /controlnet.py
YashNagraj75's picture
Add loading ckpt logic from state_dict and not directly from model class
2bd73d5
import enum
import logging
import os
from re import UNICODE
import torch
import torch.nn as nn
from unet_base import UNet, get_time_embedding
logger = logging.getLogger(__name__)
def make_zero_module(module):
for p in module.parameters():
p.detach().zero_()
return module
class ControlNet(nn.Module):
r"""
ControlNet for trained DDPM
"""
def __init__(
self, device, model_config, model_ckpt=None, model_locked=True
) -> None:
super().__init__()
# Trained DDPM
self.model = UNet(model_config)
self.model_locked = model_locked
# Load weights for the trained model
if model_ckpt is not None and device is not None:
print("Loading Trained Diffusion Model")
self.model.load_state_dict(
torch.load(model_ckpt, map_location=device), strict=True
)
# ControlNet Copy of Trained DDPM
# use_up = False removes the upblocks(decoder layers) from DDPM Unet
self.control_copy = UNet(model_config, use_up=False)
# Load same weights as the trained model
if model_ckpt is not None and device is not None:
print("Loading Control Diffusion Model")
self.control_copy.load_state_dict(
torch.load(model_ckpt, map_location=device), strict=False
)
# Hint Block for ControlNet
# Stack of Conv Activation and Zero Convolution at the end
self.hint_block = nn.Sequential(
nn.Conv2d(model_config["hint_channels"], 64, kernel_size=3, padding=1),
nn.SiLU(),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.SiLU(),
nn.Conv2d(128, self.model.down_channels[0], kernel_size=3, padding=1),
nn.SiLU(),
make_zero_module(
nn.Conv2d(
self.model.down_channels[0],
self.model.down_channels[0],
kernel_size=1,
padding=0,
)
),
)
self.control_copy_down_blocks = nn.ModuleList(
[
make_zero_module(
nn.Conv2d(
self.model.down_channels[i],
self.model.down_channels[i],
kernel_size=1,
padding=0,
)
)
for i in range(len(self.model.down_channels) - 1)
]
)
self.control_copy_mid_blocks = nn.ModuleList(
[
make_zero_module(
nn.Conv2d(
self.model.mid_channels[i],
self.model.mid_channels[i],
kernel_size=1,
padding=0,
)
)
for i in range(1, len(self.model.mid_channels) - 1)
]
)
def get_params(self):
# Get all the control_net params
params = list(self.control_copy.parameters())
params += list(self.hint_block.parameters())
params += list(self.control_copy_down_blocks.parameters())
params += list(self.control_copy_mid_blocks.parameters())
return params
def forward(self, x, t, hint):
time_embedding = get_time_embedding(
torch.as_tensor(t).long(), self.model.t_emb_dim
)
time_embedding = self.model.t_proj(time_embedding)
logger.debug(f"Got Time embeddings for Original Copy : {time_embedding.shape}")
model_down_outs = []
with torch.no_grad():
model_out = self.model.conv_in(x)
for idx, down in enumerate(self.model.downs):
model_down_outs.append(model_in)
model_out = down(model_out, time_embedding)
logger.debug(
f"Getting output of Down Layer {idx} from the original copy : {model_out.shape}"
)
logger.debug("Passing into ControlNet")
controlnet_time_embedding = get_time_embedding(
torch.as_tensor(t).long(), self.control_copy.t_emb_dim
)
controlnet_time_embedding = self.control_copy.t_proj(controlnet_time_embedding)
logger.debug(
f"Got Time embedding for ControlNet : {controlnet_time_embedding.shape}"
)
# Hint layer output here
controlnet_hint_output = self.hint_block(hint)
logger.debug(
f"Getting output of the Hint Block into the ControlNet : {controlnet_hint_output.shape}"
)
controlnet_out = self.control_copy.conv_in(x)
logger.debug(
f"Getting output of the Input Conv of ControlNet: {controlnet_out.shape}"
)
controlnet_out += controlnet_hint_output
logger.debug(f"Added Hint to the Conv Input: {controlnet_out.shape}")
controlnet_down_outs = []
# Get all the outputs of the controlnet down blocks
for idx, down in enumerate(self.control_copy.downs):
down_out = self.control_copy_down_blocks[idx](controlnet_out)
controlnet_down_outs.append(down_out)
logger.debug(
f"Got output of the {idx} Down Block of the ControlNet: {down_out.shape}"
)
# Now get the midblocks and then give to original copy
for idx in range(len(self.control_copy.mids)):
controlnet_out = self.control_copy.mids[idx](
controlnet_out, controlnet_time_embedding
)
logger.debug(
f"Got the output of the mid block {idx} in controlnet : {controlnet_out.shape}"
)
model_out = self.model.mids[idx](model_out, time_embedding)
logger.debug(
f"Got the output of Mid Block {idx} from original model : {model_out.shape}"
)
model_out += self.control_copy_mid_blocks[idx](controlnet_out)
logger.debug(
f"Concatinating the ControlNet Mid Block {idx} output :{model_out.shape} to original copy"
)
# Call the upblocks now
for idx, up in enumerate(self.model.ups):
model_down_out = model_down_outs.pop()
logger.debug(
f"Got the output from the down blocks from original model : {model_down_out.shape}"
)
controlnet_down_out = controlnet_down_outs.pop()
logger.debug(
f"Got the output from the down blocks from controlnet copy : {controlnet_down_out.shape}"
)
model_out = up(
model_out, controlnet_down_out + model_down_out, time_embedding
)
model_out = self.model.norm_out(model_out)
model_out = nn.SiLU()(model_out)
model_out = self.model.conv_out(model_out)
return model_out