File size: 11,124 Bytes
3dabe4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
from tile_methods.abstractdiffusion import AbstractDiffusion
from tile_utils.utils import *


class MultiDiffusion(AbstractDiffusion):
    """
        Multi-Diffusion Implementation
        https://arxiv.org/abs/2302.08113
    """

    def __init__(self, p:Processing, *args, **kwargs):
        super().__init__(p, *args, **kwargs)
        assert p.sampler_name != 'UniPC', 'MultiDiffusion is not compatible with UniPC!'

    def hook(self):
        if self.is_kdiff:
            # For K-Diffusion sampler with uniform prompt, we hijack into the inner model for simplicity
            # Otherwise, the masked-redraw will break due to the init_latent
            self.sampler: KDiffusionSampler
            self.sampler.model_wrap_cfg: CFGDenoiserKDiffusion
            self.sampler.model_wrap_cfg.inner_model: Union[CompVisDenoiser, CompVisVDenoiser]
            self.sampler_forward = self.sampler.model_wrap_cfg.inner_model.forward
            self.sampler.model_wrap_cfg.inner_model.forward = self.kdiff_forward
        else:
            self.sampler: CompVisSampler
            self.sampler.model_wrap_cfg: CFGDenoiserTimesteps
            self.sampler.model_wrap_cfg.inner_model: Union[CompVisTimestepsDenoiser, CompVisTimestepsVDenoiser]
            self.sampler_forward = self.sampler.model_wrap_cfg.inner_model.forward
            self.sampler.model_wrap_cfg.inner_model.forward = self.ddim_forward

    @staticmethod
    def unhook():
        # no need to unhook MultiDiffusion as it only hook the sampler,
        # which will be destroyed after the painting is done
        pass

    def reset_buffer(self, x_in:Tensor):
        super().reset_buffer(x_in)

    @custom_bbox
    def init_custom_bbox(self, *args):
        super().init_custom_bbox(*args)

        for bbox in self.custom_bboxes:
            if bbox.blend_mode == BlendMode.BACKGROUND:
                self.weights[bbox.slicer] += 1.0

    ''' ↓↓↓ kernel hijacks ↓↓↓ '''

    @torch.no_grad()
    @keep_signature
    def kdiff_forward(self, x_in:Tensor, sigma_in:Tensor, cond:CondDict) -> Tensor:
        assert CompVisDenoiser.forward
        assert CompVisVDenoiser.forward

        def org_func(x:Tensor) -> Tensor:
            return self.sampler_forward(x, sigma_in, cond=cond)

        def repeat_func(x_tile:Tensor, bboxes:List[CustomBBox]) -> Tensor:
            # For kdiff sampler, the dim 0 of input x_in is:
            #   = batch_size * (num_AND + 1)   if not an edit model
            #   = batch_size * (num_AND + 2)   otherwise
            sigma_tile = self.repeat_tensor(sigma_in, len(bboxes))
            cond_tile = self.repeat_cond_dict(cond, bboxes)
            return self.sampler_forward(x_tile, sigma_tile, cond=cond_tile)

        def custom_func(x:Tensor, bbox_id:int, bbox:CustomBBox) -> Tensor:
            return self.kdiff_custom_forward(x, sigma_in, cond, bbox_id, bbox, self.sampler_forward)

        return self.sample_one_step(x_in, org_func, repeat_func, custom_func)

    @torch.no_grad()
    @keep_signature
    def ddim_forward(self, x_in:Tensor, ts_in:Tensor, cond:Union[CondDict, Tensor]) -> Tensor:
        assert CompVisTimestepsDenoiser.forward
        assert CompVisTimestepsVDenoiser.forward

        def org_func(x:Tensor) -> Tensor:
            return self.sampler_forward(x, ts_in, cond=cond)

        def repeat_func(x_tile:Tensor, bboxes:List[CustomBBox]) -> Tuple[Tensor, Tensor]:
            n_rep = len(bboxes)
            ts_tile = self.repeat_tensor(ts_in, n_rep)
            if isinstance(cond, dict):   # FIXME: when will enter this branch?
                cond_tile = self.repeat_cond_dict(cond, bboxes)
            else:
                cond_tile = self.repeat_tensor(cond, n_rep)
            return self.sampler_forward(x_tile, ts_tile, cond=cond_tile)

        def custom_func(x:Tensor, bbox_id:int, bbox:CustomBBox) -> Tensor:
            # before the final forward, we can set the control tensor
            def forward_func(x, *args, **kwargs):
                self.set_custom_controlnet_tensors(bbox_id, 2*x.shape[0])
                self.set_custom_stablesr_tensors(bbox_id)
                return self.sampler_forward(x, *args, **kwargs)
            return self.ddim_custom_forward(x, cond, bbox, ts_in, forward_func)

        return self.sample_one_step(x_in, org_func, repeat_func, custom_func)

    def repeat_tensor(self, x:Tensor, n:int) -> Tensor:
        ''' repeat the tensor on it's first dim '''
        if n == 1: return x
        B = x.shape[0]
        r_dims = len(x.shape) - 1
        if B == 1:      # batch_size = 1 (not `tile_batch_size`)
            shape = [n] + [-1] * r_dims     # [N, -1, ...]
            return x.expand(shape)          # `expand` is much lighter than `tile`
        else:
            shape = [n] + [1] * r_dims      # [N, 1, ...]
            return x.repeat(shape)

    def repeat_cond_dict(self, cond_in:CondDict, bboxes:List[CustomBBox]) -> CondDict:
        ''' repeat all tensors in cond_dict on it's first dim (for a batch of tiles), returns a new object '''
        # n_repeat
        n_rep = len(bboxes)
        # txt cond
        tcond = self.get_tcond(cond_in)           # [B=1, L, D] => [B*N, L, D]
        tcond = self.repeat_tensor(tcond, n_rep)
        # img cond
        icond = self.get_icond(cond_in)
        if icond.shape[2:] == (self.h, self.w):   # img2img, [B=1, C, H, W]
            icond = torch.cat([icond[bbox.slicer] for bbox in bboxes], dim=0)
        else:                                     # txt2img, [B=1, C=5, H=1, W=1]
            icond = self.repeat_tensor(icond, n_rep)
        # vec cond (SDXL)
        vcond = self.get_vcond(cond_in)           # [B=1, D]
        if vcond is not None:
            vcond = self.repeat_tensor(vcond, n_rep)  # [B*N, D]
        return self.make_cond_dict(cond_in, tcond, icond, vcond)

    def sample_one_step(self, x_in:Tensor, org_func:Callable, repeat_func:Callable, custom_func:Callable) -> Tensor:
        '''
        this method splits the whole latent and process in tiles
            - x_in: current whole U-Net latent
            - org_func: original forward function, when use highres
            - repeat_func: one step denoiser for grid tile
            - custom_func: one step denoiser for custom tile
        '''

        N, C, H, W = x_in.shape
        if (H, W) != (self.h, self.w):
            # We don't tile highres, let's just use the original org_func
            self.reset_controlnet_tensors()
            return org_func(x_in)

        # clear buffer canvas
        self.reset_buffer(x_in)

        # Background sampling (grid bbox)
        if self.draw_background:
            for batch_id, bboxes in enumerate(self.batched_bboxes):
                if state.interrupted: return x_in

                # batching
                x_tile = torch.cat([x_in[bbox.slicer] for bbox in bboxes], dim=0)   # [TB, C, TH, TW]

                # controlnet tiling
                # FIXME: is_denoise is default to False, however it is set to True in case of MixtureOfDiffusers, why?
                self.switch_controlnet_tensors(batch_id, N, len(bboxes))

                # stablesr tiling
                self.switch_stablesr_tensors(batch_id)

                # compute tiles
                x_tile_out = repeat_func(x_tile, bboxes)
                for i, bbox in enumerate(bboxes):
                    self.x_buffer[bbox.slicer] += x_tile_out[i*N:(i+1)*N, :, :, :]

                # update progress bar
                self.update_pbar()

        # Custom region sampling (custom bbox)
        x_feather_buffer = None
        x_feather_mask   = None
        x_feather_count  = None
        if len(self.custom_bboxes) > 0:
            for bbox_id, bbox in enumerate(self.custom_bboxes):
                if state.interrupted: return x_in

                if not self.p.disable_extra_networks:
                    with devices.autocast():
                        extra_networks.activate(self.p, bbox.extra_network_data)

                x_tile = x_in[bbox.slicer]

                # retrieve original x_in from construncted input
                x_tile_out = custom_func(x_tile, bbox_id, bbox)

                if bbox.blend_mode == BlendMode.BACKGROUND:
                    self.x_buffer[bbox.slicer] += x_tile_out
                elif bbox.blend_mode == BlendMode.FOREGROUND:
                    if x_feather_buffer is None:
                        x_feather_buffer = torch.zeros_like(self.x_buffer)
                        x_feather_mask   = torch.zeros((1, 1, H, W), device=x_in.device)
                        x_feather_count  = torch.zeros((1, 1, H, W), device=x_in.device)
                    x_feather_buffer[bbox.slicer] += x_tile_out
                    x_feather_mask  [bbox.slicer] += bbox.feather_mask
                    x_feather_count [bbox.slicer] += 1

                if not self.p.disable_extra_networks:
                    with devices.autocast():
                        extra_networks.deactivate(self.p, bbox.extra_network_data)

                # update progress bar
                self.update_pbar()

        # Averaging background buffer
        x_out = torch.where(self.weights > 1, self.x_buffer / self.weights, self.x_buffer)

        # Foreground Feather blending
        if x_feather_buffer is not None:
            # Average overlapping feathered regions
            x_feather_buffer = torch.where(x_feather_count > 1, x_feather_buffer / x_feather_count, x_feather_buffer)
            x_feather_mask   = torch.where(x_feather_count > 1, x_feather_mask   / x_feather_count, x_feather_mask)
            # Weighted average with original x_buffer
            x_out = torch.where(x_feather_count > 0, x_out * (1 - x_feather_mask) + x_feather_buffer * x_feather_mask, x_out)

        return x_out

    def get_noise(self, x_in:Tensor, sigma_in:Tensor, cond_in:Dict[str, Tensor], step:int) -> Tensor:
        # NOTE: The following code is analytically wrong but aesthetically beautiful
        cond_in_original = cond_in.copy()

        def org_func(x:Tensor):
            return shared.sd_model.apply_model(x, sigma_in, cond=cond_in_original)

        def repeat_func(x_tile:Tensor, bboxes:List[CustomBBox]):
            sigma_in_tile = sigma_in.repeat(len(bboxes))
            cond_out = self.repeat_cond_dict(cond_in_original, bboxes)
            x_tile_out = shared.sd_model.apply_model(x_tile, sigma_in_tile, cond=cond_out)
            return x_tile_out

        def custom_func(x:Tensor, bbox_id:int, bbox:CustomBBox):
            # The negative prompt in custom bbox should not be used for noise inversion
            # otherwise the result will be astonishingly bad.
            tcond = Condition.reconstruct_cond(bbox.cond, step).unsqueeze_(0)
            icond = self.get_icond(cond_in_original)
            if icond.shape[2:] == (self.h, self.w):
                icond = icond[bbox.slicer]
            cond_out = self.make_cond_dict(cond_in, tcond, icond)
            return shared.sd_model.apply_model(x, sigma_in, cond=cond_out)

        return self.sample_one_step(x_in, org_func, repeat_func, custom_func)