|
|
from modules import scripts |
|
|
from json import dumps |
|
|
import re |
|
|
|
|
|
from scripts.couple_mapping import ( |
|
|
empty_tensor, |
|
|
basic_mapping, |
|
|
advanced_mapping, |
|
|
mask_mapping, |
|
|
) |
|
|
|
|
|
from scripts.couple_ui import couple_UI |
|
|
from scripts.ui_funcs import validate_mapping, parse_mapping |
|
|
|
|
|
from scripts.attention_couple import AttentionCouple |
|
|
forgeAttentionCouple = AttentionCouple() |
|
|
|
|
|
VERSION = "1.6.2" |
|
|
|
|
|
from scripts.gr_version import js, is_gradio_4 |
|
|
|
|
|
|
|
|
class ForgeCouple(scripts.Script): |
|
|
|
|
|
def __init__(self): |
|
|
self.couples: list = None |
|
|
|
|
|
def title(self): |
|
|
return "Forge Couple" |
|
|
|
|
|
def show(self, is_img2img): |
|
|
return scripts.AlwaysVisible |
|
|
|
|
|
def ui(self, is_img2img): |
|
|
return couple_UI(self, is_img2img, f"{self.title()} v{VERSION}") |
|
|
|
|
|
def after_component(self, component, **kwargs): |
|
|
if is_gradio_4: |
|
|
return |
|
|
|
|
|
if kwargs.get("elem_id") in ( |
|
|
"img2img_image", |
|
|
"img2img_sketch", |
|
|
"inpaint_sketch", |
|
|
"img_inpaint_base", |
|
|
): |
|
|
component.change( |
|
|
None, component, None, **js("(img) => { ForgeCouple.preview(img); }") |
|
|
) |
|
|
|
|
|
def parse_networks(self, prompt: str) -> str: |
|
|
"""LoRAs are already parsed""" |
|
|
pattern = re.compile(r"<.*?>") |
|
|
cleaned = re.sub(pattern, "", prompt) |
|
|
|
|
|
return cleaned |
|
|
|
|
|
def after_extra_networks_activate( |
|
|
self, |
|
|
p, |
|
|
enable: bool, |
|
|
mode: str, |
|
|
separator: str, |
|
|
direction: str, |
|
|
background: str, |
|
|
background_weight: float, |
|
|
mapping: list, |
|
|
*args, |
|
|
**kwargs, |
|
|
): |
|
|
if not enable: |
|
|
return |
|
|
|
|
|
separator = "\n" if not separator.strip() else separator.strip() |
|
|
|
|
|
couples = [] |
|
|
|
|
|
chunks = kwargs["prompts"][0].split(separator) |
|
|
for chunk in chunks: |
|
|
prompt = self.parse_networks(chunk).strip() |
|
|
|
|
|
if not prompt.strip(): |
|
|
|
|
|
continue |
|
|
|
|
|
couples.append(prompt) |
|
|
|
|
|
match mode: |
|
|
case "Basic": |
|
|
if len(couples) < (3 if background != "None" else 2): |
|
|
print( |
|
|
f"\n\n[Couple] Not Enough Lines in Prompt...\nCurrent: {len(couples)} / Required: {3 if background != 'None' else 2}\n\n" |
|
|
) |
|
|
|
|
|
case "Mask": |
|
|
if not mapping or len(mapping) != len(couples) - int( |
|
|
background in ("First Line", "Last Line") |
|
|
): |
|
|
print( |
|
|
f"\n\n[Couple] Number of Couples and Masks is not the same...\nCurrent: {len(couples)} / Required: {len(mapping) + int(background in ('First Line', 'Last Line'))}\n\n" |
|
|
) |
|
|
self.couples = None |
|
|
return |
|
|
|
|
|
case "Advanced": |
|
|
if not validate_mapping(mapping): |
|
|
self.couples = None |
|
|
return |
|
|
|
|
|
if not mapping or (len(parse_mapping(mapping)) != len(couples)): |
|
|
print( |
|
|
f"\n\n[Couple] Number of Couples and Mapping is not the same...\nCurrent: {len(couples)} / Required: {len(parse_mapping(mapping))}\n\n" |
|
|
) |
|
|
self.couples = None |
|
|
return |
|
|
|
|
|
|
|
|
p.extra_generation_params["forge_couple"] = True |
|
|
p.extra_generation_params["forge_couple_separator"] = ( |
|
|
"\n" if not separator.strip() else separator.strip() |
|
|
) |
|
|
p.extra_generation_params["forge_couple_mode"] = mode |
|
|
if mode == "Basic": |
|
|
p.extra_generation_params["forge_couple_direction"] = direction |
|
|
p.extra_generation_params["forge_couple_background"] = background |
|
|
p.extra_generation_params["forge_couple_background_weight"] = ( |
|
|
background_weight |
|
|
) |
|
|
elif mode == "Advanced": |
|
|
p.extra_generation_params["forge_couple_mapping"] = dumps(mapping) |
|
|
|
|
|
|
|
|
self.couples = couples |
|
|
|
|
|
def process_before_every_sampling( |
|
|
self, |
|
|
p, |
|
|
enable: bool, |
|
|
mode: str, |
|
|
separator: str, |
|
|
direction: str, |
|
|
background: str, |
|
|
background_weight: float, |
|
|
mapping: list, |
|
|
*args, |
|
|
**kwargs, |
|
|
): |
|
|
|
|
|
if not enable or not self.couples: |
|
|
return |
|
|
|
|
|
|
|
|
unet = p.sd_model.forge_objects.unet |
|
|
|
|
|
WIDTH: int = p.width |
|
|
HEIGHT: int = p.height |
|
|
IS_HORIZONTAL: bool = direction == "Horizontal" |
|
|
|
|
|
LINE_COUNT: int = len(self.couples) |
|
|
|
|
|
if mode in ("Basic", "Mask"): |
|
|
BG_WEIGHT: float = ( |
|
|
0.0 |
|
|
if (background not in ("First Line", "Last Line")) |
|
|
else max(0.1, background_weight) |
|
|
) |
|
|
|
|
|
if mode == "Basic": |
|
|
TILE_COUNT: int = LINE_COUNT - int( |
|
|
background in ("First Line", "Last Line") |
|
|
) |
|
|
TILE_WEIGHT: float = ( |
|
|
1.25 if (background not in ("First Line", "Last Line")) else 1.0 |
|
|
) |
|
|
TILE_SIZE: int = ( |
|
|
(WIDTH if IS_HORIZONTAL else HEIGHT) - 1 |
|
|
) // TILE_COUNT + 1 |
|
|
|
|
|
|
|
|
|
|
|
match mode: |
|
|
case "Basic": |
|
|
ARGs = basic_mapping( |
|
|
p.sd_model, |
|
|
self.couples, |
|
|
WIDTH, |
|
|
HEIGHT, |
|
|
LINE_COUNT, |
|
|
IS_HORIZONTAL, |
|
|
background, |
|
|
TILE_SIZE, |
|
|
TILE_WEIGHT, |
|
|
BG_WEIGHT, |
|
|
) |
|
|
|
|
|
case "Mask": |
|
|
ARGs = mask_mapping( |
|
|
p.sd_model, |
|
|
self.couples, |
|
|
WIDTH, |
|
|
HEIGHT, |
|
|
LINE_COUNT, |
|
|
mapping, |
|
|
background, |
|
|
BG_WEIGHT, |
|
|
) |
|
|
|
|
|
case "Advanced": |
|
|
ARGs = advanced_mapping( |
|
|
p.sd_model, self.couples, WIDTH, HEIGHT, mapping |
|
|
) |
|
|
|
|
|
|
|
|
assert len(ARGs.keys()) // 2 == LINE_COUNT |
|
|
|
|
|
base_mask = empty_tensor(HEIGHT, WIDTH) |
|
|
patched_unet = forgeAttentionCouple.patch_unet(unet, base_mask, ARGs) |
|
|
p.sd_model.forge_objects.unet = patched_unet |
|
|
|