ycchen commited on
Commit
8b786f3
·
1 Parent(s): 0d5fc6a

Update modeling_qwen.py

Browse files
Files changed (1) hide show
  1. modeling_qwen.py +8 -1
modeling_qwen.py CHANGED
@@ -784,6 +784,9 @@ class QWenModel(QWenPreTrainedModel):
784
 
785
  self.post_init()
786
 
 
 
 
787
  def get_input_embeddings(self):
788
  return self.wte
789
 
@@ -926,8 +929,12 @@ class QWenModel(QWenPreTrainedModel):
926
  if output_hidden_states:
927
  all_hidden_states = all_hidden_states + (hidden_states,)
928
 
929
- if self.gradient_checkpointing and self.training:
 
 
 
930
 
 
931
  def create_custom_forward(module):
932
  def custom_forward(*inputs):
933
  # None for past_key_value
 
784
 
785
  self.post_init()
786
 
787
+ # BUG: hardcode
788
+ self.skip_checkpointing_layer_ids = list(range(30))
789
+
790
  def get_input_embeddings(self):
791
  return self.wte
792
 
 
929
  if output_hidden_states:
930
  all_hidden_states = all_hidden_states + (hidden_states,)
931
 
932
+ # BUG: not work
933
+ forward_checkpointing = (self.gradient_checkpointing and self.training)
934
+ if self.skip_checkpointing_layer_ids is not None and i in self.skip_checkpointing_layer_ids:
935
+ forward_checkpointing = False
936
 
937
+ if forward_checkpointing:
938
  def create_custom_forward(module):
939
  def custom_forward(*inputs):
940
  # None for past_key_value