| import enum |
| import logging |
| import os |
| from re import UNICODE |
|
|
| import torch |
| import torch.nn as nn |
| from unet_base import UNet, get_time_embedding |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| def make_zero_module(module): |
| for p in module.parameters(): |
| p.detach().zero_() |
| return module |
|
|
|
|
| class ControlNet(nn.Module): |
| r""" |
| ControlNet for trained DDPM |
| """ |
|
|
| def __init__( |
| self, device, model_config, model_ckpt=None, model_locked=True |
| ) -> None: |
| super().__init__() |
|
|
| |
| self.model = UNet(model_config) |
| self.model_locked = model_locked |
|
|
| |
| if model_ckpt is not None and device is not None: |
| print("Loading Trained Diffusion Model") |
| self.model.load_state_dict( |
| torch.load(model_ckpt, map_location=device), strict=True |
| ) |
|
|
| |
| |
| self.control_copy = UNet(model_config, use_up=False) |
| |
|
|
| if model_ckpt is not None and device is not None: |
| print("Loading Control Diffusion Model") |
| self.control_copy.load_state_dict( |
| torch.load(model_ckpt, map_location=device), strict=False |
| ) |
|
|
| |
| |
| self.hint_block = nn.Sequential( |
| nn.Conv2d(model_config["hint_channels"], 64, kernel_size=3, padding=1), |
| nn.SiLU(), |
| nn.Conv2d(64, 128, kernel_size=3, padding=1), |
| nn.SiLU(), |
| nn.Conv2d(128, self.model.down_channels[0], kernel_size=3, padding=1), |
| nn.SiLU(), |
| make_zero_module( |
| nn.Conv2d( |
| self.model.down_channels[0], |
| self.model.down_channels[0], |
| kernel_size=1, |
| padding=0, |
| ) |
| ), |
| ) |
|
|
| self.control_copy_down_blocks = nn.ModuleList( |
| [ |
| make_zero_module( |
| nn.Conv2d( |
| self.model.down_channels[i], |
| self.model.down_channels[i], |
| kernel_size=1, |
| padding=0, |
| ) |
| ) |
| for i in range(len(self.model.down_channels) - 1) |
| ] |
| ) |
|
|
| self.control_copy_mid_blocks = nn.ModuleList( |
| [ |
| make_zero_module( |
| nn.Conv2d( |
| self.model.mid_channels[i], |
| self.model.mid_channels[i], |
| kernel_size=1, |
| padding=0, |
| ) |
| ) |
| for i in range(1, len(self.model.mid_channels) - 1) |
| ] |
| ) |
|
|
| def get_params(self): |
| |
| params = list(self.control_copy.parameters()) |
| params += list(self.hint_block.parameters()) |
| params += list(self.control_copy_down_blocks.parameters()) |
| params += list(self.control_copy_mid_blocks.parameters()) |
|
|
| return params |
|
|
| def forward(self, x, t, hint): |
| time_embedding = get_time_embedding( |
| torch.as_tensor(t).long(), self.model.t_emb_dim |
| ) |
| time_embedding = self.model.t_proj(time_embedding) |
| logger.debug(f"Got Time embeddings for Original Copy : {time_embedding.shape}") |
|
|
| model_down_outs = [] |
|
|
| with torch.no_grad(): |
| model_out = self.model.conv_in(x) |
| for idx, down in enumerate(self.model.downs): |
| model_down_outs.append(model_in) |
| model_out = down(model_out, time_embedding) |
| logger.debug( |
| f"Getting output of Down Layer {idx} from the original copy : {model_out.shape}" |
| ) |
|
|
| logger.debug("Passing into ControlNet") |
|
|
| controlnet_time_embedding = get_time_embedding( |
| torch.as_tensor(t).long(), self.control_copy.t_emb_dim |
| ) |
| controlnet_time_embedding = self.control_copy.t_proj(controlnet_time_embedding) |
| logger.debug( |
| f"Got Time embedding for ControlNet : {controlnet_time_embedding.shape}" |
| ) |
|
|
| |
| controlnet_hint_output = self.hint_block(hint) |
| logger.debug( |
| f"Getting output of the Hint Block into the ControlNet : {controlnet_hint_output.shape}" |
| ) |
|
|
| controlnet_out = self.control_copy.conv_in(x) |
| logger.debug( |
| f"Getting output of the Input Conv of ControlNet: {controlnet_out.shape}" |
| ) |
|
|
| controlnet_out += controlnet_hint_output |
| logger.debug(f"Added Hint to the Conv Input: {controlnet_out.shape}") |
|
|
| controlnet_down_outs = [] |
| |
| for idx, down in enumerate(self.control_copy.downs): |
| down_out = self.control_copy_down_blocks[idx](controlnet_out) |
| controlnet_down_outs.append(down_out) |
| logger.debug( |
| f"Got output of the {idx} Down Block of the ControlNet: {down_out.shape}" |
| ) |
|
|
| |
| for idx in range(len(self.control_copy.mids)): |
| controlnet_out = self.control_copy.mids[idx]( |
| controlnet_out, controlnet_time_embedding |
| ) |
| logger.debug( |
| f"Got the output of the mid block {idx} in controlnet : {controlnet_out.shape}" |
| ) |
|
|
| model_out = self.model.mids[idx](model_out, time_embedding) |
| logger.debug( |
| f"Got the output of Mid Block {idx} from original model : {model_out.shape}" |
| ) |
|
|
| model_out += self.control_copy_mid_blocks[idx](controlnet_out) |
| logger.debug( |
| f"Concatinating the ControlNet Mid Block {idx} output :{model_out.shape} to original copy" |
| ) |
|
|
| |
| for idx, up in enumerate(self.model.ups): |
| model_down_out = model_down_outs.pop() |
| logger.debug( |
| f"Got the output from the down blocks from original model : {model_down_out.shape}" |
| ) |
| controlnet_down_out = controlnet_down_outs.pop() |
| logger.debug( |
| f"Got the output from the down blocks from controlnet copy : {controlnet_down_out.shape}" |
| ) |
|
|
| model_out = up( |
| model_out, controlnet_down_out + model_down_out, time_embedding |
| ) |
|
|
| model_out = self.model.norm_out(model_out) |
| model_out = nn.SiLU()(model_out) |
| model_out = self.model.conv_out(model_out) |
|
|
| return model_out |
|
|