|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| import psutil
|
| import torch
|
| from accelerate import logging
|
| from accelerate.utils.versions import is_torch_version
|
| from torch import nn
|
| from torch.autograd.graph import saved_tensors_hooks
|
| from transformers import is_torch_npu_available
|
|
|
|
|
| if is_torch_npu_available():
|
| import torch_npu
|
|
|
|
|
| DTensor = None
|
| if torch.distributed.is_available():
|
| try:
|
| if is_torch_version(">=", "2.5.0"):
|
| from torch.distributed.tensor import DTensor
|
| else:
|
|
|
| from torch.distributed._tensor import DTensor
|
| except (ImportError, AttributeError):
|
| DTensor = None
|
|
|
| logger = logging.get_logger(__name__)
|
|
|
|
|
| def _get_unique_tensor_key(tensor: torch.Tensor) -> tuple:
|
| """
|
| Get a unique key for a tensor based on its storage pointer and dtype. This allows deduplication of tensors that
|
| share the same underlying storage. From:
|
| https://github.com/volcengine/verl/blob/main/verl/utils/activation_offload.py
|
|
|
| Args:
|
| tensor: The tensor to get the key for
|
|
|
| Returns:
|
| A tuple of (storage_pointer, dtype) that uniquely identifies the tensor's storage
|
| """
|
|
|
| actual_tensor = tensor
|
|
|
|
|
| if DTensor is not None and isinstance(tensor, DTensor) and hasattr(tensor, "_local_tensor"):
|
| actual_tensor = tensor._local_tensor
|
|
|
|
|
| try:
|
| storage_ptr = actual_tensor.untyped_storage().data_ptr() + actual_tensor.storage_offset()
|
| except (RuntimeError, AttributeError):
|
|
|
|
|
| storage_ptr = id(actual_tensor)
|
|
|
| return (storage_ptr, actual_tensor.dtype)
|
|
|
|
|
| class OffloadActivations(saved_tensors_hooks):
|
| """
|
| Context manager under which activation tensors created in the forward pass will be offloaded.
|
|
|
| Enable the memory efficiency technique of activation offloading, where activations bigger than `min_offload_size`
|
| bytes will be offloaded to CPU in the forward and brought back in the backward. This is in contrast to maintaining
|
| the activation on GPU VRAM throughout the program.
|
|
|
| This manager contains the option of using one additional CUDA stream to handle the communication between CUDA and
|
| CPU, which is intended to overlap with the default computation stream to improve runtime. We designed
|
| synchronization with a few heuristics for optimizing the tradeoff between runtime vs memory usage.
|
|
|
| Args:
|
| use_pin_memory (`bool`, *optional*, defaults to `True`):
|
| Whether to offloaded Tensor will be placed in pinned memory on the CPU. Pinned memory allows the Tensor to
|
| be moved back onto GPU more quickly but is a limited resource.
|
| use_streams (`bool`, *optional*, defaults to `True`):
|
| Whether to use streams for performance optimization where the communications get overlapped with the
|
| computation. Requires a torch build after torch-2.5.0.
|
| min_offload_size (`int`, *optional*, defaults to `1024`):
|
| Minimum number of bytes a Tensor must be in order to qualify for offloading. If the tensor is too small, we
|
| do not want to waste bandwidth and resources moving it to CPU and back.
|
| max_fwd_stash_size (`int`, *optional*, defaults to `5`):
|
| Maximum size of the forward stash, or the maximum number of consecutive activations to keep alive during
|
| the forward pass. This number must be at least 1. Keeping alive more activations will potentially allow
|
| more overlap between the communication and compute streams at the cost of increasing memory usage. Keeping
|
| alive fewer activations will conserve memory, but may cause poor overlap between the streams, increasing
|
| runtime.
|
|
|
| Raises:
|
| ValueError: if `max_fwd_stash_size` is not at least `1`.
|
|
|
| Example:
|
| ```python
|
| >>> with OffloadActivations():
|
| ... outputs = model(inputs, labels=labels)
|
| >>> loss = outputs.loss
|
| >>> loss.backward()
|
| ```
|
| """
|
|
|
| def __init__(
|
| self,
|
| use_pin_memory: bool = True,
|
| use_streams: bool = True,
|
| min_offload_size: int = 1024,
|
| max_fwd_stash_size: int = 5,
|
| ) -> None:
|
| self.use_streams = use_streams
|
|
|
| self.min_tensor_size_bytes = min_offload_size
|
| self.tracker = {}
|
| self.tensor_id = 0
|
| self.is_first_forward_call = True
|
| self.is_first_backward_call = True
|
| self.is_first_forward_pass = True
|
|
|
|
|
| self.storage_to_tensor_id = {}
|
|
|
|
|
| self.param_storages = set()
|
|
|
|
|
| self.use_pin_memory = use_pin_memory
|
| self.virtual_memory_safe_pct = 60
|
|
|
| self.accelerator_type = (
|
| torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
|
| )
|
|
|
| if self.accelerator_type == "xpu":
|
| self.s0 = torch.xpu.current_stream()
|
| elif is_torch_npu_available() and self.accelerator_type == "npu":
|
| self.s0 = torch.npu.current_stream()
|
| else:
|
| self.s0 = torch.cuda.default_stream()
|
|
|
|
|
| if self.use_streams:
|
| if self.accelerator_type == "xpu":
|
| self.s1 = torch.xpu.Stream()
|
| elif self.accelerator_type == "npu":
|
| self.s1 = torch.npu.Stream()
|
| else:
|
| self.s1 = torch.cuda.Stream()
|
| self.fwd_stash = {}
|
| if max_fwd_stash_size < 1:
|
| raise ValueError(f"max_fwd_stash_size should be at least 1 but is {max_fwd_stash_size}")
|
| self.max_fwd_stash_size = max_fwd_stash_size
|
| self.bwd_tensor_stash = {}
|
| self.bwd_ev_stash = {}
|
| self.curr_graph_id = None
|
| self.curr_autograd_node = None
|
|
|
|
|
| def verify_sufficient_virtual_memory():
|
| curr_pct = get_cpu_ram_pct()
|
| if curr_pct > self.virtual_memory_safe_pct:
|
| logger.warning(f"{curr_pct=}% > {self.virtual_memory_safe_pct=}% of virtual memory used")
|
|
|
| def get_cpu_ram_pct() -> float:
|
|
|
| return psutil.virtual_memory().percent
|
|
|
| def get_tensor_id() -> int:
|
|
|
| self.tensor_id += 1
|
| return self.tensor_id
|
|
|
| def get_num_bytes_tensor(x: torch.Tensor) -> int:
|
|
|
| return x.element_size() * x.nelement()
|
|
|
|
|
| def pack_tensor(activation: torch.Tensor) -> int:
|
|
|
| if self.is_first_forward_call:
|
| if len(self.tracker) != 0:
|
| raise ValueError("Backward pass should have cleared tracker of all tensors")
|
|
|
|
|
| self.is_first_forward_call = False
|
| self.is_first_backward_call = True
|
|
|
| self.storage_to_tensor_id = {}
|
|
|
|
|
| num_bytes = get_num_bytes_tensor(activation)
|
| tensor_id = get_tensor_id()
|
|
|
|
|
|
|
|
|
| storage_key = _get_unique_tensor_key(activation)
|
| if storage_key in self.storage_to_tensor_id:
|
|
|
| self.tracker[tensor_id] = (activation, False, None, None, None)
|
| return tensor_id
|
|
|
|
|
| if activation.device.type not in ["cuda", "xpu", "npu"]:
|
| self.tracker[tensor_id] = (activation, False, None, None, None)
|
| return tensor_id
|
|
|
|
|
| if num_bytes < self.min_tensor_size_bytes:
|
| self.tracker[tensor_id] = (activation, False, None, None, None)
|
| return tensor_id
|
|
|
|
|
| if isinstance(activation, torch.nn.Parameter) or (
|
| hasattr(torch.nn, "Buffer") and isinstance(activation, torch.nn.Buffer)
|
| ):
|
| self.tracker[tensor_id] = (activation, False, None, None, None)
|
| return tensor_id
|
|
|
|
|
| tensor_class_name = type(activation).__name__
|
| if tensor_class_name in ["Float8TrainingTensor", "ScaledMMConfig", "LinearMMConfig"]:
|
| self.tracker[tensor_id] = (activation, False, None, None, None)
|
| return tensor_id
|
|
|
|
|
| try:
|
|
|
| check_tensor = activation
|
| if DTensor is not None and isinstance(activation, DTensor) and hasattr(activation, "_local_tensor"):
|
| check_tensor = activation._local_tensor
|
|
|
| if check_tensor.untyped_storage().data_ptr() in self.param_storages:
|
| self.tracker[tensor_id] = (activation, False, None, None, None)
|
| return tensor_id
|
| except (RuntimeError, AttributeError):
|
|
|
| pass
|
|
|
|
|
| if self.use_streams:
|
|
|
|
|
| for id in list(self.fwd_stash.keys()):
|
| if id <= tensor_id - self.max_fwd_stash_size:
|
| _, ev = self.fwd_stash[id]
|
| self.s0.wait_event(ev)
|
| del self.fwd_stash[id]
|
| else:
|
| break
|
|
|
|
|
| self.s1.wait_stream(self.s0)
|
|
|
| stream = self.s1 if self.use_streams else self.s0
|
| if self.accelerator_type == "xpu":
|
| stream_ctx = torch.xpu.stream(stream)
|
| elif self.accelerator_type == "npu":
|
| stream_ctx = torch.npu.stream(stream)
|
| else:
|
| stream_ctx = torch.cuda.stream(stream)
|
| with stream_ctx:
|
|
|
| original_stride = activation.stride()
|
| original_storage_offset = activation.storage_offset()
|
| original_shape = activation.size()
|
|
|
|
|
|
|
| has_broadcast = 0 in original_stride
|
|
|
| if has_broadcast:
|
|
|
|
|
| storage_size = activation.untyped_storage().size()
|
| cpu_storage = torch.empty(
|
| storage_size // activation.element_size(),
|
| dtype=activation.dtype,
|
| pin_memory=self.use_pin_memory,
|
| device="cpu",
|
| )
|
|
|
| cpu_storage_view = torch.as_strided(
|
| activation, size=(storage_size // activation.element_size(),), stride=(1,), storage_offset=0
|
| )
|
| cpu_storage.copy_(cpu_storage_view, non_blocking=True)
|
| cpu_tensor = cpu_storage
|
| else:
|
|
|
| cpu_tensor = torch.empty_like(activation, pin_memory=self.use_pin_memory, device="cpu")
|
| cpu_tensor.copy_(activation, non_blocking=True)
|
|
|
|
|
| self.tracker[tensor_id] = (
|
| cpu_tensor,
|
| True,
|
| original_stride,
|
| original_storage_offset,
|
| original_shape,
|
| )
|
|
|
| if self.use_streams:
|
| event = self.s1.record_event()
|
|
|
|
|
| self.fwd_stash[tensor_id] = (activation, event)
|
|
|
|
|
| self.storage_to_tensor_id[storage_key] = tensor_id
|
|
|
| return tensor_id
|
|
|
| def unpack_tensor_single_stream(unpack_tensor_id: int) -> torch.Tensor:
|
|
|
|
|
| if self.is_first_backward_call:
|
| if self.is_first_forward_pass:
|
| self.is_first_forward_pass = False
|
| if self.use_pin_memory:
|
| verify_sufficient_virtual_memory()
|
|
|
| self.is_first_backward_call = False
|
|
|
| if unpack_tensor_id not in self.tracker:
|
| raise ValueError(f"Untracked tensor with id {unpack_tensor_id}")
|
|
|
| (
|
| maybe_accelerator_tensor,
|
| modified,
|
| original_stride,
|
| original_storage_offset,
|
| original_shape,
|
| ) = self.tracker[unpack_tensor_id]
|
|
|
| if modified:
|
|
|
| accelerator_tensor = maybe_accelerator_tensor.to(self.accelerator_type, non_blocking=True)
|
|
|
| if original_stride is not None:
|
| accelerator_tensor = torch.as_strided(
|
| accelerator_tensor,
|
| size=original_shape,
|
| stride=original_stride,
|
| storage_offset=original_storage_offset,
|
| )
|
| maybe_accelerator_tensor = accelerator_tensor
|
|
|
|
|
| del self.tracker[unpack_tensor_id]
|
|
|
| if len(self.tracker) == 0:
|
| self.is_first_forward_call = True
|
| return maybe_accelerator_tensor
|
|
|
| def unpack_tensor_with_streams(unpack_tensor_id: int) -> torch.Tensor:
|
|
|
|
|
| if self.is_first_backward_call:
|
| self.curr_graph_id = torch._C._current_graph_task_id()
|
|
|
| def wait_and_del_remaining_references() -> None:
|
| for id in list(self.bwd_tensor_stash.keys()):
|
| if id in self.bwd_ev_stash:
|
| event = self.bwd_ev_stash[id]
|
| self.s1.wait_event(event)
|
| del self.bwd_tensor_stash[id]
|
|
|
|
|
| torch.autograd.variable.Variable._execution_engine.queue_callback(wait_and_del_remaining_references)
|
|
|
| if self.is_first_forward_pass:
|
| self.is_first_forward_pass = False
|
| if self.use_pin_memory:
|
| verify_sufficient_virtual_memory()
|
|
|
| self.is_first_backward_call = False
|
|
|
| if unpack_tensor_id not in self.tracker:
|
| raise ValueError(f"untracked tensor with id {unpack_tensor_id}")
|
|
|
| (
|
| maybe_accelerator_tensor,
|
| modified,
|
| original_stride,
|
| original_storage_offset,
|
| original_shape,
|
| ) = self.tracker[unpack_tensor_id]
|
|
|
| if modified:
|
|
|
| graph_id = torch._C._current_graph_task_id()
|
| node = torch._C._current_autograd_node()
|
| prev_node_ids = []
|
|
|
|
|
| if graph_id == self.curr_graph_id and self.curr_autograd_node != node:
|
| self.curr_autograd_node = node
|
| prev_node_ids = list(self.bwd_tensor_stash.keys())
|
|
|
| brought_back_from_cpu = True
|
| if unpack_tensor_id in self.fwd_stash:
|
| maybe_accelerator_tensor = self.fwd_stash[unpack_tensor_id][0]
|
| brought_back_from_cpu = False
|
| else:
|
|
|
| if self.accelerator_type == "xpu":
|
| stream_ctx = torch.xpu.stream(self.s1)
|
| elif self.accelerator_type == "npu":
|
| stream_ctx = torch.npu.stream(self.s1)
|
| else:
|
| stream_ctx = torch.cuda.stream(self.s1)
|
| with stream_ctx:
|
|
|
| accelerator_tensor = maybe_accelerator_tensor.to(self.accelerator_type, non_blocking=True)
|
|
|
| if original_stride is not None:
|
| accelerator_tensor = torch.as_strided(
|
| accelerator_tensor,
|
| size=original_shape,
|
| stride=original_stride,
|
| storage_offset=original_storage_offset,
|
| )
|
| maybe_accelerator_tensor = accelerator_tensor
|
|
|
|
|
| self.s0.wait_stream(self.s1)
|
|
|
|
|
| self.bwd_tensor_stash[unpack_tensor_id] = maybe_accelerator_tensor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| if self.accelerator_type == "npu":
|
| storage_refcount = torch_npu._C._storage_Use_Count(
|
| maybe_accelerator_tensor.untyped_storage()._cdata
|
| )
|
| else:
|
| storage_refcount = torch._C._storage_Use_Count(
|
| maybe_accelerator_tensor.untyped_storage()._cdata
|
| )
|
|
|
| def hook(outputs, inputs):
|
|
|
| if brought_back_from_cpu:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| if unpack_tensor_id in self.bwd_tensor_stash:
|
| unpacked_tensor = self.bwd_tensor_stash[unpack_tensor_id]
|
| if self.accelerator_type == "npu":
|
| storage_count = torch_npu._C._storage_Use_Count(
|
| unpacked_tensor.untyped_storage()._cdata
|
| )
|
| else:
|
| storage_count = torch._C._storage_Use_Count(unpacked_tensor.untyped_storage()._cdata)
|
| if storage_count > storage_refcount:
|
| unpacked_tensor.record_stream(self.s0)
|
| del self.bwd_tensor_stash[unpack_tensor_id]
|
| else:
|
| event = self.s0.record_event()
|
| self.bwd_ev_stash[unpack_tensor_id] = event
|
|
|
|
|
| for id in list(self.fwd_stash.keys()):
|
| _, ev = self.fwd_stash[id]
|
| self.s0.wait_event(ev)
|
| del self.fwd_stash[id]
|
|
|
|
|
| for id in prev_node_ids:
|
|
|
| if id in self.bwd_ev_stash:
|
| event = self.bwd_ev_stash[id]
|
| self.s1.wait_event(event)
|
| del self.bwd_ev_stash[id]
|
| if id in self.bwd_tensor_stash:
|
| del self.bwd_tensor_stash[id]
|
|
|
| return outputs
|
|
|
| node.register_hook(hook)
|
|
|
|
|
| del self.tracker[unpack_tensor_id]
|
|
|
| if len(self.tracker) == 0:
|
| self.is_first_forward_call = True
|
| return maybe_accelerator_tensor
|
|
|
| unpack_tensor = unpack_tensor_with_streams if self.use_streams else unpack_tensor_single_stream
|
| super().__init__(pack_tensor, unpack_tensor)
|
|
|
| def update_model_params(self, model: nn.Module):
|
| """
|
| Update the set of parameter storage pointers from the model. This allows filtering out model parameters during
|
| offloading, which is especially important for FSDP models where parameters may not be detected by isinstance
|
| checks.
|
|
|
| For FSDP v2, this method handles DTensor parameters which may be sharded across ranks and not have valid local
|
| storage on all ranks. We extract the local tensor from DTensors using _local_tensor when available.
|
|
|
| Args:
|
| model: The model whose parameters should be tracked
|
| """
|
| param_storages = set()
|
|
|
| for p in model.parameters():
|
|
|
| actual_tensor = p
|
| if DTensor is not None and isinstance(p, DTensor) and hasattr(p, "_local_tensor"):
|
| actual_tensor = p._local_tensor
|
|
|
|
|
| try:
|
| storage_ptr = actual_tensor.untyped_storage().data_ptr()
|
| if storage_ptr != 0:
|
| param_storages.add(storage_ptr)
|
| except RuntimeError:
|
|
|
|
|
| continue
|
|
|
| self.param_storages = param_storages
|
|
|
|
|
| class NoOpManager(saved_tensors_hooks):
|
| """
|
| A `saved_tensors_hook` manager used to disable any other `saved_tensors_hook` manager applied before. This relies
|
| on the behavior that only the most recently registered `saved_tensors_hook` will run.
|
|
|
| One example usage is to opt a local region of code out of activations offloading, which is usually applied globally
|
| to best track state.
|
| """
|
|
|
| def __init__(self) -> None:
|
| def noop(tensor):
|
| return tensor
|
|
|
| super().__init__(noop, noop)
|
|
|
|
|
| def get_act_offloading_ctx_manager(
|
| model: nn.Module,
|
| use_pin_memory: bool = True,
|
| use_streams: bool = True,
|
| min_offload_size: int = 1024,
|
| max_fwd_stash_size: int = 5,
|
| warn_if_no_head: bool = True,
|
| ) -> OffloadActivations:
|
| """
|
| Returns the activation offloading context manager for the model. All but the last output Linear in every step will
|
| be offloaded.
|
|
|
| If activation offloading is enabled, we return the OffloadActivations context manager. If activation offloading is
|
| disabled, we return a NoOpManager context manager.
|
|
|
| Args:
|
| model (`nn.Module`):
|
| Model to wrap with the activation offloading context manager.
|
| use_pin_memory (`bool`, *optional*, defaults to `True`):
|
| Whether to offloaded Tensor will be placed in pinned memory on the CPU. Pinned memory allows the Tensor to
|
| be moved back onto GPU more quickly but is a limited resource.
|
| use_streams (`bool`, *optional*, defaults to `True`):
|
| Whether to use streams for performance optimization where the communications get overlapped with the
|
| computation. Requires a torch build after torch-2.5.0.
|
| min_offload_size (`int`, *optional*, defaults to `1024`):
|
| Minimum number of bytes a Tensor must be in order to qualify for offloading. If the tensor is too small, we
|
| do not want to waste bandwidth and resources moving it to CPU and back.
|
| max_fwd_stash_size (`int`, *optional*, defaults to `5`):
|
| Maximum size of the forward stash, or the maximum number of consecutive activations to keep alive during
|
| the forward pass. This number must be at least 1. Keeping alive more activations will potentially allow
|
| more overlap between the communication and compute streams at the cost of increasing memory usage. Keeping
|
| alive fewer activations will conserve memory, but may cause poor overlap between the streams, increasing
|
| runtime.
|
| warn_if_no_head (`bool`, *optional*, defaults to `True`):
|
| Whether to warn if no output head is detected. If set to `False`, no warning will be raised if no output
|
| head is detected.
|
|
|
| Returns:
|
| `contextlib.ContextDecorator`:
|
| Activation offloading context manager for the model.
|
| """
|
| activations_handling_ctx = OffloadActivations(
|
| use_pin_memory=use_pin_memory,
|
| use_streams=use_streams,
|
| min_offload_size=min_offload_size,
|
| max_fwd_stash_size=max_fwd_stash_size,
|
| )
|
|
|
|
|
| activations_handling_ctx.update_model_params(model)
|
|
|
|
|
|
|
|
|
| output_head_detected = False
|
| noop_ctx = NoOpManager()
|
|
|
|
|
| unwrapped_model = model
|
| if hasattr(unwrapped_model, "module"):
|
| unwrapped_model = unwrapped_model.module
|
|
|
| if hasattr(unwrapped_model, "base_model") and hasattr(unwrapped_model, "peft_config"):
|
| unwrapped_model = unwrapped_model.base_model
|
|
|
|
|
| if hasattr(unwrapped_model, "output"):
|
| if isinstance(unwrapped_model.output, nn.Module):
|
| unwrapped_model.output.register_forward_pre_hook(lambda *args: noop_ctx.__enter__())
|
| unwrapped_model.output.register_forward_hook(lambda *args: noop_ctx.__exit__(), always_call=True)
|
| output_head_detected = True
|
| elif hasattr(unwrapped_model.output, "linear") and isinstance(unwrapped_model.output.linear, nn.Module):
|
| unwrapped_model.output.linear.register_forward_pre_hook(lambda *args: noop_ctx.__enter__())
|
| unwrapped_model.output.linear.register_forward_hook(lambda *args: noop_ctx.__exit__(), always_call=True)
|
| output_head_detected = True
|
|
|
|
|
| elif hasattr(unwrapped_model, "lm_head"):
|
| unwrapped_model.lm_head.register_forward_pre_hook(lambda *args: noop_ctx.__enter__())
|
| unwrapped_model.lm_head.register_forward_hook(lambda *args: noop_ctx.__exit__(), always_call=True)
|
| output_head_detected = True
|
|
|
|
|
| elif hasattr(unwrapped_model, "decoder"):
|
| decoder = unwrapped_model.decoder
|
| if hasattr(decoder, "output"):
|
| decoder.output.register_forward_pre_hook(lambda *args: noop_ctx.__enter__())
|
| decoder.output.register_forward_hook(lambda *args: noop_ctx.__exit__(), always_call=True)
|
| output_head_detected = True
|
|
|
| elif hasattr(decoder, "lm_head"):
|
| decoder.lm_head.register_forward_pre_hook(lambda *args: noop_ctx.__enter__())
|
| decoder.lm_head.register_forward_hook(lambda *args: noop_ctx.__exit__(), always_call=True)
|
| output_head_detected = True
|
|
|
|
|
| elif hasattr(unwrapped_model, "final_layer_norm") or hasattr(unwrapped_model, "ln_f"):
|
| final_norm = getattr(unwrapped_model, "final_layer_norm", None) or unwrapped_model.ln_f
|
| final_norm.register_forward_pre_hook(lambda *args: noop_ctx.__enter__())
|
| final_norm.register_forward_hook(lambda *args: noop_ctx.__exit__(), always_call=True)
|
| output_head_detected = True
|
|
|
|
|
| elif hasattr(unwrapped_model, "head") and isinstance(unwrapped_model.head, nn.Module):
|
| unwrapped_model.head.register_forward_pre_hook(lambda *args: noop_ctx.__enter__())
|
| unwrapped_model.head.register_forward_hook(lambda *args: noop_ctx.__exit__(), always_call=True)
|
| output_head_detected = True
|
|
|
| if not output_head_detected and warn_if_no_head:
|
| logger.warning(
|
| "During activation offloading, no output head was detected. If your model has an output head, it will be "
|
| "offloaded. This usually greatly slows training, given the large vocabulary size. To change this "
|
| "behavior, set your output head as model.output and make it an nn.Module. You can disable this warning by "
|
| "passing `warn_if_no_head=False`."
|
| )
|
|
|
|
|
| for name, module in unwrapped_model.named_modules():
|
| if "liger" in name.lower():
|
| module.register_forward_pre_hook(lambda *args: noop_ctx.__enter__())
|
| module.register_forward_hook(lambda *args: noop_ctx.__exit__(), always_call=True)
|
|
|
| return activations_handling_ctx
|
|
|