File size: 10,172 Bytes
074c857 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 |
from torch import nn
from k_diffusion import utils as k_utils
import torch
from k_diffusion.external import CompVisDenoiser
from torchvision.utils import make_grid
from IPython import display
from torchvision.transforms.functional import to_pil_image
class CFGDenoiser(nn.Module):
def __init__(self, model):
super().__init__()
self.inner_model = model
def forward(self, x, sigma, uncond, cond, cond_scale):
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond])
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
return uncond + (cond - uncond) * cond_scale
class CFGDenoiserWithGrad(CompVisDenoiser):
def __init__(self, model,
loss_fns_scales, # List of [cond_function, scale] pairs
clamp_func=None, # Gradient clamping function, clamp_func(grad, sigma)
gradient_wrt=None, # Calculate gradient with respect to ["x", "x0_pred", "both"]
gradient_add_to=None, # Add gradient to ["cond", "uncond", "both"]
cond_uncond_sync=True, # Calculates the cond and uncond simultaneously
decode_method=None, # Function used to decode the latent during gradient calculation
grad_inject_timing_fn=None, # Option to use grad in only a few of the steps
grad_consolidate_fn=None, # Function to add grad to image fn(img, grad, sigma)
verbose=False):
super().__init__(model.inner_model)
self.inner_model = model
self.cond_uncond_sync = cond_uncond_sync
# Initialize gradient calculation variables
self.clamp_func = clamp_func
self.gradient_add_to = gradient_add_to
if gradient_wrt is None:
self.gradient_wrt = 'x'
self.gradient_wrt = gradient_wrt
if decode_method is None:
decode_fn = lambda x: x
elif decode_method == "autoencoder":
decode_fn = model.inner_model.differentiable_decode_first_stage
elif decode_method == "linear":
decode_fn = model.inner_model.linear_decode
self.decode_fn = decode_fn
# Parse loss function-scale pairs
cond_fns = []
for loss_fn,scale in loss_fns_scales:
if scale != 0:
cond_fn = self.make_cond_fn(loss_fn, scale)
else:
cond_fn = None
cond_fns += [cond_fn]
self.cond_fns = cond_fns
if grad_inject_timing_fn is None:
self.grad_inject_timing_fn = lambda sigma: True
else:
self.grad_inject_timing_fn = grad_inject_timing_fn
if grad_consolidate_fn is None:
self.grad_consolidate_fn = lambda img, grad, sigma: img + grad * sigma
else:
self.grad_consolidate_fn = grad_consolidate_fn
self.verbose = verbose
self.verbose_print = print if self.verbose else lambda *args, **kwargs: None
# General denoising model with gradient conditioning
def cond_model_fn_(self, x, sigma, inner_model=None, **kwargs):
# inner_model: optionally use a different inner_model function or a wrapper function around inner_model, see self.forward._cfg_model
if inner_model is None:
inner_model = self.inner_model
total_cond_grad = torch.zeros_like(x)
for cond_fn in self.cond_fns:
if cond_fn is None: continue
# Gradient with respect to x
if self.gradient_wrt == 'x':
with torch.enable_grad():
x = x.detach().requires_grad_()
denoised = inner_model(x, sigma, **kwargs)
cond_grad = cond_fn(x, sigma, denoised=denoised, **kwargs).detach()
# Gradient wrt x0_pred, so save some compute: don't record grad until after denoised is calculated
elif self.gradient_wrt == 'x0_pred':
with torch.no_grad():
denoised = inner_model(x, sigma, **kwargs)
with torch.enable_grad():
cond_grad = cond_fn(x, sigma, denoised=denoised.detach().requires_grad_(), **kwargs).detach()
total_cond_grad += cond_grad
total_cond_grad = torch.nan_to_num(total_cond_grad, nan=0.0, posinf=float('inf'), neginf=-float('inf'))
# Clamp the gradient
total_cond_grad = self.clamp_grad_verbose(total_cond_grad, sigma)
# Add gradient to the image
if self.gradient_wrt == 'x':
x.copy_(self.grad_consolidate_fn(x.detach(), total_cond_grad, k_utils.append_dims(sigma, x.ndim)))
cond_denoised = inner_model(x, sigma, **kwargs)
elif self.gradient_wrt == 'x0_pred':
x.copy_(self.grad_consolidate_fn(x.detach(), total_cond_grad, k_utils.append_dims(sigma, x.ndim)))
cond_denoised = self.grad_consolidate_fn(denoised.detach(), total_cond_grad, k_utils.append_dims(sigma, x.ndim))
return cond_denoised
def forward(self, x, sigma, uncond, cond, cond_scale):
def _cfg_model(x, sigma, cond, **kwargs):
# Wrapper to add denoised cond and uncond as in a cfg model
# input "cond" is both cond and uncond weights: torch.cat([uncond, cond])
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
denoised = self.inner_model(x_in, sigma_in, cond=cond, **kwargs)
uncond_x0, cond_x0 = denoised.chunk(2)
x0_pred = uncond_x0 + (cond_x0 - uncond_x0) * cond_scale
return x0_pred
# Conditioning
if self.check_conditioning_schedule(sigma):
# Apply the conditioning gradient to the completed denoised (after both cond and uncond are combined into the diffused image)
if self.cond_uncond_sync:
# x0 = self.cfg_cond_model_fn_(x, sigma, uncond=uncond, cond=cond, cond_scale=cond_scale)
cond_in = torch.cat([uncond, cond])
x0 = self.cond_model_fn_(x, sigma, cond=cond_in, inner_model=_cfg_model)
# Calculate cond and uncond separately
else:
if self.gradient_add_to == "uncond":
uncond = self.cond_model_fn_(x, sigma, cond=uncond)
cond = self.inner_model(x, sigma, cond=cond)
x0 = uncond + (cond - uncond) * cond_scale
elif self.gradient_add_to == "cond":
uncond = self.inner_model(x, sigma, cond=uncond)
cond = self.cond_model_fn_(x, sigma, cond=cond)
x0 = uncond + (cond - uncond) * cond_scale
elif self.gradient_add_to == "both":
uncond = self.cond_model_fn_(x, sigma, cond=uncond)
cond = self.cond_model_fn_(x, sigma, cond=cond)
x0 = uncond + (cond - uncond) * cond_scale
else:
raise Exception(f"Unrecognised option for gradient_add_to: {self.gradient_add_to}")
# No conditioning
else:
# calculate cond and uncond simultaneously
if self.cond_uncond_sync:
cond_in = torch.cat([uncond, cond])
x0 = _cfg_model(x, sigma, cond=cond_in)
else:
uncond = self.inner_model(x, sigma, cond=uncond)
cond = self.inner_model(x, sigma, cond=cond)
x0 = uncond + (cond - uncond) * cond_scale
return x0
def make_cond_fn(self, loss_fn, scale):
# Turns a loss function into a cond function that is applied to the decoded RGB sample
# loss_fn (function): func(x, sigma, denoised) -> number
# scale (number): how much this loss is applied to the image
# Cond function with respect to x
def cond_fn(x, sigma, denoised, **kwargs):
with torch.enable_grad():
denoised_sample = self.decode_fn(denoised).requires_grad_()
loss = loss_fn(denoised_sample, sigma, **kwargs) * scale
grad = -torch.autograd.grad(loss, x)[0]
self.verbose_print('Loss:', loss.item())
return grad
# Cond function with respect to x0_pred
def cond_fn_pred(x, sigma, denoised, **kwargs):
with torch.enable_grad():
denoised_sample = self.decode_fn(denoised).requires_grad_()
loss = loss_fn(denoised_sample, sigma, **kwargs) * scale
grad = -torch.autograd.grad(loss, denoised)[0]
self.verbose_print('Loss:', loss.item())
return grad
if self.gradient_wrt == 'x':
return cond_fn
elif self.gradient_wrt == 'x0_pred':
return cond_fn_pred
else:
raise Exception(f"Variable gradient_wrt == {self.gradient_wrt} not recognised.")
def clamp_grad_verbose(self, grad, sigma):
if self.clamp_func is not None:
if self.verbose:
print("Grad before clamping:")
self.display_samples(torch.abs(grad*2.0) - 1.0)
grad = self.clamp_func(grad, sigma)
if self.verbose:
print("Conditioning gradient")
self.display_samples(torch.abs(grad*2.0) - 1.0)
return grad
def check_conditioning_schedule(self, sigma):
is_conditioning_step = False
if (self.cond_fns is not None and
any(cond_fn is not None for cond_fn in self.cond_fns)):
# Conditioning strength != 0
# Check if this is a conditioning step
if self.grad_inject_timing_fn(sigma):
is_conditioning_step = True
if self.verbose:
print(f"Conditioning step for sigma={sigma}")
return is_conditioning_step
def display_samples(self, images):
images = images.double().cpu().add(1).div(2).clamp(0, 1)
images = torch.tensor(images.numpy())
grid = make_grid(images, 4).cpu()
display.display(to_pil_image(grid))
return
|