fix a bug when flash-attn is not installed
Browse files- modeling_time_moe.py +6 -2
modeling_time_moe.py
CHANGED
|
@@ -16,10 +16,14 @@ from .ts_generation_mixin import TSGenerationMixin
|
|
| 16 |
|
| 17 |
logger = logging.get_logger(__name__)
|
| 18 |
|
| 19 |
-
if is_flash_attn_2_available():
|
|
|
|
|
|
|
|
|
|
| 20 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
| 21 |
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
| 22 |
-
|
|
|
|
| 23 |
|
| 24 |
def _get_unpad_data(attention_mask):
|
| 25 |
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
|
|
|
| 16 |
|
| 17 |
logger = logging.get_logger(__name__)
|
| 18 |
|
| 19 |
+
# if is_flash_attn_2_available():
|
| 20 |
+
# from flash_attn import flash_attn_func, flash_attn_varlen_func
|
| 21 |
+
# from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
| 22 |
+
try:
|
| 23 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
| 24 |
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
| 25 |
+
except:
|
| 26 |
+
pass
|
| 27 |
|
| 28 |
def _get_unpad_data(attention_mask):
|
| 29 |
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|