Spaces:
Sleeping
Sleeping
Update ip_adapter/attention_processor_faceid.py
Browse files
ip_adapter/attention_processor_faceid.py
CHANGED
|
@@ -181,7 +181,8 @@ class LoRAIPAttnProcessor(nn.Module):
|
|
| 181 |
hidden_states = attn.batch_to_head_dim(hidden_states)
|
| 182 |
|
| 183 |
# for ip-adapter
|
| 184 |
-
|
|
|
|
| 185 |
ip_value = self.to_v_ip(ip_hidden_states)
|
| 186 |
|
| 187 |
ip_key = attn.head_to_batch_dim(ip_key)
|
|
@@ -400,7 +401,8 @@ class LoRAIPAttnProcessor2_0(nn.Module):
|
|
| 400 |
hidden_states = hidden_states.to(query.dtype)
|
| 401 |
|
| 402 |
# for ip
|
| 403 |
-
|
|
|
|
| 404 |
ip_value = self.to_v_ip(ip_hidden_states)
|
| 405 |
|
| 406 |
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
|
|
| 181 |
hidden_states = attn.batch_to_head_dim(hidden_states)
|
| 182 |
|
| 183 |
# for ip-adapter
|
| 184 |
+
ip_hidden_states = ip_hidden_states.to(dtype=self.to_k_ip.weight.dtype) # ★追加・修正
|
| 185 |
+
ip_key = self.to_k_ip(ip_hidden_states)
|
| 186 |
ip_value = self.to_v_ip(ip_hidden_states)
|
| 187 |
|
| 188 |
ip_key = attn.head_to_batch_dim(ip_key)
|
|
|
|
| 401 |
hidden_states = hidden_states.to(query.dtype)
|
| 402 |
|
| 403 |
# for ip
|
| 404 |
+
ip_hidden_states = ip_hidden_states.to(dtype=self.to_k_ip.weight.dtype) # ★追加・修正
|
| 405 |
+
ip_key = self.to_k_ip(ip_hidden_states)
|
| 406 |
ip_value = self.to_v_ip(ip_hidden_states)
|
| 407 |
|
| 408 |
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|