primerz commited on
Commit
669defd
·
verified ·
1 Parent(s): 5585712

Delete ip_attention_processor_compatible.py

Browse files
Files changed (1) hide show
  1. ip_attention_processor_compatible.py +0 -213
ip_attention_processor_compatible.py DELETED
@@ -1,213 +0,0 @@
1
- """
2
- Torch 2.0 Optimized IP-Adapter Attention - Maintains Weight Compatibility
3
- ===========================================================================
4
-
5
- Architecture IDENTICAL to InstantID's pretrained weights.
6
- Only adds torch 2.0 performance optimizations.
7
-
8
- Author: Pixagram Team
9
- License: MIT
10
- """
11
-
12
- import torch
13
- import torch.nn as nn
14
- import torch.nn.functional as F
15
- from typing import Optional
16
- from diffusers.models.attention_processor import AttnProcessor2_0
17
-
18
-
19
- class IPAttnProcessorCompatible(nn.Module):
20
- """
21
- IP-Adapter attention processor with EXACT architecture for weight loading.
22
- Optimized for torch 2.0 but maintains compatibility.
23
- """
24
-
25
- def __init__(
26
- self,
27
- hidden_size: int,
28
- cross_attention_dim: Optional[int] = None,
29
- scale: float = 1.0,
30
- num_tokens: int = 4,
31
- ):
32
- super().__init__()
33
-
34
- if not hasattr(F, "scaled_dot_product_attention"):
35
- raise ImportError("Requires PyTorch 2.0+ for scaled_dot_product_attention")
36
-
37
- self.hidden_size = hidden_size
38
- self.cross_attention_dim = cross_attention_dim or hidden_size
39
- self.scale = scale
40
- self.num_tokens = num_tokens
41
-
42
- # Dedicated K/V projections - MUST match pretrained architecture
43
- self.to_k_ip = nn.Linear(self.cross_attention_dim, hidden_size, bias=False)
44
- self.to_v_ip = nn.Linear(self.cross_attention_dim, hidden_size, bias=False)
45
-
46
- def forward(
47
- self,
48
- attn,
49
- hidden_states: torch.FloatTensor,
50
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
51
- attention_mask: Optional[torch.FloatTensor] = None,
52
- temb: Optional[torch.FloatTensor] = None,
53
- ) -> torch.FloatTensor:
54
- """Standard IP-Adapter forward pass with torch 2.0 attention."""
55
- residual = hidden_states
56
-
57
- if attn.spatial_norm is not None:
58
- hidden_states = attn.spatial_norm(hidden_states, temb)
59
-
60
- input_ndim = hidden_states.ndim
61
- if input_ndim == 4:
62
- batch_size, channel, height, width = hidden_states.shape
63
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
64
-
65
- batch_size, sequence_length, _ = (
66
- hidden_states.shape if encoder_hidden_states is None
67
- else encoder_hidden_states.shape
68
- )
69
-
70
- if attention_mask is not None:
71
- attention_mask = attn.prepare_attention_mask(
72
- attention_mask, sequence_length, batch_size
73
- )
74
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
75
-
76
- if attn.group_norm is not None:
77
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
78
-
79
- query = attn.to_q(hidden_states)
80
-
81
- # Split text and image embeddings
82
- if encoder_hidden_states is None:
83
- encoder_hidden_states = hidden_states
84
- ip_hidden_states = None
85
- else:
86
- end_pos = encoder_hidden_states.shape[1] - self.num_tokens
87
- encoder_hidden_states, ip_hidden_states = (
88
- encoder_hidden_states[:, :end_pos, :],
89
- encoder_hidden_states[:, end_pos:, :]
90
- )
91
- if attn.norm_cross:
92
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
93
-
94
- # Text attention with torch 2.0
95
- key = attn.to_k(encoder_hidden_states)
96
- value = attn.to_v(encoder_hidden_states)
97
-
98
- inner_dim = key.shape[-1]
99
- head_dim = inner_dim // attn.heads
100
-
101
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
102
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
103
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
104
-
105
- # Torch 2.0 optimized attention
106
- hidden_states = F.scaled_dot_product_attention(
107
- query, key, value,
108
- attn_mask=attention_mask,
109
- dropout_p=0.0,
110
- is_causal=False
111
- )
112
-
113
- hidden_states = hidden_states.transpose(1, 2).reshape(
114
- batch_size, -1, attn.heads * head_dim
115
- )
116
- hidden_states = hidden_states.to(query.dtype)
117
-
118
- # Image attention if available
119
- if ip_hidden_states is not None:
120
- ip_key = self.to_k_ip(ip_hidden_states)
121
- ip_value = self.to_v_ip(ip_hidden_states)
122
-
123
- ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
124
- ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
125
-
126
- # Torch 2.0 image attention
127
- ip_hidden_states = F.scaled_dot_product_attention(
128
- query, ip_key, ip_value,
129
- attn_mask=None,
130
- dropout_p=0.0,
131
- is_causal=False
132
- )
133
-
134
- ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(
135
- batch_size, -1, attn.heads * head_dim
136
- )
137
- ip_hidden_states = ip_hidden_states.to(query.dtype)
138
-
139
- # Blend with scale
140
- hidden_states = hidden_states + self.scale * ip_hidden_states
141
-
142
- # Output projection
143
- hidden_states = attn.to_out[0](hidden_states)
144
- hidden_states = attn.to_out[1](hidden_states)
145
-
146
- if input_ndim == 4:
147
- hidden_states = hidden_states.transpose(-1, -2).reshape(
148
- batch_size, channel, height, width
149
- )
150
-
151
- if attn.residual_connection:
152
- hidden_states = hidden_states + residual
153
-
154
- hidden_states = hidden_states / attn.rescale_output_factor
155
-
156
- return hidden_states
157
-
158
-
159
- def setup_compatible_ip_adapter_attention(
160
- pipe,
161
- ip_adapter_scale: float = 1.0,
162
- num_tokens: int = 4,
163
- device: str = "cuda",
164
- dtype = torch.float16,
165
- ):
166
- """
167
- Setup IP-Adapter with compatible architecture for weight loading.
168
- """
169
- attn_procs = {}
170
-
171
- for name in pipe.unet.attn_processors.keys():
172
- cross_attention_dim = None if name.endswith("attn1.processor") else pipe.unet.config.cross_attention_dim
173
-
174
- if name.startswith("mid_block"):
175
- hidden_size = pipe.unet.config.block_out_channels[-1]
176
- elif name.startswith("up_blocks"):
177
- block_id = int(name[len("up_blocks.")])
178
- hidden_size = list(reversed(pipe.unet.config.block_out_channels))[block_id]
179
- elif name.startswith("down_blocks"):
180
- block_id = int(name[len("down_blocks.")])
181
- hidden_size = pipe.unet.config.block_out_channels[block_id]
182
- else:
183
- hidden_size = pipe.unet.config.block_out_channels[-1]
184
-
185
- if cross_attention_dim is None:
186
- attn_procs[name] = AttnProcessor2_0()
187
- else:
188
- attn_procs[name] = IPAttnProcessorCompatible(
189
- hidden_size=hidden_size,
190
- cross_attention_dim=cross_attention_dim,
191
- scale=ip_adapter_scale,
192
- num_tokens=num_tokens
193
- ).to(device, dtype=dtype)
194
-
195
- print(f"[OK] Compatible attention processors created")
196
- print(f" - Architecture matches pretrained weights")
197
- print(f" - Using torch 2.0 optimizations")
198
-
199
- return attn_procs
200
-
201
-
202
- if __name__ == "__main__":
203
- print("Testing Compatible IP-Adapter Processor...")
204
-
205
- processor = IPAttnProcessorCompatible(
206
- hidden_size=1280,
207
- cross_attention_dim=2048,
208
- scale=0.8,
209
- num_tokens=4
210
- )
211
-
212
- print(f"[OK] Compatible processor created")
213
- print(f"Parameters: {sum(p.numel() for p in processor.parameters()):,}")