Spaces:
Sleeping
Sleeping
Update ip_adapter/attention_processor_faceid.py
Browse files
ip_adapter/attention_processor_faceid.py
CHANGED
|
@@ -280,7 +280,7 @@ class LoRAAttnProcessor2_0(nn.Module):
|
|
| 280 |
head_dim = inner_dim // attn.heads
|
| 281 |
|
| 282 |
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 283 |
-
|
| 284 |
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 285 |
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 286 |
|
|
@@ -389,10 +389,9 @@ class LoRAIPAttnProcessor2_0(nn.Module):
|
|
| 389 |
head_dim = inner_dim // attn.heads
|
| 390 |
|
| 391 |
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 392 |
-
|
| 393 |
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 394 |
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 395 |
-
|
| 396 |
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
| 397 |
# TODO: add support for attn.scale when we move to Torch 2.1
|
| 398 |
hidden_states = F.scaled_dot_product_attention(
|
|
@@ -416,7 +415,6 @@ class LoRAIPAttnProcessor2_0(nn.Module):
|
|
| 416 |
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
| 417 |
)
|
| 418 |
|
| 419 |
-
|
| 420 |
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
| 421 |
ip_hidden_states = ip_hidden_states.to(dtype=ip_key.dtype)
|
| 422 |
|
|
|
|
| 280 |
head_dim = inner_dim // attn.heads
|
| 281 |
|
| 282 |
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 283 |
+
query = query.to(dtype=key.dtype)
|
| 284 |
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 285 |
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 286 |
|
|
|
|
| 389 |
head_dim = inner_dim // attn.heads
|
| 390 |
|
| 391 |
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
|
|
| 392 |
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 393 |
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 394 |
+
|
| 395 |
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
| 396 |
# TODO: add support for attn.scale when we move to Torch 2.1
|
| 397 |
hidden_states = F.scaled_dot_product_attention(
|
|
|
|
| 415 |
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
| 416 |
)
|
| 417 |
|
|
|
|
| 418 |
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
| 419 |
ip_hidden_states = ip_hidden_states.to(dtype=ip_key.dtype)
|
| 420 |
|