File size: 11,205 Bytes
c336648
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
import gradio
import torch

import modules.devices as devices
import modules.scripts as scripts
import modules.script_callbacks as script_callbacks
import modules.sd_unet as sd_unet
import modules.shared as shared

from ldm.modules.diffusionmodules.util import timestep_embedding as timestep_embedding

class DeepShrinkHiresFixAction():
    def __init__(self, enable: bool, timestep: float, depth: int, scale: float):
        self.enable = enable
        self.timestep = timestep
        self.depth = depth
        self.scale = scale
        pass
    pass

class DeepShrinkHiresFix(scripts.Script):
    deepShrinkHiresFixActions: list[DeepShrinkHiresFixAction] = []
    enableExperimental: bool = False
    experimentalTimestep: float = 900
    experimentalScales: list[float] = []

    def __init__(self):
        pass

    def title(self):
        return "Deep Shrink Hires.fix"
        pass

    def show(self, is_img2img):
        return scripts.AlwaysVisible
        pass

    def ui(self, is_img2img):
        with gradio.Accordion(label="Deep Shrink Hires.fix", open=False):
            with gradio.Row():
                Enable_1 = gradio.Checkbox(value=True, label="Enable 1")
                Timestep_1 = gradio.Number(value=900, label="Timestep 1")
                Depth_1 = gradio.Number(value=3, label="Block Depth 1", precision=0)
                Scale_1 = gradio.Number(value=2, label="Scale factor 1")
                pass
            with gradio.Row():
                Enable_2 = gradio.Checkbox(value=True, label="Enable 2")
                Timestep_2 = gradio.Number(value=650, label="Timestep 2")
                Depth_2 = gradio.Number(value=3, label="Block Depth 2", precision=0)
                Scale_2 = gradio.Number(value=2, label="Scale factor 2")
                pass
            with gradio.Accordion(label="Advanced Settings", open=False):
                with gradio.Row():
                    Enable_3 = gradio.Checkbox(value=False, label="Enable 3")
                    Timestep_3 = gradio.Number(value=900, label="Timestep 3")
                    Depth_3 = gradio.Number(value=3, label="Block Depth 3", precision=0)
                    Scale_3 = gradio.Number(value=2, label="Scale factor 3")
                    pass
                with gradio.Row():
                    Enable_4 = gradio.Checkbox(value=False, label="Enable 4")
                    Timestep_4 = gradio.Number(value=650, label="Timestep 4")
                    Depth_4 = gradio.Number(value=3, label="Block Depth 4", precision=0)
                    Scale_4 = gradio.Number(value=2, label="Scale factor 4")
                    pass
                with gradio.Row():
                    Enable_5 = gradio.Checkbox(value=False, label="Enable 5")
                    Timestep_5 = gradio.Number(value=900, label="Timestep 5")
                    Depth_5 = gradio.Number(value=3, label="Block Depth 5", precision=0)
                    Scale_5 = gradio.Number(value=2, label="Scale factor 5")
                    pass
                with gradio.Row():
                    Enable_6 = gradio.Checkbox(value=False, label="Enable 6")
                    Timestep_6 = gradio.Number(value=650, label="Timestep 6")
                    Depth_6 = gradio.Number(value=3, label="Block Depth 6", precision=0)
                    Scale_6 = gradio.Number(value=2, label="Scale factor 6")
                    pass
                with gradio.Row():
                    Enable_7 = gradio.Checkbox(value=False, label="Enable 7")
                    Timestep_7 = gradio.Number(value=900, label="Timestep 7")
                    Depth_7 = gradio.Number(value=3, label="Block Depth 7", precision=0)
                    Scale_7 = gradio.Number(value=2, label="Scale factor 7")
                    pass
                with gradio.Row():
                    Enable_8 = gradio.Checkbox(value=False, label="Enable 8")
                    Timestep_8 = gradio.Number(value=650, label="Timestep 8")
                    Depth_8 = gradio.Number(value=3, label="Block Depth 8", precision=0)
                    Scale_8 = gradio.Number(value=2, label="Scale factor 8")
                    pass
                pass
            with gradio.Accordion(label="Experimental Settings", open=False):
                with gradio.Row():
                    Enable_Experimental = gradio.Checkbox(value=False, label="Enable Experimental Mode")
                    Timestep_Experimental = gradio.Number(value=900, label="Timestep")
                    Scale_Experimental = gradio.Textbox(value="1,1,1, 1,1,1, 1,1,1, 1,1,1, 2, 1,1,1, 1,1,1, 1,1,1, 1,1,1", label="Scale Factor List")
                    pass
                pass
            pass
        return [Enable_1, Timestep_1, Depth_1, Scale_1, Enable_2, Timestep_2, Depth_2, Scale_2, Enable_3, Timestep_3, Depth_3, Scale_3, Enable_4, Timestep_4, Depth_4, Scale_4,
                Enable_5, Timestep_5, Depth_5, Scale_5, Enable_6, Timestep_6, Depth_6, Scale_6, Enable_7, Timestep_7, Depth_7, Scale_7, Enable_8, Timestep_8, Depth_8, Scale_8,
                Enable_Experimental, Timestep_Experimental, Scale_Experimental]
        pass

    def process(self, p, *args):
        del DeepShrinkHiresFix.deepShrinkHiresFixActions[:]
        for i in range(8):
            DeepShrinkHiresFix.deepShrinkHiresFixActions.append(DeepShrinkHiresFixAction(args[i*4], args[i*4+1], args[i*4+2], args[i*4+3]))
            pass
        del DeepShrinkHiresFix.experimentalScales[:]
        DeepShrinkHiresFix.enableExperimental = args[8*4]
        DeepShrinkHiresFix.experimentalTimestep = args[8*4+1]
        scaleFactorsTexts: str = args[8*4+2]
        scaleFactorsTextsList = scaleFactorsTexts.split(",")
        for scaleFactorsText in scaleFactorsTextsList:
            DeepShrinkHiresFix.experimentalScales.append(float(scaleFactorsText))
            pass
        pass

    class DeepShrinkHiresFixUNet(sd_unet.SdUnet):
        def __init__(self, _model):
            super().__init__()
            self.model = _model.to(devices.device)
            pass
        def forward(self, x, timesteps, context, y=None, **kwargs):
            assert (y is not None) == (
                self.model.num_classes is not None
            ), "must specify y if and only if the model is class-conditional"
            hs = []
            ss = []
            t_emb = timestep_embedding(timesteps, self.model.model_channels, repeat_only=False)
            emb = self.model.time_embed(t_emb)

            if self.model.num_classes is not None:
                assert y.shape[0] == x.shape[0]
                emb = emb + self.model.label_emb(y)

            h = x.type(self.model.dtype)
            depth = 0
            block = 0
            scale = 1
            for module in self.model.input_blocks:
                for action in DeepShrinkHiresFix.deepShrinkHiresFixActions:
                    if action.enable == True and action.depth == depth and action.timestep < timesteps[0]:
                        h = torch.nn.functional.interpolate(h.float(), scale_factor=1/action.scale, mode="bicubic", align_corners=False).to(h.dtype)
                        break
                        pass
                    pass
                if DeepShrinkHiresFix.enableExperimental and timesteps[0] >= DeepShrinkHiresFix.experimentalTimestep:
                    h = torch.nn.functional.interpolate(h.float(), scale_factor=scale/DeepShrinkHiresFix.experimentalScales[block], mode="bicubic", align_corners=False).to(h.dtype)
                    scale = DeepShrinkHiresFix.experimentalScales[block]
                    ss.append(scale)
                    pass
                h = module(h, emb, context)
                hs.append(h)
                depth += 1
                block += 1
                pass

            for action in DeepShrinkHiresFix.deepShrinkHiresFixActions:
                if action.enable == True and action.depth == depth and action.timestep < timesteps[0]:
                    h = torch.nn.functional.interpolate(h.float(), scale_factor=1/action.scale, mode="bicubic", align_corners=False).to(h.dtype)
                    break
                    pass
                pass
            if DeepShrinkHiresFix.enableExperimental and timesteps[0] >= DeepShrinkHiresFix.experimentalTimestep:
                h = torch.nn.functional.interpolate(h.float(), scale_factor=scale/DeepShrinkHiresFix.experimentalScales[block], mode="bicubic", align_corners=False).to(h.dtype)
                scale = DeepShrinkHiresFix.experimentalScales[block]
                pass

            h = self.model.middle_block(h, emb, context)

            for action in DeepShrinkHiresFix.deepShrinkHiresFixActions:
                if action.enable == True and action.depth == depth and action.timestep < timesteps[0]:
                    h = torch.nn.functional.interpolate(h.float(), scale_factor=action.scale, mode="bicubic", align_corners=False).to(h.dtype)
                    break
                    pass
                pass
            block += 1

            for module in self.model.output_blocks:
                depth -= 1
                if DeepShrinkHiresFix.enableExperimental and timesteps[0] >= DeepShrinkHiresFix.experimentalTimestep:
                    h = torch.cat([torch.nn.functional.interpolate(h.float(), scale_factor=scale/DeepShrinkHiresFix.experimentalScales[block], mode="bicubic", align_corners=False).to(h.dtype), 
                                   torch.nn.functional.interpolate(hs.pop().float(), scale_factor=ss.pop()/DeepShrinkHiresFix.experimentalScales[block], mode="bicubic", align_corners=False).to(h.dtype)], dim=1)
                    scale = DeepShrinkHiresFix.experimentalScales[block]
                    pass
                else:
                    h = torch.cat([h, hs.pop()], dim=1)
                    pass
                h = module(h, emb, context)
                for action in DeepShrinkHiresFix.deepShrinkHiresFixActions:
                    if action.enable == True and action.depth == depth and action.timestep < timesteps[0]:
                        h = torch.nn.functional.interpolate(h.float(), scale_factor=action.scale, mode="bicubic", align_corners=False).to(h.dtype)
                        break
                        pass
                    pass
                block += 1
                pass
            if DeepShrinkHiresFix.enableExperimental and timesteps[0] >= DeepShrinkHiresFix.experimentalTimestep:
                h = torch.nn.functional.interpolate(h.float(), scale_factor=scale, mode="bicubic", align_corners=False).to(h.dtype)
                pass
            h = h.type(x.dtype)
            if self.model.predict_codebook_ids:
                return self.model.id_predictor(h)
            else:
                return self.model.out(h)
            pass
        pass

    DeepShrinkHiresFixUNetOption = sd_unet.SdUnetOption()
    DeepShrinkHiresFixUNetOption.label = "Deep Shrink Hires.fix"
    DeepShrinkHiresFixUNetOption.create_unet = lambda: DeepShrinkHiresFix.DeepShrinkHiresFixUNet(shared.sd_model.model.diffusion_model)

    pass

script_callbacks.on_list_unets(lambda unets: unets.append(DeepShrinkHiresFix.DeepShrinkHiresFixUNetOption))