"""Utility functions for training. For licensing see accompanying LICENSE file. Copyright (C) 2025 Apple Inc. All Rights Reserved. """ from torch.utils.checkpoint import checkpoint def checkpoint_wrapper(self, fn, *args): """Helper function that applies checkpointing. If enabled applies grad checkpointing, otherwise just executes the function normally. """ if not hasattr(self, "grad_checkpointing"): raise AttributeError( "Trying to apply grad checkpointing on a model that does not have a grad_checkpointing " "attribute." ) if self.grad_checkpointing: return checkpoint(fn, *args, use_reentrant=False) else: return fn(*args)