Update modeling_qwen.py
Browse files- modeling_qwen.py +8 -1
modeling_qwen.py
CHANGED
|
@@ -784,6 +784,9 @@ class QWenModel(QWenPreTrainedModel):
|
|
| 784 |
|
| 785 |
self.post_init()
|
| 786 |
|
|
|
|
|
|
|
|
|
|
| 787 |
def get_input_embeddings(self):
|
| 788 |
return self.wte
|
| 789 |
|
|
@@ -926,8 +929,12 @@ class QWenModel(QWenPreTrainedModel):
|
|
| 926 |
if output_hidden_states:
|
| 927 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 928 |
|
| 929 |
-
|
|
|
|
|
|
|
|
|
|
| 930 |
|
|
|
|
| 931 |
def create_custom_forward(module):
|
| 932 |
def custom_forward(*inputs):
|
| 933 |
# None for past_key_value
|
|
|
|
| 784 |
|
| 785 |
self.post_init()
|
| 786 |
|
| 787 |
+
# BUG: hardcode
|
| 788 |
+
self.skip_checkpointing_layer_ids = list(range(30))
|
| 789 |
+
|
| 790 |
def get_input_embeddings(self):
|
| 791 |
return self.wte
|
| 792 |
|
|
|
|
| 929 |
if output_hidden_states:
|
| 930 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 931 |
|
| 932 |
+
# BUG: not work
|
| 933 |
+
forward_checkpointing = (self.gradient_checkpointing and self.training)
|
| 934 |
+
if self.skip_checkpointing_layer_ids is not None and i in self.skip_checkpointing_layer_ids:
|
| 935 |
+
forward_checkpointing = False
|
| 936 |
|
| 937 |
+
if forward_checkpointing:
|
| 938 |
def create_custom_forward(module):
|
| 939 |
def custom_forward(*inputs):
|
| 940 |
# None for past_key_value
|