Spaces:
Running
on
Zero
Running
on
Zero
| from diffusers.models.attention_processor import Attention | |
| from diffusers.models.embeddings import ImageProjection, MultiIPAdapterImageProjection | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import copy | |
| from .resampler import Resampler | |
| from typing import Optional | |
| from diffusers.image_processor import IPAdapterMaskProcessor | |
| import math | |
| import warnings | |
| from pulid.encoders_transformer import IDFormer | |
| def save_ip_adapter(unet, path): | |
| state_dict = {} | |
| if ( | |
| hasattr(unet, "encoder_hid_proj") | |
| and unet.encoder_hid_proj is not None | |
| and isinstance(unet.encoder_hid_proj, torch.nn.Module) | |
| ): | |
| state_dict["encoder_hid_proj"] = unet.encoder_hid_proj.state_dict() | |
| for name, module in unet.attn_processors.items(): | |
| if isinstance(module, torch.nn.Module): | |
| state_dict[name] = module.state_dict() | |
| torch.save(state_dict, path) | |
| def load_ip_adapter( | |
| unet, | |
| path=None, | |
| clip_embeddings_dim=1280, | |
| cross_attention_dim=2048, | |
| num_image_text_embeds=4, | |
| attn_blocks=["down", "mid", "up"], | |
| ): | |
| if path is None: | |
| state_dict = None | |
| else: | |
| state_dict = torch.load(path, map_location="cpu") | |
| clip_embeddings_dim = state_dict["encoder_hid_proj"][ | |
| "image_embeds.weight" | |
| ].shape[-1] | |
| num_image_text_embeds = ( | |
| state_dict["encoder_hid_proj"]["image_embeds.weight"].shape[0] | |
| // cross_attention_dim | |
| ) | |
| if not hasattr(unet, "encoder_hid_proj") or unet.encoder_hid_proj is None: | |
| unet.encoder_hid_proj = ImageProjection( | |
| cross_attention_dim=cross_attention_dim, | |
| image_embed_dim=clip_embeddings_dim, | |
| num_image_text_embeds=num_image_text_embeds, | |
| ).to(unet.device, unet.dtype) | |
| if state_dict is not None: | |
| unet.encoder_hid_proj.load_state_dict(state_dict["encoder_hid_proj"]) | |
| for name, module in unet.named_modules(): | |
| if ( | |
| "attn2" in name | |
| and isinstance(module, Attention) | |
| and any([attn in name for attn in attn_blocks]) | |
| ): | |
| if not isinstance(module.processor, IPAttnProcessor2_0): | |
| module.set_processor( | |
| IPAttnProcessor2_0( | |
| hidden_size=module.query_dim, | |
| cross_attention_dim=cross_attention_dim, | |
| ).to(unet.device, unet.dtype) | |
| ) | |
| if state_dict is not None: | |
| module.processor.load_state_dict(state_dict[f"{name}.processor"]) | |
| else: | |
| module.processor.to_k_ip.load_state_dict(module.to_k.state_dict()) | |
| module.processor.to_v_ip.load_state_dict(module.to_v.state_dict()) | |
| def parse_clip_embeddings_dim( | |
| path, | |
| state_dict, | |
| ): | |
| if "pulid" in path: | |
| return None | |
| else: | |
| return state_dict["encoder_hid_proj"]["image_embeds.weight"].shape[-1] | |
| def parse_num_image_text_embeds(path, state_dict, cross_attention_dim=2048): | |
| if "pulid" in path: | |
| return None | |
| else: | |
| return ( | |
| state_dict["encoder_hid_proj"]["image_embeds.weight"].shape[0] | |
| // cross_attention_dim | |
| ) | |
| def parse_encoder_hid_proj_module( | |
| path=None, | |
| cross_attention_dim=2048, | |
| image_embed_dim=None, | |
| num_image_text_embeds=None, | |
| ): | |
| if "pulid" in path: | |
| return IDFormer() | |
| else: | |
| return ImageProjection( | |
| cross_attention_dim=cross_attention_dim, | |
| image_embed_dim=image_embed_dim, | |
| num_image_text_embeds=num_image_text_embeds, | |
| ) | |
| def load_multi_ip_adapter( | |
| unet, | |
| paths=None, | |
| clip_embeddings_dim=[1280], | |
| cross_attention_dim=2048, | |
| num_image_text_embeds=[4], | |
| ): | |
| if paths is None: | |
| state_dict = None | |
| else: | |
| state_dict = [torch.load(path, map_location="cpu") for path in paths] | |
| clip_embeddings_dim = [ | |
| parse_clip_embeddings_dim(path=single_path, state_dict=single_state_dict) | |
| for single_path, single_state_dict in zip(paths, state_dict) | |
| ] | |
| num_image_text_embeds = [ | |
| parse_num_image_text_embeds( | |
| path=single_path, | |
| state_dict=single_state_dict, | |
| cross_attention_dim=unet.config.cross_attention_dim, | |
| ) | |
| for single_path, single_state_dict in zip(paths, state_dict) | |
| ] | |
| if not hasattr(unet, "encoder_hid_proj") or unet.encoder_hid_proj is None: | |
| unet.encoder_hid_proj = MultiIPAdapterImageProjection( | |
| [ | |
| parse_encoder_hid_proj_module( | |
| path=single_path, | |
| cross_attention_dim=unet.config.cross_attention_dim, | |
| image_embed_dim=single_clip_embeddings_dim, | |
| num_image_text_embeds=single_num_image_text_embeds, | |
| ).to(unet.device, unet.dtype) | |
| for single_path, single_clip_embeddings_dim, single_num_image_text_embeds in zip( | |
| paths, clip_embeddings_dim, num_image_text_embeds | |
| ) | |
| ] | |
| ).to(unet.device, unet.dtype) | |
| if state_dict is not None: | |
| for single_encoder_hid_proj, single_state_dict in zip( | |
| unet.encoder_hid_proj.image_projection_layers, state_dict | |
| ): | |
| single_encoder_hid_proj.load_state_dict( | |
| single_state_dict["encoder_hid_proj"] | |
| ) | |
| for name, module in unet.named_modules(): | |
| if "attn2" in name and isinstance(module, Attention): | |
| if not isinstance(module.processor, MultiIPAttnProcessor2_0): | |
| module.set_processor( | |
| MultiIPAttnProcessor2_0( | |
| hidden_size=module.query_dim, | |
| cross_attention_dim=unet.config.cross_attention_dim, | |
| num_tokens=num_image_text_embeds, | |
| ).to(unet.device, unet.dtype) | |
| ) | |
| if state_dict is not None: | |
| for ( | |
| to_k_ip, | |
| to_v_ip, | |
| single_state_dict, | |
| ) in zip( | |
| module.processor.to_k_ip, | |
| module.processor.to_v_ip, | |
| state_dict, | |
| ): | |
| if f"{name}.processor" in single_state_dict.keys(): | |
| to_k_ip.weight = nn.Parameter( | |
| single_state_dict[f"{name}.processor"]["to_k_ip.weight"] | |
| ) | |
| to_v_ip.weight = nn.Parameter( | |
| single_state_dict[f"{name}.processor"]["to_v_ip.weight"] | |
| ) | |
| module.processor = module.processor.to(unet.device, unet.dtype) | |
| def load_ip_adapter_plus( | |
| unet, | |
| path=None, | |
| embed_dims=1664, | |
| depth=4, | |
| dim_head=64, | |
| heads=12, | |
| num_queries=32, | |
| ff_mult=4, | |
| attn_blocks=["down", "mid", "up"], | |
| ): | |
| if path is not None: | |
| state_dict = torch.load(path) | |
| else: | |
| state_dict = None | |
| if not hasattr(unet, "encoder_hid_proj") or unet.encoder_hid_proj is None: | |
| unet.encoder_hid_proj = Resampler( | |
| dim=unet.config.cross_attention_dim, | |
| depth=depth, | |
| dim_head=dim_head, | |
| heads=heads, | |
| num_queries=num_queries, | |
| embedding_dim=embed_dims, | |
| output_dim=unet.config.cross_attention_dim, | |
| ff_mult=ff_mult, | |
| ).to(unet.device, unet.dtype) | |
| if state_dict is not None: | |
| unet.encoder_hid_proj.load_state_dict(state_dict["encoder_hid_proj"]) | |
| for name, module in unet.named_modules(): | |
| if ( | |
| "attn2" in name | |
| and isinstance(module, Attention) | |
| and any([attn in name for attn in attn_blocks]) | |
| ): | |
| if not isinstance(module.processor, IPAttnProcessor2_0): | |
| module.set_processor( | |
| IPAttnProcessor2_0( | |
| hidden_size=module.query_dim, | |
| cross_attention_dim=unet.config.cross_attention_dim, | |
| ).to(unet.device, unet.dtype) | |
| ) | |
| if state_dict is not None and f"{name}.processor" in state_dict.keys(): | |
| module.processor.load_state_dict(state_dict[f"{name}.processor"]) | |
| else: | |
| module.processor.to_k_ip.load_state_dict(module.to_k.state_dict()) | |
| module.processor.to_v_ip.load_state_dict(module.to_v.state_dict()) | |
| def set_ip_hidden_states(unet, image_embeds): | |
| for name, module in unet.attn_processors.items(): | |
| if isinstance(module, IPAttnProcessor2_0) or isinstance( | |
| module, MultiIPAttnProcessor2_0 | |
| ): | |
| module.ip_hidden_states = image_embeds.clone() | |
| def set_multi_ip_hidden_states(unet, image_embeds): | |
| for name, module in unet.attn_processors.items(): | |
| if isinstance(module, IPAttnProcessor2_0) or isinstance( | |
| module, MultiIPAttnProcessor2_0 | |
| ): | |
| module.ip_hidden_states = image_embeds | |
| def set_multi_ip_attn_masks(unet, attn_masks): | |
| for name, module in unet.attn_processors.items(): | |
| if isinstance(module, IPAttnProcessor2_0) or isinstance( | |
| module, MultiIPAttnProcessor2_0 | |
| ): | |
| module.ip_hidden_states = attn_masks | |
| def clear_ip_hidden_states(model): | |
| for name, module in model.named_modules(): | |
| if isinstance(module, IPAttnProcessor2_0): | |
| module.ip_hidden_states = None | |
| def set_ip_adapter_scale(unet, scale=1.0, attn_blocks=["down", "mid", "up"]): | |
| for name, module in unet.named_modules(): | |
| if isinstance(module, IPAttnProcessor2_0) and any( | |
| tarhet_module in name for tarhet_module in attn_blocks | |
| ): | |
| module.scale = scale | |
| def downsample( | |
| mask: torch.Tensor, batch_size: int, num_queries: int, value_embed_dim: int | |
| ): | |
| """ | |
| Downsamples the provided mask tensor to match the expected dimensions for scaled dot-product attention. If the | |
| aspect ratio of the mask does not match the aspect ratio of the output image, a warning is issued. | |
| Args: | |
| mask (`torch.Tensor`): | |
| The input mask tensor generated with `IPAdapterMaskProcessor.preprocess()`. | |
| batch_size (`int`): | |
| The batch size. | |
| num_queries (`int`): | |
| The number of queries. | |
| value_embed_dim (`int`): | |
| The dimensionality of the value embeddings. | |
| Returns: | |
| `torch.Tensor`: | |
| The downsampled mask tensor. | |
| """ | |
| o_h = mask.shape[2] | |
| o_w = mask.shape[3] | |
| ratio = o_w / o_h | |
| mask_h = int(math.sqrt(num_queries / ratio)) | |
| mask_h = int(mask_h) + int((num_queries % int(mask_h)) != 0) | |
| mask_w = num_queries // mask_h | |
| mask_downsample = F.interpolate(mask, size=(mask_h, mask_w), mode="bicubic") | |
| # Repeat batch_size times | |
| if mask_downsample.shape[0] < batch_size: | |
| mask_downsample = mask_downsample.repeat(batch_size, 1, 1, 1) | |
| mask_downsample = mask_downsample.view(mask_downsample.shape[0], -1) | |
| downsampled_area = mask_h * mask_w | |
| # If the output image and the mask do not have the same aspect ratio, tensor shapes will not match | |
| # Pad tensor if downsampled_mask.shape[1] is smaller than num_queries | |
| if downsampled_area < num_queries: | |
| warnings.warn( | |
| "The aspect ratio of the mask does not match the aspect ratio of the output image. " | |
| "Please update your masks or adjust the output size for optimal performance.", | |
| UserWarning, | |
| ) | |
| mask_downsample = F.pad( | |
| mask_downsample, (0, num_queries - mask_downsample.shape[1]), value=0.0 | |
| ) | |
| # Discard last embeddings if downsampled_mask.shape[1] is bigger than num_queries | |
| if downsampled_area > num_queries: | |
| warnings.warn( | |
| "The aspect ratio of the mask does not match the aspect ratio of the output image. " | |
| "Please update your masks or adjust the output size for optimal performance.", | |
| UserWarning, | |
| ) | |
| mask_downsample = mask_downsample[:, :num_queries] | |
| # Repeat last dimension to match SDPA output shape | |
| mask_downsample = mask_downsample.view( | |
| mask_downsample.shape[0], mask_downsample.shape[1], 1 | |
| ).repeat(1, 1, value_embed_dim) | |
| return mask_downsample | |
| class IPAttnProcessor2_0(torch.nn.Module): | |
| r""" | |
| Attention processor for IP-Adapater for PyTorch 2.0. | |
| Args: | |
| hidden_size (`int`): | |
| The hidden size of the attention layer. | |
| cross_attention_dim (`int`): | |
| The number of channels in the `encoder_hidden_states`. | |
| scale (`float`, defaults to 1.0): | |
| the weight scale of image prompt. | |
| num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): | |
| The context length of the image features. | |
| """ | |
| def __init__( | |
| self, | |
| hidden_size, | |
| cross_attention_dim=None, | |
| scale=1.0, | |
| num_tokens=4, | |
| use_align_sem_and_layout_loss=False, | |
| ): | |
| super().__init__() | |
| if not hasattr(F, "scaled_dot_product_attention"): | |
| raise ImportError( | |
| "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." | |
| ) | |
| self.hidden_size = hidden_size | |
| self.cross_attention_dim = cross_attention_dim | |
| self.scale = scale | |
| self.num_tokens = num_tokens | |
| self.to_k_ip = nn.Linear( | |
| cross_attention_dim or hidden_size, hidden_size, bias=False | |
| ) | |
| self.to_v_ip = nn.Linear( | |
| cross_attention_dim or hidden_size, hidden_size, bias=False | |
| ) | |
| self.ip_hidden_states = None | |
| self.use_align_sem_and_layout_loss = use_align_sem_and_layout_loss | |
| if self.use_align_sem_and_layout_loss: | |
| self.align_sem_loss = None | |
| self.align_layout_loss = None | |
| self.cache_query = None | |
| self.cache_attn_weights = None | |
| def __call__( | |
| self, | |
| attn, | |
| hidden_states, | |
| encoder_hidden_states=None, | |
| attention_mask=None, | |
| temb=None, | |
| ip_adapter_masks: Optional[torch.FloatTensor] = None, | |
| *args, | |
| **kwargs, | |
| ): | |
| residual = hidden_states | |
| if attn.spatial_norm is not None: | |
| hidden_states = attn.spatial_norm(hidden_states, temb) | |
| input_ndim = hidden_states.ndim | |
| if input_ndim == 4: | |
| batch_size, channel, height, width = hidden_states.shape | |
| hidden_states = hidden_states.view( | |
| batch_size, channel, height * width | |
| ).transpose(1, 2) | |
| batch_size, sequence_length, _ = ( | |
| hidden_states.shape | |
| if encoder_hidden_states is None | |
| else encoder_hidden_states.shape | |
| ) | |
| if attention_mask is not None: | |
| attention_mask = attn.prepare_attention_mask( | |
| attention_mask, sequence_length, batch_size | |
| ) | |
| # scaled_dot_product_attention expects attention_mask shape to be | |
| # (batch, heads, source_length, target_length) | |
| attention_mask = attention_mask.view( | |
| batch_size, attn.heads, -1, attention_mask.shape[-1] | |
| ) | |
| if attn.group_norm is not None: | |
| hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( | |
| 1, 2 | |
| ) | |
| query = attn.to_q(hidden_states) | |
| if attn.norm_cross: | |
| encoder_hidden_states = attn.norm_encoder_hidden_states( | |
| encoder_hidden_states | |
| ) | |
| key = attn.to_k(encoder_hidden_states) | |
| value = attn.to_v(encoder_hidden_states) | |
| inner_dim = key.shape[-1] | |
| head_dim = inner_dim // attn.heads | |
| query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| # the output of sdp = (batch, num_heads, seq_len, head_dim) | |
| # TODO: add support for attn.scale when we move to Torch 2.1 | |
| hidden_states = F.scaled_dot_product_attention( | |
| query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False | |
| ) | |
| if self.use_align_sem_and_layout_loss: | |
| if self.cache_query is None: | |
| self.cache_query = query.clone().detach() | |
| self.cache_attn_weights = (key @ query.transpose(-2, -1)) / math.sqrt( | |
| query.size(-1) | |
| ) | |
| self.cache_attn_weights = torch.softmax(self.cache_attn_weights, dim=-1) | |
| else: | |
| self.attn_weights = (key @ query.transpose(-2, -1)) / math.sqrt( | |
| query.size(-1) | |
| ) | |
| self.query = query | |
| self.attn_weights = torch.softmax(self.attn_weights, dim=-1) | |
| hidden_states = hidden_states.transpose(1, 2).reshape( | |
| batch_size, -1, attn.heads * head_dim | |
| ) | |
| hidden_states = hidden_states.to(query.dtype) | |
| if self.scale != 0.0: | |
| # for ip-adapter | |
| ip_key = self.to_k_ip(self.ip_hidden_states).to(dtype=query.dtype) | |
| ip_value = self.to_v_ip(self.ip_hidden_states).to(dtype=query.dtype) | |
| ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose( | |
| 1, 2 | |
| ) | |
| # the output of sdp = (batch, num_heads, seq_len, head_dim) | |
| # TODO: add support for attn.scale when we move to Torch 2.1 | |
| ip_hidden_states = F.scaled_dot_product_attention( | |
| query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False | |
| ) | |
| # with torch.no_grad(): | |
| # self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1) | |
| # print(self.attn_map.shape) | |
| ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape( | |
| batch_size, -1, attn.heads * head_dim | |
| ) | |
| ip_hidden_states = ip_hidden_states.to(query.dtype) | |
| if ip_adapter_masks is not None: | |
| mask_downsample = downsample( | |
| ip_adapter_masks, | |
| batch_size, | |
| ip_hidden_states.shape[1], | |
| ip_hidden_states.shape[2], | |
| ) | |
| mask_downsample = mask_downsample.to( | |
| dtype=query.dtype, device=query.device | |
| ) | |
| ip_hidden_states = ip_hidden_states * mask_downsample | |
| hidden_states = hidden_states + self.scale * ip_hidden_states | |
| # linear proj | |
| hidden_states = attn.to_out[0](hidden_states) | |
| # dropout | |
| hidden_states = attn.to_out[1](hidden_states) | |
| if input_ndim == 4: | |
| hidden_states = hidden_states.transpose(-1, -2).reshape( | |
| batch_size, channel, height, width | |
| ) | |
| if attn.residual_connection: | |
| hidden_states = hidden_states + residual | |
| hidden_states = hidden_states / attn.rescale_output_factor | |
| return hidden_states | |
| def set_ortho(unet, ortho): | |
| for name, module in unet.attn_processors.items(): | |
| if isinstance(module, IPAttnProcessor2_0) or isinstance( | |
| module, MultiIPAttnProcessor2_0 | |
| ): | |
| module.ortho = ortho | |
| def set_num_zero(unet, num_zero): | |
| for name, module in unet.attn_processors.items(): | |
| if isinstance(module, IPAttnProcessor2_0) or isinstance( | |
| module, MultiIPAttnProcessor2_0 | |
| ): | |
| module.num_zero = num_zero | |
| class MultiIPAttnProcessor2_0(torch.nn.Module): | |
| r""" | |
| Attention processor for IP-Adapater for PyTorch 2.0. | |
| Args: | |
| hidden_size (`int`): | |
| The hidden size of the attention layer. | |
| cross_attention_dim (`int`): | |
| The number of channels in the `encoder_hidden_states`. | |
| num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`): | |
| The context length of the image features. | |
| scale (`float` or `List[float]`, defaults to 1.0): | |
| the weight scale of image prompt. | |
| """ | |
| def __init__( | |
| self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0 | |
| ): | |
| super().__init__() | |
| if not hasattr(F, "scaled_dot_product_attention"): | |
| raise ImportError( | |
| f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." | |
| ) | |
| self.hidden_size = hidden_size | |
| self.cross_attention_dim = cross_attention_dim | |
| if not isinstance(num_tokens, (tuple, list)): | |
| num_tokens = [num_tokens] | |
| self.num_tokens = num_tokens | |
| if not isinstance(scale, list): | |
| scale = [scale] * len(num_tokens) | |
| if len(scale) != len(num_tokens): | |
| raise ValueError( | |
| "`scale` should be a list of integers with the same length as `num_tokens`." | |
| ) | |
| self.scale = scale | |
| self.to_k_ip = nn.ModuleList( | |
| [ | |
| nn.Linear(cross_attention_dim, hidden_size, bias=False) | |
| for _ in range(len(num_tokens)) | |
| ] | |
| ) | |
| self.to_v_ip = nn.ModuleList( | |
| [ | |
| nn.Linear(cross_attention_dim, hidden_size, bias=False) | |
| for _ in range(len(num_tokens)) | |
| ] | |
| ) | |
| self.ip_hidden_states = None | |
| self.num_zero = [None] * (len(num_tokens)) | |
| self.ortho = [None] * len(num_tokens) | |
| def __call__( | |
| self, | |
| attn: Attention, | |
| hidden_states: torch.FloatTensor, | |
| encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
| attention_mask: Optional[torch.FloatTensor] = None, | |
| temb: Optional[torch.FloatTensor] = None, | |
| scale: float = 1.0, | |
| ip_adapter_masks: Optional[torch.FloatTensor] = None, | |
| ): | |
| residual = hidden_states | |
| ip_hidden_states = self.ip_hidden_states | |
| if attn.spatial_norm is not None: | |
| hidden_states = attn.spatial_norm(hidden_states, temb) | |
| input_ndim = hidden_states.ndim | |
| if input_ndim == 4: | |
| batch_size, channel, height, width = hidden_states.shape | |
| hidden_states = hidden_states.view( | |
| batch_size, channel, height * width | |
| ).transpose(1, 2) | |
| batch_size, sequence_length, _ = ( | |
| hidden_states.shape | |
| if encoder_hidden_states is None | |
| else encoder_hidden_states.shape | |
| ) | |
| if attention_mask is not None: | |
| attention_mask = attn.prepare_attention_mask( | |
| attention_mask, sequence_length, batch_size | |
| ) | |
| # scaled_dot_product_attention expects attention_mask shape to be | |
| # (batch, heads, source_length, target_length) | |
| attention_mask = attention_mask.view( | |
| batch_size, attn.heads, -1, attention_mask.shape[-1] | |
| ) | |
| if attn.group_norm is not None: | |
| hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( | |
| 1, 2 | |
| ) | |
| query = attn.to_q(hidden_states) | |
| if encoder_hidden_states is None: | |
| encoder_hidden_states = hidden_states | |
| elif attn.norm_cross: | |
| encoder_hidden_states = attn.norm_encoder_hidden_states( | |
| encoder_hidden_states | |
| ) | |
| key = attn.to_k(encoder_hidden_states) | |
| value = attn.to_v(encoder_hidden_states) | |
| inner_dim = key.shape[-1] | |
| head_dim = inner_dim // attn.heads | |
| query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| # the output of sdp = (batch, num_heads, seq_len, head_dim) | |
| # TODO: add support for attn.scale when we move to Torch 2.1 | |
| hidden_states = F.scaled_dot_product_attention( | |
| query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False | |
| ) | |
| hidden_states = hidden_states.transpose(1, 2).reshape( | |
| batch_size, -1, attn.heads * head_dim | |
| ) | |
| hidden_states = hidden_states.to(query.dtype) | |
| if ip_adapter_masks is not None: | |
| if ( | |
| not isinstance(ip_adapter_masks, torch.Tensor) | |
| or ip_adapter_masks.ndim != 4 | |
| ): | |
| raise ValueError( | |
| " ip_adapter_mask should be a tensor with shape [num_ip_adapter, 1, height, width]." | |
| " Please use `IPAdapterMaskProcessor` to preprocess your mask" | |
| ) | |
| if len(ip_adapter_masks) != len(self.scale): | |
| raise ValueError( | |
| f"Number of ip_adapter_masks ({len(ip_adapter_masks)}) must match number of IP-Adapters ({len(self.scale)})" | |
| ) | |
| else: | |
| ip_adapter_masks = [None] * len(self.scale) | |
| # for ip-adapter | |
| for ( | |
| current_ip_hidden_states, | |
| scale, | |
| to_k_ip, | |
| to_v_ip, | |
| mask, | |
| num_zero, | |
| ortho, | |
| ) in zip( | |
| ip_hidden_states, | |
| self.scale, | |
| self.to_k_ip, | |
| self.to_v_ip, | |
| ip_adapter_masks, | |
| self.num_zero, | |
| self.ortho, | |
| ): | |
| if scale == 0: | |
| continue | |
| if num_zero is not None: | |
| zero_tensor = torch.zeros( | |
| ( | |
| current_ip_hidden_states.size(0), | |
| num_zero, | |
| current_ip_hidden_states.size(-1), | |
| ), | |
| dtype=current_ip_hidden_states.dtype, | |
| device=current_ip_hidden_states.device, | |
| ) | |
| current_ip_hidden_states = torch.concat( | |
| [current_ip_hidden_states, zero_tensor], dim=1 | |
| ) | |
| ip_key = to_k_ip(current_ip_hidden_states) | |
| ip_value = to_v_ip(current_ip_hidden_states) | |
| ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose( | |
| 1, 2 | |
| ) | |
| # the output of sdp = (batch, num_heads, seq_len, head_dim) | |
| # TODO: add support for attn.scale when we move to Torch 2.1 | |
| current_ip_hidden_states = F.scaled_dot_product_attention( | |
| query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False | |
| ) | |
| current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape( | |
| batch_size, -1, attn.heads * head_dim | |
| ) | |
| current_ip_hidden_states = current_ip_hidden_states.to(query.dtype) | |
| if mask is not None: | |
| mask_downsample = IPAdapterMaskProcessor.downsample( | |
| mask, | |
| batch_size, | |
| current_ip_hidden_states.shape[1], | |
| current_ip_hidden_states.shape[2], | |
| ) | |
| mask_downsample = mask_downsample.to( | |
| dtype=query.dtype, device=query.device | |
| ) | |
| current_ip_hidden_states = current_ip_hidden_states * mask_downsample | |
| if ortho is None: | |
| hidden_states = hidden_states + scale * current_ip_hidden_states | |
| elif ortho == "ortho": | |
| orig_dtype = hidden_states.dtype | |
| hidden_states = hidden_states.to(torch.float32) | |
| current_ip_hidden_states = current_ip_hidden_states.to(torch.float32) | |
| projection = ( | |
| torch.sum( | |
| (hidden_states * current_ip_hidden_states), dim=-2, keepdim=True | |
| ) | |
| / torch.sum((hidden_states * hidden_states), dim=-2, keepdim=True) | |
| * hidden_states | |
| ) | |
| orthogonal = current_ip_hidden_states - projection | |
| hidden_states = hidden_states + current_ip_hidden_states * orthogonal | |
| hidden_states = hidden_states.to(orig_dtype) | |
| elif ortho == "ortho_v2": | |
| orig_dtype = hidden_states.dtype | |
| hidden_states = hidden_states.to(torch.float32) | |
| current_ip_hidden_states = current_ip_hidden_states.to(torch.float32) | |
| attn_map = query @ ip_key.transpose(-2, -1) | |
| attn_mean = attn_map.softmax(dim=-1).mean(dim=1) | |
| attn_mean = attn_mean[:, :, :5].sum(dim=-1, keepdim=True) | |
| projection = ( | |
| torch.sum( | |
| (hidden_states * current_ip_hidden_states), dim=-2, keepdim=True | |
| ) | |
| / torch.sum((hidden_states * hidden_states), dim=-2, keepdim=True) | |
| * hidden_states | |
| ) | |
| orthogonal = current_ip_hidden_states + (attn_mean - 1) * projection | |
| hidden_states = hidden_states + current_ip_hidden_states * orthogonal | |
| hidden_states = hidden_states.to(orig_dtype) | |
| else: | |
| raise ValueError(f"{ortho} not supported") | |
| # linear proj | |
| hidden_states = attn.to_out[0](hidden_states) | |
| # dropout | |
| hidden_states = attn.to_out[1](hidden_states) | |
| if input_ndim == 4: | |
| hidden_states = hidden_states.transpose(-1, -2).reshape( | |
| batch_size, channel, height, width | |
| ) | |
| if attn.residual_connection: | |
| hidden_states = hidden_states + residual | |
| hidden_states = hidden_states / attn.rescale_output_factor | |
| return hidden_states | |