Update cldm/LumiNet.py
Browse files- cldm/LumiNet.py +19 -125
cldm/LumiNet.py
CHANGED
|
@@ -2,9 +2,6 @@ import einops
|
|
| 2 |
import torch
|
| 3 |
import torch as th
|
| 4 |
import torch.nn as nn
|
| 5 |
-
# from cldm.latent_intrinsic_new import LatentIntrinsc
|
| 6 |
-
# from cldm.latent_intrinsic import LatentIntrinsc
|
| 7 |
-
# from cldm.latent_intrinsic_crossattn import LatentIntrinsc
|
| 8 |
from cldm.latent_intrinsic import LatentIntrinsc
|
| 9 |
from ldm.modules.diffusionmodules.util import (
|
| 10 |
conv_nd,
|
|
@@ -50,7 +47,6 @@ class ControlledUnetModel(UNetModel):
|
|
| 50 |
h = torch.cat([h, hs.pop() + control.pop()], dim=1)
|
| 51 |
h = module(h, emb, context)
|
| 52 |
# print("fool! that is the shape of the context! at output",context.shape)
|
| 53 |
-
# exit(0)
|
| 54 |
|
| 55 |
h = h.type(x.dtype)
|
| 56 |
return self.out(h)
|
|
@@ -173,60 +169,6 @@ class ControlNet(nn.Module):
|
|
| 173 |
zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
|
| 174 |
)
|
| 175 |
|
| 176 |
-
# self.input_latent_hint_block = TimestepEmbedSequential(
|
| 177 |
-
# # conv_nd(dims, hint_channels, 16, 3, padding=1),
|
| 178 |
-
# # nn.SiLU(),
|
| 179 |
-
# # conv_nd(dims, 16, 16, 3, padding=1),
|
| 180 |
-
# # nn.SiLU(),
|
| 181 |
-
# # conv_nd(dims, 16, 32, 3, padding=1, stride=2),
|
| 182 |
-
# # nn.SiLU(),
|
| 183 |
-
# # conv_nd(dims, 32, 32, 3, padding=1),
|
| 184 |
-
# # nn.SiLU(),
|
| 185 |
-
# # conv_nd(dims, 32, 96, 3, padding=1, stride=2),
|
| 186 |
-
# # nn.SiLU(),
|
| 187 |
-
# # conv_nd(dims, 96, 96, 3, padding=1),
|
| 188 |
-
# # nn.SiLU(),
|
| 189 |
-
# conv_nd(dims, 128, 256, 3, padding=1, stride=1),
|
| 190 |
-
# nn.SiLU(),
|
| 191 |
-
# zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
|
| 192 |
-
# )
|
| 193 |
-
|
| 194 |
-
# self.input_latent_hint_cat_block = TimestepEmbedSequential(
|
| 195 |
-
# # conv_nd(dims, hint_channels, 16, 3, padding=1),
|
| 196 |
-
# # nn.SiLU(),
|
| 197 |
-
# # conv_nd(dims, 16, 16, 3, padding=1),
|
| 198 |
-
# # nn.SiLU(),
|
| 199 |
-
# # conv_nd(dims, 16, 32, 3, padding=1, stride=2),
|
| 200 |
-
# # nn.SiLU(),
|
| 201 |
-
# # conv_nd(dims, 144, 144, 3, padding=1),
|
| 202 |
-
# # nn.SiLU(),
|
| 203 |
-
# # conv_nd(dims, 144, 256, 3, padding=1, stride=2),
|
| 204 |
-
# # nn.SiLU(),
|
| 205 |
-
# conv_nd(dims, 144, 144, 3, padding=1),
|
| 206 |
-
# nn.SiLU(),
|
| 207 |
-
# conv_nd(dims, 144, 256, 3, padding=1, stride=1),
|
| 208 |
-
# nn.SiLU(),
|
| 209 |
-
# zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
|
| 210 |
-
# )
|
| 211 |
-
|
| 212 |
-
# self.input_latent_hint_cat_block = TimestepEmbedSequential(
|
| 213 |
-
# # conv_nd(dims, hint_channels, 16, 3, padding=1),
|
| 214 |
-
# # nn.SiLU(),
|
| 215 |
-
# # conv_nd(dims, 16, 16, 3, padding=1),
|
| 216 |
-
# # nn.SiLU(),
|
| 217 |
-
# # conv_nd(dims, 16, 32, 3, padding=1, stride=2),
|
| 218 |
-
# # nn.SiLU(),
|
| 219 |
-
# conv_nd(dims, 144, 144, 3, padding=1),
|
| 220 |
-
# nn.SiLU(),
|
| 221 |
-
# conv_nd(dims, 144, 256, 3, padding=1, stride=2),
|
| 222 |
-
# nn.SiLU(),
|
| 223 |
-
# conv_nd(dims, 256, 256, 3, padding=1),
|
| 224 |
-
# nn.SiLU(),
|
| 225 |
-
# conv_nd(dims, 256, 256, 3, padding=1, stride=1),
|
| 226 |
-
# nn.SiLU(),
|
| 227 |
-
# zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
|
| 228 |
-
# )
|
| 229 |
-
|
| 230 |
self.input_latent_hint_cat_atten_block = TimestepEmbedSequential(
|
| 231 |
# conv_nd(dims, hint_channels, 16, 3, padding=1),
|
| 232 |
# nn.SiLU(),
|
|
@@ -245,41 +187,6 @@ class ControlNet(nn.Module):
|
|
| 245 |
zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
|
| 246 |
)
|
| 247 |
|
| 248 |
-
# self.input_latent_hint_crossattn_block = TimestepEmbedSequential(
|
| 249 |
-
# # conv_nd(dims, hint_channels, 16, 3, padding=1),
|
| 250 |
-
# # nn.SiLU(),
|
| 251 |
-
# # conv_nd(dims, 16, 16, 3, padding=1),
|
| 252 |
-
# # nn.SiLU(),
|
| 253 |
-
# # conv_nd(dims, 16, 32, 3, padding=1, stride=2),
|
| 254 |
-
# # nn.SiLU(),
|
| 255 |
-
# conv_nd(dims, 128, 128, 3, padding=1),
|
| 256 |
-
# nn.SiLU(),
|
| 257 |
-
# conv_nd(dims, 128, 256, 3, padding=1, stride=2),
|
| 258 |
-
# nn.SiLU(),
|
| 259 |
-
# conv_nd(dims, 256, 256, 3, padding=1),
|
| 260 |
-
# nn.SiLU(),
|
| 261 |
-
# conv_nd(dims, 256, 256, 3, padding=1, stride=1),
|
| 262 |
-
# nn.SiLU(),
|
| 263 |
-
# zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
|
| 264 |
-
# )
|
| 265 |
-
|
| 266 |
-
# self.input_latent_hint_cat_eq_block = TimestepEmbedSequential(
|
| 267 |
-
# # conv_nd(dims, hint_channels, 16, 3, padding=1),
|
| 268 |
-
# # nn.SiLU(),
|
| 269 |
-
# # conv_nd(dims, 16, 16, 3, padding=1),
|
| 270 |
-
# # nn.SiLU(),
|
| 271 |
-
# # conv_nd(dims, 16, 32, 3, padding=1, stride=2),
|
| 272 |
-
# # nn.SiLU(),
|
| 273 |
-
# # conv_nd(dims, 32, 32, 3, padding=1),
|
| 274 |
-
# # nn.SiLU(),
|
| 275 |
-
# # conv_nd(dims, 32, 96, 3, padding=1, stride=2),
|
| 276 |
-
# # nn.SiLU(),
|
| 277 |
-
# conv_nd(dims, 256, 512, 3, padding=1),
|
| 278 |
-
# nn.SiLU(),
|
| 279 |
-
# conv_nd(dims, 512, 256, 3, padding=1, stride=1),
|
| 280 |
-
# nn.SiLU(),
|
| 281 |
-
# zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
|
| 282 |
-
# )
|
| 283 |
|
| 284 |
self._feature_size = model_channels
|
| 285 |
input_block_chans = [model_channels]
|
|
@@ -398,15 +305,7 @@ class ControlNet(nn.Module):
|
|
| 398 |
self._feature_size += ch
|
| 399 |
self.latent_iid = True
|
| 400 |
self.concat = True
|
| 401 |
-
|
| 402 |
-
# if self.latent_iid:
|
| 403 |
-
# # print(hint.shape)
|
| 404 |
-
# if self.concat:
|
| 405 |
-
# self.input_hint_block = self.input_latent_hint_cat_block
|
| 406 |
-
# else:
|
| 407 |
-
# self.input_hint_block = self.input_latent_hint_block
|
| 408 |
-
|
| 409 |
-
# New w. the crossattn version #comment before train Uncomment for test
|
| 410 |
if self.latent_iid:
|
| 411 |
self.input_hint_block = self.input_latent_hint_cat_atten_block
|
| 412 |
def make_zero_conv(self, channels):
|
|
@@ -414,31 +313,20 @@ class ControlNet(nn.Module):
|
|
| 414 |
#our modification for the latent intrinsic
|
| 415 |
def add_latent_prior(self):
|
| 416 |
self.prior_extracter = LatentIntrinsc()
|
| 417 |
-
|
| 418 |
-
# self.input_hint_block[0] = conv_nd(2, 19, 16, 3, padding=1)
|
| 419 |
|
| 420 |
def forward(self, x, hint, timesteps, context, **kwargs):
|
| 421 |
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
| 422 |
emb = self.time_embed(t_emb)
|
| 423 |
self.latent_iid = True
|
| 424 |
-
# self.concat = False
|
| 425 |
# adding a cross-attention version
|
| 426 |
if self.latent_iid:
|
| 427 |
hint, hint_lighting = self.prior_extracter(hint)
|
| 428 |
self.new_context = torch.cat([context,hint_lighting],1)
|
| 429 |
-
# self.new_context = context*hint_lighting
|
| 430 |
context = self.new_context
|
| 431 |
-
|
| 432 |
-
# exit(0)
|
| 433 |
self.input_hint_block = self.input_latent_hint_cat_atten_block
|
| 434 |
-
|
| 435 |
-
# if self.latent_iid:
|
| 436 |
-
# hint = self.prior_extracter(hint)
|
| 437 |
-
# # print(hint.shape)
|
| 438 |
-
# if self.concat:
|
| 439 |
-
# self.input_hint_block = self.input_latent_hint_cat_block
|
| 440 |
-
# else:
|
| 441 |
-
# self.input_hint_block = self.input_latent_hint_block
|
| 442 |
guided_hint = self.input_hint_block(hint, emb, context)
|
| 443 |
else:
|
| 444 |
guided_hint = self.input_hint_block(hint, emb, context)
|
|
@@ -470,6 +358,19 @@ class ControlLDM(LatentDiffusion):
|
|
| 470 |
self.only_mid_control = only_mid_control
|
| 471 |
self.control_scales = [1.0] * 13
|
| 472 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 473 |
|
| 474 |
@torch.no_grad()
|
| 475 |
def add_new_layers(self):
|
|
@@ -491,8 +392,6 @@ class ControlLDM(LatentDiffusion):
|
|
| 491 |
diffusion_model = self.model.diffusion_model
|
| 492 |
|
| 493 |
cond_txt = torch.cat(cond['c_crossattn'], 1)
|
| 494 |
-
# print('attention_shape:',cond)
|
| 495 |
-
# exit(0)
|
| 496 |
if cond['c_concat'] is None:
|
| 497 |
eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=None, only_mid_control=self.only_mid_control)
|
| 498 |
else:
|
|
@@ -504,13 +403,10 @@ class ControlLDM(LatentDiffusion):
|
|
| 504 |
eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=control, only_mid_control=self.only_mid_control)
|
| 505 |
# print('unet_context',cond_txt.shape)
|
| 506 |
return eps
|
| 507 |
-
# @torch.no_grad()
|
| 508 |
-
# def get_attn_intrinsic(self):
|
| 509 |
-
# return self.control_model.new_context
|
| 510 |
@torch.no_grad()
|
| 511 |
def get_unconditional_conditioning(self, N):
|
| 512 |
return self.get_learned_conditioning([""] * N)
|
| 513 |
-
|
| 514 |
@torch.no_grad()
|
| 515 |
def log_images(self, batch, N=4, n_row=2, sample=False, ddim_steps=50, ddim_eta=0.0, return_keys=None,
|
| 516 |
quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
|
|
@@ -583,9 +479,7 @@ class ControlLDM(LatentDiffusion):
|
|
| 583 |
def configure_optimizers(self):
|
| 584 |
lr = self.learning_rate
|
| 585 |
params = list(self.control_model.parameters())
|
| 586 |
-
|
| 587 |
-
# params += list(self.model.diffusion_model.output_blocks.parameters())
|
| 588 |
-
# params += list(self.model.diffusion_model.out.parameters())
|
| 589 |
if self.crossattn_start: #here we also train the cross-attan in the input layer if has any
|
| 590 |
for block in self.model.diffusion_model.input_blocks:
|
| 591 |
for layer in block:
|
|
|
|
| 2 |
import torch
|
| 3 |
import torch as th
|
| 4 |
import torch.nn as nn
|
|
|
|
|
|
|
|
|
|
| 5 |
from cldm.latent_intrinsic import LatentIntrinsc
|
| 6 |
from ldm.modules.diffusionmodules.util import (
|
| 7 |
conv_nd,
|
|
|
|
| 47 |
h = torch.cat([h, hs.pop() + control.pop()], dim=1)
|
| 48 |
h = module(h, emb, context)
|
| 49 |
# print("fool! that is the shape of the context! at output",context.shape)
|
|
|
|
| 50 |
|
| 51 |
h = h.type(x.dtype)
|
| 52 |
return self.out(h)
|
|
|
|
| 169 |
zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
|
| 170 |
)
|
| 171 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
self.input_latent_hint_cat_atten_block = TimestepEmbedSequential(
|
| 173 |
# conv_nd(dims, hint_channels, 16, 3, padding=1),
|
| 174 |
# nn.SiLU(),
|
|
|
|
| 187 |
zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
|
| 188 |
)
|
| 189 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
|
| 191 |
self._feature_size = model_channels
|
| 192 |
input_block_chans = [model_channels]
|
|
|
|
| 305 |
self._feature_size += ch
|
| 306 |
self.latent_iid = True
|
| 307 |
self.concat = True
|
| 308 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 309 |
if self.latent_iid:
|
| 310 |
self.input_hint_block = self.input_latent_hint_cat_atten_block
|
| 311 |
def make_zero_conv(self, channels):
|
|
|
|
| 313 |
#our modification for the latent intrinsic
|
| 314 |
def add_latent_prior(self):
|
| 315 |
self.prior_extracter = LatentIntrinsc()
|
| 316 |
+
|
|
|
|
| 317 |
|
| 318 |
def forward(self, x, hint, timesteps, context, **kwargs):
|
| 319 |
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
| 320 |
emb = self.time_embed(t_emb)
|
| 321 |
self.latent_iid = True
|
|
|
|
| 322 |
# adding a cross-attention version
|
| 323 |
if self.latent_iid:
|
| 324 |
hint, hint_lighting = self.prior_extracter(hint)
|
| 325 |
self.new_context = torch.cat([context,hint_lighting],1)
|
|
|
|
| 326 |
context = self.new_context
|
| 327 |
+
|
|
|
|
| 328 |
self.input_hint_block = self.input_latent_hint_cat_atten_block
|
| 329 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 330 |
guided_hint = self.input_hint_block(hint, emb, context)
|
| 331 |
else:
|
| 332 |
guided_hint = self.input_hint_block(hint, emb, context)
|
|
|
|
| 358 |
self.only_mid_control = only_mid_control
|
| 359 |
self.control_scales = [1.0] * 13
|
| 360 |
|
| 361 |
+
# load bypass decoder
|
| 362 |
+
@torch.no_grad()
|
| 363 |
+
def change_first_stage(self, checkpoint_file,og=False):
|
| 364 |
+
del self.first_stage_model
|
| 365 |
+
from modi_vae.autoencoder import AutoencoderKL
|
| 366 |
+
self.first_stage_model = AutoencoderKL(load_checkpoint=False)
|
| 367 |
+
state_dict = torch.load(checkpoint_file, map_location=torch.device("cpu"))["state_dict"]
|
| 368 |
+
new_state_dict = {}
|
| 369 |
+
for s in state_dict:
|
| 370 |
+
new_state_dict[s]=state_dict[s]
|
| 371 |
+
|
| 372 |
+
self.first_stage_model.load_state_dict(new_state_dict)
|
| 373 |
+
print("Successfully load new auto-encoder")
|
| 374 |
|
| 375 |
@torch.no_grad()
|
| 376 |
def add_new_layers(self):
|
|
|
|
| 392 |
diffusion_model = self.model.diffusion_model
|
| 393 |
|
| 394 |
cond_txt = torch.cat(cond['c_crossattn'], 1)
|
|
|
|
|
|
|
| 395 |
if cond['c_concat'] is None:
|
| 396 |
eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=None, only_mid_control=self.only_mid_control)
|
| 397 |
else:
|
|
|
|
| 403 |
eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=control, only_mid_control=self.only_mid_control)
|
| 404 |
# print('unet_context',cond_txt.shape)
|
| 405 |
return eps
|
|
|
|
|
|
|
|
|
|
| 406 |
@torch.no_grad()
|
| 407 |
def get_unconditional_conditioning(self, N):
|
| 408 |
return self.get_learned_conditioning([""] * N)
|
| 409 |
+
|
| 410 |
@torch.no_grad()
|
| 411 |
def log_images(self, batch, N=4, n_row=2, sample=False, ddim_steps=50, ddim_eta=0.0, return_keys=None,
|
| 412 |
quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
|
|
|
|
| 479 |
def configure_optimizers(self):
|
| 480 |
lr = self.learning_rate
|
| 481 |
params = list(self.control_model.parameters())
|
| 482 |
+
|
|
|
|
|
|
|
| 483 |
if self.crossattn_start: #here we also train the cross-attan in the input layer if has any
|
| 484 |
for block in self.model.diffusion_model.input_blocks:
|
| 485 |
for layer in block:
|