File size: 35,871 Bytes
3dabe4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
from tile_utils.utils import *


class AbstractDiffusion:

    def __init__(self, p: Processing, sampler: Sampler):
        self.method = self.__class__.__name__
        self.p: Processing = p
        self.pbar = None

        # sampler
        self.sampler_name = p.sampler_name
        self.sampler_raw = sampler
        self.sampler = sampler

        # fix. Kdiff 'AND' support and image editing model support
        if self.is_kdiff and not hasattr(self, 'is_edit_model'):
            self.is_edit_model = (shared.sd_model.cond_stage_key == "edit"      # "txt"
                and self.sampler.model_wrap_cfg.image_cfg_scale is not None 
                and self.sampler.model_wrap_cfg.image_cfg_scale != 1.0)

        # cache. final result of current sampling step, [B, C=4, H//8, W//8]
        # avoiding overhead of creating new tensors and weight summing
        self.x_buffer: Tensor = None
        self.w: int = int(self.p.width  // opt_f)       # latent size
        self.h: int = int(self.p.height // opt_f)
        # weights for background & grid bboxes
        self.weights: Tensor = torch.zeros((1, 1, self.h, self.w), device=devices.device, dtype=torch.float32)

        # FIXME: I'm trying to count the step correctly but it's not working
        self.step_count = 0         
        self.inner_loop_count = 0  
        self.kdiff_step = -1

        # ext. Grid tiling painting (grid bbox)
        self.enable_grid_bbox: bool = False
        self.tile_w: int = None
        self.tile_h: int = None
        self.tile_bs: int = None
        self.num_tiles: int = None
        self.num_batches: int = None
        self.batched_bboxes: List[List[BBox]] = []

        # ext. Region Prompt Control (custom bbox)
        self.enable_custom_bbox: bool = False
        self.custom_bboxes: List[CustomBBox] = []
        self.cond_basis: Cond = None
        self.uncond_basis: Uncond = None
        self.draw_background: bool = True       # by default we draw major prompts in grid tiles
        self.causal_layers: bool = None

        # ext. Noise Inversion (noise inversion)
        self.noise_inverse_enabled: bool = False
        self.noise_inverse_steps: int = 0
        self.noise_inverse_retouch: float = None
        self.noise_inverse_renoise_strength: float = None
        self.noise_inverse_renoise_kernel: int = None
        self.noise_inverse_get_cache = None
        self.noise_inverse_set_cache = None
        self.sample_img2img_original = None

        # ext. ControlNet
        self.enable_controlnet: bool = False
        self.controlnet_script: ModuleType = None
        self.control_tensor_batch: List[List[Tensor]] = []
        self.control_params: Dict[str, Tensor] = {}
        self.control_tensor_cpu: bool = None
        self.control_tensor_custom: List[List[Tensor]] = []

        # ext. StableSR
        self.enable_stablesr: bool = False
        self.stablesr_script: ModuleType = None
        self.stablesr_tensor: Tensor = None
        self.stablesr_tensor_batch: List[Tensor] = []
        self.stablesr_tensor_custom: List[Tensor] = []

    @property
    def is_kdiff(self):
        return isinstance(self.sampler_raw, KDiffusionSampler)

    @property
    def is_ddim(self):
        return isinstance(self.sampler_raw, CompVisSampler)

    def update_pbar(self):
        if self.pbar.n >= self.pbar.total:
            self.pbar.close()
        else:
            if self.step_count == state.sampling_step:
                self.inner_loop_count += 1
                if self.inner_loop_count < self.total_bboxes:
                    self.pbar.update()
            else:
                self.step_count = state.sampling_step
                self.inner_loop_count = 0

    def reset_buffer(self, x_in:Tensor):
        # Judge if the shape of x_in is the same as the shape of x_buffer
        if self.x_buffer is None or self.x_buffer.shape != x_in.shape:
            self.x_buffer = torch.zeros_like(x_in, device=x_in.device, dtype=x_in.dtype)
        else:
            self.x_buffer.zero_()

    def init_done(self):
        '''
          Call this after all `init_*`, settings are done, now perform:
            - settings sanity check
            - pre-computations, cache init
            - anything thing needed before denoising starts
        '''

        self.total_bboxes = 0
        if self.enable_grid_bbox:   self.total_bboxes += self.num_batches
        if self.enable_custom_bbox: self.total_bboxes += len(self.custom_bboxes)
        assert self.total_bboxes > 0, "Nothing to paint! No background to draw and no custom bboxes were provided."

        self.pbar = tqdm(total=(self.total_bboxes) * state.sampling_steps, desc=f"{self.method} Sampling: ")

    ''' ↓↓↓ cond_dict utils ↓↓↓ '''

    def _tcond_key(self, cond_dict:CondDict) -> str:
        return 'crossattn' if 'crossattn' in cond_dict else 'c_crossattn'

    def get_tcond(self, cond_dict:CondDict) -> Tensor:
        tcond = cond_dict[self._tcond_key(cond_dict)]
        if isinstance(tcond, list): tcond = tcond[0]
        return tcond

    def set_tcond(self, cond_dict:CondDict, tcond:Tensor):
        key = self._tcond_key(cond_dict)
        if isinstance(cond_dict[key], list): tcond = [tcond]
        cond_dict[key] = tcond

    def _icond_key(self, cond_dict:CondDict) -> str:
        return 'c_adm' if shared.sd_model.model.conditioning_key in ['crossattn-adm', 'adm'] else 'c_concat'

    def get_icond(self, cond_dict:CondDict) -> Tensor:
        ''' icond differs for different models (inpaint/unclip model) '''
        key = self._icond_key(cond_dict)
        icond = cond_dict[key]
        if isinstance(icond, list): icond = icond[0]
        return icond

    def set_icond(self, cond_dict:CondDict, icond:Tensor):
        key = self._icond_key(cond_dict)
        if isinstance(cond_dict[key], list): icond = [icond]
        cond_dict[key] = icond

    def _vcond_key(self, cond_dict:CondDict) -> Optional[str]:
        return 'vector' if 'vector' in cond_dict else None

    def get_vcond(self, cond_dict:CondDict) -> Optional[Tensor]:
        ''' vector for SDXL '''
        key = self._vcond_key(cond_dict)
        return cond_dict.get(key)

    def set_vcond(self, cond_dict:CondDict, vcond:Optional[Tensor]):
        key = self._vcond_key(cond_dict)
        if key is not None:
            cond_dict[key] = vcond

    def make_cond_dict(self, cond_in:CondDict, tcond:Tensor, icond:Tensor, vcond:Tensor=None) -> CondDict:
        ''' copy & replace the content, returns a new object '''
        cond_out = cond_in.copy()
        self.set_tcond(cond_out, tcond)
        self.set_icond(cond_out, icond)
        self.set_vcond(cond_out, vcond)
        return cond_out

    ''' ↓↓↓ extensive functionality ↓↓↓ '''

    @grid_bbox
    def init_grid_bbox(self, tile_w:int, tile_h:int, overlap:int, tile_bs:int):
        self.enable_grid_bbox = True

        self.tile_w = min(tile_w, self.w)
        self.tile_h = min(tile_h, self.h)
        overlap = max(0, min(overlap, min(tile_w, tile_h) - 4))
        # split the latent into overlapped tiles, then batching
        # weights basically indicate how many times a pixel is painted
        bboxes, weights = split_bboxes(self.w, self.h, self.tile_w, self.tile_h, overlap, self.get_tile_weights())
        self.weights += weights
        self.num_tiles = len(bboxes)
        self.num_batches = math.ceil(self.num_tiles / tile_bs)
        self.tile_bs = math.ceil(len(bboxes) / self.num_batches)          # optimal_batch_size
        self.batched_bboxes = [bboxes[i*self.tile_bs:(i+1)*self.tile_bs] for i in range(self.num_batches)]

    @grid_bbox
    def get_tile_weights(self) -> Union[Tensor, float]:
        return 1.0


    @custom_bbox
    def init_custom_bbox(self, bbox_settings:Dict[int,BBoxSettings], draw_background:bool, causal_layers:bool):
        self.enable_custom_bbox = True

        self.causal_layers = causal_layers
        self.draw_background = draw_background
        if not draw_background:
            self.enable_grid_bbox = False
            self.weights.zero_()

        self.custom_bboxes: List[CustomBBox] = []
        for bbox_setting in bbox_settings.values():
            e, x, y, w, h, p, n, blend_mode, feather_ratio, seed = bbox_setting
            if not e or x > 1.0 or y > 1.0 or w <= 0.0 or h <= 0.0: continue
            x = int(x * self.w)
            y = int(y * self.h)
            w = math.ceil(w * self.w)
            h = math.ceil(h * self.h)
            x = max(0, x)
            y = max(0, y)
            w = min(self.w - x, w)
            h = min(self.h - y, h)
            self.custom_bboxes.append(CustomBBox(x, y, w, h, p, n, blend_mode, feather_ratio, seed))

        if len(self.custom_bboxes) == 0:
            self.enable_custom_bbox = False
            return

        # prepare cond
        p = self.p
        prompts = p.all_prompts[:p.batch_size]
        neg_prompts = p.all_negative_prompts[:p.batch_size]
        for bbox in self.custom_bboxes:
            bbox.cond, bbox.extra_network_data = Condition.get_custom_cond(prompts, bbox.prompt, p.steps, p.styles)
            bbox.uncond = Condition.get_uncond(Prompt.append_prompt(neg_prompts, bbox.neg_prompt), p.steps, p.styles)
        self.cond_basis = Condition.get_cond(prompts, p.steps)
        self.uncond_basis = Condition.get_uncond(neg_prompts, p.steps)

    @custom_bbox
    def reconstruct_custom_cond(self, org_cond:CondDict, custom_cond:Cond, custom_uncond:Uncond, bbox:CustomBBox) -> Tuple[List, Tensor, Uncond, Tensor]:
        image_conditioning = None
        if isinstance(org_cond, dict):
            icond = self.get_icond(org_cond)
            if icond.shape[2:] == (self.h, self.w):    # img2img
                icond = icond[bbox.slicer]
            image_conditioning = icond

        sampler_step = self.sampler.model_wrap_cfg.step
        tensor = Condition.reconstruct_cond(custom_cond, sampler_step)
        custom_uncond = Condition.reconstruct_uncond(custom_uncond, sampler_step)
        return tensor, custom_uncond, image_conditioning

    @custom_bbox
    def kdiff_custom_forward(self, x_tile:Tensor, sigma_in:Tensor, original_cond:CondDict, bbox_id:int, bbox:CustomBBox, forward_func:Callable) -> Tensor:
        '''
        The inner kdiff noise prediction is usually batched.
        We need to unwrap the inside loop to simulate the batched behavior.
        This can be extremely tricky.
        '''

        sampler_step = self.sampler.model_wrap_cfg.step
        if self.kdiff_step != sampler_step:
            self.kdiff_step = sampler_step
            self.kdiff_step_bbox = [-1 for _ in range(len(self.custom_bboxes))]
            self.tensor = {}        # {int: Tensor[cond]}
            self.uncond = {}        # {int: Tensor[cond]}
            self.image_cond_in = {}
            # Initialize global prompts just for estimate the behavior of kdiff
            self.real_tensor = Condition.reconstruct_cond(self.cond_basis, sampler_step)
            self.real_uncond = Condition.reconstruct_uncond(self.uncond_basis, sampler_step)
            # reset the progress for all bboxes
            self.a = [0 for _ in range(len(self.custom_bboxes))]

        if self.kdiff_step_bbox[bbox_id] != sampler_step:
            # When a new step starts for a bbox, we need to judge whether the tensor is batched.
            self.kdiff_step_bbox[bbox_id] = sampler_step

            tensor, uncond, image_cond_in = self.reconstruct_custom_cond(original_cond, bbox.cond, bbox.uncond, bbox)

            if self.real_tensor.shape[1] == self.real_uncond.shape[1]:
                if shared.batch_cond_uncond:
                    # when the real tensor is with equal length, all information is contained in x_tile.
                    # we simulate the batched behavior and compute all the tensors in one go.
                    if tensor.shape[1] == uncond.shape[1]:
                        # When our prompt tensor is with equal length, we can directly their code.
                        if not self.is_edit_model:
                            cond = torch.cat([tensor, uncond])
                        else:
                            cond = torch.cat([tensor, uncond, uncond])
                        self.set_custom_controlnet_tensors(bbox_id, x_tile.shape[0])
                        self.set_custom_stablesr_tensors(bbox_id)
                        return forward_func(
                            x_tile, 
                            sigma_in, 
                            cond=self.make_cond_dict(original_cond, cond, image_cond_in),
                        )
                    else:
                        # When not, we need to pass the tensor to UNet separately.
                        x_out = torch.zeros_like(x_tile)
                        cond_size = tensor.shape[0]
                        self.set_custom_controlnet_tensors(bbox_id, cond_size)
                        self.set_custom_stablesr_tensors(bbox_id)
                        cond_out = forward_func(
                            x_tile  [:cond_size], 
                            sigma_in[:cond_size], 
                            cond=self.make_cond_dict(original_cond, tensor, image_cond_in[:cond_size]),
                        )
                        uncond_size = uncond.shape[0]
                        self.set_custom_controlnet_tensors(bbox_id, uncond_size)
                        self.set_custom_stablesr_tensors(bbox_id)
                        uncond_out = forward_func(
                            x_tile  [cond_size:cond_size+uncond_size], 
                            sigma_in[cond_size:cond_size+uncond_size], 
                            cond=self.make_cond_dict(original_cond, uncond, image_cond_in[cond_size:cond_size+uncond_size]),
                        )
                        x_out[:cond_size] = cond_out
                        x_out[cond_size:cond_size+uncond_size] = uncond_out
                        if self.is_edit_model:
                            x_out[cond_size+uncond_size:] = uncond_out
                        return x_out
                
            # otherwise, the x_tile is only a partial batch. 
            # We have to denoise in different runs.
            # We store the prompt and neg_prompt tensors for current bbox
            self.tensor[bbox_id] = tensor
            self.uncond[bbox_id] = uncond
            self.image_cond_in[bbox_id] = image_cond_in

        # Now we get current batch of prompt and neg_prompt tensors
        tensor: Tensor = self.tensor[bbox_id]
        uncond: Tensor = self.uncond[bbox_id]
        batch_size = x_tile.shape[0]
        # get the start and end index of the current batch
        a = self.a[bbox_id]
        b = a + batch_size
        self.a[bbox_id] += batch_size

        if self.real_tensor.shape[1] == self.real_uncond.shape[1]:
            # When use --lowvram or --medvram, kdiff will slice the cond and uncond with [a:b]
            # So we need to slice our tensor and uncond with the same index as original kdiff.
            
            # --- original code in kdiff ---
            # if not self.is_edit_model:
            #     cond = torch.cat([tensor, uncond])
            # else:
            #     cond = torch.cat([tensor, uncond, uncond])
            # cond = cond[a:b]
            # ------------------------------
            
            # The original kdiff code is to concat and then slice, but this cannot apply to
            # our custom prompt tensor when tensor.shape[1] != uncond.shape[1]. So we adapt it.
            cond_in, uncond_in = None, None
            # Slice the [prompt, neg prompt, (possibly) neg prompt] with [a:b]
            if not self.is_edit_model:
                if   b <= tensor.shape[0]: cond_in = tensor[a:b]
                elif a >= tensor.shape[0]: cond_in = uncond[a-tensor.shape[0]:b-tensor.shape[0]]
                else:
                    cond_in = tensor[a:]
                    uncond_in = uncond[:b-tensor.shape[0]]
            else:
                if b <= tensor.shape[0]:
                    cond_in = tensor[a:b]
                elif b > tensor.shape[0] and b <= tensor.shape[0] + uncond.shape[0]:
                    if a>= tensor.shape[0]:
                        cond_in = uncond[a-tensor.shape[0]:b-tensor.shape[0]]
                    else:
                        cond_in = tensor[a:]
                        uncond_in = uncond[:b-tensor.shape[0]]
                else:
                    if a >= tensor.shape[0] + uncond.shape[0]:
                        cond_in = uncond[a-tensor.shape[0]-uncond.shape[0]:b-tensor.shape[0]-uncond.shape[0]]
                    elif a >= tensor.shape[0]:
                        cond_in = torch.cat([uncond[a-tensor.shape[0]:], uncond[:b-tensor.shape[0]-uncond.shape[0]]])

            if uncond_in is None or tensor.shape[1] == uncond.shape[1]:
                # If the tensor can be passed to UNet in one go, do it.
                if uncond_in is not None:
                    cond_in = torch.cat([cond_in, uncond_in])
                self.set_custom_controlnet_tensors(bbox_id, x_tile.shape[0])
                self.set_custom_stablesr_tensors(bbox_id)
                return forward_func(
                    x_tile, 
                    sigma_in, 
                    cond=self.make_cond_dict(original_cond, cond_in, self.image_cond_in[bbox_id]),
                )
            else:
                # If not, we need to pass the tensor to UNet separately.
                x_out = torch.zeros_like(x_tile)
                cond_size = cond_in.shape[0]
                self.set_custom_controlnet_tensors(bbox_id, cond_size)
                self.set_custom_stablesr_tensors(bbox_id)
                cond_out = forward_func(
                    x_tile  [:cond_size], 
                    sigma_in[:cond_size], 
                    cond=self.make_cond_dict(original_cond, cond_in, self.image_cond_in[bbox_id])
                )
                self.set_custom_controlnet_tensors(bbox_id, uncond_in.shape[0])
                self.set_custom_stablesr_tensors(bbox_id)
                uncond_out = forward_func(
                    x_tile  [cond_size:], 
                    sigma_in[cond_size:], 
                    cond=self.make_cond_dict(original_cond, uncond_in, self.image_cond_in[bbox_id])
                )
                x_out[:cond_size] = cond_out
                x_out[cond_size:] = uncond_out
                return x_out

        # If the original prompt is with different length, 
        # kdiff will deal with the cond and uncond separately.
        # Hence we also deal with the tensor and uncond separately.
        # get the start and end index of the current batch

        if a < tensor.shape[0]:
            # Deal with custom prompt tensor
            if not self.is_edit_model:
                c_crossattn = tensor[a:b]
            else:
                c_crossattn = torch.cat([tensor[a:b]], uncond)
            self.set_custom_controlnet_tensors(bbox_id, x_tile.shape[0])
            self.set_custom_stablesr_tensors(bbox_id)
            # complete this batch.
            return forward_func(
                x_tile, 
                sigma_in, 
                cond=self.make_cond_dict(original_cond, c_crossattn, self.image_cond_in[bbox_id])
            )
        else:
            # if the cond is finished, we need to process the uncond.
            self.set_custom_controlnet_tensors(bbox_id, uncond.shape[0])
            self.set_custom_stablesr_tensors(bbox_id)
            return forward_func(
                x_tile, 
                sigma_in, 
                cond=self.make_cond_dict(original_cond, uncond, self.image_cond_in[bbox_id])
            )

    @custom_bbox
    def ddim_custom_forward(self, x:Tensor, cond_in:CondDict, bbox:CustomBBox, ts:Tensor, forward_func:Callable, *args, **kwargs) -> Tensor:
        ''' draw custom bbox '''

        tensor, uncond, image_conditioning = self.reconstruct_custom_cond(cond_in, bbox.cond, bbox.uncond, bbox)

        cond = tensor
        # for DDIM, shapes definitely match. So we dont need to do the same thing as in the KDIFF sampler.
        if uncond.shape[1] < cond.shape[1]:
            last_vector = uncond[:, -1:]
            last_vector_repeated = last_vector.repeat([1, cond.shape[1] - uncond.shape[1], 1])
            uncond = torch.hstack([uncond, last_vector_repeated])
        elif uncond.shape[1] > cond.shape[1]:
            uncond = uncond[:, :cond.shape[1]]

        # Wrap the image conditioning back up since the DDIM code can accept the dict directly.
        # Note that they need to be lists because it just concatenates them later.
        if image_conditioning is not None:
            cond   = self.make_cond_dict(cond_in, cond,   image_conditioning)
            uncond = self.make_cond_dict(cond_in, uncond, image_conditioning)
        
        # We cannot determine the batch size here for different methods, so delay it to the forward_func.
        return forward_func(x, cond, ts, unconditional_conditioning=uncond, *args, **kwargs)


    @controlnet
    def init_controlnet(self, controlnet_script:ModuleType, control_tensor_cpu:bool):
        self.enable_controlnet = True

        self.controlnet_script = controlnet_script
        self.control_tensor_cpu = control_tensor_cpu
        self.control_tensor_batch = None
        self.control_params = None
        self.control_tensor_custom = []
        
        self.prepare_controlnet_tensors()

    @controlnet
    def reset_controlnet_tensors(self):
        if not self.enable_controlnet: return
        if self.control_tensor_batch is None: return

        for param_id in range(len(self.control_params)):
            self.control_params[param_id].hint_cond = self.org_control_tensor_batch[param_id]

    @controlnet
    def prepare_controlnet_tensors(self, refresh:bool=False):
        ''' Crop the control tensor into tiles and cache them '''

        if not refresh:
            if self.control_tensor_batch is not None or self.control_params is not None: return

        if not self.enable_controlnet or self.controlnet_script is None: return

        latest_network = self.controlnet_script.latest_network
        if latest_network is None or not hasattr(latest_network, 'control_params'): return

        self.control_params = latest_network.control_params
        tensors = [param.hint_cond for param in latest_network.control_params]
        self.org_control_tensor_batch = tensors

        if len(tensors) == 0: return

        self.control_tensor_batch = []
        for i in range(len(tensors)):
            control_tile_list = []
            control_tensor = tensors[i]
            for bboxes in self.batched_bboxes:
                single_batch_tensors = []
                for bbox in bboxes:
                    if len(control_tensor.shape) == 3:
                        control_tensor.unsqueeze_(0)
                    control_tile = control_tensor[:, :, bbox[1]*opt_f:bbox[3]*opt_f, bbox[0]*opt_f:bbox[2]*opt_f]
                    single_batch_tensors.append(control_tile)
                control_tile = torch.cat(single_batch_tensors, dim=0)
                if self.control_tensor_cpu:
                    control_tile = control_tile.cpu()
                control_tile_list.append(control_tile)
            self.control_tensor_batch.append(control_tile_list)

            if len(self.custom_bboxes) > 0:
                custom_control_tile_list = []
                for bbox in self.custom_bboxes:
                    if len(control_tensor.shape) == 3:
                        control_tensor.unsqueeze_(0)
                    control_tile = control_tensor[:, :, bbox[1]*opt_f:bbox[3]*opt_f, bbox[0]*opt_f:bbox[2]*opt_f]
                    if self.control_tensor_cpu:
                        control_tile = control_tile.cpu()
                    custom_control_tile_list.append(control_tile)
                self.control_tensor_custom.append(custom_control_tile_list)

    @controlnet
    def switch_controlnet_tensors(self, batch_id:int, x_batch_size:int, tile_batch_size:int, is_denoise=False):
        if not self.enable_controlnet: return
        if self.control_tensor_batch is None: return

        for param_id in range(len(self.control_params)):
            control_tile = self.control_tensor_batch[param_id][batch_id]
            if self.is_kdiff:
                all_control_tile = []
                for i in range(tile_batch_size):
                    this_control_tile = [control_tile[i].unsqueeze(0)] * x_batch_size
                    all_control_tile.append(torch.cat(this_control_tile, dim=0))
                control_tile = torch.cat(all_control_tile, dim=0)                                           
            else:
                control_tile = control_tile.repeat([x_batch_size if is_denoise else x_batch_size * 2, 1, 1, 1])
            self.control_params[param_id].hint_cond = control_tile.to(devices.device)

    @controlnet
    def set_custom_controlnet_tensors(self, bbox_id:int, repeat_size:int):
        if not self.enable_controlnet: return
        if not len(self.control_tensor_custom): return
        
        for param_id in range(len(self.control_params)):
            control_tensor = self.control_tensor_custom[param_id][bbox_id].to(devices.device)
            self.control_params[param_id].hint_cond = control_tensor.repeat((repeat_size, 1, 1, 1))


    @stablesr
    def init_stablesr(self, stablesr_script:ModuleType):
        if stablesr_script.stablesr_model is None: return
        self.stablesr_script = stablesr_script
        def set_image_hook(latent_image):
            self.enable_stablesr = True
            self.stablesr_tensor = latent_image
            self.stablesr_tensor_batch = []
            for bboxes in self.batched_bboxes:
                single_batch_tensors = []
                for bbox in bboxes:
                    stablesr_tile = self.stablesr_tensor[:, :, bbox[1]:bbox[3], bbox[0]:bbox[2]]
                    single_batch_tensors.append(stablesr_tile)
                stablesr_tile = torch.cat(single_batch_tensors, dim=0)
                self.stablesr_tensor_batch.append(stablesr_tile)
            if len(self.custom_bboxes) > 0:
                self.stablesr_tensor_custom = []
                for bbox in self.custom_bboxes:
                    stablesr_tile = self.stablesr_tensor[:, :, bbox[1]:bbox[3], bbox[0]:bbox[2]]
                    self.stablesr_tensor_custom.append(stablesr_tile)

        stablesr_script.stablesr_model.set_image_hooks['TiledDiffusion'] = set_image_hook

    @stablesr
    def reset_stablesr_tensors(self):
        if not self.enable_stablesr: return
        if self.stablesr_script.stablesr_model is None: return
        self.stablesr_script.stablesr_model.latent_image = self.stablesr_tensor

    @stablesr
    def switch_stablesr_tensors(self, batch_id:int):
        if not self.enable_stablesr: return
        if self.stablesr_script.stablesr_model is None: return
        if self.stablesr_tensor_batch is None: return
        self.stablesr_script.stablesr_model.latent_image = self.stablesr_tensor_batch[batch_id]

    @stablesr
    def set_custom_stablesr_tensors(self, bbox_id:int):
        if not self.enable_stablesr: return
        if self.stablesr_script.stablesr_model is None: return
        if not len(self.stablesr_tensor_custom): return
        self.stablesr_script.stablesr_model.latent_image = self.stablesr_tensor_custom[bbox_id]


    @noise_inverse
    def init_noise_inverse(self, steps:int, retouch:float, get_cache_callback, set_cache_callback, renoise_strength:float, renoise_kernel:int):
        self.noise_inverse_enabled = True
        self.noise_inverse_steps = steps
        self.noise_inverse_retouch = float(retouch)
        self.noise_inverse_renoise_strength = float(renoise_strength)
        self.noise_inverse_renoise_kernel = int(renoise_kernel)
        if self.sample_img2img_original is None:
            self.sample_img2img_original = self.sampler_raw.sample_img2img
        self.sampler_raw.sample_img2img = MethodType(self.sample_img2img, self.sampler_raw)
        self.noise_inverse_set_cache = set_cache_callback
        self.noise_inverse_get_cache = get_cache_callback

    @noise_inverse
    @keep_signature
    def sample_img2img(self, sampler: KDiffusionSampler, p:ProcessingImg2Img, 
                       x:Tensor, noise:Tensor, conditioning, unconditional_conditioning,
                       steps=None, image_conditioning=None):
        # noise inverse sampling - renoise mask
        import torch.nn.functional as F
        renoise_mask = None
        if self.noise_inverse_renoise_strength > 0:
            image = p.init_images[0]
            # convert to grayscale with PIL
            image = image.convert('L')
            np_mask = get_retouch_mask(np.asarray(image), self.noise_inverse_renoise_kernel)
            renoise_mask = torch.from_numpy(np_mask).to(noise.device)
            # resize retouch mask to match noise size
            renoise_mask = 1 - F.interpolate(renoise_mask.unsqueeze(0).unsqueeze(0), size=noise.shape[-2:], mode='bilinear').squeeze(0).squeeze(0)
            renoise_mask *= self.noise_inverse_renoise_strength
            renoise_mask = torch.clamp(renoise_mask, 0, 1)

        prompts = p.all_prompts[:p.batch_size]
        
        latent = None
        # try to use cached latent to save huge amount of time.
        cached_latent: NoiseInverseCache = self.noise_inverse_get_cache()
        if cached_latent is not None and \
            cached_latent.model_hash == p.sd_model.sd_model_hash and \
            cached_latent.noise_inversion_steps == self.noise_inverse_steps and \
            len(cached_latent.prompts) == len(prompts) and \
            all([cached_latent.prompts[i] == prompts[i] for i in range(len(prompts))]) and \
            abs(cached_latent.retouch - self.noise_inverse_retouch) < 0.01 and \
            cached_latent.x0.shape == p.init_latent.shape and \
            torch.abs(cached_latent.x0.to(p.init_latent.device) - p.init_latent).sum() < 100: # the 100 is an arbitrary threshold copy-pasted from the img2img alt code
                # use cached noise
                print('[Tiled Diffusion] Your checkpoint, image, prompts, inverse steps, and retouch params are all unchanged.')
                print('[Tiled Diffusion] Noise Inversion will use the cached noise from the previous run. To clear the cache, click the Free GPU button.')
                latent = cached_latent.xt.to(noise.device)
        if latent is None:
            # run noise inversion
            shared.state.job_count += 1
            latent = self.find_noise_for_image_sigma_adjustment(sampler.model_wrap, self.noise_inverse_steps, prompts)
            shared.state.nextjob()
            self.noise_inverse_set_cache(p.init_latent.clone().cpu(), latent.clone().cpu(), prompts)
            # The cache is only 1 latent image and is very small (16 MB for 8192 * 8192 image), so we don't need to worry about memory leakage.

        # calculate sampling steps
        adjusted_steps, _ = sd_samplers_common.setup_img2img_steps(p, steps)
        sigmas = sampler.get_sigmas(p, adjusted_steps)
        inverse_noise = latent - (p.init_latent / sigmas[0])

        # inject noise to high-frequency area so that the details won't lose too much
        if renoise_mask is not None:
            # If the background is not drawn, we need to filter out the un-drawn pixels and reweight foreground with feather mask
            # This is to enable the renoise mask in regional inpainting
            if not self.enable_grid_bbox:
                background_count  = torch.zeros((1, 1, noise.shape[2], noise.shape[3]), device=noise.device)
                foreground_noise  = torch.zeros_like(noise)
                foreground_weight = torch.zeros((1, 1, noise.shape[2], noise.shape[3]), device=noise.device)
                foreground_count  = torch.zeros((1, 1, noise.shape[2], noise.shape[3]), device=noise.device)
                for bbox in self.custom_bboxes:
                    if bbox.blend_mode == BlendMode.BACKGROUND:
                        background_count[bbox.slicer] += 1
                    elif bbox.blend_mode == BlendMode.FOREGROUND:
                        foreground_noise [bbox.slicer] += noise[bbox.slicer]
                        foreground_weight[bbox.slicer] += bbox.feather_mask
                        foreground_count [bbox.slicer] += 1
                background_noise  = torch.where(background_count > 0, noise, 0)
                foreground_noise  = torch.where(foreground_count > 0, foreground_noise  / foreground_count, 0)
                foreground_weight = torch.where(foreground_count > 0, foreground_weight / foreground_count, 0)
                noise = background_noise * (1 - foreground_weight) + foreground_noise * foreground_weight
                del background_noise, foreground_noise, foreground_weight, background_count, foreground_count
            combined_noise = ((1 - renoise_mask) * inverse_noise + renoise_mask * noise) / ((renoise_mask**2 + (1 - renoise_mask)**2) ** 0.5)
        else:
            combined_noise = inverse_noise

        # use the estimated noise for the original img2img sampling
        return self.sample_img2img_original(p, x, combined_noise, conditioning, unconditional_conditioning, steps, image_conditioning)

    @noise_inverse
    @torch.no_grad()
    def find_noise_for_image_sigma_adjustment(self, dnw, steps, prompts:List[str]) -> Tensor:
        '''
        Migrate from the built-in script img2imgalt.py
        Tiled noise inverse for better image upscaling
        '''
        import k_diffusion as K
        assert self.p.sampler_name == 'Euler'

        x = self.p.init_latent
        s_in = x.new_ones([x.shape[0]])
        skip = 1 if shared.sd_model.parameterization == "v" else 0
        sigmas = dnw.get_sigmas(steps).flip(0)
    
        cond = self.p.sd_model.get_learned_conditioning(prompts)
        if isinstance(cond, Tensor):    # SD1/SD2
            cond_dict_dummy = {
                'c_crossattn': [],      # List[Tensor]
                'c_concat': [],         # List[Tensor]
            }
            cond_in = self.make_cond_dict(cond_dict_dummy, cond, self.p.image_conditioning)
        else:                           # SDXL
            cond_dict_dummy = {
                'crossattn': None,      # Tensor
                'vector': None,         # Tensor
                'c_concat': [],         # List[Tensor]
            }
            cond_in = self.make_cond_dict(cond_dict_dummy, cond['crossattn'], self.p.image_conditioning, cond['vector'])

        state.sampling_steps = steps
        pbar = tqdm(total=steps, desc='Noise Inversion')
        for i in range(1, len(sigmas)):
            if state.interrupted: return x

            state.sampling_step += 1

            x_in = x
            sigma_in = torch.cat([sigmas[i] * s_in])
            c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)[skip:]]

            t = dnw.sigma_to_t(sigma_in)
            t = t / self.noise_inverse_retouch

            eps = self.get_noise(x_in * c_in, t, cond_in, steps - i)
            denoised = x_in + eps * c_out

            # Euler method:
            d = (x - denoised) / sigmas[i]
            dt = sigmas[i] - sigmas[i - 1]
            x = x + d * dt

            sd_samplers_common.store_latent(x)

            # This is neccessary to save memory before the next iteration
            del x_in, sigma_in, c_out, c_in, t,
            del eps, denoised, d, dt

            pbar.update(1)
        pbar.close()

        return x / sigmas[-1]
    
    @noise_inverse
    @torch.no_grad()
    def get_noise(self, x_in: Tensor, sigma_in:Tensor, cond_in:Dict[str, Tensor], step:int) -> Tensor:
        raise NotImplementedError