Spaces:
Runtime error
Runtime error
| import itertools | |
| import torch | |
| from typing import Union, Dict | |
| from .ipadapter_model import ImageEmbed, IPAdapterModel | |
| from ..enums import StableDiffusionVersion, TransformerID | |
| def get_block(model, flag): | |
| return { | |
| "input": model.input_blocks, | |
| "middle": [model.middle_block], | |
| "output": model.output_blocks, | |
| }[flag] | |
| def attn_forward_hacked(self, x, context=None, **kwargs): | |
| batch_size, sequence_length, inner_dim = x.shape | |
| h = self.heads | |
| head_dim = inner_dim // h | |
| if context is None: | |
| context = x | |
| q = self.to_q(x) | |
| k = self.to_k(context) | |
| v = self.to_v(context) | |
| del context | |
| q, k, v = map( | |
| lambda t: t.view(batch_size, -1, h, head_dim).transpose(1, 2), | |
| (q, k, v), | |
| ) | |
| out = torch.nn.functional.scaled_dot_product_attention( | |
| q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False | |
| ) | |
| out = out.transpose(1, 2).reshape(batch_size, -1, h * head_dim) | |
| del k, v | |
| for f in self.ipadapter_hacks: | |
| out = out + f(self, x, q) | |
| del q, x | |
| return self.to_out(out) | |
| all_hacks = {} | |
| current_model = None | |
| def hack_blk(block, function, type): | |
| if not hasattr(block, "ipadapter_hacks"): | |
| block.ipadapter_hacks = [] | |
| if len(block.ipadapter_hacks) == 0: | |
| all_hacks[block] = block.forward | |
| block.forward = attn_forward_hacked.__get__(block, type) | |
| block.ipadapter_hacks.append(function) | |
| return | |
| def set_model_attn2_replace( | |
| model, | |
| target_cls, | |
| function, | |
| transformer_id: TransformerID, | |
| ): | |
| block = get_block(model, transformer_id.block_type.value) | |
| module = ( | |
| block[transformer_id.block_id][1] | |
| .transformer_blocks[transformer_id.block_index] | |
| .attn2 | |
| ) | |
| hack_blk(module, function, target_cls) | |
| def clear_all_ip_adapter(): | |
| global all_hacks, current_model | |
| for k, v in all_hacks.items(): | |
| k.forward = v | |
| k.ipadapter_hacks = [] | |
| all_hacks = {} | |
| current_model = None | |
| return | |
| class PlugableIPAdapter(torch.nn.Module): | |
| def __init__(self, ipadapter: IPAdapterModel): | |
| super().__init__() | |
| self.ipadapter = ipadapter | |
| self.disable_memory_management = True | |
| self.dtype = None | |
| self.weight: Union[float, Dict[int, float]] = 1.0 | |
| self.cache = None | |
| self.p_start = 0.0 | |
| self.p_end = 1.0 | |
| def reset(self): | |
| self.cache = {} | |
| def hook( | |
| self, model, preprocessor_outputs, weight, start, end, dtype=torch.float32 | |
| ): | |
| global current_model | |
| current_model = model | |
| self.p_start = start | |
| self.p_end = end | |
| self.cache = {} | |
| self.weight = weight | |
| device = torch.device("cpu") | |
| self.dtype = dtype | |
| self.ipadapter.to(device, dtype=self.dtype) | |
| if isinstance(preprocessor_outputs, (list, tuple)): | |
| preprocessor_outputs = preprocessor_outputs | |
| else: | |
| preprocessor_outputs = [preprocessor_outputs] | |
| self.image_emb = ImageEmbed.average_of( | |
| *[self.ipadapter.get_image_emb(o) for o in preprocessor_outputs] | |
| ) | |
| if self.ipadapter.is_sdxl: | |
| sd_version = StableDiffusionVersion.SDXL | |
| from sgm.modules.attention import CrossAttention | |
| else: | |
| sd_version = StableDiffusionVersion.SD1x | |
| from ldm.modules.attention import CrossAttention | |
| input_ids, output_ids, middle_ids = sd_version.transformer_ids | |
| for i, transformer_id in enumerate( | |
| itertools.chain(input_ids, output_ids, middle_ids) | |
| ): | |
| set_model_attn2_replace( | |
| model, | |
| CrossAttention, | |
| self.patch_forward(i, transformer_id.transformer_index), | |
| transformer_id, | |
| ) | |
| def weight_on_transformer(self, transformer_index: int) -> float: | |
| if isinstance(self.weight, dict): | |
| return self.weight.get(transformer_index, 0.0) | |
| else: | |
| assert isinstance(self.weight, (float, int)) | |
| return self.weight | |
| def call_ip(self, key: str, feat, device): | |
| if key in self.cache: | |
| return self.cache[key] | |
| else: | |
| ip = self.ipadapter.ip_layers.to_kvs[key](feat).to(device) | |
| self.cache[key] = ip | |
| return ip | |
| def patch_forward(self, number: int, transformer_index: int): | |
| def forward(attn_blk, x, q): | |
| batch_size, sequence_length, inner_dim = x.shape | |
| h = attn_blk.heads | |
| head_dim = inner_dim // h | |
| weight = self.weight_on_transformer(transformer_index) | |
| current_sampling_percent = getattr( | |
| current_model, "current_sampling_percent", 0.5 | |
| ) | |
| if ( | |
| current_sampling_percent < self.p_start | |
| or current_sampling_percent > self.p_end | |
| or weight == 0.0 | |
| ): | |
| return 0.0 | |
| k_key = f"{number * 2 + 1}_to_k_ip" | |
| v_key = f"{number * 2 + 1}_to_v_ip" | |
| cond_uncond_image_emb = self.image_emb.eval(current_model.cond_mark) | |
| ip_k = self.call_ip(k_key, cond_uncond_image_emb, device=q.device) | |
| ip_v = self.call_ip(v_key, cond_uncond_image_emb, device=q.device) | |
| ip_k, ip_v = map( | |
| lambda t: t.view(batch_size, -1, h, head_dim).transpose(1, 2), | |
| (ip_k, ip_v), | |
| ) | |
| assert ip_k.dtype == ip_v.dtype | |
| # On MacOS, q can be float16 instead of float32. | |
| # https://github.com/Mikubill/sd-webui-controlnet/issues/2208 | |
| if q.dtype != ip_k.dtype: | |
| ip_k = ip_k.to(dtype=q.dtype) | |
| ip_v = ip_v.to(dtype=q.dtype) | |
| ip_out = torch.nn.functional.scaled_dot_product_attention( | |
| q, ip_k, ip_v, attn_mask=None, dropout_p=0.0, is_causal=False | |
| ) | |
| ip_out = ip_out.transpose(1, 2).reshape(batch_size, -1, h * head_dim) | |
| return ip_out * weight | |
| return forward | |