vinayh19 commited on
Commit
29b747c
·
verified ·
1 Parent(s): 1346bc8

Update bert_layers.py

Browse files
Files changed (1) hide show
  1. bert_layers.py +3 -2
bert_layers.py CHANGED
@@ -27,8 +27,9 @@ from bert_padding import (index_first_axis, index_put_first_axis, pad_input, unp
27
  #from bert_padding_module import (index_first_axis, index_put_first_axis, pad_input, unpad_input, unpad_input_only)
28
 
29
  try:
30
- from .flash_attn_triton import flash_attn_qkvpacked_func
31
- #flash_attn_qkvpacked_func = flash_attn_triton.flash_attn_qkvpacked_func
 
32
  except ImportError as e:
33
  flash_attn_qkvpacked_func = None
34
 
 
27
  #from bert_padding_module import (index_first_axis, index_put_first_axis, pad_input, unpad_input, unpad_input_only)
28
 
29
  try:
30
+ # Force disable triton flash attn due to API incompatibility (trans_b argument) or dtype issues
31
+ # from .flash_attn_triton import flash_attn_qkvpacked_func
32
+ flash_attn_qkvpacked_func = None
33
  except ImportError as e:
34
  flash_attn_qkvpacked_func = None
35