from refnet.models.basemodel import CustomizedColorizer, CustomizedWrapper from refnet.util import * from refnet.modules.lora import LoraModules from refnet.modules.reference_net import hack_unet_forward, hack_inference_forward from refnet.sampling.hook import ReferenceAttentionControl class InferenceWrapperXL(CustomizedWrapper, CustomizedColorizer): def __init__( self, scalar_embedder_config, img_embedder_config, fg_encoder_config = None, bg_encoder_config = None, style_encoder_config = None, lora_config = None, logits_embed = False, controller = False, *args, **kwargs ): CustomizedColorizer.__init__(self, version="sdxl", *args, **kwargs) CustomizedWrapper.__init__(self) self.logits_embed = logits_embed ( self.scalar_embedder, self.img_embedder, self.fg_encoder, self.bg_encoder, self.style_encoder ) = map( lambda t: instantiate_from_config(t) if exists(t) else None, ( scalar_embedder_config, img_embedder_config, fg_encoder_config, bg_encoder_config, style_encoder_config ) ) self.loras = LoraModules(self, **lora_config) if controller: self.controller = ReferenceAttentionControl( # time_embed_ch = self.model.diffusion_model.model_channels * 4, reader_module = self.model.diffusion_model, writer_module = self.bg_encoder, # only_decoder = True ) else: self.controller = None new_model_list = { # "style_encoder": self.style_encoder, "scalar_embedder": self.scalar_embedder, "img_embedder": self.img_embedder, # "controller": self.controller } hack_unet_forward(self.model.diffusion_model) if exists(self.fg_encoder): hack_inference_forward(self.fg_encoder) new_model_list["fg_encoder"] = self.fg_encoder if exists(self.bg_encoder): hack_inference_forward(self.bg_encoder) new_model_list["bg_encoder"] = self.bg_encoder # hack_inference_forward(self.bg_encoder) # hack_inference_forward(self.style_encoder) self.switch_cond_modules += list(new_model_list.keys()) # self.switch_main_modules += ["controller"] self.model_list.update(new_model_list) def switch_to_fp16(self): super().switch_to_fp16() self.model.diffusion_model.map_modules.to(self.half_precision_dtype) self.model.diffusion_model.warp_modules.to(self.half_precision_dtype) self.model.diffusion_model.style_modules.to(self.half_precision_dtype) self.model.diffusion_model.conv_fg.to(self.half_precision_dtype) if exists(self.fg_encoder): self.fg_encoder.to(self.half_precision_dtype) self.fg_encoder.dtype = self.half_precision_dtype self.fg_encoder.time_embed.float() if exists(self.bg_encoder): self.bg_encoder.to(self.half_precision_dtype) self.bg_encoder.dtype = self.half_precision_dtype self.bg_encoder.time_embed.float() # self.style_encoder.to(self.half_precision_dtype) # self.style_encoder.dtype = self.half_precision_dtype # self.style_encoder.time_embed.float() def switch_to_fp32(self): super().switch_to_fp32() self.model.diffusion_model.map_modules.float() self.model.diffusion_model.warp_modules.float() self.model.diffusion_model.style_modules.float() self.fg_encoder.float() self.bg_encoder.float() # self.style_encoder.float() self.fg_encoder.dtype = torch.float32 self.bg_encoder.dtype = torch.float32 # self.style_encoder.dtype = torch.float32 def rescale_size(self, x: torch.Tensor, height, width): oh, ow = x.shape[2:] if oh < height or ow < width: dh, dw = height - oh, width - ow if dh > dw: iw = ow + int(dh * ow/oh) ih = height else: ih = oh + int(dw * oh/ow) iw = width else: ih, iw = oh, ow return torch.tensor([ih]), torch.tensor([iw]) def rescale_background_size(self, x, height, width): oh, ow = x.shape[2:] if oh < height or ow < width: # A simple bias to avoid deterioration caused by reference resolution mind = max(height, width) ih = oh + mind iw = ow / oh * ih else: ih, iw = oh, ow # rh, rw = ih / height, iw / width return torch.tensor([ih]), torch.tensor([iw]) def get_learned_embedding(self, c, bg=False, sketch=None, mapping=False, *args, **kwargs): clip_emb = self.cond_stage_model.encode(c, "full").detach() wd_emb, logits = self.img_embedder.encode(c, pooled=False, return_logits=True) cls_emb, local_emb = clip_emb[:, :1], clip_emb[:, 1:] if self.logits_embed and exists(sketch) and mapping: _, sketch_logits = self.img_embedder.encode(-sketch, pooled=True, return_logits=True) logits = self.img_embedder.geometry_update(logits, sketch_logits) if self.logits_embed: emb = self.proj(clip_emb, logits, bg)[0] else: emb = self.proj(clip_emb, wd_emb, bg) return emb.to(self.dtype), cls_emb.to(self.dtype) def prepare_conditions( self, bs, sketch, reference, height, width, control_scale = 1, mask_scale = 1, merge_scale = 0., cond_aug = 0., background = None, smask = None, rmask = None, mask_threshold_ref = 0., mask_threshold_sketch = 0., style_enhance = False, fg_enhance = False, bg_enhance = False, latent_inpaint = False, fg_disentangle_scale = 1., targets = None, anchors = None, controls = None, target_scales = None, enhances = None, thresholds_list = None, low_vram = False, *args, **kwargs ): def prepare_style_modulations(y): # Style enhancement part z_ref = self.get_first_stage_encoding(warp_resize(reference, (height, width))) if exists(background) and merge_scale > 0: rh, rw = self.rescale_size(background, height, width) z_bg = self.get_first_stage_encoding(warp_resize(background, (height, width))) bg_emb, bg_cls_emb = self.get_learned_embedding(background) scalar_embed = torch.cat( self.scalar_embedder(torch.cat([rh, rw, ct, cl, h, w])).chunk(6), 1 ).to(bg_emb.device) bgy = torch.cat([bg_cls_emb.squeeze(1), scalar_embed], 1).to(self.dtype) style_modulations = self.style_encoder( torch.cat([z_ref, z_bg]), timesteps = torch.zeros((2,), dtype=torch.long, device=z_ref.device), context = torch.cat([emb, bg_emb]), y = torch.cat([y, bgy]) ) for idx, m in enumerate(style_modulations): fg, bg = m.chunk(2) m = fg * (1-merge_scale) + merge_scale * bg style_modulations[idx] = expand_to_batch_size(m, bs).to(self.dtype) else: z_bg = None bg_emb = None bgy = None style_modulations = self.style_encoder( z_ref, timesteps = torch.zeros((1,), dtype=torch.long, device=z_ref.device), context = emb, y = y, ) style_modulations = [expand_to_batch_size(m, bs).to(self.dtype) for m in style_modulations] return style_modulations, z_bg, bg_emb, bgy def prepare_background_latents(z_bg, bg_emb, bgy): # Background enhancement part bgh, bgw = background.shape[2:] if exists(background) else reference.shape[2:] ch, cw = get_crop_scale(h, w, bgh, bgw) if low_vram: self.low_vram_shift(["first", "cond", "img_embedder"]) if latent_inpaint and exists(background): hs_bg = self.get_first_stage_encoding(resize_and_crop(background, ch, cw, height, width)) bg_emb, cls_emb = self.get_learned_embedding(background) else: if not exists(z_bg): bgy = torch.cat( self.scalar_embedder(torch.tensor([ct, cl, ch, cw])).chunk(4), 1 # self.scalar_embedder(torch.tensor([bgh / bgw, h / w, ct, cl, ch, cw])).chunk(6), 1 ).to(self.dtype).cuda() if exists(background): # bgh, bgw = self.rescale_background_size(background, height, width) z_bg = self.get_first_stage_encoding(warp_resize(background, (height, width))) bg_emb, cls_emb = self.get_learned_embedding(background) # scalar_embed = torch.cat(self.scalar_embedder(torch.cat([bgh, bgw, ct, cl, h, w])).chunk(6), 1).cuda() # bgy = torch.cat([cls_emb.squeeze(1), scalar_embed], 1).to(self.dtype) else: xbg = torch.where(rmask < mask_threshold_ref, reference, torch.ones_like(reference)) z_bg = self.get_first_stage_encoding(warp_resize(xbg, (height, width))) bg_emb, cls_emb = self.get_learned_embedding(xbg) if low_vram: self.low_vram_shift(["bg_encoder"]) hs_bg = self.bg_encoder( x = torch.cat([ z_bg, # torch.where( # smask > mask_threshold_sketch, # torch.zeros_like(smask), # F.interpolate(warp_resize(rmask, (height, width)), scale_factor=0.125) # ) F.interpolate(warp_resize(smask, (height, width)), scale_factor=0.125), F.interpolate(warp_resize(rmask, (height, width)), scale_factor=0.125) ], 1), timesteps = torch.zeros((1,), dtype=torch.long, device=z_bg.device), # context = bg_emb, y = bgy.to(self.dtype), ) return hs_bg, bg_emb self.loras.recover_lora() # prepare reference embedding # manipulate = self.check_manipulate(target_scales) c = {} uc = [{}, {}] self.loras.switch_lora(False) # self.loras.recover_lora() if exists(reference): emb, cls_emb = self.get_learned_embedding(reference, sketch=sketch) # rh, rw = reference.shape[2:] # rh, rw = self.rescale_background_size(reference, height, width) else: emb, cls_emb = map(lambda t: torch.zeros_like(t), self.get_learned_embedding(sketch)) # rh, rw = torch.Tensor([height]), torch.Tensor([width]) ct, cl = torch.Tensor([0]), torch.Tensor([0]) # h, w = torch.Tensor([height]), torch.Tensor([width]) # scalar_embed = torch.cat(self.scalar_embedder(torch.cat([rh, rw, ct, cl, h, w])).chunk(6), 1).cuda() # y = torch.cat([cls_emb.squeeze(1), scalar_embed], 1) # y = self.scalar_embedder((h*w)**0.5).cuda() # y = torch.cat(self.scalar_embedder(torch.cat([h, w])).chunk(2), 1).cuda() h, w, score = torch.Tensor([height]), torch.Tensor([width]), torch.Tensor([7.]) y = torch.cat(self.scalar_embedder(torch.cat([(h * w) ** 0.5, score])).cuda().chunk(2), 1) z_bg, bg_emb, bgy = None, None, None # Style enhance part if style_enhance: style_modulations, z_bg, bg_emb, bgy = prepare_style_modulations(y) for d in [c] + uc: d.update({"style_modulations": style_modulations}) # Foreground enhance part if fg_enhance: assert exists(smask) and exists(rmask) self.loras.switch_lora(True, "foreground") if low_vram: self.low_vram_shift(["first"]) z_fg = self.get_first_stage_encoding(warp_resize( torch.where(rmask >= mask_threshold_ref, reference, torch.ones_like(reference)), (height, width) )) * fg_disentangle_scale # z_ref = default(z_ref, self.get_first_stage_encoding(warp_resize(reference, (height, width)))) # self.loras.switch_lora(True, False) self.loras.adjust_lora_scales(fg_disentangle_scale, "foreground") if low_vram: self.low_vram_shift(["fg_encoder"]) hs_fg = self.fg_encoder( z_fg, timesteps = torch.zeros((1,), dtype=torch.long, device=z_fg.device), ) # hs_fg = [hs * fg_disentangle_scale for hs in hs_fg] hs_fg = expand_to_batch_size(hs_fg, bs) for d in [c] + uc: d.update({ "hs_fg": hs_fg, "inject_mask": expand_to_batch_size(smask, bs), }) # for d in [c] + uc: # d.update({"z_fg": expand_to_batch_size(z_fg, bs)}) # Background enhance part if bg_enhance: assert exists(rmask) and exists(smask) # if not self.controller.hooked: # self.controller.register("read", self.model.diffusion_model) # self.loras.switch_lora(False, True) hs_bg, bg_emb = prepare_background_latents(z_bg, bg_emb, default(bgy, y)) self.loras.switch_lora(True, "background") if latent_inpaint and exists(background): hs_bg = expand_to_batch_size(hs_bg, bs) c.update({"inpaint_bg": hs_bg}) elif exists(self.controller): # self.loras.merge_lora() self.controller.update() else: hs_bg = expand_to_batch_size(hs_bg, bs) for d in [c] + uc: d.update({"hs_bg": hs_bg}) elif exists(self.controller): # self.controller.reader_restore() self.controller.clean() if fg_enhance or bg_enhance: # need to activate mask-guided split cross-attetnion emb = torch.cat([emb, default(bg_emb, emb)], 1) smask = expand_to_batch_size(smask.to(self.dtype), bs) for d in [c] + uc: d.update({"mask": F.interpolate(smask, scale_factor=0.125), "threshold": mask_threshold_sketch}) # if fg_enhance and bg_enhance: # self.loras.switch_lora(True, True) sketch = sketch.to(self.dtype) context = expand_to_batch_size(emb, bs).to(self.dtype) y = expand_to_batch_size(y, bs).float() uc_context = torch.zeros_like(context) control = [] uc_control = [] if low_vram: self.low_vram_shift(["control_encoder"]) encoded_sketch = self.control_encoder( torch.cat([sketch, -torch.ones_like(sketch)], 0) ) for idx, es in enumerate(encoded_sketch): es = es * control_scale[idx] ec, uec = es.chunk(2) control.append(expand_to_batch_size(ec, bs)) uc_control.append(expand_to_batch_size(uec, bs)) self.loras.merge_lora() c.update({"control": control, "context": [context], "y": [y]}) uc[0].update({"control": control, "context": [uc_context], "y": [y]}) uc[1].update({"control": uc_control, "context": [context], "y": [y]}) return c, uc