import enum import logging import os from re import UNICODE import torch import torch.nn as nn from unet_base import UNet, get_time_embedding logger = logging.getLogger(__name__) def make_zero_module(module): for p in module.parameters(): p.detach().zero_() return module class ControlNet(nn.Module): r""" ControlNet for trained DDPM """ def __init__( self, device, model_config, model_ckpt=None, model_locked=True ) -> None: super().__init__() # Trained DDPM self.model = UNet(model_config) self.model_locked = model_locked # Load weights for the trained model if model_ckpt is not None and device is not None: print("Loading Trained Diffusion Model") self.model.load_state_dict( torch.load(model_ckpt, map_location=device), strict=True ) # ControlNet Copy of Trained DDPM # use_up = False removes the upblocks(decoder layers) from DDPM Unet self.control_copy = UNet(model_config, use_up=False) # Load same weights as the trained model if model_ckpt is not None and device is not None: print("Loading Control Diffusion Model") self.control_copy.load_state_dict( torch.load(model_ckpt, map_location=device), strict=False ) # Hint Block for ControlNet # Stack of Conv Activation and Zero Convolution at the end self.hint_block = nn.Sequential( nn.Conv2d(model_config["hint_channels"], 64, kernel_size=3, padding=1), nn.SiLU(), nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.SiLU(), nn.Conv2d(128, self.model.down_channels[0], kernel_size=3, padding=1), nn.SiLU(), make_zero_module( nn.Conv2d( self.model.down_channels[0], self.model.down_channels[0], kernel_size=1, padding=0, ) ), ) self.control_copy_down_blocks = nn.ModuleList( [ make_zero_module( nn.Conv2d( self.model.down_channels[i], self.model.down_channels[i], kernel_size=1, padding=0, ) ) for i in range(len(self.model.down_channels) - 1) ] ) self.control_copy_mid_blocks = nn.ModuleList( [ make_zero_module( nn.Conv2d( self.model.mid_channels[i], self.model.mid_channels[i], kernel_size=1, padding=0, ) ) for i in range(1, len(self.model.mid_channels) - 1) ] ) def get_params(self): # Get all the control_net params params = list(self.control_copy.parameters()) params += list(self.hint_block.parameters()) params += list(self.control_copy_down_blocks.parameters()) params += list(self.control_copy_mid_blocks.parameters()) return params def forward(self, x, t, hint): time_embedding = get_time_embedding( torch.as_tensor(t).long(), self.model.t_emb_dim ) time_embedding = self.model.t_proj(time_embedding) logger.debug(f"Got Time embeddings for Original Copy : {time_embedding.shape}") model_down_outs = [] with torch.no_grad(): model_out = self.model.conv_in(x) for idx, down in enumerate(self.model.downs): model_down_outs.append(model_in) model_out = down(model_out, time_embedding) logger.debug( f"Getting output of Down Layer {idx} from the original copy : {model_out.shape}" ) logger.debug("Passing into ControlNet") controlnet_time_embedding = get_time_embedding( torch.as_tensor(t).long(), self.control_copy.t_emb_dim ) controlnet_time_embedding = self.control_copy.t_proj(controlnet_time_embedding) logger.debug( f"Got Time embedding for ControlNet : {controlnet_time_embedding.shape}" ) # Hint layer output here controlnet_hint_output = self.hint_block(hint) logger.debug( f"Getting output of the Hint Block into the ControlNet : {controlnet_hint_output.shape}" ) controlnet_out = self.control_copy.conv_in(x) logger.debug( f"Getting output of the Input Conv of ControlNet: {controlnet_out.shape}" ) controlnet_out += controlnet_hint_output logger.debug(f"Added Hint to the Conv Input: {controlnet_out.shape}") controlnet_down_outs = [] # Get all the outputs of the controlnet down blocks for idx, down in enumerate(self.control_copy.downs): down_out = self.control_copy_down_blocks[idx](controlnet_out) controlnet_down_outs.append(down_out) logger.debug( f"Got output of the {idx} Down Block of the ControlNet: {down_out.shape}" ) # Now get the midblocks and then give to original copy for idx in range(len(self.control_copy.mids)): controlnet_out = self.control_copy.mids[idx]( controlnet_out, controlnet_time_embedding ) logger.debug( f"Got the output of the mid block {idx} in controlnet : {controlnet_out.shape}" ) model_out = self.model.mids[idx](model_out, time_embedding) logger.debug( f"Got the output of Mid Block {idx} from original model : {model_out.shape}" ) model_out += self.control_copy_mid_blocks[idx](controlnet_out) logger.debug( f"Concatinating the ControlNet Mid Block {idx} output :{model_out.shape} to original copy" ) # Call the upblocks now for idx, up in enumerate(self.model.ups): model_down_out = model_down_outs.pop() logger.debug( f"Got the output from the down blocks from original model : {model_down_out.shape}" ) controlnet_down_out = controlnet_down_outs.pop() logger.debug( f"Got the output from the down blocks from controlnet copy : {controlnet_down_out.shape}" ) model_out = up( model_out, controlnet_down_out + model_down_out, time_embedding ) model_out = self.model.norm_out(model_out) model_out = nn.SiLU()(model_out) model_out = self.model.conv_out(model_out) return model_out