ml-sharp / src /sharp /utils /training.py
amael-apple's picture
Initial commit
c20d7cc
raw
history blame contribute delete
716 Bytes
"""Utility functions for training.
For licensing see accompanying LICENSE file.
Copyright (C) 2025 Apple Inc. All Rights Reserved.
"""
from torch.utils.checkpoint import checkpoint
def checkpoint_wrapper(self, fn, *args):
"""Helper function that applies checkpointing.
If enabled applies grad checkpointing, otherwise just executes the function normally.
"""
if not hasattr(self, "grad_checkpointing"):
raise AttributeError(
"Trying to apply grad checkpointing on a model that does not have a grad_checkpointing "
"attribute."
)
if self.grad_checkpointing:
return checkpoint(fn, *args, use_reentrant=False)
else:
return fn(*args)