KalvinPhan commited on
Commit
a993899
·
verified ·
1 Parent(s): 6bf52ea

Update modeling_internlm2.py

Browse files
Files changed (1) hide show
  1. modeling_internlm2.py +30 -24
modeling_internlm2.py CHANGED
@@ -46,44 +46,50 @@ logger = logging.get_logger(__name__)
46
 
47
  _CONFIG_FOR_DOC = 'InternLM2Config'
48
 
49
-
50
- # --- PATCH: Safe FlashAttention import ---
51
  flash_attn_func, flash_attn_varlen_func = None, None
52
  pad_input, index_first_axis, unpad_input = None, None, None
 
53
 
54
  try:
55
- from flash_attn import flash_attn_func as _flash_attn_func
56
- from flash_attn import flash_attn_varlen_func as _flash_attn_varlen_func
57
- from flash_attn.bert_padding import index_first_axis as _index_first_axis
58
- from flash_attn.bert_padding import pad_input as _pad_input
59
- from flash_attn.bert_padding import unpad_input as _unpad_input
60
-
61
- flash_attn_func, flash_attn_varlen_func = _flash_attn_func, _flash_attn_varlen_func
62
- pad_input, index_first_axis, unpad_input = _pad_input, _index_first_axis, _unpad_input
63
- has_flash_attn = True
64
- print("[INFO] FlashAttention detected and enabled.")
 
 
 
 
 
65
  except Exception as e:
66
- has_flash_attn = False
67
- print(f"[WARNING] FlashAttention not available ({e}). Using PyTorch scaled_dot_product_attention instead.")
68
 
69
  def _import_flash_attn():
70
- """Safe import for FlashAttention; if not available, fallback to torch attention."""
71
  global flash_attn_func, flash_attn_varlen_func
72
- global pad_input, index_first_axis, unpad_input
73
  try:
74
  from flash_attn import flash_attn_func as _flash_attn_func
75
  from flash_attn import flash_attn_varlen_func as _flash_attn_varlen_func
76
- from flash_attn.bert_padding import index_first_axis as _index_first_axis
77
- from flash_attn.bert_padding import pad_input as _pad_input
78
- from flash_attn.bert_padding import unpad_input as _unpad_input
 
 
79
  flash_attn_func, flash_attn_varlen_func = _flash_attn_func, _flash_attn_varlen_func
80
  pad_input, index_first_axis, unpad_input = _pad_input, _index_first_axis, _unpad_input
81
- print("[INFO] FlashAttention successfully imported.")
 
82
  except ImportError:
83
- print("[WARNING] flash_attn is not installed. Continuing with standard attention.")
84
- flash_attn_func = None
85
- flash_attn_varlen_func = None
86
- pad_input = index_first_axis = unpad_input = None
87
 
88
 
89
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
 
46
 
47
  _CONFIG_FOR_DOC = 'InternLM2Config'
48
 
49
+ # --- PATCH: Safe FlashAttention import (for Kaggle/Colab without flash_attn) ---
 
50
  flash_attn_func, flash_attn_varlen_func = None, None
51
  pad_input, index_first_axis, unpad_input = None, None, None
52
+ has_flash_attn = False # default = False
53
 
54
  try:
55
+ import importlib.util
56
+ if importlib.util.find_spec("flash_attn") is not None:
57
+ from flash_attn import flash_attn_func as _flash_attn_func
58
+ from flash_attn import flash_attn_varlen_func as _flash_attn_varlen_func
59
+ from flash_attn.bert_padding import (
60
+ index_first_axis as _index_first_axis,
61
+ pad_input as _pad_input,
62
+ unpad_input as _unpad_input,
63
+ )
64
+ flash_attn_func, flash_attn_varlen_func = _flash_attn_func, _flash_attn_varlen_func
65
+ pad_input, index_first_axis, unpad_input = _pad_input, _index_first_axis, _unpad_input
66
+ has_flash_attn = True
67
+ print("[INFO] FlashAttention detected and enabled.")
68
+ else:
69
+ print("[INFO] FlashAttention not installed. Using PyTorch attention instead.")
70
  except Exception as e:
71
+ print(f"[WARNING] Failed to import flash_attn ({e}). Using PyTorch attention fallback.")
 
72
 
73
  def _import_flash_attn():
74
+ """Safe re-import; ignored if flash_attn is missing."""
75
  global flash_attn_func, flash_attn_varlen_func
76
+ global pad_input, index_first_axis, unpad_input, has_flash_attn
77
  try:
78
  from flash_attn import flash_attn_func as _flash_attn_func
79
  from flash_attn import flash_attn_varlen_func as _flash_attn_varlen_func
80
+ from flash_attn.bert_padding import (
81
+ index_first_axis as _index_first_axis,
82
+ pad_input as _pad_input,
83
+ unpad_input as _unpad_input,
84
+ )
85
  flash_attn_func, flash_attn_varlen_func = _flash_attn_func, _flash_attn_varlen_func
86
  pad_input, index_first_axis, unpad_input = _pad_input, _index_first_axis, _unpad_input
87
+ has_flash_attn = True
88
+ print("[INFO] FlashAttention successfully re-imported.")
89
  except ImportError:
90
+ has_flash_attn = False
91
+ print("[WARNING] flash_attn not installed. Using standard torch.nn.functional.scaled_dot_product_attention.")
92
+
 
93
 
94
 
95
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data