File size: 716 Bytes
c20d7cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
"""Utility functions for training.

For licensing see accompanying LICENSE file.
Copyright (C) 2025 Apple Inc. All Rights Reserved.
"""

from torch.utils.checkpoint import checkpoint


def checkpoint_wrapper(self, fn, *args):
    """Helper function that applies checkpointing.

    If enabled applies grad checkpointing, otherwise just executes the function normally.
    """
    if not hasattr(self, "grad_checkpointing"):
        raise AttributeError(
            "Trying to apply grad checkpointing on a model that does not have a grad_checkpointing "
            "attribute."
        )

    if self.grad_checkpointing:
        return checkpoint(fn, *args, use_reentrant=False)
    else:
        return fn(*args)