Commit
·
7d28380
1
Parent(s):
b0afe49
add cond stage to trainable parameters
Browse files
ControlNet/ControlNet.ipynb
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ControlNet/cldm/cldm.py
CHANGED
|
@@ -2,6 +2,8 @@ import einops
|
|
| 2 |
import torch
|
| 3 |
import torch as th
|
| 4 |
import torch.nn as nn
|
|
|
|
|
|
|
| 5 |
from torchvision.transforms import Resize
|
| 6 |
|
| 7 |
from ldm.modules.diffusionmodules.util import (
|
|
@@ -305,12 +307,15 @@ class ControlNet(nn.Module):
|
|
| 305 |
|
| 306 |
class ControlInpaintLDM(LatentDiffusion):
|
| 307 |
|
| 308 |
-
def __init__(self, control_stage_config, control_key, only_mid_control, *args, **kwargs):
|
| 309 |
super().__init__(*args, **kwargs)
|
| 310 |
self.control_model = instantiate_from_config(control_stage_config)
|
| 311 |
self.control_key = control_key
|
| 312 |
self.only_mid_control = only_mid_control
|
| 313 |
self.control_scales = [1.0] * 13
|
|
|
|
|
|
|
|
|
|
| 314 |
|
| 315 |
@torch.no_grad()
|
| 316 |
def get_input(self, batch, k, bs=None, *args, **kwargs):
|
|
@@ -380,6 +385,7 @@ class ControlInpaintLDM(LatentDiffusion):
|
|
| 380 |
|
| 381 |
if self.cond_stage_trainable:
|
| 382 |
c = self.get_learned_conditioning(c)
|
|
|
|
| 383 |
|
| 384 |
if sample:
|
| 385 |
# get denoise row
|
|
@@ -412,15 +418,38 @@ class ControlInpaintLDM(LatentDiffusion):
|
|
| 412 |
shape = (self.channels, h // 8, w // 8)
|
| 413 |
samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, shape, cond, verbose=False, **kwargs)
|
| 414 |
return samples, intermediates
|
| 415 |
-
|
| 416 |
def configure_optimizers(self):
|
| 417 |
lr = self.learning_rate
|
| 418 |
params = list(self.control_model.parameters())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 419 |
if not self.sd_locked:
|
| 420 |
params += list(self.model.diffusion_model.output_blocks.parameters())
|
| 421 |
params += list(self.model.diffusion_model.out.parameters())
|
| 422 |
-
opt = torch.optim.AdamW(params, lr=lr)
|
|
|
|
|
|
|
| 423 |
return opt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 424 |
|
| 425 |
def low_vram_shift(self, is_diffusing):
|
| 426 |
if is_diffusing:
|
|
|
|
| 2 |
import torch
|
| 3 |
import torch as th
|
| 4 |
import torch.nn as nn
|
| 5 |
+
import random
|
| 6 |
+
import bitsandbytes as bnb
|
| 7 |
from torchvision.transforms import Resize
|
| 8 |
|
| 9 |
from ldm.modules.diffusionmodules.util import (
|
|
|
|
| 307 |
|
| 308 |
class ControlInpaintLDM(LatentDiffusion):
|
| 309 |
|
| 310 |
+
def __init__(self, control_stage_config, control_key, u_cond_percent, only_mid_control, *args, **kwargs):
|
| 311 |
super().__init__(*args, **kwargs)
|
| 312 |
self.control_model = instantiate_from_config(control_stage_config)
|
| 313 |
self.control_key = control_key
|
| 314 |
self.only_mid_control = only_mid_control
|
| 315 |
self.control_scales = [1.0] * 13
|
| 316 |
+
self.learnable_vector = nn.Parameter(torch.randn((1,1,768)), requires_grad=True)
|
| 317 |
+
self.proj_out=nn.Linear(1024, 768)
|
| 318 |
+
self.u_cond_percent=u_cond_percent
|
| 319 |
|
| 320 |
@torch.no_grad()
|
| 321 |
def get_input(self, batch, k, bs=None, *args, **kwargs):
|
|
|
|
| 385 |
|
| 386 |
if self.cond_stage_trainable:
|
| 387 |
c = self.get_learned_conditioning(c)
|
| 388 |
+
c = self.proj_out(c)
|
| 389 |
|
| 390 |
if sample:
|
| 391 |
# get denoise row
|
|
|
|
| 418 |
shape = (self.channels, h // 8, w // 8)
|
| 419 |
samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, shape, cond, verbose=False, **kwargs)
|
| 420 |
return samples, intermediates
|
| 421 |
+
|
| 422 |
def configure_optimizers(self):
|
| 423 |
lr = self.learning_rate
|
| 424 |
params = list(self.control_model.parameters())
|
| 425 |
+
if self.cond_stage_trainable:
|
| 426 |
+
print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
|
| 427 |
+
params = params + list(self.cond_stage_model.final_ln.parameters())+list(self.cond_stage_model.mapper.parameters())+list(self.proj_out.parameters())
|
| 428 |
+
self.params = params
|
| 429 |
+
self.params_with_white=params + list(self.learnable_vector)
|
| 430 |
if not self.sd_locked:
|
| 431 |
params += list(self.model.diffusion_model.output_blocks.parameters())
|
| 432 |
params += list(self.model.diffusion_model.out.parameters())
|
| 433 |
+
#opt = torch.optim.AdamW(params, lr=lr)
|
| 434 |
+
opt = bnb.optim.Adam8bit(params,lr=lr)
|
| 435 |
+
self.opt=opt
|
| 436 |
return opt
|
| 437 |
+
|
| 438 |
+
def forward(self, x, c, *args, **kwargs):
|
| 439 |
+
self.opt.params=self.params
|
| 440 |
+
t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
|
| 441 |
+
if self.model.conditioning_key is not None:
|
| 442 |
+
assert c is not None
|
| 443 |
+
if self.cond_stage_trainable:
|
| 444 |
+
c['c_crossattn'][0] = self.get_learned_conditioning(c['c_crossattn'][0])
|
| 445 |
+
c['c_crossattn'][0] = self.proj_out(c['c_crossattn'][0])
|
| 446 |
+
u_cond_prop=random.uniform(0, 1)
|
| 447 |
+
if u_cond_prop<self.u_cond_percent:
|
| 448 |
+
self.opt.params=self.params_with_white
|
| 449 |
+
c['c_crossattn'][0] = self.learnable_vector.repeat(x.shape[0],1,1)
|
| 450 |
+
return self.p_losses(x, c, t, *args, **kwargs)
|
| 451 |
+
return self.p_losses(x, c, t, *args, **kwargs)
|
| 452 |
+
|
| 453 |
|
| 454 |
def low_vram_shift(self, is_diffusing):
|
| 455 |
if is_diffusing:
|
ControlNet/environment.yaml
CHANGED
|
@@ -1,12 +1,13 @@
|
|
| 1 |
name: control
|
| 2 |
channels:
|
| 3 |
- pytorch
|
|
|
|
| 4 |
- defaults
|
| 5 |
dependencies:
|
| 6 |
- python=3.8.5
|
| 7 |
- pip=20.3
|
| 8 |
- cudatoolkit=11.3
|
| 9 |
-
- pytorch=1.
|
| 10 |
- torchvision=0.13.1
|
| 11 |
- numpy=1.23.1
|
| 12 |
- pip:
|
|
@@ -36,4 +37,5 @@ dependencies:
|
|
| 36 |
- ipdb==0.13.11
|
| 37 |
- ipython==8.11.0
|
| 38 |
- ipykernel==6.21.2
|
|
|
|
| 39 |
|
|
|
|
| 1 |
name: control
|
| 2 |
channels:
|
| 3 |
- pytorch
|
| 4 |
+
- anaconda
|
| 5 |
- defaults
|
| 6 |
dependencies:
|
| 7 |
- python=3.8.5
|
| 8 |
- pip=20.3
|
| 9 |
- cudatoolkit=11.3
|
| 10 |
+
- pytorch=1.13.1
|
| 11 |
- torchvision=0.13.1
|
| 12 |
- numpy=1.23.1
|
| 13 |
- pip:
|
|
|
|
| 37 |
- ipdb==0.13.11
|
| 38 |
- ipython==8.11.0
|
| 39 |
- ipykernel==6.21.2
|
| 40 |
+
- bitsandbytes==0.37.1
|
| 41 |
|
ControlNet/ldm/models/diffusion/ddpm.py
CHANGED
|
@@ -552,8 +552,6 @@ class LatentDiffusion(DDPM):
|
|
| 552 |
reset_num_ema_updates = kwargs.pop("reset_num_ema_updates", False)
|
| 553 |
ignore_keys = kwargs.pop("ignore_keys", [])
|
| 554 |
super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
|
| 555 |
-
self.learnable_vector = nn.Parameter(torch.randn((1,1,768)), requires_grad=True)
|
| 556 |
-
self.u_cond_percent=u_cond_percent
|
| 557 |
self.concat_mode = concat_mode
|
| 558 |
self.cond_stage_trainable = cond_stage_trainable
|
| 559 |
self.cond_stage_key = cond_stage_key
|
|
|
|
| 552 |
reset_num_ema_updates = kwargs.pop("reset_num_ema_updates", False)
|
| 553 |
ignore_keys = kwargs.pop("ignore_keys", [])
|
| 554 |
super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
|
|
|
|
|
|
|
| 555 |
self.concat_mode = concat_mode
|
| 556 |
self.cond_stage_trainable = cond_stage_trainable
|
| 557 |
self.cond_stage_key = cond_stage_key
|
ControlNet/ldm/modules/encoders/modules.py
CHANGED
|
@@ -137,7 +137,6 @@ class FrozenCLIPImageEmbedder(AbstractEncoder):
|
|
| 137 |
super().__init__()
|
| 138 |
self.transformer = CLIPVisionModel.from_pretrained(version)
|
| 139 |
self.final_ln = LayerNorm(1024)
|
| 140 |
-
self.proj_out=nn.Linear(1024, 768)
|
| 141 |
self.mapper = Transformer(
|
| 142 |
1,
|
| 143 |
1024,
|
|
@@ -162,7 +161,6 @@ class FrozenCLIPImageEmbedder(AbstractEncoder):
|
|
| 162 |
z = z.unsqueeze(1)
|
| 163 |
z = self.mapper(z)
|
| 164 |
z = self.final_ln(z)
|
| 165 |
-
z = self.proj_out(z)
|
| 166 |
return z
|
| 167 |
|
| 168 |
def encode(self, image):
|
|
|
|
| 137 |
super().__init__()
|
| 138 |
self.transformer = CLIPVisionModel.from_pretrained(version)
|
| 139 |
self.final_ln = LayerNorm(1024)
|
|
|
|
| 140 |
self.mapper = Transformer(
|
| 141 |
1,
|
| 142 |
1024,
|
|
|
|
| 161 |
z = z.unsqueeze(1)
|
| 162 |
z = self.mapper(z)
|
| 163 |
z = self.final_ln(z)
|
|
|
|
| 164 |
return z
|
| 165 |
|
| 166 |
def encode(self, image):
|