Update bert_layers.py
Browse files- bert_layers.py +3 -2
bert_layers.py
CHANGED
|
@@ -27,8 +27,9 @@ from bert_padding import (index_first_axis, index_put_first_axis, pad_input, unp
|
|
| 27 |
#from bert_padding_module import (index_first_axis, index_put_first_axis, pad_input, unpad_input, unpad_input_only)
|
| 28 |
|
| 29 |
try:
|
| 30 |
-
|
| 31 |
-
#
|
|
|
|
| 32 |
except ImportError as e:
|
| 33 |
flash_attn_qkvpacked_func = None
|
| 34 |
|
|
|
|
| 27 |
#from bert_padding_module import (index_first_axis, index_put_first_axis, pad_input, unpad_input, unpad_input_only)
|
| 28 |
|
| 29 |
try:
|
| 30 |
+
# Force disable triton flash attn due to API incompatibility (trans_b argument) or dtype issues
|
| 31 |
+
# from .flash_attn_triton import flash_attn_qkvpacked_func
|
| 32 |
+
flash_attn_qkvpacked_func = None
|
| 33 |
except ImportError as e:
|
| 34 |
flash_attn_qkvpacked_func = None
|
| 35 |
|