fix bug when using gradient_checkpointing
Browse files- modeling_telechat.py +4 -3
modeling_telechat.py
CHANGED
|
@@ -43,8 +43,6 @@ except ImportError:
|
|
| 43 |
try:
|
| 44 |
from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_unpadded_func # flashattn2
|
| 45 |
print("# FLASH ATTENTION 2 DETECTED #")
|
| 46 |
-
r
|
| 47 |
-
r
|
| 48 |
except ImportError:
|
| 49 |
print("# NO FLASH ATTENTION DETECTED #")
|
| 50 |
flash_attn_unpadded_func = None
|
|
@@ -857,6 +855,8 @@ class TELECHATTransformer(TELECHATPretrainedModel):
|
|
| 857 |
if output_hidden_states:
|
| 858 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 859 |
|
|
|
|
|
|
|
| 860 |
if self.gradient_checkpointing and self.training:
|
| 861 |
|
| 862 |
if use_cache:
|
|
@@ -880,6 +880,7 @@ class TELECHATTransformer(TELECHATPretrainedModel):
|
|
| 880 |
head_mask[i],
|
| 881 |
encoder_hidden_states,
|
| 882 |
encoder_attention_mask,
|
|
|
|
| 883 |
)
|
| 884 |
else:
|
| 885 |
outputs = block(
|
|
@@ -889,7 +890,7 @@ class TELECHATTransformer(TELECHATPretrainedModel):
|
|
| 889 |
head_mask=head_mask[i],
|
| 890 |
encoder_hidden_states=encoder_hidden_states,
|
| 891 |
encoder_attention_mask=encoder_attention_mask,
|
| 892 |
-
rotary_embedding=
|
| 893 |
use_cache=use_cache,
|
| 894 |
output_attentions=output_attentions
|
| 895 |
)
|
|
|
|
| 43 |
try:
|
| 44 |
from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_unpadded_func # flashattn2
|
| 45 |
print("# FLASH ATTENTION 2 DETECTED #")
|
|
|
|
|
|
|
| 46 |
except ImportError:
|
| 47 |
print("# NO FLASH ATTENTION DETECTED #")
|
| 48 |
flash_attn_unpadded_func = None
|
|
|
|
| 855 |
if output_hidden_states:
|
| 856 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 857 |
|
| 858 |
+
rotary_embedding=self.wpe if self.relative_encoding == 'rotary' else None
|
| 859 |
+
|
| 860 |
if self.gradient_checkpointing and self.training:
|
| 861 |
|
| 862 |
if use_cache:
|
|
|
|
| 880 |
head_mask[i],
|
| 881 |
encoder_hidden_states,
|
| 882 |
encoder_attention_mask,
|
| 883 |
+
rotary_embedding
|
| 884 |
)
|
| 885 |
else:
|
| 886 |
outputs = block(
|
|
|
|
| 890 |
head_mask=head_mask[i],
|
| 891 |
encoder_hidden_states=encoder_hidden_states,
|
| 892 |
encoder_attention_mask=encoder_attention_mask,
|
| 893 |
+
rotary_embedding=rotary_embedding,
|
| 894 |
use_cache=use_cache,
|
| 895 |
output_attentions=output_attentions
|
| 896 |
)
|