| """Utility functions for training. | |
| For licensing see accompanying LICENSE file. | |
| Copyright (C) 2025 Apple Inc. All Rights Reserved. | |
| """ | |
| from torch.utils.checkpoint import checkpoint | |
| def checkpoint_wrapper(self, fn, *args): | |
| """Helper function that applies checkpointing. | |
| If enabled applies grad checkpointing, otherwise just executes the function normally. | |
| """ | |
| if not hasattr(self, "grad_checkpointing"): | |
| raise AttributeError( | |
| "Trying to apply grad checkpointing on a model that does not have a grad_checkpointing " | |
| "attribute." | |
| ) | |
| if self.grad_checkpointing: | |
| return checkpoint(fn, *args, use_reentrant=False) | |
| else: | |
| return fn(*args) | |