Upload modeling_fegeo_qwen2.py
Browse files- modeling_fegeo_qwen2.py +7 -3
modeling_fegeo_qwen2.py
CHANGED
|
@@ -615,10 +615,14 @@ from .configuration_fegeo_qwen2 import Qwen2Config
|
|
| 615 |
|
| 616 |
|
| 617 |
if is_flash_attn_2_available():
|
| 618 |
-
|
| 619 |
-
|
|
|
|
| 620 |
|
| 621 |
-
|
|
|
|
|
|
|
|
|
|
| 622 |
|
| 623 |
|
| 624 |
logger = logging.get_logger(__name__)
|
|
|
|
| 615 |
|
| 616 |
|
| 617 |
if is_flash_attn_2_available():
|
| 618 |
+
try:
|
| 619 |
+
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
| 620 |
+
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
| 621 |
|
| 622 |
+
_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
|
| 623 |
+
except:
|
| 624 |
+
flash_attn_func, flash_attn_varlen_func, index_first_axis, pad_input, unpad_input = None, None, None, None, None
|
| 625 |
+
|
| 626 |
|
| 627 |
|
| 628 |
logger = logging.get_logger(__name__)
|