x54-729
commited on
Commit
·
7f2ea77
1
Parent(s):
05e5c8c
fix import error
Browse files- modeling_internlm.py +19 -5
modeling_internlm.py
CHANGED
|
@@ -48,6 +48,20 @@ logger = logging.get_logger(__name__)
|
|
| 48 |
|
| 49 |
_CONFIG_FOR_DOC = "InternLMConfig"
|
| 50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
def _get_unpad_data(attention_mask):
|
| 52 |
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
| 53 |
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
|
@@ -438,13 +452,11 @@ class InternLMFlashAttention2(InternLMAttention):
|
|
| 438 |
softmax_scale (`float`, *optional*):
|
| 439 |
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
| 440 |
"""
|
| 441 |
-
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
| 442 |
-
from flash_attn.bert_padding import pad_input
|
| 443 |
# Contains at least one padding token in the sequence
|
| 444 |
causal = self.is_causal and query_length != 1
|
| 445 |
if attention_mask is not None:
|
| 446 |
batch_size = query_states.shape[0]
|
| 447 |
-
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self.
|
| 448 |
query_states, key_states, value_states, attention_mask, query_length
|
| 449 |
)
|
| 450 |
|
|
@@ -472,8 +484,7 @@ class InternLMFlashAttention2(InternLMAttention):
|
|
| 472 |
|
| 473 |
return attn_output
|
| 474 |
|
| 475 |
-
def
|
| 476 |
-
from flash_attn.bert_padding import index_first_axis, unpad_input
|
| 477 |
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
| 478 |
batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
|
| 479 |
|
|
@@ -762,6 +773,9 @@ class InternLMModel(InternLMPreTrainedModel):
|
|
| 762 |
|
| 763 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 764 |
|
|
|
|
|
|
|
|
|
|
| 765 |
# retrieve input_ids and inputs_embeds
|
| 766 |
if input_ids is not None and inputs_embeds is not None:
|
| 767 |
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
|
|
|
| 48 |
|
| 49 |
_CONFIG_FOR_DOC = "InternLMConfig"
|
| 50 |
|
| 51 |
+
flash_attn_func, flash_attn_varlen_func = None, None
|
| 52 |
+
pad_input, index_first_axis, unpad_input = None, None, None
|
| 53 |
+
def _import_flash_attn():
|
| 54 |
+
global flash_attn_func, flash_attn_varlen_func
|
| 55 |
+
global pad_input, index_first_axis, unpad_input
|
| 56 |
+
try:
|
| 57 |
+
from flash_attn import flash_attn_func as _flash_attn_func, flash_attn_varlen_func as _flash_attn_varlen_func
|
| 58 |
+
from flash_attn.bert_padding import pad_input as _pad_input, index_first_axis as _index_first_axis, unpad_input as _unpad_input
|
| 59 |
+
flash_attn_func, flash_attn_varlen_func = _flash_attn_func, _flash_attn_varlen_func
|
| 60 |
+
pad_input, index_first_axis, unpad_input = _pad_input, _index_first_axis, _unpad_input
|
| 61 |
+
except ImportError:
|
| 62 |
+
raise ImportError("flash_attn is not installed.")
|
| 63 |
+
|
| 64 |
+
|
| 65 |
def _get_unpad_data(attention_mask):
|
| 66 |
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
| 67 |
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
|
|
|
| 452 |
softmax_scale (`float`, *optional*):
|
| 453 |
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
| 454 |
"""
|
|
|
|
|
|
|
| 455 |
# Contains at least one padding token in the sequence
|
| 456 |
causal = self.is_causal and query_length != 1
|
| 457 |
if attention_mask is not None:
|
| 458 |
batch_size = query_states.shape[0]
|
| 459 |
+
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._unpad_input(
|
| 460 |
query_states, key_states, value_states, attention_mask, query_length
|
| 461 |
)
|
| 462 |
|
|
|
|
| 484 |
|
| 485 |
return attn_output
|
| 486 |
|
| 487 |
+
def _unpad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
|
|
|
|
| 488 |
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
| 489 |
batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
|
| 490 |
|
|
|
|
| 773 |
|
| 774 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 775 |
|
| 776 |
+
if self.config.attn_implementation == "flash_attention_2":
|
| 777 |
+
_import_flash_attn()
|
| 778 |
+
|
| 779 |
# retrieve input_ids and inputs_embeds
|
| 780 |
if input_ids is not None and inputs_embeds is not None:
|
| 781 |
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|