File size: 21,876 Bytes
3dabe4a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 |
import logging
from os import environ
import modules.scripts as scripts
import gradio as gr
import scipy.stats as stats
from modules import script_callbacks, prompt_parser
from modules.script_callbacks import CFGDenoiserParams
from modules.prompt_parser import reconstruct_multicond_batch
from modules.processing import StableDiffusionProcessing
#from modules.shared import sd_model, opts
from modules.sd_samplers_cfg_denoiser import pad_cond
from modules import shared
import torch
logger = logging.getLogger(__name__)
logger.setLevel(environ.get("SD_WEBUI_LOG_LEVEL", logging.INFO))
"""
An unofficial implementation of SEGA: Instructing Text-to-Image Models using Semantic Guidance for Automatic1111 WebUI
@misc{brack2023sega,
title={SEGA: Instructing Text-to-Image Models using Semantic Guidance},
author={Manuel Brack and Felix Friedrich and Dominik Hintersdorf and Lukas Struppek and Patrick Schramowski and Kristian Kersting},
year={2023},
eprint={2301.12247},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
Author: v0xie
GitHub URL: https://github.com/v0xie/sd-webui-semantic-guidance
"""
class SegaStateParams:
def __init__(self):
self.concept_name = ''
self.v = {} # velocity
self.warmup_period: int = 10 # [0, 20]
self.edit_guidance_scale: float = 1 # [0., 1.]
self.tail_percentage_threshold: float = 0.05 # [0., 1.] if abs value of difference between uncodition and concept-conditioned is less than this, then zero out the concept-conditioned values less than this
self.momentum_scale: float = 0.3 # [0., 1.]
self.momentum_beta: float = 0.6 # [0., 1.) # larger bm is less volatile changes in momentum
self.strength = 1.0
class SegaExtensionScript(scripts.Script):
def __init__(self):
self.cached_c = [None, None]
# Extension title in menu UI
def title(self):
return "Semantic Guidance"
# Decide to show menu in txt2img or img2img
def show(self, is_img2img):
return scripts.AlwaysVisible
# Setup menu ui detail
def ui(self, is_img2img):
with gr.Accordion('Semantic Guidance', open=False):
active = gr.Checkbox(value=False, default=False, label="Active", elem_id='sega_active')
with gr.Row():
prompt = gr.Textbox(lines=2, label="Prompt", elem_id = 'sega_prompt', elem_classes=["prompt"])
with gr.Row():
neg_prompt = gr.Textbox(lines=2, label="Negative Prompt", elem_id = 'sega_neg_prompt', elem_classes=["prompt"])
with gr.Row():
warmup = gr.Slider(value = 10, minimum = 0, maximum = 30, step = 1, label="Warmup Period", elem_id = 'sega_warmup', info="How many steps to wait before applying semantic guidance, default 10")
edit_guidance_scale = gr.Slider(value = 1.0, minimum = 0.0, maximum = 20.0, step = 0.01, label="Edit Guidance Scale", elem_id = 'sega_edit_guidance_scale', info="Scale of edit guidance, default 1.0")
tail_percentage_threshold = gr.Slider(value = 0.05, minimum = 0.0, maximum = 1.0, step = 0.01, label="Tail Percentage Threshold", elem_id = 'sega_tail_percentage_threshold', info="The percentage of latents to modify, default 0.05")
momentum_scale = gr.Slider(value = 0.3, minimum = 0.0, maximum = 1.0, step = 0.01, label="Momentum Scale", elem_id = 'sega_momentum_scale', info="Scale of momentum, default 0.3")
momentum_beta = gr.Slider(value = 0.6, minimum = 0.0, maximum = 0.999, step = 0.01, label="Momentum Beta", elem_id = 'sega_momentum_beta', info="Beta for momentum, default 0.6")
active.do_not_save_to_config = True
prompt.do_not_save_to_config = True
neg_prompt.do_not_save_to_config = True
warmup.do_not_save_to_config = True
edit_guidance_scale.do_not_save_to_config = True
tail_percentage_threshold.do_not_save_to_config = True
momentum_scale.do_not_save_to_config = True
momentum_beta.do_not_save_to_config = True
self.infotext_fields = [
(active, lambda d: gr.Checkbox.update(value='SEGA Active' in d)),
(prompt, 'SEGA Prompt'),
(neg_prompt, 'SEGA Negative Prompt'),
(warmup, 'SEGA Warmup Period'),
(edit_guidance_scale, 'SEGA Edit Guidance Scale'),
(tail_percentage_threshold, 'SEGA Tail Percentage Threshold'),
(momentum_scale, 'SEGA Momentum Scale'),
(momentum_beta, 'SEGA Momentum Beta'),
]
self.paste_field_names = [
'sega_active',
'sega_prompt',
'sega_neg_prompt',
'sega_warmup',
'sega_edit_guidance_scale',
'sega_tail_percentage_threshold',
'sega_momentum_scale',
'sega_momentum_beta'
]
return [active, prompt, neg_prompt, warmup, edit_guidance_scale, tail_percentage_threshold, momentum_scale, momentum_beta]
def process_batch(self, p: StableDiffusionProcessing, active, prompt, neg_prompt, warmup, edit_guidance_scale, tail_percentage_threshold, momentum_scale, momentum_beta, *args, **kwargs):
active = getattr(p, "sega_active", active)
if active is False:
return
prompt = getattr(p, "sega_prompt", prompt)
neg_prompt = getattr(p, "sega_neg_prompt", neg_prompt)
warmup = getattr(p, "sega_warmup", warmup)
edit_guidance_scale = getattr(p, "sega_edit_guidance_scale", edit_guidance_scale)
tail_percentage_threshold = getattr(p, "sega_tail_percentage_threshold", tail_percentage_threshold)
momentum_scale = getattr(p, "sega_momentum_scale", momentum_scale)
momentum_beta = getattr(p, "sega_momentum_beta", momentum_beta)
# FIXME: must have some prompt
#if prompt is None:
# return
#if len(prompt) == 0:
# return
p.extra_generation_params.update({
"SEGA Active": active,
"SEGA Prompt": prompt,
"SEGA Negative Prompt": neg_prompt,
"SEGA Warmup Period": warmup,
"SEGA Edit Guidance Scale": edit_guidance_scale,
"SEGA Tail Percentage Threshold": tail_percentage_threshold,
"SEGA Momentum Scale": momentum_scale,
"SEGA Momentum Beta": momentum_beta,
})
# separate concepts by comma
concept_prompts = self.parse_concept_prompt(prompt)
concept_prompts_neg = self.parse_concept_prompt(neg_prompt)
# [[concept_1, strength_1], ...]
concept_prompts = [prompt_parser.parse_prompt_attention(concept)[0] for concept in concept_prompts]
concept_prompts_neg = [prompt_parser.parse_prompt_attention(neg_concept)[0] for neg_concept in concept_prompts_neg]
concept_prompts_neg = [[concept, -strength] for concept, strength in concept_prompts_neg]
concept_prompts.extend(concept_prompts_neg)
concept_conds = []
for concept, strength in concept_prompts:
prompt_list = [concept] * p.batch_size
prompts = prompt_parser.SdConditioning(prompt_list, width=p.width, height=p.height)
c = p.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, p.steps, [self.cached_c], p.extra_network_data)
concept_conds.append([c, strength])
self.create_hook(p, active, concept_conds, None, warmup, edit_guidance_scale, tail_percentage_threshold, momentum_scale, momentum_beta)
def parse_concept_prompt(self, prompt:str) -> list[str]:
"""
Separate prompt by comma into a list of concepts
TODO: parse prompt into a list of concepts using A1111 functions
>>> g = lambda prompt: self.parse_concept_prompt(prompt)
>>> g("")
[]
>>> g("apples")
['apples']
>>> g("apple, banana, carrot")
['apple', 'banana', 'carrot']
"""
if len(prompt) == 0:
return []
return [x.strip() for x in prompt.split(",")]
def create_hook(self, p, active, concept_conds, concept_conds_neg, warmup, edit_guidance_scale, tail_percentage_threshold, momentum_scale, momentum_beta, *args, **kwargs):
# Create a list of parameters for each concept
concepts_sega_params = []
for _, strength in concept_conds:
sega_params = SegaStateParams()
sega_params.warmup_period = warmup
sega_params.edit_guidance_scale = edit_guidance_scale
sega_params.tail_percentage_threshold = tail_percentage_threshold
sega_params.momentum_scale = momentum_scale
sega_params.momentum_beta = momentum_beta
sega_params.strength = strength
concepts_sega_params.append(sega_params)
# Use lambda to call the callback function with the parameters to avoid global variables
y = lambda params: self.on_cfg_denoiser_callback(params, concept_conds, concepts_sega_params)
logger.debug('Hooked callbacks')
script_callbacks.on_cfg_denoiser(y)
script_callbacks.on_script_unloaded(self.unhook_callbacks)
def postprocess_batch(self, p, active, neg_text, *args, **kwargs):
active = getattr(p, "sega_active", active)
if active is False:
return
self.unhook_callbacks()
def unhook_callbacks(self):
logger.debug('Unhooked callbacks')
script_callbacks.remove_current_script_callbacks()
def on_cfg_denoiser_callback(self, params: CFGDenoiserParams, concept_conds, sega_params: list[SegaStateParams]):
# TODO: add option to opt out of batching for performance
sampling_step = params.sampling_step
text_cond = params.text_cond
text_uncond = params.text_uncond
# pad text_cond or text_uncond to match the length of the longest prompt
# i would prefer to let sd_samplers_cfg_denoiser.py handle the padding, but
# there isn't a callback that returns the padded conds
if text_cond.shape[1] != text_uncond.shape[1]:
empty = shared.sd_model.cond_stage_model_empty_prompt
num_repeats = (text_cond.shape[1] - text_uncond.shape[1]) // empty.shape[1]
if num_repeats < 0:
text_cond = pad_cond(text_cond, -num_repeats, empty)
elif num_repeats > 0:
text_uncond = pad_cond(text_uncond, num_repeats, empty)
batch_conds_list = []
batch_tensor = {}
# sd 1.5 support
if isinstance(text_cond, torch.Tensor):
text_cond = {'crossattn': text_cond}
if isinstance(text_uncond, torch.Tensor):
text_uncond = {'crossattn': text_uncond}
for i, _ in enumerate(sega_params):
concept_cond, _ = concept_conds[i]
conds_list, tensor_dict = reconstruct_multicond_batch(concept_cond, sampling_step)
# sd 1.5 support
if isinstance(tensor_dict, torch.Tensor):
tensor_dict = {'crossattn': tensor_dict}
# initialize here because we don't know the shape/dtype of the tensor until we reconstruct it
for key, tensor in tensor_dict.items():
if tensor.shape[1] != text_uncond[key].shape[1]:
empty = shared.sd_model.cond_stage_model_empty_prompt
# sd 1.5
if key == "crossattn":
num_repeats = (tensor.shape[1] - text_uncond[key].shape[1]) // empty.shape[1]
# sdxl
else:
num_repeats = (tensor.shape[1] - text_uncond.shape[1]) // empty.shape[1]
if num_repeats < 0:
tensor = pad_cond(tensor, -num_repeats, empty)
tensor = tensor.unsqueeze(0)
if key not in batch_tensor.keys():
batch_tensor[key] = tensor
else:
batch_tensor[key] = torch.cat((batch_tensor[key], tensor), dim=0)
batch_conds_list.append(conds_list)
self.sega_routine_batch(params, batch_conds_list, batch_tensor, sega_params, text_cond, text_uncond)
def make_tuple_dim(self, dim):
# sd 1.5 support
if isinstance(dim, torch.Tensor):
dim = dim.dim()
return (-1,) + (1,) * (dim - 1)
def sega_routine_batch(self, params: CFGDenoiserParams, batch_conds_list, batch_tensor, sega_params: list[SegaStateParams], text_cond, text_uncond):
# FIXME: these parameters should be specific to each concept
warmup_period = sega_params[0].warmup_period
edit_guidance_scale = sega_params[0].edit_guidance_scale
tail_percentage_threshold = sega_params[0].tail_percentage_threshold
momentum_scale = sega_params[0].momentum_scale
momentum_beta = sega_params[0].momentum_beta
sampling_step = params.sampling_step
# Semantic Guidance
edit_dir_dict = {}
# batch_tensor: [num_concepts, batch_size, tokens(77, 154, etc.), 2048]
# Calculate edit direction
for key, concept_cond in batch_tensor.items():
new_shape = self.make_tuple_dim(concept_cond)
strength = torch.Tensor([params.strength for params in sega_params]).to(dtype=concept_cond.dtype, device=concept_cond.device)
strength = strength.view(new_shape)
if key not in edit_dir_dict.keys():
edit_dir_dict[key] = torch.zeros_like(concept_cond, dtype=concept_cond.dtype, device=concept_cond.device)
# filter out values in-between tails
# FIXME: does this take into account image batch size?, i.e. dim 1
inside_dim = tuple(range(-concept_cond.dim() + 1, 0)) # for tensor of dim 4, returns (-3, -2, -1), for tensor of dim 3, returns (-2, -1)
cond_mean, cond_std = torch.mean(concept_cond, dim=inside_dim), torch.std(concept_cond, dim=inside_dim)
# broadcast element-wise subtraction
edit_dir = concept_cond - text_uncond[key]
# multiply by strength for positive / negative direction
edit_dir = torch.mul(strength, edit_dir)
# z-scores for tails
upper_z = stats.norm.ppf(1.0 - tail_percentage_threshold)
# numerical thresholds
# FIXME: does this take into account image batch size?, i.e. dim 1
upper_threshold = cond_mean + (upper_z * cond_std)
# reshape to be able to broadcast / use torch.where to filter out values for each concept
#new_shape = (-1,) + (1,) * (concept_cond.dim() - 1)
new_shape = self.make_tuple_dim(concept_cond)
upper_threshold_reshaped = upper_threshold.view(new_shape)
# zero out values in-between tails
# elementwise multiplication between scale tensor and edit direction
zero_tensor = torch.zeros_like(concept_cond, dtype=concept_cond.dtype, device=concept_cond.device)
scale_tensor = torch.ones_like(concept_cond, dtype=concept_cond.dtype, device=concept_cond.device) * edit_guidance_scale
edit_dir_abs = edit_dir.abs()
scale_tensor = torch.where((edit_dir_abs > upper_threshold_reshaped), scale_tensor, zero_tensor)
# update edit direction with the edit dir for this concept
guidance_strength = 0.0 if sampling_step < warmup_period else 1.0 # FIXME: Use appropriate guidance strength
edit_dir = torch.mul(scale_tensor, edit_dir)
edit_dir_dict[key] = edit_dir_dict[key] + guidance_strength * edit_dir
# TODO: batch this
for i, sega_param in enumerate(sega_params):
for key, dir in edit_dir_dict.items():
# calculate momentum scale and velocity
if key not in sega_param.v.keys():
slice_idx = 1 - dir.dim() # should be negative, for dim=4, slice_idx = -3
sega_param.v[key] = torch.zeros(dir.shape[slice_idx:], dtype=dir.dtype, device=dir.device)
# add to text condition
v_t = sega_param.v[key]
dir[i] = dir[i] + torch.mul(momentum_scale, v_t)
# calculate v_t+1 and update state
v_t_1 = momentum_beta * ((1 - momentum_beta) * v_t) * dir[i]
# add to cond after warmup elapsed
# for sd 1.5, we must add to the original params.text_cond because we reassigned text_cond
if sampling_step >= warmup_period:
if isinstance(params.text_cond, dict):
params.text_cond[key] = params.text_cond[key] + dir[i]
else:
params.text_cond = params.text_cond + dir[i]
# update velocity
sega_param.v[key] = v_t_1
# XYZ Plot
# Based on @mcmonkey4eva's XYZ Plot implementation here: https://github.com/mcmonkeyprojects/sd-dynamic-thresholding/blob/master/scripts/dynamic_thresholding.py
def sega_apply_override(field, boolean: bool = False):
def fun(p, x, xs):
if boolean:
x = True if x.lower() == "true" else False
setattr(p, field, x)
return fun
def sega_apply_field(field):
def fun(p, x, xs):
if not hasattr(p, "sega_active"):
setattr(p, "sega_active", True)
setattr(p, field, x)
return fun
def make_axis_options():
xyz_grid = [x for x in scripts.scripts_data if x.script_class.__module__ in ("xyz_grid.py", "scripts.xyz_grid")][0].module
extra_axis_options = {
xyz_grid.AxisOption("[Semantic Guidance] Active", str, sega_apply_override('sega_active', boolean=True), choices=xyz_grid.boolean_choice(reverse=True)),
xyz_grid.AxisOption("[Semantic Guidance] Prompt", str, sega_apply_field("sega_prompt")),
xyz_grid.AxisOption("[Semantic Guidance] Negative Prompt", str, sega_apply_field("sega_neg_prompt")),
xyz_grid.AxisOption("[Semantic Guidance] Warmup Steps", int, sega_apply_field("sega_warmup")),
xyz_grid.AxisOption("[Semantic Guidance] Guidance Scale", float, sega_apply_field("sega_edit_guidance_scale")),
xyz_grid.AxisOption("[Semantic Guidance] Tail Percentage Threshold", float, sega_apply_field("sega_tail_percentage_threshold")),
xyz_grid.AxisOption("[Semantic Guidance] Momentum Scale", float, sega_apply_field("sega_momentum_scale")),
xyz_grid.AxisOption("[Semantic Guidance] Momentum Beta", float, sega_apply_field("sega_momentum_beta")),
}
if not any("[Semantic Guidance]" in x.label for x in xyz_grid.axis_options):
xyz_grid.axis_options.extend(extra_axis_options)
def callback_before_ui():
try:
make_axis_options()
except:
logger.exception("Semantic Guidance: Error while making axis options")
script_callbacks.on_before_ui(callback_before_ui)
|