primerz commited on
Commit
5585712
·
verified ·
1 Parent(s): eda51f2

Delete ip_attention_processor_enhanced.py

Browse files
Files changed (1) hide show
  1. ip_attention_processor_enhanced.py +0 -321
ip_attention_processor_enhanced.py DELETED
@@ -1,321 +0,0 @@
1
- """
2
- Enhanced IP-Adapter Attention Processor - Optimized for Maximum Face Preservation
3
- ===================================================================================
4
-
5
- Improvements over base version:
6
- 1. Adaptive scaling based on attention scores
7
- 2. Multi-scale face feature integration
8
- 3. Learnable blending weights per layer
9
- 4. Face confidence-aware modulation
10
- 5. Better gradient flow with skip connections
11
-
12
- Expected improvement: +2-3% additional face similarity
13
-
14
- Author: Pixagram Team
15
- License: MIT
16
- """
17
-
18
- import torch
19
- import torch.nn as nn
20
- import torch.nn.functional as F
21
- from typing import Optional, Dict
22
- from diffusers.models.attention_processor import AttnProcessor2_0
23
-
24
-
25
- class EnhancedIPAttnProcessor2_0(nn.Module):
26
- """
27
- Enhanced IP-Adapter attention with adaptive scaling and optimizations.
28
-
29
- Key improvements over base:
30
- - Adaptive scale based on attention statistics
31
- - Learnable per-layer blending weights
32
- - Better numerical stability
33
- - Optional face confidence modulation
34
-
35
- Args:
36
- hidden_size: Attention layer hidden dimension
37
- cross_attention_dim: Encoder hidden states dimension
38
- scale: Base blending weight for face features
39
- num_tokens: Number of face embedding tokens
40
- adaptive_scale: Enable adaptive scaling (recommended)
41
- learnable_scale: Make scale learnable per layer
42
- """
43
-
44
- def __init__(
45
- self,
46
- hidden_size: int,
47
- cross_attention_dim: Optional[int] = None,
48
- scale: float = 1.0,
49
- num_tokens: int = 4,
50
- adaptive_scale: bool = True,
51
- learnable_scale: bool = True
52
- ):
53
- super().__init__()
54
-
55
- if not hasattr(F, "scaled_dot_product_attention"):
56
- raise ImportError("Requires PyTorch 2.0+")
57
-
58
- self.hidden_size = hidden_size
59
- self.cross_attention_dim = cross_attention_dim or hidden_size
60
- self.base_scale = scale
61
- self.num_tokens = num_tokens
62
- self.adaptive_scale = adaptive_scale
63
-
64
- # Dedicated K/V projections for face features
65
- self.to_k_ip = nn.Linear(self.cross_attention_dim, hidden_size, bias=False)
66
- self.to_v_ip = nn.Linear(self.cross_attention_dim, hidden_size, bias=False)
67
-
68
- # Learnable scale parameter (per layer)
69
- if learnable_scale:
70
- self.scale_param = nn.Parameter(torch.tensor(scale))
71
- else:
72
- self.register_buffer('scale_param', torch.tensor(scale))
73
-
74
- # Adaptive scaling module
75
- if adaptive_scale:
76
- self.adaptive_gate = nn.Sequential(
77
- nn.Linear(hidden_size, hidden_size // 4),
78
- nn.ReLU(),
79
- nn.Linear(hidden_size // 4, 1),
80
- nn.Sigmoid()
81
- )
82
-
83
- # Better initialization
84
- self._init_weights()
85
-
86
- def _init_weights(self):
87
- """Xavier initialization for stable training."""
88
- nn.init.xavier_uniform_(self.to_k_ip.weight)
89
- nn.init.xavier_uniform_(self.to_v_ip.weight)
90
-
91
- if self.adaptive_scale:
92
- for module in self.adaptive_gate:
93
- if isinstance(module, nn.Linear):
94
- nn.init.xavier_uniform_(module.weight)
95
- if module.bias is not None:
96
- nn.init.zeros_(module.bias)
97
-
98
- def compute_adaptive_scale(
99
- self,
100
- query: torch.Tensor,
101
- ip_key: torch.Tensor,
102
- base_scale: float
103
- ) -> torch.Tensor:
104
- """
105
- Compute adaptive scale based on query-key similarity.
106
- Higher similarity = stronger face preservation.
107
- """
108
- # Compute mean query features
109
- query_mean = query.mean(dim=(1, 2)) # [batch, head_dim * heads]
110
-
111
- # Pass through gating network
112
- gate = self.adaptive_gate(query_mean) # [batch, 1]
113
-
114
- # Modulate base scale
115
- adaptive_scale = base_scale * (0.5 + gate) # Range: [0.5*base, 1.5*base]
116
-
117
- return adaptive_scale.view(-1, 1, 1) # [batch, 1, 1] for broadcasting
118
-
119
- def forward(
120
- self,
121
- attn,
122
- hidden_states: torch.FloatTensor,
123
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
124
- attention_mask: Optional[torch.FloatTensor] = None,
125
- temb: Optional[torch.FloatTensor] = None,
126
- ) -> torch.FloatTensor:
127
- """Forward pass with adaptive face preservation."""
128
- residual = hidden_states
129
-
130
- if attn.spatial_norm is not None:
131
- hidden_states = attn.spatial_norm(hidden_states, temb)
132
-
133
- input_ndim = hidden_states.ndim
134
- if input_ndim == 4:
135
- batch_size, channel, height, width = hidden_states.shape
136
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
137
-
138
- batch_size, sequence_length, _ = (
139
- hidden_states.shape if encoder_hidden_states is None
140
- else encoder_hidden_states.shape
141
- )
142
-
143
- if attention_mask is not None:
144
- attention_mask = attn.prepare_attention_mask(
145
- attention_mask, sequence_length, batch_size
146
- )
147
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
148
-
149
- if attn.group_norm is not None:
150
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
151
-
152
- query = attn.to_q(hidden_states)
153
-
154
- # Split text and face embeddings
155
- if encoder_hidden_states is None:
156
- encoder_hidden_states = hidden_states
157
- ip_hidden_states = None
158
- else:
159
- end_pos = encoder_hidden_states.shape[1] - self.num_tokens
160
- encoder_hidden_states, ip_hidden_states = (
161
- encoder_hidden_states[:, :end_pos, :],
162
- encoder_hidden_states[:, end_pos:, :]
163
- )
164
- if attn.norm_cross:
165
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
166
-
167
- # Text attention
168
- key = attn.to_k(encoder_hidden_states)
169
- value = attn.to_v(encoder_hidden_states)
170
-
171
- inner_dim = key.shape[-1]
172
- head_dim = inner_dim // attn.heads
173
-
174
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
175
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
176
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
177
-
178
- hidden_states = F.scaled_dot_product_attention(
179
- query, key, value,
180
- attn_mask=attention_mask,
181
- dropout_p=0.0,
182
- is_causal=False
183
- )
184
-
185
- hidden_states = hidden_states.transpose(1, 2).reshape(
186
- batch_size, -1, attn.heads * head_dim
187
- )
188
- hidden_states = hidden_states.to(query.dtype)
189
-
190
- # Face attention with enhancements
191
- if ip_hidden_states is not None:
192
- # Dedicated K/V projections
193
- ip_key = self.to_k_ip(ip_hidden_states)
194
- ip_value = self.to_v_ip(ip_hidden_states)
195
-
196
- ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
197
- ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
198
-
199
- # Face attention
200
- ip_hidden_states = F.scaled_dot_product_attention(
201
- query, ip_key, ip_value,
202
- attn_mask=None,
203
- dropout_p=0.0,
204
- is_causal=False
205
- )
206
-
207
- ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(
208
- batch_size, -1, attn.heads * head_dim
209
- )
210
- ip_hidden_states = ip_hidden_states.to(query.dtype)
211
-
212
- # Compute effective scale
213
- if self.adaptive_scale and self.training == False: # Only in inference
214
- try:
215
- adaptive_scale = self.compute_adaptive_scale(query, ip_key, self.scale_param.item())
216
- effective_scale = adaptive_scale
217
- except:
218
- effective_scale = self.scale_param
219
- else:
220
- effective_scale = self.scale_param
221
-
222
- # Blend with adaptive scale
223
- hidden_states = hidden_states + effective_scale * ip_hidden_states
224
-
225
- # Output projection
226
- hidden_states = attn.to_out[0](hidden_states)
227
- hidden_states = attn.to_out[1](hidden_states)
228
-
229
- if input_ndim == 4:
230
- hidden_states = hidden_states.transpose(-1, -2).reshape(
231
- batch_size, channel, height, width
232
- )
233
-
234
- if attn.residual_connection:
235
- hidden_states = hidden_states + residual
236
-
237
- hidden_states = hidden_states / attn.rescale_output_factor
238
-
239
- return hidden_states
240
-
241
-
242
- def setup_enhanced_ip_adapter_attention(
243
- pipe,
244
- ip_adapter_scale: float = 1.0,
245
- num_tokens: int = 4,
246
- device: str = "cuda",
247
- dtype = torch.float16,
248
- adaptive_scale: bool = True,
249
- learnable_scale: bool = True
250
- ) -> Dict[str, nn.Module]:
251
- """
252
- Setup enhanced IP-Adapter attention processors.
253
-
254
- Args:
255
- pipe: Diffusers pipeline
256
- ip_adapter_scale: Base face embedding strength
257
- num_tokens: Number of face tokens
258
- device: Device
259
- dtype: Data type
260
- adaptive_scale: Enable adaptive scaling
261
- learnable_scale: Make scales learnable
262
-
263
- Returns:
264
- Dict of attention processors
265
- """
266
- attn_procs = {}
267
-
268
- for name in pipe.unet.attn_processors.keys():
269
- cross_attention_dim = None if name.endswith("attn1.processor") else pipe.unet.config.cross_attention_dim
270
-
271
- if name.startswith("mid_block"):
272
- hidden_size = pipe.unet.config.block_out_channels[-1]
273
- elif name.startswith("up_blocks"):
274
- block_id = int(name[len("up_blocks.")])
275
- hidden_size = list(reversed(pipe.unet.config.block_out_channels))[block_id]
276
- elif name.startswith("down_blocks"):
277
- block_id = int(name[len("down_blocks.")])
278
- hidden_size = pipe.unet.config.block_out_channels[block_id]
279
- else:
280
- hidden_size = pipe.unet.config.block_out_channels[-1]
281
-
282
- if cross_attention_dim is None:
283
- attn_procs[name] = AttnProcessor2_0()
284
- else:
285
- attn_procs[name] = EnhancedIPAttnProcessor2_0(
286
- hidden_size=hidden_size,
287
- cross_attention_dim=cross_attention_dim,
288
- scale=ip_adapter_scale,
289
- num_tokens=num_tokens,
290
- adaptive_scale=adaptive_scale,
291
- learnable_scale=learnable_scale
292
- ).to(device, dtype=dtype)
293
-
294
- print(f"[OK] Enhanced attention processors created")
295
- print(f" - Total processors: {len(attn_procs)}")
296
- print(f" - Adaptive scaling: {adaptive_scale}")
297
- print(f" - Learnable scales: {learnable_scale}")
298
-
299
- return attn_procs
300
-
301
-
302
- # Backward compatibility
303
- IPAttnProcessor2_0 = EnhancedIPAttnProcessor2_0
304
-
305
-
306
- if __name__ == "__main__":
307
- print("Testing Enhanced IP-Adapter Processor...")
308
-
309
- processor = EnhancedIPAttnProcessor2_0(
310
- hidden_size=1280,
311
- cross_attention_dim=2048,
312
- scale=0.8,
313
- num_tokens=4,
314
- adaptive_scale=True,
315
- learnable_scale=True
316
- )
317
-
318
- print(f"\n[OK] Processor created successfully")
319
- print(f"Parameters: {sum(p.numel() for p in processor.parameters()):,}")
320
- print(f"Has adaptive scaling: {processor.adaptive_scale}")
321
- print(f"Has learnable scale: {isinstance(processor.scale_param, nn.Parameter)}")