dikdimon's picture
Upload extensions using SD-Hub extension
3dabe4a verified
from modules import scripts
import re
from scripts.couple_mapping import empty_tensor, basic_mapping, advanced_mapping
from scripts.couple_ui import couple_UI, validata_mapping, parse_mapping
from scripts.attention_couple import AttentionCouple
forgeAttentionCouple = AttentionCouple()
VERSION = "1.3.0"
class ForgeCouple(scripts.Script):
def __init__(self):
self.couples: list = None
def title(self):
return "Forge Couple"
def show(self, is_img2img):
return scripts.AlwaysVisible
def ui(self, *args, **kwargs):
return couple_UI(self, f"{self.title()} {VERSION}")
def parse_networks(self, prompt: str) -> str:
"""LoRAs are already parsed"""
pattern = re.compile(r"<.*?>")
cleaned = re.sub(pattern, "", prompt)
return cleaned
def after_extra_networks_activate(
self,
p,
enable: bool,
direction: str,
background: str,
separator: str,
mode: str,
mapping: list,
*args,
**kwargs,
):
if not enable:
return
if not separator.strip():
separator = "\n"
couples = []
chunks = kwargs["prompts"][0].split(separator)
for chunk in chunks:
prompt = self.parse_networks(chunk).strip()
if not prompt.strip():
# Skip Empty Lines
continue
couples.append(prompt)
if len(couples) < (3 if background != "None" else 2):
print("\n\n[Couple] Not Enough Lines in Prompt...\n\n")
self.couples = None
return
if (mode == "Advanced") and (len(couples) != len(parse_mapping(mapping))):
print("\n\n[Couple] Number of Couples and Mapping is not the same...\n\n")
self.couples = None
return
if (mode == "Advanced") and not validata_mapping(mapping):
self.couples = None
return
self.couples = couples
def process_before_every_sampling(
self,
p,
enable: bool,
direction: str,
background: str,
separator: str,
mode: str,
mapping: list,
*args,
**kwargs,
):
if not enable or not self.couples:
return
# ===== Infotext =====
p.extra_generation_params["forge_couple"] = True
p.extra_generation_params["forge_couple_separator"] = "\n" if not separator.strip() else separator
p.extra_generation_params["forge_couple_mode"] = mode
if mode == "Basic":
p.extra_generation_params["forge_couple_direction"] = direction
p.extra_generation_params["forge_couple_background"] = background
else:
p.extra_generation_params["forge_couple_mapping"] = mapping
# ===== Infotext =====
# ===== Init =====
unet = p.sd_model.forge_objects.unet
WIDTH: int = p.width
HEIGHT: int = p.height
IS_HORIZONTAL: bool = direction == "Horizontal"
LINE_COUNT: int = len(self.couples)
TILE_COUNT: int = LINE_COUNT - (background != "None")
if mode == "Basic":
TILE_WEIGHT: float = 1.25 if background == "None" else 1.0
BG_WEIGHT: float = 0.0 if background == "None" else 0.5
TILE_SIZE: int = (
(WIDTH if IS_HORIZONTAL else HEIGHT) - 1
) // TILE_COUNT + 1
# ===== Init =====
# ===== Tiles =====
if mode == "Basic":
ARGs = basic_mapping(
p.sd_model,
self.couples,
WIDTH,
HEIGHT,
LINE_COUNT,
IS_HORIZONTAL,
background,
TILE_SIZE,
TILE_WEIGHT,
BG_WEIGHT,
)
else:
ARGs = advanced_mapping(p.sd_model, self.couples, WIDTH, HEIGHT, mapping)
# ===== Tiles =====
assert len(ARGs.keys()) // 2 == LINE_COUNT
base_mask = empty_tensor(HEIGHT, WIDTH)
patched_unet = forgeAttentionCouple.patch_unet(unet, base_mask, ARGs)
p.sd_model.forge_objects.unet = patched_unet