Spaces:
Runtime error
Runtime error
| """ | |
| Torch 2.0 Optimized IP-Adapter Attention - Maintains Weight Compatibility | |
| =========================================================================== | |
| Architecture IDENTICAL to InstantID's pretrained weights. | |
| Only adds torch 2.0 performance optimizations. | |
| Author: Pixagram Team | |
| License: MIT | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from typing import Optional | |
| from diffusers.models.attention_processor import AttnProcessor2_0 | |
| class IPAttnProcessorCompatible(nn.Module): | |
| """ | |
| IP-Adapter attention processor with EXACT architecture for weight loading. | |
| Optimized for torch 2.0 but maintains compatibility. | |
| """ | |
| def __init__( | |
| self, | |
| hidden_size: int, | |
| cross_attention_dim: Optional[int] = None, | |
| scale: float = 1.0, | |
| num_tokens: int = 4, | |
| ): | |
| super().__init__() | |
| if not hasattr(F, "scaled_dot_product_attention"): | |
| raise ImportError("Requires PyTorch 2.0+ for scaled_dot_product_attention") | |
| self.hidden_size = hidden_size | |
| self.cross_attention_dim = cross_attention_dim or hidden_size | |
| self.scale = scale | |
| self.num_tokens = num_tokens | |
| # Dedicated K/V projections - MUST match pretrained architecture | |
| self.to_k_ip = nn.Linear(self.cross_attention_dim, hidden_size, bias=False) | |
| self.to_v_ip = nn.Linear(self.cross_attention_dim, hidden_size, bias=False) | |
| def forward( | |
| self, | |
| attn, | |
| hidden_states: torch.FloatTensor, | |
| encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
| attention_mask: Optional[torch.FloatTensor] = None, | |
| temb: Optional[torch.FloatTensor] = None, | |
| ) -> torch.FloatTensor: | |
| """Standard IP-Adapter forward pass with torch 2.0 attention.""" | |
| residual = hidden_states | |
| if attn.spatial_norm is not None: | |
| hidden_states = attn.spatial_norm(hidden_states, temb) | |
| input_ndim = hidden_states.ndim | |
| if input_ndim == 4: | |
| batch_size, channel, height, width = hidden_states.shape | |
| hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) | |
| batch_size, sequence_length, _ = ( | |
| hidden_states.shape if encoder_hidden_states is None | |
| else encoder_hidden_states.shape | |
| ) | |
| if attention_mask is not None: | |
| attention_mask = attn.prepare_attention_mask( | |
| attention_mask, sequence_length, batch_size | |
| ) | |
| attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) | |
| if attn.group_norm is not None: | |
| hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | |
| query = attn.to_q(hidden_states) | |
| # Split text and image embeddings | |
| if encoder_hidden_states is None: | |
| encoder_hidden_states = hidden_states | |
| ip_hidden_states = None | |
| else: | |
| end_pos = encoder_hidden_states.shape[1] - self.num_tokens | |
| encoder_hidden_states, ip_hidden_states = ( | |
| encoder_hidden_states[:, :end_pos, :], | |
| encoder_hidden_states[:, end_pos:, :] | |
| ) | |
| if attn.norm_cross: | |
| encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | |
| # Text attention with torch 2.0 | |
| key = attn.to_k(encoder_hidden_states) | |
| value = attn.to_v(encoder_hidden_states) | |
| inner_dim = key.shape[-1] | |
| head_dim = inner_dim // attn.heads | |
| query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| # Torch 2.0 optimized attention | |
| hidden_states = F.scaled_dot_product_attention( | |
| query, key, value, | |
| attn_mask=attention_mask, | |
| dropout_p=0.0, | |
| is_causal=False | |
| ) | |
| hidden_states = hidden_states.transpose(1, 2).reshape( | |
| batch_size, -1, attn.heads * head_dim | |
| ) | |
| hidden_states = hidden_states.to(query.dtype) | |
| # Image attention if available | |
| if ip_hidden_states is not None: | |
| ip_key = self.to_k_ip(ip_hidden_states) | |
| ip_value = self.to_v_ip(ip_hidden_states) | |
| ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| # Torch 2.0 image attention | |
| ip_hidden_states = F.scaled_dot_product_attention( | |
| query, ip_key, ip_value, | |
| attn_mask=None, | |
| dropout_p=0.0, | |
| is_causal=False | |
| ) | |
| ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape( | |
| batch_size, -1, attn.heads * head_dim | |
| ) | |
| ip_hidden_states = ip_hidden_states.to(query.dtype) | |
| # Blend with scale | |
| hidden_states = hidden_states + self.scale * ip_hidden_states | |
| # Output projection | |
| hidden_states = attn.to_out[0](hidden_states) | |
| hidden_states = attn.to_out[1](hidden_states) | |
| if input_ndim == 4: | |
| hidden_states = hidden_states.transpose(-1, -2).reshape( | |
| batch_size, channel, height, width | |
| ) | |
| if attn.residual_connection: | |
| hidden_states = hidden_states + residual | |
| hidden_states = hidden_states / attn.rescale_output_factor | |
| return hidden_states | |
| def setup_compatible_ip_adapter_attention( | |
| pipe, | |
| ip_adapter_scale: float = 1.0, | |
| num_tokens: int = 4, | |
| device: str = "cuda", | |
| dtype = torch.float16, | |
| ): | |
| """ | |
| Setup IP-Adapter with compatible architecture for weight loading. | |
| """ | |
| attn_procs = {} | |
| for name in pipe.unet.attn_processors.keys(): | |
| cross_attention_dim = None if name.endswith("attn1.processor") else pipe.unet.config.cross_attention_dim | |
| if name.startswith("mid_block"): | |
| hidden_size = pipe.unet.config.block_out_channels[-1] | |
| elif name.startswith("up_blocks"): | |
| block_id = int(name[len("up_blocks.")]) | |
| hidden_size = list(reversed(pipe.unet.config.block_out_channels))[block_id] | |
| elif name.startswith("down_blocks"): | |
| block_id = int(name[len("down_blocks.")]) | |
| hidden_size = pipe.unet.config.block_out_channels[block_id] | |
| else: | |
| hidden_size = pipe.unet.config.block_out_channels[-1] | |
| if cross_attention_dim is None: | |
| attn_procs[name] = AttnProcessor2_0() | |
| else: | |
| attn_procs[name] = IPAttnProcessorCompatible( | |
| hidden_size=hidden_size, | |
| cross_attention_dim=cross_attention_dim, | |
| scale=ip_adapter_scale, | |
| num_tokens=num_tokens | |
| ).to(device, dtype=dtype) | |
| print(f"[OK] Compatible attention processors created") | |
| print(f" - Architecture matches pretrained weights") | |
| print(f" - Using torch 2.0 optimizations") | |
| return attn_procs | |
| if __name__ == "__main__": | |
| print("Testing Compatible IP-Adapter Processor...") | |
| processor = IPAttnProcessorCompatible( | |
| hidden_size=1280, | |
| cross_attention_dim=2048, | |
| scale=0.8, | |
| num_tokens=4 | |
| ) | |
| print(f"[OK] Compatible processor created") | |
| print(f"Parameters: {sum(p.numel() for p in processor.parameters()):,}") | |