Update modeling_gritlm7b.py
Browse files- modeling_gritlm7b.py +3 -2
modeling_gritlm7b.py
CHANGED
|
@@ -46,12 +46,13 @@ from ...utils import (
|
|
| 46 |
from .configuration_mistral import MistralConfig
|
| 47 |
|
| 48 |
|
| 49 |
-
|
| 50 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
| 51 |
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
| 52 |
|
| 53 |
_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
|
| 54 |
-
|
|
|
|
| 55 |
|
| 56 |
logger = logging.get_logger(__name__)
|
| 57 |
|
|
|
|
| 46 |
from .configuration_mistral import MistralConfig
|
| 47 |
|
| 48 |
|
| 49 |
+
try:
|
| 50 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
| 51 |
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
| 52 |
|
| 53 |
_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
|
| 54 |
+
except:
|
| 55 |
+
pass
|
| 56 |
|
| 57 |
logger = logging.get_logger(__name__)
|
| 58 |
|