KalvinPhan commited on
Commit
6bf52ea
·
verified ·
1 Parent(s): 5968185

Update modeling_internlm2.py

Browse files
Files changed (1) hide show
  1. modeling_internlm2.py +14 -7
modeling_internlm2.py CHANGED
@@ -46,8 +46,11 @@ logger = logging.get_logger(__name__)
46
 
47
  _CONFIG_FOR_DOC = 'InternLM2Config'
48
 
 
 
49
  flash_attn_func, flash_attn_varlen_func = None, None
50
  pad_input, index_first_axis, unpad_input = None, None, None
 
51
  try:
52
  from flash_attn import flash_attn_func as _flash_attn_func
53
  from flash_attn import flash_attn_varlen_func as _flash_attn_varlen_func
@@ -58,25 +61,29 @@ try:
58
  flash_attn_func, flash_attn_varlen_func = _flash_attn_func, _flash_attn_varlen_func
59
  pad_input, index_first_axis, unpad_input = _pad_input, _index_first_axis, _unpad_input
60
  has_flash_attn = True
61
- except:
 
62
  has_flash_attn = False
63
-
64
 
65
  def _import_flash_attn():
 
66
  global flash_attn_func, flash_attn_varlen_func
67
  global pad_input, index_first_axis, unpad_input
68
  try:
69
  from flash_attn import flash_attn_func as _flash_attn_func
70
- from flash_attn import \
71
- flash_attn_varlen_func as _flash_attn_varlen_func
72
- from flash_attn.bert_padding import \
73
- index_first_axis as _index_first_axis
74
  from flash_attn.bert_padding import pad_input as _pad_input
75
  from flash_attn.bert_padding import unpad_input as _unpad_input
76
  flash_attn_func, flash_attn_varlen_func = _flash_attn_func, _flash_attn_varlen_func
77
  pad_input, index_first_axis, unpad_input = _pad_input, _index_first_axis, _unpad_input
 
78
  except ImportError:
79
- raise ImportError('flash_attn is not installed.')
 
 
 
80
 
81
 
82
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
 
46
 
47
  _CONFIG_FOR_DOC = 'InternLM2Config'
48
 
49
+
50
+ # --- PATCH: Safe FlashAttention import ---
51
  flash_attn_func, flash_attn_varlen_func = None, None
52
  pad_input, index_first_axis, unpad_input = None, None, None
53
+
54
  try:
55
  from flash_attn import flash_attn_func as _flash_attn_func
56
  from flash_attn import flash_attn_varlen_func as _flash_attn_varlen_func
 
61
  flash_attn_func, flash_attn_varlen_func = _flash_attn_func, _flash_attn_varlen_func
62
  pad_input, index_first_axis, unpad_input = _pad_input, _index_first_axis, _unpad_input
63
  has_flash_attn = True
64
+ print("[INFO] FlashAttention detected and enabled.")
65
+ except Exception as e:
66
  has_flash_attn = False
67
+ print(f"[WARNING] FlashAttention not available ({e}). Using PyTorch scaled_dot_product_attention instead.")
68
 
69
  def _import_flash_attn():
70
+ """Safe import for FlashAttention; if not available, fallback to torch attention."""
71
  global flash_attn_func, flash_attn_varlen_func
72
  global pad_input, index_first_axis, unpad_input
73
  try:
74
  from flash_attn import flash_attn_func as _flash_attn_func
75
+ from flash_attn import flash_attn_varlen_func as _flash_attn_varlen_func
76
+ from flash_attn.bert_padding import index_first_axis as _index_first_axis
 
 
77
  from flash_attn.bert_padding import pad_input as _pad_input
78
  from flash_attn.bert_padding import unpad_input as _unpad_input
79
  flash_attn_func, flash_attn_varlen_func = _flash_attn_func, _flash_attn_varlen_func
80
  pad_input, index_first_axis, unpad_input = _pad_input, _index_first_axis, _unpad_input
81
+ print("[INFO] FlashAttention successfully imported.")
82
  except ImportError:
83
+ print("[WARNING] flash_attn is not installed. Continuing with standard attention.")
84
+ flash_attn_func = None
85
+ flash_attn_varlen_func = None
86
+ pad_input = index_first_axis = unpad_input = None
87
 
88
 
89
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data