# Copyright 2020-2026 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import itertools import warnings from collections.abc import Callable from contextlib import contextmanager from copy import deepcopy from typing import TYPE_CHECKING, Any import accelerate import torch.nn as nn import transformers from accelerate import Accelerator from packaging.version import Version from torch.distributed.fsdp import FSDPModule from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP from transformers import GenerationConfig, PreTrainedModel from ..import_utils import suppress_experimental_warning with suppress_experimental_warning(): from ..experimental.utils import create_reference_model as _create_reference_model if Version(accelerate.__version__) >= Version("1.11.0"): from accelerate.utils.fsdp_utils import get_parameters_from_modules if TYPE_CHECKING: from deepspeed.runtime.engine import DeepSpeedEngine from torch.nn import Module from torch.nn.parallel.distributed import DistributedDataParallel def remove_hooks(model: "DeepSpeedEngine") -> None: """Removes the optimizer hooks from a DeepSpeed ZeRO-3 model.""" if not hasattr(model, "optimizer"): # before the first training step, the model has no optimizer return if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"): optimizer_offload = model.optimizer.parameter_offload elif model.optimizer is not None: optimizer_offload = model.optimizer else: raise RuntimeError("The model optimizer is None, which is not yet supported.") for param in iter_params(optimizer_offload.module, recurse=True): param.ds_active_sub_modules.clear() for hook in optimizer_offload.forward_hooks: hook.remove() for hook in optimizer_offload.backward_hooks: hook.remove() optimizer_offload.forward_hooks = [] optimizer_offload.backward_hooks = [] def get_all_parameters(sub_module, recurse=False): return itertools.chain(sub_module.named_parameters(recurse=recurse), sub_module.ds_external_parameters()) def iter_params(module, recurse=False): return [param for _, param in get_all_parameters(module, recurse)] def add_hooks(model: "DeepSpeedEngine") -> None: """Adds the optimizer hooks from a DeepSpeed ZeRO-3 model.""" import deepspeed if not hasattr(model, "optimizer"): # before the first training step, the model has no optimizer return if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"): optimizer_offload = model.optimizer.parameter_offload elif model.optimizer is not None: optimizer_offload = model.optimizer else: raise RuntimeError("The model optimizer is None, which is not yet supported.") if Version(deepspeed.__version__) >= Version("0.16.4"): # Account for renaming in https://github.com/deepspeedai/DeepSpeed/pull/6847 optimizer_offload._register_deepspeed_module(optimizer_offload.module) else: optimizer_offload._register_hooks_recursively(optimizer_offload.module) @contextmanager def _unwrap_model_for_generation( model: "DistributedDataParallel | DeepSpeedEngine", accelerator: "Accelerator", gather_deepspeed3_params: bool = True, ): """ Context manager to unwrap distributed or accelerated models for generation tasks. Args: model (`DistributedDataParallel | DeepSpeedEngine`): Model to be unwrapped. accelerator ([`~accelerate.Accelerator`]): Accelerator instance managing the model. gather_deepspeed3_params (`bool`, *optional*, defaults to `True`): Whether to gather weights for DeepSpeed ZeRO Stage 3 models. If `False`, skips parameter gathering, which can be more memory-efficient but may lead to slower generation times. Yields: Unwrapped model. Example: ```python with _unwrap_model_for_generation(model, accelerator) as unwrapped_model: generated_outputs = unwrapped_model.generate(input_ids) ``` """ unwrapped_model = accelerator.unwrap_model(model) is_gradient_checkpointing = unwrapped_model.is_gradient_checkpointing if is_gradient_checkpointing: unwrapped_model.gradient_checkpointing_disable() if accelerator.state.deepspeed_plugin is not None and accelerator.state.deepspeed_plugin.zero_stage == 3: if not gather_deepspeed3_params: yield accelerator.unwrap_model(model) else: import deepspeed with deepspeed.zero.GatheredParameters(model.parameters()): remove_hooks(model) yield accelerator.unwrap_model(model) add_hooks(model) else: yield unwrapped_model if is_gradient_checkpointing: unwrapped_model.gradient_checkpointing_enable() @contextmanager def _override_model_generation_config(model, generation_kwargs=None): """ Context manager to temporarily override a model's generation_config with training config. This works around transformers' config merging logic that would otherwise overwrite values matching global defaults with model-specific values (see upstream issue transformers#42762; fixed in transformers v5 by PR `transformers#42702`). By temporarily setting the model's generation_config to match the passed generation_config, we avoid the conflict. The model's original generation_config is preserved outside this context, ensuring that saved/pushed models retain their intended inference behavior. Args: model: The model (typically unwrapped_model) whose generation_config to temporarily override. generation_kwargs (dict): Generation kwargs to be used to override model's generation config. """ if ( # Issue fixed in transformers v5 by PR transformers#42702 Version(transformers.__version__) >= Version("5.0.0") or generation_kwargs is None or not hasattr(model, "generation_config") ): yield model return # If it is a PEFT model, override the underlying base model if hasattr(model, "get_base_model"): model = model.get_base_model() # Keep original model generation_config original_config = model.generation_config # Create training-specific generation config from the model's original generation config # Then overwrite it with the training-specific generation kwargs generation_config = GenerationConfig.from_dict(model.generation_config.to_dict()) generation_config.update(**generation_kwargs) model.generation_config = generation_config try: yield finally: model.generation_config = original_config @contextmanager def unwrap_model_for_generation( model: "DistributedDataParallel | DeepSpeedEngine", accelerator: "Accelerator", gather_deepspeed3_params: bool = True, generation_kwargs: dict | None = None, ): """ Context manager to unwrap distributed or accelerated models for generation tasks. This function unwraps distributed models (FSDP, DeepSpeed) and optionally overrides the model's generation_config temporarily during generation. This is useful for applying training-specific generation parameters without permanently modifying the model's original generation_config. Args: model (`DistributedDataParallel | DeepSpeedEngine`): Model to be unwrapped. accelerator ([`~accelerate.Accelerator`]): Accelerator instance managing the model. gather_deepspeed3_params (`bool`, *optional*, defaults to `True`): Whether to gather weights for DeepSpeed ZeRO Stage 3 models. If `False`, skips parameter gathering, which can be more memory-efficient but may lead to slower generation times. generation_kwargs (dict, *optional*): If provided, temporarily overrides the model's generation_config during generation. The original config is automatically restored when exiting the context. This is useful for using different generation parameters during training vs. inference. Yields: Unwrapped model with optionally overridden generation_config. """ with ( _unwrap_model_for_generation( model, accelerator, gather_deepspeed3_params=gather_deepspeed3_params ) as unwrapped_model, _override_model_generation_config(unwrapped_model, generation_kwargs=generation_kwargs), ): yield unwrapped_model def prepare_deepspeed(model: "Module", accelerator: "Accelerator"): """Prepares the model for DeepSpeed inference or evaluation by initializing it with the appropriate configuration. Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473 """ import deepspeed # local import (instead of top-level) to avoid DS init interfering with other backends (like vllm): https://github.com/deepspeedai/DeepSpeed/issues/7252 deepspeed_plugin = accelerator.state.deepspeed_plugin config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config) stage = config_kwargs["zero_optimization"]["stage"] if model is not None: hidden_size = ( max(model.config.hidden_sizes) if getattr(model.config, "hidden_sizes", None) else getattr(model.config, "hidden_size", None) ) if hidden_size is not None and stage == 3: # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache # @ step 0: expected module 1, but got module 0` # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081 config_kwargs.update( { "zero_optimization.reduce_bucket_size": hidden_size * hidden_size, "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size, "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size, } ) # If ZeRO-3 is used, we shard both the active and reference model. # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO # disabled (stage 0) if stage != 3: config_kwargs["zero_optimization"]["stage"] = 0 model, *_ = deepspeed.initialize(model=model, config=config_kwargs) model.eval() return model def prepare_fsdp(model, accelerator: Accelerator) -> FSDP | FSDPModule: # Check if the model is already a FSDP model due to `Manual Wrapping` and if so, don't wrap it again if not isinstance(model, (FSDP, FSDPModule)): fsdp_plugin = accelerator.state.fsdp_plugin if fsdp_plugin.fsdp_version == 1: accelerator.state.fsdp_plugin.set_auto_wrap_policy(model) kwargs = { "sharding_strategy": fsdp_plugin.sharding_strategy or fsdp_plugin.reshard_after_forward, "cpu_offload": fsdp_plugin.cpu_offload, "auto_wrap_policy": fsdp_plugin.auto_wrap_policy, "mixed_precision": fsdp_plugin.mixed_precision_policy, "sync_module_states": fsdp_plugin.sync_module_states, "backward_prefetch": fsdp_plugin.backward_prefetch, "forward_prefetch": fsdp_plugin.forward_prefetch, "use_orig_params": fsdp_plugin.use_orig_params, "param_init_fn": fsdp_plugin.param_init_fn, "ignored_modules": fsdp_plugin.ignored_modules, "limit_all_gathers": fsdp_plugin.limit_all_gathers, "device_id": accelerator.device, } model = FSDP(model, **kwargs) elif fsdp_plugin.fsdp_version == 2: from torch.distributed.fsdp import MixedPrecisionPolicy, fully_shard mesh = getattr(accelerator, "torch_device_mesh", None) if Version(accelerate.__version__) >= Version("1.11.0"): ignored_params = get_parameters_from_modules(fsdp_plugin.ignored_modules, model, accelerator.device) else: warnings.warn( "FSDP version 2 is being used with accelerate version < 1.11.0, which may lead to incorrect " "handling of ignored modules. Please upgrade accelerate to v1.11.0 or later for proper support." ) ignored_params = None fully_shard( model, reshard_after_forward=fsdp_plugin.reshard_after_forward, offload_policy=fsdp_plugin.cpu_offload, # `fully_shard` doesn't accept `None` in case of `MixedPrecisionPolicy` mp_policy=fsdp_plugin.mixed_precision_policy or MixedPrecisionPolicy(), mesh=mesh[tuple(accelerator.parallelism_config.fsdp_dim_names)] if mesh is not None else None, ignored_params=ignored_params, ) else: raise ValueError(f"FSDP version {fsdp_plugin.fsdp_version} is not supported.") model.eval() return model class _ForwardRedirection: """Implements the `forward-redirection`. Taken from Pytorch-lightning: https://github.com/Lightning-AI/pytorch-lightning/blob/02311d03fb982560246eead7c08104481fac9579/src/lightning/pytorch/strategies/strategy.py#L602 A method call to a wrapped module gets rerouted through the wrapper's `forward` method instead. """ def __call__( self, wrapper_module: nn.Module, original_module: nn.Module, method: Callable, *args: Any, **kwargs: Any ): """Reroutes a method call through the `wrapper_module`'s `forward` method. Args: wrapper_module: The module that has `original_module` wrapped. original_module: The module that was wrapped inside `wrapper_module`. method: The method that should be called on the `original_module` after inputs get redirected through the `wrapper_module`'s `forward` method. *args: The positional arguments to the `method`. They will get passed to a patched `forward` method instead. **kwargs: The keyword arguments to the `method`. They will get passed to a patched `forward` method instead. """ original_forward = original_module.forward def wrapped_forward(*_args: Any, **_kwargs: Any) -> Any: # Unpatch ourselves immediately before calling the method `method_name` # because itself may want to call the real `forward` original_module.forward = original_forward # type: ignore[method-assign] # Call the actual method e.g. `.training_step(...)` out = method(*_args, **_kwargs) self.on_after_inner_forward(wrapper_module, original_module) return out # Patch the original_module's forward so we can redirect the arguments back to the real method original_module.forward = wrapped_forward # type: ignore[method-assign] wrapper_output = wrapper_module(*args, **kwargs) self.on_after_outer_forward(wrapper_module, original_module) return wrapper_output def on_after_inner_forward(self, wrapper_module: nn.Module, original_module: nn.Module) -> None: pass def on_after_outer_forward(self, wrapper_module: nn.Module, original_module: nn.Module) -> None: pass @contextmanager def disable_gradient_checkpointing(model: PreTrainedModel, gradient_checkpointing_kwargs: dict | None = None): """ Temporarily disable gradient checkpointing, restoring the previous state afterward. Args: model (`PreTrainedModel`): Model for which to temporarily disable gradient checkpointing. gradient_checkpointing_kwargs (`dict` or `None`, *optional*): Additional kwargs for gradient checkpointing enabling. """ was_enabled = model.is_gradient_checkpointing if was_enabled: model.gradient_checkpointing_disable() try: yield finally: if was_enabled: model.gradient_checkpointing_enable(gradient_checkpointing_kwargs) def create_reference_model( model: nn.Module, num_shared_layers: int | None = None, pattern: str | None = None ) -> nn.Module: warnings.warn( "The `create_reference_model` function is now located in `trl.experimental.utils`. Please update your " "imports to `from trl.experimental.utils import create_reference_model`. This import path will be removed in " "TRL 1.0.0.", FutureWarning, stacklevel=2, ) return _create_reference_model(model, num_shared_layers=num_shared_layers, pattern=pattern)