| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import contextlib |
| | from typing import Union |
| | from warnings import warn |
| |
|
| | import psutil |
| | import torch |
| | from torch import nn |
| | from torch.autograd.graph import saved_tensors_hooks |
| |
|
| | from torchtitan.tools.logging import logger |
| |
|
| | try: |
| | import torchao |
| | from torchao.dtypes.nf4tensor import NF4Tensor |
| | except ImportError: |
| | torchao = None |
| | NF4Tensor = None |
| | logger.warning("torchao not found. ") |
| |
|
| | |
| |
|
| |
|
| | class OffloadActivations(saved_tensors_hooks): |
| | """Context manager under which activation tensors created in the forward pass will be offloaded. |
| | |
| | Enable the memory efficiency technique of activation offloading, where activations bigger than |
| | min_offload_size bytes will be offloaded to CPU in the forward and brought back in the backward. |
| | This is in contrast to maintaining the activation on GPU VRAM throughout the program. |
| | |
| | This manager contains the option of using one additional CUDA stream to handle the communication |
| | between CUDA and CPU, which is intended to overlap with the default computation stream to improve |
| | runtime. We designed synchronization with a few heuristics for optimizing the tradeoff between |
| | runtime vs memory usage. |
| | |
| | Args: |
| | use_pin_memory (bool): Whether or not the offloaded Tensor will be placed in pinned |
| | memory on the CPU. Pinned memory allows the Tensor to be moved back onto GPU more quickly |
| | but is a limited resource. Default: True. |
| | |
| | use_streams (bool): Whether or not to use streams for performance optimization where |
| | the communications get overlapped with the computation. Requires a torch build |
| | after torch-2.5.0.]. Default: True. |
| | |
| | max_fwd_stash_size (int): The maximum size of the forward stash, or the maximum number of |
| | consecutive activations to keep alive during the forward pass. This number must be at |
| | least 1. Keeping alive more activations will potentially allow more overlap between the |
| | communication and compute streams at the cost of increasing memory usage. Keeping alive |
| | fewer activations will conserve memory, but may cause poor overlap between the streams, |
| | increasing runtime. Default: 5. |
| | |
| | min_offload_size (int): The minimum number of bytes a Tensor must be in order to qualify |
| | for offloading. If the tensor is too small, we do not want to waste bandwidth and resources |
| | moving it to CPU and back. Default: 1024 bytes. |
| | |
| | Raises: |
| | ValueError: if max_fwd_stash_size is not at least 1. |
| | |
| | Example: |
| | >>> with OffloadActivations(): |
| | >>> logits = model(inputs) |
| | >>> loss = ... |
| | >>> loss.backward() |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | use_pin_memory: bool = True, |
| | use_streams: bool = True, |
| | max_fwd_stash_size: int = 5, |
| | min_offload_size: int = 1024, |
| | ) -> None: |
| |
|
| | self.use_streams: bool = use_streams |
| |
|
| | self.min_tensor_size_bytes = ( |
| | min_offload_size |
| | ) |
| | self.tracker = ( |
| | {} |
| | ) |
| | self.tensor_id: int = 0 |
| | self.is_first_forward_call = True |
| | self.is_first_backward_call = True |
| | self.is_first_forward_pass = True |
| |
|
| | |
| | self.use_pin_memory: bool = use_pin_memory |
| | self.virtual_memory_safe_pct = ( |
| | 60 |
| | ) |
| |
|
| | self.s0 = torch.cuda.default_stream() |
| |
|
| | |
| | if self.use_streams: |
| | self.s1 = torch.cuda.Stream() |
| | self.fwd_stash = {} |
| | if max_fwd_stash_size < 1: |
| | raise ValueError( |
| | f"max_fwd_stash_size should be at least 1 but is {max_fwd_stash_size}" |
| | ) |
| | self.max_fwd_stash_size = max_fwd_stash_size |
| | self.bwd_tensor_stash = {} |
| | self.bwd_ev_stash = {} |
| | self.curr_graph_id = None |
| | self.curr_autograd_node = None |
| |
|
| | |
| | def verify_sufficient_virtual_memory(): |
| | curr_pct = get_cpu_ram_pct() |
| | if curr_pct > self.virtual_memory_safe_pct: |
| | warn( |
| | f"***** WARNING: {curr_pct=}% > {self.virtual_memory_safe_pct=}% of virtual memory used" |
| | ) |
| |
|
| | def get_cpu_ram_pct() -> float: |
| | |
| | return psutil.virtual_memory().percent |
| |
|
| | def get_tensor_id() -> int: |
| | |
| | self.tensor_id += 1 |
| | return self.tensor_id |
| |
|
| | def get_num_bytes_tensor(x: torch.Tensor) -> int: |
| | |
| | return ( |
| | x.element_size() * x.nelement() |
| | ) |
| |
|
| | |
| | def pack_tensor(activation: torch.Tensor) -> int: |
| | |
| | if self.is_first_forward_call: |
| | assert ( |
| | len(self.tracker) == 0 |
| | ), "backward pass should have cleared tracker of all tensors" |
| |
|
| | |
| | self.is_first_forward_call = False |
| | self.is_first_backward_call = True |
| |
|
| | |
| | num_bytes = get_num_bytes_tensor(activation) |
| | tensor_id = get_tensor_id() |
| |
|
| | |
| | |
| | if ( |
| | activation.is_cuda |
| | and num_bytes >= self.min_tensor_size_bytes |
| | and ( |
| | not isinstance(activation, torch.nn.Parameter) |
| | and not isinstance(activation, torch.nn.Buffer) |
| | ) |
| | ): |
| | if self.use_streams: |
| | |
| | |
| | for id in [k for k in self.fwd_stash.keys()]: |
| | if id <= tensor_id - self.max_fwd_stash_size: |
| | _, ev = self.fwd_stash[id] |
| | self.s0.wait_event(ev) |
| | del self.fwd_stash[id] |
| | else: |
| | break |
| |
|
| | |
| | self.s1.wait_stream(self.s0) |
| |
|
| | stream = self.s1 if self.use_streams else self.s0 |
| | with torch.cuda.stream(stream): |
| | try: |
| | cpu_tensor = torch.empty_like( |
| | activation, pin_memory=self.use_pin_memory, device="cpu" |
| | ) |
| | except NotImplementedError as e: |
| | if ( |
| | isinstance(activation, NF4Tensor) |
| | and torchao.__version__ < "0.6.0.dev20240917" |
| | ): |
| | raise RuntimeError( |
| | "Offloading NF4Tensors requires torchao-0.6.0.dev20240917 or later" |
| | ) from e |
| | raise e |
| | cpu_tensor.copy_(activation, non_blocking=True) |
| | self.tracker[tensor_id] = ( |
| | cpu_tensor, |
| | True, |
| | ) |
| |
|
| | if self.use_streams: |
| | event = self.s1.record_event() |
| |
|
| | |
| | self.fwd_stash[tensor_id] = (activation, event) |
| | else: |
| | self.tracker[tensor_id] = ( |
| | activation, |
| | False, |
| | ) |
| |
|
| | return tensor_id |
| |
|
| | def unpack_tensor_single_stream(unpack_tensor_id: int) -> torch.Tensor: |
| | |
| | |
| | if self.is_first_backward_call: |
| | if self.is_first_forward_pass: |
| | self.is_first_forward_pass = False |
| | if self.use_pin_memory: |
| | verify_sufficient_virtual_memory() |
| |
|
| | self.is_first_backward_call = False |
| | self.is_first_forward_call = True |
| |
|
| | assert ( |
| | unpack_tensor_id in self.tracker |
| | ), f"untracked tensor with id {unpack_tensor_id}" |
| |
|
| | maybe_gpu_tensor, modified = self.tracker[unpack_tensor_id] |
| | if modified: |
| | gpu_tensor = maybe_gpu_tensor.to("cuda", non_blocking=True) |
| | maybe_gpu_tensor = gpu_tensor |
| |
|
| | |
| | del self.tracker[unpack_tensor_id] |
| | return maybe_gpu_tensor |
| |
|
| | def unpack_tensor_with_streams(unpack_tensor_id: int) -> torch.Tensor: |
| | |
| | |
| | if self.is_first_backward_call: |
| | self.curr_graph_id = torch._C._current_graph_task_id() |
| |
|
| | def wait_and_del_remaining_references() -> None: |
| | for id in [k for k in self.bwd_tensor_stash.keys()]: |
| | event = self.bwd_ev_stash[id] |
| | self.s1.wait_event(event) |
| | del self.bwd_tensor_stash[id] |
| |
|
| | |
| | torch.autograd.variable.Variable._execution_engine.queue_callback( |
| | wait_and_del_remaining_references |
| | ) |
| |
|
| | if self.is_first_forward_pass: |
| | self.is_first_forward_pass = False |
| | if self.use_pin_memory: |
| | verify_sufficient_virtual_memory() |
| |
|
| | self.is_first_backward_call = False |
| | self.is_first_forward_call = True |
| |
|
| | assert ( |
| | unpack_tensor_id in self.tracker |
| | ), f"untracked tensor with id {unpack_tensor_id}" |
| |
|
| | maybe_gpu_tensor, modified = self.tracker[unpack_tensor_id] |
| | if modified: |
| | |
| | graph_id = torch._C._current_graph_task_id() |
| | node = torch._C._current_autograd_node() |
| | prev_node_ids = [] |
| |
|
| | |
| | if graph_id == self.curr_graph_id and self.curr_autograd_node != node: |
| | self.curr_autograd_node = node |
| | prev_node_ids = [id for id in self.bwd_tensor_stash.keys()] |
| |
|
| | brought_back_from_cpu = True |
| | if unpack_tensor_id in self.fwd_stash: |
| | maybe_gpu_tensor = self.fwd_stash[unpack_tensor_id][0] |
| | brought_back_from_cpu = False |
| | else: |
| | |
| | with torch.cuda.stream(self.s1): |
| | gpu_tensor = maybe_gpu_tensor.to("cuda", non_blocking=True) |
| | maybe_gpu_tensor = gpu_tensor |
| |
|
| | |
| | self.s0.wait_stream(self.s1) |
| |
|
| | |
| | self.bwd_tensor_stash[unpack_tensor_id] = maybe_gpu_tensor |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | storage_refcount = torch._C._storage_Use_Count( |
| | maybe_gpu_tensor.untyped_storage()._cdata |
| | ) |
| |
|
| | def hook(outputs, inputs): |
| | |
| | if brought_back_from_cpu: |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | unpacked_tensor = self.bwd_tensor_stash[unpack_tensor_id] |
| | if ( |
| | torch._C._storage_Use_Count( |
| | unpacked_tensor.untyped_storage()._cdata |
| | ) |
| | > storage_refcount |
| | ): |
| | unpacked_tensor.record_stream(self.s0) |
| | del self.bwd_tensor_stash[unpack_tensor_id] |
| | else: |
| | event = self.s0.record_event() |
| | self.bwd_ev_stash[unpack_tensor_id] = event |
| |
|
| | |
| | for id in [k for k in self.fwd_stash.keys()]: |
| | _, ev = self.fwd_stash[id] |
| | self.s0.wait_event(ev) |
| | del self.fwd_stash[id] |
| |
|
| | |
| | for id in prev_node_ids: |
| | event = self.bwd_ev_stash[id] |
| | self.s1.wait_event(event) |
| | del self.bwd_tensor_stash[id] |
| |
|
| | return outputs |
| |
|
| | node.register_hook(hook) |
| |
|
| | |
| | del self.tracker[unpack_tensor_id] |
| | return maybe_gpu_tensor |
| |
|
| | unpack_tensor = ( |
| | unpack_tensor_with_streams |
| | if self.use_streams |
| | else unpack_tensor_single_stream |
| | ) |
| | super().__init__(pack_tensor, unpack_tensor) |
| |
|
| |
|
| | class NoOpManager(saved_tensors_hooks): |
| | """ |
| | A saved_tensors_hook manager used to disable any other saved_tensors_hook manager |
| | applied before. This relies on the behavior that only the most recently registered |
| | saved_tensors_hook will run. |
| | |
| | One example usage is to opt a local region of code out of activations offloading, |
| | which is usually applied globally to best track state. |
| | """ |
| |
|
| | def __init__(self) -> None: |
| | def noop(tensor): |
| | return tensor |
| |
|
| | super().__init__(noop, noop) |
| |
|
| |
|
| | def get_act_offloading_ctx_manager( |
| | model: nn.Module, enable_activation_offloading: bool |
| | ) -> Union[OffloadActivations, contextlib.nullcontext]: |
| | """Returns the activation offloading context manager for the model, which will be |
| | a null context if enable_activation_offloading is False. |
| | |
| | If activation offloading is enabled, we return the OffloadActivations context manager. |
| | If activation offloading is disabled, we return a NoOpManager context manager. |
| | |
| | Args: |
| | model (nn.Module): the model to wrap with the activation offloading context manager. |
| | enable_activation_offloading (bool): whether or not to enable activation offloading |
| | for the model. |
| | |
| | Returns: |
| | contextlib.ContextDecorator: the activation offloading context manager for the model. |
| | |
| | Raises: |
| | NotImplementedError: If the model is a multimodal model and activation offloading is enabled. |
| | """ |
| | if enable_activation_offloading: |
| | activations_handling_ctx = OffloadActivations() |
| |
|
| | |
| | |
| | |
| | |
| | output_head_detected = False |
| | noop_ctx = NoOpManager() |
| |
|
| | if hasattr(model, "output"): |
| | if isinstance(model.output, nn.Module): |
| | model.output.register_forward_pre_hook( |
| | lambda *args: noop_ctx.__enter__() |
| | ) |
| | model.output.register_forward_hook( |
| | lambda *args: noop_ctx.__exit__(), always_call=True |
| | ) |
| | print("registering hooks for model.output ============ ") |
| | output_head_detected = True |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | if not output_head_detected: |
| | logger.warning( |
| | "During activation offloading, no output head was detected. " |
| | "If your model has an output head, it will be offloaded. " |
| | "This usually greatly slows training, given the large vocabulary size. " |
| | "To change this behavior, set your output head as model.output and make it " |
| | "an nn.Module." |
| | ) |
| |
|
| | else: |
| | activations_handling_ctx = contextlib.nullcontext() |
| |
|
| | return activations_handling_ctx |
| |
|