Update modeling_motif.py
Browse files- modeling_motif.py +13 -10
modeling_motif.py
CHANGED
|
@@ -545,16 +545,19 @@ class MotifFlashAttention2(MotifAttention):
|
|
| 545 |
|
| 546 |
bsz = query_states.shape[0]
|
| 547 |
|
| 548 |
-
return
|
| 549 |
-
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
|
|
|
|
|
|
|
|
|
|
| 558 |
|
| 559 |
def forward(
|
| 560 |
self,
|
|
|
|
| 545 |
|
| 546 |
bsz = query_states.shape[0]
|
| 547 |
|
| 548 |
+
return map(
|
| 549 |
+
lambda x: x.float32(),
|
| 550 |
+
_flash_attention_forward(query_states.bfloat16(),
|
| 551 |
+
key_states.bfloat16(),
|
| 552 |
+
value_states.bfloat16(),
|
| 553 |
+
attention_mask,
|
| 554 |
+
q_len,
|
| 555 |
+
position_ids=position_ids,
|
| 556 |
+
dropout=dropout_rate,
|
| 557 |
+
sliding_window=sliding_window,
|
| 558 |
+
is_causal=self.is_causal,
|
| 559 |
+
use_top_left_mask=self._flash_attn_uses_top_left_mask)
|
| 560 |
+
)
|
| 561 |
|
| 562 |
def forward(
|
| 563 |
self,
|