Upload modeling_fegeo_llama.py
Browse files- modeling_fegeo_llama.py +5 -2
modeling_fegeo_llama.py
CHANGED
|
@@ -898,8 +898,11 @@ from .configuration_fegeo_llama import LlamaConfig
|
|
| 898 |
|
| 899 |
|
| 900 |
if is_flash_attn_2_available():
|
| 901 |
-
|
| 902 |
-
|
|
|
|
|
|
|
|
|
|
| 903 |
|
| 904 |
|
| 905 |
logger = logging.get_logger(__name__)
|
|
|
|
| 898 |
|
| 899 |
|
| 900 |
if is_flash_attn_2_available():
|
| 901 |
+
try:
|
| 902 |
+
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
| 903 |
+
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
| 904 |
+
except:
|
| 905 |
+
flash_attn_func, flash_attn_varlen_func, index_first_axis, pad_input, unpad_input = None, None, None, None, None
|
| 906 |
|
| 907 |
|
| 908 |
logger = logging.get_logger(__name__)
|