tellurion's picture
Add refnet/models and ldm/models source files previously excluded by .gitignore
47ab351
import torch
import torch.nn.functional as F
from ..modules.reference_net import hack_inference_forward
from ..models.basemodel import CustomizedColorizer, CustomizedWrapper
from ..modules.lora import LoraModules
from ..util import exists, expand_to_batch_size, instantiate_from_config, get_crop_scale, resize_and_crop
class InferenceWrapper(CustomizedWrapper, CustomizedColorizer):
def __init__(
self,
scalar_embedder_config,
img_embedder_config,
lora_config = None,
logits_embed = False,
*args,
**kwargs
):
CustomizedColorizer.__init__(self, version="sdxl", *args, **kwargs)
CustomizedWrapper.__init__(self)
self.scalar_embedder = instantiate_from_config(scalar_embedder_config)
self.img_embedder = instantiate_from_config(img_embedder_config)
self.loras = LoraModules(self, **lora_config) if exists(lora_config) else None
self.logits_embed = logits_embed
new_model_list = {
"scalar_embedder": self.scalar_embedder,
"img_embedder": self.img_embedder,
# "style_encoder": self.style_encoder,
}
self.switch_cond_modules += list(new_model_list.keys())
self.model_list.update(new_model_list)
def retrieve_attn_modules(self):
scale_factor_levels = {"high": 0.5, "low": 0.25, "bottom": 0.25}
from refnet.modules.transformer import BasicTransformerBlock
from refnet.sampling import torch_dfs
attn_modules = []
for module in torch_dfs(self.model.diffusion_model):
if isinstance(module, BasicTransformerBlock):
attn_modules.append(module)
self.attn_modules = {
"high": [0, 1, 2, 3] + [64, 65, 66, 67, 68, 69],
"low": [i for i in range(4, 24)] + [i for i in range(34, 64)],
"bottom": [i for i in range(24, 34)],
"encoder": [i for i in range(24)],
"decoder": [i for i in range(34, len(attn_modules))]
}
self.attn_modules["modules"] = attn_modules
for k in ["high", "low", "bottom"]:
scale_factor = scale_factor_levels[k]
for attn in self.attn_modules[k]:
attn_modules[attn].scale_factor = scale_factor
def adjust_reference_scale(self, scale_kwargs):
for module in self.attn_modules["modules"]:
module.reference_scale = scale_kwargs["scales"]["encoder"]
def adjust_masked_attn(self, scale, mask_threshold, merge_scale):
for layer in self.attn_layers:
layer.mask_scale = scale
layer.mask_threshold = mask_threshold
layer.merge_scale = merge_scale
def rescale_size(self, x: torch.Tensor, height, width):
oh, ow = x.shape[2:]
if oh < height or ow < width:
dh, dw = height - oh, width - ow
if dh > dw:
iw = ow + int(dh * ow/oh)
ih = height
else:
ih = oh + int(dw * oh/ow)
iw = width
else:
ih, iw = oh, ow
return torch.Tensor([ih]), torch.Tensor([iw])
def get_learned_embedding(self, c, bg=False, mapping=False, sketch=None, *args, **kwargs):
clip_emb = self.cond_stage_model.encode(c, "full").detach()
wd_emb, logits = self.img_embedder.encode(c, pooled=False, return_logits=True)
cls_emb, local_emb = clip_emb[:, :1], clip_emb[:, 1:]
if mapping:
_, sketch_logits = self.img_embedder.encode(-sketch, pooled=False, return_logits=True)
sketch_logits.mean(dim=1, keepdim=True)
logits = self.img_embedder.geometry_update(logits, sketch_logits)
emb = self.proj(clip_emb, logits if self.logits_embed else wd_emb, bg)
return emb, cls_emb
def prepare_conditions(
self,
bs,
sketch,
reference,
height,
width,
control_scale = (1., 1., 1., 1.),
merge_scale = 0,
mask_scale = 1.,
fg_scale = 1.,
bg_scale = 1.,
smask = None,
rmask = None,
mask_threshold_ref = 0.,
mask_threshold_sketch = 0.,
style_enhance = False,
fg_enhance = False,
bg_enhance = False,
background = None,
targets = None,
anchors = None,
controls = None,
target_scales = None,
enhances = None,
thresholds_list = None,
geometry_map = False,
latent_inpaint = False,
low_vram = False,
*args,
**kwargs
):
# prepare reference embedding
# manipulate = self.check_manipulate(target_scales)
c = {}
uc = [{}, {}]
if exists(reference):
emb, cls_emb = self.get_learned_embedding(reference, sketch=sketch, mapping=geometry_map)
else:
emb, cls_emb = map(lambda t: torch.zeros_like(t), self.get_learned_embedding(sketch))
h, w, score = torch.Tensor([height]), torch.Tensor([width]), torch.Tensor([7.])
y = torch.cat(self.scalar_embedder(torch.cat([(h*w)**0.5, score])).cuda().chunk(2), 1)
if bg_enhance:
assert exists(rmask) and exists(smask)
if low_vram:
self.low_vram_shift(["first", "cond", "img_embedder", "proj"])
if latent_inpaint and exists(background):
bgh, bgw = background.shape[2:]
ch, cw = get_crop_scale(torch.tensor([height]), torch.tensor([width]), bgh, bgw)
hs_bg = self.get_first_stage_encoding(resize_and_crop(background, ch, cw, height, width).to(self.first_stage_model.dtype))
bg_emb, _ = self.get_learned_embedding(background, bg=True)
hs_bg = expand_to_batch_size(hs_bg, bs)
c.update({"inpaint_bg": hs_bg})
else:
if exists(background):
bg_emb, _ = self.get_learned_embedding(background, bg=True)
else:
bg_emb, _ = self.get_learned_embedding(
torch.where(rmask < mask_threshold_ref, reference, torch.ones_like(reference)),
True
)
emb = torch.cat([emb, bg_emb], 1)
if fg_enhance and exists(self.loras):
self.loras.switch_lora(True, "foreground")
if not bg_enhance:
emb = emb.repeat(1, 2, 1)
if fg_enhance or bg_enhance:
# sketch mask for cross-attention
smask = expand_to_batch_size(smask.to(self.dtype), bs)
for d in [c] + uc:
d.update({"mask": F.interpolate(smask, scale_factor=0.125)})
elif exists(self.loras):
self.loras.switch_lora(False)
sketch = sketch.to(self.dtype)
context = expand_to_batch_size(emb, bs).to(self.dtype)
y = expand_to_batch_size(y, bs)
uc_context = torch.zeros_like(context)
control = []
uc_control = []
if low_vram:
self.low_vram_shift(["control_encoder"])
encoded_sketch = self.control_encoder(
torch.cat([sketch, -torch.ones_like(sketch)], 0)
)
for idx, es in enumerate(encoded_sketch):
es = es * control_scale[idx]
ec, uec = es.chunk(2)
control.append(expand_to_batch_size(ec, bs))
uc_control.append(expand_to_batch_size(uec, bs))
c.update({"control": control, "context": [context], "y": [y]})
uc[0].update({"control": control, "context": [uc_context], "y": [y]})
uc[1].update({"control": uc_control, "context": [context], "y": [y]})
return c, uc