File size: 17,890 Bytes
1fa3c6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools
import warnings
from collections.abc import Callable
from contextlib import contextmanager
from copy import deepcopy
from typing import TYPE_CHECKING, Any

import accelerate
import torch.nn as nn
import transformers
from accelerate import Accelerator
from packaging.version import Version
from torch.distributed.fsdp import FSDPModule
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
from transformers import GenerationConfig, PreTrainedModel

from ..import_utils import suppress_experimental_warning


with suppress_experimental_warning():
    from ..experimental.utils import create_reference_model as _create_reference_model


if Version(accelerate.__version__) >= Version("1.11.0"):
    from accelerate.utils.fsdp_utils import get_parameters_from_modules

if TYPE_CHECKING:
    from deepspeed.runtime.engine import DeepSpeedEngine
    from torch.nn import Module
    from torch.nn.parallel.distributed import DistributedDataParallel


def remove_hooks(model: "DeepSpeedEngine") -> None:
    """Removes the optimizer hooks from a DeepSpeed ZeRO-3 model."""
    if not hasattr(model, "optimizer"):  # before the first training step, the model has no optimizer
        return
    if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"):
        optimizer_offload = model.optimizer.parameter_offload
    elif model.optimizer is not None:
        optimizer_offload = model.optimizer
    else:
        raise RuntimeError("The model optimizer is None, which is not yet supported.")

    for param in iter_params(optimizer_offload.module, recurse=True):
        param.ds_active_sub_modules.clear()

    for hook in optimizer_offload.forward_hooks:
        hook.remove()
    for hook in optimizer_offload.backward_hooks:
        hook.remove()

    optimizer_offload.forward_hooks = []
    optimizer_offload.backward_hooks = []


def get_all_parameters(sub_module, recurse=False):
    return itertools.chain(sub_module.named_parameters(recurse=recurse), sub_module.ds_external_parameters())


def iter_params(module, recurse=False):
    return [param for _, param in get_all_parameters(module, recurse)]


def add_hooks(model: "DeepSpeedEngine") -> None:
    """Adds the optimizer hooks from a DeepSpeed ZeRO-3 model."""
    import deepspeed

    if not hasattr(model, "optimizer"):  # before the first training step, the model has no optimizer
        return
    if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"):
        optimizer_offload = model.optimizer.parameter_offload
    elif model.optimizer is not None:
        optimizer_offload = model.optimizer
    else:
        raise RuntimeError("The model optimizer is None, which is not yet supported.")
    if Version(deepspeed.__version__) >= Version("0.16.4"):
        # Account for renaming in https://github.com/deepspeedai/DeepSpeed/pull/6847
        optimizer_offload._register_deepspeed_module(optimizer_offload.module)
    else:
        optimizer_offload._register_hooks_recursively(optimizer_offload.module)


@contextmanager
def _unwrap_model_for_generation(

    model: "DistributedDataParallel | DeepSpeedEngine",

    accelerator: "Accelerator",

    gather_deepspeed3_params: bool = True,

):
    """

    Context manager to unwrap distributed or accelerated models for generation tasks.



    Args:

        model (`DistributedDataParallel | DeepSpeedEngine`):

            Model to be unwrapped.

        accelerator ([`~accelerate.Accelerator`]):

            Accelerator instance managing the model.

        gather_deepspeed3_params (`bool`, *optional*, defaults to `True`):

            Whether to gather weights for DeepSpeed ZeRO Stage 3 models. If `False`, skips parameter gathering, which

            can be more memory-efficient but may lead to slower generation times.



    Yields:

        Unwrapped model.



    Example:

    ```python

    with _unwrap_model_for_generation(model, accelerator) as unwrapped_model:

        generated_outputs = unwrapped_model.generate(input_ids)

    ```

    """
    unwrapped_model = accelerator.unwrap_model(model)
    is_gradient_checkpointing = unwrapped_model.is_gradient_checkpointing
    if is_gradient_checkpointing:
        unwrapped_model.gradient_checkpointing_disable()
    if accelerator.state.deepspeed_plugin is not None and accelerator.state.deepspeed_plugin.zero_stage == 3:
        if not gather_deepspeed3_params:
            yield accelerator.unwrap_model(model)
        else:
            import deepspeed

            with deepspeed.zero.GatheredParameters(model.parameters()):
                remove_hooks(model)
                yield accelerator.unwrap_model(model)
                add_hooks(model)
    else:
        yield unwrapped_model
    if is_gradient_checkpointing:
        unwrapped_model.gradient_checkpointing_enable()


@contextmanager
def _override_model_generation_config(model, generation_kwargs=None):
    """

    Context manager to temporarily override a model's generation_config with training config.



    This works around transformers' config merging logic that would otherwise overwrite values matching global defaults

    with model-specific values (see upstream issue transformers#42762; fixed in transformers v5 by PR

    `transformers#42702`).



    By temporarily setting the model's generation_config to match the passed generation_config, we avoid the conflict.



    The model's original generation_config is preserved outside this context, ensuring that saved/pushed models retain

    their intended inference behavior.



    Args:

        model: The model (typically unwrapped_model) whose generation_config to temporarily override.

        generation_kwargs (dict): Generation kwargs to be used to override model's generation config.

    """
    if (
        # Issue fixed in transformers v5 by PR transformers#42702
        Version(transformers.__version__) >= Version("5.0.0")
        or generation_kwargs is None
        or not hasattr(model, "generation_config")
    ):
        yield model
        return
    # If it is a PEFT model, override the underlying base model
    if hasattr(model, "get_base_model"):
        model = model.get_base_model()
    # Keep original model generation_config
    original_config = model.generation_config
    # Create training-specific generation config from the model's original generation config
    # Then overwrite it with the training-specific generation kwargs
    generation_config = GenerationConfig.from_dict(model.generation_config.to_dict())
    generation_config.update(**generation_kwargs)
    model.generation_config = generation_config
    try:
        yield
    finally:
        model.generation_config = original_config


@contextmanager
def unwrap_model_for_generation(

    model: "DistributedDataParallel | DeepSpeedEngine",

    accelerator: "Accelerator",

    gather_deepspeed3_params: bool = True,

    generation_kwargs: dict | None = None,

):
    """

    Context manager to unwrap distributed or accelerated models for generation tasks.



    This function unwraps distributed models (FSDP, DeepSpeed) and optionally overrides the model's generation_config

    temporarily during generation. This is useful for applying training-specific generation parameters without

    permanently modifying the model's original generation_config.



    Args:

        model (`DistributedDataParallel | DeepSpeedEngine`):

            Model to be unwrapped.

        accelerator ([`~accelerate.Accelerator`]):

            Accelerator instance managing the model.

        gather_deepspeed3_params (`bool`, *optional*, defaults to `True`):

            Whether to gather weights for DeepSpeed ZeRO Stage 3 models. If `False`, skips parameter gathering, which

            can be more memory-efficient but may lead to slower generation times.

        generation_kwargs (dict, *optional*):

            If provided, temporarily overrides the model's generation_config during generation. The original config is

            automatically restored when exiting the context. This is useful for using different generation parameters

            during training vs. inference.



    Yields:

        Unwrapped model with optionally overridden generation_config.

    """
    with (
        _unwrap_model_for_generation(
            model, accelerator, gather_deepspeed3_params=gather_deepspeed3_params
        ) as unwrapped_model,
        _override_model_generation_config(unwrapped_model, generation_kwargs=generation_kwargs),
    ):
        yield unwrapped_model


def prepare_deepspeed(model: "Module", accelerator: "Accelerator"):
    """Prepares the model for DeepSpeed inference or evaluation by initializing it with the appropriate configuration.



    Adapted from accelerate:

    https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473

    """
    import deepspeed  # local import (instead of top-level) to avoid DS init interfering with other backends (like vllm): https://github.com/deepspeedai/DeepSpeed/issues/7252

    deepspeed_plugin = accelerator.state.deepspeed_plugin
    config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
    stage = config_kwargs["zero_optimization"]["stage"]

    if model is not None:
        hidden_size = (
            max(model.config.hidden_sizes)
            if getattr(model.config, "hidden_sizes", None)
            else getattr(model.config, "hidden_size", None)
        )
        if hidden_size is not None and stage == 3:
            # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache
            # @ step 0: expected module 1, but got module 0`
            # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
            config_kwargs.update(
                {
                    "zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
                    "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
                    "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
                }
            )

    # If ZeRO-3 is used, we shard both the active and reference model.
    # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO
    # disabled (stage 0)
    if stage != 3:
        config_kwargs["zero_optimization"]["stage"] = 0
    model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
    model.eval()
    return model


def prepare_fsdp(model, accelerator: Accelerator) -> FSDP | FSDPModule:
    # Check if the model is already a FSDP model due to `Manual Wrapping` and if so, don't wrap it again
    if not isinstance(model, (FSDP, FSDPModule)):
        fsdp_plugin = accelerator.state.fsdp_plugin
        if fsdp_plugin.fsdp_version == 1:
            accelerator.state.fsdp_plugin.set_auto_wrap_policy(model)
            kwargs = {
                "sharding_strategy": fsdp_plugin.sharding_strategy or fsdp_plugin.reshard_after_forward,
                "cpu_offload": fsdp_plugin.cpu_offload,
                "auto_wrap_policy": fsdp_plugin.auto_wrap_policy,
                "mixed_precision": fsdp_plugin.mixed_precision_policy,
                "sync_module_states": fsdp_plugin.sync_module_states,
                "backward_prefetch": fsdp_plugin.backward_prefetch,
                "forward_prefetch": fsdp_plugin.forward_prefetch,
                "use_orig_params": fsdp_plugin.use_orig_params,
                "param_init_fn": fsdp_plugin.param_init_fn,
                "ignored_modules": fsdp_plugin.ignored_modules,
                "limit_all_gathers": fsdp_plugin.limit_all_gathers,
                "device_id": accelerator.device,
            }
            model = FSDP(model, **kwargs)
        elif fsdp_plugin.fsdp_version == 2:
            from torch.distributed.fsdp import MixedPrecisionPolicy, fully_shard

            mesh = getattr(accelerator, "torch_device_mesh", None)
            if Version(accelerate.__version__) >= Version("1.11.0"):
                ignored_params = get_parameters_from_modules(fsdp_plugin.ignored_modules, model, accelerator.device)
            else:
                warnings.warn(
                    "FSDP version 2 is being used with accelerate version < 1.11.0, which may lead to incorrect "
                    "handling of ignored modules. Please upgrade accelerate to v1.11.0 or later for proper support."
                )
                ignored_params = None
            fully_shard(
                model,
                reshard_after_forward=fsdp_plugin.reshard_after_forward,
                offload_policy=fsdp_plugin.cpu_offload,
                # `fully_shard` doesn't accept `None` in case of `MixedPrecisionPolicy`
                mp_policy=fsdp_plugin.mixed_precision_policy or MixedPrecisionPolicy(),
                mesh=mesh[tuple(accelerator.parallelism_config.fsdp_dim_names)] if mesh is not None else None,
                ignored_params=ignored_params,
            )
        else:
            raise ValueError(f"FSDP version {fsdp_plugin.fsdp_version} is not supported.")
    model.eval()
    return model


class _ForwardRedirection:
    """Implements the `forward-redirection`.



    Taken from Pytorch-lightning:

    https://github.com/Lightning-AI/pytorch-lightning/blob/02311d03fb982560246eead7c08104481fac9579/src/lightning/pytorch/strategies/strategy.py#L602



    A method call to a wrapped module gets rerouted through the wrapper's `forward` method instead.



    """

    def __call__(

        self, wrapper_module: nn.Module, original_module: nn.Module, method: Callable, *args: Any, **kwargs: Any

    ):
        """Reroutes a method call through the `wrapper_module`'s `forward` method.



        Args:

            wrapper_module: The module that has `original_module` wrapped.

            original_module: The module that was wrapped inside `wrapper_module`.

            method: The method that should be called on the `original_module` after inputs get

                redirected through the `wrapper_module`'s `forward` method.

            *args: The positional arguments to the `method`. They will get passed to a patched

                `forward` method instead.

            **kwargs: The keyword arguments to the `method`. They will get passed to a patched

                `forward` method instead.



        """
        original_forward = original_module.forward

        def wrapped_forward(*_args: Any, **_kwargs: Any) -> Any:
            # Unpatch ourselves immediately before calling the method `method_name`
            # because itself may want to call the real `forward`
            original_module.forward = original_forward  # type: ignore[method-assign]
            # Call the actual method e.g. `.training_step(...)`
            out = method(*_args, **_kwargs)
            self.on_after_inner_forward(wrapper_module, original_module)
            return out

        # Patch the original_module's forward so we can redirect the arguments back to the real method
        original_module.forward = wrapped_forward  # type: ignore[method-assign]

        wrapper_output = wrapper_module(*args, **kwargs)
        self.on_after_outer_forward(wrapper_module, original_module)
        return wrapper_output

    def on_after_inner_forward(self, wrapper_module: nn.Module, original_module: nn.Module) -> None:
        pass

    def on_after_outer_forward(self, wrapper_module: nn.Module, original_module: nn.Module) -> None:
        pass


@contextmanager
def disable_gradient_checkpointing(model: PreTrainedModel, gradient_checkpointing_kwargs: dict | None = None):
    """

    Temporarily disable gradient checkpointing, restoring the previous state afterward.



    Args:

        model (`PreTrainedModel`):

            Model for which to temporarily disable gradient checkpointing.

        gradient_checkpointing_kwargs (`dict` or `None`, *optional*):

            Additional kwargs for gradient checkpointing enabling.

    """
    was_enabled = model.is_gradient_checkpointing
    if was_enabled:
        model.gradient_checkpointing_disable()
    try:
        yield
    finally:
        if was_enabled:
            model.gradient_checkpointing_enable(gradient_checkpointing_kwargs)


def create_reference_model(

    model: nn.Module, num_shared_layers: int | None = None, pattern: str | None = None

) -> nn.Module:
    warnings.warn(
        "The `create_reference_model` function is now located in `trl.experimental.utils`. Please update your "
        "imports to `from trl.experimental.utils import create_reference_model`. This import path will be removed in "
        "TRL 1.0.0.",
        FutureWarning,
        stacklevel=2,
    )
    return _create_reference_model(model, num_shared_layers=num_shared_layers, pattern=pattern)