pixagram-backup / ip_attention_processor_compatible.py
primerz's picture
Upload 12 files
fe30f16 verified
raw
history blame
7.81 kB
"""
Torch 2.0 Optimized IP-Adapter Attention - Maintains Weight Compatibility
===========================================================================
Architecture IDENTICAL to InstantID's pretrained weights.
Only adds torch 2.0 performance optimizations.
Author: Pixagram Team
License: MIT
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
from diffusers.models.attention_processor import AttnProcessor2_0
class IPAttnProcessorCompatible(nn.Module):
"""
IP-Adapter attention processor with EXACT architecture for weight loading.
Optimized for torch 2.0 but maintains compatibility.
"""
def __init__(
self,
hidden_size: int,
cross_attention_dim: Optional[int] = None,
scale: float = 1.0,
num_tokens: int = 4,
):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("Requires PyTorch 2.0+ for scaled_dot_product_attention")
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim or hidden_size
self.scale = scale
self.num_tokens = num_tokens
# Dedicated K/V projections - MUST match pretrained architecture
self.to_k_ip = nn.Linear(self.cross_attention_dim, hidden_size, bias=False)
self.to_v_ip = nn.Linear(self.cross_attention_dim, hidden_size, bias=False)
def forward(
self,
attn,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
"""Standard IP-Adapter forward pass with torch 2.0 attention."""
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None
else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(
attention_mask, sequence_length, batch_size
)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
# Split text and image embeddings
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
ip_hidden_states = None
else:
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
encoder_hidden_states, ip_hidden_states = (
encoder_hidden_states[:, :end_pos, :],
encoder_hidden_states[:, end_pos:, :]
)
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
# Text attention with torch 2.0
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# Torch 2.0 optimized attention
hidden_states = F.scaled_dot_product_attention(
query, key, value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
hidden_states = hidden_states.to(query.dtype)
# Image attention if available
if ip_hidden_states is not None:
ip_key = self.to_k_ip(ip_hidden_states)
ip_value = self.to_v_ip(ip_hidden_states)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# Torch 2.0 image attention
ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value,
attn_mask=None,
dropout_p=0.0,
is_causal=False
)
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
ip_hidden_states = ip_hidden_states.to(query.dtype)
# Blend with scale
hidden_states = hidden_states + self.scale * ip_hidden_states
# Output projection
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(
batch_size, channel, height, width
)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
def setup_compatible_ip_adapter_attention(
pipe,
ip_adapter_scale: float = 1.0,
num_tokens: int = 4,
device: str = "cuda",
dtype = torch.float16,
):
"""
Setup IP-Adapter with compatible architecture for weight loading.
"""
attn_procs = {}
for name in pipe.unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else pipe.unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = pipe.unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(pipe.unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = pipe.unet.config.block_out_channels[block_id]
else:
hidden_size = pipe.unet.config.block_out_channels[-1]
if cross_attention_dim is None:
attn_procs[name] = AttnProcessor2_0()
else:
attn_procs[name] = IPAttnProcessorCompatible(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
scale=ip_adapter_scale,
num_tokens=num_tokens
).to(device, dtype=dtype)
print(f"[OK] Compatible attention processors created")
print(f" - Architecture matches pretrained weights")
print(f" - Using torch 2.0 optimizations")
return attn_procs
if __name__ == "__main__":
print("Testing Compatible IP-Adapter Processor...")
processor = IPAttnProcessorCompatible(
hidden_size=1280,
cross_attention_dim=2048,
scale=0.8,
num_tokens=4
)
print(f"[OK] Compatible processor created")
print(f"Parameters: {sum(p.numel() for p in processor.parameters()):,}")