File size: 3,752 Bytes
fb11af9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import enum
from contextlib import nullcontext
from typing import Tuple, Union

import torch
from torch.autograd.graph import saved_tensors_hooks


class OffloadPolicy(enum.Enum):
    OFFLOAD = 0
    KEEP_ON_GPU = 1
    IGNORE = 2


class custom_save_on_cpu(saved_tensors_hooks):
    def __init__(self, gpu_limit_in_gb: float = 0, pin_memory: bool = False, min_offload_size: int = 1024) -> None:
        self.cur_gpu_ram_in_mb = 0.0

        def pack_to_cpu(tensor: torch.Tensor) -> Tuple[OffloadPolicy, torch.device, torch.Tensor]:
            tensor_num_bytes = tensor.element_size() * tensor.nelement()
            # heuristic to skip nn.Linear.weight
            if type(tensor.grad_fn).__name__ == "TBackward0" or tensor_num_bytes <= min_offload_size:
                return (OffloadPolicy.IGNORE, tensor.device, tensor)

            if self.cur_gpu_ram_in_mb < gpu_limit_in_gb * 1024:
                self.cur_gpu_ram_in_mb += tensor_num_bytes / 1024 / 1024
                return (OffloadPolicy.KEEP_ON_GPU, tensor.device, tensor)

            if not pin_memory:
                return (OffloadPolicy.OFFLOAD, tensor.device, tensor.cpu())

            packed = torch.empty(
                tensor.size(),
                dtype=tensor.dtype,
                layout=tensor.layout,
                pin_memory=(not tensor.is_sparse),
            )
            packed.copy_(tensor)
            return (OffloadPolicy.OFFLOAD, tensor.device, packed)

        def unpack_from_cpu(packed: Tuple[OffloadPolicy, torch.device, torch.Tensor]) -> torch.Tensor:
            offload_policy, device, tensor = packed

            if offload_policy == OffloadPolicy.IGNORE:
                return tensor
            elif offload_policy == OffloadPolicy.KEEP_ON_GPU:
                tensor_num_bytes = tensor.element_size() * tensor.nelement()
                self.cur_gpu_ram_in_mb -= tensor_num_bytes / 1024 / 1024
                return tensor
            else:
                return tensor.to(device, non_blocking=pin_memory)

        super().__init__(pack_to_cpu, unpack_from_cpu)


def build_activation_offloading_context(
    enable_activation_offload: bool = False,
    enable_gradient_checkpointing: bool = False,
    activation_gpu_limit: float = 0.0,
) -> Tuple[Union["saved_tensors_hooks", "nullcontext"], Union["saved_tensors_hooks", "nullcontext"]]:
    model_fwd_context, model_bwd_context = nullcontext(), nullcontext()
    if enable_activation_offload:
        # pin_memory=False since CachingHostAllocator caches pinned memory aggressively.
        # torch._C._host_emptyCache() can be used after version 2.5.
        if enable_gradient_checkpointing:
            # inter-layer activations are always offloaded when enabling gradient checkpointing to avoid potential thrashing
            model_fwd_context = custom_save_on_cpu(gpu_limit_in_gb=0.0, pin_memory=False)
            model_bwd_context = custom_save_on_cpu(gpu_limit_in_gb=activation_gpu_limit, pin_memory=False)
        else:
            model_fwd_context = custom_save_on_cpu(gpu_limit_in_gb=activation_gpu_limit, pin_memory=False)

    return model_fwd_context, model_bwd_context