revi13 commited on
Commit
c7b8b01
·
verified ·
1 Parent(s): 5071412

Update ip_adapter/attention_processor_faceid.py

Browse files
ip_adapter/attention_processor_faceid.py CHANGED
@@ -181,7 +181,8 @@ class LoRAIPAttnProcessor(nn.Module):
181
  hidden_states = attn.batch_to_head_dim(hidden_states)
182
 
183
  # for ip-adapter
184
- ip_key = self.to_k_ip(ip_hidden_states.to(dtype=self.to_k_ip.weight.dtype))
 
185
  ip_value = self.to_v_ip(ip_hidden_states)
186
 
187
  ip_key = attn.head_to_batch_dim(ip_key)
@@ -400,7 +401,8 @@ class LoRAIPAttnProcessor2_0(nn.Module):
400
  hidden_states = hidden_states.to(query.dtype)
401
 
402
  # for ip
403
- ip_key = self.to_k_ip(ip_hidden_states.to(dtype=self.to_k_ip.weight.dtype))
 
404
  ip_value = self.to_v_ip(ip_hidden_states)
405
 
406
  ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
 
181
  hidden_states = attn.batch_to_head_dim(hidden_states)
182
 
183
  # for ip-adapter
184
+ ip_hidden_states = ip_hidden_states.to(dtype=self.to_k_ip.weight.dtype) # ★追加・修正
185
+ ip_key = self.to_k_ip(ip_hidden_states)
186
  ip_value = self.to_v_ip(ip_hidden_states)
187
 
188
  ip_key = attn.head_to_batch_dim(ip_key)
 
401
  hidden_states = hidden_states.to(query.dtype)
402
 
403
  # for ip
404
+ ip_hidden_states = ip_hidden_states.to(dtype=self.to_k_ip.weight.dtype) # ★追加・修正
405
+ ip_key = self.to_k_ip(ip_hidden_states)
406
  ip_value = self.to_v_ip(ip_hidden_states)
407
 
408
  ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)