| """custom checkpointing utils""" | |
| from axolotl.utils.gradient_checkpointing.unsloth import ( | |
| Unsloth_Offloaded_Gradient_Checkpointer, | |
| ) | |
| def hf_grad_checkpoint_unsloth_wrapper( | |
| decoder_layer, *args, use_reentrant=None | |
| ): # pylint: disable=unused-argument | |
| return Unsloth_Offloaded_Gradient_Checkpointer.apply( | |
| decoder_layer.__self__, | |
| *args, | |
| ) | |