File size: 3,929 Bytes
8d7ec14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import torch

from modules_forge.supported_preprocessor import Preprocessor, PreprocessorParameter
from modules_forge.shared import add_supported_preprocessor


def blur(x, k):
    y = torch.nn.functional.pad(x, (k, k, k, k), mode='replicate')
    y = torch.nn.functional.avg_pool2d(y, (k * 2 + 1, k * 2 + 1), stride=(1, 1))
    return y


class PreprocessorTile(Preprocessor):
    def __init__(self):
        super().__init__()
        self.name = 'tile_resample'
        self.tags = ['Tile']
        self.model_filename_filters = ['tile']
        self.slider_resolution = PreprocessorParameter(visible=False)
        self.latent = None

    def register_latent(self, process, cond):
        vae = process.sd_model.forge_objects.vae
        # This is a powerful VAE with integrated memory management, bf16, and tiled fallback.

        latent_image = vae.encode(cond.movedim(1, -1))
        latent_image = process.sd_model.forge_objects.unet.model.latent_format.process_in(latent_image)
        self.latent = latent_image
        return self.latent


class PreprocessorTileColorFix(PreprocessorTile):
    def __init__(self):
        super().__init__()
        self.name = 'tile_colorfix'
        self.slider_1 = PreprocessorParameter(label='Variation', value=8.0, minimum=3.0, maximum=32.0, step=1.0, visible=True)
        self.variation = 8
        self.sharpness = None

    def process_before_every_sampling(self, process, cond, mask, *args, **kwargs):
        self.variation = int(kwargs['unit'].threshold_a)

        latent = self.register_latent(process, cond)

        unet = process.sd_model.forge_objects.unet.clone()
        sigma_data = process.sd_model.forge_objects.unet.model.model_sampling.sigma_data

        if getattr(process, 'is_hr_pass', False):
            k = int(self.variation * 2)
        else:
            k = int(self.variation)

        def block_proc(h, flag, transformer_options):
            location, block_id = transformer_options['block']
            cond_mark = transformer_options['cond_mark'][:, None, None, None]  # cond is 0

            if location == 'input' and block_id == 0 and flag == 'before':
                sigma = transformer_options['sigmas'].to(h)
                self.x_input = h[:, :4]  # Inpaint fix
                self.x_input_sigma_space = self.x_input * (sigma ** 2 + sigma_data ** 2) ** 0.5

            if location == 'last' and block_id == 0 and flag == 'after':
                sigma = transformer_options['sigmas'].to(h)
                eps_estimation = h[:, :4]
                denoised = self.x_input_sigma_space - eps_estimation * sigma

                denoised = denoised - blur(denoised, k) + blur(latent.to(denoised), k)

                if isinstance(self.sharpness, float):
                    detail_weight = float(self.sharpness) * 0.01
                    neg = detail_weight * blur(denoised, k) + (1 - detail_weight) * denoised
                    denoised = (1 - cond_mark) * denoised + cond_mark * neg

                eps_modified = (self.x_input_sigma_space - denoised) / sigma

                return eps_modified

            return h

        unet.add_block_modifier(block_proc)

        process.sd_model.forge_objects.unet = unet

        return cond, mask


class PreprocessorTileColorFixSharp(PreprocessorTileColorFix):
    def __init__(self):
        super().__init__()
        self.name = 'tile_colorfix+sharp'
        self.slider_2 = PreprocessorParameter(label='Sharpness', value=1.0, minimum=0.0, maximum=2.0, step=0.01, visible=True)

    def process_before_every_sampling(self, process, cond, mask, *args, **kwargs):
        self.sharpness = float(kwargs['unit'].threshold_b)
        return super().process_before_every_sampling(process, cond, mask, *args, **kwargs)


add_supported_preprocessor(PreprocessorTile())

add_supported_preprocessor(PreprocessorTileColorFix())

add_supported_preprocessor(PreprocessorTileColorFixSharp())