|
|
import gradio as gr
|
|
|
import torch
|
|
|
import re
|
|
|
import json
|
|
|
import ldm.modules.attention as atm
|
|
|
import modules.ui
|
|
|
import modules
|
|
|
from modules import prompt_parser
|
|
|
|
|
|
from modules import shared
|
|
|
from modules.script_callbacks import CFGDenoiserParams, on_cfg_denoiser, on_ui_settings
|
|
|
|
|
|
debug = False
|
|
|
debug_p = False
|
|
|
|
|
|
try:
|
|
|
from ldm_patched.modules import model_management
|
|
|
forge = True
|
|
|
except:
|
|
|
forge = False
|
|
|
|
|
|
OPT_ACT = "negpip_active"
|
|
|
OPT_HIDE = "negpip_hide"
|
|
|
|
|
|
NEGPIP_T = "customscript/negpip.py/txt2img/Active/value"
|
|
|
NEGPIP_I = "customscript/negpip.py/img2img/Active/value"
|
|
|
CONFIG = shared.cmd_opts.ui_config_file
|
|
|
|
|
|
with open(CONFIG, 'r', encoding="utf-8") as json_file:
|
|
|
ui_config = json.load(json_file)
|
|
|
|
|
|
startup_t = ui_config[NEGPIP_T] if NEGPIP_T in ui_config else None
|
|
|
startup_i = ui_config[NEGPIP_I] if NEGPIP_I in ui_config else None
|
|
|
active_t = "Active" if startup_t else "Not Active"
|
|
|
active_i = "Active" if startup_i else "Not Active"
|
|
|
|
|
|
opt_active = getattr(shared.opts,OPT_ACT, True)
|
|
|
opt_hideui = getattr(shared.opts,OPT_HIDE, False)
|
|
|
|
|
|
minusgetter = r'\(([^(:)]*):\s*-[\d]+(\.[\d]+)?(?:\s*)\)'
|
|
|
|
|
|
class Script(modules.scripts.Script):
|
|
|
def __init__(self):
|
|
|
self.active = False
|
|
|
self.conds = None
|
|
|
self.unconds = None
|
|
|
self.conlen = []
|
|
|
self.unlen = []
|
|
|
self.contokens = []
|
|
|
self.untokens = []
|
|
|
self.hr = False
|
|
|
self.x = None
|
|
|
|
|
|
self.ipa = None
|
|
|
|
|
|
self.enable_rp_latent = False
|
|
|
|
|
|
def title(self):
|
|
|
return "NegPiP"
|
|
|
|
|
|
def show(self, is_img2img):
|
|
|
return modules.scripts.AlwaysVisible
|
|
|
|
|
|
infotext_fields = None
|
|
|
paste_field_names = []
|
|
|
|
|
|
def ui(self, is_img2img):
|
|
|
with gr.Accordion(f"NegPiP : {active_i if is_img2img else active_t}",open = False, visible = not opt_hideui) as acc:
|
|
|
with gr.Row():
|
|
|
active = gr.Checkbox(value=False, label="Active",interactive=True)
|
|
|
toggle = gr.Button(elem_id="switch_default", value=f"Toggle startup with Active(Now:{startup_i if is_img2img else startup_t})",variant="primary")
|
|
|
|
|
|
def f_toggle(is_img2img):
|
|
|
key = NEGPIP_I if is_img2img else NEGPIP_T
|
|
|
|
|
|
with open(CONFIG, 'r', encoding="utf-8") as json_file:
|
|
|
data = json.load(json_file)
|
|
|
data[key] = not data[key]
|
|
|
|
|
|
with open(CONFIG, 'w', encoding="utf-8") as json_file:
|
|
|
json.dump(data, json_file, indent=4)
|
|
|
|
|
|
return gr.update(value = f"Toggle startup Active(Now:{data[key]})")
|
|
|
|
|
|
toggle.click(fn=f_toggle,inputs=[gr.Checkbox(value = is_img2img, visible = False)],outputs=[toggle])
|
|
|
active.change(fn=lambda x:gr.update(label = f"NegPiP : {'Active' if x else 'Not Active'}"),inputs=active, outputs=[acc])
|
|
|
|
|
|
self.infotext_fields = [
|
|
|
(active, "NegPiP Active"),
|
|
|
]
|
|
|
|
|
|
for _,name in self.infotext_fields:
|
|
|
self.paste_field_names.append(name)
|
|
|
|
|
|
return [active]
|
|
|
|
|
|
def process_batch(self, p, active,**kwargs):
|
|
|
self.__init__()
|
|
|
flag = False
|
|
|
|
|
|
if getattr(shared.opts,OPT_HIDE, False) and not getattr(shared.opts,OPT_ACT, False): return
|
|
|
elif not active: return
|
|
|
|
|
|
self.rpscript = None
|
|
|
|
|
|
from modules.scripts import scripts_txt2img
|
|
|
for script in scripts_txt2img.alwayson_scripts:
|
|
|
if "rp.py" in script.filename:
|
|
|
self.rpscript = script
|
|
|
|
|
|
self.hrp, self.hrn = hr_dealer(p)
|
|
|
|
|
|
self.active = active
|
|
|
self.batch = p.batch_size
|
|
|
self.isxl = hasattr(shared.sd_model,"conditioner")
|
|
|
|
|
|
self.rev = p.sampler_name in ["DDIM", "PLMS", "UniPC"]
|
|
|
if forge: self.rev = not self.rev
|
|
|
|
|
|
tokenizer = shared.sd_model.conditioner.embedders[0].tokenize_line if self.isxl else shared.sd_model.cond_stage_model.tokenize_line
|
|
|
|
|
|
|
|
|
def getshedulednegs(scheduled,prompts):
|
|
|
output = []
|
|
|
nonlocal flag
|
|
|
for i, batch_shedule in enumerate(scheduled):
|
|
|
stepout = []
|
|
|
seps = None
|
|
|
if self.rpscript:
|
|
|
if hasattr(self.rpscript,"seps"):
|
|
|
seps = self.rpscript.seps
|
|
|
self.enable_rp_latent = seps == "AND"
|
|
|
|
|
|
for step,prompt in batch_shedule:
|
|
|
sep_prompts = prompt.split(seps) if seps else [prompt]
|
|
|
padd = 0
|
|
|
padtextweight = []
|
|
|
for sep_prompt in sep_prompts:
|
|
|
minusmatches = re.finditer(minusgetter, sep_prompt)
|
|
|
minus_targets = []
|
|
|
textweights = []
|
|
|
for minusmatch in minusmatches:
|
|
|
minus_targets.append(minusmatch.group().replace("(","").replace(")",""))
|
|
|
|
|
|
prompts[i] = prompts[i].replace(minusmatch.group(),"")
|
|
|
minus_targets = [x.split(":") for x in minus_targets]
|
|
|
|
|
|
for text,weight in minus_targets:
|
|
|
weight = float(weight)
|
|
|
if text == "BREAK": continue
|
|
|
if weight < 0:
|
|
|
textweights.append([text,weight])
|
|
|
flag = True
|
|
|
padtextweight.append([padd,textweights])
|
|
|
tokens, tokensnum = tokenizer(sep_prompt)
|
|
|
padd = tokensnum // 75 + 1 + padd
|
|
|
stepout.append([step,padtextweight])
|
|
|
output.append(stepout)
|
|
|
return output
|
|
|
|
|
|
scheduled_p = prompt_parser.get_learned_conditioning_prompt_schedules(p.prompts,p.steps)
|
|
|
scheduled_np = prompt_parser.get_learned_conditioning_prompt_schedules(p.negative_prompts,p.steps)
|
|
|
|
|
|
if self.hrp: scheduled_hr_p = prompt_parser.get_learned_conditioning_prompt_schedules(p.hr_prompts,p.hr_second_pass_steps if p.hr_second_pass_steps > 0 else p.steps)
|
|
|
if self.hrn: scheduled_hr_np = prompt_parser.get_learned_conditioning_prompt_schedules(p.hr_negative_prompts,p.hr_second_pass_steps if p.hr_second_pass_steps > 0 else p.steps)
|
|
|
|
|
|
nip = getshedulednegs(scheduled_p,p.prompts)
|
|
|
pin = getshedulednegs(scheduled_np,p.negative_prompts)
|
|
|
|
|
|
if self.hrp: hr_nip = getshedulednegs(scheduled_hr_p,p.hr_prompts)
|
|
|
if self.hrn: hr_pin = getshedulednegs(scheduled_hr_np,p.hr_negative_prompts)
|
|
|
|
|
|
def conddealer(targets):
|
|
|
conds =[]
|
|
|
start = None
|
|
|
end = None
|
|
|
for target in targets:
|
|
|
input = SdConditioning([f"({target[0]}:{-target[1]})"], width=p.width, height=p.height)
|
|
|
cond = prompt_parser.get_learned_conditioning(shared.sd_model,input,p.steps)
|
|
|
if start is None: start = cond[0][0].cond[0:1,:] if not self.isxl else cond[0][0].cond["crossattn"][0:1,:]
|
|
|
if end is None: end = cond[0][0].cond[-1:,:] if not self.isxl else cond[0][0].cond["crossattn"][-1:,:]
|
|
|
token, tokenlen = tokenizer(target[0])
|
|
|
conds.append(cond[0][0].cond[1:tokenlen +2,:] if not self.isxl else cond[0][0].cond["crossattn"][1:tokenlen +2,:] )
|
|
|
conds = torch.cat(conds, 0)
|
|
|
|
|
|
conds = torch.split(conds, 75, dim=0)
|
|
|
condsout = []
|
|
|
condcount = []
|
|
|
for cond in conds:
|
|
|
condcount.append(cond.shape[0])
|
|
|
repeat = 0 if cond.shape[0] == 75 else 75 - cond.shape[0]
|
|
|
cond = torch.cat((start,cond,end.repeat(repeat + 1,1)),0)
|
|
|
condsout.append(cond)
|
|
|
condout = torch.cat(condsout,0).unsqueeze(0)
|
|
|
return condout.repeat(self.batch,1,1), condcount
|
|
|
|
|
|
def calcconds(targetlist):
|
|
|
outconds = []
|
|
|
for batch in targetlist:
|
|
|
stepconds = []
|
|
|
for step, regions in batch:
|
|
|
regionconds = []
|
|
|
for region, targets in regions:
|
|
|
if targets:
|
|
|
conds, contokens = conddealer(targets)
|
|
|
regionconds.append([region, conds, contokens])
|
|
|
else:
|
|
|
regionconds.append([region, None, None])
|
|
|
stepconds.append([step,regionconds])
|
|
|
outconds.append(stepconds)
|
|
|
return outconds
|
|
|
|
|
|
self.conds_all = calcconds(nip)
|
|
|
self.unconds_all = calcconds(pin)
|
|
|
|
|
|
if self.hrp: self.hr_conds_all = calcconds(hr_nip)
|
|
|
if self.hrn: self.hr_unconds_all = calcconds(hr_pin)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
resetpcache(p)
|
|
|
|
|
|
def calcsets(A, B):
|
|
|
return A // B if A % B == 0 else A // B + 1
|
|
|
|
|
|
self.conlen = calcsets(tokenizer(p.prompts[0])[1],75)
|
|
|
self.unlen = calcsets(tokenizer(p.negative_prompts[0])[1],75)
|
|
|
|
|
|
if not flag:
|
|
|
self.active = False
|
|
|
unload(self,p)
|
|
|
return
|
|
|
|
|
|
if not hasattr(self,"negpip_dr_callbacks"):
|
|
|
self.negpip_dr_callbacks = on_cfg_denoiser(self.denoiser_callback)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
already_hooked = False
|
|
|
if self.rpscript is not None and hasattr(self.rpscript,"hooked"):already_hooked = self.rpscript.hooked
|
|
|
|
|
|
if not already_hooked:
|
|
|
self.handle = hook_forwards(self, p.sd_model.model.diffusion_model)
|
|
|
|
|
|
print(f"NegPiP enable, Positive:{self.conds_all[0][0][1][0][2]},Negative:{self.unconds_all[0][0][1][0][2]}")
|
|
|
|
|
|
p.extra_generation_params.update({
|
|
|
"NegPiP Active":active,
|
|
|
})
|
|
|
|
|
|
def postprocess(self, p, processed, *args):
|
|
|
unload(self,p)
|
|
|
self.conds_all = None
|
|
|
self.unconds_all = None
|
|
|
|
|
|
def denoiser_callback(self, params: CFGDenoiserParams):
|
|
|
if debug: print(params.text_cond.shape)
|
|
|
if self.active:
|
|
|
if self.x is None: self.x = params.x.shape
|
|
|
if self.x != params.x.shape: self.hr = True
|
|
|
|
|
|
self.latenti = 0
|
|
|
|
|
|
condslist = []
|
|
|
tokenslist = []
|
|
|
conds = self.hr_conds_all if self.hrp and self.hr else self.conds_all
|
|
|
if conds is not None:
|
|
|
for step, regions in conds[0]:
|
|
|
if step >= params.sampling_step + 2:
|
|
|
for region, conds, tokens in regions:
|
|
|
condslist.append(conds)
|
|
|
tokenslist.append(tokens)
|
|
|
if debug: print(f"current:{params.sampling_step + 2},selected:{step}")
|
|
|
break
|
|
|
self.conds = condslist
|
|
|
self.contokens = tokenslist
|
|
|
|
|
|
uncondslist = []
|
|
|
untokenslist = []
|
|
|
unconds = self.hr_unconds_all if self.hrn and self.hr else self.unconds_all
|
|
|
if unconds is not None:
|
|
|
for step, regions in unconds[0]:
|
|
|
if step >= params.sampling_step + 2:
|
|
|
for region, unconds, untokens in regions:
|
|
|
uncondslist.append(unconds)
|
|
|
untokenslist.append(untokens)
|
|
|
break
|
|
|
|
|
|
self.unconds = uncondslist
|
|
|
self.untokens = untokenslist
|
|
|
|
|
|
from pprint import pprint
|
|
|
|
|
|
def unload(self,p):
|
|
|
if hasattr(self,"handle"):
|
|
|
hook_forwards(self, p.sd_model.model.diffusion_model, remove=True)
|
|
|
del self.handle
|
|
|
|
|
|
def hook_forward(self, module):
|
|
|
def forward(x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0, value = None, transformer_options=None):
|
|
|
if debug: print(" x.shape:",x.shape,"context.shape:",context.shape,"self.contokens",self.contokens,"self.untokens",self.untokens)
|
|
|
|
|
|
def sub_forward(x, context, mask, additional_tokens, n_times_crossframe_attn_in_self,conds,contokens,unconds,untokens, latent = None):
|
|
|
if debug: print(" x.shape[0]:",x.shape[0],"batch:",self.batch *2)
|
|
|
|
|
|
if x.shape[0] == self.batch *2:
|
|
|
if debug: print(" x.shape[0] == self.batch *2")
|
|
|
|
|
|
if self.rev:
|
|
|
contn,contp = context.chunk(2)
|
|
|
ixn,ixp = x.chunk(2)
|
|
|
else:
|
|
|
contp,contn = context.chunk(2)
|
|
|
ixp,ixn = x.chunk(2)
|
|
|
|
|
|
if conds is not None:
|
|
|
if contp.shape[0] != conds.shape[0]:
|
|
|
conds = conds.expand(contp.shape[0],-1,-1)
|
|
|
contp = torch.cat((contp,conds),1)
|
|
|
if unconds is not None:
|
|
|
if contn.shape[0] != unconds.shape[0]:
|
|
|
unconds = unconds.expand(contn.shape[0],-1,-1)
|
|
|
contn = torch.cat((contn,unconds),1)
|
|
|
xp = main_foward(self, module, ixp,contp,mask,additional_tokens,n_times_crossframe_attn_in_self,contokens)
|
|
|
xn = main_foward(self, module, ixn,contn,mask,additional_tokens,n_times_crossframe_attn_in_self,untokens)
|
|
|
|
|
|
out = torch.cat([xn,xp]) if self.rev else torch.cat([xp,xn])
|
|
|
return out
|
|
|
|
|
|
elif latent is not None:
|
|
|
if debug:print(" latent is not None")
|
|
|
if latent:
|
|
|
conds = conds if conds is not None else None
|
|
|
else:
|
|
|
conds = unconds if unconds is not None else None
|
|
|
if conds is not None:
|
|
|
if context.shape[0] != conds.shape[0]:
|
|
|
conds = conds.expand(context.shape[0],-1,-1)
|
|
|
context = torch.cat([context,conds],1)
|
|
|
|
|
|
tokens = contokens if contokens is not None else untokens
|
|
|
|
|
|
out = main_foward(self, module, x,context,mask,additional_tokens,n_times_crossframe_attn_in_self,tokens)
|
|
|
return out
|
|
|
|
|
|
else:
|
|
|
if debug:
|
|
|
print(" Else")
|
|
|
print(context.shape[1] , self.conlen,self.unlen)
|
|
|
|
|
|
tokens = []
|
|
|
concon = counter(self.isxl)
|
|
|
if debug: print(concon)
|
|
|
if context.shape[1] == self.conlen * 77 and concon:
|
|
|
if conds is not None:
|
|
|
if context.shape[0] != conds.shape[0]:
|
|
|
conds = conds.expand(context.shape[0],-1,-1)
|
|
|
context = torch.cat([context,conds],1)
|
|
|
tokens = contokens
|
|
|
elif context.shape[1] == self.unlen * 77 and concon:
|
|
|
if unconds is not None:
|
|
|
if context.shape[0] != unconds.shape[0]:
|
|
|
unconds = unconds.expand(context.shape[0],-1,-1)
|
|
|
context = torch.cat([context,unconds],1)
|
|
|
tokens = untokens
|
|
|
out = main_foward(self, module, x,context,mask,additional_tokens,n_times_crossframe_attn_in_self,tokens)
|
|
|
return out
|
|
|
|
|
|
if self.enable_rp_latent:
|
|
|
if len(self.conds) - 1 >= self.latenti:
|
|
|
out = sub_forward(x, context, mask, additional_tokens, n_times_crossframe_attn_in_self,self.conds[self.latenti],self.contokens[self.latenti],None,None ,latent = True)
|
|
|
self.latenti += 1
|
|
|
else:
|
|
|
out = sub_forward(x, context, mask, additional_tokens, n_times_crossframe_attn_in_self,None,None,self.unconds[0],self.untokens[0], latent = False)
|
|
|
self.latenti = 0
|
|
|
return out
|
|
|
else:
|
|
|
if self.conds is not None and self.unconds is not None and len(self.conds) > 0 and len(self.unconds) > 0:
|
|
|
return sub_forward(x, context, mask, additional_tokens, n_times_crossframe_attn_in_self,self.conds[0],self.contokens[0],self.unconds[0],self.untokens[0])
|
|
|
else:
|
|
|
return sub_forward(x, context, mask, additional_tokens, n_times_crossframe_attn_in_self,None,None,None,None)
|
|
|
|
|
|
return forward
|
|
|
|
|
|
count = 0
|
|
|
pn = True
|
|
|
|
|
|
def counter(isxl):
|
|
|
global count, pn
|
|
|
count += 1
|
|
|
|
|
|
limit = 70 if isxl else 16
|
|
|
outpn = pn
|
|
|
|
|
|
if count == limit:
|
|
|
pn = not pn
|
|
|
count = 0
|
|
|
return outpn
|
|
|
|
|
|
def main_foward(self, module, x, context, mask, additional_tokens, n_times_crossframe_attn_in_self, tokens):
|
|
|
h = module.heads
|
|
|
context = context.to(x.dtype)
|
|
|
q = module.to_q(x)
|
|
|
|
|
|
context = atm.default(context, x)
|
|
|
k = module.to_k(context)
|
|
|
v = module.to_v(context)
|
|
|
if debug: print(h,context.shape,q.shape,k.shape,v.shape)
|
|
|
|
|
|
_, _, dim_head = q.shape
|
|
|
dim_head //= h
|
|
|
scale = dim_head ** -0.5
|
|
|
|
|
|
q, k, v = map(lambda t: atm.rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
|
|
sim = atm.einsum('b i d, b j d -> b i j', q, k) * scale
|
|
|
|
|
|
if self.active:
|
|
|
if tokens:
|
|
|
for token in tokens:
|
|
|
start = (v.shape[1]//77 - len(tokens)) * 77
|
|
|
|
|
|
v[:,start+1:start+token,:] = -v[:,start+1:start+token,:]
|
|
|
|
|
|
if atm.exists(mask):
|
|
|
mask = atm.rearrange(mask, 'b ... -> b (...)')
|
|
|
max_neg_value = -torch.finfo(sim.dtype).max
|
|
|
mask = atm.repeat(mask, 'b j -> (b h) () j', h=h)
|
|
|
sim.masked_fill_(~mask, max_neg_value)
|
|
|
|
|
|
attn = sim.softmax(dim=-1)
|
|
|
|
|
|
out = atm.einsum('b i j, b j d -> b i d', attn, v)
|
|
|
|
|
|
out = atm.rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
|
|
|
|
|
return module.to_out(out)
|
|
|
|
|
|
import inspect
|
|
|
|
|
|
def hook_forwards(self, root_module: torch.nn.Module, remove=False):
|
|
|
for name, module in root_module.named_modules():
|
|
|
if "attn2" in name and module.__class__.__name__ == "CrossAttention":
|
|
|
module.forward = hook_forward(self, module)
|
|
|
if remove:
|
|
|
del module.forward
|
|
|
|
|
|
def resetpcache(p):
|
|
|
p.cached_c = [None,None]
|
|
|
p.cached_uc = [None,None]
|
|
|
p.cached_hr_c = [None, None]
|
|
|
p.cached_hr_uc = [None, None]
|
|
|
|
|
|
|
|
|
class SdConditioning(list):
|
|
|
def __init__(self, prompts, is_negative_prompt=False, width=None, height=None, copy_from=None):
|
|
|
super().__init__()
|
|
|
self.extend(prompts)
|
|
|
|
|
|
if copy_from is None:
|
|
|
copy_from = prompts
|
|
|
|
|
|
self.is_negative_prompt = is_negative_prompt or getattr(copy_from, 'is_negative_prompt', False)
|
|
|
self.width = width or getattr(copy_from, 'width', None)
|
|
|
self.height = height or getattr(copy_from, 'height', None)
|
|
|
|
|
|
def ext_on_ui_settings():
|
|
|
|
|
|
negpip_options = [
|
|
|
(OPT_HIDE, False, "Hide in Txt2Img/Img2Img tab(Reload UI required)"),
|
|
|
(OPT_ACT, True, "Active(Effective when Hide is Checked)",),
|
|
|
]
|
|
|
section = ('negpip', "NegPiP")
|
|
|
|
|
|
for cur_setting_name, *option_info in negpip_options:
|
|
|
shared.opts.add_option(cur_setting_name, shared.OptionInfo(*option_info, section=section))
|
|
|
|
|
|
on_ui_settings(ext_on_ui_settings)
|
|
|
|
|
|
def hr_dealer(p):
|
|
|
if not hasattr(p, "hr_prompts"):
|
|
|
p.hr_prompts = None
|
|
|
if not hasattr(p, "hr_negative_prompts"):
|
|
|
p.hr_negative_prompts = None
|
|
|
|
|
|
return bool(p.hr_prompts), bool(p.hr_negative_prompts )
|
|
|
|