xyxingx commited on
Commit
6c63884
·
verified ·
1 Parent(s): ce5b71a

Update cldm/LumiNet.py

Browse files
Files changed (1) hide show
  1. cldm/LumiNet.py +19 -125
cldm/LumiNet.py CHANGED
@@ -2,9 +2,6 @@ import einops
2
  import torch
3
  import torch as th
4
  import torch.nn as nn
5
- # from cldm.latent_intrinsic_new import LatentIntrinsc
6
- # from cldm.latent_intrinsic import LatentIntrinsc
7
- # from cldm.latent_intrinsic_crossattn import LatentIntrinsc
8
  from cldm.latent_intrinsic import LatentIntrinsc
9
  from ldm.modules.diffusionmodules.util import (
10
  conv_nd,
@@ -50,7 +47,6 @@ class ControlledUnetModel(UNetModel):
50
  h = torch.cat([h, hs.pop() + control.pop()], dim=1)
51
  h = module(h, emb, context)
52
  # print("fool! that is the shape of the context! at output",context.shape)
53
- # exit(0)
54
 
55
  h = h.type(x.dtype)
56
  return self.out(h)
@@ -173,60 +169,6 @@ class ControlNet(nn.Module):
173
  zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
174
  )
175
 
176
- # self.input_latent_hint_block = TimestepEmbedSequential(
177
- # # conv_nd(dims, hint_channels, 16, 3, padding=1),
178
- # # nn.SiLU(),
179
- # # conv_nd(dims, 16, 16, 3, padding=1),
180
- # # nn.SiLU(),
181
- # # conv_nd(dims, 16, 32, 3, padding=1, stride=2),
182
- # # nn.SiLU(),
183
- # # conv_nd(dims, 32, 32, 3, padding=1),
184
- # # nn.SiLU(),
185
- # # conv_nd(dims, 32, 96, 3, padding=1, stride=2),
186
- # # nn.SiLU(),
187
- # # conv_nd(dims, 96, 96, 3, padding=1),
188
- # # nn.SiLU(),
189
- # conv_nd(dims, 128, 256, 3, padding=1, stride=1),
190
- # nn.SiLU(),
191
- # zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
192
- # )
193
-
194
- # self.input_latent_hint_cat_block = TimestepEmbedSequential(
195
- # # conv_nd(dims, hint_channels, 16, 3, padding=1),
196
- # # nn.SiLU(),
197
- # # conv_nd(dims, 16, 16, 3, padding=1),
198
- # # nn.SiLU(),
199
- # # conv_nd(dims, 16, 32, 3, padding=1, stride=2),
200
- # # nn.SiLU(),
201
- # # conv_nd(dims, 144, 144, 3, padding=1),
202
- # # nn.SiLU(),
203
- # # conv_nd(dims, 144, 256, 3, padding=1, stride=2),
204
- # # nn.SiLU(),
205
- # conv_nd(dims, 144, 144, 3, padding=1),
206
- # nn.SiLU(),
207
- # conv_nd(dims, 144, 256, 3, padding=1, stride=1),
208
- # nn.SiLU(),
209
- # zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
210
- # )
211
-
212
- # self.input_latent_hint_cat_block = TimestepEmbedSequential(
213
- # # conv_nd(dims, hint_channels, 16, 3, padding=1),
214
- # # nn.SiLU(),
215
- # # conv_nd(dims, 16, 16, 3, padding=1),
216
- # # nn.SiLU(),
217
- # # conv_nd(dims, 16, 32, 3, padding=1, stride=2),
218
- # # nn.SiLU(),
219
- # conv_nd(dims, 144, 144, 3, padding=1),
220
- # nn.SiLU(),
221
- # conv_nd(dims, 144, 256, 3, padding=1, stride=2),
222
- # nn.SiLU(),
223
- # conv_nd(dims, 256, 256, 3, padding=1),
224
- # nn.SiLU(),
225
- # conv_nd(dims, 256, 256, 3, padding=1, stride=1),
226
- # nn.SiLU(),
227
- # zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
228
- # )
229
-
230
  self.input_latent_hint_cat_atten_block = TimestepEmbedSequential(
231
  # conv_nd(dims, hint_channels, 16, 3, padding=1),
232
  # nn.SiLU(),
@@ -245,41 +187,6 @@ class ControlNet(nn.Module):
245
  zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
246
  )
247
 
248
- # self.input_latent_hint_crossattn_block = TimestepEmbedSequential(
249
- # # conv_nd(dims, hint_channels, 16, 3, padding=1),
250
- # # nn.SiLU(),
251
- # # conv_nd(dims, 16, 16, 3, padding=1),
252
- # # nn.SiLU(),
253
- # # conv_nd(dims, 16, 32, 3, padding=1, stride=2),
254
- # # nn.SiLU(),
255
- # conv_nd(dims, 128, 128, 3, padding=1),
256
- # nn.SiLU(),
257
- # conv_nd(dims, 128, 256, 3, padding=1, stride=2),
258
- # nn.SiLU(),
259
- # conv_nd(dims, 256, 256, 3, padding=1),
260
- # nn.SiLU(),
261
- # conv_nd(dims, 256, 256, 3, padding=1, stride=1),
262
- # nn.SiLU(),
263
- # zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
264
- # )
265
-
266
- # self.input_latent_hint_cat_eq_block = TimestepEmbedSequential(
267
- # # conv_nd(dims, hint_channels, 16, 3, padding=1),
268
- # # nn.SiLU(),
269
- # # conv_nd(dims, 16, 16, 3, padding=1),
270
- # # nn.SiLU(),
271
- # # conv_nd(dims, 16, 32, 3, padding=1, stride=2),
272
- # # nn.SiLU(),
273
- # # conv_nd(dims, 32, 32, 3, padding=1),
274
- # # nn.SiLU(),
275
- # # conv_nd(dims, 32, 96, 3, padding=1, stride=2),
276
- # # nn.SiLU(),
277
- # conv_nd(dims, 256, 512, 3, padding=1),
278
- # nn.SiLU(),
279
- # conv_nd(dims, 512, 256, 3, padding=1, stride=1),
280
- # nn.SiLU(),
281
- # zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
282
- # )
283
 
284
  self._feature_size = model_channels
285
  input_block_chans = [model_channels]
@@ -398,15 +305,7 @@ class ControlNet(nn.Module):
398
  self._feature_size += ch
399
  self.latent_iid = True
400
  self.concat = True
401
- # Old pure ControlNet
402
- # if self.latent_iid:
403
- # # print(hint.shape)
404
- # if self.concat:
405
- # self.input_hint_block = self.input_latent_hint_cat_block
406
- # else:
407
- # self.input_hint_block = self.input_latent_hint_block
408
-
409
- # New w. the crossattn version #comment before train Uncomment for test
410
  if self.latent_iid:
411
  self.input_hint_block = self.input_latent_hint_cat_atten_block
412
  def make_zero_conv(self, channels):
@@ -414,31 +313,20 @@ class ControlNet(nn.Module):
414
  #our modification for the latent intrinsic
415
  def add_latent_prior(self):
416
  self.prior_extracter = LatentIntrinsc()
417
- # send the extracted information to control net encoder (original image)
418
- # self.input_hint_block[0] = conv_nd(2, 19, 16, 3, padding=1)
419
 
420
  def forward(self, x, hint, timesteps, context, **kwargs):
421
  t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
422
  emb = self.time_embed(t_emb)
423
  self.latent_iid = True
424
- # self.concat = False
425
  # adding a cross-attention version
426
  if self.latent_iid:
427
  hint, hint_lighting = self.prior_extracter(hint)
428
  self.new_context = torch.cat([context,hint_lighting],1)
429
- # self.new_context = context*hint_lighting
430
  context = self.new_context
431
- # print("hint_context",context.shape)
432
- # exit(0)
433
  self.input_hint_block = self.input_latent_hint_cat_atten_block
434
- # pure controlnet with latent guidance
435
- # if self.latent_iid:
436
- # hint = self.prior_extracter(hint)
437
- # # print(hint.shape)
438
- # if self.concat:
439
- # self.input_hint_block = self.input_latent_hint_cat_block
440
- # else:
441
- # self.input_hint_block = self.input_latent_hint_block
442
  guided_hint = self.input_hint_block(hint, emb, context)
443
  else:
444
  guided_hint = self.input_hint_block(hint, emb, context)
@@ -470,6 +358,19 @@ class ControlLDM(LatentDiffusion):
470
  self.only_mid_control = only_mid_control
471
  self.control_scales = [1.0] * 13
472
 
 
 
 
 
 
 
 
 
 
 
 
 
 
473
 
474
  @torch.no_grad()
475
  def add_new_layers(self):
@@ -491,8 +392,6 @@ class ControlLDM(LatentDiffusion):
491
  diffusion_model = self.model.diffusion_model
492
 
493
  cond_txt = torch.cat(cond['c_crossattn'], 1)
494
- # print('attention_shape:',cond)
495
- # exit(0)
496
  if cond['c_concat'] is None:
497
  eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=None, only_mid_control=self.only_mid_control)
498
  else:
@@ -504,13 +403,10 @@ class ControlLDM(LatentDiffusion):
504
  eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=control, only_mid_control=self.only_mid_control)
505
  # print('unet_context',cond_txt.shape)
506
  return eps
507
- # @torch.no_grad()
508
- # def get_attn_intrinsic(self):
509
- # return self.control_model.new_context
510
  @torch.no_grad()
511
  def get_unconditional_conditioning(self, N):
512
  return self.get_learned_conditioning([""] * N)
513
-
514
  @torch.no_grad()
515
  def log_images(self, batch, N=4, n_row=2, sample=False, ddim_steps=50, ddim_eta=0.0, return_keys=None,
516
  quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
@@ -583,9 +479,7 @@ class ControlLDM(LatentDiffusion):
583
  def configure_optimizers(self):
584
  lr = self.learning_rate
585
  params = list(self.control_model.parameters())
586
- # if not self.sd_locked:
587
- # params += list(self.model.diffusion_model.output_blocks.parameters())
588
- # params += list(self.model.diffusion_model.out.parameters())
589
  if self.crossattn_start: #here we also train the cross-attan in the input layer if has any
590
  for block in self.model.diffusion_model.input_blocks:
591
  for layer in block:
 
2
  import torch
3
  import torch as th
4
  import torch.nn as nn
 
 
 
5
  from cldm.latent_intrinsic import LatentIntrinsc
6
  from ldm.modules.diffusionmodules.util import (
7
  conv_nd,
 
47
  h = torch.cat([h, hs.pop() + control.pop()], dim=1)
48
  h = module(h, emb, context)
49
  # print("fool! that is the shape of the context! at output",context.shape)
 
50
 
51
  h = h.type(x.dtype)
52
  return self.out(h)
 
169
  zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
170
  )
171
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  self.input_latent_hint_cat_atten_block = TimestepEmbedSequential(
173
  # conv_nd(dims, hint_channels, 16, 3, padding=1),
174
  # nn.SiLU(),
 
187
  zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
188
  )
189
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
 
191
  self._feature_size = model_channels
192
  input_block_chans = [model_channels]
 
305
  self._feature_size += ch
306
  self.latent_iid = True
307
  self.concat = True
308
+
 
 
 
 
 
 
 
 
309
  if self.latent_iid:
310
  self.input_hint_block = self.input_latent_hint_cat_atten_block
311
  def make_zero_conv(self, channels):
 
313
  #our modification for the latent intrinsic
314
  def add_latent_prior(self):
315
  self.prior_extracter = LatentIntrinsc()
316
+
 
317
 
318
  def forward(self, x, hint, timesteps, context, **kwargs):
319
  t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
320
  emb = self.time_embed(t_emb)
321
  self.latent_iid = True
 
322
  # adding a cross-attention version
323
  if self.latent_iid:
324
  hint, hint_lighting = self.prior_extracter(hint)
325
  self.new_context = torch.cat([context,hint_lighting],1)
 
326
  context = self.new_context
327
+
 
328
  self.input_hint_block = self.input_latent_hint_cat_atten_block
329
+
 
 
 
 
 
 
 
330
  guided_hint = self.input_hint_block(hint, emb, context)
331
  else:
332
  guided_hint = self.input_hint_block(hint, emb, context)
 
358
  self.only_mid_control = only_mid_control
359
  self.control_scales = [1.0] * 13
360
 
361
+ # load bypass decoder
362
+ @torch.no_grad()
363
+ def change_first_stage(self, checkpoint_file,og=False):
364
+ del self.first_stage_model
365
+ from modi_vae.autoencoder import AutoencoderKL
366
+ self.first_stage_model = AutoencoderKL(load_checkpoint=False)
367
+ state_dict = torch.load(checkpoint_file, map_location=torch.device("cpu"))["state_dict"]
368
+ new_state_dict = {}
369
+ for s in state_dict:
370
+ new_state_dict[s]=state_dict[s]
371
+
372
+ self.first_stage_model.load_state_dict(new_state_dict)
373
+ print("Successfully load new auto-encoder")
374
 
375
  @torch.no_grad()
376
  def add_new_layers(self):
 
392
  diffusion_model = self.model.diffusion_model
393
 
394
  cond_txt = torch.cat(cond['c_crossattn'], 1)
 
 
395
  if cond['c_concat'] is None:
396
  eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=None, only_mid_control=self.only_mid_control)
397
  else:
 
403
  eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=control, only_mid_control=self.only_mid_control)
404
  # print('unet_context',cond_txt.shape)
405
  return eps
 
 
 
406
  @torch.no_grad()
407
  def get_unconditional_conditioning(self, N):
408
  return self.get_learned_conditioning([""] * N)
409
+
410
  @torch.no_grad()
411
  def log_images(self, batch, N=4, n_row=2, sample=False, ddim_steps=50, ddim_eta=0.0, return_keys=None,
412
  quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
 
479
  def configure_optimizers(self):
480
  lr = self.learning_rate
481
  params = list(self.control_model.parameters())
482
+
 
 
483
  if self.crossattn_start: #here we also train the cross-attan in the input layer if has any
484
  for block in self.model.diffusion_model.input_blocks:
485
  for layer in block: