Spaces:
Running on Zero
Running on Zero
| from refnet.models.basemodel import CustomizedColorizer, CustomizedWrapper | |
| from refnet.util import * | |
| from refnet.modules.lora import LoraModules | |
| from refnet.modules.reference_net import hack_unet_forward, hack_inference_forward | |
| from refnet.sampling.hook import ReferenceAttentionControl | |
| class InferenceWrapperXL(CustomizedWrapper, CustomizedColorizer): | |
| def __init__( | |
| self, | |
| scalar_embedder_config, | |
| img_embedder_config, | |
| fg_encoder_config = None, | |
| bg_encoder_config = None, | |
| style_encoder_config = None, | |
| lora_config = None, | |
| logits_embed = False, | |
| controller = False, | |
| *args, | |
| **kwargs | |
| ): | |
| CustomizedColorizer.__init__(self, version="sdxl", *args, **kwargs) | |
| CustomizedWrapper.__init__(self) | |
| self.logits_embed = logits_embed | |
| ( | |
| self.scalar_embedder, | |
| self.img_embedder, | |
| self.fg_encoder, | |
| self.bg_encoder, | |
| self.style_encoder | |
| ) = map( | |
| lambda t: instantiate_from_config(t) if exists(t) else None, | |
| ( | |
| scalar_embedder_config, | |
| img_embedder_config, | |
| fg_encoder_config, | |
| bg_encoder_config, | |
| style_encoder_config | |
| ) | |
| ) | |
| self.loras = LoraModules(self, **lora_config) | |
| if controller: | |
| self.controller = ReferenceAttentionControl( | |
| # time_embed_ch = self.model.diffusion_model.model_channels * 4, | |
| reader_module = self.model.diffusion_model, | |
| writer_module = self.bg_encoder, | |
| # only_decoder = True | |
| ) | |
| else: | |
| self.controller = None | |
| new_model_list = { | |
| # "style_encoder": self.style_encoder, | |
| "scalar_embedder": self.scalar_embedder, | |
| "img_embedder": self.img_embedder, | |
| # "controller": self.controller | |
| } | |
| hack_unet_forward(self.model.diffusion_model) | |
| if exists(self.fg_encoder): | |
| hack_inference_forward(self.fg_encoder) | |
| new_model_list["fg_encoder"] = self.fg_encoder | |
| if exists(self.bg_encoder): | |
| hack_inference_forward(self.bg_encoder) | |
| new_model_list["bg_encoder"] = self.bg_encoder | |
| # hack_inference_forward(self.bg_encoder) | |
| # hack_inference_forward(self.style_encoder) | |
| self.switch_cond_modules += list(new_model_list.keys()) | |
| # self.switch_main_modules += ["controller"] | |
| self.model_list.update(new_model_list) | |
| def switch_to_fp16(self): | |
| super().switch_to_fp16() | |
| self.model.diffusion_model.map_modules.to(self.half_precision_dtype) | |
| self.model.diffusion_model.warp_modules.to(self.half_precision_dtype) | |
| self.model.diffusion_model.style_modules.to(self.half_precision_dtype) | |
| self.model.diffusion_model.conv_fg.to(self.half_precision_dtype) | |
| if exists(self.fg_encoder): | |
| self.fg_encoder.to(self.half_precision_dtype) | |
| self.fg_encoder.dtype = self.half_precision_dtype | |
| self.fg_encoder.time_embed.float() | |
| if exists(self.bg_encoder): | |
| self.bg_encoder.to(self.half_precision_dtype) | |
| self.bg_encoder.dtype = self.half_precision_dtype | |
| self.bg_encoder.time_embed.float() | |
| # self.style_encoder.to(self.half_precision_dtype) | |
| # self.style_encoder.dtype = self.half_precision_dtype | |
| # self.style_encoder.time_embed.float() | |
| def switch_to_fp32(self): | |
| super().switch_to_fp32() | |
| self.model.diffusion_model.map_modules.float() | |
| self.model.diffusion_model.warp_modules.float() | |
| self.model.diffusion_model.style_modules.float() | |
| self.fg_encoder.float() | |
| self.bg_encoder.float() | |
| # self.style_encoder.float() | |
| self.fg_encoder.dtype = torch.float32 | |
| self.bg_encoder.dtype = torch.float32 | |
| # self.style_encoder.dtype = torch.float32 | |
| def rescale_size(self, x: torch.Tensor, height, width): | |
| oh, ow = x.shape[2:] | |
| if oh < height or ow < width: | |
| dh, dw = height - oh, width - ow | |
| if dh > dw: | |
| iw = ow + int(dh * ow/oh) | |
| ih = height | |
| else: | |
| ih = oh + int(dw * oh/ow) | |
| iw = width | |
| else: | |
| ih, iw = oh, ow | |
| return torch.tensor([ih]), torch.tensor([iw]) | |
| def rescale_background_size(self, x, height, width): | |
| oh, ow = x.shape[2:] | |
| if oh < height or ow < width: | |
| # A simple bias to avoid deterioration caused by reference resolution | |
| mind = max(height, width) | |
| ih = oh + mind | |
| iw = ow / oh * ih | |
| else: | |
| ih, iw = oh, ow | |
| # rh, rw = ih / height, iw / width | |
| return torch.tensor([ih]), torch.tensor([iw]) | |
| def get_learned_embedding(self, c, bg=False, sketch=None, mapping=False, *args, **kwargs): | |
| clip_emb = self.cond_stage_model.encode(c, "full").detach() | |
| wd_emb, logits = self.img_embedder.encode(c, pooled=False, return_logits=True) | |
| cls_emb, local_emb = clip_emb[:, :1], clip_emb[:, 1:] | |
| if self.logits_embed and exists(sketch) and mapping: | |
| _, sketch_logits = self.img_embedder.encode(-sketch, pooled=True, return_logits=True) | |
| logits = self.img_embedder.geometry_update(logits, sketch_logits) | |
| if self.logits_embed: | |
| emb = self.proj(clip_emb, logits, bg)[0] | |
| else: | |
| emb = self.proj(clip_emb, wd_emb, bg) | |
| return emb.to(self.dtype), cls_emb.to(self.dtype) | |
| def prepare_conditions( | |
| self, | |
| bs, | |
| sketch, | |
| reference, | |
| height, | |
| width, | |
| control_scale = 1, | |
| mask_scale = 1, | |
| merge_scale = 0., | |
| cond_aug = 0., | |
| background = None, | |
| smask = None, | |
| rmask = None, | |
| mask_threshold_ref = 0., | |
| mask_threshold_sketch = 0., | |
| style_enhance = False, | |
| fg_enhance = False, | |
| bg_enhance = False, | |
| latent_inpaint = False, | |
| fg_disentangle_scale = 1., | |
| targets = None, | |
| anchors = None, | |
| controls = None, | |
| target_scales = None, | |
| enhances = None, | |
| thresholds_list = None, | |
| low_vram = False, | |
| *args, | |
| **kwargs | |
| ): | |
| def prepare_style_modulations(y): | |
| # Style enhancement part | |
| z_ref = self.get_first_stage_encoding(warp_resize(reference, (height, width))) | |
| if exists(background) and merge_scale > 0: | |
| rh, rw = self.rescale_size(background, height, width) | |
| z_bg = self.get_first_stage_encoding(warp_resize(background, (height, width))) | |
| bg_emb, bg_cls_emb = self.get_learned_embedding(background) | |
| scalar_embed = torch.cat( | |
| self.scalar_embedder(torch.cat([rh, rw, ct, cl, h, w])).chunk(6), 1 | |
| ).to(bg_emb.device) | |
| bgy = torch.cat([bg_cls_emb.squeeze(1), scalar_embed], 1).to(self.dtype) | |
| style_modulations = self.style_encoder( | |
| torch.cat([z_ref, z_bg]), | |
| timesteps = torch.zeros((2,), dtype=torch.long, device=z_ref.device), | |
| context = torch.cat([emb, bg_emb]), | |
| y = torch.cat([y, bgy]) | |
| ) | |
| for idx, m in enumerate(style_modulations): | |
| fg, bg = m.chunk(2) | |
| m = fg * (1-merge_scale) + merge_scale * bg | |
| style_modulations[idx] = expand_to_batch_size(m, bs).to(self.dtype) | |
| else: | |
| z_bg = None | |
| bg_emb = None | |
| bgy = None | |
| style_modulations = self.style_encoder( | |
| z_ref, | |
| timesteps = torch.zeros((1,), dtype=torch.long, device=z_ref.device), | |
| context = emb, | |
| y = y, | |
| ) | |
| style_modulations = [expand_to_batch_size(m, bs).to(self.dtype) for m in style_modulations] | |
| return style_modulations, z_bg, bg_emb, bgy | |
| def prepare_background_latents(z_bg, bg_emb, bgy): | |
| # Background enhancement part | |
| bgh, bgw = background.shape[2:] if exists(background) else reference.shape[2:] | |
| ch, cw = get_crop_scale(h, w, bgh, bgw) | |
| if low_vram: | |
| self.low_vram_shift(["first", "cond", "img_embedder"]) | |
| if latent_inpaint and exists(background): | |
| hs_bg = self.get_first_stage_encoding(resize_and_crop(background, ch, cw, height, width)) | |
| bg_emb, cls_emb = self.get_learned_embedding(background) | |
| else: | |
| if not exists(z_bg): | |
| bgy = torch.cat( | |
| self.scalar_embedder(torch.tensor([ct, cl, ch, cw])).chunk(4), 1 | |
| # self.scalar_embedder(torch.tensor([bgh / bgw, h / w, ct, cl, ch, cw])).chunk(6), 1 | |
| ).to(self.dtype).cuda() | |
| if exists(background): | |
| # bgh, bgw = self.rescale_background_size(background, height, width) | |
| z_bg = self.get_first_stage_encoding(warp_resize(background, (height, width))) | |
| bg_emb, cls_emb = self.get_learned_embedding(background) | |
| # scalar_embed = torch.cat(self.scalar_embedder(torch.cat([bgh, bgw, ct, cl, h, w])).chunk(6), 1).cuda() | |
| # bgy = torch.cat([cls_emb.squeeze(1), scalar_embed], 1).to(self.dtype) | |
| else: | |
| xbg = torch.where(rmask < mask_threshold_ref, reference, torch.ones_like(reference)) | |
| z_bg = self.get_first_stage_encoding(warp_resize(xbg, (height, width))) | |
| bg_emb, cls_emb = self.get_learned_embedding(xbg) | |
| if low_vram: | |
| self.low_vram_shift(["bg_encoder"]) | |
| hs_bg = self.bg_encoder( | |
| x = torch.cat([ | |
| z_bg, | |
| # torch.where( | |
| # smask > mask_threshold_sketch, | |
| # torch.zeros_like(smask), | |
| # F.interpolate(warp_resize(rmask, (height, width)), scale_factor=0.125) | |
| # ) | |
| F.interpolate(warp_resize(smask, (height, width)), scale_factor=0.125), | |
| F.interpolate(warp_resize(rmask, (height, width)), scale_factor=0.125) | |
| ], 1), | |
| timesteps = torch.zeros((1,), dtype=torch.long, device=z_bg.device), | |
| # context = bg_emb, | |
| y = bgy.to(self.dtype), | |
| ) | |
| return hs_bg, bg_emb | |
| self.loras.recover_lora() | |
| # prepare reference embedding | |
| # manipulate = self.check_manipulate(target_scales) | |
| c = {} | |
| uc = [{}, {}] | |
| self.loras.switch_lora(False) | |
| # self.loras.recover_lora() | |
| if exists(reference): | |
| emb, cls_emb = self.get_learned_embedding(reference, sketch=sketch) | |
| # rh, rw = reference.shape[2:] | |
| # rh, rw = self.rescale_background_size(reference, height, width) | |
| else: | |
| emb, cls_emb = map(lambda t: torch.zeros_like(t), self.get_learned_embedding(sketch)) | |
| # rh, rw = torch.Tensor([height]), torch.Tensor([width]) | |
| ct, cl = torch.Tensor([0]), torch.Tensor([0]) | |
| # h, w = torch.Tensor([height]), torch.Tensor([width]) | |
| # scalar_embed = torch.cat(self.scalar_embedder(torch.cat([rh, rw, ct, cl, h, w])).chunk(6), 1).cuda() | |
| # y = torch.cat([cls_emb.squeeze(1), scalar_embed], 1) | |
| # y = self.scalar_embedder((h*w)**0.5).cuda() | |
| # y = torch.cat(self.scalar_embedder(torch.cat([h, w])).chunk(2), 1).cuda() | |
| h, w, score = torch.Tensor([height]), torch.Tensor([width]), torch.Tensor([7.]) | |
| y = torch.cat(self.scalar_embedder(torch.cat([(h * w) ** 0.5, score])).cuda().chunk(2), 1) | |
| z_bg, bg_emb, bgy = None, None, None | |
| # Style enhance part | |
| if style_enhance: | |
| style_modulations, z_bg, bg_emb, bgy = prepare_style_modulations(y) | |
| for d in [c] + uc: | |
| d.update({"style_modulations": style_modulations}) | |
| # Foreground enhance part | |
| if fg_enhance: | |
| assert exists(smask) and exists(rmask) | |
| self.loras.switch_lora(True, "foreground") | |
| if low_vram: | |
| self.low_vram_shift(["first"]) | |
| z_fg = self.get_first_stage_encoding(warp_resize( | |
| torch.where(rmask >= mask_threshold_ref, reference, torch.ones_like(reference)), | |
| (height, width) | |
| )) * fg_disentangle_scale | |
| # z_ref = default(z_ref, self.get_first_stage_encoding(warp_resize(reference, (height, width)))) | |
| # self.loras.switch_lora(True, False) | |
| self.loras.adjust_lora_scales(fg_disentangle_scale, "foreground") | |
| if low_vram: | |
| self.low_vram_shift(["fg_encoder"]) | |
| hs_fg = self.fg_encoder( | |
| z_fg, | |
| timesteps = torch.zeros((1,), dtype=torch.long, device=z_fg.device), | |
| ) | |
| # hs_fg = [hs * fg_disentangle_scale for hs in hs_fg] | |
| hs_fg = expand_to_batch_size(hs_fg, bs) | |
| for d in [c] + uc: | |
| d.update({ | |
| "hs_fg": hs_fg, | |
| "inject_mask": expand_to_batch_size(smask, bs), | |
| }) | |
| # for d in [c] + uc: | |
| # d.update({"z_fg": expand_to_batch_size(z_fg, bs)}) | |
| # Background enhance part | |
| if bg_enhance: | |
| assert exists(rmask) and exists(smask) | |
| # if not self.controller.hooked: | |
| # self.controller.register("read", self.model.diffusion_model) | |
| # self.loras.switch_lora(False, True) | |
| hs_bg, bg_emb = prepare_background_latents(z_bg, bg_emb, default(bgy, y)) | |
| self.loras.switch_lora(True, "background") | |
| if latent_inpaint and exists(background): | |
| hs_bg = expand_to_batch_size(hs_bg, bs) | |
| c.update({"inpaint_bg": hs_bg}) | |
| elif exists(self.controller): | |
| # self.loras.merge_lora() | |
| self.controller.update() | |
| else: | |
| hs_bg = expand_to_batch_size(hs_bg, bs) | |
| for d in [c] + uc: | |
| d.update({"hs_bg": hs_bg}) | |
| elif exists(self.controller): | |
| # self.controller.reader_restore() | |
| self.controller.clean() | |
| if fg_enhance or bg_enhance: | |
| # need to activate mask-guided split cross-attetnion | |
| emb = torch.cat([emb, default(bg_emb, emb)], 1) | |
| smask = expand_to_batch_size(smask.to(self.dtype), bs) | |
| for d in [c] + uc: | |
| d.update({"mask": F.interpolate(smask, scale_factor=0.125), "threshold": mask_threshold_sketch}) | |
| # if fg_enhance and bg_enhance: | |
| # self.loras.switch_lora(True, True) | |
| sketch = sketch.to(self.dtype) | |
| context = expand_to_batch_size(emb, bs).to(self.dtype) | |
| y = expand_to_batch_size(y, bs).float() | |
| uc_context = torch.zeros_like(context) | |
| control = [] | |
| uc_control = [] | |
| if low_vram: | |
| self.low_vram_shift(["control_encoder"]) | |
| encoded_sketch = self.control_encoder( | |
| torch.cat([sketch, -torch.ones_like(sketch)], 0) | |
| ) | |
| for idx, es in enumerate(encoded_sketch): | |
| es = es * control_scale[idx] | |
| ec, uec = es.chunk(2) | |
| control.append(expand_to_batch_size(ec, bs)) | |
| uc_control.append(expand_to_batch_size(uec, bs)) | |
| self.loras.merge_lora() | |
| c.update({"control": control, "context": [context], "y": [y]}) | |
| uc[0].update({"control": control, "context": [uc_context], "y": [y]}) | |
| uc[1].update({"control": uc_control, "context": [context], "y": [y]}) | |
| return c, uc |