File size: 17,890 Bytes
1fa3c6c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 | # Copyright 2020-2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import warnings
from collections.abc import Callable
from contextlib import contextmanager
from copy import deepcopy
from typing import TYPE_CHECKING, Any
import accelerate
import torch.nn as nn
import transformers
from accelerate import Accelerator
from packaging.version import Version
from torch.distributed.fsdp import FSDPModule
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
from transformers import GenerationConfig, PreTrainedModel
from ..import_utils import suppress_experimental_warning
with suppress_experimental_warning():
from ..experimental.utils import create_reference_model as _create_reference_model
if Version(accelerate.__version__) >= Version("1.11.0"):
from accelerate.utils.fsdp_utils import get_parameters_from_modules
if TYPE_CHECKING:
from deepspeed.runtime.engine import DeepSpeedEngine
from torch.nn import Module
from torch.nn.parallel.distributed import DistributedDataParallel
def remove_hooks(model: "DeepSpeedEngine") -> None:
"""Removes the optimizer hooks from a DeepSpeed ZeRO-3 model."""
if not hasattr(model, "optimizer"): # before the first training step, the model has no optimizer
return
if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"):
optimizer_offload = model.optimizer.parameter_offload
elif model.optimizer is not None:
optimizer_offload = model.optimizer
else:
raise RuntimeError("The model optimizer is None, which is not yet supported.")
for param in iter_params(optimizer_offload.module, recurse=True):
param.ds_active_sub_modules.clear()
for hook in optimizer_offload.forward_hooks:
hook.remove()
for hook in optimizer_offload.backward_hooks:
hook.remove()
optimizer_offload.forward_hooks = []
optimizer_offload.backward_hooks = []
def get_all_parameters(sub_module, recurse=False):
return itertools.chain(sub_module.named_parameters(recurse=recurse), sub_module.ds_external_parameters())
def iter_params(module, recurse=False):
return [param for _, param in get_all_parameters(module, recurse)]
def add_hooks(model: "DeepSpeedEngine") -> None:
"""Adds the optimizer hooks from a DeepSpeed ZeRO-3 model."""
import deepspeed
if not hasattr(model, "optimizer"): # before the first training step, the model has no optimizer
return
if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"):
optimizer_offload = model.optimizer.parameter_offload
elif model.optimizer is not None:
optimizer_offload = model.optimizer
else:
raise RuntimeError("The model optimizer is None, which is not yet supported.")
if Version(deepspeed.__version__) >= Version("0.16.4"):
# Account for renaming in https://github.com/deepspeedai/DeepSpeed/pull/6847
optimizer_offload._register_deepspeed_module(optimizer_offload.module)
else:
optimizer_offload._register_hooks_recursively(optimizer_offload.module)
@contextmanager
def _unwrap_model_for_generation(
model: "DistributedDataParallel | DeepSpeedEngine",
accelerator: "Accelerator",
gather_deepspeed3_params: bool = True,
):
"""
Context manager to unwrap distributed or accelerated models for generation tasks.
Args:
model (`DistributedDataParallel | DeepSpeedEngine`):
Model to be unwrapped.
accelerator ([`~accelerate.Accelerator`]):
Accelerator instance managing the model.
gather_deepspeed3_params (`bool`, *optional*, defaults to `True`):
Whether to gather weights for DeepSpeed ZeRO Stage 3 models. If `False`, skips parameter gathering, which
can be more memory-efficient but may lead to slower generation times.
Yields:
Unwrapped model.
Example:
```python
with _unwrap_model_for_generation(model, accelerator) as unwrapped_model:
generated_outputs = unwrapped_model.generate(input_ids)
```
"""
unwrapped_model = accelerator.unwrap_model(model)
is_gradient_checkpointing = unwrapped_model.is_gradient_checkpointing
if is_gradient_checkpointing:
unwrapped_model.gradient_checkpointing_disable()
if accelerator.state.deepspeed_plugin is not None and accelerator.state.deepspeed_plugin.zero_stage == 3:
if not gather_deepspeed3_params:
yield accelerator.unwrap_model(model)
else:
import deepspeed
with deepspeed.zero.GatheredParameters(model.parameters()):
remove_hooks(model)
yield accelerator.unwrap_model(model)
add_hooks(model)
else:
yield unwrapped_model
if is_gradient_checkpointing:
unwrapped_model.gradient_checkpointing_enable()
@contextmanager
def _override_model_generation_config(model, generation_kwargs=None):
"""
Context manager to temporarily override a model's generation_config with training config.
This works around transformers' config merging logic that would otherwise overwrite values matching global defaults
with model-specific values (see upstream issue transformers#42762; fixed in transformers v5 by PR
`transformers#42702`).
By temporarily setting the model's generation_config to match the passed generation_config, we avoid the conflict.
The model's original generation_config is preserved outside this context, ensuring that saved/pushed models retain
their intended inference behavior.
Args:
model: The model (typically unwrapped_model) whose generation_config to temporarily override.
generation_kwargs (dict): Generation kwargs to be used to override model's generation config.
"""
if (
# Issue fixed in transformers v5 by PR transformers#42702
Version(transformers.__version__) >= Version("5.0.0")
or generation_kwargs is None
or not hasattr(model, "generation_config")
):
yield model
return
# If it is a PEFT model, override the underlying base model
if hasattr(model, "get_base_model"):
model = model.get_base_model()
# Keep original model generation_config
original_config = model.generation_config
# Create training-specific generation config from the model's original generation config
# Then overwrite it with the training-specific generation kwargs
generation_config = GenerationConfig.from_dict(model.generation_config.to_dict())
generation_config.update(**generation_kwargs)
model.generation_config = generation_config
try:
yield
finally:
model.generation_config = original_config
@contextmanager
def unwrap_model_for_generation(
model: "DistributedDataParallel | DeepSpeedEngine",
accelerator: "Accelerator",
gather_deepspeed3_params: bool = True,
generation_kwargs: dict | None = None,
):
"""
Context manager to unwrap distributed or accelerated models for generation tasks.
This function unwraps distributed models (FSDP, DeepSpeed) and optionally overrides the model's generation_config
temporarily during generation. This is useful for applying training-specific generation parameters without
permanently modifying the model's original generation_config.
Args:
model (`DistributedDataParallel | DeepSpeedEngine`):
Model to be unwrapped.
accelerator ([`~accelerate.Accelerator`]):
Accelerator instance managing the model.
gather_deepspeed3_params (`bool`, *optional*, defaults to `True`):
Whether to gather weights for DeepSpeed ZeRO Stage 3 models. If `False`, skips parameter gathering, which
can be more memory-efficient but may lead to slower generation times.
generation_kwargs (dict, *optional*):
If provided, temporarily overrides the model's generation_config during generation. The original config is
automatically restored when exiting the context. This is useful for using different generation parameters
during training vs. inference.
Yields:
Unwrapped model with optionally overridden generation_config.
"""
with (
_unwrap_model_for_generation(
model, accelerator, gather_deepspeed3_params=gather_deepspeed3_params
) as unwrapped_model,
_override_model_generation_config(unwrapped_model, generation_kwargs=generation_kwargs),
):
yield unwrapped_model
def prepare_deepspeed(model: "Module", accelerator: "Accelerator"):
"""Prepares the model for DeepSpeed inference or evaluation by initializing it with the appropriate configuration.
Adapted from accelerate:
https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
"""
import deepspeed # local import (instead of top-level) to avoid DS init interfering with other backends (like vllm): https://github.com/deepspeedai/DeepSpeed/issues/7252
deepspeed_plugin = accelerator.state.deepspeed_plugin
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
stage = config_kwargs["zero_optimization"]["stage"]
if model is not None:
hidden_size = (
max(model.config.hidden_sizes)
if getattr(model.config, "hidden_sizes", None)
else getattr(model.config, "hidden_size", None)
)
if hidden_size is not None and stage == 3:
# Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache
# @ step 0: expected module 1, but got module 0`
# This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
config_kwargs.update(
{
"zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
"zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
"zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
}
)
# If ZeRO-3 is used, we shard both the active and reference model.
# Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO
# disabled (stage 0)
if stage != 3:
config_kwargs["zero_optimization"]["stage"] = 0
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
model.eval()
return model
def prepare_fsdp(model, accelerator: Accelerator) -> FSDP | FSDPModule:
# Check if the model is already a FSDP model due to `Manual Wrapping` and if so, don't wrap it again
if not isinstance(model, (FSDP, FSDPModule)):
fsdp_plugin = accelerator.state.fsdp_plugin
if fsdp_plugin.fsdp_version == 1:
accelerator.state.fsdp_plugin.set_auto_wrap_policy(model)
kwargs = {
"sharding_strategy": fsdp_plugin.sharding_strategy or fsdp_plugin.reshard_after_forward,
"cpu_offload": fsdp_plugin.cpu_offload,
"auto_wrap_policy": fsdp_plugin.auto_wrap_policy,
"mixed_precision": fsdp_plugin.mixed_precision_policy,
"sync_module_states": fsdp_plugin.sync_module_states,
"backward_prefetch": fsdp_plugin.backward_prefetch,
"forward_prefetch": fsdp_plugin.forward_prefetch,
"use_orig_params": fsdp_plugin.use_orig_params,
"param_init_fn": fsdp_plugin.param_init_fn,
"ignored_modules": fsdp_plugin.ignored_modules,
"limit_all_gathers": fsdp_plugin.limit_all_gathers,
"device_id": accelerator.device,
}
model = FSDP(model, **kwargs)
elif fsdp_plugin.fsdp_version == 2:
from torch.distributed.fsdp import MixedPrecisionPolicy, fully_shard
mesh = getattr(accelerator, "torch_device_mesh", None)
if Version(accelerate.__version__) >= Version("1.11.0"):
ignored_params = get_parameters_from_modules(fsdp_plugin.ignored_modules, model, accelerator.device)
else:
warnings.warn(
"FSDP version 2 is being used with accelerate version < 1.11.0, which may lead to incorrect "
"handling of ignored modules. Please upgrade accelerate to v1.11.0 or later for proper support."
)
ignored_params = None
fully_shard(
model,
reshard_after_forward=fsdp_plugin.reshard_after_forward,
offload_policy=fsdp_plugin.cpu_offload,
# `fully_shard` doesn't accept `None` in case of `MixedPrecisionPolicy`
mp_policy=fsdp_plugin.mixed_precision_policy or MixedPrecisionPolicy(),
mesh=mesh[tuple(accelerator.parallelism_config.fsdp_dim_names)] if mesh is not None else None,
ignored_params=ignored_params,
)
else:
raise ValueError(f"FSDP version {fsdp_plugin.fsdp_version} is not supported.")
model.eval()
return model
class _ForwardRedirection:
"""Implements the `forward-redirection`.
Taken from Pytorch-lightning:
https://github.com/Lightning-AI/pytorch-lightning/blob/02311d03fb982560246eead7c08104481fac9579/src/lightning/pytorch/strategies/strategy.py#L602
A method call to a wrapped module gets rerouted through the wrapper's `forward` method instead.
"""
def __call__(
self, wrapper_module: nn.Module, original_module: nn.Module, method: Callable, *args: Any, **kwargs: Any
):
"""Reroutes a method call through the `wrapper_module`'s `forward` method.
Args:
wrapper_module: The module that has `original_module` wrapped.
original_module: The module that was wrapped inside `wrapper_module`.
method: The method that should be called on the `original_module` after inputs get
redirected through the `wrapper_module`'s `forward` method.
*args: The positional arguments to the `method`. They will get passed to a patched
`forward` method instead.
**kwargs: The keyword arguments to the `method`. They will get passed to a patched
`forward` method instead.
"""
original_forward = original_module.forward
def wrapped_forward(*_args: Any, **_kwargs: Any) -> Any:
# Unpatch ourselves immediately before calling the method `method_name`
# because itself may want to call the real `forward`
original_module.forward = original_forward # type: ignore[method-assign]
# Call the actual method e.g. `.training_step(...)`
out = method(*_args, **_kwargs)
self.on_after_inner_forward(wrapper_module, original_module)
return out
# Patch the original_module's forward so we can redirect the arguments back to the real method
original_module.forward = wrapped_forward # type: ignore[method-assign]
wrapper_output = wrapper_module(*args, **kwargs)
self.on_after_outer_forward(wrapper_module, original_module)
return wrapper_output
def on_after_inner_forward(self, wrapper_module: nn.Module, original_module: nn.Module) -> None:
pass
def on_after_outer_forward(self, wrapper_module: nn.Module, original_module: nn.Module) -> None:
pass
@contextmanager
def disable_gradient_checkpointing(model: PreTrainedModel, gradient_checkpointing_kwargs: dict | None = None):
"""
Temporarily disable gradient checkpointing, restoring the previous state afterward.
Args:
model (`PreTrainedModel`):
Model for which to temporarily disable gradient checkpointing.
gradient_checkpointing_kwargs (`dict` or `None`, *optional*):
Additional kwargs for gradient checkpointing enabling.
"""
was_enabled = model.is_gradient_checkpointing
if was_enabled:
model.gradient_checkpointing_disable()
try:
yield
finally:
if was_enabled:
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs)
def create_reference_model(
model: nn.Module, num_shared_layers: int | None = None, pattern: str | None = None
) -> nn.Module:
warnings.warn(
"The `create_reference_model` function is now located in `trl.experimental.utils`. Please update your "
"imports to `from trl.experimental.utils import create_reference_model`. This import path will be removed in "
"TRL 1.0.0.",
FutureWarning,
stacklevel=2,
)
return _create_reference_model(model, num_shared_layers=num_shared_layers, pattern=pattern)
|