|
|
import logging |
|
|
from os import environ |
|
|
import modules.scripts as scripts |
|
|
import gradio as gr |
|
|
import scipy.stats as stats |
|
|
|
|
|
from modules import script_callbacks, prompt_parser |
|
|
from modules.script_callbacks import CFGDenoiserParams |
|
|
from modules.prompt_parser import reconstruct_multicond_batch |
|
|
from modules.processing import StableDiffusionProcessing |
|
|
|
|
|
from modules.sd_samplers_cfg_denoiser import pad_cond |
|
|
from modules import shared |
|
|
|
|
|
import torch |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
logger.setLevel(environ.get("SD_WEBUI_LOG_LEVEL", logging.INFO)) |
|
|
|
|
|
""" |
|
|
|
|
|
An unofficial implementation of SEGA: Instructing Text-to-Image Models using Semantic Guidance for Automatic1111 WebUI |
|
|
|
|
|
@misc{brack2023sega, |
|
|
title={SEGA: Instructing Text-to-Image Models using Semantic Guidance}, |
|
|
author={Manuel Brack and Felix Friedrich and Dominik Hintersdorf and Lukas Struppek and Patrick Schramowski and Kristian Kersting}, |
|
|
year={2023}, |
|
|
eprint={2301.12247}, |
|
|
archivePrefix={arXiv}, |
|
|
primaryClass={cs.CV} |
|
|
} |
|
|
|
|
|
Author: v0xie |
|
|
GitHub URL: https://github.com/v0xie/sd-webui-semantic-guidance |
|
|
|
|
|
""" |
|
|
|
|
|
class SegaStateParams: |
|
|
def __init__(self): |
|
|
self.concept_name = '' |
|
|
self.v = {} |
|
|
self.warmup_period: int = 10 |
|
|
self.edit_guidance_scale: float = 1 |
|
|
self.tail_percentage_threshold: float = 0.05 |
|
|
self.momentum_scale: float = 0.3 |
|
|
self.momentum_beta: float = 0.6 |
|
|
self.strength = 1.0 |
|
|
|
|
|
class SegaExtensionScript(scripts.Script): |
|
|
def __init__(self): |
|
|
self.cached_c = [None, None] |
|
|
|
|
|
|
|
|
def title(self): |
|
|
return "Semantic Guidance" |
|
|
|
|
|
|
|
|
def show(self, is_img2img): |
|
|
return scripts.AlwaysVisible |
|
|
|
|
|
|
|
|
def ui(self, is_img2img): |
|
|
with gr.Accordion('Semantic Guidance', open=False): |
|
|
active = gr.Checkbox(value=False, default=False, label="Active", elem_id='sega_active') |
|
|
with gr.Row(): |
|
|
prompt = gr.Textbox(lines=2, label="Prompt", elem_id = 'sega_prompt', elem_classes=["prompt"]) |
|
|
with gr.Row(): |
|
|
neg_prompt = gr.Textbox(lines=2, label="Negative Prompt", elem_id = 'sega_neg_prompt', elem_classes=["prompt"]) |
|
|
with gr.Row(): |
|
|
warmup = gr.Slider(value = 10, minimum = 0, maximum = 30, step = 1, label="Warmup Period", elem_id = 'sega_warmup', info="How many steps to wait before applying semantic guidance, default 10") |
|
|
edit_guidance_scale = gr.Slider(value = 1.0, minimum = 0.0, maximum = 20.0, step = 0.01, label="Edit Guidance Scale", elem_id = 'sega_edit_guidance_scale', info="Scale of edit guidance, default 1.0") |
|
|
tail_percentage_threshold = gr.Slider(value = 0.05, minimum = 0.0, maximum = 1.0, step = 0.01, label="Tail Percentage Threshold", elem_id = 'sega_tail_percentage_threshold', info="The percentage of latents to modify, default 0.05") |
|
|
momentum_scale = gr.Slider(value = 0.3, minimum = 0.0, maximum = 1.0, step = 0.01, label="Momentum Scale", elem_id = 'sega_momentum_scale', info="Scale of momentum, default 0.3") |
|
|
momentum_beta = gr.Slider(value = 0.6, minimum = 0.0, maximum = 0.999, step = 0.01, label="Momentum Beta", elem_id = 'sega_momentum_beta', info="Beta for momentum, default 0.6") |
|
|
active.do_not_save_to_config = True |
|
|
prompt.do_not_save_to_config = True |
|
|
neg_prompt.do_not_save_to_config = True |
|
|
warmup.do_not_save_to_config = True |
|
|
edit_guidance_scale.do_not_save_to_config = True |
|
|
tail_percentage_threshold.do_not_save_to_config = True |
|
|
momentum_scale.do_not_save_to_config = True |
|
|
momentum_beta.do_not_save_to_config = True |
|
|
self.infotext_fields = [ |
|
|
(active, lambda d: gr.Checkbox.update(value='SEGA Active' in d)), |
|
|
(prompt, 'SEGA Prompt'), |
|
|
(neg_prompt, 'SEGA Negative Prompt'), |
|
|
(warmup, 'SEGA Warmup Period'), |
|
|
(edit_guidance_scale, 'SEGA Edit Guidance Scale'), |
|
|
(tail_percentage_threshold, 'SEGA Tail Percentage Threshold'), |
|
|
(momentum_scale, 'SEGA Momentum Scale'), |
|
|
(momentum_beta, 'SEGA Momentum Beta'), |
|
|
] |
|
|
self.paste_field_names = [ |
|
|
'sega_active', |
|
|
'sega_prompt', |
|
|
'sega_neg_prompt', |
|
|
'sega_warmup', |
|
|
'sega_edit_guidance_scale', |
|
|
'sega_tail_percentage_threshold', |
|
|
'sega_momentum_scale', |
|
|
'sega_momentum_beta' |
|
|
] |
|
|
return [active, prompt, neg_prompt, warmup, edit_guidance_scale, tail_percentage_threshold, momentum_scale, momentum_beta] |
|
|
|
|
|
def process_batch(self, p: StableDiffusionProcessing, active, prompt, neg_prompt, warmup, edit_guidance_scale, tail_percentage_threshold, momentum_scale, momentum_beta, *args, **kwargs): |
|
|
active = getattr(p, "sega_active", active) |
|
|
if active is False: |
|
|
return |
|
|
prompt = getattr(p, "sega_prompt", prompt) |
|
|
neg_prompt = getattr(p, "sega_neg_prompt", neg_prompt) |
|
|
warmup = getattr(p, "sega_warmup", warmup) |
|
|
edit_guidance_scale = getattr(p, "sega_edit_guidance_scale", edit_guidance_scale) |
|
|
tail_percentage_threshold = getattr(p, "sega_tail_percentage_threshold", tail_percentage_threshold) |
|
|
momentum_scale = getattr(p, "sega_momentum_scale", momentum_scale) |
|
|
momentum_beta = getattr(p, "sega_momentum_beta", momentum_beta) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
p.extra_generation_params.update({ |
|
|
"SEGA Active": active, |
|
|
"SEGA Prompt": prompt, |
|
|
"SEGA Negative Prompt": neg_prompt, |
|
|
"SEGA Warmup Period": warmup, |
|
|
"SEGA Edit Guidance Scale": edit_guidance_scale, |
|
|
"SEGA Tail Percentage Threshold": tail_percentage_threshold, |
|
|
"SEGA Momentum Scale": momentum_scale, |
|
|
"SEGA Momentum Beta": momentum_beta, |
|
|
}) |
|
|
|
|
|
|
|
|
concept_prompts = self.parse_concept_prompt(prompt) |
|
|
concept_prompts_neg = self.parse_concept_prompt(neg_prompt) |
|
|
|
|
|
concept_prompts = [prompt_parser.parse_prompt_attention(concept)[0] for concept in concept_prompts] |
|
|
concept_prompts_neg = [prompt_parser.parse_prompt_attention(neg_concept)[0] for neg_concept in concept_prompts_neg] |
|
|
concept_prompts_neg = [[concept, -strength] for concept, strength in concept_prompts_neg] |
|
|
concept_prompts.extend(concept_prompts_neg) |
|
|
|
|
|
concept_conds = [] |
|
|
for concept, strength in concept_prompts: |
|
|
prompt_list = [concept] * p.batch_size |
|
|
prompts = prompt_parser.SdConditioning(prompt_list, width=p.width, height=p.height) |
|
|
c = p.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, p.steps, [self.cached_c], p.extra_network_data) |
|
|
concept_conds.append([c, strength]) |
|
|
|
|
|
self.create_hook(p, active, concept_conds, None, warmup, edit_guidance_scale, tail_percentage_threshold, momentum_scale, momentum_beta) |
|
|
|
|
|
def parse_concept_prompt(self, prompt:str) -> list[str]: |
|
|
""" |
|
|
Separate prompt by comma into a list of concepts |
|
|
TODO: parse prompt into a list of concepts using A1111 functions |
|
|
>>> g = lambda prompt: self.parse_concept_prompt(prompt) |
|
|
>>> g("") |
|
|
[] |
|
|
>>> g("apples") |
|
|
['apples'] |
|
|
>>> g("apple, banana, carrot") |
|
|
['apple', 'banana', 'carrot'] |
|
|
""" |
|
|
if len(prompt) == 0: |
|
|
return [] |
|
|
return [x.strip() for x in prompt.split(",")] |
|
|
|
|
|
def create_hook(self, p, active, concept_conds, concept_conds_neg, warmup, edit_guidance_scale, tail_percentage_threshold, momentum_scale, momentum_beta, *args, **kwargs): |
|
|
|
|
|
concepts_sega_params = [] |
|
|
for _, strength in concept_conds: |
|
|
sega_params = SegaStateParams() |
|
|
sega_params.warmup_period = warmup |
|
|
sega_params.edit_guidance_scale = edit_guidance_scale |
|
|
sega_params.tail_percentage_threshold = tail_percentage_threshold |
|
|
sega_params.momentum_scale = momentum_scale |
|
|
sega_params.momentum_beta = momentum_beta |
|
|
sega_params.strength = strength |
|
|
concepts_sega_params.append(sega_params) |
|
|
|
|
|
|
|
|
y = lambda params: self.on_cfg_denoiser_callback(params, concept_conds, concepts_sega_params) |
|
|
|
|
|
logger.debug('Hooked callbacks') |
|
|
script_callbacks.on_cfg_denoiser(y) |
|
|
script_callbacks.on_script_unloaded(self.unhook_callbacks) |
|
|
|
|
|
def postprocess_batch(self, p, active, neg_text, *args, **kwargs): |
|
|
active = getattr(p, "sega_active", active) |
|
|
if active is False: |
|
|
return |
|
|
self.unhook_callbacks() |
|
|
|
|
|
def unhook_callbacks(self): |
|
|
logger.debug('Unhooked callbacks') |
|
|
script_callbacks.remove_current_script_callbacks() |
|
|
|
|
|
def on_cfg_denoiser_callback(self, params: CFGDenoiserParams, concept_conds, sega_params: list[SegaStateParams]): |
|
|
|
|
|
sampling_step = params.sampling_step |
|
|
text_cond = params.text_cond |
|
|
text_uncond = params.text_uncond |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if text_cond.shape[1] != text_uncond.shape[1]: |
|
|
empty = shared.sd_model.cond_stage_model_empty_prompt |
|
|
num_repeats = (text_cond.shape[1] - text_uncond.shape[1]) // empty.shape[1] |
|
|
|
|
|
if num_repeats < 0: |
|
|
text_cond = pad_cond(text_cond, -num_repeats, empty) |
|
|
elif num_repeats > 0: |
|
|
text_uncond = pad_cond(text_uncond, num_repeats, empty) |
|
|
|
|
|
batch_conds_list = [] |
|
|
batch_tensor = {} |
|
|
|
|
|
|
|
|
if isinstance(text_cond, torch.Tensor): |
|
|
text_cond = {'crossattn': text_cond} |
|
|
if isinstance(text_uncond, torch.Tensor): |
|
|
text_uncond = {'crossattn': text_uncond} |
|
|
|
|
|
for i, _ in enumerate(sega_params): |
|
|
concept_cond, _ = concept_conds[i] |
|
|
conds_list, tensor_dict = reconstruct_multicond_batch(concept_cond, sampling_step) |
|
|
|
|
|
|
|
|
if isinstance(tensor_dict, torch.Tensor): |
|
|
tensor_dict = {'crossattn': tensor_dict} |
|
|
|
|
|
|
|
|
for key, tensor in tensor_dict.items(): |
|
|
if tensor.shape[1] != text_uncond[key].shape[1]: |
|
|
empty = shared.sd_model.cond_stage_model_empty_prompt |
|
|
|
|
|
if key == "crossattn": |
|
|
num_repeats = (tensor.shape[1] - text_uncond[key].shape[1]) // empty.shape[1] |
|
|
|
|
|
else: |
|
|
num_repeats = (tensor.shape[1] - text_uncond.shape[1]) // empty.shape[1] |
|
|
if num_repeats < 0: |
|
|
tensor = pad_cond(tensor, -num_repeats, empty) |
|
|
tensor = tensor.unsqueeze(0) |
|
|
if key not in batch_tensor.keys(): |
|
|
batch_tensor[key] = tensor |
|
|
else: |
|
|
batch_tensor[key] = torch.cat((batch_tensor[key], tensor), dim=0) |
|
|
batch_conds_list.append(conds_list) |
|
|
self.sega_routine_batch(params, batch_conds_list, batch_tensor, sega_params, text_cond, text_uncond) |
|
|
|
|
|
def make_tuple_dim(self, dim): |
|
|
|
|
|
if isinstance(dim, torch.Tensor): |
|
|
dim = dim.dim() |
|
|
return (-1,) + (1,) * (dim - 1) |
|
|
|
|
|
def sega_routine_batch(self, params: CFGDenoiserParams, batch_conds_list, batch_tensor, sega_params: list[SegaStateParams], text_cond, text_uncond): |
|
|
|
|
|
warmup_period = sega_params[0].warmup_period |
|
|
edit_guidance_scale = sega_params[0].edit_guidance_scale |
|
|
tail_percentage_threshold = sega_params[0].tail_percentage_threshold |
|
|
momentum_scale = sega_params[0].momentum_scale |
|
|
momentum_beta = sega_params[0].momentum_beta |
|
|
|
|
|
sampling_step = params.sampling_step |
|
|
|
|
|
|
|
|
edit_dir_dict = {} |
|
|
|
|
|
|
|
|
|
|
|
for key, concept_cond in batch_tensor.items(): |
|
|
new_shape = self.make_tuple_dim(concept_cond) |
|
|
strength = torch.Tensor([params.strength for params in sega_params]).to(dtype=concept_cond.dtype, device=concept_cond.device) |
|
|
strength = strength.view(new_shape) |
|
|
|
|
|
if key not in edit_dir_dict.keys(): |
|
|
edit_dir_dict[key] = torch.zeros_like(concept_cond, dtype=concept_cond.dtype, device=concept_cond.device) |
|
|
|
|
|
|
|
|
|
|
|
inside_dim = tuple(range(-concept_cond.dim() + 1, 0)) |
|
|
cond_mean, cond_std = torch.mean(concept_cond, dim=inside_dim), torch.std(concept_cond, dim=inside_dim) |
|
|
|
|
|
|
|
|
edit_dir = concept_cond - text_uncond[key] |
|
|
|
|
|
|
|
|
edit_dir = torch.mul(strength, edit_dir) |
|
|
|
|
|
|
|
|
upper_z = stats.norm.ppf(1.0 - tail_percentage_threshold) |
|
|
|
|
|
|
|
|
|
|
|
upper_threshold = cond_mean + (upper_z * cond_std) |
|
|
|
|
|
|
|
|
|
|
|
new_shape = self.make_tuple_dim(concept_cond) |
|
|
upper_threshold_reshaped = upper_threshold.view(new_shape) |
|
|
|
|
|
|
|
|
|
|
|
zero_tensor = torch.zeros_like(concept_cond, dtype=concept_cond.dtype, device=concept_cond.device) |
|
|
scale_tensor = torch.ones_like(concept_cond, dtype=concept_cond.dtype, device=concept_cond.device) * edit_guidance_scale |
|
|
edit_dir_abs = edit_dir.abs() |
|
|
scale_tensor = torch.where((edit_dir_abs > upper_threshold_reshaped), scale_tensor, zero_tensor) |
|
|
|
|
|
|
|
|
guidance_strength = 0.0 if sampling_step < warmup_period else 1.0 |
|
|
edit_dir = torch.mul(scale_tensor, edit_dir) |
|
|
edit_dir_dict[key] = edit_dir_dict[key] + guidance_strength * edit_dir |
|
|
|
|
|
|
|
|
for i, sega_param in enumerate(sega_params): |
|
|
for key, dir in edit_dir_dict.items(): |
|
|
|
|
|
if key not in sega_param.v.keys(): |
|
|
slice_idx = 1 - dir.dim() |
|
|
sega_param.v[key] = torch.zeros(dir.shape[slice_idx:], dtype=dir.dtype, device=dir.device) |
|
|
|
|
|
|
|
|
v_t = sega_param.v[key] |
|
|
dir[i] = dir[i] + torch.mul(momentum_scale, v_t) |
|
|
|
|
|
|
|
|
v_t_1 = momentum_beta * ((1 - momentum_beta) * v_t) * dir[i] |
|
|
|
|
|
|
|
|
|
|
|
if sampling_step >= warmup_period: |
|
|
if isinstance(params.text_cond, dict): |
|
|
params.text_cond[key] = params.text_cond[key] + dir[i] |
|
|
else: |
|
|
params.text_cond = params.text_cond + dir[i] |
|
|
|
|
|
|
|
|
sega_param.v[key] = v_t_1 |
|
|
|
|
|
|
|
|
|
|
|
def sega_apply_override(field, boolean: bool = False): |
|
|
def fun(p, x, xs): |
|
|
if boolean: |
|
|
x = True if x.lower() == "true" else False |
|
|
setattr(p, field, x) |
|
|
return fun |
|
|
|
|
|
def sega_apply_field(field): |
|
|
def fun(p, x, xs): |
|
|
if not hasattr(p, "sega_active"): |
|
|
setattr(p, "sega_active", True) |
|
|
setattr(p, field, x) |
|
|
|
|
|
return fun |
|
|
|
|
|
def make_axis_options(): |
|
|
xyz_grid = [x for x in scripts.scripts_data if x.script_class.__module__ in ("xyz_grid.py", "scripts.xyz_grid")][0].module |
|
|
extra_axis_options = { |
|
|
xyz_grid.AxisOption("[Semantic Guidance] Active", str, sega_apply_override('sega_active', boolean=True), choices=xyz_grid.boolean_choice(reverse=True)), |
|
|
xyz_grid.AxisOption("[Semantic Guidance] Prompt", str, sega_apply_field("sega_prompt")), |
|
|
xyz_grid.AxisOption("[Semantic Guidance] Negative Prompt", str, sega_apply_field("sega_neg_prompt")), |
|
|
xyz_grid.AxisOption("[Semantic Guidance] Warmup Steps", int, sega_apply_field("sega_warmup")), |
|
|
xyz_grid.AxisOption("[Semantic Guidance] Guidance Scale", float, sega_apply_field("sega_edit_guidance_scale")), |
|
|
xyz_grid.AxisOption("[Semantic Guidance] Tail Percentage Threshold", float, sega_apply_field("sega_tail_percentage_threshold")), |
|
|
xyz_grid.AxisOption("[Semantic Guidance] Momentum Scale", float, sega_apply_field("sega_momentum_scale")), |
|
|
xyz_grid.AxisOption("[Semantic Guidance] Momentum Beta", float, sega_apply_field("sega_momentum_beta")), |
|
|
} |
|
|
if not any("[Semantic Guidance]" in x.label for x in xyz_grid.axis_options): |
|
|
xyz_grid.axis_options.extend(extra_axis_options) |
|
|
|
|
|
def callback_before_ui(): |
|
|
try: |
|
|
make_axis_options() |
|
|
except: |
|
|
logger.exception("Semantic Guidance: Error while making axis options") |
|
|
|
|
|
script_callbacks.on_before_ui(callback_before_ui) |
|
|
|