revi13 commited on
Commit
0213b58
·
verified ·
1 Parent(s): c7b8b01

Update ip_adapter/attention_processor_faceid.py

Browse files
ip_adapter/attention_processor_faceid.py CHANGED
@@ -60,6 +60,8 @@ class LoRAAttnProcessor(nn.Module):
60
  hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
61
 
62
  query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states)
 
 
63
 
64
  if encoder_hidden_states is None:
65
  encoder_hidden_states = hidden_states
@@ -156,6 +158,7 @@ class LoRAIPAttnProcessor(nn.Module):
156
  hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
157
 
158
  query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states)
 
159
 
160
  if encoder_hidden_states is None:
161
  encoder_hidden_states = hidden_states
 
60
  hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
61
 
62
  query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states)
63
+ query = query.to(dtype=key.dtype) # ★ 追加
64
+
65
 
66
  if encoder_hidden_states is None:
67
  encoder_hidden_states = hidden_states
 
158
  hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
159
 
160
  query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states)
161
+ query = query.to(dtype=key.dtype) # ★追加(または .half())
162
 
163
  if encoder_hidden_states is None:
164
  encoder_hidden_states = hidden_states