import torch from refnet.util import exists, fitting_weights, instantiate_from_config, load_weights, delete_states from refnet.ldm import LatentDiffusion from typing import Union from refnet.sampling import ( UnetHook, KDiffusionSampler, DiffuserDenoiser, ) class GuidanceFlag: none = 0 reference = 1 sketch = 10 both = 11 def reconstruct_cond(cond, uncond): if not isinstance(uncond, list): uncond = [uncond] for k in cond.keys(): if k == "inpaint_bg": continue for uc in uncond: if isinstance(cond[k], list): cond[k] = [torch.cat([cond[k][i], uc[k][i]]) for i in range(len(cond[k]))] elif isinstance(cond[k], torch.Tensor): cond[k] = torch.cat([cond[k], uc[k]]) return cond class CustomizedLDM(LatentDiffusion): def __init__( self, dtype = torch.float32, sigma_max = None, sigma_min = None, *args, **kwargs ): super().__init__(*args, **kwargs) self.dtype = dtype self.sigma_max = sigma_max self.sigma_min = sigma_min self.model_list = { "first": self.first_stage_model, "cond": self.cond_stage_model, "unet": self.model, } self.switch_cond_modules = ["cond"] self.switch_main_modules = ["unet"] self.retrieve_attn_modules() self.retrieve_attn_layers() def init_from_ckpt( self, path, only_model = False, logging = False, make_it_fit = False, ignore_keys: list[str] = (), ): sd = delete_states(load_weights(path), ignore_keys) if make_it_fit: sd = fitting_weights(self, sd) missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model \ else self.model.load_state_dict(sd, strict=False) filtered_missing = [] filtered_unexpect = [] for k in missing: if not k.find("cond_stage_model") > -1 and not k.find("img_embedder") > -1 and not k.find("fg") > -1: filtered_missing.append(k) for k in unexpected: if not k.find("cond_stage_model") > -1 and not k.find("img_embedder") > -1: filtered_unexpect.append(k) print( f"Restored from {path} with {len(filtered_missing)} filtered missing and " f"{len(filtered_unexpect)} filtered unexpected keys") if logging: if len(missing) > 0: print(f"Filtered missing Keys: {filtered_missing}") if len(unexpected) > 0: print(f"Filtered unexpected Keys: {filtered_unexpect}") def sample( self, cond: dict, uncond: Union[dict, list[dict]] = None, cfg_scale: Union[float, list[float]] = 1., bs: int = 1, shape: Union[tuple, list] = None, step: int = 20, sampler = "DPM++ 3M SDE", scheduler = "Automatic", device = "cuda", x_T = None, seed = None, deterministic = False, **kwargs ): shape = shape or (self.channels, self.image_size, self.image_size) x = x_T or torch.randn(bs, *shape, device=device) if exists(uncond): cond = reconstruct_cond(cond, uncond) if sampler.startswith("diffuser"): # Using huggingface diffuser noise sampler and scheduler sampler = DiffuserDenoiser( sampler, prediction_type = "v_prediction" if self.parameterization == "v" else "epsilon", use_karras = scheduler == "Karras" ) samples = sampler( x, cond, cond_scale=cfg_scale, unet=self, timesteps=step, generator=torch.manual_seed(seed) if exists(seed) else None, device=device ) else: # Using k-diffusion sampler and noise scheduler seed = seed or torch.seed() sampler = KDiffusionSampler(sampler, scheduler, self, device) sigmas = sampler.get_sigmas(step) extra_args = { "cond": cond, "cond_scale": cfg_scale, } seed = [seed for _ in range(bs)] if deterministic else seed samples = sampler(x, sigmas, extra_args, seed, deterministic, step) return samples def switch_to_fp16(self): unet = self.model.diffusion_model unet.input_blocks = unet.input_blocks.to(self.half_precision_dtype) unet.middle_block = unet.middle_block.to(self.half_precision_dtype) unet.output_blocks = unet.output_blocks.to(self.half_precision_dtype) self.dtype = self.half_precision_dtype unet.dtype = self.half_precision_dtype def switch_to_fp32(self): unet = self.model.diffusion_model unet.input_blocks = unet.input_blocks.float() unet.middle_block = unet.middle_block.float() unet.output_blocks = unet.output_blocks.float() self.dtype = torch.float32 unet.dtype = torch.float32 def switch_vae_to_fp16(self): self.first_stage_model = self.first_stage_model.to(self.half_precision_dtype) def switch_vae_to_fp32(self): self.first_stage_model = self.first_stage_model.float() def low_vram_shift(self, cuda_list: Union[str, list[str]]): if not isinstance(cuda_list, list): cuda_list = [cuda_list] cpu_list = self.model_list.keys() - cuda_list for model in cpu_list: self.model_list[model] = self.model_list[model].cpu() torch.cuda.empty_cache() for model in cuda_list: self.model_list[model] = self.model_list[model].cuda() def retrieve_attn_modules(self): from refnet.modules.transformer import BasicTransformerBlock from refnet.sampling import torch_dfs scale_factor_levels = {"high": 0.5, "low": 0.25, "bottom": 0.25} attn_modules = [] for module in torch_dfs(self.model.diffusion_model): if isinstance(module, BasicTransformerBlock): attn_modules.append(module) self.attn_modules = { "high": [0, 1, 2, 3] + [64, 65, 66, 67, 68, 69], "low": [i for i in range(4, 24)] + [i for i in range(34, 64)], "bottom": [i for i in range(24, 34)], "encoder": [i for i in range(24)], "decoder": [i for i in range(34, len(attn_modules))] } self.attn_modules["modules"] = attn_modules for k in ["high", "low", "bottom"]: scale_factor = scale_factor_levels[k] for attn in self.attn_modules[k]: attn_modules[attn].scale_factor = scale_factor def retrieve_attn_layers(self): self.attn_layers = [] for module in (self.attn_modules["modules"]): if hasattr(module, "attn2") and exists(getattr(module, "attn2")): self.attn_layers.append(module.attn2) class CustomizedColorizer(CustomizedLDM): def __init__( self, control_encoder_config, proj_config, token_type = "full", *args, **kwargs ): super().__init__(*args, **kwargs) self.control_encoder = instantiate_from_config(control_encoder_config) self.proj = instantiate_from_config(proj_config) self.token_type = token_type self.model_list.update({"control_encoder": self.control_encoder, "proj": self.proj}) self.switch_cond_modules += ["control_encoder", "proj"] def switch_to_fp16(self): self.control_encoder = self.control_encoder.to(self.half_precision_dtype) super().switch_to_fp16() def switch_to_fp32(self): self.control_encoder = self.control_encoder.float() super().switch_to_fp32() from refnet.modules.unet import hack_inference_forward class CustomizedWrapper: def __init__(self): self.scaling_sample = False self.guidance_steps = (0, 1) self.no_guidance_steps = (-0.05, 0.05) hack_inference_forward(self.model.diffusion_model) def adjust_reference_scale(self, scale_kwargs): if isinstance(scale_kwargs, dict): if scale_kwargs["level_control"]: for key in scale_kwargs["scales"]: if key == "middle": continue for idx in self.attn_modules[key]: self.attn_modules["modules"][idx].reference_scale = scale_kwargs["scales"][key] else: for idx, s in enumerate(scale_kwargs["scales"]): self.attn_modules["modules"][idx].reference_scale = s else: for module in self.attn_modules["modules"]: module.reference_scale = scale_kwargs def adjust_fgbg_scale(self, fg_scale, bg_scale, merge_scale, mask_threshold): for layer in self.attn_layers: layer.fg_scale = fg_scale layer.bg_scale = bg_scale layer.merge_scale = merge_scale layer.mask_threshold = mask_threshold # for layer in self.attn_modules["modules"]: # layer.fg_scale = fg_scale # layer.bg_scale = bg_scale # layer.merge_scale = merge_scale # layer.mask_threshold = mask_threshold def apply_model(self, x_noisy, t, cond): tr = 1 - t[0] / (self.num_timesteps - 1) crossattn = cond["context"][0] if ((tr < self.guidance_steps[0] or tr > self.guidance_steps[1]) or (tr >= self.no_guidance_steps[0] and tr <= self.no_guidance_steps[1])): crossattn = torch.zeros_like(crossattn)[:, :1] cond["context"] = [crossattn] model_cond = {k: v for k, v in cond.items() if k != "inpaint_bg"} return self.model(x_noisy, t, **model_cond) def prepare_conditions(self, *args, **kwargs): raise NotImplementedError("Inputs preprocessing function is not implemented.") def check_manipulate(self, scales): if exists(scales) and len(scales) > 0: for scale in scales: if scale > 0: return True return False @torch.inference_mode() def generate( self, # Conditional inputs cond: dict, ctl_scale: Union[float|list[float]], merge_scale: float, mask_scale: float, mask_thresh: float, mask_thresh_sketch: float, # Sampling settings sampler, scheduler, step: int, bs: int, gs: list[float], strength: Union[float, list[float]], fg_strength: float, bg_strength: float, seed: int, start_step: float = 0.0, end_step: float = 1.0, no_start_step: float = -0.05, no_end_step: float = -0.05, deterministic: bool = False, style_enhance: bool = False, bg_enhance: bool = False, fg_enhance: bool = False, latent_inpaint: bool = False, height: int = 512, width: int = 512, # Injection settings injection: bool = False, injection_cfg: float = 0.5, injection_control: float = 0, injection_start_step: float = 0, hook_xr: torch.Tensor = None, hook_xs: torch.Tensor = None, # Additional settings low_vram: bool = True, return_intermediate = False, manipulation_params = None, **kwargs, ): """ User interface function. """ hook_unet = UnetHook() self.guidance_steps = (start_step, end_step) self.no_guidance_steps = (no_start_step, no_end_step) self.adjust_reference_scale(strength) self.adjust_fgbg_scale(fg_strength, bg_strength, merge_scale, mask_thresh_sketch) if low_vram: self.low_vram_shift(self.switch_cond_modules) else: self.low_vram_shift(list(self.model_list.keys())) c, uc = self.prepare_conditions( bs = bs, control_scale = ctl_scale, merge_scale = merge_scale, mask_scale = mask_scale, mask_threshold_ref = mask_thresh, mask_threshold_sketch = mask_thresh_sketch, style_enhance = style_enhance, bg_enhance = bg_enhance, fg_enhance = fg_enhance, latent_inpaint = latent_inpaint, height = height, width = width, bg_strength = bg_strength, low_vram = low_vram, **cond, **manipulation_params, **kwargs ) cfg = int(gs[0] > 1) * GuidanceFlag.reference + int(gs[1] > 1) * GuidanceFlag.sketch gr_indice = [] if (cfg == GuidanceFlag.none or cfg == GuidanceFlag.sketch) else [i for i in range(bs, bs*2)] repeat = 1 if cfg == GuidanceFlag.none: gs = 1 uc = None if cfg == GuidanceFlag.reference: gs = gs[0] uc = uc[0] repeat = 2 if cfg == GuidanceFlag.sketch: gs = gs[1] uc = uc[1] repeat = 2 if cfg == GuidanceFlag.both: repeat = 3 if low_vram: self.low_vram_shift("first") if injection: rx = self.get_first_stage_encoding(hook_xr.to(self.first_stage_model.dtype)) hook_unet.enhance_reference( model = self.model, ldm = self, bs = bs * repeat, s = -hook_xr.to(self.dtype), r = rx, style_cfg = injection_cfg, control_cfg = injection_control, gr_indice = gr_indice, start_step = injection_start_step, ) if low_vram: self.low_vram_shift(self.switch_main_modules) z = self.sample( cond = c, uncond = uc, bs = bs, shape = (self.channels, height // 8, width // 8), cfg_scale = gs, step = step, sampler = sampler, scheduler = scheduler, seed = seed, deterministic = deterministic, return_intermediate = return_intermediate, ) if injection: hook_unet.restore(self.model) if low_vram: self.low_vram_shift("first") return self.decode_first_stage(z.to(self.first_stage_model.dtype))