natmin322 commited on
Commit
2b87f4b
·
1 Parent(s): 008c76c

fix: restore _set_gradient_checkpointing + enable_input_require_grads for gradient checkpointing

Browse files
root_gainlora/runtest_colab.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
root_gainlora/src/run_t5.py CHANGED
@@ -734,6 +734,7 @@ def main():
734
  print(f"-----Gradient checkpointing: {training_args.gradient_checkpointing} -----")
735
  if training_args.gradient_checkpointing:
736
  model.gradient_checkpointing_enable()
 
737
 
738
  world_size = int(os.environ.get("WORLD_SIZE", 1))
739
  training_args.step_per_epoch = math.ceil(len(raw_datasets["train"]) / training_args.per_device_train_batch_size / world_size / training_args.gradient_accumulation_steps)
 
734
  print(f"-----Gradient checkpointing: {training_args.gradient_checkpointing} -----")
735
  if training_args.gradient_checkpointing:
736
  model.gradient_checkpointing_enable()
737
+ model.enable_input_require_grads()
738
 
739
  world_size = int(os.environ.get("WORLD_SIZE", 1))
740
  training_args.step_per_epoch = math.ceil(len(raw_datasets["train"]) / training_args.per_device_train_batch_size / world_size / training_args.gradient_accumulation_steps)
root_gainlora/src/t5_gainlora_inflora.py CHANGED
@@ -1006,11 +1006,9 @@ class T5PreTrainedModel(PreTrainedModel):
1006
  if module.has_relative_attention_bias:
1007
  module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
1008
 
1009
- # NOTE: _set_gradient_checkpointing removed intentionally.
1010
- # The old format (with 'value' param) causes transformers to silently ignore
1011
- # gradient_checkpointing_kwargs (including use_reentrant=False).
1012
- # Without this method, transformers uses the new format which properly
1013
- # passes the checkpointing function with use_reentrant=False.
1014
 
1015
  def _shift_right(self, input_ids):
1016
  decoder_start_token_id = self.config.decoder_start_token_id
 
1006
  if module.has_relative_attention_bias:
1007
  module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
1008
 
1009
+ def _set_gradient_checkpointing(self, module, value=False):
1010
+ if isinstance(module, (T5Attention, T5Stack)):
1011
+ module.gradient_checkpointing = value
 
 
1012
 
1013
  def _shift_right(self, input_ids):
1014
  decoder_start_token_id = self.config.decoder_start_token_id