dikdimon's picture
Upload extensions using SD-Hub extension
3dabe4a verified
raw
history blame
6.49 kB
from modules import scripts
from json import dumps
import re
from lib_couple.mapping import (
empty_tensor,
basic_mapping,
advanced_mapping,
mask_mapping,
)
from lib_couple.ui import couple_UI
from lib_couple.ui_funcs import validate_mapping
from lib_couple.attention_couple import AttentionCouple
forgeAttentionCouple = AttentionCouple()
VERSION = "2.0.3"
from lib_couple.gr_version import js
class ForgeCouple(scripts.Script):
def __init__(self):
self.couples: list = None
def title(self):
return "Forge Couple"
def show(self, is_img2img):
return scripts.AlwaysVisible
def ui(self, is_img2img):
return couple_UI(self, is_img2img, f"{self.title()} v{VERSION}")
def after_component(self, component, **kwargs):
if not (elem_id := kwargs.get("elem_id", None)):
return
if elem_id in ("txt2img_width", "txt2img_height"):
component.change(None, **js('() => { ForgeCouple.preview("t2i"); }'))
elif elem_id in ("img2img_width", "img2img_height"):
component.change(None, **js('() => { ForgeCouple.preview("i2i"); }'))
@staticmethod
def strip_networks(prompt: str) -> str:
"""LoRAs are already parsed thus no longer needed"""
pattern = re.compile(r"<.*?>")
cleaned = re.sub(pattern, "", prompt)
return cleaned
def after_extra_networks_activate(
self,
p,
enable: bool,
mode: str,
separator: str,
direction: str,
background: str,
background_weight: float,
mapping: list,
*args,
**kwargs,
):
if not enable:
return
separator = "\n" if not separator.strip() else separator.strip()
couples = []
chunks = kwargs["prompts"][0].split(separator)
for chunk in chunks:
prompt = self.strip_networks(chunk).strip()
if not prompt.strip():
# Skip Empty Lines
continue
couples.append(prompt)
match mode:
case "Basic":
if len(couples) < (3 - int(background == "None")):
print("\n[Couple] Not Enough Lines in Prompt...")
print(f"\t[{len(couples)} / {3 - int(background == 'None')}]\n")
self.couples = None
return
case "Mask":
if not mapping:
print("\n[Couple] No Mapping...?\n")
self.couples = None
return
if len(couples) != len(mapping) + int(background != "None"):
print(
f"""\n[Couple] Number of Couples and Masks is not the same...
\t[{len(couples)} / {len(mapping) + int( background != 'None')}]\n"""
)
self.couples = None
return
case "Advanced":
if not mapping:
print("\n[Couple] No Mapping...?\n")
self.couples = None
return
if not validate_mapping(mapping):
self.couples = None
return
if len(couples) != len(mapping):
print("\n[Couple] Number of Couples and Mapping is not the same...")
print(f"[{len(couples)} / {len(mapping)}]\n")
self.couples = None
return
# ===== Infotext =====
p.extra_generation_params["forge_couple"] = True
p.extra_generation_params["forge_couple_separator"] = separator
p.extra_generation_params["forge_couple_mode"] = mode
if mode == "Basic":
p.extra_generation_params.update(
{
"forge_couple_direction": direction,
"forge_couple_background": background,
"forge_couple_background_weight": background_weight,
}
)
elif mode == "Advanced":
p.extra_generation_params["forge_couple_mapping"] = dumps(mapping)
# ===== Infotext =====
self.couples = couples
def process_before_every_sampling(
self,
p,
enable: bool,
mode: str,
separator: str,
direction: str,
background: str,
background_weight: float,
mapping: list,
*args,
**kwargs,
):
if not enable or not self.couples:
return
# ===== Init =====
unet = p.sd_model.forge_objects.unet
WIDTH: int = p.width
HEIGHT: int = p.height
IS_HORIZONTAL: bool = direction == "Horizontal"
NO_BACKGROUND: bool = background == "None"
LINE_COUNT: int = len(self.couples)
if mode != "Advanced":
BG_WEIGHT: float = 0.0 if NO_BACKGROUND else max(0.1, background_weight)
if mode == "Basic":
TILE_COUNT: int = LINE_COUNT - int(not NO_BACKGROUND)
TILE_WEIGHT: float = 1.25 if NO_BACKGROUND else 1.0
TILE_SIZE: int = (
(WIDTH if IS_HORIZONTAL else HEIGHT) - 1
) // TILE_COUNT + 1
# ===== Init =====
# ===== Tiles =====
match mode:
case "Basic":
ARGs = basic_mapping(
p.sd_model,
self.couples,
WIDTH,
HEIGHT,
LINE_COUNT,
IS_HORIZONTAL,
background,
TILE_SIZE,
TILE_WEIGHT,
BG_WEIGHT,
)
case "Mask":
ARGs = mask_mapping(
p.sd_model,
self.couples,
WIDTH,
HEIGHT,
LINE_COUNT,
mapping,
background,
BG_WEIGHT,
)
case "Advanced":
ARGs = advanced_mapping(
p.sd_model, self.couples, WIDTH, HEIGHT, mapping
)
# ===== Tiles =====
assert len(ARGs.keys()) // 2 == LINE_COUNT
base_mask = empty_tensor(HEIGHT, WIDTH)
patched_unet = forgeAttentionCouple.patch_unet(unet, base_mask, ARGs)
p.sd_model.forge_objects.unet = patched_unet