| |
| |
| import torch |
| from PIL import Image |
| import numpy as np |
| from .attention_processor import IPAFluxAttnProcessor2_0 |
| from .utils import is_model_pathched, FluxUpdateModules |
| from .sd3.resampler import TimeResampler |
| from .sd3.joinblock import JointBlockIPWrapper, IPAttnProcessor |
|
|
| image_proj_model = None |
| class MLPProjModel(torch.nn.Module): |
| def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4): |
| super().__init__() |
|
|
| self.cross_attention_dim = cross_attention_dim |
| self.num_tokens = num_tokens |
|
|
| self.proj = torch.nn.Sequential( |
| torch.nn.Linear(id_embeddings_dim, id_embeddings_dim * 2), |
| torch.nn.GELU(), |
| torch.nn.Linear(id_embeddings_dim * 2, cross_attention_dim * num_tokens), |
| ) |
| self.norm = torch.nn.LayerNorm(cross_attention_dim) |
|
|
| def forward(self, id_embeds): |
| x = self.proj(id_embeds) |
| x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) |
| x = self.norm(x) |
| return x |
|
|
| class InstantXFluxIpadapterApply: |
| def __init__(self, num_tokens=128): |
| self.device = None |
| self.dtype = torch.float16 |
| self.num_tokens = num_tokens |
| self.ip_ckpt = None |
| self.clip_vision = None |
| self.image_encoder = None |
| self.clip_image_processor = None |
| |
| self.state_dict = None |
| self.joint_attention_dim = 4096 |
| self.hidden_size = 3072 |
|
|
| def set_ip_adapter(self, flux_model, weight, timestep_percent_range=(0.0, 1.0)): |
| s = flux_model.model_sampling |
| percent_to_timestep_function = lambda a: s.percent_to_sigma(a) |
| timestep_range = (percent_to_timestep_function(timestep_percent_range[0]), |
| percent_to_timestep_function(timestep_percent_range[1])) |
| ip_attn_procs = {} |
| dsb_count = len(flux_model.diffusion_model.double_blocks) |
| for i in range(dsb_count): |
| name = f"double_blocks.{i}" |
| ip_attn_procs[name] = IPAFluxAttnProcessor2_0( |
| hidden_size=self.hidden_size, |
| cross_attention_dim=self.joint_attention_dim, |
| num_tokens=self.num_tokens, |
| scale=weight, |
| timestep_range=timestep_range |
| ).to(self.device, dtype=self.dtype) |
| ssb_count = len(flux_model.diffusion_model.single_blocks) |
| for i in range(ssb_count): |
| name = f"single_blocks.{i}" |
| ip_attn_procs[name] = IPAFluxAttnProcessor2_0( |
| hidden_size=self.hidden_size, |
| cross_attention_dim=self.joint_attention_dim, |
| num_tokens=self.num_tokens, |
| scale=weight, |
| timestep_range=timestep_range |
| ).to(self.device, dtype=self.dtype) |
| return ip_attn_procs |
|
|
| def load_ip_adapter(self, flux_model, weight, timestep_percent_range=(0.0, 1.0)): |
| global image_proj_model |
| image_proj_model.load_state_dict(self.state_dict["image_proj"], strict=True) |
| ip_attn_procs = self.set_ip_adapter(flux_model, weight, timestep_percent_range) |
| ip_layers = torch.nn.ModuleList(ip_attn_procs.values()) |
| ip_layers.load_state_dict(self.state_dict["ip_adapter"], strict=True) |
| return ip_attn_procs |
|
|
| def get_image_embeds(self, pil_image=None, clip_image_embeds=None): |
| |
| |
| |
| |
| if pil_image is not None: |
| if isinstance(pil_image, Image.Image): |
| pil_image = [pil_image] |
| clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values |
| clip_image_embeds = self.image_encoder( |
| clip_image.to(self.device, dtype=self.image_encoder.dtype)).pooler_output |
| clip_image_embeds = clip_image_embeds.to(dtype=self.dtype) |
| else: |
| clip_image_embeds = clip_image_embeds.to(self.device, dtype=self.dtype) |
| global image_proj_model |
| image_prompt_embeds = image_proj_model(clip_image_embeds) |
| return image_prompt_embeds |
|
|
| def apply_ipadapter(self, model, ipadapter, image, weight, start_at, end_at, provider=None, use_tiled=False): |
| self.device = provider.lower() |
| if "clipvision" in ipadapter: |
| |
| self.image_encoder = ipadapter["clipvision"]['model']['image_encoder'].to(self.device, dtype=self.dtype) |
| self.clip_image_processor = ipadapter["clipvision"]['model']['clip_image_processor'] |
| if "ipadapter" in ipadapter: |
| self.ip_ckpt = ipadapter["ipadapter"]['file'] |
| self.state_dict = ipadapter["ipadapter"]['model'] |
|
|
| |
| pil_image = image.numpy()[0] * 255.0 |
| pil_image = Image.fromarray(pil_image.astype(np.uint8)) |
| |
| global image_proj_model |
| if image_proj_model is None: |
| image_proj_model = MLPProjModel( |
| cross_attention_dim=self.joint_attention_dim, |
| id_embeddings_dim=1152, |
| num_tokens=self.num_tokens, |
| ) |
| image_proj_model.to(self.device, dtype=self.dtype) |
| ip_attn_procs = self.load_ip_adapter(model.model, weight, (start_at, end_at)) |
| |
| image_prompt_embeds = self.get_image_embeds(pil_image=pil_image, clip_image_embeds=None) |
| |
| |
| bi = model.clone() |
| FluxUpdateModules(bi, ip_attn_procs, image_prompt_embeds) |
|
|
| return (bi, image) |
|
|
|
|
| def patch_sd3( |
| patcher, |
| ip_procs, |
| resampler: TimeResampler, |
| clip_embeds, |
| weight=1.0, |
| start=0.0, |
| end=1.0, |
| ): |
| """ |
| Patches a model_sampler to add the ipadapter |
| """ |
| mmdit = patcher.model.diffusion_model |
| timestep_schedule_max = patcher.model.model_config.sampling_settings.get( |
| "timesteps", 1000 |
| ) |
| |
| |
| ip_options = { |
| "hidden_states": None, |
| "t_emb": None, |
| "weight": weight, |
| } |
|
|
| def ddit_wrapper(forward, args): |
| |
| |
| t_percent = 1 - args["timestep"].flatten()[0].cpu().item() |
| if start <= t_percent <= end: |
| batch_size = args["input"].shape[0] // len(args["cond_or_uncond"]) |
| |
| embeds = clip_embeds[args["cond_or_uncond"]] |
| |
| |
| embeds = torch.repeat_interleave(embeds, batch_size, dim=0) |
| |
| timestep = args["timestep"] * timestep_schedule_max |
| image_emb, t_emb = resampler(embeds, timestep, need_temb=True) |
| |
| ip_options["hidden_states"] = image_emb |
| ip_options["t_emb"] = t_emb |
| else: |
| ip_options["hidden_states"] = None |
| ip_options["t_emb"] = None |
|
|
| return forward(args["input"], args["timestep"], **args["c"]) |
|
|
| patcher.set_model_unet_function_wrapper(ddit_wrapper) |
| |
| for i, block in enumerate(mmdit.joint_blocks): |
| wrapper = JointBlockIPWrapper(block, ip_procs[i], ip_options) |
| patcher.set_model_patch_replace(wrapper, "dit", "double_block", i) |
|
|
| class InstantXSD3IpadapterApply: |
| def __init__(self): |
| self.device = None |
| self.dtype = torch.float16 |
| self.clip_image_processor = None |
| self.image_encoder = None |
| self.resampler = None |
| self.procs = None |
|
|
| @torch.inference_mode() |
| def encode(self, image): |
| clip_image = self.clip_image_processor.image_processor(image, return_tensors="pt", do_rescale=False).pixel_values |
| clip_image_embeds = self.image_encoder( |
| clip_image.to(self.device, dtype=self.image_encoder.dtype), |
| output_hidden_states=True, |
| ).hidden_states[-2] |
| clip_image_embeds = torch.cat( |
| [clip_image_embeds, torch.zeros_like(clip_image_embeds)], dim=0 |
| ) |
| clip_image_embeds = clip_image_embeds.to(dtype=torch.float16) |
| return clip_image_embeds |
|
|
| def apply_ipadapter(self, model, ipadapter, image, weight, start_at, end_at, provider=None, use_tiled=False): |
| self.device = provider.lower() |
| if "clipvision" in ipadapter: |
| self.image_encoder = ipadapter["clipvision"]['model']['image_encoder'].to(self.device, dtype=self.dtype) |
| self.clip_image_processor = ipadapter["clipvision"]['model']['clip_image_processor'] |
| if "ipadapter" in ipadapter: |
| self.ip_ckpt = ipadapter["ipadapter"]['file'] |
| self.state_dict = ipadapter["ipadapter"]['model'] |
|
|
| self.resampler = TimeResampler( |
| dim=1280, |
| depth=4, |
| dim_head=64, |
| heads=20, |
| num_queries=64, |
| embedding_dim=1152, |
| output_dim=2432, |
| ff_mult=4, |
| timestep_in_dim=320, |
| timestep_flip_sin_to_cos=True, |
| timestep_freq_shift=0, |
| ) |
| self.resampler.eval() |
| self.resampler.to(self.device, dtype=self.dtype) |
| self.resampler.load_state_dict(self.state_dict["image_proj"]) |
|
|
| |
| |
| n_procs = len( |
| set(x.split(".")[0] for x in self.state_dict["ip_adapter"].keys()) |
| ) |
| self.procs = torch.nn.ModuleList( |
| [ |
| |
| IPAttnProcessor( |
| hidden_size=2432, |
| cross_attention_dim=2432, |
| ip_hidden_states_dim=2432, |
| ip_encoder_hidden_states_dim=2432, |
| head_dim=64, |
| timesteps_emb_dim=1280, |
| ).to(self.device, dtype=torch.float16) |
| for _ in range(n_procs) |
| ] |
| ) |
| self.procs.load_state_dict(self.state_dict["ip_adapter"]) |
|
|
| work_model = model.clone() |
| embeds = self.encode(image) |
|
|
| patch_sd3( |
| work_model, |
| self.procs, |
| self.resampler, |
| embeds, |
| weight, |
| start_at, |
| end_at, |
| ) |
|
|
| return (work_model, image) |