|
|
from difflib import restore
|
|
|
import random
|
|
|
import copy
|
|
|
from pprint import pprint
|
|
|
import re
|
|
|
from typing import Union
|
|
|
import torch
|
|
|
from modules import devices, shared, extra_networks, sd_hijack
|
|
|
from modules.script_callbacks import CFGDenoisedParams, CFGDenoiserParams
|
|
|
from torchvision.transforms import InterpolationMode, Resize
|
|
|
import scripts.attention as att
|
|
|
from scripts.regions import floatdef
|
|
|
from scripts.attention import makerrandman
|
|
|
|
|
|
islora = True
|
|
|
in_hr = False
|
|
|
layer_name = "lora_layer_name"
|
|
|
orig_Linear_forward = None
|
|
|
|
|
|
orig_lora_functional = False
|
|
|
|
|
|
lactive = False
|
|
|
labug =False
|
|
|
MINID = 1000
|
|
|
MAXID = 10000
|
|
|
LORAID = MINID
|
|
|
|
|
|
try:
|
|
|
from ldm_patched.modules import model_management
|
|
|
forge = True
|
|
|
except:
|
|
|
forge = False
|
|
|
|
|
|
def setuploras(self):
|
|
|
global lactive, labug, islora, orig_Linear_forward, orig_lora_functional, layer_name
|
|
|
lactive = True
|
|
|
labug = self.debug
|
|
|
islora = self.isbefore15
|
|
|
layer_name = self.layer_name
|
|
|
orig_lora_functional = orig_lora_functional = shared.opts.lora_functional if hasattr(shared.opts,"lora_functional") else False
|
|
|
|
|
|
try:
|
|
|
if 150 <= self.ui_version <= 159 or self.slowlora:
|
|
|
shared.opts.lora_functional = False
|
|
|
else:
|
|
|
shared.opts.lora_functional = True
|
|
|
except:
|
|
|
pass
|
|
|
|
|
|
is15 = 150 <= self.ui_version <= 159
|
|
|
orig_Linear_forward = torch.nn.Linear.forward
|
|
|
torch.nn.Linear.forward = h15_Linear_forward if is15 else h_Linear_forward
|
|
|
|
|
|
def cloneparams(orig,target):
|
|
|
target.x = orig.x.clone()
|
|
|
target.image_cond = orig.image_cond.clone()
|
|
|
target.sigma = orig.sigma.clone()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def denoiser_callback_s(self, params: CFGDenoiserParams):
|
|
|
if "Pro" in self.mode:
|
|
|
if self.x == None : cloneparams(params,self)
|
|
|
self.step = params.sampling_step
|
|
|
self.pfirst = True
|
|
|
|
|
|
lim = 1 if self.isxl else 3
|
|
|
|
|
|
if len(att.pmaskshw) > lim:
|
|
|
self.filters = []
|
|
|
for b in range(self.batch_size):
|
|
|
|
|
|
allmask = []
|
|
|
basemask = None
|
|
|
for t, th, bratio in zip(self.pe, self.th, self.bratios):
|
|
|
key = f"{t}-{b}"
|
|
|
_, _, mask = att.makepmask(att.pmasks[key], params.x.shape[2], params.x.shape[3], th, self.step, bratio = bratio)
|
|
|
mask = mask.repeat(params.x.shape[1],1,1)
|
|
|
basemask = 1 - mask if basemask is None else basemask - mask
|
|
|
if self.ex:
|
|
|
for l in range(len(allmask)):
|
|
|
mt = allmask[l] - mask
|
|
|
allmask[l] = torch.where(mt > 0, 1,0)
|
|
|
allmask.append(mask)
|
|
|
if not self.ex:
|
|
|
sum = torch.stack(allmask, dim=0).sum(dim=0)
|
|
|
sum = torch.where(sum == 0, 1 , sum)
|
|
|
allmask = [mask / sum for mask in allmask]
|
|
|
basemask = torch.where(basemask > 0, 1, 0)
|
|
|
allmask.insert(0,basemask)
|
|
|
self.filters.extend(allmask)
|
|
|
att.maskready = True
|
|
|
|
|
|
for t, th, bratio in zip(self.pe, self.th, self.bratios):
|
|
|
allmask = []
|
|
|
for hw in att.pmaskshw:
|
|
|
masks = None
|
|
|
for b in range(self.batch_size):
|
|
|
key = f"{t}-{b}"
|
|
|
_, mask, _ = att.makepmask(att.pmasks[key], hw[0], hw[1], th, self.step, bratio = bratio)
|
|
|
mask = mask.unsqueeze(0).unsqueeze(-1)
|
|
|
masks = mask if b ==0 else torch.cat((masks,mask),dim=0)
|
|
|
allmask.append(mask)
|
|
|
att.pmasksf[key] = allmask
|
|
|
att.maskready = True
|
|
|
|
|
|
if not self.rebacked:
|
|
|
cloneparams(self,params)
|
|
|
params.sampling_step = 0
|
|
|
self.rebacked = True
|
|
|
|
|
|
if "La" in self.calc:
|
|
|
self.condi = 0
|
|
|
global in_hr, regioner
|
|
|
regioner.step = params.sampling_step
|
|
|
in_hr = self.in_hr
|
|
|
regioner.u_count = 0
|
|
|
if "u_list" not in self.log.keys() and hasattr(regioner,"u_llist"):
|
|
|
self.log["u_list"] = regioner.u_llist.copy()
|
|
|
if "u_list_hr" not in self.log.keys() and hasattr(regioner,"u_llist") and in_hr:
|
|
|
self.log["u_list_hr"] = regioner.u_llist.copy()
|
|
|
xt = params.x.clone()
|
|
|
ict = params.image_cond.clone()
|
|
|
st = params.sigma.clone()
|
|
|
batch = self.batch_size
|
|
|
areas = xt.shape[0] // batch -1
|
|
|
|
|
|
if hasattr(params,"text_cond"):
|
|
|
if "DictWithShape" in params.text_cond.__class__.__name__:
|
|
|
ct = {}
|
|
|
for key in params.text_cond.keys():
|
|
|
ct[key] = params.text_cond[key].clone()
|
|
|
else:
|
|
|
ct = params.text_cond.clone()
|
|
|
|
|
|
for a in range(areas):
|
|
|
for b in range(batch):
|
|
|
params.x[b+a*batch] = xt[a + b * areas]
|
|
|
params.image_cond[b+a*batch] = ict[a + b * areas]
|
|
|
params.sigma[b+a*batch] = st[a + b * areas]
|
|
|
|
|
|
if hasattr(params,"text_cond"):
|
|
|
if "DictWithShape" in params.text_cond.__class__.__name__:
|
|
|
for key in params.text_cond.keys():
|
|
|
params.text_cond[key][b+a*batch] = ct[key][a + b * areas]
|
|
|
else:
|
|
|
params.text_cond[b+a*batch] = ct[a + b * areas]
|
|
|
|
|
|
def denoised_callback_s(self, params: CFGDenoisedParams):
|
|
|
batch = self.batch_size
|
|
|
x = params.x
|
|
|
xt = params.x.clone()
|
|
|
areas = xt.shape[0] // batch - 1
|
|
|
|
|
|
if "La" in self.calc:
|
|
|
|
|
|
|
|
|
if not "Pro" in self.mode:
|
|
|
indrebuild = self.filters == [] or self.filters[0].size() != x[0].size()
|
|
|
|
|
|
if indrebuild:
|
|
|
if "Ran" in self.mode:
|
|
|
if self.filters == []:
|
|
|
self.filters = [self.ranbase] + self.ransors if self.usebase else self.ransors
|
|
|
elif self.filters[0][:,:].size() != x[0,0,:,:].size():
|
|
|
self.filters = hrchange(self.ransors,x.shape[2], x.shape[3])
|
|
|
else:
|
|
|
if "Mask" in self.mode:
|
|
|
masks = (self.regmasks,self.regbase)
|
|
|
else:
|
|
|
masks = self.aratios
|
|
|
self.filters = makefilters(x.shape[1], x.shape[2], x.shape[3],masks, self.mode, self.usebase, self.bratios, "Mas" in self.mode)
|
|
|
self.filters = [f for f in self.filters]*batch
|
|
|
else:
|
|
|
if not att.maskready:
|
|
|
self.filters = [1,*[0 for a in range(areas - 1)]] * batch
|
|
|
|
|
|
if self.debug : print("filterlength : ",len(self.filters))
|
|
|
|
|
|
for b in range(batch):
|
|
|
for a in range(areas) :
|
|
|
fil = self.filters[a + b*areas]
|
|
|
if self.debug : print(f"x = {x.size()}i = {a + b*areas}, j = {b + a*batch}, cond = {a + b*areas},filsum = {fil if type(fil) is int else torch.sum(fil)}, uncon = {x.size()[0]+(b-batch)}")
|
|
|
x[a + b * areas, :, :, :] = xt[b + a*batch, :, :, :] * fil + x[x.size()[0]+(b-batch), :, :, :] * (1 - fil)
|
|
|
|
|
|
if params.total_sampling_steps == params.sampling_step + 2:
|
|
|
if self.rps is not None and self.diff:
|
|
|
if self.rps.latent is None:
|
|
|
self.rps.latent = x.clone()
|
|
|
return
|
|
|
elif self.rps.latent.shape[2:] != x.shape[2:] and self.rps.latent_hr is None:
|
|
|
self.rps.latent_hr = x.clone()
|
|
|
return
|
|
|
else:
|
|
|
for b in range(batch):
|
|
|
for a in range(areas) :
|
|
|
fil = self.filters[a+1]
|
|
|
orig = self.rps.latent if self.rps.latent.shape[2:] == x.shape[2:] else self.rps.latent_hr
|
|
|
if self.debug : print(f"x = {x.size()}i = {a + b*areas}, j = {b + a*batch}, cond = {a + b*areas},filsum = {fil if type(fil) is int else torch.sum(fil)}, uncon = {x.size()[0]+(b-batch)}")
|
|
|
|
|
|
x[:,:,:,:] = orig[:,:,:,:] * (1 - fil) + x[:,:,:,:] * fil
|
|
|
|
|
|
|
|
|
if att.maskready:
|
|
|
if self.rps is not None and self.diff:
|
|
|
if self.rps.latent is not None:
|
|
|
if self.rps.latent.shape[2:] != x.shape[2:]:
|
|
|
if self.rps.latent_hr is None: return
|
|
|
for b in range(batch):
|
|
|
for a in range(areas) :
|
|
|
fil = self.filters[a+1]
|
|
|
orig = self.rps.latent if self.rps.latent.shape[2:] == x.shape[2:] else self.rps.latent_hr
|
|
|
if self.debug : print(f"x = {x.size()}i = {a + b*areas}, j = {b + a*batch}, cond = {a + b*areas},filsum = {fil if type(fil) is int else torch.sum(fil)}, uncon = {x.size()[0]+(b-batch)}")
|
|
|
|
|
|
x[:,:,:,:] = orig[:,:,:,:] * (1 - fil) + x[:,:,:,:] * fil
|
|
|
|
|
|
if params.sampling_step == 0 and self.in_hr:
|
|
|
if self.rps is not None and self.diff:
|
|
|
if self.rps.latent is not None:
|
|
|
if self.rps.latent.shape[2:] != x.shape[2:] and self.rps.latent_hr is None: return
|
|
|
for b in range(batch):
|
|
|
for a in range(areas) :
|
|
|
fil = self.filters[a+1]
|
|
|
orig = self.rps.latent if self.rps.latent.shape[2:] == x.shape[2:] else self.rps.latent_hr
|
|
|
if self.debug : print(f"x = {x.size()}i = {a + b*areas}, j = {b + a*batch}, cond = {a + b*areas},filsum = {fil if type(fil) is int else torch.sum(fil)}, uncon = {x.size()[0]+(b-batch)}")
|
|
|
|
|
|
x[:,:,:,:] = orig[:,:,:,:] * (1 - fil) + x[:,:,:,:] * fil
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def hrchange(filters,h, w):
|
|
|
out = []
|
|
|
for filter in filters:
|
|
|
out.append(makerrandman(filter,h,w,True))
|
|
|
return out
|
|
|
|
|
|
|
|
|
flokey = lambda x: (x.split("added_by_regional_prompter")[0]
|
|
|
.split("added_by_lora_block_weight")[0].split("_in_LBW")[0].split("_in_RP")[0])
|
|
|
|
|
|
def lora_namer(self, p, lnter, lnur):
|
|
|
ldict_u = {}
|
|
|
ldict_te = {}
|
|
|
lorder = []
|
|
|
import lora as loraclass
|
|
|
for lora in loraclass.loaded_loras:
|
|
|
ldict_u[lora.name] =lora.multiplier if self.isbefore15 else lora.unet_multiplier
|
|
|
ldict_te[lora.name] =lora.multiplier if self.isbefore15 else lora.te_multiplier
|
|
|
|
|
|
subprompts = self.current_prompts[0].split("AND")
|
|
|
ldictlist_u =[ldict_u.copy() for i in range(len(subprompts)+1)]
|
|
|
ldictlist_te =[ldict_te.copy() for i in range(len(subprompts)+1)]
|
|
|
|
|
|
for i, prompt in enumerate(subprompts):
|
|
|
_, extranets = extra_networks.parse_prompts([prompt])
|
|
|
calledloras = extranets["lora"]
|
|
|
|
|
|
names = ""
|
|
|
tdict = {}
|
|
|
|
|
|
for called in calledloras:
|
|
|
names = names + called.items[0]
|
|
|
tdict[called.items[0]] = syntaxdealer(called.items,"unet=",1)
|
|
|
|
|
|
for key in ldictlist_u[i].keys():
|
|
|
shin_key = flokey(key)
|
|
|
if shin_key in names:
|
|
|
ldictlist_u[i+1][key] = float(tdict[shin_key])
|
|
|
ldictlist_te[i+1][key] = float(tdict[shin_key])
|
|
|
if key not in lorder:
|
|
|
lorder.append(key)
|
|
|
else:
|
|
|
ldictlist_u[i+1][key] = 0
|
|
|
ldictlist_te[i+1][key] = 0
|
|
|
|
|
|
if self.debug: print("Regioner lorder: ",lorder)
|
|
|
global regioner
|
|
|
regioner.__init__(self.lstop,self.lstop_hr)
|
|
|
u_llist = [d.copy() for d in ldictlist_u[1:]]
|
|
|
u_llist.append(ldictlist_u[0].copy())
|
|
|
regioner.te_llist = ldictlist_te
|
|
|
regioner.u_llist = u_llist
|
|
|
regioner.ndeleter(lnter, lnur, lorder)
|
|
|
if self.debug:
|
|
|
print("LoRA regioner : TE list",regioner.te_llist)
|
|
|
print("LoRA regioner : U list",regioner.u_llist)
|
|
|
|
|
|
def syntaxdealer(items,type,index):
|
|
|
for item in items:
|
|
|
if type in item:
|
|
|
if "@" in item:return 1
|
|
|
return item.replace(type,"")
|
|
|
return items[index] if "@" not in items[index] else 1
|
|
|
|
|
|
def makefilters(c,h,w,masks,mode,usebase,bratios,indmask):
|
|
|
if indmask:
|
|
|
(regmasks, regbase) = masks
|
|
|
|
|
|
filters = []
|
|
|
x = torch.zeros(c, h, w).to(devices.device)
|
|
|
if usebase:
|
|
|
x0 = torch.zeros(c, h, w).to(devices.device)
|
|
|
i=0
|
|
|
if indmask:
|
|
|
ftrans = Resize((h, w), interpolation = InterpolationMode("nearest"))
|
|
|
for rmask, bratio in zip(regmasks,bratios[0]):
|
|
|
|
|
|
|
|
|
rmask2 = ftrans(rmask.reshape([1, *rmask.shape]))
|
|
|
rmask2 = rmask2.reshape([1, h, w])
|
|
|
fx = x.clone()
|
|
|
if usebase:
|
|
|
fx[:,:,:] = fx + rmask2 * (1 - bratio)
|
|
|
x0[:,:,:] = x0 + rmask2 * bratio
|
|
|
else:
|
|
|
fx[:,:,:] = fx + rmask2 * 1
|
|
|
filters.append(fx)
|
|
|
|
|
|
if usebase:
|
|
|
rmask = regbase
|
|
|
rmask2 = ftrans(rmask.reshape([1, *rmask.shape]))
|
|
|
rmask2 = rmask2.reshape([1, h, w])
|
|
|
x0 = x0 + rmask2
|
|
|
else:
|
|
|
for drow in masks:
|
|
|
for dcell in drow.cols:
|
|
|
fx = x.clone()
|
|
|
if "Horizontal" in mode:
|
|
|
if usebase:
|
|
|
fx[:,int(h*drow.st):int(h*drow.ed),int(w*dcell.st):int(w*dcell.ed)] = 1 - dcell.base
|
|
|
x0[:,int(h*drow.st):int(h*drow.ed),int(w*dcell.st):int(w*dcell.ed)] = dcell.base
|
|
|
else:
|
|
|
fx[:,int(h*drow.st):int(h*drow.ed),int(w*dcell.st):int(w*dcell.ed)] = 1
|
|
|
elif "Vertical" in mode:
|
|
|
if usebase:
|
|
|
fx[:,int(h*dcell.st):int(h*dcell.ed),int(w*drow.st):int(w*drow.ed)] = 1 - dcell.base
|
|
|
x0[:,int(h*dcell.st):int(h*dcell.ed),int(w*drow.st):int(w*drow.ed)] = dcell.base
|
|
|
else:
|
|
|
fx[:,int(h*dcell.st):int(h*dcell.ed),int(w*drow.st):int(w*drow.ed)] = 1
|
|
|
filters.append(fx)
|
|
|
i +=1
|
|
|
if usebase : filters.insert(0,x0)
|
|
|
if labug : print(i,len(filters))
|
|
|
|
|
|
return filters
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
TE_START_NAME = "transformer_text_model_encoder_layers_0_self_attn_q_proj"
|
|
|
UNET_START_NAME = "diffusion_model_time_embed_0"
|
|
|
|
|
|
TE_START_NAME_XL = "0_transformer_text_model_encoder_layers_0_self_attn_q_proj"
|
|
|
|
|
|
class LoRARegioner:
|
|
|
|
|
|
def __init__(self,stop=0,stop_hr=0):
|
|
|
self.te_count = 0
|
|
|
self.u_count = 0
|
|
|
self.te_llist = [{}]
|
|
|
self.u_llist = [{}]
|
|
|
self.mlist = {}
|
|
|
self.ctl = False
|
|
|
self.step = 0
|
|
|
self.stop = stop
|
|
|
self.stop_hr = stop_hr
|
|
|
|
|
|
try:
|
|
|
import lora_ctl_network as ctl
|
|
|
self.ctlweight = copy.deepcopy(ctl.lora_weights)
|
|
|
for set in self.ctlweight.values():
|
|
|
for weight in set.values():
|
|
|
if type(weight) == list:
|
|
|
self.ctl = True
|
|
|
except:
|
|
|
pass
|
|
|
|
|
|
def expand_del(self, val, lorder):
|
|
|
"""Broadcast single / comma separated val to lora list.
|
|
|
|
|
|
"""
|
|
|
lval = val.split(",")
|
|
|
if len(lval) > len(lorder):
|
|
|
lval = lval[:len(lorder)]
|
|
|
lval = [floatdef(v, 0) for v in lval]
|
|
|
if len(lval) < len(lorder):
|
|
|
lval.extend([lval[-1]] * (len(lorder) - len(lval)))
|
|
|
return lval
|
|
|
|
|
|
def ndeleter(self, lnter, lnur, lorder = None):
|
|
|
"""Multiply global weights by 0:1 factor.
|
|
|
|
|
|
Can be any value, negative too, but doesn't help much.
|
|
|
"""
|
|
|
if lorder is None:
|
|
|
lkeys = self.te_llist[0].keys()
|
|
|
else:
|
|
|
lkeys = lorder
|
|
|
lnter = self.expand_del(lnter, lkeys)
|
|
|
for (key, val) in zip(lkeys, lnter):
|
|
|
self.te_llist[0][key] *= val
|
|
|
if lorder is None:
|
|
|
lkeys = self.u_llist[-1].keys()
|
|
|
else:
|
|
|
lkeys = lorder
|
|
|
lnur = self.expand_del(lnur, lkeys)
|
|
|
for (key, val) in zip(lkeys, lnur):
|
|
|
self.u_llist[-1][key] *= val
|
|
|
|
|
|
def search_key(self,lora,i,xlist):
|
|
|
lorakey = lora.loaded_loras[i].name
|
|
|
if lorakey not in xlist.keys():
|
|
|
shin_key = flokey(lorakey)
|
|
|
picked = False
|
|
|
for mlkey in xlist.keys():
|
|
|
if mlkey.startswith(shin_key):
|
|
|
lorakey = mlkey
|
|
|
picked = True
|
|
|
if not picked:
|
|
|
print(f"key is not found in:{xlist.keys()}")
|
|
|
return lorakey
|
|
|
|
|
|
def te_start(self):
|
|
|
self.mlist = self.te_llist[self.te_count % len(self.te_llist)]
|
|
|
if self.mlist == {}: return
|
|
|
self.te_count += 1
|
|
|
import lora
|
|
|
for i in range(len(lora.loaded_loras)):
|
|
|
lorakey = self.search_key(lora,i,self.mlist)
|
|
|
lora.loaded_loras[i].multiplier = self.mlist[lorakey]
|
|
|
lora.loaded_loras[i].te_multiplier = self.mlist[lorakey]
|
|
|
|
|
|
def u_start(self):
|
|
|
if labug : print("u_count",self.u_count ,"u_count '%' divide", self.u_count % len(self.u_llist))
|
|
|
self.mlist = self.u_llist[self.u_count % len(self.u_llist)]
|
|
|
if self.mlist == {}: return
|
|
|
self.u_count += 1
|
|
|
|
|
|
stopstep = self.stop_hr if in_hr else self.stop
|
|
|
|
|
|
import lora
|
|
|
for i in range(len(lora.loaded_loras)):
|
|
|
lorakey = self.search_key(lora,i,self.mlist)
|
|
|
lora.loaded_loras[i].multiplier = 0 if self.step + 2 > stopstep and stopstep else self.mlist[lorakey]
|
|
|
lora.loaded_loras[i].unet_multiplier = 0 if self.step + 2 > stopstep and stopstep else self.mlist[lorakey]
|
|
|
if labug :print(lorakey,lora.loaded_loras[i].multiplier,lora.loaded_loras[i].multiplier )
|
|
|
if self.ctl:
|
|
|
import lora_ctl_network as ctl
|
|
|
key = "hrunet" if in_hr else "unet"
|
|
|
if self.mlist[lorakey] == 0 or (self.step + 2 > stopstep and stopstep):
|
|
|
ctl.lora_weights[lorakey][key] = [[0],[0]]
|
|
|
if labug :print(ctl.lora_weights[lorakey])
|
|
|
else:
|
|
|
if key in self.ctlweight[lorakey].keys():
|
|
|
ctl.lora_weights[lorakey][key] = self.ctlweight[lorakey][key]
|
|
|
else:
|
|
|
ctl.lora_weights[lorakey][key] = self.ctlweight[lorakey]["unet"]
|
|
|
if labug :print(ctl.lora_weights[lorakey])
|
|
|
|
|
|
def reset(self):
|
|
|
self.te_count = 0
|
|
|
self.u_count = 0
|
|
|
|
|
|
regioner = LoRARegioner()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def h_Linear_forward(self, input):
|
|
|
changethelora(getattr(self, layer_name, None))
|
|
|
if islora:
|
|
|
import lora
|
|
|
return lora.lora_forward(self, input, torch.nn.Linear_forward_before_lora)
|
|
|
elif forge:
|
|
|
return orig_Linear_forward(self, input)
|
|
|
else:
|
|
|
import networks
|
|
|
if shared.opts.lora_functional:
|
|
|
return networks.network_forward(self, input, networks.originals.Linear_forward)
|
|
|
networks.network_apply_weights(self)
|
|
|
return networks.originals.Linear_forward(self, input)
|
|
|
|
|
|
def h15_Linear_forward(self, input):
|
|
|
changethelora(getattr(self, layer_name, None))
|
|
|
if islora:
|
|
|
import lora
|
|
|
return lora.lora_forward(self, input, torch.nn.Linear_forward_before_lora)
|
|
|
else:
|
|
|
import networks
|
|
|
if shared.opts.lora_functional:
|
|
|
return networks.network_forward(self, input, networks.network_Linear_forward)
|
|
|
networks.network_apply_weights(self)
|
|
|
return torch.nn.Linear_forward_before_network(self, input)
|
|
|
|
|
|
def changethelora(name):
|
|
|
if lactive:
|
|
|
global regioner
|
|
|
if name == TE_START_NAME or name == TE_START_NAME_XL:
|
|
|
regioner.te_start()
|
|
|
elif name == UNET_START_NAME:
|
|
|
regioner.u_start()
|
|
|
|
|
|
LORAANDSOON = {
|
|
|
"LoraHadaModule" : "w1a",
|
|
|
"LycoHadaModule" : "w1a",
|
|
|
"NetworkModuleHada": "w1a",
|
|
|
"FullModule" : "weight",
|
|
|
"NetworkModuleFull": "weight",
|
|
|
"IA3Module" : "w",
|
|
|
"NetworkModuleIa3" : "w",
|
|
|
"LoraKronModule" : "w1",
|
|
|
"LycoKronModule" : "w1",
|
|
|
"NetworkModuleLokr": "w1",
|
|
|
}
|
|
|
|
|
|
def changethedevice(module):
|
|
|
ltype = type(module).__name__
|
|
|
if ltype == "LoraUpDownModule" or ltype == "LycoUpDownModule" :
|
|
|
if hasattr(module,"up_model") :
|
|
|
module.up_model.weight = torch.nn.Parameter(module.up_model.weight.to(devices.device, dtype = torch.float))
|
|
|
module.down_model.weight = torch.nn.Parameter(module.down_model.weight.to(devices.device, dtype=torch.float))
|
|
|
else:
|
|
|
module.up.weight = torch.nn.Parameter(module.up.weight.to(devices.device, dtype = torch.float))
|
|
|
if hasattr(module.down, "weight"):
|
|
|
module.down.weight = torch.nn.Parameter(module.down.weight.to(devices.device, dtype=torch.float))
|
|
|
|
|
|
elif ltype == "LoraHadaModule" or ltype == "LycoHadaModule" or ltype == "NetworkModuleHada":
|
|
|
module.w1a = torch.nn.Parameter(module.w1a.to(devices.device, dtype=torch.float))
|
|
|
module.w1b = torch.nn.Parameter(module.w1b.to(devices.device, dtype=torch.float))
|
|
|
module.w2a = torch.nn.Parameter(module.w2a.to(devices.device, dtype=torch.float))
|
|
|
module.w2b = torch.nn.Parameter(module.w2b.to(devices.device, dtype=torch.float))
|
|
|
|
|
|
if module.t1 is not None:
|
|
|
module.t1 = torch.nn.Parameter(module.t1.to(devices.device, dtype=torch.float))
|
|
|
|
|
|
if module.t2 is not None:
|
|
|
module.t2 = torch.nn.Parameter(module.t2.to(devices.device, dtype=torch.float))
|
|
|
|
|
|
elif ltype == "FullModule" or ltype == "NetworkModuleFull":
|
|
|
module.weight = torch.nn.Parameter(module.weight.to(devices.device, dtype=torch.float))
|
|
|
|
|
|
if hasattr(module, 'bias') and module.bias != None:
|
|
|
module.bias = torch.nn.Parameter(module.bias.to(devices.device, dtype=torch.float))
|
|
|
|
|
|
def unloadlorafowards(p):
|
|
|
global orig_Linear_forward, lactive, labug
|
|
|
lactive = labug = False
|
|
|
|
|
|
try:
|
|
|
shared.opts.lora_functional = orig_lora_functional
|
|
|
except:
|
|
|
pass
|
|
|
|
|
|
emb_db = sd_hijack.model_hijack.embedding_db
|
|
|
import lora
|
|
|
for net in lora.loaded_loras:
|
|
|
if hasattr(net,"bundle_embeddings"):
|
|
|
for emb_name, embedding in net.bundle_embeddings.items():
|
|
|
if embedding.loaded:
|
|
|
emb_db.register_embedding_by_name(None, shared.sd_model, emb_name)
|
|
|
|
|
|
lora.loaded_loras.clear()
|
|
|
if orig_Linear_forward != None :
|
|
|
torch.nn.Linear.forward = orig_Linear_forward
|
|
|
orig_Linear_forward = None
|
|
|
|