File size: 2,946 Bytes
359fa44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from typing_extensions import override
import comfy.utils
from comfy_api.latest import ComfyExtension, io


class PatchModelAddDownscale(io.ComfyNode):
    UPSCALE_METHODS = ["bicubic", "nearest-exact", "bilinear", "area", "bislerp"]
    @classmethod
    def define_schema(cls):
        return io.Schema(
            node_id="PatchModelAddDownscale",
            display_name="PatchModelAddDownscale (Kohya Deep Shrink)",
            category="model_patches/unet",
            inputs=[
                io.Model.Input("model"),
                io.Int.Input("block_number", default=3, min=1, max=32, step=1),
                io.Float.Input("downscale_factor", default=2.0, min=0.1, max=9.0, step=0.001),
                io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001),
                io.Float.Input("end_percent", default=0.35, min=0.0, max=1.0, step=0.001),
                io.Boolean.Input("downscale_after_skip", default=True),
                io.Combo.Input("downscale_method", options=cls.UPSCALE_METHODS),
                io.Combo.Input("upscale_method", options=cls.UPSCALE_METHODS),
            ],
            outputs=[
                io.Model.Output(),
            ],
        )

    @classmethod
    def execute(cls, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method) -> io.NodeOutput:
        model_sampling = model.get_model_object("model_sampling")
        sigma_start = model_sampling.percent_to_sigma(start_percent)
        sigma_end = model_sampling.percent_to_sigma(end_percent)

        def input_block_patch(h, transformer_options):
            if transformer_options["block"][1] == block_number:
                sigma = transformer_options["sigmas"][0].item()
                if sigma <= sigma_start and sigma >= sigma_end:
                    h = comfy.utils.common_upscale(h, round(h.shape[-1] * (1.0 / downscale_factor)), round(h.shape[-2] * (1.0 / downscale_factor)), downscale_method, "disabled")
            return h

        def output_block_patch(h, hsp, transformer_options):
            if h.shape[2] != hsp.shape[2]:
                h = comfy.utils.common_upscale(h, hsp.shape[-1], hsp.shape[-2], upscale_method, "disabled")
            return h, hsp

        m = model.clone()
        if downscale_after_skip:
            m.set_model_input_block_patch_after_skip(input_block_patch)
        else:
            m.set_model_input_block_patch(input_block_patch)
        m.set_model_output_block_patch(output_block_patch)
        return io.NodeOutput(m)


NODE_DISPLAY_NAME_MAPPINGS = {
    # Sampling
    "PatchModelAddDownscale": "",
}

class ModelDownscaleExtension(ComfyExtension):
    @override
    async def get_node_list(self) -> list[type[io.ComfyNode]]:
        return [
            PatchModelAddDownscale,
        ]


async def comfy_entrypoint() -> ModelDownscaleExtension:
    return ModelDownscaleExtension()