dikdimon's picture
Upload extensions using SD-Hub extension
3dabe4a verified
import logging
from os import environ
import modules.scripts as scripts
import gradio as gr
import scipy.stats as stats
from modules import script_callbacks, prompt_parser
from modules.script_callbacks import CFGDenoiserParams
from modules.prompt_parser import reconstruct_multicond_batch
from modules.processing import StableDiffusionProcessing
#from modules.shared import sd_model, opts
from modules.sd_samplers_cfg_denoiser import pad_cond
from modules import shared
import torch
logger = logging.getLogger(__name__)
logger.setLevel(environ.get("SD_WEBUI_LOG_LEVEL", logging.INFO))
"""
An unofficial implementation of SEGA: Instructing Text-to-Image Models using Semantic Guidance for Automatic1111 WebUI
@misc{brack2023sega,
title={SEGA: Instructing Text-to-Image Models using Semantic Guidance},
author={Manuel Brack and Felix Friedrich and Dominik Hintersdorf and Lukas Struppek and Patrick Schramowski and Kristian Kersting},
year={2023},
eprint={2301.12247},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
Author: v0xie
GitHub URL: https://github.com/v0xie/sd-webui-semantic-guidance
"""
class SegaStateParams:
def __init__(self):
self.concept_name = ''
self.v = {} # velocity
self.warmup_period: int = 10 # [0, 20]
self.edit_guidance_scale: float = 1 # [0., 1.]
self.tail_percentage_threshold: float = 0.05 # [0., 1.] if abs value of difference between uncodition and concept-conditioned is less than this, then zero out the concept-conditioned values less than this
self.momentum_scale: float = 0.3 # [0., 1.]
self.momentum_beta: float = 0.6 # [0., 1.) # larger bm is less volatile changes in momentum
self.strength = 1.0
class SegaExtensionScript(scripts.Script):
def __init__(self):
self.cached_c = [None, None]
# Extension title in menu UI
def title(self):
return "Semantic Guidance"
# Decide to show menu in txt2img or img2img
def show(self, is_img2img):
return scripts.AlwaysVisible
# Setup menu ui detail
def ui(self, is_img2img):
with gr.Accordion('Semantic Guidance', open=False):
active = gr.Checkbox(value=False, default=False, label="Active", elem_id='sega_active')
with gr.Row():
prompt = gr.Textbox(lines=2, label="Prompt", elem_id = 'sega_prompt', elem_classes=["prompt"])
with gr.Row():
neg_prompt = gr.Textbox(lines=2, label="Negative Prompt", elem_id = 'sega_neg_prompt', elem_classes=["prompt"])
with gr.Row():
warmup = gr.Slider(value = 10, minimum = 0, maximum = 30, step = 1, label="Warmup Period", elem_id = 'sega_warmup', info="How many steps to wait before applying semantic guidance, default 10")
edit_guidance_scale = gr.Slider(value = 1.0, minimum = 0.0, maximum = 20.0, step = 0.01, label="Edit Guidance Scale", elem_id = 'sega_edit_guidance_scale', info="Scale of edit guidance, default 1.0")
tail_percentage_threshold = gr.Slider(value = 0.05, minimum = 0.0, maximum = 1.0, step = 0.01, label="Tail Percentage Threshold", elem_id = 'sega_tail_percentage_threshold', info="The percentage of latents to modify, default 0.05")
momentum_scale = gr.Slider(value = 0.3, minimum = 0.0, maximum = 1.0, step = 0.01, label="Momentum Scale", elem_id = 'sega_momentum_scale', info="Scale of momentum, default 0.3")
momentum_beta = gr.Slider(value = 0.6, minimum = 0.0, maximum = 0.999, step = 0.01, label="Momentum Beta", elem_id = 'sega_momentum_beta', info="Beta for momentum, default 0.6")
active.do_not_save_to_config = True
prompt.do_not_save_to_config = True
neg_prompt.do_not_save_to_config = True
warmup.do_not_save_to_config = True
edit_guidance_scale.do_not_save_to_config = True
tail_percentage_threshold.do_not_save_to_config = True
momentum_scale.do_not_save_to_config = True
momentum_beta.do_not_save_to_config = True
self.infotext_fields = [
(active, lambda d: gr.Checkbox.update(value='SEGA Active' in d)),
(prompt, 'SEGA Prompt'),
(neg_prompt, 'SEGA Negative Prompt'),
(warmup, 'SEGA Warmup Period'),
(edit_guidance_scale, 'SEGA Edit Guidance Scale'),
(tail_percentage_threshold, 'SEGA Tail Percentage Threshold'),
(momentum_scale, 'SEGA Momentum Scale'),
(momentum_beta, 'SEGA Momentum Beta'),
]
self.paste_field_names = [
'sega_active',
'sega_prompt',
'sega_neg_prompt',
'sega_warmup',
'sega_edit_guidance_scale',
'sega_tail_percentage_threshold',
'sega_momentum_scale',
'sega_momentum_beta'
]
return [active, prompt, neg_prompt, warmup, edit_guidance_scale, tail_percentage_threshold, momentum_scale, momentum_beta]
def process_batch(self, p: StableDiffusionProcessing, active, prompt, neg_prompt, warmup, edit_guidance_scale, tail_percentage_threshold, momentum_scale, momentum_beta, *args, **kwargs):
active = getattr(p, "sega_active", active)
if active is False:
return
prompt = getattr(p, "sega_prompt", prompt)
neg_prompt = getattr(p, "sega_neg_prompt", neg_prompt)
warmup = getattr(p, "sega_warmup", warmup)
edit_guidance_scale = getattr(p, "sega_edit_guidance_scale", edit_guidance_scale)
tail_percentage_threshold = getattr(p, "sega_tail_percentage_threshold", tail_percentage_threshold)
momentum_scale = getattr(p, "sega_momentum_scale", momentum_scale)
momentum_beta = getattr(p, "sega_momentum_beta", momentum_beta)
# FIXME: must have some prompt
#if prompt is None:
# return
#if len(prompt) == 0:
# return
p.extra_generation_params.update({
"SEGA Active": active,
"SEGA Prompt": prompt,
"SEGA Negative Prompt": neg_prompt,
"SEGA Warmup Period": warmup,
"SEGA Edit Guidance Scale": edit_guidance_scale,
"SEGA Tail Percentage Threshold": tail_percentage_threshold,
"SEGA Momentum Scale": momentum_scale,
"SEGA Momentum Beta": momentum_beta,
})
# separate concepts by comma
concept_prompts = self.parse_concept_prompt(prompt)
concept_prompts_neg = self.parse_concept_prompt(neg_prompt)
# [[concept_1, strength_1], ...]
concept_prompts = [prompt_parser.parse_prompt_attention(concept)[0] for concept in concept_prompts]
concept_prompts_neg = [prompt_parser.parse_prompt_attention(neg_concept)[0] for neg_concept in concept_prompts_neg]
concept_prompts_neg = [[concept, -strength] for concept, strength in concept_prompts_neg]
concept_prompts.extend(concept_prompts_neg)
concept_conds = []
for concept, strength in concept_prompts:
prompt_list = [concept] * p.batch_size
prompts = prompt_parser.SdConditioning(prompt_list, width=p.width, height=p.height)
c = p.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, p.steps, [self.cached_c], p.extra_network_data)
concept_conds.append([c, strength])
self.create_hook(p, active, concept_conds, None, warmup, edit_guidance_scale, tail_percentage_threshold, momentum_scale, momentum_beta)
def parse_concept_prompt(self, prompt:str) -> list[str]:
"""
Separate prompt by comma into a list of concepts
TODO: parse prompt into a list of concepts using A1111 functions
>>> g = lambda prompt: self.parse_concept_prompt(prompt)
>>> g("")
[]
>>> g("apples")
['apples']
>>> g("apple, banana, carrot")
['apple', 'banana', 'carrot']
"""
if len(prompt) == 0:
return []
return [x.strip() for x in prompt.split(",")]
def create_hook(self, p, active, concept_conds, concept_conds_neg, warmup, edit_guidance_scale, tail_percentage_threshold, momentum_scale, momentum_beta, *args, **kwargs):
# Create a list of parameters for each concept
concepts_sega_params = []
for _, strength in concept_conds:
sega_params = SegaStateParams()
sega_params.warmup_period = warmup
sega_params.edit_guidance_scale = edit_guidance_scale
sega_params.tail_percentage_threshold = tail_percentage_threshold
sega_params.momentum_scale = momentum_scale
sega_params.momentum_beta = momentum_beta
sega_params.strength = strength
concepts_sega_params.append(sega_params)
# Use lambda to call the callback function with the parameters to avoid global variables
y = lambda params: self.on_cfg_denoiser_callback(params, concept_conds, concepts_sega_params)
logger.debug('Hooked callbacks')
script_callbacks.on_cfg_denoiser(y)
script_callbacks.on_script_unloaded(self.unhook_callbacks)
def postprocess_batch(self, p, active, neg_text, *args, **kwargs):
active = getattr(p, "sega_active", active)
if active is False:
return
self.unhook_callbacks()
def unhook_callbacks(self):
logger.debug('Unhooked callbacks')
script_callbacks.remove_current_script_callbacks()
def on_cfg_denoiser_callback(self, params: CFGDenoiserParams, concept_conds, sega_params: list[SegaStateParams]):
# TODO: add option to opt out of batching for performance
sampling_step = params.sampling_step
text_cond = params.text_cond
text_uncond = params.text_uncond
# pad text_cond or text_uncond to match the length of the longest prompt
# i would prefer to let sd_samplers_cfg_denoiser.py handle the padding, but
# there isn't a callback that returns the padded conds
if text_cond.shape[1] != text_uncond.shape[1]:
empty = shared.sd_model.cond_stage_model_empty_prompt
num_repeats = (text_cond.shape[1] - text_uncond.shape[1]) // empty.shape[1]
if num_repeats < 0:
text_cond = pad_cond(text_cond, -num_repeats, empty)
elif num_repeats > 0:
text_uncond = pad_cond(text_uncond, num_repeats, empty)
batch_conds_list = []
batch_tensor = {}
# sd 1.5 support
if isinstance(text_cond, torch.Tensor):
text_cond = {'crossattn': text_cond}
if isinstance(text_uncond, torch.Tensor):
text_uncond = {'crossattn': text_uncond}
for i, _ in enumerate(sega_params):
concept_cond, _ = concept_conds[i]
conds_list, tensor_dict = reconstruct_multicond_batch(concept_cond, sampling_step)
# sd 1.5 support
if isinstance(tensor_dict, torch.Tensor):
tensor_dict = {'crossattn': tensor_dict}
# initialize here because we don't know the shape/dtype of the tensor until we reconstruct it
for key, tensor in tensor_dict.items():
if tensor.shape[1] != text_uncond[key].shape[1]:
empty = shared.sd_model.cond_stage_model_empty_prompt
# sd 1.5
if key == "crossattn":
num_repeats = (tensor.shape[1] - text_uncond[key].shape[1]) // empty.shape[1]
# sdxl
else:
num_repeats = (tensor.shape[1] - text_uncond.shape[1]) // empty.shape[1]
if num_repeats < 0:
tensor = pad_cond(tensor, -num_repeats, empty)
tensor = tensor.unsqueeze(0)
if key not in batch_tensor.keys():
batch_tensor[key] = tensor
else:
batch_tensor[key] = torch.cat((batch_tensor[key], tensor), dim=0)
batch_conds_list.append(conds_list)
self.sega_routine_batch(params, batch_conds_list, batch_tensor, sega_params, text_cond, text_uncond)
def make_tuple_dim(self, dim):
# sd 1.5 support
if isinstance(dim, torch.Tensor):
dim = dim.dim()
return (-1,) + (1,) * (dim - 1)
def sega_routine_batch(self, params: CFGDenoiserParams, batch_conds_list, batch_tensor, sega_params: list[SegaStateParams], text_cond, text_uncond):
# FIXME: these parameters should be specific to each concept
warmup_period = sega_params[0].warmup_period
edit_guidance_scale = sega_params[0].edit_guidance_scale
tail_percentage_threshold = sega_params[0].tail_percentage_threshold
momentum_scale = sega_params[0].momentum_scale
momentum_beta = sega_params[0].momentum_beta
sampling_step = params.sampling_step
# Semantic Guidance
edit_dir_dict = {}
# batch_tensor: [num_concepts, batch_size, tokens(77, 154, etc.), 2048]
# Calculate edit direction
for key, concept_cond in batch_tensor.items():
new_shape = self.make_tuple_dim(concept_cond)
strength = torch.Tensor([params.strength for params in sega_params]).to(dtype=concept_cond.dtype, device=concept_cond.device)
strength = strength.view(new_shape)
if key not in edit_dir_dict.keys():
edit_dir_dict[key] = torch.zeros_like(concept_cond, dtype=concept_cond.dtype, device=concept_cond.device)
# filter out values in-between tails
# FIXME: does this take into account image batch size?, i.e. dim 1
inside_dim = tuple(range(-concept_cond.dim() + 1, 0)) # for tensor of dim 4, returns (-3, -2, -1), for tensor of dim 3, returns (-2, -1)
cond_mean, cond_std = torch.mean(concept_cond, dim=inside_dim), torch.std(concept_cond, dim=inside_dim)
# broadcast element-wise subtraction
edit_dir = concept_cond - text_uncond[key]
# multiply by strength for positive / negative direction
edit_dir = torch.mul(strength, edit_dir)
# z-scores for tails
upper_z = stats.norm.ppf(1.0 - tail_percentage_threshold)
# numerical thresholds
# FIXME: does this take into account image batch size?, i.e. dim 1
upper_threshold = cond_mean + (upper_z * cond_std)
# reshape to be able to broadcast / use torch.where to filter out values for each concept
#new_shape = (-1,) + (1,) * (concept_cond.dim() - 1)
new_shape = self.make_tuple_dim(concept_cond)
upper_threshold_reshaped = upper_threshold.view(new_shape)
# zero out values in-between tails
# elementwise multiplication between scale tensor and edit direction
zero_tensor = torch.zeros_like(concept_cond, dtype=concept_cond.dtype, device=concept_cond.device)
scale_tensor = torch.ones_like(concept_cond, dtype=concept_cond.dtype, device=concept_cond.device) * edit_guidance_scale
edit_dir_abs = edit_dir.abs()
scale_tensor = torch.where((edit_dir_abs > upper_threshold_reshaped), scale_tensor, zero_tensor)
# update edit direction with the edit dir for this concept
guidance_strength = 0.0 if sampling_step < warmup_period else 1.0 # FIXME: Use appropriate guidance strength
edit_dir = torch.mul(scale_tensor, edit_dir)
edit_dir_dict[key] = edit_dir_dict[key] + guidance_strength * edit_dir
# TODO: batch this
for i, sega_param in enumerate(sega_params):
for key, dir in edit_dir_dict.items():
# calculate momentum scale and velocity
if key not in sega_param.v.keys():
slice_idx = 1 - dir.dim() # should be negative, for dim=4, slice_idx = -3
sega_param.v[key] = torch.zeros(dir.shape[slice_idx:], dtype=dir.dtype, device=dir.device)
# add to text condition
v_t = sega_param.v[key]
dir[i] = dir[i] + torch.mul(momentum_scale, v_t)
# calculate v_t+1 and update state
v_t_1 = momentum_beta * ((1 - momentum_beta) * v_t) * dir[i]
# add to cond after warmup elapsed
# for sd 1.5, we must add to the original params.text_cond because we reassigned text_cond
if sampling_step >= warmup_period:
if isinstance(params.text_cond, dict):
params.text_cond[key] = params.text_cond[key] + dir[i]
else:
params.text_cond = params.text_cond + dir[i]
# update velocity
sega_param.v[key] = v_t_1
# XYZ Plot
# Based on @mcmonkey4eva's XYZ Plot implementation here: https://github.com/mcmonkeyprojects/sd-dynamic-thresholding/blob/master/scripts/dynamic_thresholding.py
def sega_apply_override(field, boolean: bool = False):
def fun(p, x, xs):
if boolean:
x = True if x.lower() == "true" else False
setattr(p, field, x)
return fun
def sega_apply_field(field):
def fun(p, x, xs):
if not hasattr(p, "sega_active"):
setattr(p, "sega_active", True)
setattr(p, field, x)
return fun
def make_axis_options():
xyz_grid = [x for x in scripts.scripts_data if x.script_class.__module__ in ("xyz_grid.py", "scripts.xyz_grid")][0].module
extra_axis_options = {
xyz_grid.AxisOption("[Semantic Guidance] Active", str, sega_apply_override('sega_active', boolean=True), choices=xyz_grid.boolean_choice(reverse=True)),
xyz_grid.AxisOption("[Semantic Guidance] Prompt", str, sega_apply_field("sega_prompt")),
xyz_grid.AxisOption("[Semantic Guidance] Negative Prompt", str, sega_apply_field("sega_neg_prompt")),
xyz_grid.AxisOption("[Semantic Guidance] Warmup Steps", int, sega_apply_field("sega_warmup")),
xyz_grid.AxisOption("[Semantic Guidance] Guidance Scale", float, sega_apply_field("sega_edit_guidance_scale")),
xyz_grid.AxisOption("[Semantic Guidance] Tail Percentage Threshold", float, sega_apply_field("sega_tail_percentage_threshold")),
xyz_grid.AxisOption("[Semantic Guidance] Momentum Scale", float, sega_apply_field("sega_momentum_scale")),
xyz_grid.AxisOption("[Semantic Guidance] Momentum Beta", float, sega_apply_field("sega_momentum_beta")),
}
if not any("[Semantic Guidance]" in x.label for x in xyz_grid.axis_options):
xyz_grid.axis_options.extend(extra_axis_options)
def callback_before_ui():
try:
make_axis_options()
except:
logger.exception("Semantic Guidance: Error while making axis options")
script_callbacks.on_before_ui(callback_before_ui)