Update modeling_motif.py
Browse files- modeling_motif.py +4 -7
modeling_motif.py
CHANGED
|
@@ -545,9 +545,7 @@ class MotifFlashAttention2(MotifAttention):
|
|
| 545 |
|
| 546 |
bsz = query_states.shape[0]
|
| 547 |
|
| 548 |
-
return
|
| 549 |
-
lambda x: x.float32(),
|
| 550 |
-
_flash_attention_forward(query_states.bfloat16(),
|
| 551 |
key_states.bfloat16(),
|
| 552 |
value_states.bfloat16(),
|
| 553 |
attention_mask,
|
|
@@ -557,7 +555,6 @@ class MotifFlashAttention2(MotifAttention):
|
|
| 557 |
sliding_window=sliding_window,
|
| 558 |
is_causal=self.is_causal,
|
| 559 |
use_top_left_mask=self._flash_attn_uses_top_left_mask)
|
| 560 |
-
)
|
| 561 |
|
| 562 |
def forward(
|
| 563 |
self,
|
|
@@ -642,7 +639,7 @@ class MotifFlashAttention2(MotifAttention):
|
|
| 642 |
self._compute_attention(q2, k2, v2, attention_mask, q_len, position_ids, dropout_rate, sliding_window, self.batch_num)
|
| 643 |
|
| 644 |
|
| 645 |
-
attn1, attn2 = torch.cat([attn11, attn12], dim=-1), torch.cat([attn21, attn22], dim=-1)
|
| 646 |
|
| 647 |
lambda_q1 = self.lambda_q1.unsqueeze(0).expand([bsz, self.lambda_q1.shape[0]]) # bsz, num_head
|
| 648 |
lambda_q2 = self.lambda_q2.unsqueeze(0).expand([bsz, self.lambda_q2.shape[0]]) # bsz, num_head
|
|
@@ -661,10 +658,10 @@ class MotifFlashAttention2(MotifAttention):
|
|
| 661 |
raise ValueError(f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
| 662 |
f" {attn_output.size()}")
|
| 663 |
|
| 664 |
-
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
| 665 |
attn_output = self.o_proj(attn_output) * self.o_proj_alpha
|
| 666 |
|
| 667 |
-
return attn_output, None, past_key_value
|
| 668 |
|
| 669 |
|
| 670 |
# @log_timing
|
|
|
|
| 545 |
|
| 546 |
bsz = query_states.shape[0]
|
| 547 |
|
| 548 |
+
return _flash_attention_forward(query_states.bfloat16(),
|
|
|
|
|
|
|
| 549 |
key_states.bfloat16(),
|
| 550 |
value_states.bfloat16(),
|
| 551 |
attention_mask,
|
|
|
|
| 555 |
sliding_window=sliding_window,
|
| 556 |
is_causal=self.is_causal,
|
| 557 |
use_top_left_mask=self._flash_attn_uses_top_left_mask)
|
|
|
|
| 558 |
|
| 559 |
def forward(
|
| 560 |
self,
|
|
|
|
| 639 |
self._compute_attention(q2, k2, v2, attention_mask, q_len, position_ids, dropout_rate, sliding_window, self.batch_num)
|
| 640 |
|
| 641 |
|
| 642 |
+
attn1, attn2 = torch.cat([attn11, attn12], dim=-1).float(), torch.cat([attn21, attn22], dim=-1).float()
|
| 643 |
|
| 644 |
lambda_q1 = self.lambda_q1.unsqueeze(0).expand([bsz, self.lambda_q1.shape[0]]) # bsz, num_head
|
| 645 |
lambda_q2 = self.lambda_q2.unsqueeze(0).expand([bsz, self.lambda_q2.shape[0]]) # bsz, num_head
|
|
|
|
| 658 |
raise ValueError(f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
| 659 |
f" {attn_output.size()}")
|
| 660 |
|
| 661 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).bfloat16()
|
| 662 |
attn_output = self.o_proj(attn_output) * self.o_proj_alpha
|
| 663 |
|
| 664 |
+
return attn_output.float(), None, past_key_value
|
| 665 |
|
| 666 |
|
| 667 |
# @log_timing
|