File size: 1,892 Bytes
ecc4278
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import logging
import gradio as gr
from modules import scripts
from RescaleCFG.nodes_RescaleCFG import RescaleCFG

class RescaleCFGScript(scripts.Script):
    def __init__(self):
        self.enabled = False
        self.multiplier = 0.7

    sorting_priority = 15

    def title(self):
        return "RescaleCFG for reForge"

    def show(self, is_img2img):
        return scripts.AlwaysVisible

    def ui(self, *args, **kwargs):
        with gr.Accordion(open=False, label=self.title()):
            gr.HTML("<p><i>Adjust the settings for RescaleCFG.</i></p>")
            enabled = gr.Checkbox(label="Enable RescaleCFG", value=self.enabled)
            multiplier = gr.Slider(label="RescaleCFG Multiplier", minimum=0.0, maximum=1.0, step=0.01, value=self.multiplier)

        enabled.change(
            lambda x: self.update_enabled(x),
            inputs=[enabled]
        )

        return (enabled, multiplier)

    def update_enabled(self, value):
        self.enabled = value

    def process_before_every_sampling(self, p, *args, **kwargs):
        if len(args) >= 2:
            self.enabled, self.multiplier = args[:2]
        else:
            logging.warning("Not enough arguments provided to process_before_every_sampling")
            return

        # Always start with a fresh clone of the original unet
        unet = p.sd_model.forge_objects.unet.clone()

        if not self.enabled:
            # Reset the unet to its original state
            p.sd_model.forge_objects.unet = unet
            return

        unet = RescaleCFG().patch(unet, self.multiplier)[0]

        p.sd_model.forge_objects.unet = unet
        p.extra_generation_params.update({
            "rescale_cfg_enabled": True,
            "rescale_cfg_multiplier": self.multiplier,
        })

        logging.debug(f"RescaleCFG: Enabled: {self.enabled}, Multiplier: {self.multiplier}")

        return