Update modeling_motif.py
Browse files- modeling_motif.py +0 -27
modeling_motif.py
CHANGED
|
@@ -571,33 +571,6 @@ class MotifFlashAttention2(MotifAttention):
|
|
| 571 |
|
| 572 |
bsz = query_states.shape[0]
|
| 573 |
|
| 574 |
-
if batch_num:
|
| 575 |
-
query_states = query_states.reshape(bsz*q_len,self.num_heads,self.head_dim)
|
| 576 |
-
key_states = key_states.reshape(bsz*q_len,self.num_heads,self.head_dim)
|
| 577 |
-
value_states = value_states.reshape(bsz*q_len,self.num_heads,self.head_dim)
|
| 578 |
-
|
| 579 |
-
attn_out = moreh_ops.flash_attention_varlen_dp(query_states,
|
| 580 |
-
key_states,
|
| 581 |
-
value_states,
|
| 582 |
-
attention_mask,
|
| 583 |
-
attention_mask,
|
| 584 |
-
max_seqlen_q=q_len,
|
| 585 |
-
max_seqlen_kv=q_len,
|
| 586 |
-
dropout_p=dropout_rate,
|
| 587 |
-
softmax_scale=scale_factor,
|
| 588 |
-
is_causal=causal,
|
| 589 |
-
batch_num=batch_num)
|
| 590 |
-
attn_out = attn_out.reshape(bsz, q_len, self.num_heads, -1)
|
| 591 |
-
else:
|
| 592 |
-
return MorehFlashAttention(query_states,
|
| 593 |
-
key_states,
|
| 594 |
-
value_states,
|
| 595 |
-
padding_mask=attention_mask,
|
| 596 |
-
dropout_p=dropout_rate,
|
| 597 |
-
softmax_scale=scale_factor,
|
| 598 |
-
causal=causal)
|
| 599 |
-
return attn_out
|
| 600 |
-
else:
|
| 601 |
return _flash_attention_forward(query_states,
|
| 602 |
key_states,
|
| 603 |
value_states,
|
|
|
|
| 571 |
|
| 572 |
bsz = query_states.shape[0]
|
| 573 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 574 |
return _flash_attention_forward(query_states,
|
| 575 |
key_states,
|
| 576 |
value_states,
|