File size: 10,172 Bytes
074c857
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
from torch import nn
from k_diffusion import utils as k_utils
import torch
from k_diffusion.external import CompVisDenoiser
from torchvision.utils import make_grid
from IPython import display
from torchvision.transforms.functional import to_pil_image

class CFGDenoiser(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.inner_model = model

    def forward(self, x, sigma, uncond, cond, cond_scale):
        x_in = torch.cat([x] * 2)
        sigma_in = torch.cat([sigma] * 2)
        cond_in = torch.cat([uncond, cond])
        uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
        return uncond + (cond - uncond) * cond_scale

class CFGDenoiserWithGrad(CompVisDenoiser):
    def __init__(self, model, 
                       loss_fns_scales, # List of [cond_function, scale] pairs
                       clamp_func=None,  # Gradient clamping function, clamp_func(grad, sigma)
                       gradient_wrt=None, # Calculate gradient with respect to ["x", "x0_pred", "both"]
                       gradient_add_to=None, # Add gradient to ["cond", "uncond", "both"]
                       cond_uncond_sync=True, # Calculates the cond and uncond simultaneously
                       decode_method=None, # Function used to decode the latent during gradient calculation
                       grad_inject_timing_fn=None, # Option to use grad in only a few of the steps
                       grad_consolidate_fn=None, # Function to add grad to image fn(img, grad, sigma)
                       verbose=False):
        super().__init__(model.inner_model)
        self.inner_model = model
        self.cond_uncond_sync = cond_uncond_sync 

        # Initialize gradient calculation variables
        self.clamp_func = clamp_func
        self.gradient_add_to = gradient_add_to
        if gradient_wrt is None:
            self.gradient_wrt = 'x'
        self.gradient_wrt = gradient_wrt
        if decode_method is None:
            decode_fn = lambda x: x
        elif decode_method == "autoencoder":
            decode_fn = model.inner_model.differentiable_decode_first_stage
        elif decode_method == "linear":
            decode_fn = model.inner_model.linear_decode
        self.decode_fn = decode_fn

        # Parse loss function-scale pairs
        cond_fns = []
        for loss_fn,scale in loss_fns_scales:
            if scale != 0:
                cond_fn = self.make_cond_fn(loss_fn, scale)
            else:
                cond_fn = None
            cond_fns += [cond_fn]
        self.cond_fns = cond_fns

        if grad_inject_timing_fn is None:
            self.grad_inject_timing_fn = lambda sigma: True
        else:
            self.grad_inject_timing_fn = grad_inject_timing_fn
        if grad_consolidate_fn is None:
            self.grad_consolidate_fn = lambda img, grad, sigma: img + grad * sigma
        else:
            self.grad_consolidate_fn = grad_consolidate_fn

        self.verbose = verbose
        self.verbose_print = print if self.verbose else lambda *args, **kwargs: None


    # General denoising model with gradient conditioning
    def cond_model_fn_(self, x, sigma, inner_model=None, **kwargs):

        # inner_model: optionally use a different inner_model function or a wrapper function around inner_model, see self.forward._cfg_model
        if inner_model is None:
            inner_model = self.inner_model

        total_cond_grad = torch.zeros_like(x)
        for cond_fn in self.cond_fns:
            if cond_fn is None: continue

            # Gradient with respect to x
            if self.gradient_wrt == 'x':
                with torch.enable_grad():
                    x = x.detach().requires_grad_()
                    denoised = inner_model(x, sigma, **kwargs)
                    cond_grad = cond_fn(x, sigma, denoised=denoised, **kwargs).detach()

            # Gradient wrt x0_pred, so save some compute: don't record grad until after denoised is calculated
            elif self.gradient_wrt == 'x0_pred':
                with torch.no_grad():
                    denoised = inner_model(x, sigma, **kwargs)
                with torch.enable_grad():
                    cond_grad = cond_fn(x, sigma, denoised=denoised.detach().requires_grad_(), **kwargs).detach()
            total_cond_grad += cond_grad

        total_cond_grad = torch.nan_to_num(total_cond_grad, nan=0.0, posinf=float('inf'), neginf=-float('inf'))

        # Clamp the gradient
        total_cond_grad = self.clamp_grad_verbose(total_cond_grad, sigma)

        # Add gradient to the image
        if self.gradient_wrt == 'x':
            x.copy_(self.grad_consolidate_fn(x.detach(), total_cond_grad, k_utils.append_dims(sigma, x.ndim)))
            cond_denoised = inner_model(x, sigma, **kwargs)
        elif self.gradient_wrt == 'x0_pred':
            x.copy_(self.grad_consolidate_fn(x.detach(), total_cond_grad, k_utils.append_dims(sigma, x.ndim)))
            cond_denoised = self.grad_consolidate_fn(denoised.detach(), total_cond_grad, k_utils.append_dims(sigma, x.ndim))

        return cond_denoised

    def forward(self, x, sigma, uncond, cond, cond_scale):

        def _cfg_model(x, sigma, cond, **kwargs):
            # Wrapper to add denoised cond and uncond as in a cfg model
            # input "cond" is both cond and uncond weights: torch.cat([uncond, cond])
            x_in = torch.cat([x] * 2)
            sigma_in = torch.cat([sigma] * 2)

            denoised = self.inner_model(x_in, sigma_in, cond=cond, **kwargs)
            uncond_x0, cond_x0 = denoised.chunk(2)
            x0_pred = uncond_x0 + (cond_x0 - uncond_x0) * cond_scale
            return x0_pred

        # Conditioning
        if self.check_conditioning_schedule(sigma):
            # Apply the conditioning gradient to the completed denoised (after both cond and uncond are combined into the diffused image)
            if self.cond_uncond_sync:
                # x0 = self.cfg_cond_model_fn_(x, sigma, uncond=uncond, cond=cond, cond_scale=cond_scale)
                cond_in = torch.cat([uncond, cond])
                x0 = self.cond_model_fn_(x, sigma, cond=cond_in, inner_model=_cfg_model)

            # Calculate cond and uncond separately
            else:
                if self.gradient_add_to == "uncond":
                    uncond = self.cond_model_fn_(x, sigma, cond=uncond)
                    cond = self.inner_model(x, sigma, cond=cond)
                    x0 = uncond + (cond - uncond) * cond_scale
                elif self.gradient_add_to == "cond":
                    uncond = self.inner_model(x, sigma, cond=uncond)
                    cond = self.cond_model_fn_(x, sigma, cond=cond)
                    x0 = uncond + (cond - uncond) * cond_scale
                elif self.gradient_add_to == "both":
                    uncond = self.cond_model_fn_(x, sigma, cond=uncond)
                    cond = self.cond_model_fn_(x, sigma, cond=cond)
                    x0 = uncond + (cond - uncond) * cond_scale
                else: 
                    raise Exception(f"Unrecognised option for gradient_add_to: {self.gradient_add_to}")

        # No conditioning
        else:
            # calculate cond and uncond simultaneously
            if self.cond_uncond_sync:
                cond_in = torch.cat([uncond, cond])
                x0 = _cfg_model(x, sigma, cond=cond_in)
            else:
                uncond = self.inner_model(x, sigma, cond=uncond)
                cond = self.inner_model(x, sigma, cond=cond)
                x0 = uncond + (cond - uncond) * cond_scale

        return x0

    def make_cond_fn(self, loss_fn, scale):
        # Turns a loss function into a cond function that is applied to the decoded RGB sample
        # loss_fn (function): func(x, sigma, denoised) -> number
        # scale (number): how much this loss is applied to the image

        # Cond function with respect to x
        def cond_fn(x, sigma, denoised, **kwargs):
            with torch.enable_grad():
                denoised_sample = self.decode_fn(denoised).requires_grad_()
                loss = loss_fn(denoised_sample, sigma, **kwargs) * scale
                grad = -torch.autograd.grad(loss, x)[0]
            self.verbose_print('Loss:', loss.item())
            return grad

        # Cond function with respect to x0_pred
        def cond_fn_pred(x, sigma, denoised, **kwargs):
            with torch.enable_grad():
                denoised_sample = self.decode_fn(denoised).requires_grad_()
                loss = loss_fn(denoised_sample, sigma, **kwargs) * scale
                grad = -torch.autograd.grad(loss, denoised)[0]
            self.verbose_print('Loss:', loss.item())
            return grad

        if self.gradient_wrt == 'x':
            return cond_fn
        elif self.gradient_wrt == 'x0_pred':
            return cond_fn_pred
        else:
            raise Exception(f"Variable gradient_wrt == {self.gradient_wrt} not recognised.")

    def clamp_grad_verbose(self, grad, sigma):
        if self.clamp_func is not None:
            if self.verbose:
                print("Grad before clamping:")
                self.display_samples(torch.abs(grad*2.0) - 1.0)
            grad = self.clamp_func(grad, sigma)
        if self.verbose:
            print("Conditioning gradient")
            self.display_samples(torch.abs(grad*2.0) - 1.0)
        return grad

    def check_conditioning_schedule(self, sigma):
        is_conditioning_step = False

        if (self.cond_fns is not None and 
            any(cond_fn is not None for cond_fn in self.cond_fns)):
            # Conditioning strength != 0
            # Check if this is a conditioning step
            if self.grad_inject_timing_fn(sigma):
                is_conditioning_step = True

                if self.verbose:
                    print(f"Conditioning step for sigma={sigma}")

        return is_conditioning_step

    def display_samples(self, images):
        images = images.double().cpu().add(1).div(2).clamp(0, 1)
        images = torch.tensor(images.numpy())
        grid = make_grid(images, 4).cpu()
        display.display(to_pil_image(grid))
        return