Spaces:
Running on Zero
Running on Zero
| import torch | |
| import torch.nn.functional as F | |
| from ..modules.reference_net import hack_inference_forward | |
| from ..models.basemodel import CustomizedColorizer, CustomizedWrapper | |
| from ..modules.lora import LoraModules | |
| from ..util import exists, expand_to_batch_size, instantiate_from_config, get_crop_scale, resize_and_crop | |
| class InferenceWrapper(CustomizedWrapper, CustomizedColorizer): | |
| def __init__( | |
| self, | |
| scalar_embedder_config, | |
| img_embedder_config, | |
| lora_config = None, | |
| logits_embed = False, | |
| *args, | |
| **kwargs | |
| ): | |
| CustomizedColorizer.__init__(self, version="sdxl", *args, **kwargs) | |
| CustomizedWrapper.__init__(self) | |
| self.scalar_embedder = instantiate_from_config(scalar_embedder_config) | |
| self.img_embedder = instantiate_from_config(img_embedder_config) | |
| self.loras = LoraModules(self, **lora_config) if exists(lora_config) else None | |
| self.logits_embed = logits_embed | |
| new_model_list = { | |
| "scalar_embedder": self.scalar_embedder, | |
| "img_embedder": self.img_embedder, | |
| # "style_encoder": self.style_encoder, | |
| } | |
| self.switch_cond_modules += list(new_model_list.keys()) | |
| self.model_list.update(new_model_list) | |
| def retrieve_attn_modules(self): | |
| scale_factor_levels = {"high": 0.5, "low": 0.25, "bottom": 0.25} | |
| from refnet.modules.transformer import BasicTransformerBlock | |
| from refnet.sampling import torch_dfs | |
| attn_modules = [] | |
| for module in torch_dfs(self.model.diffusion_model): | |
| if isinstance(module, BasicTransformerBlock): | |
| attn_modules.append(module) | |
| self.attn_modules = { | |
| "high": [0, 1, 2, 3] + [64, 65, 66, 67, 68, 69], | |
| "low": [i for i in range(4, 24)] + [i for i in range(34, 64)], | |
| "bottom": [i for i in range(24, 34)], | |
| "encoder": [i for i in range(24)], | |
| "decoder": [i for i in range(34, len(attn_modules))] | |
| } | |
| self.attn_modules["modules"] = attn_modules | |
| for k in ["high", "low", "bottom"]: | |
| scale_factor = scale_factor_levels[k] | |
| for attn in self.attn_modules[k]: | |
| attn_modules[attn].scale_factor = scale_factor | |
| def adjust_reference_scale(self, scale_kwargs): | |
| for module in self.attn_modules["modules"]: | |
| module.reference_scale = scale_kwargs["scales"]["encoder"] | |
| def adjust_masked_attn(self, scale, mask_threshold, merge_scale): | |
| for layer in self.attn_layers: | |
| layer.mask_scale = scale | |
| layer.mask_threshold = mask_threshold | |
| layer.merge_scale = merge_scale | |
| def rescale_size(self, x: torch.Tensor, height, width): | |
| oh, ow = x.shape[2:] | |
| if oh < height or ow < width: | |
| dh, dw = height - oh, width - ow | |
| if dh > dw: | |
| iw = ow + int(dh * ow/oh) | |
| ih = height | |
| else: | |
| ih = oh + int(dw * oh/ow) | |
| iw = width | |
| else: | |
| ih, iw = oh, ow | |
| return torch.Tensor([ih]), torch.Tensor([iw]) | |
| def get_learned_embedding(self, c, bg=False, mapping=False, sketch=None, *args, **kwargs): | |
| clip_emb = self.cond_stage_model.encode(c, "full").detach() | |
| wd_emb, logits = self.img_embedder.encode(c, pooled=False, return_logits=True) | |
| cls_emb, local_emb = clip_emb[:, :1], clip_emb[:, 1:] | |
| if mapping: | |
| _, sketch_logits = self.img_embedder.encode(-sketch, pooled=False, return_logits=True) | |
| sketch_logits.mean(dim=1, keepdim=True) | |
| logits = self.img_embedder.geometry_update(logits, sketch_logits) | |
| emb = self.proj(clip_emb, logits if self.logits_embed else wd_emb, bg) | |
| return emb, cls_emb | |
| def prepare_conditions( | |
| self, | |
| bs, | |
| sketch, | |
| reference, | |
| height, | |
| width, | |
| control_scale = (1., 1., 1., 1.), | |
| merge_scale = 0, | |
| mask_scale = 1., | |
| fg_scale = 1., | |
| bg_scale = 1., | |
| smask = None, | |
| rmask = None, | |
| mask_threshold_ref = 0., | |
| mask_threshold_sketch = 0., | |
| style_enhance = False, | |
| fg_enhance = False, | |
| bg_enhance = False, | |
| background = None, | |
| targets = None, | |
| anchors = None, | |
| controls = None, | |
| target_scales = None, | |
| enhances = None, | |
| thresholds_list = None, | |
| geometry_map = False, | |
| latent_inpaint = False, | |
| low_vram = False, | |
| *args, | |
| **kwargs | |
| ): | |
| # prepare reference embedding | |
| # manipulate = self.check_manipulate(target_scales) | |
| c = {} | |
| uc = [{}, {}] | |
| if exists(reference): | |
| emb, cls_emb = self.get_learned_embedding(reference, sketch=sketch, mapping=geometry_map) | |
| else: | |
| emb, cls_emb = map(lambda t: torch.zeros_like(t), self.get_learned_embedding(sketch)) | |
| h, w, score = torch.Tensor([height]), torch.Tensor([width]), torch.Tensor([7.]) | |
| y = torch.cat(self.scalar_embedder(torch.cat([(h*w)**0.5, score])).cuda().chunk(2), 1) | |
| if bg_enhance: | |
| assert exists(rmask) and exists(smask) | |
| if low_vram: | |
| self.low_vram_shift(["first", "cond", "img_embedder", "proj"]) | |
| if latent_inpaint and exists(background): | |
| bgh, bgw = background.shape[2:] | |
| ch, cw = get_crop_scale(torch.tensor([height]), torch.tensor([width]), bgh, bgw) | |
| hs_bg = self.get_first_stage_encoding(resize_and_crop(background, ch, cw, height, width).to(self.first_stage_model.dtype)) | |
| bg_emb, _ = self.get_learned_embedding(background, bg=True) | |
| hs_bg = expand_to_batch_size(hs_bg, bs) | |
| c.update({"inpaint_bg": hs_bg}) | |
| else: | |
| if exists(background): | |
| bg_emb, _ = self.get_learned_embedding(background, bg=True) | |
| else: | |
| bg_emb, _ = self.get_learned_embedding( | |
| torch.where(rmask < mask_threshold_ref, reference, torch.ones_like(reference)), | |
| True | |
| ) | |
| emb = torch.cat([emb, bg_emb], 1) | |
| if fg_enhance and exists(self.loras): | |
| self.loras.switch_lora(True, "foreground") | |
| if not bg_enhance: | |
| emb = emb.repeat(1, 2, 1) | |
| if fg_enhance or bg_enhance: | |
| # sketch mask for cross-attention | |
| smask = expand_to_batch_size(smask.to(self.dtype), bs) | |
| for d in [c] + uc: | |
| d.update({"mask": F.interpolate(smask, scale_factor=0.125)}) | |
| elif exists(self.loras): | |
| self.loras.switch_lora(False) | |
| sketch = sketch.to(self.dtype) | |
| context = expand_to_batch_size(emb, bs).to(self.dtype) | |
| y = expand_to_batch_size(y, bs) | |
| uc_context = torch.zeros_like(context) | |
| control = [] | |
| uc_control = [] | |
| if low_vram: | |
| self.low_vram_shift(["control_encoder"]) | |
| encoded_sketch = self.control_encoder( | |
| torch.cat([sketch, -torch.ones_like(sketch)], 0) | |
| ) | |
| for idx, es in enumerate(encoded_sketch): | |
| es = es * control_scale[idx] | |
| ec, uec = es.chunk(2) | |
| control.append(expand_to_batch_size(ec, bs)) | |
| uc_control.append(expand_to_batch_size(uec, bs)) | |
| c.update({"control": control, "context": [context], "y": [y]}) | |
| uc[0].update({"control": control, "context": [uc_context], "y": [y]}) | |
| uc[1].update({"control": uc_control, "context": [context], "y": [y]}) | |
| return c, uc |