WonwoongCho commited on
Commit
d761edd
·
1 Parent(s): f90a31c

update attention processor

Browse files
Files changed (1) hide show
  1. src/attention_processor.py +1 -1
src/attention_processor.py CHANGED
@@ -73,7 +73,7 @@ class FluxBlendedAttnProcessor2_0(nn.Module):
73
  ba_value = ba_value.view(chunk, -1, attn.heads, head_dim).transpose(1, 2)
74
 
75
  ba_hidden_states = F.scaled_dot_product_attention(
76
- ba_query, ba_key, ba_value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=(1 / math.sqrt(ba_query.size(-1)))*self.temperature if self.num_ref > 1 else 1 / math.sqrt(ba_query.size(-1))
77
  )
78
 
79
  ba_hidden_states = ba_hidden_states.transpose(1, 2).reshape(chunk, -1, attn.heads * head_dim)
 
73
  ba_value = ba_value.view(chunk, -1, attn.heads, head_dim).transpose(1, 2)
74
 
75
  ba_hidden_states = F.scaled_dot_product_attention(
76
+ ba_query, ba_key, ba_value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False#, scale=(1 / math.sqrt(ba_query.size(-1)))*self.temperature if self.num_ref > 1 else 1 / math.sqrt(ba_query.size(-1))
77
  )
78
 
79
  ba_hidden_states = ba_hidden_states.transpose(1, 2).reshape(chunk, -1, attn.heads * head_dim)