natmin322 commited on
Commit
2200936
·
1 Parent(s): f90d880
Files changed (1) hide show
  1. improve_gainlora/src/run_t5.py +4 -3
improve_gainlora/src/run_t5.py CHANGED
@@ -839,10 +839,11 @@ def main():
839
  return result
840
  print(f"-----Gradient checkpointing: {training_args.gradient_checkpointing} -----")
841
  if training_args.gradient_checkpointing:
842
- # use_reentrant=False: don't require input requires_grad=True
843
- # Recommended by PyTorch 2.5+ (will be mandatory in future versions)
 
844
  model.gradient_checkpointing_enable(
845
- gradient_checkpointing_kwargs={"use_reentrant": False}
846
  )
847
  model.enable_input_require_grads()
848
 
 
839
  return result
840
  print(f"-----Gradient checkpointing: {training_args.gradient_checkpointing} -----")
841
  if training_args.gradient_checkpointing:
842
+ # use_reentrant=True: allows backward through graph 2x (needed for SpecRoute + PEFT)
843
+ # When key_attention_weights passes through checkpointed layers, reentrant mode
844
+ # handles the complex computation graph without "backward second time" errors
845
  model.gradient_checkpointing_enable(
846
+ gradient_checkpointing_kwargs={"use_reentrant": True}
847
  )
848
  model.enable_input_require_grads()
849