|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| import itertools
|
| import warnings
|
| from collections.abc import Callable
|
| from contextlib import contextmanager
|
| from copy import deepcopy
|
| from typing import TYPE_CHECKING, Any
|
|
|
| import accelerate
|
| import torch.nn as nn
|
| import transformers
|
| from accelerate import Accelerator
|
| from packaging.version import Version
|
| from torch.distributed.fsdp import FSDPModule
|
| from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
|
| from transformers import GenerationConfig, PreTrainedModel
|
|
|
| from ..import_utils import suppress_experimental_warning
|
|
|
|
|
| with suppress_experimental_warning():
|
| from ..experimental.utils import create_reference_model as _create_reference_model
|
|
|
|
|
| if Version(accelerate.__version__) >= Version("1.11.0"):
|
| from accelerate.utils.fsdp_utils import get_parameters_from_modules
|
|
|
| if TYPE_CHECKING:
|
| from deepspeed.runtime.engine import DeepSpeedEngine
|
| from torch.nn import Module
|
| from torch.nn.parallel.distributed import DistributedDataParallel
|
|
|
|
|
| def remove_hooks(model: "DeepSpeedEngine") -> None:
|
| """Removes the optimizer hooks from a DeepSpeed ZeRO-3 model."""
|
| if not hasattr(model, "optimizer"):
|
| return
|
| if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"):
|
| optimizer_offload = model.optimizer.parameter_offload
|
| elif model.optimizer is not None:
|
| optimizer_offload = model.optimizer
|
| else:
|
| raise RuntimeError("The model optimizer is None, which is not yet supported.")
|
|
|
| for param in iter_params(optimizer_offload.module, recurse=True):
|
| param.ds_active_sub_modules.clear()
|
|
|
| for hook in optimizer_offload.forward_hooks:
|
| hook.remove()
|
| for hook in optimizer_offload.backward_hooks:
|
| hook.remove()
|
|
|
| optimizer_offload.forward_hooks = []
|
| optimizer_offload.backward_hooks = []
|
|
|
|
|
| def get_all_parameters(sub_module, recurse=False):
|
| return itertools.chain(sub_module.named_parameters(recurse=recurse), sub_module.ds_external_parameters())
|
|
|
|
|
| def iter_params(module, recurse=False):
|
| return [param for _, param in get_all_parameters(module, recurse)]
|
|
|
|
|
| def add_hooks(model: "DeepSpeedEngine") -> None:
|
| """Adds the optimizer hooks from a DeepSpeed ZeRO-3 model."""
|
| import deepspeed
|
|
|
| if not hasattr(model, "optimizer"):
|
| return
|
| if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"):
|
| optimizer_offload = model.optimizer.parameter_offload
|
| elif model.optimizer is not None:
|
| optimizer_offload = model.optimizer
|
| else:
|
| raise RuntimeError("The model optimizer is None, which is not yet supported.")
|
| if Version(deepspeed.__version__) >= Version("0.16.4"):
|
|
|
| optimizer_offload._register_deepspeed_module(optimizer_offload.module)
|
| else:
|
| optimizer_offload._register_hooks_recursively(optimizer_offload.module)
|
|
|
|
|
| @contextmanager
|
| def _unwrap_model_for_generation(
|
| model: "DistributedDataParallel | DeepSpeedEngine",
|
| accelerator: "Accelerator",
|
| gather_deepspeed3_params: bool = True,
|
| ):
|
| """
|
| Context manager to unwrap distributed or accelerated models for generation tasks.
|
|
|
| Args:
|
| model (`DistributedDataParallel | DeepSpeedEngine`):
|
| Model to be unwrapped.
|
| accelerator ([`~accelerate.Accelerator`]):
|
| Accelerator instance managing the model.
|
| gather_deepspeed3_params (`bool`, *optional*, defaults to `True`):
|
| Whether to gather weights for DeepSpeed ZeRO Stage 3 models. If `False`, skips parameter gathering, which
|
| can be more memory-efficient but may lead to slower generation times.
|
|
|
| Yields:
|
| Unwrapped model.
|
|
|
| Example:
|
| ```python
|
| with _unwrap_model_for_generation(model, accelerator) as unwrapped_model:
|
| generated_outputs = unwrapped_model.generate(input_ids)
|
| ```
|
| """
|
| unwrapped_model = accelerator.unwrap_model(model)
|
| is_gradient_checkpointing = unwrapped_model.is_gradient_checkpointing
|
| if is_gradient_checkpointing:
|
| unwrapped_model.gradient_checkpointing_disable()
|
| if accelerator.state.deepspeed_plugin is not None and accelerator.state.deepspeed_plugin.zero_stage == 3:
|
| if not gather_deepspeed3_params:
|
| yield accelerator.unwrap_model(model)
|
| else:
|
| import deepspeed
|
|
|
| with deepspeed.zero.GatheredParameters(model.parameters()):
|
| remove_hooks(model)
|
| yield accelerator.unwrap_model(model)
|
| add_hooks(model)
|
| else:
|
| yield unwrapped_model
|
| if is_gradient_checkpointing:
|
| unwrapped_model.gradient_checkpointing_enable()
|
|
|
|
|
| @contextmanager
|
| def _override_model_generation_config(model, generation_kwargs=None):
|
| """
|
| Context manager to temporarily override a model's generation_config with training config.
|
|
|
| This works around transformers' config merging logic that would otherwise overwrite values matching global defaults
|
| with model-specific values (see upstream issue transformers#42762; fixed in transformers v5 by PR
|
| `transformers#42702`).
|
|
|
| By temporarily setting the model's generation_config to match the passed generation_config, we avoid the conflict.
|
|
|
| The model's original generation_config is preserved outside this context, ensuring that saved/pushed models retain
|
| their intended inference behavior.
|
|
|
| Args:
|
| model: The model (typically unwrapped_model) whose generation_config to temporarily override.
|
| generation_kwargs (dict): Generation kwargs to be used to override model's generation config.
|
| """
|
| if (
|
|
|
| Version(transformers.__version__) >= Version("5.0.0")
|
| or generation_kwargs is None
|
| or not hasattr(model, "generation_config")
|
| ):
|
| yield model
|
| return
|
|
|
| if hasattr(model, "get_base_model"):
|
| model = model.get_base_model()
|
|
|
| original_config = model.generation_config
|
|
|
|
|
| generation_config = GenerationConfig.from_dict(model.generation_config.to_dict())
|
| generation_config.update(**generation_kwargs)
|
| model.generation_config = generation_config
|
| try:
|
| yield
|
| finally:
|
| model.generation_config = original_config
|
|
|
|
|
| @contextmanager
|
| def unwrap_model_for_generation(
|
| model: "DistributedDataParallel | DeepSpeedEngine",
|
| accelerator: "Accelerator",
|
| gather_deepspeed3_params: bool = True,
|
| generation_kwargs: dict | None = None,
|
| ):
|
| """
|
| Context manager to unwrap distributed or accelerated models for generation tasks.
|
|
|
| This function unwraps distributed models (FSDP, DeepSpeed) and optionally overrides the model's generation_config
|
| temporarily during generation. This is useful for applying training-specific generation parameters without
|
| permanently modifying the model's original generation_config.
|
|
|
| Args:
|
| model (`DistributedDataParallel | DeepSpeedEngine`):
|
| Model to be unwrapped.
|
| accelerator ([`~accelerate.Accelerator`]):
|
| Accelerator instance managing the model.
|
| gather_deepspeed3_params (`bool`, *optional*, defaults to `True`):
|
| Whether to gather weights for DeepSpeed ZeRO Stage 3 models. If `False`, skips parameter gathering, which
|
| can be more memory-efficient but may lead to slower generation times.
|
| generation_kwargs (dict, *optional*):
|
| If provided, temporarily overrides the model's generation_config during generation. The original config is
|
| automatically restored when exiting the context. This is useful for using different generation parameters
|
| during training vs. inference.
|
|
|
| Yields:
|
| Unwrapped model with optionally overridden generation_config.
|
| """
|
| with (
|
| _unwrap_model_for_generation(
|
| model, accelerator, gather_deepspeed3_params=gather_deepspeed3_params
|
| ) as unwrapped_model,
|
| _override_model_generation_config(unwrapped_model, generation_kwargs=generation_kwargs),
|
| ):
|
| yield unwrapped_model
|
|
|
|
|
| def prepare_deepspeed(model: "Module", accelerator: "Accelerator"):
|
| """Prepares the model for DeepSpeed inference or evaluation by initializing it with the appropriate configuration.
|
|
|
| Adapted from accelerate:
|
| https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
|
| """
|
| import deepspeed
|
|
|
| deepspeed_plugin = accelerator.state.deepspeed_plugin
|
| config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
|
| stage = config_kwargs["zero_optimization"]["stage"]
|
|
|
| if model is not None:
|
| hidden_size = (
|
| max(model.config.hidden_sizes)
|
| if getattr(model.config, "hidden_sizes", None)
|
| else getattr(model.config, "hidden_size", None)
|
| )
|
| if hidden_size is not None and stage == 3:
|
|
|
|
|
|
|
| config_kwargs.update(
|
| {
|
| "zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
|
| "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
|
| "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
|
| }
|
| )
|
|
|
|
|
|
|
|
|
| if stage != 3:
|
| config_kwargs["zero_optimization"]["stage"] = 0
|
| model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
|
| model.eval()
|
| return model
|
|
|
|
|
| def prepare_fsdp(model, accelerator: Accelerator) -> FSDP | FSDPModule:
|
|
|
| if not isinstance(model, (FSDP, FSDPModule)):
|
| fsdp_plugin = accelerator.state.fsdp_plugin
|
| if fsdp_plugin.fsdp_version == 1:
|
| accelerator.state.fsdp_plugin.set_auto_wrap_policy(model)
|
| kwargs = {
|
| "sharding_strategy": fsdp_plugin.sharding_strategy or fsdp_plugin.reshard_after_forward,
|
| "cpu_offload": fsdp_plugin.cpu_offload,
|
| "auto_wrap_policy": fsdp_plugin.auto_wrap_policy,
|
| "mixed_precision": fsdp_plugin.mixed_precision_policy,
|
| "sync_module_states": fsdp_plugin.sync_module_states,
|
| "backward_prefetch": fsdp_plugin.backward_prefetch,
|
| "forward_prefetch": fsdp_plugin.forward_prefetch,
|
| "use_orig_params": fsdp_plugin.use_orig_params,
|
| "param_init_fn": fsdp_plugin.param_init_fn,
|
| "ignored_modules": fsdp_plugin.ignored_modules,
|
| "limit_all_gathers": fsdp_plugin.limit_all_gathers,
|
| "device_id": accelerator.device,
|
| }
|
| model = FSDP(model, **kwargs)
|
| elif fsdp_plugin.fsdp_version == 2:
|
| from torch.distributed.fsdp import MixedPrecisionPolicy, fully_shard
|
|
|
| mesh = getattr(accelerator, "torch_device_mesh", None)
|
| if Version(accelerate.__version__) >= Version("1.11.0"):
|
| ignored_params = get_parameters_from_modules(fsdp_plugin.ignored_modules, model, accelerator.device)
|
| else:
|
| warnings.warn(
|
| "FSDP version 2 is being used with accelerate version < 1.11.0, which may lead to incorrect "
|
| "handling of ignored modules. Please upgrade accelerate to v1.11.0 or later for proper support."
|
| )
|
| ignored_params = None
|
| fully_shard(
|
| model,
|
| reshard_after_forward=fsdp_plugin.reshard_after_forward,
|
| offload_policy=fsdp_plugin.cpu_offload,
|
|
|
| mp_policy=fsdp_plugin.mixed_precision_policy or MixedPrecisionPolicy(),
|
| mesh=mesh[tuple(accelerator.parallelism_config.fsdp_dim_names)] if mesh is not None else None,
|
| ignored_params=ignored_params,
|
| )
|
| else:
|
| raise ValueError(f"FSDP version {fsdp_plugin.fsdp_version} is not supported.")
|
| model.eval()
|
| return model
|
|
|
|
|
| class _ForwardRedirection:
|
| """Implements the `forward-redirection`.
|
|
|
| Taken from Pytorch-lightning:
|
| https://github.com/Lightning-AI/pytorch-lightning/blob/02311d03fb982560246eead7c08104481fac9579/src/lightning/pytorch/strategies/strategy.py#L602
|
|
|
| A method call to a wrapped module gets rerouted through the wrapper's `forward` method instead.
|
|
|
| """
|
|
|
| def __call__(
|
| self, wrapper_module: nn.Module, original_module: nn.Module, method: Callable, *args: Any, **kwargs: Any
|
| ):
|
| """Reroutes a method call through the `wrapper_module`'s `forward` method.
|
|
|
| Args:
|
| wrapper_module: The module that has `original_module` wrapped.
|
| original_module: The module that was wrapped inside `wrapper_module`.
|
| method: The method that should be called on the `original_module` after inputs get
|
| redirected through the `wrapper_module`'s `forward` method.
|
| *args: The positional arguments to the `method`. They will get passed to a patched
|
| `forward` method instead.
|
| **kwargs: The keyword arguments to the `method`. They will get passed to a patched
|
| `forward` method instead.
|
|
|
| """
|
| original_forward = original_module.forward
|
|
|
| def wrapped_forward(*_args: Any, **_kwargs: Any) -> Any:
|
|
|
|
|
| original_module.forward = original_forward
|
|
|
| out = method(*_args, **_kwargs)
|
| self.on_after_inner_forward(wrapper_module, original_module)
|
| return out
|
|
|
|
|
| original_module.forward = wrapped_forward
|
|
|
| wrapper_output = wrapper_module(*args, **kwargs)
|
| self.on_after_outer_forward(wrapper_module, original_module)
|
| return wrapper_output
|
|
|
| def on_after_inner_forward(self, wrapper_module: nn.Module, original_module: nn.Module) -> None:
|
| pass
|
|
|
| def on_after_outer_forward(self, wrapper_module: nn.Module, original_module: nn.Module) -> None:
|
| pass
|
|
|
|
|
| @contextmanager
|
| def disable_gradient_checkpointing(model: PreTrainedModel, gradient_checkpointing_kwargs: dict | None = None):
|
| """
|
| Temporarily disable gradient checkpointing, restoring the previous state afterward.
|
|
|
| Args:
|
| model (`PreTrainedModel`):
|
| Model for which to temporarily disable gradient checkpointing.
|
| gradient_checkpointing_kwargs (`dict` or `None`, *optional*):
|
| Additional kwargs for gradient checkpointing enabling.
|
| """
|
| was_enabled = model.is_gradient_checkpointing
|
| if was_enabled:
|
| model.gradient_checkpointing_disable()
|
| try:
|
| yield
|
| finally:
|
| if was_enabled:
|
| model.gradient_checkpointing_enable(gradient_checkpointing_kwargs)
|
|
|
|
|
| def create_reference_model(
|
| model: nn.Module, num_shared_layers: int | None = None, pattern: str | None = None
|
| ) -> nn.Module:
|
| warnings.warn(
|
| "The `create_reference_model` function is now located in `trl.experimental.utils`. Please update your "
|
| "imports to `from trl.experimental.utils import create_reference_model`. This import path will be removed in "
|
| "TRL 1.0.0.",
|
| FutureWarning,
|
| stacklevel=2,
|
| )
|
| return _create_reference_model(model, num_shared_layers=num_shared_layers, pattern=pattern)
|
|
|