import gradio as gr import torch import re import json import ldm.modules.attention as atm import modules.ui import modules from modules import prompt_parser from modules import shared from modules.script_callbacks import CFGDenoiserParams, on_cfg_denoiser, on_ui_settings debug = False debug_p = False try: from ldm_patched.modules import model_management forge = True except: forge = False OPT_ACT = "negpip_active" OPT_HIDE = "negpip_hide" NEGPIP_T = "customscript/negpip.py/txt2img/Active/value" NEGPIP_I = "customscript/negpip.py/img2img/Active/value" CONFIG = shared.cmd_opts.ui_config_file with open(CONFIG, 'r', encoding="utf-8") as json_file: ui_config = json.load(json_file) startup_t = ui_config[NEGPIP_T] if NEGPIP_T in ui_config else None startup_i = ui_config[NEGPIP_I] if NEGPIP_I in ui_config else None active_t = "Active" if startup_t else "Not Active" active_i = "Active" if startup_i else "Not Active" opt_active = getattr(shared.opts,OPT_ACT, True) opt_hideui = getattr(shared.opts,OPT_HIDE, False) minusgetter = r'\(([^(:)]*):\s*-[\d]+(\.[\d]+)?(?:\s*)\)' class Script(modules.scripts.Script): def __init__(self): self.active = False self.conds = None self.unconds = None self.conlen = [] self.unlen = [] self.contokens = [] self.untokens = [] self.hr = False self.x = None self.ipa = None self.enable_rp_latent = False def title(self): return "NegPiP" def show(self, is_img2img): return modules.scripts.AlwaysVisible infotext_fields = None paste_field_names = [] def ui(self, is_img2img): with gr.Accordion(f"NegPiP : {active_i if is_img2img else active_t}",open = False, visible = not opt_hideui) as acc: with gr.Row(): active = gr.Checkbox(value=False, label="Active",interactive=True) toggle = gr.Button(elem_id="switch_default", value=f"Toggle startup with Active(Now:{startup_i if is_img2img else startup_t})",variant="primary") def f_toggle(is_img2img): key = NEGPIP_I if is_img2img else NEGPIP_T with open(CONFIG, 'r', encoding="utf-8") as json_file: data = json.load(json_file) data[key] = not data[key] with open(CONFIG, 'w', encoding="utf-8") as json_file: json.dump(data, json_file, indent=4) return gr.update(value = f"Toggle startup Active(Now:{data[key]})") toggle.click(fn=f_toggle,inputs=[gr.Checkbox(value = is_img2img, visible = False)],outputs=[toggle]) active.change(fn=lambda x:gr.update(label = f"NegPiP : {'Active' if x else 'Not Active'}"),inputs=active, outputs=[acc]) self.infotext_fields = [ (active, "NegPiP Active"), ] for _,name in self.infotext_fields: self.paste_field_names.append(name) return [active] def process_batch(self, p, active,**kwargs): self.__init__() flag = False if getattr(shared.opts,OPT_HIDE, False) and not getattr(shared.opts,OPT_ACT, False): return elif not active: return self.rpscript = None #get infomation of regponal prompter from modules.scripts import scripts_txt2img for script in scripts_txt2img.alwayson_scripts: if "rp.py" in script.filename: self.rpscript = script self.hrp, self.hrn = hr_dealer(p) self.active = active self.batch = p.batch_size self.isxl = hasattr(shared.sd_model,"conditioner") self.rev = p.sampler_name in ["DDIM", "PLMS", "UniPC"] if forge: self.rev = not self.rev tokenizer = shared.sd_model.conditioner.embedders[0].tokenize_line if self.isxl else shared.sd_model.cond_stage_model.tokenize_line def getshedulednegs(scheduled,prompts): output = [] nonlocal flag for i, batch_shedule in enumerate(scheduled): stepout = [] seps = None if self.rpscript: if hasattr(self.rpscript,"seps"): seps = self.rpscript.seps self.enable_rp_latent = seps == "AND" for step,prompt in batch_shedule: sep_prompts = prompt.split(seps) if seps else [prompt] padd = 0 padtextweight = [] for sep_prompt in sep_prompts: minusmatches = re.finditer(minusgetter, sep_prompt) minus_targets = [] textweights = [] for minusmatch in minusmatches: minus_targets.append(minusmatch.group().replace("(","").replace(")","")) prompts[i] = prompts[i].replace(minusmatch.group(),"") minus_targets = [x.split(":") for x in minus_targets] #print(minus_targets) for text,weight in minus_targets: weight = float(weight) if text == "BREAK": continue if weight < 0: textweights.append([text,weight]) flag = True padtextweight.append([padd,textweights]) tokens, tokensnum = tokenizer(sep_prompt) padd = tokensnum // 75 + 1 + padd stepout.append([step,padtextweight]) output.append(stepout) return output scheduled_p = prompt_parser.get_learned_conditioning_prompt_schedules(p.prompts,p.steps) scheduled_np = prompt_parser.get_learned_conditioning_prompt_schedules(p.negative_prompts,p.steps) if self.hrp: scheduled_hr_p = prompt_parser.get_learned_conditioning_prompt_schedules(p.hr_prompts,p.hr_second_pass_steps if p.hr_second_pass_steps > 0 else p.steps) if self.hrn: scheduled_hr_np = prompt_parser.get_learned_conditioning_prompt_schedules(p.hr_negative_prompts,p.hr_second_pass_steps if p.hr_second_pass_steps > 0 else p.steps) nip = getshedulednegs(scheduled_p,p.prompts) pin = getshedulednegs(scheduled_np,p.negative_prompts) if self.hrp: hr_nip = getshedulednegs(scheduled_hr_p,p.hr_prompts) if self.hrn: hr_pin = getshedulednegs(scheduled_hr_np,p.hr_negative_prompts) def conddealer(targets): conds =[] start = None end = None for target in targets: input = SdConditioning([f"({target[0]}:{-target[1]})"], width=p.width, height=p.height) cond = prompt_parser.get_learned_conditioning(shared.sd_model,input,p.steps) if start is None: start = cond[0][0].cond[0:1,:] if not self.isxl else cond[0][0].cond["crossattn"][0:1,:] if end is None: end = cond[0][0].cond[-1:,:] if not self.isxl else cond[0][0].cond["crossattn"][-1:,:] token, tokenlen = tokenizer(target[0]) conds.append(cond[0][0].cond[1:tokenlen +2,:] if not self.isxl else cond[0][0].cond["crossattn"][1:tokenlen +2,:] ) conds = torch.cat(conds, 0) conds = torch.split(conds, 75, dim=0) condsout = [] condcount = [] for cond in conds: condcount.append(cond.shape[0]) repeat = 0 if cond.shape[0] == 75 else 75 - cond.shape[0] cond = torch.cat((start,cond,end.repeat(repeat + 1,1)),0) condsout.append(cond) condout = torch.cat(condsout,0).unsqueeze(0) return condout.repeat(self.batch,1,1), condcount def calcconds(targetlist): outconds = [] for batch in targetlist: stepconds = [] for step, regions in batch: regionconds = [] for region, targets in regions: if targets: conds, contokens = conddealer(targets) regionconds.append([region, conds, contokens]) else: regionconds.append([region, None, None]) stepconds.append([step,regionconds]) outconds.append(stepconds) return outconds self.conds_all = calcconds(nip) self.unconds_all = calcconds(pin) if self.hrp: self.hr_conds_all = calcconds(hr_nip) if self.hrn: self.hr_unconds_all = calcconds(hr_pin) #print(self.conds_all) #print(self.unconds_all) resetpcache(p) def calcsets(A, B): return A // B if A % B == 0 else A // B + 1 self.conlen = calcsets(tokenizer(p.prompts[0])[1],75) self.unlen = calcsets(tokenizer(p.negative_prompts[0])[1],75) if not flag: self.active = False unload(self,p) return if not hasattr(self,"negpip_dr_callbacks"): self.negpip_dr_callbacks = on_cfg_denoiser(self.denoiser_callback) #disable hookforward if hookfoward in regional prompter is eanble. #negpip operation is treated in regional prompter already_hooked = False if self.rpscript is not None and hasattr(self.rpscript,"hooked"):already_hooked = self.rpscript.hooked if not already_hooked: self.handle = hook_forwards(self, p.sd_model.model.diffusion_model) print(f"NegPiP enable, Positive:{self.conds_all[0][0][1][0][2]},Negative:{self.unconds_all[0][0][1][0][2]}") p.extra_generation_params.update({ "NegPiP Active":active, }) def postprocess(self, p, processed, *args): unload(self,p) self.conds_all = None self.unconds_all = None def denoiser_callback(self, params: CFGDenoiserParams): if debug: print(params.text_cond.shape) if self.active: if self.x is None: self.x = params.x.shape if self.x != params.x.shape: self.hr = True self.latenti = 0 condslist = [] tokenslist = [] conds = self.hr_conds_all if self.hrp and self.hr else self.conds_all if conds is not None: for step, regions in conds[0]: if step >= params.sampling_step + 2: for region, conds, tokens in regions: condslist.append(conds) tokenslist.append(tokens) if debug: print(f"current:{params.sampling_step + 2},selected:{step}") break self.conds = condslist self.contokens = tokenslist uncondslist = [] untokenslist = [] unconds = self.hr_unconds_all if self.hrn and self.hr else self.unconds_all if unconds is not None: for step, regions in unconds[0]: if step >= params.sampling_step + 2: for region, unconds, untokens in regions: uncondslist.append(unconds) untokenslist.append(untokens) break self.unconds = uncondslist self.untokens = untokenslist from pprint import pprint def unload(self,p): if hasattr(self,"handle"): hook_forwards(self, p.sd_model.model.diffusion_model, remove=True) del self.handle def hook_forward(self, module): def forward(x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0, value = None, transformer_options=None): if debug: print(" x.shape:",x.shape,"context.shape:",context.shape,"self.contokens",self.contokens,"self.untokens",self.untokens) def sub_forward(x, context, mask, additional_tokens, n_times_crossframe_attn_in_self,conds,contokens,unconds,untokens, latent = None): if debug: print(" x.shape[0]:",x.shape[0],"batch:",self.batch *2) if x.shape[0] == self.batch *2: if debug: print(" x.shape[0] == self.batch *2") if self.rev: contn,contp = context.chunk(2) ixn,ixp = x.chunk(2) else: contp,contn = context.chunk(2) ixp,ixn = x.chunk(2) #x[0:self.batch,:,:],x[self.batch:,:,:] if conds is not None: if contp.shape[0] != conds.shape[0]: conds = conds.expand(contp.shape[0],-1,-1) contp = torch.cat((contp,conds),1) if unconds is not None: if contn.shape[0] != unconds.shape[0]: unconds = unconds.expand(contn.shape[0],-1,-1) contn = torch.cat((contn,unconds),1) xp = main_foward(self, module, ixp,contp,mask,additional_tokens,n_times_crossframe_attn_in_self,contokens) xn = main_foward(self, module, ixn,contn,mask,additional_tokens,n_times_crossframe_attn_in_self,untokens) out = torch.cat([xn,xp]) if self.rev else torch.cat([xp,xn]) return out elif latent is not None: if debug:print(" latent is not None") if latent: conds = conds if conds is not None else None else: conds = unconds if unconds is not None else None if conds is not None: if context.shape[0] != conds.shape[0]: conds = conds.expand(context.shape[0],-1,-1) context = torch.cat([context,conds],1) tokens = contokens if contokens is not None else untokens out = main_foward(self, module, x,context,mask,additional_tokens,n_times_crossframe_attn_in_self,tokens) return out else: if debug: print(" Else") print(context.shape[1] , self.conlen,self.unlen) tokens = [] concon = counter(self.isxl) if debug: print(concon) if context.shape[1] == self.conlen * 77 and concon: if conds is not None: if context.shape[0] != conds.shape[0]: conds = conds.expand(context.shape[0],-1,-1) context = torch.cat([context,conds],1) tokens = contokens elif context.shape[1] == self.unlen * 77 and concon: if unconds is not None: if context.shape[0] != unconds.shape[0]: unconds = unconds.expand(context.shape[0],-1,-1) context = torch.cat([context,unconds],1) tokens = untokens out = main_foward(self, module, x,context,mask,additional_tokens,n_times_crossframe_attn_in_self,tokens) return out if self.enable_rp_latent: if len(self.conds) - 1 >= self.latenti: out = sub_forward(x, context, mask, additional_tokens, n_times_crossframe_attn_in_self,self.conds[self.latenti],self.contokens[self.latenti],None,None ,latent = True) self.latenti += 1 else: out = sub_forward(x, context, mask, additional_tokens, n_times_crossframe_attn_in_self,None,None,self.unconds[0],self.untokens[0], latent = False) self.latenti = 0 return out else: if self.conds is not None and self.unconds is not None and len(self.conds) > 0 and len(self.unconds) > 0: return sub_forward(x, context, mask, additional_tokens, n_times_crossframe_attn_in_self,self.conds[0],self.contokens[0],self.unconds[0],self.untokens[0]) else: return sub_forward(x, context, mask, additional_tokens, n_times_crossframe_attn_in_self,None,None,None,None) return forward count = 0 pn = True def counter(isxl): global count, pn count += 1 limit = 70 if isxl else 16 outpn = pn if count == limit: pn = not pn count = 0 return outpn def main_foward(self, module, x, context, mask, additional_tokens, n_times_crossframe_attn_in_self, tokens): h = module.heads context = context.to(x.dtype) q = module.to_q(x) context = atm.default(context, x) k = module.to_k(context) v = module.to_v(context) if debug: print(h,context.shape,q.shape,k.shape,v.shape) _, _, dim_head = q.shape dim_head //= h scale = dim_head ** -0.5 q, k, v = map(lambda t: atm.rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) sim = atm.einsum('b i d, b j d -> b i j', q, k) * scale if self.active: if tokens: for token in tokens: start = (v.shape[1]//77 - len(tokens)) * 77 #print("v.shape:",v.shape,"start:",start+1,"stop:",start+token) v[:,start+1:start+token,:] = -v[:,start+1:start+token,:] if atm.exists(mask): mask = atm.rearrange(mask, 'b ... -> b (...)') max_neg_value = -torch.finfo(sim.dtype).max mask = atm.repeat(mask, 'b j -> (b h) () j', h=h) sim.masked_fill_(~mask, max_neg_value) attn = sim.softmax(dim=-1) #print(h,context.shape,q.shape,k.shape,v.shape,attn.shape) out = atm.einsum('b i j, b j d -> b i d', attn, v) out = atm.rearrange(out, '(b h) n d -> b n (h d)', h=h) return module.to_out(out) import inspect def hook_forwards(self, root_module: torch.nn.Module, remove=False): for name, module in root_module.named_modules(): if "attn2" in name and module.__class__.__name__ == "CrossAttention": module.forward = hook_forward(self, module) if remove: del module.forward def resetpcache(p): p.cached_c = [None,None] p.cached_uc = [None,None] p.cached_hr_c = [None, None] p.cached_hr_uc = [None, None] class SdConditioning(list): def __init__(self, prompts, is_negative_prompt=False, width=None, height=None, copy_from=None): super().__init__() self.extend(prompts) if copy_from is None: copy_from = prompts self.is_negative_prompt = is_negative_prompt or getattr(copy_from, 'is_negative_prompt', False) self.width = width or getattr(copy_from, 'width', None) self.height = height or getattr(copy_from, 'height', None) def ext_on_ui_settings(): # [setting_name], [default], [label], [component(blank is checkbox)], [component_args]debug_level_choices = [] negpip_options = [ (OPT_HIDE, False, "Hide in Txt2Img/Img2Img tab(Reload UI required)"), (OPT_ACT, True, "Active(Effective when Hide is Checked)",), ] section = ('negpip', "NegPiP") for cur_setting_name, *option_info in negpip_options: shared.opts.add_option(cur_setting_name, shared.OptionInfo(*option_info, section=section)) on_ui_settings(ext_on_ui_settings) def hr_dealer(p): if not hasattr(p, "hr_prompts"): p.hr_prompts = None if not hasattr(p, "hr_negative_prompts"): p.hr_negative_prompts = None return bool(p.hr_prompts), bool(p.hr_negative_prompts )