File size: 2,059 Bytes
3dabe4a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
import modules.scripts as scripts
from modules import extra_networks
from modules.processing import StableDiffusionProcessing
import gradio as gr
from loractl.lib import utils, plot, lora_ctl_network, network_patch
class LoraCtlScript(scripts.Script):
def __init__(self):
self.original_network = None
super().__init__()
def title(self):
return "Dynamic Lora Weights"
def show(self, is_img2img):
return scripts.AlwaysVisible
def ui(self, is_img2img):
with gr.Group():
with gr.Accordion("Dynamic Lora Weights", open=False):
opt_enable = gr.Checkbox(
value=True, label="Enable Dynamic Lora Weights")
opt_plot_lora_weight = gr.Checkbox(
value=False, label="Plot the LoRA weight in all steps")
return [opt_enable, opt_plot_lora_weight]
def process(self, p: StableDiffusionProcessing, opt_enable=True, opt_plot_lora_weight=False, **kwargs):
if opt_enable and type(extra_networks.extra_network_registry["lora"]) != lora_ctl_network.LoraCtlNetwork:
self.original_network = extra_networks.extra_network_registry["lora"]
network = lora_ctl_network.LoraCtlNetwork()
extra_networks.register_extra_network(network)
extra_networks.register_extra_network_alias(network, "loractl")
elif not opt_enable and type(extra_networks.extra_network_registry["lora"]) != lora_ctl_network.LoraCtlNetwork.__bases__[0]:
extra_networks.register_extra_network(self.original_network)
self.original_network = None
network_patch.apply()
utils.set_hires(False)
utils.set_active(opt_enable)
lora_ctl_network.reset_weights()
plot.reset_plot()
def before_hr(self, p, *args):
utils.set_hires(True)
def postprocess(self, p, processed, opt_enable=True, opt_plot_lora_weight=False, **kwargs):
if opt_plot_lora_weight and opt_enable:
processed.images.extend([plot.make_plot()])
|