v10a
Browse files
improve_gainlora/src/run_t5.py
CHANGED
|
@@ -839,10 +839,11 @@ def main():
|
|
| 839 |
return result
|
| 840 |
print(f"-----Gradient checkpointing: {training_args.gradient_checkpointing} -----")
|
| 841 |
if training_args.gradient_checkpointing:
|
| 842 |
-
# use_reentrant=
|
| 843 |
-
#
|
|
|
|
| 844 |
model.gradient_checkpointing_enable(
|
| 845 |
-
gradient_checkpointing_kwargs={"use_reentrant":
|
| 846 |
)
|
| 847 |
model.enable_input_require_grads()
|
| 848 |
|
|
|
|
| 839 |
return result
|
| 840 |
print(f"-----Gradient checkpointing: {training_args.gradient_checkpointing} -----")
|
| 841 |
if training_args.gradient_checkpointing:
|
| 842 |
+
# use_reentrant=True: allows backward through graph 2x (needed for SpecRoute + PEFT)
|
| 843 |
+
# When key_attention_weights passes through checkpointed layers, reentrant mode
|
| 844 |
+
# handles the complex computation graph without "backward second time" errors
|
| 845 |
model.gradient_checkpointing_enable(
|
| 846 |
+
gradient_checkpointing_kwargs={"use_reentrant": True}
|
| 847 |
)
|
| 848 |
model.enable_input_require_grads()
|
| 849 |
|