revi13 commited on
Commit
6159afa
·
verified ·
1 Parent(s): bde45d4

Update ip_adapter/attention_processor_faceid.py

Browse files
ip_adapter/attention_processor_faceid.py CHANGED
@@ -392,7 +392,7 @@ class LoRAIPAttnProcessor2_0(nn.Module):
392
  head_dim = inner_dim // attn.heads
393
 
394
  query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
395
- query = query.to(dtype=ip_key.dtype) # ← これを追加
396
  key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
397
  value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
398
 
@@ -413,7 +413,7 @@ class LoRAIPAttnProcessor2_0(nn.Module):
413
  ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
414
  ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
415
 
416
-
417
 
418
  # the output of sdp = (batch, num_heads, seq_len, head_dim)
419
  # TODO: add support for attn.scale when we move to Torch 2.1
 
392
  head_dim = inner_dim // attn.heads
393
 
394
  query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
395
+ query = query.to(dtype=key.dtype) # ← これを追加
396
  key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
397
  value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
398
 
 
413
  ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
414
  ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
415
 
416
+ query = query.to(dtype=ip_key.dtype) # ← これを追加
417
 
418
  # the output of sdp = (batch, num_heads, seq_len, head_dim)
419
  # TODO: add support for attn.scale when we move to Torch 2.1