revi13 commited on
Commit
282bb6d
·
verified ·
1 Parent(s): 294e0b0

Update ip_adapter/attention_processor_faceid.py

Browse files
ip_adapter/attention_processor_faceid.py CHANGED
@@ -280,7 +280,7 @@ class LoRAAttnProcessor2_0(nn.Module):
280
  head_dim = inner_dim // attn.heads
281
 
282
  query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
283
-
284
  key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
285
  value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
286
 
@@ -389,10 +389,9 @@ class LoRAIPAttnProcessor2_0(nn.Module):
389
  head_dim = inner_dim // attn.heads
390
 
391
  query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
392
-
393
  key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
394
  value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
395
-
396
  # the output of sdp = (batch, num_heads, seq_len, head_dim)
397
  # TODO: add support for attn.scale when we move to Torch 2.1
398
  hidden_states = F.scaled_dot_product_attention(
@@ -416,7 +415,6 @@ class LoRAIPAttnProcessor2_0(nn.Module):
416
  query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
417
  )
418
 
419
-
420
  ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
421
  ip_hidden_states = ip_hidden_states.to(dtype=ip_key.dtype)
422
 
 
280
  head_dim = inner_dim // attn.heads
281
 
282
  query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
283
+ query = query.to(dtype=key.dtype)
284
  key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
285
  value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
286
 
 
389
  head_dim = inner_dim // attn.heads
390
 
391
  query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
 
392
  key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
393
  value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
394
+
395
  # the output of sdp = (batch, num_heads, seq_len, head_dim)
396
  # TODO: add support for attn.scale when we move to Torch 2.1
397
  hidden_states = F.scaled_dot_product_attention(
 
415
  query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
416
  )
417
 
 
418
  ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
419
  ip_hidden_states = ip_hidden_states.to(dtype=ip_key.dtype)
420