File size: 716 Bytes
c20d7cc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
"""Utility functions for training.
For licensing see accompanying LICENSE file.
Copyright (C) 2025 Apple Inc. All Rights Reserved.
"""
from torch.utils.checkpoint import checkpoint
def checkpoint_wrapper(self, fn, *args):
"""Helper function that applies checkpointing.
If enabled applies grad checkpointing, otherwise just executes the function normally.
"""
if not hasattr(self, "grad_checkpointing"):
raise AttributeError(
"Trying to apply grad checkpointing on a model that does not have a grad_checkpointing "
"attribute."
)
if self.grad_checkpointing:
return checkpoint(fn, *args, use_reentrant=False)
else:
return fn(*args)
|