fix: restore _set_gradient_checkpointing + enable_input_require_grads for gradient checkpointing
Browse files
root_gainlora/runtest_colab.ipynb
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
root_gainlora/src/run_t5.py
CHANGED
|
@@ -734,6 +734,7 @@ def main():
|
|
| 734 |
print(f"-----Gradient checkpointing: {training_args.gradient_checkpointing} -----")
|
| 735 |
if training_args.gradient_checkpointing:
|
| 736 |
model.gradient_checkpointing_enable()
|
|
|
|
| 737 |
|
| 738 |
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
| 739 |
training_args.step_per_epoch = math.ceil(len(raw_datasets["train"]) / training_args.per_device_train_batch_size / world_size / training_args.gradient_accumulation_steps)
|
|
|
|
| 734 |
print(f"-----Gradient checkpointing: {training_args.gradient_checkpointing} -----")
|
| 735 |
if training_args.gradient_checkpointing:
|
| 736 |
model.gradient_checkpointing_enable()
|
| 737 |
+
model.enable_input_require_grads()
|
| 738 |
|
| 739 |
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
| 740 |
training_args.step_per_epoch = math.ceil(len(raw_datasets["train"]) / training_args.per_device_train_batch_size / world_size / training_args.gradient_accumulation_steps)
|
root_gainlora/src/t5_gainlora_inflora.py
CHANGED
|
@@ -1006,11 +1006,9 @@ class T5PreTrainedModel(PreTrainedModel):
|
|
| 1006 |
if module.has_relative_attention_bias:
|
| 1007 |
module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
|
| 1008 |
|
| 1009 |
-
|
| 1010 |
-
|
| 1011 |
-
|
| 1012 |
-
# Without this method, transformers uses the new format which properly
|
| 1013 |
-
# passes the checkpointing function with use_reentrant=False.
|
| 1014 |
|
| 1015 |
def _shift_right(self, input_ids):
|
| 1016 |
decoder_start_token_id = self.config.decoder_start_token_id
|
|
|
|
| 1006 |
if module.has_relative_attention_bias:
|
| 1007 |
module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
|
| 1008 |
|
| 1009 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
| 1010 |
+
if isinstance(module, (T5Attention, T5Stack)):
|
| 1011 |
+
module.gradient_checkpointing = value
|
|
|
|
|
|
|
| 1012 |
|
| 1013 |
def _shift_right(self, input_ids):
|
| 1014 |
decoder_start_token_id = self.config.decoder_start_token_id
|