File size: 6,923 Bytes
d62b4c3 2bd73d5 d62b4c3 2bd73d5 d62b4c3 2bd73d5 d62b4c3 2bd73d5 d62b4c3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 | import enum
import logging
import os
from re import UNICODE
import torch
import torch.nn as nn
from unet_base import UNet, get_time_embedding
logger = logging.getLogger(__name__)
def make_zero_module(module):
for p in module.parameters():
p.detach().zero_()
return module
class ControlNet(nn.Module):
r"""
ControlNet for trained DDPM
"""
def __init__(
self, device, model_config, model_ckpt=None, model_locked=True
) -> None:
super().__init__()
# Trained DDPM
self.model = UNet(model_config)
self.model_locked = model_locked
# Load weights for the trained model
if model_ckpt is not None and device is not None:
print("Loading Trained Diffusion Model")
self.model.load_state_dict(
torch.load(model_ckpt, map_location=device), strict=True
)
# ControlNet Copy of Trained DDPM
# use_up = False removes the upblocks(decoder layers) from DDPM Unet
self.control_copy = UNet(model_config, use_up=False)
# Load same weights as the trained model
if model_ckpt is not None and device is not None:
print("Loading Control Diffusion Model")
self.control_copy.load_state_dict(
torch.load(model_ckpt, map_location=device), strict=False
)
# Hint Block for ControlNet
# Stack of Conv Activation and Zero Convolution at the end
self.hint_block = nn.Sequential(
nn.Conv2d(model_config["hint_channels"], 64, kernel_size=3, padding=1),
nn.SiLU(),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.SiLU(),
nn.Conv2d(128, self.model.down_channels[0], kernel_size=3, padding=1),
nn.SiLU(),
make_zero_module(
nn.Conv2d(
self.model.down_channels[0],
self.model.down_channels[0],
kernel_size=1,
padding=0,
)
),
)
self.control_copy_down_blocks = nn.ModuleList(
[
make_zero_module(
nn.Conv2d(
self.model.down_channels[i],
self.model.down_channels[i],
kernel_size=1,
padding=0,
)
)
for i in range(len(self.model.down_channels) - 1)
]
)
self.control_copy_mid_blocks = nn.ModuleList(
[
make_zero_module(
nn.Conv2d(
self.model.mid_channels[i],
self.model.mid_channels[i],
kernel_size=1,
padding=0,
)
)
for i in range(1, len(self.model.mid_channels) - 1)
]
)
def get_params(self):
# Get all the control_net params
params = list(self.control_copy.parameters())
params += list(self.hint_block.parameters())
params += list(self.control_copy_down_blocks.parameters())
params += list(self.control_copy_mid_blocks.parameters())
return params
def forward(self, x, t, hint):
time_embedding = get_time_embedding(
torch.as_tensor(t).long(), self.model.t_emb_dim
)
time_embedding = self.model.t_proj(time_embedding)
logger.debug(f"Got Time embeddings for Original Copy : {time_embedding.shape}")
model_down_outs = []
with torch.no_grad():
model_out = self.model.conv_in(x)
for idx, down in enumerate(self.model.downs):
model_down_outs.append(model_in)
model_out = down(model_out, time_embedding)
logger.debug(
f"Getting output of Down Layer {idx} from the original copy : {model_out.shape}"
)
logger.debug("Passing into ControlNet")
controlnet_time_embedding = get_time_embedding(
torch.as_tensor(t).long(), self.control_copy.t_emb_dim
)
controlnet_time_embedding = self.control_copy.t_proj(controlnet_time_embedding)
logger.debug(
f"Got Time embedding for ControlNet : {controlnet_time_embedding.shape}"
)
# Hint layer output here
controlnet_hint_output = self.hint_block(hint)
logger.debug(
f"Getting output of the Hint Block into the ControlNet : {controlnet_hint_output.shape}"
)
controlnet_out = self.control_copy.conv_in(x)
logger.debug(
f"Getting output of the Input Conv of ControlNet: {controlnet_out.shape}"
)
controlnet_out += controlnet_hint_output
logger.debug(f"Added Hint to the Conv Input: {controlnet_out.shape}")
controlnet_down_outs = []
# Get all the outputs of the controlnet down blocks
for idx, down in enumerate(self.control_copy.downs):
down_out = self.control_copy_down_blocks[idx](controlnet_out)
controlnet_down_outs.append(down_out)
logger.debug(
f"Got output of the {idx} Down Block of the ControlNet: {down_out.shape}"
)
# Now get the midblocks and then give to original copy
for idx in range(len(self.control_copy.mids)):
controlnet_out = self.control_copy.mids[idx](
controlnet_out, controlnet_time_embedding
)
logger.debug(
f"Got the output of the mid block {idx} in controlnet : {controlnet_out.shape}"
)
model_out = self.model.mids[idx](model_out, time_embedding)
logger.debug(
f"Got the output of Mid Block {idx} from original model : {model_out.shape}"
)
model_out += self.control_copy_mid_blocks[idx](controlnet_out)
logger.debug(
f"Concatinating the ControlNet Mid Block {idx} output :{model_out.shape} to original copy"
)
# Call the upblocks now
for idx, up in enumerate(self.model.ups):
model_down_out = model_down_outs.pop()
logger.debug(
f"Got the output from the down blocks from original model : {model_down_out.shape}"
)
controlnet_down_out = controlnet_down_outs.pop()
logger.debug(
f"Got the output from the down blocks from controlnet copy : {controlnet_down_out.shape}"
)
model_out = up(
model_out, controlnet_down_out + model_down_out, time_embedding
)
model_out = self.model.norm_out(model_out)
model_out = nn.SiLU()(model_out)
model_out = self.model.conv_out(model_out)
return model_out
|