Add files using upload-large-folder tool
Browse files- pythonProject/.venv/Lib/site-packages/diffusers/guiders/__init__.py +41 -0
- pythonProject/.venv/Lib/site-packages/diffusers/guiders/__pycache__/__init__.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/guiders/__pycache__/adaptive_projected_guidance.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/guiders/__pycache__/auto_guidance.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/guiders/__pycache__/classifier_free_guidance.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/guiders/__pycache__/classifier_free_zero_star_guidance.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/guiders/__pycache__/frequency_decoupled_guidance.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/guiders/__pycache__/guider_utils.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/guiders/__pycache__/perturbed_attention_guidance.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/guiders/__pycache__/skip_layer_guidance.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/guiders/__pycache__/smoothed_energy_guidance.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/guiders/__pycache__/tangential_classifier_free_guidance.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/guiders/guider_utils.py +315 -0
- pythonProject/.venv/Lib/site-packages/diffusers/guiders/perturbed_attention_guidance.py +271 -0
- pythonProject/.venv/Lib/site-packages/diffusers/guiders/skip_layer_guidance.py +262 -0
- pythonProject/.venv/Lib/site-packages/diffusers/guiders/smoothed_energy_guidance.py +251 -0
- pythonProject/.venv/Lib/site-packages/diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
- pythonProject/.venv/Lib/site-packages/diffusers/hooks/faster_cache.py +654 -0
- pythonProject/.venv/Lib/site-packages/diffusers/hooks/first_block_cache.py +259 -0
- pythonProject/.venv/Lib/site-packages/diffusers/hooks/group_offloading.py +898 -0
pythonProject/.venv/Lib/site-packages/diffusers/guiders/__init__.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from typing import Union
|
| 16 |
+
|
| 17 |
+
from ..utils import is_torch_available
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if is_torch_available():
|
| 21 |
+
from .adaptive_projected_guidance import AdaptiveProjectedGuidance
|
| 22 |
+
from .auto_guidance import AutoGuidance
|
| 23 |
+
from .classifier_free_guidance import ClassifierFreeGuidance
|
| 24 |
+
from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance
|
| 25 |
+
from .frequency_decoupled_guidance import FrequencyDecoupledGuidance
|
| 26 |
+
from .perturbed_attention_guidance import PerturbedAttentionGuidance
|
| 27 |
+
from .skip_layer_guidance import SkipLayerGuidance
|
| 28 |
+
from .smoothed_energy_guidance import SmoothedEnergyGuidance
|
| 29 |
+
from .tangential_classifier_free_guidance import TangentialClassifierFreeGuidance
|
| 30 |
+
|
| 31 |
+
GuiderType = Union[
|
| 32 |
+
AdaptiveProjectedGuidance,
|
| 33 |
+
AutoGuidance,
|
| 34 |
+
ClassifierFreeGuidance,
|
| 35 |
+
ClassifierFreeZeroStarGuidance,
|
| 36 |
+
FrequencyDecoupledGuidance,
|
| 37 |
+
PerturbedAttentionGuidance,
|
| 38 |
+
SkipLayerGuidance,
|
| 39 |
+
SmoothedEnergyGuidance,
|
| 40 |
+
TangentialClassifierFreeGuidance,
|
| 41 |
+
]
|
pythonProject/.venv/Lib/site-packages/diffusers/guiders/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (1.01 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/guiders/__pycache__/adaptive_projected_guidance.cpython-310.pyc
ADDED
|
Binary file (6.65 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/guiders/__pycache__/auto_guidance.cpython-310.pyc
ADDED
|
Binary file (7.43 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/guiders/__pycache__/classifier_free_guidance.cpython-310.pyc
ADDED
|
Binary file (5.93 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/guiders/__pycache__/classifier_free_zero_star_guidance.cpython-310.pyc
ADDED
|
Binary file (5.81 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/guiders/__pycache__/frequency_decoupled_guidance.cpython-310.pyc
ADDED
|
Binary file (12.3 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/guiders/__pycache__/guider_utils.cpython-310.pyc
ADDED
|
Binary file (14.9 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/guiders/__pycache__/perturbed_attention_guidance.cpython-310.pyc
ADDED
|
Binary file (9.77 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/guiders/__pycache__/skip_layer_guidance.cpython-310.pyc
ADDED
|
Binary file (10.3 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/guiders/__pycache__/smoothed_energy_guidance.cpython-310.pyc
ADDED
|
Binary file (9.95 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/guiders/__pycache__/tangential_classifier_free_guidance.cpython-310.pyc
ADDED
|
Binary file (5.16 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/guiders/guider_utils.py
ADDED
|
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
from huggingface_hub.utils import validate_hf_hub_args
|
| 20 |
+
from typing_extensions import Self
|
| 21 |
+
|
| 22 |
+
from ..configuration_utils import ConfigMixin
|
| 23 |
+
from ..utils import BaseOutput, PushToHubMixin, get_logger
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
if TYPE_CHECKING:
|
| 27 |
+
from ..modular_pipelines.modular_pipeline import BlockState
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
GUIDER_CONFIG_NAME = "guider_config.json"
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
logger = get_logger(__name__) # pylint: disable=invalid-name
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class BaseGuidance(ConfigMixin, PushToHubMixin):
|
| 37 |
+
r"""Base class providing the skeleton for implementing guidance techniques."""
|
| 38 |
+
|
| 39 |
+
config_name = GUIDER_CONFIG_NAME
|
| 40 |
+
_input_predictions = None
|
| 41 |
+
_identifier_key = "__guidance_identifier__"
|
| 42 |
+
|
| 43 |
+
def __init__(self, start: float = 0.0, stop: float = 1.0):
|
| 44 |
+
self._start = start
|
| 45 |
+
self._stop = stop
|
| 46 |
+
self._step: int = None
|
| 47 |
+
self._num_inference_steps: int = None
|
| 48 |
+
self._timestep: torch.LongTensor = None
|
| 49 |
+
self._count_prepared = 0
|
| 50 |
+
self._input_fields: Dict[str, Union[str, Tuple[str, str]]] = None
|
| 51 |
+
self._enabled = True
|
| 52 |
+
|
| 53 |
+
if not (0.0 <= start < 1.0):
|
| 54 |
+
raise ValueError(f"Expected `start` to be between 0.0 and 1.0, but got {start}.")
|
| 55 |
+
if not (start <= stop <= 1.0):
|
| 56 |
+
raise ValueError(f"Expected `stop` to be between {start} and 1.0, but got {stop}.")
|
| 57 |
+
|
| 58 |
+
if self._input_predictions is None or not isinstance(self._input_predictions, list):
|
| 59 |
+
raise ValueError(
|
| 60 |
+
"`_input_predictions` must be a list of required prediction names for the guidance technique."
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
def disable(self):
|
| 64 |
+
self._enabled = False
|
| 65 |
+
|
| 66 |
+
def enable(self):
|
| 67 |
+
self._enabled = True
|
| 68 |
+
|
| 69 |
+
def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None:
|
| 70 |
+
self._step = step
|
| 71 |
+
self._num_inference_steps = num_inference_steps
|
| 72 |
+
self._timestep = timestep
|
| 73 |
+
self._count_prepared = 0
|
| 74 |
+
|
| 75 |
+
def set_input_fields(self, **kwargs: Dict[str, Union[str, Tuple[str, str]]]) -> None:
|
| 76 |
+
"""
|
| 77 |
+
Set the input fields for the guidance technique. The input fields are used to specify the names of the returned
|
| 78 |
+
attributes containing the prepared data after `prepare_inputs` is called. The prepared data is obtained from
|
| 79 |
+
the values of the provided keyword arguments to this method.
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
**kwargs (`Dict[str, Union[str, Tuple[str, str]]]`):
|
| 83 |
+
A dictionary where the keys are the names of the fields that will be used to store the data once it is
|
| 84 |
+
prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, which is used
|
| 85 |
+
to look up the required data provided for preparation.
|
| 86 |
+
|
| 87 |
+
If a string is provided, it will be used as the conditional data (or unconditional if used with a
|
| 88 |
+
guidance method that requires it). If a tuple of length 2 is provided, the first element must be the
|
| 89 |
+
conditional data identifier and the second element must be the unconditional data identifier or None.
|
| 90 |
+
|
| 91 |
+
Example:
|
| 92 |
+
```
|
| 93 |
+
data = {"prompt_embeds": <some tensor>, "negative_prompt_embeds": <some tensor>, "latents": <some tensor>}
|
| 94 |
+
|
| 95 |
+
BaseGuidance.set_input_fields(
|
| 96 |
+
latents="latents",
|
| 97 |
+
prompt_embeds=("prompt_embeds", "negative_prompt_embeds"),
|
| 98 |
+
)
|
| 99 |
+
```
|
| 100 |
+
"""
|
| 101 |
+
for key, value in kwargs.items():
|
| 102 |
+
is_string = isinstance(value, str)
|
| 103 |
+
is_tuple_of_str_with_len_2 = (
|
| 104 |
+
isinstance(value, tuple) and len(value) == 2 and all(isinstance(v, str) for v in value)
|
| 105 |
+
)
|
| 106 |
+
if not (is_string or is_tuple_of_str_with_len_2):
|
| 107 |
+
raise ValueError(
|
| 108 |
+
f"Expected `set_input_fields` to be called with a string or a tuple of string with length 2, but got {type(value)} for key {key}."
|
| 109 |
+
)
|
| 110 |
+
self._input_fields = kwargs
|
| 111 |
+
|
| 112 |
+
def prepare_models(self, denoiser: torch.nn.Module) -> None:
|
| 113 |
+
"""
|
| 114 |
+
Prepares the models for the guidance technique on a given batch of data. This method should be overridden in
|
| 115 |
+
subclasses to implement specific model preparation logic.
|
| 116 |
+
"""
|
| 117 |
+
self._count_prepared += 1
|
| 118 |
+
|
| 119 |
+
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
|
| 120 |
+
"""
|
| 121 |
+
Cleans up the models for the guidance technique after a given batch of data. This method should be overridden
|
| 122 |
+
in subclasses to implement specific model cleanup logic. It is useful for removing any hooks or other stateful
|
| 123 |
+
modifications made during `prepare_models`.
|
| 124 |
+
"""
|
| 125 |
+
pass
|
| 126 |
+
|
| 127 |
+
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
|
| 128 |
+
raise NotImplementedError("BaseGuidance::prepare_inputs must be implemented in subclasses.")
|
| 129 |
+
|
| 130 |
+
def __call__(self, data: List["BlockState"]) -> Any:
|
| 131 |
+
if not all(hasattr(d, "noise_pred") for d in data):
|
| 132 |
+
raise ValueError("Expected all data to have `noise_pred` attribute.")
|
| 133 |
+
if len(data) != self.num_conditions:
|
| 134 |
+
raise ValueError(
|
| 135 |
+
f"Expected {self.num_conditions} data items, but got {len(data)}. Please check the input data."
|
| 136 |
+
)
|
| 137 |
+
forward_inputs = {getattr(d, self._identifier_key): d.noise_pred for d in data}
|
| 138 |
+
return self.forward(**forward_inputs)
|
| 139 |
+
|
| 140 |
+
def forward(self, *args, **kwargs) -> Any:
|
| 141 |
+
raise NotImplementedError("BaseGuidance::forward must be implemented in subclasses.")
|
| 142 |
+
|
| 143 |
+
@property
|
| 144 |
+
def is_conditional(self) -> bool:
|
| 145 |
+
raise NotImplementedError("BaseGuidance::is_conditional must be implemented in subclasses.")
|
| 146 |
+
|
| 147 |
+
@property
|
| 148 |
+
def is_unconditional(self) -> bool:
|
| 149 |
+
return not self.is_conditional
|
| 150 |
+
|
| 151 |
+
@property
|
| 152 |
+
def num_conditions(self) -> int:
|
| 153 |
+
raise NotImplementedError("BaseGuidance::num_conditions must be implemented in subclasses.")
|
| 154 |
+
|
| 155 |
+
@classmethod
|
| 156 |
+
def _prepare_batch(
|
| 157 |
+
cls,
|
| 158 |
+
input_fields: Dict[str, Union[str, Tuple[str, str]]],
|
| 159 |
+
data: "BlockState",
|
| 160 |
+
tuple_index: int,
|
| 161 |
+
identifier: str,
|
| 162 |
+
) -> "BlockState":
|
| 163 |
+
"""
|
| 164 |
+
Prepares a batch of data for the guidance technique. This method is used in the `prepare_inputs` method of the
|
| 165 |
+
`BaseGuidance` class. It prepares the batch based on the provided tuple index.
|
| 166 |
+
|
| 167 |
+
Args:
|
| 168 |
+
input_fields (`Dict[str, Union[str, Tuple[str, str]]]`):
|
| 169 |
+
A dictionary where the keys are the names of the fields that will be used to store the data once it is
|
| 170 |
+
prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, which is used
|
| 171 |
+
to look up the required data provided for preparation. If a string is provided, it will be used as the
|
| 172 |
+
conditional data (or unconditional if used with a guidance method that requires it). If a tuple of
|
| 173 |
+
length 2 is provided, the first element must be the conditional data identifier and the second element
|
| 174 |
+
must be the unconditional data identifier or None.
|
| 175 |
+
data (`BlockState`):
|
| 176 |
+
The input data to be prepared.
|
| 177 |
+
tuple_index (`int`):
|
| 178 |
+
The index to use when accessing input fields that are tuples.
|
| 179 |
+
|
| 180 |
+
Returns:
|
| 181 |
+
`BlockState`: The prepared batch of data.
|
| 182 |
+
"""
|
| 183 |
+
from ..modular_pipelines.modular_pipeline import BlockState
|
| 184 |
+
|
| 185 |
+
if input_fields is None:
|
| 186 |
+
raise ValueError(
|
| 187 |
+
"Input fields cannot be None. Please pass `input_fields` to `prepare_inputs` or call `set_input_fields` before preparing inputs."
|
| 188 |
+
)
|
| 189 |
+
data_batch = {}
|
| 190 |
+
for key, value in input_fields.items():
|
| 191 |
+
try:
|
| 192 |
+
if isinstance(value, str):
|
| 193 |
+
data_batch[key] = getattr(data, value)
|
| 194 |
+
elif isinstance(value, tuple):
|
| 195 |
+
data_batch[key] = getattr(data, value[tuple_index])
|
| 196 |
+
else:
|
| 197 |
+
# We've already checked that value is a string or a tuple of strings with length 2
|
| 198 |
+
pass
|
| 199 |
+
except AttributeError:
|
| 200 |
+
logger.debug(f"`data` does not have attribute(s) {value}, skipping.")
|
| 201 |
+
data_batch[cls._identifier_key] = identifier
|
| 202 |
+
return BlockState(**data_batch)
|
| 203 |
+
|
| 204 |
+
@classmethod
|
| 205 |
+
@validate_hf_hub_args
|
| 206 |
+
def from_pretrained(
|
| 207 |
+
cls,
|
| 208 |
+
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
|
| 209 |
+
subfolder: Optional[str] = None,
|
| 210 |
+
return_unused_kwargs=False,
|
| 211 |
+
**kwargs,
|
| 212 |
+
) -> Self:
|
| 213 |
+
r"""
|
| 214 |
+
Instantiate a guider from a pre-defined JSON configuration file in a local directory or Hub repository.
|
| 215 |
+
|
| 216 |
+
Parameters:
|
| 217 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
|
| 218 |
+
Can be either:
|
| 219 |
+
|
| 220 |
+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
| 221 |
+
the Hub.
|
| 222 |
+
- A path to a *directory* (for example `./my_model_directory`) containing the guider configuration
|
| 223 |
+
saved with [`~BaseGuidance.save_pretrained`].
|
| 224 |
+
subfolder (`str`, *optional*):
|
| 225 |
+
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
| 226 |
+
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
|
| 227 |
+
Whether kwargs that are not consumed by the Python class should be returned or not.
|
| 228 |
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
| 229 |
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
| 230 |
+
is not used.
|
| 231 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
| 232 |
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
| 233 |
+
cached versions if they exist.
|
| 234 |
+
|
| 235 |
+
proxies (`Dict[str, str]`, *optional*):
|
| 236 |
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
| 237 |
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
| 238 |
+
output_loading_info(`bool`, *optional*, defaults to `False`):
|
| 239 |
+
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
| 240 |
+
local_files_only(`bool`, *optional*, defaults to `False`):
|
| 241 |
+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
| 242 |
+
won't be downloaded from the Hub.
|
| 243 |
+
token (`str` or *bool*, *optional*):
|
| 244 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
| 245 |
+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
| 246 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
| 247 |
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
| 248 |
+
allowed by Git.
|
| 249 |
+
|
| 250 |
+
<Tip>
|
| 251 |
+
|
| 252 |
+
To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with `hf
|
| 253 |
+
auth login`. You can also activate the special
|
| 254 |
+
["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
|
| 255 |
+
firewalled environment.
|
| 256 |
+
|
| 257 |
+
</Tip>
|
| 258 |
+
|
| 259 |
+
"""
|
| 260 |
+
config, kwargs, commit_hash = cls.load_config(
|
| 261 |
+
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
| 262 |
+
subfolder=subfolder,
|
| 263 |
+
return_unused_kwargs=True,
|
| 264 |
+
return_commit_hash=True,
|
| 265 |
+
**kwargs,
|
| 266 |
+
)
|
| 267 |
+
return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs)
|
| 268 |
+
|
| 269 |
+
def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
|
| 270 |
+
"""
|
| 271 |
+
Save a guider configuration object to a directory so that it can be reloaded using the
|
| 272 |
+
[`~BaseGuidance.from_pretrained`] class method.
|
| 273 |
+
|
| 274 |
+
Args:
|
| 275 |
+
save_directory (`str` or `os.PathLike`):
|
| 276 |
+
Directory where the configuration JSON file will be saved (will be created if it does not exist).
|
| 277 |
+
push_to_hub (`bool`, *optional*, defaults to `False`):
|
| 278 |
+
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
|
| 279 |
+
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
| 280 |
+
namespace).
|
| 281 |
+
kwargs (`Dict[str, Any]`, *optional*):
|
| 282 |
+
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
| 283 |
+
"""
|
| 284 |
+
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
class GuiderOutput(BaseOutput):
|
| 288 |
+
pred: torch.Tensor
|
| 289 |
+
pred_cond: Optional[torch.Tensor]
|
| 290 |
+
pred_uncond: Optional[torch.Tensor]
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
| 294 |
+
r"""
|
| 295 |
+
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
|
| 296 |
+
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
| 297 |
+
Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
| 298 |
+
|
| 299 |
+
Args:
|
| 300 |
+
noise_cfg (`torch.Tensor`):
|
| 301 |
+
The predicted noise tensor for the guided diffusion process.
|
| 302 |
+
noise_pred_text (`torch.Tensor`):
|
| 303 |
+
The predicted noise tensor for the text-guided diffusion process.
|
| 304 |
+
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
| 305 |
+
A rescale factor applied to the noise predictions.
|
| 306 |
+
Returns:
|
| 307 |
+
noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
|
| 308 |
+
"""
|
| 309 |
+
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
| 310 |
+
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
| 311 |
+
# rescale the results from guidance (fixes overexposure)
|
| 312 |
+
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
| 313 |
+
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
| 314 |
+
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
| 315 |
+
return noise_cfg
|
pythonProject/.venv/Lib/site-packages/diffusers/guiders/perturbed_attention_guidance.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import math
|
| 16 |
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
from ..configuration_utils import register_to_config
|
| 21 |
+
from ..hooks import HookRegistry, LayerSkipConfig
|
| 22 |
+
from ..hooks.layer_skip import _apply_layer_skip_hook
|
| 23 |
+
from ..utils import get_logger
|
| 24 |
+
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
if TYPE_CHECKING:
|
| 28 |
+
from ..modular_pipelines.modular_pipeline import BlockState
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
logger = get_logger(__name__) # pylint: disable=invalid-name
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class PerturbedAttentionGuidance(BaseGuidance):
|
| 35 |
+
"""
|
| 36 |
+
Perturbed Attention Guidance (PAG): https://huggingface.co/papers/2403.17377
|
| 37 |
+
|
| 38 |
+
The intution behind PAG can be thought of as moving the CFG predicted distribution estimates further away from
|
| 39 |
+
worse versions of the conditional distribution estimates. PAG was one of the first techniques to introduce the idea
|
| 40 |
+
of using a worse version of the trained model for better guiding itself in the denoising process. It perturbs the
|
| 41 |
+
attention scores of the latent stream by replacing the score matrix with an identity matrix for selectively chosen
|
| 42 |
+
layers.
|
| 43 |
+
|
| 44 |
+
Additional reading:
|
| 45 |
+
- [Guiding a Diffusion Model with a Bad Version of Itself](https://huggingface.co/papers/2406.02507)
|
| 46 |
+
|
| 47 |
+
PAG is implemented with similar implementation to SkipLayerGuidance due to overlap in the configuration parameters
|
| 48 |
+
and implementation details.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
guidance_scale (`float`, defaults to `7.5`):
|
| 52 |
+
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
| 53 |
+
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
| 54 |
+
deterioration of image quality.
|
| 55 |
+
perturbed_guidance_scale (`float`, defaults to `2.8`):
|
| 56 |
+
The scale parameter for perturbed attention guidance.
|
| 57 |
+
perturbed_guidance_start (`float`, defaults to `0.01`):
|
| 58 |
+
The fraction of the total number of denoising steps after which perturbed attention guidance starts.
|
| 59 |
+
perturbed_guidance_stop (`float`, defaults to `0.2`):
|
| 60 |
+
The fraction of the total number of denoising steps after which perturbed attention guidance stops.
|
| 61 |
+
perturbed_guidance_layers (`int` or `List[int]`, *optional*):
|
| 62 |
+
The layer indices to apply perturbed attention guidance to. Can be a single integer or a list of integers.
|
| 63 |
+
If not provided, `perturbed_guidance_config` must be provided.
|
| 64 |
+
perturbed_guidance_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*):
|
| 65 |
+
The configuration for the perturbed attention guidance. Can be a single `LayerSkipConfig` or a list of
|
| 66 |
+
`LayerSkipConfig`. If not provided, `perturbed_guidance_layers` must be provided.
|
| 67 |
+
guidance_rescale (`float`, defaults to `0.0`):
|
| 68 |
+
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
| 69 |
+
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
| 70 |
+
Flawed](https://huggingface.co/papers/2305.08891).
|
| 71 |
+
use_original_formulation (`bool`, defaults to `False`):
|
| 72 |
+
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
| 73 |
+
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
| 74 |
+
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
| 75 |
+
start (`float`, defaults to `0.01`):
|
| 76 |
+
The fraction of the total number of denoising steps after which guidance starts.
|
| 77 |
+
stop (`float`, defaults to `0.2`):
|
| 78 |
+
The fraction of the total number of denoising steps after which guidance stops.
|
| 79 |
+
"""
|
| 80 |
+
|
| 81 |
+
# NOTE: The current implementation does not account for joint latent conditioning (text + image/video tokens in
|
| 82 |
+
# the same latent stream). It assumes the entire latent is a single stream of visual tokens. It would be very
|
| 83 |
+
# complex to support joint latent conditioning in a model-agnostic manner without specializing the implementation
|
| 84 |
+
# for each model architecture.
|
| 85 |
+
|
| 86 |
+
_input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
|
| 87 |
+
|
| 88 |
+
@register_to_config
|
| 89 |
+
def __init__(
|
| 90 |
+
self,
|
| 91 |
+
guidance_scale: float = 7.5,
|
| 92 |
+
perturbed_guidance_scale: float = 2.8,
|
| 93 |
+
perturbed_guidance_start: float = 0.01,
|
| 94 |
+
perturbed_guidance_stop: float = 0.2,
|
| 95 |
+
perturbed_guidance_layers: Optional[Union[int, List[int]]] = None,
|
| 96 |
+
perturbed_guidance_config: Union[LayerSkipConfig, List[LayerSkipConfig], Dict[str, Any]] = None,
|
| 97 |
+
guidance_rescale: float = 0.0,
|
| 98 |
+
use_original_formulation: bool = False,
|
| 99 |
+
start: float = 0.0,
|
| 100 |
+
stop: float = 1.0,
|
| 101 |
+
):
|
| 102 |
+
super().__init__(start, stop)
|
| 103 |
+
|
| 104 |
+
self.guidance_scale = guidance_scale
|
| 105 |
+
self.skip_layer_guidance_scale = perturbed_guidance_scale
|
| 106 |
+
self.skip_layer_guidance_start = perturbed_guidance_start
|
| 107 |
+
self.skip_layer_guidance_stop = perturbed_guidance_stop
|
| 108 |
+
self.guidance_rescale = guidance_rescale
|
| 109 |
+
self.use_original_formulation = use_original_formulation
|
| 110 |
+
|
| 111 |
+
if perturbed_guidance_config is None:
|
| 112 |
+
if perturbed_guidance_layers is None:
|
| 113 |
+
raise ValueError(
|
| 114 |
+
"`perturbed_guidance_layers` must be provided if `perturbed_guidance_config` is not specified."
|
| 115 |
+
)
|
| 116 |
+
perturbed_guidance_config = LayerSkipConfig(
|
| 117 |
+
indices=perturbed_guidance_layers,
|
| 118 |
+
fqn="auto",
|
| 119 |
+
skip_attention=False,
|
| 120 |
+
skip_attention_scores=True,
|
| 121 |
+
skip_ff=False,
|
| 122 |
+
)
|
| 123 |
+
else:
|
| 124 |
+
if perturbed_guidance_layers is not None:
|
| 125 |
+
raise ValueError(
|
| 126 |
+
"`perturbed_guidance_layers` should not be provided if `perturbed_guidance_config` is specified."
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
if isinstance(perturbed_guidance_config, dict):
|
| 130 |
+
perturbed_guidance_config = LayerSkipConfig.from_dict(perturbed_guidance_config)
|
| 131 |
+
|
| 132 |
+
if isinstance(perturbed_guidance_config, LayerSkipConfig):
|
| 133 |
+
perturbed_guidance_config = [perturbed_guidance_config]
|
| 134 |
+
|
| 135 |
+
if not isinstance(perturbed_guidance_config, list):
|
| 136 |
+
raise ValueError(
|
| 137 |
+
"`perturbed_guidance_config` must be a `LayerSkipConfig`, a list of `LayerSkipConfig`, or a dict that can be converted to a `LayerSkipConfig`."
|
| 138 |
+
)
|
| 139 |
+
elif isinstance(next(iter(perturbed_guidance_config), None), dict):
|
| 140 |
+
perturbed_guidance_config = [LayerSkipConfig.from_dict(config) for config in perturbed_guidance_config]
|
| 141 |
+
|
| 142 |
+
for config in perturbed_guidance_config:
|
| 143 |
+
if config.skip_attention or not config.skip_attention_scores or config.skip_ff:
|
| 144 |
+
logger.warning(
|
| 145 |
+
"Perturbed Attention Guidance is designed to perturb attention scores, so `skip_attention` should be False, `skip_attention_scores` should be True, and `skip_ff` should be False. "
|
| 146 |
+
"Please check your configuration. Modifying the config to match the expected values."
|
| 147 |
+
)
|
| 148 |
+
config.skip_attention = False
|
| 149 |
+
config.skip_attention_scores = True
|
| 150 |
+
config.skip_ff = False
|
| 151 |
+
|
| 152 |
+
self.skip_layer_config = perturbed_guidance_config
|
| 153 |
+
self._skip_layer_hook_names = [f"SkipLayerGuidance_{i}" for i in range(len(self.skip_layer_config))]
|
| 154 |
+
|
| 155 |
+
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.prepare_models
|
| 156 |
+
def prepare_models(self, denoiser: torch.nn.Module) -> None:
|
| 157 |
+
self._count_prepared += 1
|
| 158 |
+
if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
|
| 159 |
+
for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config):
|
| 160 |
+
_apply_layer_skip_hook(denoiser, config, name=name)
|
| 161 |
+
|
| 162 |
+
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.cleanup_models
|
| 163 |
+
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
|
| 164 |
+
if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
|
| 165 |
+
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
|
| 166 |
+
# Remove the hooks after inference
|
| 167 |
+
for hook_name in self._skip_layer_hook_names:
|
| 168 |
+
registry.remove_hook(hook_name, recurse=True)
|
| 169 |
+
|
| 170 |
+
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.prepare_inputs
|
| 171 |
+
def prepare_inputs(
|
| 172 |
+
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
|
| 173 |
+
) -> List["BlockState"]:
|
| 174 |
+
if input_fields is None:
|
| 175 |
+
input_fields = self._input_fields
|
| 176 |
+
|
| 177 |
+
if self.num_conditions == 1:
|
| 178 |
+
tuple_indices = [0]
|
| 179 |
+
input_predictions = ["pred_cond"]
|
| 180 |
+
elif self.num_conditions == 2:
|
| 181 |
+
tuple_indices = [0, 1]
|
| 182 |
+
input_predictions = (
|
| 183 |
+
["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"]
|
| 184 |
+
)
|
| 185 |
+
else:
|
| 186 |
+
tuple_indices = [0, 1, 0]
|
| 187 |
+
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
|
| 188 |
+
data_batches = []
|
| 189 |
+
for i in range(self.num_conditions):
|
| 190 |
+
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
|
| 191 |
+
data_batches.append(data_batch)
|
| 192 |
+
return data_batches
|
| 193 |
+
|
| 194 |
+
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.forward
|
| 195 |
+
def forward(
|
| 196 |
+
self,
|
| 197 |
+
pred_cond: torch.Tensor,
|
| 198 |
+
pred_uncond: Optional[torch.Tensor] = None,
|
| 199 |
+
pred_cond_skip: Optional[torch.Tensor] = None,
|
| 200 |
+
) -> GuiderOutput:
|
| 201 |
+
pred = None
|
| 202 |
+
|
| 203 |
+
if not self._is_cfg_enabled() and not self._is_slg_enabled():
|
| 204 |
+
pred = pred_cond
|
| 205 |
+
elif not self._is_cfg_enabled():
|
| 206 |
+
shift = pred_cond - pred_cond_skip
|
| 207 |
+
pred = pred_cond if self.use_original_formulation else pred_cond_skip
|
| 208 |
+
pred = pred + self.skip_layer_guidance_scale * shift
|
| 209 |
+
elif not self._is_slg_enabled():
|
| 210 |
+
shift = pred_cond - pred_uncond
|
| 211 |
+
pred = pred_cond if self.use_original_formulation else pred_uncond
|
| 212 |
+
pred = pred + self.guidance_scale * shift
|
| 213 |
+
else:
|
| 214 |
+
shift = pred_cond - pred_uncond
|
| 215 |
+
shift_skip = pred_cond - pred_cond_skip
|
| 216 |
+
pred = pred_cond if self.use_original_formulation else pred_uncond
|
| 217 |
+
pred = pred + self.guidance_scale * shift + self.skip_layer_guidance_scale * shift_skip
|
| 218 |
+
|
| 219 |
+
if self.guidance_rescale > 0.0:
|
| 220 |
+
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
| 221 |
+
|
| 222 |
+
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
|
| 223 |
+
|
| 224 |
+
@property
|
| 225 |
+
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.is_conditional
|
| 226 |
+
def is_conditional(self) -> bool:
|
| 227 |
+
return self._count_prepared == 1 or self._count_prepared == 3
|
| 228 |
+
|
| 229 |
+
@property
|
| 230 |
+
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.num_conditions
|
| 231 |
+
def num_conditions(self) -> int:
|
| 232 |
+
num_conditions = 1
|
| 233 |
+
if self._is_cfg_enabled():
|
| 234 |
+
num_conditions += 1
|
| 235 |
+
if self._is_slg_enabled():
|
| 236 |
+
num_conditions += 1
|
| 237 |
+
return num_conditions
|
| 238 |
+
|
| 239 |
+
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance._is_cfg_enabled
|
| 240 |
+
def _is_cfg_enabled(self) -> bool:
|
| 241 |
+
if not self._enabled:
|
| 242 |
+
return False
|
| 243 |
+
|
| 244 |
+
is_within_range = True
|
| 245 |
+
if self._num_inference_steps is not None:
|
| 246 |
+
skip_start_step = int(self._start * self._num_inference_steps)
|
| 247 |
+
skip_stop_step = int(self._stop * self._num_inference_steps)
|
| 248 |
+
is_within_range = skip_start_step <= self._step < skip_stop_step
|
| 249 |
+
|
| 250 |
+
is_close = False
|
| 251 |
+
if self.use_original_formulation:
|
| 252 |
+
is_close = math.isclose(self.guidance_scale, 0.0)
|
| 253 |
+
else:
|
| 254 |
+
is_close = math.isclose(self.guidance_scale, 1.0)
|
| 255 |
+
|
| 256 |
+
return is_within_range and not is_close
|
| 257 |
+
|
| 258 |
+
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance._is_slg_enabled
|
| 259 |
+
def _is_slg_enabled(self) -> bool:
|
| 260 |
+
if not self._enabled:
|
| 261 |
+
return False
|
| 262 |
+
|
| 263 |
+
is_within_range = True
|
| 264 |
+
if self._num_inference_steps is not None:
|
| 265 |
+
skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
|
| 266 |
+
skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
|
| 267 |
+
is_within_range = skip_start_step < self._step < skip_stop_step
|
| 268 |
+
|
| 269 |
+
is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0)
|
| 270 |
+
|
| 271 |
+
return is_within_range and not is_zero
|
pythonProject/.venv/Lib/site-packages/diffusers/guiders/skip_layer_guidance.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import math
|
| 16 |
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
from ..configuration_utils import register_to_config
|
| 21 |
+
from ..hooks import HookRegistry, LayerSkipConfig
|
| 22 |
+
from ..hooks.layer_skip import _apply_layer_skip_hook
|
| 23 |
+
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
if TYPE_CHECKING:
|
| 27 |
+
from ..modular_pipelines.modular_pipeline import BlockState
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class SkipLayerGuidance(BaseGuidance):
|
| 31 |
+
"""
|
| 32 |
+
Skip Layer Guidance (SLG): https://github.com/Stability-AI/sd3.5
|
| 33 |
+
|
| 34 |
+
Spatio-Temporal Guidance (STG): https://huggingface.co/papers/2411.18664
|
| 35 |
+
|
| 36 |
+
SLG was introduced by StabilityAI for improving structure and anotomy coherence in generated images. It works by
|
| 37 |
+
skipping the forward pass of specified transformer blocks during the denoising process on an additional conditional
|
| 38 |
+
batch of data, apart from the conditional and unconditional batches already used in CFG
|
| 39 |
+
([~guiders.classifier_free_guidance.ClassifierFreeGuidance]), and then scaling and shifting the CFG predictions
|
| 40 |
+
based on the difference between conditional without skipping and conditional with skipping predictions.
|
| 41 |
+
|
| 42 |
+
The intution behind SLG can be thought of as moving the CFG predicted distribution estimates further away from
|
| 43 |
+
worse versions of the conditional distribution estimates (because skipping layers is equivalent to using a worse
|
| 44 |
+
version of the model for the conditional prediction).
|
| 45 |
+
|
| 46 |
+
STG is an improvement and follow-up work combining ideas from SLG, PAG and similar techniques for improving
|
| 47 |
+
generation quality in video diffusion models.
|
| 48 |
+
|
| 49 |
+
Additional reading:
|
| 50 |
+
- [Guiding a Diffusion Model with a Bad Version of Itself](https://huggingface.co/papers/2406.02507)
|
| 51 |
+
|
| 52 |
+
The values for `skip_layer_guidance_scale`, `skip_layer_guidance_start`, and `skip_layer_guidance_stop` are
|
| 53 |
+
defaulted to the recommendations by StabilityAI for Stable Diffusion 3.5 Medium.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
guidance_scale (`float`, defaults to `7.5`):
|
| 57 |
+
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
| 58 |
+
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
| 59 |
+
deterioration of image quality.
|
| 60 |
+
skip_layer_guidance_scale (`float`, defaults to `2.8`):
|
| 61 |
+
The scale parameter for skip layer guidance. Anatomy and structure coherence may improve with higher
|
| 62 |
+
values, but it may also lead to overexposure and saturation.
|
| 63 |
+
skip_layer_guidance_start (`float`, defaults to `0.01`):
|
| 64 |
+
The fraction of the total number of denoising steps after which skip layer guidance starts.
|
| 65 |
+
skip_layer_guidance_stop (`float`, defaults to `0.2`):
|
| 66 |
+
The fraction of the total number of denoising steps after which skip layer guidance stops.
|
| 67 |
+
skip_layer_guidance_layers (`int` or `List[int]`, *optional*):
|
| 68 |
+
The layer indices to apply skip layer guidance to. Can be a single integer or a list of integers. If not
|
| 69 |
+
provided, `skip_layer_config` must be provided. The recommended values are `[7, 8, 9]` for Stable Diffusion
|
| 70 |
+
3.5 Medium.
|
| 71 |
+
skip_layer_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*):
|
| 72 |
+
The configuration for the skip layer guidance. Can be a single `LayerSkipConfig` or a list of
|
| 73 |
+
`LayerSkipConfig`. If not provided, `skip_layer_guidance_layers` must be provided.
|
| 74 |
+
guidance_rescale (`float`, defaults to `0.0`):
|
| 75 |
+
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
| 76 |
+
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
| 77 |
+
Flawed](https://huggingface.co/papers/2305.08891).
|
| 78 |
+
use_original_formulation (`bool`, defaults to `False`):
|
| 79 |
+
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
| 80 |
+
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
| 81 |
+
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
| 82 |
+
start (`float`, defaults to `0.01`):
|
| 83 |
+
The fraction of the total number of denoising steps after which guidance starts.
|
| 84 |
+
stop (`float`, defaults to `0.2`):
|
| 85 |
+
The fraction of the total number of denoising steps after which guidance stops.
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
_input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
|
| 89 |
+
|
| 90 |
+
@register_to_config
|
| 91 |
+
def __init__(
|
| 92 |
+
self,
|
| 93 |
+
guidance_scale: float = 7.5,
|
| 94 |
+
skip_layer_guidance_scale: float = 2.8,
|
| 95 |
+
skip_layer_guidance_start: float = 0.01,
|
| 96 |
+
skip_layer_guidance_stop: float = 0.2,
|
| 97 |
+
skip_layer_guidance_layers: Optional[Union[int, List[int]]] = None,
|
| 98 |
+
skip_layer_config: Union[LayerSkipConfig, List[LayerSkipConfig], Dict[str, Any]] = None,
|
| 99 |
+
guidance_rescale: float = 0.0,
|
| 100 |
+
use_original_formulation: bool = False,
|
| 101 |
+
start: float = 0.0,
|
| 102 |
+
stop: float = 1.0,
|
| 103 |
+
):
|
| 104 |
+
super().__init__(start, stop)
|
| 105 |
+
|
| 106 |
+
self.guidance_scale = guidance_scale
|
| 107 |
+
self.skip_layer_guidance_scale = skip_layer_guidance_scale
|
| 108 |
+
self.skip_layer_guidance_start = skip_layer_guidance_start
|
| 109 |
+
self.skip_layer_guidance_stop = skip_layer_guidance_stop
|
| 110 |
+
self.guidance_rescale = guidance_rescale
|
| 111 |
+
self.use_original_formulation = use_original_formulation
|
| 112 |
+
|
| 113 |
+
if not (0.0 <= skip_layer_guidance_start < 1.0):
|
| 114 |
+
raise ValueError(
|
| 115 |
+
f"Expected `skip_layer_guidance_start` to be between 0.0 and 1.0, but got {skip_layer_guidance_start}."
|
| 116 |
+
)
|
| 117 |
+
if not (skip_layer_guidance_start <= skip_layer_guidance_stop <= 1.0):
|
| 118 |
+
raise ValueError(
|
| 119 |
+
f"Expected `skip_layer_guidance_stop` to be between 0.0 and 1.0, but got {skip_layer_guidance_stop}."
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
if skip_layer_guidance_layers is None and skip_layer_config is None:
|
| 123 |
+
raise ValueError(
|
| 124 |
+
"Either `skip_layer_guidance_layers` or `skip_layer_config` must be provided to enable Skip Layer Guidance."
|
| 125 |
+
)
|
| 126 |
+
if skip_layer_guidance_layers is not None and skip_layer_config is not None:
|
| 127 |
+
raise ValueError("Only one of `skip_layer_guidance_layers` or `skip_layer_config` can be provided.")
|
| 128 |
+
|
| 129 |
+
if skip_layer_guidance_layers is not None:
|
| 130 |
+
if isinstance(skip_layer_guidance_layers, int):
|
| 131 |
+
skip_layer_guidance_layers = [skip_layer_guidance_layers]
|
| 132 |
+
if not isinstance(skip_layer_guidance_layers, list):
|
| 133 |
+
raise ValueError(
|
| 134 |
+
f"Expected `skip_layer_guidance_layers` to be an int or a list of ints, but got {type(skip_layer_guidance_layers)}."
|
| 135 |
+
)
|
| 136 |
+
skip_layer_config = [LayerSkipConfig(layer, fqn="auto") for layer in skip_layer_guidance_layers]
|
| 137 |
+
|
| 138 |
+
if isinstance(skip_layer_config, dict):
|
| 139 |
+
skip_layer_config = LayerSkipConfig.from_dict(skip_layer_config)
|
| 140 |
+
|
| 141 |
+
if isinstance(skip_layer_config, LayerSkipConfig):
|
| 142 |
+
skip_layer_config = [skip_layer_config]
|
| 143 |
+
|
| 144 |
+
if not isinstance(skip_layer_config, list):
|
| 145 |
+
raise ValueError(
|
| 146 |
+
f"Expected `skip_layer_config` to be a LayerSkipConfig or a list of LayerSkipConfig, but got {type(skip_layer_config)}."
|
| 147 |
+
)
|
| 148 |
+
elif isinstance(next(iter(skip_layer_config), None), dict):
|
| 149 |
+
skip_layer_config = [LayerSkipConfig.from_dict(config) for config in skip_layer_config]
|
| 150 |
+
|
| 151 |
+
self.skip_layer_config = skip_layer_config
|
| 152 |
+
self._skip_layer_hook_names = [f"SkipLayerGuidance_{i}" for i in range(len(self.skip_layer_config))]
|
| 153 |
+
|
| 154 |
+
def prepare_models(self, denoiser: torch.nn.Module) -> None:
|
| 155 |
+
self._count_prepared += 1
|
| 156 |
+
if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
|
| 157 |
+
for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config):
|
| 158 |
+
_apply_layer_skip_hook(denoiser, config, name=name)
|
| 159 |
+
|
| 160 |
+
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
|
| 161 |
+
if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
|
| 162 |
+
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
|
| 163 |
+
# Remove the hooks after inference
|
| 164 |
+
for hook_name in self._skip_layer_hook_names:
|
| 165 |
+
registry.remove_hook(hook_name, recurse=True)
|
| 166 |
+
|
| 167 |
+
def prepare_inputs(
|
| 168 |
+
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
|
| 169 |
+
) -> List["BlockState"]:
|
| 170 |
+
if input_fields is None:
|
| 171 |
+
input_fields = self._input_fields
|
| 172 |
+
|
| 173 |
+
if self.num_conditions == 1:
|
| 174 |
+
tuple_indices = [0]
|
| 175 |
+
input_predictions = ["pred_cond"]
|
| 176 |
+
elif self.num_conditions == 2:
|
| 177 |
+
tuple_indices = [0, 1]
|
| 178 |
+
input_predictions = (
|
| 179 |
+
["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"]
|
| 180 |
+
)
|
| 181 |
+
else:
|
| 182 |
+
tuple_indices = [0, 1, 0]
|
| 183 |
+
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
|
| 184 |
+
data_batches = []
|
| 185 |
+
for i in range(self.num_conditions):
|
| 186 |
+
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
|
| 187 |
+
data_batches.append(data_batch)
|
| 188 |
+
return data_batches
|
| 189 |
+
|
| 190 |
+
def forward(
|
| 191 |
+
self,
|
| 192 |
+
pred_cond: torch.Tensor,
|
| 193 |
+
pred_uncond: Optional[torch.Tensor] = None,
|
| 194 |
+
pred_cond_skip: Optional[torch.Tensor] = None,
|
| 195 |
+
) -> GuiderOutput:
|
| 196 |
+
pred = None
|
| 197 |
+
|
| 198 |
+
if not self._is_cfg_enabled() and not self._is_slg_enabled():
|
| 199 |
+
pred = pred_cond
|
| 200 |
+
elif not self._is_cfg_enabled():
|
| 201 |
+
shift = pred_cond - pred_cond_skip
|
| 202 |
+
pred = pred_cond if self.use_original_formulation else pred_cond_skip
|
| 203 |
+
pred = pred + self.skip_layer_guidance_scale * shift
|
| 204 |
+
elif not self._is_slg_enabled():
|
| 205 |
+
shift = pred_cond - pred_uncond
|
| 206 |
+
pred = pred_cond if self.use_original_formulation else pred_uncond
|
| 207 |
+
pred = pred + self.guidance_scale * shift
|
| 208 |
+
else:
|
| 209 |
+
shift = pred_cond - pred_uncond
|
| 210 |
+
shift_skip = pred_cond - pred_cond_skip
|
| 211 |
+
pred = pred_cond if self.use_original_formulation else pred_uncond
|
| 212 |
+
pred = pred + self.guidance_scale * shift + self.skip_layer_guidance_scale * shift_skip
|
| 213 |
+
|
| 214 |
+
if self.guidance_rescale > 0.0:
|
| 215 |
+
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
| 216 |
+
|
| 217 |
+
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
|
| 218 |
+
|
| 219 |
+
@property
|
| 220 |
+
def is_conditional(self) -> bool:
|
| 221 |
+
return self._count_prepared == 1 or self._count_prepared == 3
|
| 222 |
+
|
| 223 |
+
@property
|
| 224 |
+
def num_conditions(self) -> int:
|
| 225 |
+
num_conditions = 1
|
| 226 |
+
if self._is_cfg_enabled():
|
| 227 |
+
num_conditions += 1
|
| 228 |
+
if self._is_slg_enabled():
|
| 229 |
+
num_conditions += 1
|
| 230 |
+
return num_conditions
|
| 231 |
+
|
| 232 |
+
def _is_cfg_enabled(self) -> bool:
|
| 233 |
+
if not self._enabled:
|
| 234 |
+
return False
|
| 235 |
+
|
| 236 |
+
is_within_range = True
|
| 237 |
+
if self._num_inference_steps is not None:
|
| 238 |
+
skip_start_step = int(self._start * self._num_inference_steps)
|
| 239 |
+
skip_stop_step = int(self._stop * self._num_inference_steps)
|
| 240 |
+
is_within_range = skip_start_step <= self._step < skip_stop_step
|
| 241 |
+
|
| 242 |
+
is_close = False
|
| 243 |
+
if self.use_original_formulation:
|
| 244 |
+
is_close = math.isclose(self.guidance_scale, 0.0)
|
| 245 |
+
else:
|
| 246 |
+
is_close = math.isclose(self.guidance_scale, 1.0)
|
| 247 |
+
|
| 248 |
+
return is_within_range and not is_close
|
| 249 |
+
|
| 250 |
+
def _is_slg_enabled(self) -> bool:
|
| 251 |
+
if not self._enabled:
|
| 252 |
+
return False
|
| 253 |
+
|
| 254 |
+
is_within_range = True
|
| 255 |
+
if self._num_inference_steps is not None:
|
| 256 |
+
skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
|
| 257 |
+
skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
|
| 258 |
+
is_within_range = skip_start_step < self._step < skip_stop_step
|
| 259 |
+
|
| 260 |
+
is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0)
|
| 261 |
+
|
| 262 |
+
return is_within_range and not is_zero
|
pythonProject/.venv/Lib/site-packages/diffusers/guiders/smoothed_energy_guidance.py
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import math
|
| 16 |
+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
from ..configuration_utils import register_to_config
|
| 21 |
+
from ..hooks import HookRegistry
|
| 22 |
+
from ..hooks.smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig, _apply_smoothed_energy_guidance_hook
|
| 23 |
+
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
if TYPE_CHECKING:
|
| 27 |
+
from ..modular_pipelines.modular_pipeline import BlockState
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class SmoothedEnergyGuidance(BaseGuidance):
|
| 31 |
+
"""
|
| 32 |
+
Smoothed Energy Guidance (SEG): https://huggingface.co/papers/2408.00760
|
| 33 |
+
|
| 34 |
+
SEG is only supported as an experimental prototype feature for now, so the implementation may be modified in the
|
| 35 |
+
future without warning or guarantee of reproducibility. This implementation assumes:
|
| 36 |
+
- Generated images are square (height == width)
|
| 37 |
+
- The model does not combine different modalities together (e.g., text and image latent streams are not combined
|
| 38 |
+
together such as Flux)
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
guidance_scale (`float`, defaults to `7.5`):
|
| 42 |
+
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
| 43 |
+
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
| 44 |
+
deterioration of image quality.
|
| 45 |
+
seg_guidance_scale (`float`, defaults to `3.0`):
|
| 46 |
+
The scale parameter for smoothed energy guidance. Anatomy and structure coherence may improve with higher
|
| 47 |
+
values, but it may also lead to overexposure and saturation.
|
| 48 |
+
seg_blur_sigma (`float`, defaults to `9999999.0`):
|
| 49 |
+
The amount by which we blur the attention weights. Setting this value greater than 9999.0 results in
|
| 50 |
+
infinite blur, which means uniform queries. Controlling it exponentially is empirically effective.
|
| 51 |
+
seg_blur_threshold_inf (`float`, defaults to `9999.0`):
|
| 52 |
+
The threshold above which the blur is considered infinite.
|
| 53 |
+
seg_guidance_start (`float`, defaults to `0.0`):
|
| 54 |
+
The fraction of the total number of denoising steps after which smoothed energy guidance starts.
|
| 55 |
+
seg_guidance_stop (`float`, defaults to `1.0`):
|
| 56 |
+
The fraction of the total number of denoising steps after which smoothed energy guidance stops.
|
| 57 |
+
seg_guidance_layers (`int` or `List[int]`, *optional*):
|
| 58 |
+
The layer indices to apply smoothed energy guidance to. Can be a single integer or a list of integers. If
|
| 59 |
+
not provided, `seg_guidance_config` must be provided. The recommended values are `[7, 8, 9]` for Stable
|
| 60 |
+
Diffusion 3.5 Medium.
|
| 61 |
+
seg_guidance_config (`SmoothedEnergyGuidanceConfig` or `List[SmoothedEnergyGuidanceConfig]`, *optional*):
|
| 62 |
+
The configuration for the smoothed energy layer guidance. Can be a single `SmoothedEnergyGuidanceConfig` or
|
| 63 |
+
a list of `SmoothedEnergyGuidanceConfig`. If not provided, `seg_guidance_layers` must be provided.
|
| 64 |
+
guidance_rescale (`float`, defaults to `0.0`):
|
| 65 |
+
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
| 66 |
+
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
| 67 |
+
Flawed](https://huggingface.co/papers/2305.08891).
|
| 68 |
+
use_original_formulation (`bool`, defaults to `False`):
|
| 69 |
+
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
| 70 |
+
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
| 71 |
+
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
| 72 |
+
start (`float`, defaults to `0.01`):
|
| 73 |
+
The fraction of the total number of denoising steps after which guidance starts.
|
| 74 |
+
stop (`float`, defaults to `0.2`):
|
| 75 |
+
The fraction of the total number of denoising steps after which guidance stops.
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
_input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"]
|
| 79 |
+
|
| 80 |
+
@register_to_config
|
| 81 |
+
def __init__(
|
| 82 |
+
self,
|
| 83 |
+
guidance_scale: float = 7.5,
|
| 84 |
+
seg_guidance_scale: float = 2.8,
|
| 85 |
+
seg_blur_sigma: float = 9999999.0,
|
| 86 |
+
seg_blur_threshold_inf: float = 9999.0,
|
| 87 |
+
seg_guidance_start: float = 0.0,
|
| 88 |
+
seg_guidance_stop: float = 1.0,
|
| 89 |
+
seg_guidance_layers: Optional[Union[int, List[int]]] = None,
|
| 90 |
+
seg_guidance_config: Union[SmoothedEnergyGuidanceConfig, List[SmoothedEnergyGuidanceConfig]] = None,
|
| 91 |
+
guidance_rescale: float = 0.0,
|
| 92 |
+
use_original_formulation: bool = False,
|
| 93 |
+
start: float = 0.0,
|
| 94 |
+
stop: float = 1.0,
|
| 95 |
+
):
|
| 96 |
+
super().__init__(start, stop)
|
| 97 |
+
|
| 98 |
+
self.guidance_scale = guidance_scale
|
| 99 |
+
self.seg_guidance_scale = seg_guidance_scale
|
| 100 |
+
self.seg_blur_sigma = seg_blur_sigma
|
| 101 |
+
self.seg_blur_threshold_inf = seg_blur_threshold_inf
|
| 102 |
+
self.seg_guidance_start = seg_guidance_start
|
| 103 |
+
self.seg_guidance_stop = seg_guidance_stop
|
| 104 |
+
self.guidance_rescale = guidance_rescale
|
| 105 |
+
self.use_original_formulation = use_original_formulation
|
| 106 |
+
|
| 107 |
+
if not (0.0 <= seg_guidance_start < 1.0):
|
| 108 |
+
raise ValueError(f"Expected `seg_guidance_start` to be between 0.0 and 1.0, but got {seg_guidance_start}.")
|
| 109 |
+
if not (seg_guidance_start <= seg_guidance_stop <= 1.0):
|
| 110 |
+
raise ValueError(f"Expected `seg_guidance_stop` to be between 0.0 and 1.0, but got {seg_guidance_stop}.")
|
| 111 |
+
|
| 112 |
+
if seg_guidance_layers is None and seg_guidance_config is None:
|
| 113 |
+
raise ValueError(
|
| 114 |
+
"Either `seg_guidance_layers` or `seg_guidance_config` must be provided to enable Smoothed Energy Guidance."
|
| 115 |
+
)
|
| 116 |
+
if seg_guidance_layers is not None and seg_guidance_config is not None:
|
| 117 |
+
raise ValueError("Only one of `seg_guidance_layers` or `seg_guidance_config` can be provided.")
|
| 118 |
+
|
| 119 |
+
if seg_guidance_layers is not None:
|
| 120 |
+
if isinstance(seg_guidance_layers, int):
|
| 121 |
+
seg_guidance_layers = [seg_guidance_layers]
|
| 122 |
+
if not isinstance(seg_guidance_layers, list):
|
| 123 |
+
raise ValueError(
|
| 124 |
+
f"Expected `seg_guidance_layers` to be an int or a list of ints, but got {type(seg_guidance_layers)}."
|
| 125 |
+
)
|
| 126 |
+
seg_guidance_config = [SmoothedEnergyGuidanceConfig(layer, fqn="auto") for layer in seg_guidance_layers]
|
| 127 |
+
|
| 128 |
+
if isinstance(seg_guidance_config, dict):
|
| 129 |
+
seg_guidance_config = SmoothedEnergyGuidanceConfig.from_dict(seg_guidance_config)
|
| 130 |
+
|
| 131 |
+
if isinstance(seg_guidance_config, SmoothedEnergyGuidanceConfig):
|
| 132 |
+
seg_guidance_config = [seg_guidance_config]
|
| 133 |
+
|
| 134 |
+
if not isinstance(seg_guidance_config, list):
|
| 135 |
+
raise ValueError(
|
| 136 |
+
f"Expected `seg_guidance_config` to be a SmoothedEnergyGuidanceConfig or a list of SmoothedEnergyGuidanceConfig, but got {type(seg_guidance_config)}."
|
| 137 |
+
)
|
| 138 |
+
elif isinstance(next(iter(seg_guidance_config), None), dict):
|
| 139 |
+
seg_guidance_config = [SmoothedEnergyGuidanceConfig.from_dict(config) for config in seg_guidance_config]
|
| 140 |
+
|
| 141 |
+
self.seg_guidance_config = seg_guidance_config
|
| 142 |
+
self._seg_layer_hook_names = [f"SmoothedEnergyGuidance_{i}" for i in range(len(self.seg_guidance_config))]
|
| 143 |
+
|
| 144 |
+
def prepare_models(self, denoiser: torch.nn.Module) -> None:
|
| 145 |
+
if self._is_seg_enabled() and self.is_conditional and self._count_prepared > 1:
|
| 146 |
+
for name, config in zip(self._seg_layer_hook_names, self.seg_guidance_config):
|
| 147 |
+
_apply_smoothed_energy_guidance_hook(denoiser, config, self.seg_blur_sigma, name=name)
|
| 148 |
+
|
| 149 |
+
def cleanup_models(self, denoiser: torch.nn.Module):
|
| 150 |
+
if self._is_seg_enabled() and self.is_conditional and self._count_prepared > 1:
|
| 151 |
+
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
|
| 152 |
+
# Remove the hooks after inference
|
| 153 |
+
for hook_name in self._seg_layer_hook_names:
|
| 154 |
+
registry.remove_hook(hook_name, recurse=True)
|
| 155 |
+
|
| 156 |
+
def prepare_inputs(
|
| 157 |
+
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
|
| 158 |
+
) -> List["BlockState"]:
|
| 159 |
+
if input_fields is None:
|
| 160 |
+
input_fields = self._input_fields
|
| 161 |
+
|
| 162 |
+
if self.num_conditions == 1:
|
| 163 |
+
tuple_indices = [0]
|
| 164 |
+
input_predictions = ["pred_cond"]
|
| 165 |
+
elif self.num_conditions == 2:
|
| 166 |
+
tuple_indices = [0, 1]
|
| 167 |
+
input_predictions = (
|
| 168 |
+
["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_seg"]
|
| 169 |
+
)
|
| 170 |
+
else:
|
| 171 |
+
tuple_indices = [0, 1, 0]
|
| 172 |
+
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"]
|
| 173 |
+
data_batches = []
|
| 174 |
+
for i in range(self.num_conditions):
|
| 175 |
+
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
|
| 176 |
+
data_batches.append(data_batch)
|
| 177 |
+
return data_batches
|
| 178 |
+
|
| 179 |
+
def forward(
|
| 180 |
+
self,
|
| 181 |
+
pred_cond: torch.Tensor,
|
| 182 |
+
pred_uncond: Optional[torch.Tensor] = None,
|
| 183 |
+
pred_cond_seg: Optional[torch.Tensor] = None,
|
| 184 |
+
) -> GuiderOutput:
|
| 185 |
+
pred = None
|
| 186 |
+
|
| 187 |
+
if not self._is_cfg_enabled() and not self._is_seg_enabled():
|
| 188 |
+
pred = pred_cond
|
| 189 |
+
elif not self._is_cfg_enabled():
|
| 190 |
+
shift = pred_cond - pred_cond_seg
|
| 191 |
+
pred = pred_cond if self.use_original_formulation else pred_cond_seg
|
| 192 |
+
pred = pred + self.seg_guidance_scale * shift
|
| 193 |
+
elif not self._is_seg_enabled():
|
| 194 |
+
shift = pred_cond - pred_uncond
|
| 195 |
+
pred = pred_cond if self.use_original_formulation else pred_uncond
|
| 196 |
+
pred = pred + self.guidance_scale * shift
|
| 197 |
+
else:
|
| 198 |
+
shift = pred_cond - pred_uncond
|
| 199 |
+
shift_seg = pred_cond - pred_cond_seg
|
| 200 |
+
pred = pred_cond if self.use_original_formulation else pred_uncond
|
| 201 |
+
pred = pred + self.guidance_scale * shift + self.seg_guidance_scale * shift_seg
|
| 202 |
+
|
| 203 |
+
if self.guidance_rescale > 0.0:
|
| 204 |
+
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
| 205 |
+
|
| 206 |
+
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
|
| 207 |
+
|
| 208 |
+
@property
|
| 209 |
+
def is_conditional(self) -> bool:
|
| 210 |
+
return self._count_prepared == 1 or self._count_prepared == 3
|
| 211 |
+
|
| 212 |
+
@property
|
| 213 |
+
def num_conditions(self) -> int:
|
| 214 |
+
num_conditions = 1
|
| 215 |
+
if self._is_cfg_enabled():
|
| 216 |
+
num_conditions += 1
|
| 217 |
+
if self._is_seg_enabled():
|
| 218 |
+
num_conditions += 1
|
| 219 |
+
return num_conditions
|
| 220 |
+
|
| 221 |
+
def _is_cfg_enabled(self) -> bool:
|
| 222 |
+
if not self._enabled:
|
| 223 |
+
return False
|
| 224 |
+
|
| 225 |
+
is_within_range = True
|
| 226 |
+
if self._num_inference_steps is not None:
|
| 227 |
+
skip_start_step = int(self._start * self._num_inference_steps)
|
| 228 |
+
skip_stop_step = int(self._stop * self._num_inference_steps)
|
| 229 |
+
is_within_range = skip_start_step <= self._step < skip_stop_step
|
| 230 |
+
|
| 231 |
+
is_close = False
|
| 232 |
+
if self.use_original_formulation:
|
| 233 |
+
is_close = math.isclose(self.guidance_scale, 0.0)
|
| 234 |
+
else:
|
| 235 |
+
is_close = math.isclose(self.guidance_scale, 1.0)
|
| 236 |
+
|
| 237 |
+
return is_within_range and not is_close
|
| 238 |
+
|
| 239 |
+
def _is_seg_enabled(self) -> bool:
|
| 240 |
+
if not self._enabled:
|
| 241 |
+
return False
|
| 242 |
+
|
| 243 |
+
is_within_range = True
|
| 244 |
+
if self._num_inference_steps is not None:
|
| 245 |
+
skip_start_step = int(self.seg_guidance_start * self._num_inference_steps)
|
| 246 |
+
skip_stop_step = int(self.seg_guidance_stop * self._num_inference_steps)
|
| 247 |
+
is_within_range = skip_start_step < self._step < skip_stop_step
|
| 248 |
+
|
| 249 |
+
is_zero = math.isclose(self.seg_guidance_scale, 0.0)
|
| 250 |
+
|
| 251 |
+
return is_within_range and not is_zero
|
pythonProject/.venv/Lib/site-packages/diffusers/guiders/tangential_classifier_free_guidance.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import math
|
| 16 |
+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
from ..configuration_utils import register_to_config
|
| 21 |
+
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
if TYPE_CHECKING:
|
| 25 |
+
from ..modular_pipelines.modular_pipeline import BlockState
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class TangentialClassifierFreeGuidance(BaseGuidance):
|
| 29 |
+
"""
|
| 30 |
+
Tangential Classifier Free Guidance (TCFG): https://huggingface.co/papers/2503.18137
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
guidance_scale (`float`, defaults to `7.5`):
|
| 34 |
+
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
| 35 |
+
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
| 36 |
+
deterioration of image quality.
|
| 37 |
+
guidance_rescale (`float`, defaults to `0.0`):
|
| 38 |
+
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
| 39 |
+
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
| 40 |
+
Flawed](https://huggingface.co/papers/2305.08891).
|
| 41 |
+
use_original_formulation (`bool`, defaults to `False`):
|
| 42 |
+
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
| 43 |
+
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
| 44 |
+
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
| 45 |
+
start (`float`, defaults to `0.0`):
|
| 46 |
+
The fraction of the total number of denoising steps after which guidance starts.
|
| 47 |
+
stop (`float`, defaults to `1.0`):
|
| 48 |
+
The fraction of the total number of denoising steps after which guidance stops.
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
_input_predictions = ["pred_cond", "pred_uncond"]
|
| 52 |
+
|
| 53 |
+
@register_to_config
|
| 54 |
+
def __init__(
|
| 55 |
+
self,
|
| 56 |
+
guidance_scale: float = 7.5,
|
| 57 |
+
guidance_rescale: float = 0.0,
|
| 58 |
+
use_original_formulation: bool = False,
|
| 59 |
+
start: float = 0.0,
|
| 60 |
+
stop: float = 1.0,
|
| 61 |
+
):
|
| 62 |
+
super().__init__(start, stop)
|
| 63 |
+
|
| 64 |
+
self.guidance_scale = guidance_scale
|
| 65 |
+
self.guidance_rescale = guidance_rescale
|
| 66 |
+
self.use_original_formulation = use_original_formulation
|
| 67 |
+
|
| 68 |
+
def prepare_inputs(
|
| 69 |
+
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
|
| 70 |
+
) -> List["BlockState"]:
|
| 71 |
+
if input_fields is None:
|
| 72 |
+
input_fields = self._input_fields
|
| 73 |
+
|
| 74 |
+
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
| 75 |
+
data_batches = []
|
| 76 |
+
for i in range(self.num_conditions):
|
| 77 |
+
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
|
| 78 |
+
data_batches.append(data_batch)
|
| 79 |
+
return data_batches
|
| 80 |
+
|
| 81 |
+
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
|
| 82 |
+
pred = None
|
| 83 |
+
|
| 84 |
+
if not self._is_tcfg_enabled():
|
| 85 |
+
pred = pred_cond
|
| 86 |
+
else:
|
| 87 |
+
pred = normalized_guidance(pred_cond, pred_uncond, self.guidance_scale, self.use_original_formulation)
|
| 88 |
+
|
| 89 |
+
if self.guidance_rescale > 0.0:
|
| 90 |
+
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
| 91 |
+
|
| 92 |
+
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
|
| 93 |
+
|
| 94 |
+
@property
|
| 95 |
+
def is_conditional(self) -> bool:
|
| 96 |
+
return self._num_outputs_prepared == 1
|
| 97 |
+
|
| 98 |
+
@property
|
| 99 |
+
def num_conditions(self) -> int:
|
| 100 |
+
num_conditions = 1
|
| 101 |
+
if self._is_tcfg_enabled():
|
| 102 |
+
num_conditions += 1
|
| 103 |
+
return num_conditions
|
| 104 |
+
|
| 105 |
+
def _is_tcfg_enabled(self) -> bool:
|
| 106 |
+
if not self._enabled:
|
| 107 |
+
return False
|
| 108 |
+
|
| 109 |
+
is_within_range = True
|
| 110 |
+
if self._num_inference_steps is not None:
|
| 111 |
+
skip_start_step = int(self._start * self._num_inference_steps)
|
| 112 |
+
skip_stop_step = int(self._stop * self._num_inference_steps)
|
| 113 |
+
is_within_range = skip_start_step <= self._step < skip_stop_step
|
| 114 |
+
|
| 115 |
+
is_close = False
|
| 116 |
+
if self.use_original_formulation:
|
| 117 |
+
is_close = math.isclose(self.guidance_scale, 0.0)
|
| 118 |
+
else:
|
| 119 |
+
is_close = math.isclose(self.guidance_scale, 1.0)
|
| 120 |
+
|
| 121 |
+
return is_within_range and not is_close
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def normalized_guidance(
|
| 125 |
+
pred_cond: torch.Tensor, pred_uncond: torch.Tensor, guidance_scale: float, use_original_formulation: bool = False
|
| 126 |
+
) -> torch.Tensor:
|
| 127 |
+
cond_dtype = pred_cond.dtype
|
| 128 |
+
preds = torch.stack([pred_cond, pred_uncond], dim=1).float()
|
| 129 |
+
preds = preds.flatten(2)
|
| 130 |
+
U, S, Vh = torch.linalg.svd(preds, full_matrices=False)
|
| 131 |
+
Vh_modified = Vh.clone()
|
| 132 |
+
Vh_modified[:, 1] = 0
|
| 133 |
+
|
| 134 |
+
uncond_flat = pred_uncond.reshape(pred_uncond.size(0), 1, -1).float()
|
| 135 |
+
x_Vh = torch.matmul(uncond_flat, Vh.transpose(-2, -1))
|
| 136 |
+
x_Vh_V = torch.matmul(x_Vh, Vh_modified)
|
| 137 |
+
pred_uncond = x_Vh_V.reshape(pred_uncond.shape).to(cond_dtype)
|
| 138 |
+
|
| 139 |
+
pred = pred_cond if use_original_formulation else pred_uncond
|
| 140 |
+
shift = pred_cond - pred_uncond
|
| 141 |
+
pred = pred + guidance_scale * shift
|
| 142 |
+
|
| 143 |
+
return pred
|
pythonProject/.venv/Lib/site-packages/diffusers/hooks/faster_cache.py
ADDED
|
@@ -0,0 +1,654 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import re
|
| 16 |
+
from dataclasses import dataclass
|
| 17 |
+
from typing import Any, Callable, List, Optional, Tuple
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
|
| 21 |
+
from ..models.attention import AttentionModuleMixin
|
| 22 |
+
from ..models.modeling_outputs import Transformer2DModelOutput
|
| 23 |
+
from ..utils import logging
|
| 24 |
+
from ._common import _ATTENTION_CLASSES
|
| 25 |
+
from .hooks import HookRegistry, ModelHook
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
_FASTER_CACHE_DENOISER_HOOK = "faster_cache_denoiser"
|
| 32 |
+
_FASTER_CACHE_BLOCK_HOOK = "faster_cache_block"
|
| 33 |
+
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = (
|
| 34 |
+
"^blocks.*attn",
|
| 35 |
+
"^transformer_blocks.*attn",
|
| 36 |
+
"^single_transformer_blocks.*attn",
|
| 37 |
+
)
|
| 38 |
+
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("^temporal_transformer_blocks.*attn",)
|
| 39 |
+
_TRANSFORMER_BLOCK_IDENTIFIERS = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS + _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
|
| 40 |
+
_UNCOND_COND_INPUT_KWARGS_IDENTIFIERS = (
|
| 41 |
+
"hidden_states",
|
| 42 |
+
"encoder_hidden_states",
|
| 43 |
+
"timestep",
|
| 44 |
+
"attention_mask",
|
| 45 |
+
"encoder_attention_mask",
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@dataclass
|
| 50 |
+
class FasterCacheConfig:
|
| 51 |
+
r"""
|
| 52 |
+
Configuration for [FasterCache](https://huggingface.co/papers/2410.19355).
|
| 53 |
+
|
| 54 |
+
Attributes:
|
| 55 |
+
spatial_attention_block_skip_range (`int`, defaults to `2`):
|
| 56 |
+
Calculate the attention states every `N` iterations. If this is set to `N`, the attention computation will
|
| 57 |
+
be skipped `N - 1` times (i.e., cached attention states will be reused) before computing the new attention
|
| 58 |
+
states again.
|
| 59 |
+
temporal_attention_block_skip_range (`int`, *optional*, defaults to `None`):
|
| 60 |
+
Calculate the attention states every `N` iterations. If this is set to `N`, the attention computation will
|
| 61 |
+
be skipped `N - 1` times (i.e., cached attention states will be reused) before computing the new attention
|
| 62 |
+
states again.
|
| 63 |
+
spatial_attention_timestep_skip_range (`Tuple[float, float]`, defaults to `(-1, 681)`):
|
| 64 |
+
The timestep range within which the spatial attention computation can be skipped without a significant loss
|
| 65 |
+
in quality. This is to be determined by the user based on the underlying model. The first value in the
|
| 66 |
+
tuple is the lower bound and the second value is the upper bound. Typically, diffusion timesteps for
|
| 67 |
+
denoising are in the reversed range of 0 to 1000 (i.e. denoising starts at timestep 1000 and ends at
|
| 68 |
+
timestep 0). For the default values, this would mean that the spatial attention computation skipping will
|
| 69 |
+
be applicable only after denoising timestep 681 is reached, and continue until the end of the denoising
|
| 70 |
+
process.
|
| 71 |
+
temporal_attention_timestep_skip_range (`Tuple[float, float]`, *optional*, defaults to `None`):
|
| 72 |
+
The timestep range within which the temporal attention computation can be skipped without a significant
|
| 73 |
+
loss in quality. This is to be determined by the user based on the underlying model. The first value in the
|
| 74 |
+
tuple is the lower bound and the second value is the upper bound. Typically, diffusion timesteps for
|
| 75 |
+
denoising are in the reversed range of 0 to 1000 (i.e. denoising starts at timestep 1000 and ends at
|
| 76 |
+
timestep 0).
|
| 77 |
+
low_frequency_weight_update_timestep_range (`Tuple[int, int]`, defaults to `(99, 901)`):
|
| 78 |
+
The timestep range within which the low frequency weight scaling update is applied. The first value in the
|
| 79 |
+
tuple is the lower bound and the second value is the upper bound of the timestep range. The callback
|
| 80 |
+
function for the update is called only within this range.
|
| 81 |
+
high_frequency_weight_update_timestep_range (`Tuple[int, int]`, defaults to `(-1, 301)`):
|
| 82 |
+
The timestep range within which the high frequency weight scaling update is applied. The first value in the
|
| 83 |
+
tuple is the lower bound and the second value is the upper bound of the timestep range. The callback
|
| 84 |
+
function for the update is called only within this range.
|
| 85 |
+
alpha_low_frequency (`float`, defaults to `1.1`):
|
| 86 |
+
The weight to scale the low frequency updates by. This is used to approximate the unconditional branch from
|
| 87 |
+
the conditional branch outputs.
|
| 88 |
+
alpha_high_frequency (`float`, defaults to `1.1`):
|
| 89 |
+
The weight to scale the high frequency updates by. This is used to approximate the unconditional branch
|
| 90 |
+
from the conditional branch outputs.
|
| 91 |
+
unconditional_batch_skip_range (`int`, defaults to `5`):
|
| 92 |
+
Process the unconditional branch every `N` iterations. If this is set to `N`, the unconditional branch
|
| 93 |
+
computation will be skipped `N - 1` times (i.e., cached unconditional branch states will be reused) before
|
| 94 |
+
computing the new unconditional branch states again.
|
| 95 |
+
unconditional_batch_timestep_skip_range (`Tuple[float, float]`, defaults to `(-1, 641)`):
|
| 96 |
+
The timestep range within which the unconditional branch computation can be skipped without a significant
|
| 97 |
+
loss in quality. This is to be determined by the user based on the underlying model. The first value in the
|
| 98 |
+
tuple is the lower bound and the second value is the upper bound.
|
| 99 |
+
spatial_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks.*attn1", "transformer_blocks.*attn1", "single_transformer_blocks.*attn1")`):
|
| 100 |
+
The identifiers to match the spatial attention blocks in the model. If the name of the block contains any
|
| 101 |
+
of these identifiers, FasterCache will be applied to that block. This can either be the full layer names,
|
| 102 |
+
partial layer names, or regex patterns. Matching will always be done using a regex match.
|
| 103 |
+
temporal_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("temporal_transformer_blocks.*attn1",)`):
|
| 104 |
+
The identifiers to match the temporal attention blocks in the model. If the name of the block contains any
|
| 105 |
+
of these identifiers, FasterCache will be applied to that block. This can either be the full layer names,
|
| 106 |
+
partial layer names, or regex patterns. Matching will always be done using a regex match.
|
| 107 |
+
attention_weight_callback (`Callable[[torch.nn.Module], float]`, defaults to `None`):
|
| 108 |
+
The callback function to determine the weight to scale the attention outputs by. This function should take
|
| 109 |
+
the attention module as input and return a float value. This is used to approximate the unconditional
|
| 110 |
+
branch from the conditional branch outputs. If not provided, the default weight is 0.5 for all timesteps.
|
| 111 |
+
Typically, as described in the paper, this weight should gradually increase from 0 to 1 as the inference
|
| 112 |
+
progresses. Users are encouraged to experiment and provide custom weight schedules that take into account
|
| 113 |
+
the number of inference steps and underlying model behaviour as denoising progresses.
|
| 114 |
+
low_frequency_weight_callback (`Callable[[torch.nn.Module], float]`, defaults to `None`):
|
| 115 |
+
The callback function to determine the weight to scale the low frequency updates by. If not provided, the
|
| 116 |
+
default weight is 1.1 for timesteps within the range specified (as described in the paper).
|
| 117 |
+
high_frequency_weight_callback (`Callable[[torch.nn.Module], float]`, defaults to `None`):
|
| 118 |
+
The callback function to determine the weight to scale the high frequency updates by. If not provided, the
|
| 119 |
+
default weight is 1.1 for timesteps within the range specified (as described in the paper).
|
| 120 |
+
tensor_format (`str`, defaults to `"BCFHW"`):
|
| 121 |
+
The format of the input tensors. This should be one of `"BCFHW"`, `"BFCHW"`, or `"BCHW"`. The format is
|
| 122 |
+
used to split individual latent frames in order for low and high frequency components to be computed.
|
| 123 |
+
is_guidance_distilled (`bool`, defaults to `False`):
|
| 124 |
+
Whether the model is guidance distilled or not. If the model is guidance distilled, FasterCache will not be
|
| 125 |
+
applied at the denoiser-level to skip the unconditional branch computation (as there is none).
|
| 126 |
+
_unconditional_conditional_input_kwargs_identifiers (`List[str]`, defaults to `("hidden_states", "encoder_hidden_states", "timestep", "attention_mask", "encoder_attention_mask")`):
|
| 127 |
+
The identifiers to match the input kwargs that contain the batchwise-concatenated unconditional and
|
| 128 |
+
conditional inputs. If the name of the input kwargs contains any of these identifiers, FasterCache will
|
| 129 |
+
split the inputs into unconditional and conditional branches. This must be a list of exact input kwargs
|
| 130 |
+
names that contain the batchwise-concatenated unconditional and conditional inputs.
|
| 131 |
+
"""
|
| 132 |
+
|
| 133 |
+
# In the paper and codebase, they hardcode these values to 2. However, it can be made configurable
|
| 134 |
+
# after some testing. We default to 2 if these parameters are not provided.
|
| 135 |
+
spatial_attention_block_skip_range: int = 2
|
| 136 |
+
temporal_attention_block_skip_range: Optional[int] = None
|
| 137 |
+
|
| 138 |
+
spatial_attention_timestep_skip_range: Tuple[int, int] = (-1, 681)
|
| 139 |
+
temporal_attention_timestep_skip_range: Tuple[int, int] = (-1, 681)
|
| 140 |
+
|
| 141 |
+
# Indicator functions for low/high frequency as mentioned in Equation 11 of the paper
|
| 142 |
+
low_frequency_weight_update_timestep_range: Tuple[int, int] = (99, 901)
|
| 143 |
+
high_frequency_weight_update_timestep_range: Tuple[int, int] = (-1, 301)
|
| 144 |
+
|
| 145 |
+
# ⍺1 and ⍺2 as mentioned in Equation 11 of the paper
|
| 146 |
+
alpha_low_frequency: float = 1.1
|
| 147 |
+
alpha_high_frequency: float = 1.1
|
| 148 |
+
|
| 149 |
+
# n as described in CFG-Cache explanation in the paper - dependent on the model
|
| 150 |
+
unconditional_batch_skip_range: int = 5
|
| 151 |
+
unconditional_batch_timestep_skip_range: Tuple[int, int] = (-1, 641)
|
| 152 |
+
|
| 153 |
+
spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS
|
| 154 |
+
temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
|
| 155 |
+
|
| 156 |
+
attention_weight_callback: Callable[[torch.nn.Module], float] = None
|
| 157 |
+
low_frequency_weight_callback: Callable[[torch.nn.Module], float] = None
|
| 158 |
+
high_frequency_weight_callback: Callable[[torch.nn.Module], float] = None
|
| 159 |
+
|
| 160 |
+
tensor_format: str = "BCFHW"
|
| 161 |
+
is_guidance_distilled: bool = False
|
| 162 |
+
|
| 163 |
+
current_timestep_callback: Callable[[], int] = None
|
| 164 |
+
|
| 165 |
+
_unconditional_conditional_input_kwargs_identifiers: List[str] = _UNCOND_COND_INPUT_KWARGS_IDENTIFIERS
|
| 166 |
+
|
| 167 |
+
def __repr__(self) -> str:
|
| 168 |
+
return (
|
| 169 |
+
f"FasterCacheConfig(\n"
|
| 170 |
+
f" spatial_attention_block_skip_range={self.spatial_attention_block_skip_range},\n"
|
| 171 |
+
f" temporal_attention_block_skip_range={self.temporal_attention_block_skip_range},\n"
|
| 172 |
+
f" spatial_attention_timestep_skip_range={self.spatial_attention_timestep_skip_range},\n"
|
| 173 |
+
f" temporal_attention_timestep_skip_range={self.temporal_attention_timestep_skip_range},\n"
|
| 174 |
+
f" low_frequency_weight_update_timestep_range={self.low_frequency_weight_update_timestep_range},\n"
|
| 175 |
+
f" high_frequency_weight_update_timestep_range={self.high_frequency_weight_update_timestep_range},\n"
|
| 176 |
+
f" alpha_low_frequency={self.alpha_low_frequency},\n"
|
| 177 |
+
f" alpha_high_frequency={self.alpha_high_frequency},\n"
|
| 178 |
+
f" unconditional_batch_skip_range={self.unconditional_batch_skip_range},\n"
|
| 179 |
+
f" unconditional_batch_timestep_skip_range={self.unconditional_batch_timestep_skip_range},\n"
|
| 180 |
+
f" spatial_attention_block_identifiers={self.spatial_attention_block_identifiers},\n"
|
| 181 |
+
f" temporal_attention_block_identifiers={self.temporal_attention_block_identifiers},\n"
|
| 182 |
+
f" tensor_format={self.tensor_format},\n"
|
| 183 |
+
f")"
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
class FasterCacheDenoiserState:
|
| 188 |
+
r"""
|
| 189 |
+
State for [FasterCache](https://huggingface.co/papers/2410.19355) top-level denoiser module.
|
| 190 |
+
"""
|
| 191 |
+
|
| 192 |
+
def __init__(self) -> None:
|
| 193 |
+
self.iteration: int = 0
|
| 194 |
+
self.low_frequency_delta: torch.Tensor = None
|
| 195 |
+
self.high_frequency_delta: torch.Tensor = None
|
| 196 |
+
|
| 197 |
+
def reset(self):
|
| 198 |
+
self.iteration = 0
|
| 199 |
+
self.low_frequency_delta = None
|
| 200 |
+
self.high_frequency_delta = None
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
class FasterCacheBlockState:
|
| 204 |
+
r"""
|
| 205 |
+
State for [FasterCache](https://huggingface.co/papers/2410.19355). Every underlying block that FasterCache is
|
| 206 |
+
applied to will have an instance of this state.
|
| 207 |
+
"""
|
| 208 |
+
|
| 209 |
+
def __init__(self) -> None:
|
| 210 |
+
self.iteration: int = 0
|
| 211 |
+
self.batch_size: int = None
|
| 212 |
+
self.cache: Tuple[torch.Tensor, torch.Tensor] = None
|
| 213 |
+
|
| 214 |
+
def reset(self):
|
| 215 |
+
self.iteration = 0
|
| 216 |
+
self.batch_size = None
|
| 217 |
+
self.cache = None
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
class FasterCacheDenoiserHook(ModelHook):
|
| 221 |
+
_is_stateful = True
|
| 222 |
+
|
| 223 |
+
def __init__(
|
| 224 |
+
self,
|
| 225 |
+
unconditional_batch_skip_range: int,
|
| 226 |
+
unconditional_batch_timestep_skip_range: Tuple[int, int],
|
| 227 |
+
tensor_format: str,
|
| 228 |
+
is_guidance_distilled: bool,
|
| 229 |
+
uncond_cond_input_kwargs_identifiers: List[str],
|
| 230 |
+
current_timestep_callback: Callable[[], int],
|
| 231 |
+
low_frequency_weight_callback: Callable[[torch.nn.Module], torch.Tensor],
|
| 232 |
+
high_frequency_weight_callback: Callable[[torch.nn.Module], torch.Tensor],
|
| 233 |
+
) -> None:
|
| 234 |
+
super().__init__()
|
| 235 |
+
|
| 236 |
+
self.unconditional_batch_skip_range = unconditional_batch_skip_range
|
| 237 |
+
self.unconditional_batch_timestep_skip_range = unconditional_batch_timestep_skip_range
|
| 238 |
+
# We can't easily detect what args are to be split in unconditional and conditional branches. We
|
| 239 |
+
# can only do it for kwargs, hence they are the only ones we split. The args are passed as-is.
|
| 240 |
+
# If a model is to be made compatible with FasterCache, the user must ensure that the inputs that
|
| 241 |
+
# contain batchwise-concatenated unconditional and conditional inputs are passed as kwargs.
|
| 242 |
+
self.uncond_cond_input_kwargs_identifiers = uncond_cond_input_kwargs_identifiers
|
| 243 |
+
self.tensor_format = tensor_format
|
| 244 |
+
self.is_guidance_distilled = is_guidance_distilled
|
| 245 |
+
|
| 246 |
+
self.current_timestep_callback = current_timestep_callback
|
| 247 |
+
self.low_frequency_weight_callback = low_frequency_weight_callback
|
| 248 |
+
self.high_frequency_weight_callback = high_frequency_weight_callback
|
| 249 |
+
|
| 250 |
+
def initialize_hook(self, module):
|
| 251 |
+
self.state = FasterCacheDenoiserState()
|
| 252 |
+
return module
|
| 253 |
+
|
| 254 |
+
@staticmethod
|
| 255 |
+
def _get_cond_input(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 256 |
+
# Note: this method assumes that the input tensor is batchwise-concatenated with unconditional inputs
|
| 257 |
+
# followed by conditional inputs.
|
| 258 |
+
_, cond = input.chunk(2, dim=0)
|
| 259 |
+
return cond
|
| 260 |
+
|
| 261 |
+
def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any:
|
| 262 |
+
# Split the unconditional and conditional inputs. We only want to infer the conditional branch if the
|
| 263 |
+
# requirements for skipping the unconditional branch are met as described in the paper.
|
| 264 |
+
# We skip the unconditional branch only if the following conditions are met:
|
| 265 |
+
# 1. We have completed at least one iteration of the denoiser
|
| 266 |
+
# 2. The current timestep is within the range specified by the user. This is the optimal timestep range
|
| 267 |
+
# where approximating the unconditional branch from the computation of the conditional branch is possible
|
| 268 |
+
# without a significant loss in quality.
|
| 269 |
+
# 3. The current iteration is not a multiple of the unconditional batch skip range. This is done so that
|
| 270 |
+
# we compute the unconditional branch at least once every few iterations to ensure minimal quality loss.
|
| 271 |
+
is_within_timestep_range = (
|
| 272 |
+
self.unconditional_batch_timestep_skip_range[0]
|
| 273 |
+
< self.current_timestep_callback()
|
| 274 |
+
< self.unconditional_batch_timestep_skip_range[1]
|
| 275 |
+
)
|
| 276 |
+
should_skip_uncond = (
|
| 277 |
+
self.state.iteration > 0
|
| 278 |
+
and is_within_timestep_range
|
| 279 |
+
and self.state.iteration % self.unconditional_batch_skip_range != 0
|
| 280 |
+
and not self.is_guidance_distilled
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
if should_skip_uncond:
|
| 284 |
+
is_any_kwarg_uncond = any(k in self.uncond_cond_input_kwargs_identifiers for k in kwargs.keys())
|
| 285 |
+
if is_any_kwarg_uncond:
|
| 286 |
+
logger.debug("FasterCache - Skipping unconditional branch computation")
|
| 287 |
+
args = tuple([self._get_cond_input(arg) if torch.is_tensor(arg) else arg for arg in args])
|
| 288 |
+
kwargs = {
|
| 289 |
+
k: v if k not in self.uncond_cond_input_kwargs_identifiers else self._get_cond_input(v)
|
| 290 |
+
for k, v in kwargs.items()
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
output = self.fn_ref.original_forward(*args, **kwargs)
|
| 294 |
+
|
| 295 |
+
if self.is_guidance_distilled:
|
| 296 |
+
self.state.iteration += 1
|
| 297 |
+
return output
|
| 298 |
+
|
| 299 |
+
if torch.is_tensor(output):
|
| 300 |
+
hidden_states = output
|
| 301 |
+
elif isinstance(output, (tuple, Transformer2DModelOutput)):
|
| 302 |
+
hidden_states = output[0]
|
| 303 |
+
|
| 304 |
+
batch_size = hidden_states.size(0)
|
| 305 |
+
|
| 306 |
+
if should_skip_uncond:
|
| 307 |
+
self.state.low_frequency_delta = self.state.low_frequency_delta * self.low_frequency_weight_callback(
|
| 308 |
+
module
|
| 309 |
+
)
|
| 310 |
+
self.state.high_frequency_delta = self.state.high_frequency_delta * self.high_frequency_weight_callback(
|
| 311 |
+
module
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
if self.tensor_format == "BCFHW":
|
| 315 |
+
hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
|
| 316 |
+
if self.tensor_format == "BCFHW" or self.tensor_format == "BFCHW":
|
| 317 |
+
hidden_states = hidden_states.flatten(0, 1)
|
| 318 |
+
|
| 319 |
+
low_freq_cond, high_freq_cond = _split_low_high_freq(hidden_states.float())
|
| 320 |
+
|
| 321 |
+
# Approximate/compute the unconditional branch outputs as described in Equation 9 and 10 of the paper
|
| 322 |
+
low_freq_uncond = self.state.low_frequency_delta + low_freq_cond
|
| 323 |
+
high_freq_uncond = self.state.high_frequency_delta + high_freq_cond
|
| 324 |
+
uncond_freq = low_freq_uncond + high_freq_uncond
|
| 325 |
+
|
| 326 |
+
uncond_states = torch.fft.ifftshift(uncond_freq)
|
| 327 |
+
uncond_states = torch.fft.ifft2(uncond_states).real
|
| 328 |
+
|
| 329 |
+
if self.tensor_format == "BCFHW" or self.tensor_format == "BFCHW":
|
| 330 |
+
uncond_states = uncond_states.unflatten(0, (batch_size, -1))
|
| 331 |
+
hidden_states = hidden_states.unflatten(0, (batch_size, -1))
|
| 332 |
+
if self.tensor_format == "BCFHW":
|
| 333 |
+
uncond_states = uncond_states.permute(0, 2, 1, 3, 4)
|
| 334 |
+
hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
|
| 335 |
+
|
| 336 |
+
# Concatenate the approximated unconditional and predicted conditional branches
|
| 337 |
+
uncond_states = uncond_states.to(hidden_states.dtype)
|
| 338 |
+
hidden_states = torch.cat([uncond_states, hidden_states], dim=0)
|
| 339 |
+
else:
|
| 340 |
+
uncond_states, cond_states = hidden_states.chunk(2, dim=0)
|
| 341 |
+
if self.tensor_format == "BCFHW":
|
| 342 |
+
uncond_states = uncond_states.permute(0, 2, 1, 3, 4)
|
| 343 |
+
cond_states = cond_states.permute(0, 2, 1, 3, 4)
|
| 344 |
+
if self.tensor_format == "BCFHW" or self.tensor_format == "BFCHW":
|
| 345 |
+
uncond_states = uncond_states.flatten(0, 1)
|
| 346 |
+
cond_states = cond_states.flatten(0, 1)
|
| 347 |
+
|
| 348 |
+
low_freq_uncond, high_freq_uncond = _split_low_high_freq(uncond_states.float())
|
| 349 |
+
low_freq_cond, high_freq_cond = _split_low_high_freq(cond_states.float())
|
| 350 |
+
self.state.low_frequency_delta = low_freq_uncond - low_freq_cond
|
| 351 |
+
self.state.high_frequency_delta = high_freq_uncond - high_freq_cond
|
| 352 |
+
|
| 353 |
+
self.state.iteration += 1
|
| 354 |
+
if torch.is_tensor(output):
|
| 355 |
+
output = hidden_states
|
| 356 |
+
elif isinstance(output, tuple):
|
| 357 |
+
output = (hidden_states, *output[1:])
|
| 358 |
+
else:
|
| 359 |
+
output.sample = hidden_states
|
| 360 |
+
|
| 361 |
+
return output
|
| 362 |
+
|
| 363 |
+
def reset_state(self, module: torch.nn.Module) -> torch.nn.Module:
|
| 364 |
+
self.state.reset()
|
| 365 |
+
return module
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
class FasterCacheBlockHook(ModelHook):
|
| 369 |
+
_is_stateful = True
|
| 370 |
+
|
| 371 |
+
def __init__(
|
| 372 |
+
self,
|
| 373 |
+
block_skip_range: int,
|
| 374 |
+
timestep_skip_range: Tuple[int, int],
|
| 375 |
+
is_guidance_distilled: bool,
|
| 376 |
+
weight_callback: Callable[[torch.nn.Module], float],
|
| 377 |
+
current_timestep_callback: Callable[[], int],
|
| 378 |
+
) -> None:
|
| 379 |
+
super().__init__()
|
| 380 |
+
|
| 381 |
+
self.block_skip_range = block_skip_range
|
| 382 |
+
self.timestep_skip_range = timestep_skip_range
|
| 383 |
+
self.is_guidance_distilled = is_guidance_distilled
|
| 384 |
+
|
| 385 |
+
self.weight_callback = weight_callback
|
| 386 |
+
self.current_timestep_callback = current_timestep_callback
|
| 387 |
+
|
| 388 |
+
def initialize_hook(self, module):
|
| 389 |
+
self.state = FasterCacheBlockState()
|
| 390 |
+
return module
|
| 391 |
+
|
| 392 |
+
def _compute_approximated_attention_output(
|
| 393 |
+
self, t_2_output: torch.Tensor, t_output: torch.Tensor, weight: float, batch_size: int
|
| 394 |
+
) -> torch.Tensor:
|
| 395 |
+
if t_2_output.size(0) != batch_size:
|
| 396 |
+
# The cache t_2_output contains both batchwise-concatenated unconditional-conditional branch outputs. Just
|
| 397 |
+
# take the conditional branch outputs.
|
| 398 |
+
assert t_2_output.size(0) == 2 * batch_size
|
| 399 |
+
t_2_output = t_2_output[batch_size:]
|
| 400 |
+
if t_output.size(0) != batch_size:
|
| 401 |
+
# The cache t_output contains both batchwise-concatenated unconditional-conditional branch outputs. Just
|
| 402 |
+
# take the conditional branch outputs.
|
| 403 |
+
assert t_output.size(0) == 2 * batch_size
|
| 404 |
+
t_output = t_output[batch_size:]
|
| 405 |
+
return t_output + (t_output - t_2_output) * weight
|
| 406 |
+
|
| 407 |
+
def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any:
|
| 408 |
+
batch_size = [
|
| 409 |
+
*[arg.size(0) for arg in args if torch.is_tensor(arg)],
|
| 410 |
+
*[v.size(0) for v in kwargs.values() if torch.is_tensor(v)],
|
| 411 |
+
][0]
|
| 412 |
+
if self.state.batch_size is None:
|
| 413 |
+
# Will be updated on first forward pass through the denoiser
|
| 414 |
+
self.state.batch_size = batch_size
|
| 415 |
+
|
| 416 |
+
# If we have to skip due to the skip conditions, then let's skip as expected.
|
| 417 |
+
# But, we can't skip if the denoiser wants to infer both unconditional and conditional branches. This
|
| 418 |
+
# is because the expected output shapes of attention layer will not match if we only return values from
|
| 419 |
+
# the cache (which only caches conditional branch outputs). So, if state.batch_size (which is the true
|
| 420 |
+
# unconditional-conditional batch size) is same as the current batch size, we don't perform the layer
|
| 421 |
+
# skip. Otherwise, we conditionally skip the layer based on what state.skip_callback returns.
|
| 422 |
+
is_within_timestep_range = (
|
| 423 |
+
self.timestep_skip_range[0] < self.current_timestep_callback() < self.timestep_skip_range[1]
|
| 424 |
+
)
|
| 425 |
+
if not is_within_timestep_range:
|
| 426 |
+
should_skip_attention = False
|
| 427 |
+
else:
|
| 428 |
+
should_compute_attention = self.state.iteration > 0 and self.state.iteration % self.block_skip_range == 0
|
| 429 |
+
should_skip_attention = not should_compute_attention
|
| 430 |
+
if should_skip_attention:
|
| 431 |
+
should_skip_attention = self.is_guidance_distilled or self.state.batch_size != batch_size
|
| 432 |
+
|
| 433 |
+
if should_skip_attention:
|
| 434 |
+
logger.debug("FasterCache - Skipping attention and using approximation")
|
| 435 |
+
if torch.is_tensor(self.state.cache[-1]):
|
| 436 |
+
t_2_output, t_output = self.state.cache
|
| 437 |
+
weight = self.weight_callback(module)
|
| 438 |
+
output = self._compute_approximated_attention_output(t_2_output, t_output, weight, batch_size)
|
| 439 |
+
else:
|
| 440 |
+
# The cache contains multiple tensors from past N iterations (N=2 for FasterCache). We need to handle all of them.
|
| 441 |
+
# Diffusers blocks can return multiple tensors - let's call them [A, B, C, ...] for simplicity.
|
| 442 |
+
# In our cache, we would have [[A_1, B_1, C_1, ...], [A_2, B_2, C_2, ...], ...] where each list is the output from
|
| 443 |
+
# a forward pass of the block. We need to compute the approximated output for each of these tensors.
|
| 444 |
+
# The zip(*state.cache) operation will give us [(A_1, A_2, ...), (B_1, B_2, ...), (C_1, C_2, ...), ...] which
|
| 445 |
+
# allows us to compute the approximated attention output for each tensor in the cache.
|
| 446 |
+
output = ()
|
| 447 |
+
for t_2_output, t_output in zip(*self.state.cache):
|
| 448 |
+
result = self._compute_approximated_attention_output(
|
| 449 |
+
t_2_output, t_output, self.weight_callback(module), batch_size
|
| 450 |
+
)
|
| 451 |
+
output += (result,)
|
| 452 |
+
else:
|
| 453 |
+
logger.debug("FasterCache - Computing attention")
|
| 454 |
+
output = self.fn_ref.original_forward(*args, **kwargs)
|
| 455 |
+
|
| 456 |
+
# Note that the following condition for getting hidden_states should suffice since Diffusers blocks either return
|
| 457 |
+
# a single hidden_states tensor, or a tuple of (hidden_states, encoder_hidden_states) tensors. We need to handle
|
| 458 |
+
# both cases.
|
| 459 |
+
if torch.is_tensor(output):
|
| 460 |
+
cache_output = output
|
| 461 |
+
if not self.is_guidance_distilled and cache_output.size(0) == self.state.batch_size:
|
| 462 |
+
# The output here can be both unconditional-conditional branch outputs or just conditional branch outputs.
|
| 463 |
+
# This is determined at the higher-level denoiser module. We only want to cache the conditional branch outputs.
|
| 464 |
+
cache_output = cache_output.chunk(2, dim=0)[1]
|
| 465 |
+
else:
|
| 466 |
+
# Cache all return values and perform the same operation as above
|
| 467 |
+
cache_output = ()
|
| 468 |
+
for out in output:
|
| 469 |
+
if not self.is_guidance_distilled and out.size(0) == self.state.batch_size:
|
| 470 |
+
out = out.chunk(2, dim=0)[1]
|
| 471 |
+
cache_output += (out,)
|
| 472 |
+
|
| 473 |
+
if self.state.cache is None:
|
| 474 |
+
self.state.cache = [cache_output, cache_output]
|
| 475 |
+
else:
|
| 476 |
+
self.state.cache = [self.state.cache[-1], cache_output]
|
| 477 |
+
|
| 478 |
+
self.state.iteration += 1
|
| 479 |
+
return output
|
| 480 |
+
|
| 481 |
+
def reset_state(self, module: torch.nn.Module) -> torch.nn.Module:
|
| 482 |
+
self.state.reset()
|
| 483 |
+
return module
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
def apply_faster_cache(module: torch.nn.Module, config: FasterCacheConfig) -> None:
|
| 487 |
+
r"""
|
| 488 |
+
Applies [FasterCache](https://huggingface.co/papers/2410.19355) to a given pipeline.
|
| 489 |
+
|
| 490 |
+
Args:
|
| 491 |
+
module (`torch.nn.Module`):
|
| 492 |
+
The pytorch module to apply FasterCache to. Typically, this should be a transformer architecture supported
|
| 493 |
+
in Diffusers, such as `CogVideoXTransformer3DModel`, but external implementations may also work.
|
| 494 |
+
config (`FasterCacheConfig`):
|
| 495 |
+
The configuration to use for FasterCache.
|
| 496 |
+
|
| 497 |
+
Example:
|
| 498 |
+
```python
|
| 499 |
+
>>> import torch
|
| 500 |
+
>>> from diffusers import CogVideoXPipeline, FasterCacheConfig, apply_faster_cache
|
| 501 |
+
|
| 502 |
+
>>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
|
| 503 |
+
>>> pipe.to("cuda")
|
| 504 |
+
|
| 505 |
+
>>> config = FasterCacheConfig(
|
| 506 |
+
... spatial_attention_block_skip_range=2,
|
| 507 |
+
... spatial_attention_timestep_skip_range=(-1, 681),
|
| 508 |
+
... low_frequency_weight_update_timestep_range=(99, 641),
|
| 509 |
+
... high_frequency_weight_update_timestep_range=(-1, 301),
|
| 510 |
+
... spatial_attention_block_identifiers=["transformer_blocks"],
|
| 511 |
+
... attention_weight_callback=lambda _: 0.3,
|
| 512 |
+
... tensor_format="BFCHW",
|
| 513 |
+
... )
|
| 514 |
+
>>> apply_faster_cache(pipe.transformer, config)
|
| 515 |
+
```
|
| 516 |
+
"""
|
| 517 |
+
|
| 518 |
+
logger.warning(
|
| 519 |
+
"FasterCache is a purely experimental feature and may not work as expected. Not all models support FasterCache. "
|
| 520 |
+
"The API is subject to change in future releases, with no guarantee of backward compatibility. Please report any issues at "
|
| 521 |
+
"https://github.com/huggingface/diffusers/issues."
|
| 522 |
+
)
|
| 523 |
+
|
| 524 |
+
if config.attention_weight_callback is None:
|
| 525 |
+
# If the user has not provided a weight callback, we default to 0.5 for all timesteps.
|
| 526 |
+
# In the paper, they recommend using a gradually increasing weight from 0 to 1 as the inference progresses, but
|
| 527 |
+
# this depends from model-to-model. It is required by the user to provide a weight callback if they want to
|
| 528 |
+
# use a different weight function. Defaulting to 0.5 works well in practice for most cases.
|
| 529 |
+
logger.warning(
|
| 530 |
+
"No `attention_weight_callback` provided when enabling FasterCache. Defaulting to using a weight of 0.5 for all timesteps."
|
| 531 |
+
)
|
| 532 |
+
config.attention_weight_callback = lambda _: 0.5
|
| 533 |
+
|
| 534 |
+
if config.low_frequency_weight_callback is None:
|
| 535 |
+
logger.debug(
|
| 536 |
+
"Low frequency weight callback not provided when enabling FasterCache. Defaulting to behaviour described in the paper."
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
def low_frequency_weight_callback(module: torch.nn.Module) -> float:
|
| 540 |
+
is_within_range = (
|
| 541 |
+
config.low_frequency_weight_update_timestep_range[0]
|
| 542 |
+
< config.current_timestep_callback()
|
| 543 |
+
< config.low_frequency_weight_update_timestep_range[1]
|
| 544 |
+
)
|
| 545 |
+
return config.alpha_low_frequency if is_within_range else 1.0
|
| 546 |
+
|
| 547 |
+
config.low_frequency_weight_callback = low_frequency_weight_callback
|
| 548 |
+
|
| 549 |
+
if config.high_frequency_weight_callback is None:
|
| 550 |
+
logger.debug(
|
| 551 |
+
"High frequency weight callback not provided when enabling FasterCache. Defaulting to behaviour described in the paper."
|
| 552 |
+
)
|
| 553 |
+
|
| 554 |
+
def high_frequency_weight_callback(module: torch.nn.Module) -> float:
|
| 555 |
+
is_within_range = (
|
| 556 |
+
config.high_frequency_weight_update_timestep_range[0]
|
| 557 |
+
< config.current_timestep_callback()
|
| 558 |
+
< config.high_frequency_weight_update_timestep_range[1]
|
| 559 |
+
)
|
| 560 |
+
return config.alpha_high_frequency if is_within_range else 1.0
|
| 561 |
+
|
| 562 |
+
config.high_frequency_weight_callback = high_frequency_weight_callback
|
| 563 |
+
|
| 564 |
+
supported_tensor_formats = ["BCFHW", "BFCHW", "BCHW"] # TODO(aryan): Support BSC for LTX Video
|
| 565 |
+
if config.tensor_format not in supported_tensor_formats:
|
| 566 |
+
raise ValueError(f"`tensor_format` must be one of {supported_tensor_formats}, but got {config.tensor_format}.")
|
| 567 |
+
|
| 568 |
+
_apply_faster_cache_on_denoiser(module, config)
|
| 569 |
+
|
| 570 |
+
for name, submodule in module.named_modules():
|
| 571 |
+
if not isinstance(submodule, _ATTENTION_CLASSES):
|
| 572 |
+
continue
|
| 573 |
+
if any(re.search(identifier, name) is not None for identifier in _TRANSFORMER_BLOCK_IDENTIFIERS):
|
| 574 |
+
_apply_faster_cache_on_attention_class(name, submodule, config)
|
| 575 |
+
|
| 576 |
+
|
| 577 |
+
def _apply_faster_cache_on_denoiser(module: torch.nn.Module, config: FasterCacheConfig) -> None:
|
| 578 |
+
hook = FasterCacheDenoiserHook(
|
| 579 |
+
config.unconditional_batch_skip_range,
|
| 580 |
+
config.unconditional_batch_timestep_skip_range,
|
| 581 |
+
config.tensor_format,
|
| 582 |
+
config.is_guidance_distilled,
|
| 583 |
+
config._unconditional_conditional_input_kwargs_identifiers,
|
| 584 |
+
config.current_timestep_callback,
|
| 585 |
+
config.low_frequency_weight_callback,
|
| 586 |
+
config.high_frequency_weight_callback,
|
| 587 |
+
)
|
| 588 |
+
registry = HookRegistry.check_if_exists_or_initialize(module)
|
| 589 |
+
registry.register_hook(hook, _FASTER_CACHE_DENOISER_HOOK)
|
| 590 |
+
|
| 591 |
+
|
| 592 |
+
def _apply_faster_cache_on_attention_class(name: str, module: AttentionModuleMixin, config: FasterCacheConfig) -> None:
|
| 593 |
+
is_spatial_self_attention = (
|
| 594 |
+
any(re.search(identifier, name) is not None for identifier in config.spatial_attention_block_identifiers)
|
| 595 |
+
and config.spatial_attention_block_skip_range is not None
|
| 596 |
+
and not getattr(module, "is_cross_attention", False)
|
| 597 |
+
)
|
| 598 |
+
is_temporal_self_attention = (
|
| 599 |
+
any(re.search(identifier, name) is not None for identifier in config.temporal_attention_block_identifiers)
|
| 600 |
+
and config.temporal_attention_block_skip_range is not None
|
| 601 |
+
and not module.is_cross_attention
|
| 602 |
+
)
|
| 603 |
+
|
| 604 |
+
block_skip_range, timestep_skip_range, block_type = None, None, None
|
| 605 |
+
if is_spatial_self_attention:
|
| 606 |
+
block_skip_range = config.spatial_attention_block_skip_range
|
| 607 |
+
timestep_skip_range = config.spatial_attention_timestep_skip_range
|
| 608 |
+
block_type = "spatial"
|
| 609 |
+
elif is_temporal_self_attention:
|
| 610 |
+
block_skip_range = config.temporal_attention_block_skip_range
|
| 611 |
+
timestep_skip_range = config.temporal_attention_timestep_skip_range
|
| 612 |
+
block_type = "temporal"
|
| 613 |
+
|
| 614 |
+
if block_skip_range is None or timestep_skip_range is None:
|
| 615 |
+
logger.debug(
|
| 616 |
+
f'Unable to apply FasterCache to the selected layer: "{name}" because it does '
|
| 617 |
+
f"not match any of the required criteria for spatial or temporal attention layers. Note, "
|
| 618 |
+
f"however, that this layer may still be valid for applying PAB. Please specify the correct "
|
| 619 |
+
f"block identifiers in the configuration or use the specialized `apply_faster_cache_on_module` "
|
| 620 |
+
f"function to apply FasterCache to this layer."
|
| 621 |
+
)
|
| 622 |
+
return
|
| 623 |
+
|
| 624 |
+
logger.debug(f"Enabling FasterCache ({block_type}) for layer: {name}")
|
| 625 |
+
hook = FasterCacheBlockHook(
|
| 626 |
+
block_skip_range,
|
| 627 |
+
timestep_skip_range,
|
| 628 |
+
config.is_guidance_distilled,
|
| 629 |
+
config.attention_weight_callback,
|
| 630 |
+
config.current_timestep_callback,
|
| 631 |
+
)
|
| 632 |
+
registry = HookRegistry.check_if_exists_or_initialize(module)
|
| 633 |
+
registry.register_hook(hook, _FASTER_CACHE_BLOCK_HOOK)
|
| 634 |
+
|
| 635 |
+
|
| 636 |
+
# Reference: https://github.com/Vchitect/FasterCache/blob/fab32c15014636dc854948319c0a9a8d92c7acb4/scripts/latte/faster_cache_sample_latte.py#L127C1-L143C39
|
| 637 |
+
@torch.no_grad()
|
| 638 |
+
def _split_low_high_freq(x):
|
| 639 |
+
fft = torch.fft.fft2(x)
|
| 640 |
+
fft_shifted = torch.fft.fftshift(fft)
|
| 641 |
+
height, width = x.shape[-2:]
|
| 642 |
+
radius = min(height, width) // 5
|
| 643 |
+
|
| 644 |
+
y_grid, x_grid = torch.meshgrid(torch.arange(height), torch.arange(width))
|
| 645 |
+
center_x, center_y = width // 2, height // 2
|
| 646 |
+
mask = (x_grid - center_x) ** 2 + (y_grid - center_y) ** 2 <= radius**2
|
| 647 |
+
|
| 648 |
+
low_freq_mask = mask.unsqueeze(0).unsqueeze(0).to(x.device)
|
| 649 |
+
high_freq_mask = ~low_freq_mask
|
| 650 |
+
|
| 651 |
+
low_freq_fft = fft_shifted * low_freq_mask
|
| 652 |
+
high_freq_fft = fft_shifted * high_freq_mask
|
| 653 |
+
|
| 654 |
+
return low_freq_fft, high_freq_fft
|
pythonProject/.venv/Lib/site-packages/diffusers/hooks/first_block_cache.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from dataclasses import dataclass
|
| 16 |
+
from typing import Tuple, Union
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
from ..utils import get_logger
|
| 21 |
+
from ..utils.torch_utils import unwrap_module
|
| 22 |
+
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS
|
| 23 |
+
from ._helpers import TransformerBlockRegistry
|
| 24 |
+
from .hooks import BaseState, HookRegistry, ModelHook, StateManager
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
logger = get_logger(__name__) # pylint: disable=invalid-name
|
| 28 |
+
|
| 29 |
+
_FBC_LEADER_BLOCK_HOOK = "fbc_leader_block_hook"
|
| 30 |
+
_FBC_BLOCK_HOOK = "fbc_block_hook"
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@dataclass
|
| 34 |
+
class FirstBlockCacheConfig:
|
| 35 |
+
r"""
|
| 36 |
+
Configuration for [First Block
|
| 37 |
+
Cache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching).
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
threshold (`float`, defaults to `0.05`):
|
| 41 |
+
The threshold to determine whether or not a forward pass through all layers of the model is required. A
|
| 42 |
+
higher threshold usually results in a forward pass through a lower number of layers and faster inference,
|
| 43 |
+
but might lead to poorer generation quality. A lower threshold may not result in significant generation
|
| 44 |
+
speedup. The threshold is compared against the absmean difference of the residuals between the current and
|
| 45 |
+
cached outputs from the first transformer block. If the difference is below the threshold, the forward pass
|
| 46 |
+
is skipped.
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
threshold: float = 0.05
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class FBCSharedBlockState(BaseState):
|
| 53 |
+
def __init__(self) -> None:
|
| 54 |
+
super().__init__()
|
| 55 |
+
|
| 56 |
+
self.head_block_output: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
|
| 57 |
+
self.head_block_residual: torch.Tensor = None
|
| 58 |
+
self.tail_block_residuals: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
|
| 59 |
+
self.should_compute: bool = True
|
| 60 |
+
|
| 61 |
+
def reset(self):
|
| 62 |
+
self.tail_block_residuals = None
|
| 63 |
+
self.should_compute = True
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class FBCHeadBlockHook(ModelHook):
|
| 67 |
+
_is_stateful = True
|
| 68 |
+
|
| 69 |
+
def __init__(self, state_manager: StateManager, threshold: float):
|
| 70 |
+
self.state_manager = state_manager
|
| 71 |
+
self.threshold = threshold
|
| 72 |
+
self._metadata = None
|
| 73 |
+
|
| 74 |
+
def initialize_hook(self, module):
|
| 75 |
+
unwrapped_module = unwrap_module(module)
|
| 76 |
+
self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__)
|
| 77 |
+
return module
|
| 78 |
+
|
| 79 |
+
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
| 80 |
+
original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs)
|
| 81 |
+
|
| 82 |
+
output = self.fn_ref.original_forward(*args, **kwargs)
|
| 83 |
+
is_output_tuple = isinstance(output, tuple)
|
| 84 |
+
|
| 85 |
+
if is_output_tuple:
|
| 86 |
+
hidden_states_residual = output[self._metadata.return_hidden_states_index] - original_hidden_states
|
| 87 |
+
else:
|
| 88 |
+
hidden_states_residual = output - original_hidden_states
|
| 89 |
+
|
| 90 |
+
shared_state: FBCSharedBlockState = self.state_manager.get_state()
|
| 91 |
+
hidden_states = encoder_hidden_states = None
|
| 92 |
+
should_compute = self._should_compute_remaining_blocks(hidden_states_residual)
|
| 93 |
+
shared_state.should_compute = should_compute
|
| 94 |
+
|
| 95 |
+
if not should_compute:
|
| 96 |
+
# Apply caching
|
| 97 |
+
if is_output_tuple:
|
| 98 |
+
hidden_states = (
|
| 99 |
+
shared_state.tail_block_residuals[0] + output[self._metadata.return_hidden_states_index]
|
| 100 |
+
)
|
| 101 |
+
else:
|
| 102 |
+
hidden_states = shared_state.tail_block_residuals[0] + output
|
| 103 |
+
|
| 104 |
+
if self._metadata.return_encoder_hidden_states_index is not None:
|
| 105 |
+
assert is_output_tuple
|
| 106 |
+
encoder_hidden_states = (
|
| 107 |
+
shared_state.tail_block_residuals[1] + output[self._metadata.return_encoder_hidden_states_index]
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
if is_output_tuple:
|
| 111 |
+
return_output = [None] * len(output)
|
| 112 |
+
return_output[self._metadata.return_hidden_states_index] = hidden_states
|
| 113 |
+
return_output[self._metadata.return_encoder_hidden_states_index] = encoder_hidden_states
|
| 114 |
+
return_output = tuple(return_output)
|
| 115 |
+
else:
|
| 116 |
+
return_output = hidden_states
|
| 117 |
+
output = return_output
|
| 118 |
+
else:
|
| 119 |
+
if is_output_tuple:
|
| 120 |
+
head_block_output = [None] * len(output)
|
| 121 |
+
head_block_output[0] = output[self._metadata.return_hidden_states_index]
|
| 122 |
+
head_block_output[1] = output[self._metadata.return_encoder_hidden_states_index]
|
| 123 |
+
else:
|
| 124 |
+
head_block_output = output
|
| 125 |
+
shared_state.head_block_output = head_block_output
|
| 126 |
+
shared_state.head_block_residual = hidden_states_residual
|
| 127 |
+
|
| 128 |
+
return output
|
| 129 |
+
|
| 130 |
+
def reset_state(self, module):
|
| 131 |
+
self.state_manager.reset()
|
| 132 |
+
return module
|
| 133 |
+
|
| 134 |
+
@torch.compiler.disable
|
| 135 |
+
def _should_compute_remaining_blocks(self, hidden_states_residual: torch.Tensor) -> bool:
|
| 136 |
+
shared_state = self.state_manager.get_state()
|
| 137 |
+
if shared_state.head_block_residual is None:
|
| 138 |
+
return True
|
| 139 |
+
prev_hidden_states_residual = shared_state.head_block_residual
|
| 140 |
+
absmean = (hidden_states_residual - prev_hidden_states_residual).abs().mean()
|
| 141 |
+
prev_hidden_states_absmean = prev_hidden_states_residual.abs().mean()
|
| 142 |
+
diff = (absmean / prev_hidden_states_absmean).item()
|
| 143 |
+
return diff > self.threshold
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
class FBCBlockHook(ModelHook):
|
| 147 |
+
def __init__(self, state_manager: StateManager, is_tail: bool = False):
|
| 148 |
+
super().__init__()
|
| 149 |
+
self.state_manager = state_manager
|
| 150 |
+
self.is_tail = is_tail
|
| 151 |
+
self._metadata = None
|
| 152 |
+
|
| 153 |
+
def initialize_hook(self, module):
|
| 154 |
+
unwrapped_module = unwrap_module(module)
|
| 155 |
+
self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__)
|
| 156 |
+
return module
|
| 157 |
+
|
| 158 |
+
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
| 159 |
+
original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs)
|
| 160 |
+
original_encoder_hidden_states = None
|
| 161 |
+
if self._metadata.return_encoder_hidden_states_index is not None:
|
| 162 |
+
original_encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs(
|
| 163 |
+
"encoder_hidden_states", args, kwargs
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
shared_state = self.state_manager.get_state()
|
| 167 |
+
|
| 168 |
+
if shared_state.should_compute:
|
| 169 |
+
output = self.fn_ref.original_forward(*args, **kwargs)
|
| 170 |
+
if self.is_tail:
|
| 171 |
+
hidden_states_residual = encoder_hidden_states_residual = None
|
| 172 |
+
if isinstance(output, tuple):
|
| 173 |
+
hidden_states_residual = (
|
| 174 |
+
output[self._metadata.return_hidden_states_index] - shared_state.head_block_output[0]
|
| 175 |
+
)
|
| 176 |
+
encoder_hidden_states_residual = (
|
| 177 |
+
output[self._metadata.return_encoder_hidden_states_index] - shared_state.head_block_output[1]
|
| 178 |
+
)
|
| 179 |
+
else:
|
| 180 |
+
hidden_states_residual = output - shared_state.head_block_output
|
| 181 |
+
shared_state.tail_block_residuals = (hidden_states_residual, encoder_hidden_states_residual)
|
| 182 |
+
return output
|
| 183 |
+
|
| 184 |
+
if original_encoder_hidden_states is None:
|
| 185 |
+
return_output = original_hidden_states
|
| 186 |
+
else:
|
| 187 |
+
return_output = [None, None]
|
| 188 |
+
return_output[self._metadata.return_hidden_states_index] = original_hidden_states
|
| 189 |
+
return_output[self._metadata.return_encoder_hidden_states_index] = original_encoder_hidden_states
|
| 190 |
+
return_output = tuple(return_output)
|
| 191 |
+
return return_output
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def apply_first_block_cache(module: torch.nn.Module, config: FirstBlockCacheConfig) -> None:
|
| 195 |
+
"""
|
| 196 |
+
Applies [First Block
|
| 197 |
+
Cache](https://github.com/chengzeyi/ParaAttention/blob/4de137c5b96416489f06e43e19f2c14a772e28fd/README.md#first-block-cache-our-dynamic-caching)
|
| 198 |
+
to a given module.
|
| 199 |
+
|
| 200 |
+
First Block Cache builds on the ideas of [TeaCache](https://huggingface.co/papers/2411.19108). It is much simpler
|
| 201 |
+
to implement generically for a wide range of models and has been integrated first for experimental purposes.
|
| 202 |
+
|
| 203 |
+
Args:
|
| 204 |
+
module (`torch.nn.Module`):
|
| 205 |
+
The pytorch module to apply FBCache to. Typically, this should be a transformer architecture supported in
|
| 206 |
+
Diffusers, such as `CogVideoXTransformer3DModel`, but external implementations may also work.
|
| 207 |
+
config (`FirstBlockCacheConfig`):
|
| 208 |
+
The configuration to use for applying the FBCache method.
|
| 209 |
+
|
| 210 |
+
Example:
|
| 211 |
+
```python
|
| 212 |
+
>>> import torch
|
| 213 |
+
>>> from diffusers import CogView4Pipeline
|
| 214 |
+
>>> from diffusers.hooks import apply_first_block_cache, FirstBlockCacheConfig
|
| 215 |
+
|
| 216 |
+
>>> pipe = CogView4Pipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16)
|
| 217 |
+
>>> pipe.to("cuda")
|
| 218 |
+
|
| 219 |
+
>>> apply_first_block_cache(pipe.transformer, FirstBlockCacheConfig(threshold=0.2))
|
| 220 |
+
|
| 221 |
+
>>> prompt = "A photo of an astronaut riding a horse on mars"
|
| 222 |
+
>>> image = pipe(prompt, generator=torch.Generator().manual_seed(42)).images[0]
|
| 223 |
+
>>> image.save("output.png")
|
| 224 |
+
```
|
| 225 |
+
"""
|
| 226 |
+
|
| 227 |
+
state_manager = StateManager(FBCSharedBlockState, (), {})
|
| 228 |
+
remaining_blocks = []
|
| 229 |
+
|
| 230 |
+
for name, submodule in module.named_children():
|
| 231 |
+
if name not in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS or not isinstance(submodule, torch.nn.ModuleList):
|
| 232 |
+
continue
|
| 233 |
+
for index, block in enumerate(submodule):
|
| 234 |
+
remaining_blocks.append((f"{name}.{index}", block))
|
| 235 |
+
|
| 236 |
+
head_block_name, head_block = remaining_blocks.pop(0)
|
| 237 |
+
tail_block_name, tail_block = remaining_blocks.pop(-1)
|
| 238 |
+
|
| 239 |
+
logger.debug(f"Applying FBCHeadBlockHook to '{head_block_name}'")
|
| 240 |
+
_apply_fbc_head_block_hook(head_block, state_manager, config.threshold)
|
| 241 |
+
|
| 242 |
+
for name, block in remaining_blocks:
|
| 243 |
+
logger.debug(f"Applying FBCBlockHook to '{name}'")
|
| 244 |
+
_apply_fbc_block_hook(block, state_manager)
|
| 245 |
+
|
| 246 |
+
logger.debug(f"Applying FBCBlockHook to tail block '{tail_block_name}'")
|
| 247 |
+
_apply_fbc_block_hook(tail_block, state_manager, is_tail=True)
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def _apply_fbc_head_block_hook(block: torch.nn.Module, state_manager: StateManager, threshold: float) -> None:
|
| 251 |
+
registry = HookRegistry.check_if_exists_or_initialize(block)
|
| 252 |
+
hook = FBCHeadBlockHook(state_manager, threshold)
|
| 253 |
+
registry.register_hook(hook, _FBC_LEADER_BLOCK_HOOK)
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def _apply_fbc_block_hook(block: torch.nn.Module, state_manager: StateManager, is_tail: bool = False) -> None:
|
| 257 |
+
registry = HookRegistry.check_if_exists_or_initialize(block)
|
| 258 |
+
hook = FBCBlockHook(state_manager, is_tail)
|
| 259 |
+
registry.register_hook(hook, _FBC_BLOCK_HOOK)
|
pythonProject/.venv/Lib/site-packages/diffusers/hooks/group_offloading.py
ADDED
|
@@ -0,0 +1,898 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import hashlib
|
| 16 |
+
import os
|
| 17 |
+
from contextlib import contextmanager, nullcontext
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
from enum import Enum
|
| 20 |
+
from typing import Dict, List, Optional, Set, Tuple, Union
|
| 21 |
+
|
| 22 |
+
import safetensors.torch
|
| 23 |
+
import torch
|
| 24 |
+
|
| 25 |
+
from ..utils import get_logger, is_accelerate_available
|
| 26 |
+
from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
|
| 27 |
+
from .hooks import HookRegistry, ModelHook
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
if is_accelerate_available():
|
| 31 |
+
from accelerate.hooks import AlignDevicesHook, CpuOffload
|
| 32 |
+
from accelerate.utils import send_to_device
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
logger = get_logger(__name__) # pylint: disable=invalid-name
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# fmt: off
|
| 39 |
+
_GROUP_OFFLOADING = "group_offloading"
|
| 40 |
+
_LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
|
| 41 |
+
_LAZY_PREFETCH_GROUP_OFFLOADING = "lazy_prefetch_group_offloading"
|
| 42 |
+
_GROUP_ID_LAZY_LEAF = "lazy_leafs"
|
| 43 |
+
# fmt: on
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class GroupOffloadingType(str, Enum):
|
| 47 |
+
BLOCK_LEVEL = "block_level"
|
| 48 |
+
LEAF_LEVEL = "leaf_level"
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@dataclass
|
| 52 |
+
class GroupOffloadingConfig:
|
| 53 |
+
onload_device: torch.device
|
| 54 |
+
offload_device: torch.device
|
| 55 |
+
offload_type: GroupOffloadingType
|
| 56 |
+
non_blocking: bool
|
| 57 |
+
record_stream: bool
|
| 58 |
+
low_cpu_mem_usage: bool
|
| 59 |
+
num_blocks_per_group: Optional[int] = None
|
| 60 |
+
offload_to_disk_path: Optional[str] = None
|
| 61 |
+
stream: Optional[Union[torch.cuda.Stream, torch.Stream]] = None
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class ModuleGroup:
|
| 65 |
+
def __init__(
|
| 66 |
+
self,
|
| 67 |
+
modules: List[torch.nn.Module],
|
| 68 |
+
offload_device: torch.device,
|
| 69 |
+
onload_device: torch.device,
|
| 70 |
+
offload_leader: torch.nn.Module,
|
| 71 |
+
onload_leader: Optional[torch.nn.Module] = None,
|
| 72 |
+
parameters: Optional[List[torch.nn.Parameter]] = None,
|
| 73 |
+
buffers: Optional[List[torch.Tensor]] = None,
|
| 74 |
+
non_blocking: bool = False,
|
| 75 |
+
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
|
| 76 |
+
record_stream: Optional[bool] = False,
|
| 77 |
+
low_cpu_mem_usage: bool = False,
|
| 78 |
+
onload_self: bool = True,
|
| 79 |
+
offload_to_disk_path: Optional[str] = None,
|
| 80 |
+
group_id: Optional[int] = None,
|
| 81 |
+
) -> None:
|
| 82 |
+
self.modules = modules
|
| 83 |
+
self.offload_device = offload_device
|
| 84 |
+
self.onload_device = onload_device
|
| 85 |
+
self.offload_leader = offload_leader
|
| 86 |
+
self.onload_leader = onload_leader
|
| 87 |
+
self.parameters = parameters or []
|
| 88 |
+
self.buffers = buffers or []
|
| 89 |
+
self.non_blocking = non_blocking or stream is not None
|
| 90 |
+
self.stream = stream
|
| 91 |
+
self.record_stream = record_stream
|
| 92 |
+
self.onload_self = onload_self
|
| 93 |
+
self.low_cpu_mem_usage = low_cpu_mem_usage
|
| 94 |
+
|
| 95 |
+
self.offload_to_disk_path = offload_to_disk_path
|
| 96 |
+
self._is_offloaded_to_disk = False
|
| 97 |
+
|
| 98 |
+
if self.offload_to_disk_path is not None:
|
| 99 |
+
# Instead of `group_id or str(id(self))` we do this because `group_id` can be "" as well.
|
| 100 |
+
self.group_id = group_id if group_id is not None else str(id(self))
|
| 101 |
+
short_hash = _compute_group_hash(self.group_id)
|
| 102 |
+
self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{short_hash}.safetensors")
|
| 103 |
+
|
| 104 |
+
all_tensors = []
|
| 105 |
+
for module in self.modules:
|
| 106 |
+
all_tensors.extend(list(module.parameters()))
|
| 107 |
+
all_tensors.extend(list(module.buffers()))
|
| 108 |
+
all_tensors.extend(self.parameters)
|
| 109 |
+
all_tensors.extend(self.buffers)
|
| 110 |
+
all_tensors = list(dict.fromkeys(all_tensors)) # Remove duplicates
|
| 111 |
+
|
| 112 |
+
self.tensor_to_key = {tensor: f"tensor_{i}" for i, tensor in enumerate(all_tensors)}
|
| 113 |
+
self.key_to_tensor = {v: k for k, v in self.tensor_to_key.items()}
|
| 114 |
+
self.cpu_param_dict = {}
|
| 115 |
+
else:
|
| 116 |
+
self.cpu_param_dict = self._init_cpu_param_dict()
|
| 117 |
+
|
| 118 |
+
self._torch_accelerator_module = (
|
| 119 |
+
getattr(torch, torch.accelerator.current_accelerator().type)
|
| 120 |
+
if hasattr(torch, "accelerator")
|
| 121 |
+
else torch.cuda
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
def _init_cpu_param_dict(self):
|
| 125 |
+
cpu_param_dict = {}
|
| 126 |
+
if self.stream is None:
|
| 127 |
+
return cpu_param_dict
|
| 128 |
+
|
| 129 |
+
for module in self.modules:
|
| 130 |
+
for param in module.parameters():
|
| 131 |
+
cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory()
|
| 132 |
+
for buffer in module.buffers():
|
| 133 |
+
cpu_param_dict[buffer] = (
|
| 134 |
+
buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory()
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
for param in self.parameters:
|
| 138 |
+
cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory()
|
| 139 |
+
|
| 140 |
+
for buffer in self.buffers:
|
| 141 |
+
cpu_param_dict[buffer] = buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory()
|
| 142 |
+
|
| 143 |
+
return cpu_param_dict
|
| 144 |
+
|
| 145 |
+
@contextmanager
|
| 146 |
+
def _pinned_memory_tensors(self):
|
| 147 |
+
try:
|
| 148 |
+
pinned_dict = {
|
| 149 |
+
param: tensor.pin_memory() if not tensor.is_pinned() else tensor
|
| 150 |
+
for param, tensor in self.cpu_param_dict.items()
|
| 151 |
+
}
|
| 152 |
+
yield pinned_dict
|
| 153 |
+
finally:
|
| 154 |
+
pinned_dict = None
|
| 155 |
+
|
| 156 |
+
def _transfer_tensor_to_device(self, tensor, source_tensor):
|
| 157 |
+
tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
|
| 158 |
+
if self.record_stream:
|
| 159 |
+
tensor.data.record_stream(self._torch_accelerator_module.current_stream())
|
| 160 |
+
|
| 161 |
+
def _process_tensors_from_modules(self, pinned_memory=None):
|
| 162 |
+
for group_module in self.modules:
|
| 163 |
+
for param in group_module.parameters():
|
| 164 |
+
source = pinned_memory[param] if pinned_memory else param.data
|
| 165 |
+
self._transfer_tensor_to_device(param, source)
|
| 166 |
+
for buffer in group_module.buffers():
|
| 167 |
+
source = pinned_memory[buffer] if pinned_memory else buffer.data
|
| 168 |
+
self._transfer_tensor_to_device(buffer, source)
|
| 169 |
+
|
| 170 |
+
for param in self.parameters:
|
| 171 |
+
source = pinned_memory[param] if pinned_memory else param.data
|
| 172 |
+
self._transfer_tensor_to_device(param, source)
|
| 173 |
+
|
| 174 |
+
for buffer in self.buffers:
|
| 175 |
+
source = pinned_memory[buffer] if pinned_memory else buffer.data
|
| 176 |
+
self._transfer_tensor_to_device(buffer, source)
|
| 177 |
+
|
| 178 |
+
def _onload_from_disk(self):
|
| 179 |
+
if self.stream is not None:
|
| 180 |
+
# Wait for previous Host->Device transfer to complete
|
| 181 |
+
self.stream.synchronize()
|
| 182 |
+
|
| 183 |
+
context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream)
|
| 184 |
+
current_stream = self._torch_accelerator_module.current_stream() if self.record_stream else None
|
| 185 |
+
|
| 186 |
+
with context:
|
| 187 |
+
# Load to CPU (if using streams) or directly to target device, pin, and async copy to device
|
| 188 |
+
device = str(self.onload_device) if self.stream is None else "cpu"
|
| 189 |
+
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=device)
|
| 190 |
+
|
| 191 |
+
if self.stream is not None:
|
| 192 |
+
for key, tensor_obj in self.key_to_tensor.items():
|
| 193 |
+
pinned_tensor = loaded_tensors[key].pin_memory()
|
| 194 |
+
tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking)
|
| 195 |
+
if self.record_stream:
|
| 196 |
+
tensor_obj.data.record_stream(current_stream)
|
| 197 |
+
else:
|
| 198 |
+
onload_device = (
|
| 199 |
+
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
|
| 200 |
+
)
|
| 201 |
+
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
|
| 202 |
+
for key, tensor_obj in self.key_to_tensor.items():
|
| 203 |
+
tensor_obj.data = loaded_tensors[key]
|
| 204 |
+
|
| 205 |
+
def _onload_from_memory(self):
|
| 206 |
+
if self.stream is not None:
|
| 207 |
+
# Wait for previous Host->Device transfer to complete
|
| 208 |
+
self.stream.synchronize()
|
| 209 |
+
|
| 210 |
+
context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream)
|
| 211 |
+
with context:
|
| 212 |
+
if self.stream is not None:
|
| 213 |
+
with self._pinned_memory_tensors() as pinned_memory:
|
| 214 |
+
self._process_tensors_from_modules(pinned_memory)
|
| 215 |
+
else:
|
| 216 |
+
self._process_tensors_from_modules(None)
|
| 217 |
+
|
| 218 |
+
def _offload_to_disk(self):
|
| 219 |
+
# TODO: we can potentially optimize this code path by checking if the _all_ the desired
|
| 220 |
+
# safetensor files exist on the disk and if so, skip this step entirely, reducing IO
|
| 221 |
+
# overhead. Currently, we just check if the given `safetensors_file_path` exists and if not
|
| 222 |
+
# we perform a write.
|
| 223 |
+
# Check if the file has been saved in this session or if it already exists on disk.
|
| 224 |
+
if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path):
|
| 225 |
+
os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True)
|
| 226 |
+
tensors_to_save = {key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()}
|
| 227 |
+
safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path)
|
| 228 |
+
|
| 229 |
+
# The group is now considered offloaded to disk for the rest of the session.
|
| 230 |
+
self._is_offloaded_to_disk = True
|
| 231 |
+
|
| 232 |
+
# We do this to free up the RAM which is still holding the up tensor data.
|
| 233 |
+
for tensor_obj in self.tensor_to_key.keys():
|
| 234 |
+
tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device)
|
| 235 |
+
|
| 236 |
+
def _offload_to_memory(self):
|
| 237 |
+
if self.stream is not None:
|
| 238 |
+
if not self.record_stream:
|
| 239 |
+
self._torch_accelerator_module.current_stream().synchronize()
|
| 240 |
+
|
| 241 |
+
for group_module in self.modules:
|
| 242 |
+
for param in group_module.parameters():
|
| 243 |
+
param.data = self.cpu_param_dict[param]
|
| 244 |
+
for param in self.parameters:
|
| 245 |
+
param.data = self.cpu_param_dict[param]
|
| 246 |
+
for buffer in self.buffers:
|
| 247 |
+
buffer.data = self.cpu_param_dict[buffer]
|
| 248 |
+
else:
|
| 249 |
+
for group_module in self.modules:
|
| 250 |
+
group_module.to(self.offload_device, non_blocking=False)
|
| 251 |
+
for param in self.parameters:
|
| 252 |
+
param.data = param.data.to(self.offload_device, non_blocking=False)
|
| 253 |
+
for buffer in self.buffers:
|
| 254 |
+
buffer.data = buffer.data.to(self.offload_device, non_blocking=False)
|
| 255 |
+
|
| 256 |
+
@torch.compiler.disable()
|
| 257 |
+
def onload_(self):
|
| 258 |
+
r"""Onloads the group of parameters to the onload_device."""
|
| 259 |
+
if self.offload_to_disk_path is not None:
|
| 260 |
+
self._onload_from_disk()
|
| 261 |
+
else:
|
| 262 |
+
self._onload_from_memory()
|
| 263 |
+
|
| 264 |
+
@torch.compiler.disable()
|
| 265 |
+
def offload_(self):
|
| 266 |
+
r"""Offloads the group of parameters to the offload_device."""
|
| 267 |
+
if self.offload_to_disk_path:
|
| 268 |
+
self._offload_to_disk()
|
| 269 |
+
else:
|
| 270 |
+
self._offload_to_memory()
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
class GroupOffloadingHook(ModelHook):
|
| 274 |
+
r"""
|
| 275 |
+
A hook that offloads groups of torch.nn.Module to the CPU for storage and onloads to accelerator device for
|
| 276 |
+
computation. Each group has one "onload leader" module that is responsible for onloading, and an "offload leader"
|
| 277 |
+
module that is responsible for offloading. If prefetching is enabled, the onload leader of the previous module
|
| 278 |
+
group is responsible for onloading the current module group.
|
| 279 |
+
"""
|
| 280 |
+
|
| 281 |
+
_is_stateful = False
|
| 282 |
+
|
| 283 |
+
def __init__(self, group: ModuleGroup, *, config: GroupOffloadingConfig) -> None:
|
| 284 |
+
self.group = group
|
| 285 |
+
self.next_group: Optional[ModuleGroup] = None
|
| 286 |
+
self.config = config
|
| 287 |
+
|
| 288 |
+
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
|
| 289 |
+
if self.group.offload_leader == module:
|
| 290 |
+
self.group.offload_()
|
| 291 |
+
return module
|
| 292 |
+
|
| 293 |
+
def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
|
| 294 |
+
# If there wasn't an onload_leader assigned, we assume that the submodule that first called its forward
|
| 295 |
+
# method is the onload_leader of the group.
|
| 296 |
+
if self.group.onload_leader is None:
|
| 297 |
+
self.group.onload_leader = module
|
| 298 |
+
|
| 299 |
+
# If the current module is the onload_leader of the group, we onload the group if it is supposed
|
| 300 |
+
# to onload itself. In the case of using prefetching with streams, we onload the next group if
|
| 301 |
+
# it is not supposed to onload itself.
|
| 302 |
+
if self.group.onload_leader == module:
|
| 303 |
+
if self.group.onload_self:
|
| 304 |
+
self.group.onload_()
|
| 305 |
+
|
| 306 |
+
should_onload_next_group = self.next_group is not None and not self.next_group.onload_self
|
| 307 |
+
if should_onload_next_group:
|
| 308 |
+
self.next_group.onload_()
|
| 309 |
+
|
| 310 |
+
should_synchronize = (
|
| 311 |
+
not self.group.onload_self and self.group.stream is not None and not should_onload_next_group
|
| 312 |
+
)
|
| 313 |
+
if should_synchronize:
|
| 314 |
+
# If this group didn't onload itself, it means it was asynchronously onloaded by the
|
| 315 |
+
# previous group. We need to synchronize the side stream to ensure parameters
|
| 316 |
+
# are completely loaded to proceed with forward pass. Without this, uninitialized
|
| 317 |
+
# weights will be used in the computation, leading to incorrect results
|
| 318 |
+
# Also, we should only do this synchronization if we don't already do it from the sync call in
|
| 319 |
+
# self.next_group.onload_, hence the `not should_onload_next_group` check.
|
| 320 |
+
self.group.stream.synchronize()
|
| 321 |
+
|
| 322 |
+
args = send_to_device(args, self.group.onload_device, non_blocking=self.group.non_blocking)
|
| 323 |
+
kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking)
|
| 324 |
+
return args, kwargs
|
| 325 |
+
|
| 326 |
+
def post_forward(self, module: torch.nn.Module, output):
|
| 327 |
+
if self.group.offload_leader == module:
|
| 328 |
+
self.group.offload_()
|
| 329 |
+
return output
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
class LazyPrefetchGroupOffloadingHook(ModelHook):
|
| 333 |
+
r"""
|
| 334 |
+
A hook, used in conjunction with GroupOffloadingHook, that applies lazy prefetching to groups of torch.nn.Module.
|
| 335 |
+
This hook is used to determine the order in which the layers are executed during the forward pass. Once the layer
|
| 336 |
+
invocation order is known, assignments of the next_group attribute for prefetching can be made, which allows
|
| 337 |
+
prefetching groups in the correct order.
|
| 338 |
+
"""
|
| 339 |
+
|
| 340 |
+
_is_stateful = False
|
| 341 |
+
|
| 342 |
+
def __init__(self):
|
| 343 |
+
self.execution_order: List[Tuple[str, torch.nn.Module]] = []
|
| 344 |
+
self._layer_execution_tracker_module_names = set()
|
| 345 |
+
|
| 346 |
+
def initialize_hook(self, module):
|
| 347 |
+
def make_execution_order_update_callback(current_name, current_submodule):
|
| 348 |
+
def callback():
|
| 349 |
+
if not torch.compiler.is_compiling():
|
| 350 |
+
logger.debug(f"Adding {current_name} to the execution order")
|
| 351 |
+
self.execution_order.append((current_name, current_submodule))
|
| 352 |
+
|
| 353 |
+
return callback
|
| 354 |
+
|
| 355 |
+
# To every submodule that contains a group offloading hook (at this point, no prefetching is enabled for any
|
| 356 |
+
# of the groups), we add a layer execution tracker hook that will be used to determine the order in which the
|
| 357 |
+
# layers are executed during the forward pass.
|
| 358 |
+
for name, submodule in module.named_modules():
|
| 359 |
+
if name == "" or not hasattr(submodule, "_diffusers_hook"):
|
| 360 |
+
continue
|
| 361 |
+
|
| 362 |
+
registry = HookRegistry.check_if_exists_or_initialize(submodule)
|
| 363 |
+
group_offloading_hook = registry.get_hook(_GROUP_OFFLOADING)
|
| 364 |
+
|
| 365 |
+
if group_offloading_hook is not None:
|
| 366 |
+
# For the first forward pass, we have to load in a blocking manner
|
| 367 |
+
group_offloading_hook.group.non_blocking = False
|
| 368 |
+
layer_tracker_hook = LayerExecutionTrackerHook(make_execution_order_update_callback(name, submodule))
|
| 369 |
+
registry.register_hook(layer_tracker_hook, _LAYER_EXECUTION_TRACKER)
|
| 370 |
+
self._layer_execution_tracker_module_names.add(name)
|
| 371 |
+
|
| 372 |
+
return module
|
| 373 |
+
|
| 374 |
+
def post_forward(self, module, output):
|
| 375 |
+
# At this point, for the current modules' submodules, we know the execution order of the layers. We can now
|
| 376 |
+
# remove the layer execution tracker hooks and apply prefetching by setting the next_group attribute for each
|
| 377 |
+
# group offloading hook.
|
| 378 |
+
num_executed = len(self.execution_order)
|
| 379 |
+
execution_order_module_names = {name for name, _ in self.execution_order}
|
| 380 |
+
|
| 381 |
+
# It may be possible that some layers were not executed during the forward pass. This can happen if the layer
|
| 382 |
+
# is not used in the forward pass, or if the layer is not executed due to some other reason. In such cases, we
|
| 383 |
+
# may not be able to apply prefetching in the correct order, which can lead to device-mismatch related errors
|
| 384 |
+
# if the missing layers end up being executed in the future.
|
| 385 |
+
if execution_order_module_names != self._layer_execution_tracker_module_names:
|
| 386 |
+
unexecuted_layers = list(self._layer_execution_tracker_module_names - execution_order_module_names)
|
| 387 |
+
if not torch.compiler.is_compiling():
|
| 388 |
+
logger.warning(
|
| 389 |
+
"It seems like some layers were not executed during the forward pass. This may lead to problems when "
|
| 390 |
+
"applying lazy prefetching with automatic tracing and lead to device-mismatch related errors. Please "
|
| 391 |
+
"make sure that all layers are executed during the forward pass. The following layers were not executed:\n"
|
| 392 |
+
f"{unexecuted_layers=}"
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
# Remove the layer execution tracker hooks from the submodules
|
| 396 |
+
base_module_registry = module._diffusers_hook
|
| 397 |
+
registries = [submodule._diffusers_hook for _, submodule in self.execution_order]
|
| 398 |
+
group_offloading_hooks = [registry.get_hook(_GROUP_OFFLOADING) for registry in registries]
|
| 399 |
+
|
| 400 |
+
for i in range(num_executed):
|
| 401 |
+
registries[i].remove_hook(_LAYER_EXECUTION_TRACKER, recurse=False)
|
| 402 |
+
|
| 403 |
+
# Remove the current lazy prefetch group offloading hook so that it doesn't interfere with the next forward pass
|
| 404 |
+
base_module_registry.remove_hook(_LAZY_PREFETCH_GROUP_OFFLOADING, recurse=False)
|
| 405 |
+
|
| 406 |
+
# LazyPrefetchGroupOffloadingHook is only used with streams, so we know that non_blocking should be True.
|
| 407 |
+
# We disable non_blocking for the first forward pass, but need to enable it for the subsequent passes to
|
| 408 |
+
# see the benefits of prefetching.
|
| 409 |
+
for hook in group_offloading_hooks:
|
| 410 |
+
hook.group.non_blocking = True
|
| 411 |
+
|
| 412 |
+
# Set required attributes for prefetching
|
| 413 |
+
if num_executed > 0:
|
| 414 |
+
base_module_group_offloading_hook = base_module_registry.get_hook(_GROUP_OFFLOADING)
|
| 415 |
+
base_module_group_offloading_hook.next_group = group_offloading_hooks[0].group
|
| 416 |
+
base_module_group_offloading_hook.next_group.onload_self = False
|
| 417 |
+
|
| 418 |
+
for i in range(num_executed - 1):
|
| 419 |
+
name1, _ = self.execution_order[i]
|
| 420 |
+
name2, _ = self.execution_order[i + 1]
|
| 421 |
+
if not torch.compiler.is_compiling():
|
| 422 |
+
logger.debug(f"Applying lazy prefetch group offloading from {name1} to {name2}")
|
| 423 |
+
group_offloading_hooks[i].next_group = group_offloading_hooks[i + 1].group
|
| 424 |
+
group_offloading_hooks[i].next_group.onload_self = False
|
| 425 |
+
|
| 426 |
+
return output
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
class LayerExecutionTrackerHook(ModelHook):
|
| 430 |
+
r"""
|
| 431 |
+
A hook that tracks the order in which the layers are executed during the forward pass by calling back to the
|
| 432 |
+
LazyPrefetchGroupOffloadingHook to update the execution order.
|
| 433 |
+
"""
|
| 434 |
+
|
| 435 |
+
_is_stateful = False
|
| 436 |
+
|
| 437 |
+
def __init__(self, execution_order_update_callback):
|
| 438 |
+
self.execution_order_update_callback = execution_order_update_callback
|
| 439 |
+
|
| 440 |
+
def pre_forward(self, module, *args, **kwargs):
|
| 441 |
+
self.execution_order_update_callback()
|
| 442 |
+
return args, kwargs
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
def apply_group_offloading(
|
| 446 |
+
module: torch.nn.Module,
|
| 447 |
+
onload_device: Union[str, torch.device],
|
| 448 |
+
offload_device: Union[str, torch.device] = torch.device("cpu"),
|
| 449 |
+
offload_type: Union[str, GroupOffloadingType] = "block_level",
|
| 450 |
+
num_blocks_per_group: Optional[int] = None,
|
| 451 |
+
non_blocking: bool = False,
|
| 452 |
+
use_stream: bool = False,
|
| 453 |
+
record_stream: bool = False,
|
| 454 |
+
low_cpu_mem_usage: bool = False,
|
| 455 |
+
offload_to_disk_path: Optional[str] = None,
|
| 456 |
+
) -> None:
|
| 457 |
+
r"""
|
| 458 |
+
Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
|
| 459 |
+
where it is beneficial, we need to first provide some context on how other supported offloading methods work.
|
| 460 |
+
|
| 461 |
+
Typically, offloading is done at two levels:
|
| 462 |
+
- Module-level: In Diffusers, this can be enabled using the `ModelMixin::enable_model_cpu_offload()` method. It
|
| 463 |
+
works by offloading each component of a pipeline to the CPU for storage, and onloading to the accelerator device
|
| 464 |
+
when needed for computation. This method is more memory-efficient than keeping all components on the accelerator,
|
| 465 |
+
but the memory requirements are still quite high. For this method to work, one needs memory equivalent to size of
|
| 466 |
+
the model in runtime dtype + size of largest intermediate activation tensors to be able to complete the forward
|
| 467 |
+
pass.
|
| 468 |
+
- Leaf-level: In Diffusers, this can be enabled using the `ModelMixin::enable_sequential_cpu_offload()` method. It
|
| 469 |
+
works by offloading the lowest leaf-level parameters of the computation graph to the CPU for storage, and
|
| 470 |
+
onloading only the leafs to the accelerator device for computation. This uses the lowest amount of accelerator
|
| 471 |
+
memory, but can be slower due to the excessive number of device synchronizations.
|
| 472 |
+
|
| 473 |
+
Group offloading is a middle ground between the two methods. It works by offloading groups of internal layers,
|
| 474 |
+
(either `torch.nn.ModuleList` or `torch.nn.Sequential`). This method uses lower memory than module-level
|
| 475 |
+
offloading. It is also faster than leaf-level/sequential offloading, as the number of device synchronizations is
|
| 476 |
+
reduced.
|
| 477 |
+
|
| 478 |
+
Another supported feature (for CUDA devices with support for asynchronous data transfer streams) is the ability to
|
| 479 |
+
overlap data transfer and computation to reduce the overall execution time compared to sequential offloading. This
|
| 480 |
+
is enabled using layer prefetching with streams, i.e., the layer that is to be executed next starts onloading to
|
| 481 |
+
the accelerator device while the current layer is being executed - this increases the memory requirements slightly.
|
| 482 |
+
Note that this implementation also supports leaf-level offloading but can be made much faster when using streams.
|
| 483 |
+
|
| 484 |
+
Args:
|
| 485 |
+
module (`torch.nn.Module`):
|
| 486 |
+
The module to which group offloading is applied.
|
| 487 |
+
onload_device (`torch.device`):
|
| 488 |
+
The device to which the group of modules are onloaded.
|
| 489 |
+
offload_device (`torch.device`, defaults to `torch.device("cpu")`):
|
| 490 |
+
The device to which the group of modules are offloaded. This should typically be the CPU. Default is CPU.
|
| 491 |
+
offload_type (`str` or `GroupOffloadingType`, defaults to "block_level"):
|
| 492 |
+
The type of offloading to be applied. Can be one of "block_level" or "leaf_level". Default is
|
| 493 |
+
"block_level".
|
| 494 |
+
offload_to_disk_path (`str`, *optional*, defaults to `None`):
|
| 495 |
+
The path to the directory where parameters will be offloaded. Setting this option can be useful in limited
|
| 496 |
+
RAM environment settings where a reasonable speed-memory trade-off is desired.
|
| 497 |
+
num_blocks_per_group (`int`, *optional*):
|
| 498 |
+
The number of blocks per group when using offload_type="block_level". This is required when using
|
| 499 |
+
offload_type="block_level".
|
| 500 |
+
non_blocking (`bool`, defaults to `False`):
|
| 501 |
+
If True, offloading and onloading is done with non-blocking data transfer.
|
| 502 |
+
use_stream (`bool`, defaults to `False`):
|
| 503 |
+
If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for
|
| 504 |
+
overlapping computation and data transfer.
|
| 505 |
+
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
|
| 506 |
+
as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the
|
| 507 |
+
[PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more
|
| 508 |
+
details.
|
| 509 |
+
low_cpu_mem_usage (`bool`, defaults to `False`):
|
| 510 |
+
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
|
| 511 |
+
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
|
| 512 |
+
the CPU memory is a bottleneck but may counteract the benefits of using streams.
|
| 513 |
+
|
| 514 |
+
Example:
|
| 515 |
+
```python
|
| 516 |
+
>>> from diffusers import CogVideoXTransformer3DModel
|
| 517 |
+
>>> from diffusers.hooks import apply_group_offloading
|
| 518 |
+
|
| 519 |
+
>>> transformer = CogVideoXTransformer3DModel.from_pretrained(
|
| 520 |
+
... "THUDM/CogVideoX-5b", subfolder="transformer", torch_dtype=torch.bfloat16
|
| 521 |
+
... )
|
| 522 |
+
|
| 523 |
+
>>> apply_group_offloading(
|
| 524 |
+
... transformer,
|
| 525 |
+
... onload_device=torch.device("cuda"),
|
| 526 |
+
... offload_device=torch.device("cpu"),
|
| 527 |
+
... offload_type="block_level",
|
| 528 |
+
... num_blocks_per_group=2,
|
| 529 |
+
... use_stream=True,
|
| 530 |
+
... )
|
| 531 |
+
```
|
| 532 |
+
"""
|
| 533 |
+
|
| 534 |
+
onload_device = torch.device(onload_device) if isinstance(onload_device, str) else onload_device
|
| 535 |
+
offload_device = torch.device(offload_device) if isinstance(offload_device, str) else offload_device
|
| 536 |
+
offload_type = GroupOffloadingType(offload_type)
|
| 537 |
+
|
| 538 |
+
stream = None
|
| 539 |
+
if use_stream:
|
| 540 |
+
if torch.cuda.is_available():
|
| 541 |
+
stream = torch.cuda.Stream()
|
| 542 |
+
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
| 543 |
+
stream = torch.Stream()
|
| 544 |
+
else:
|
| 545 |
+
raise ValueError("Using streams for data transfer requires a CUDA device, or an Intel XPU device.")
|
| 546 |
+
|
| 547 |
+
if not use_stream and record_stream:
|
| 548 |
+
raise ValueError("`record_stream` cannot be True when `use_stream=False`.")
|
| 549 |
+
if offload_type == GroupOffloadingType.BLOCK_LEVEL and num_blocks_per_group is None:
|
| 550 |
+
raise ValueError("`num_blocks_per_group` must be provided when using `offload_type='block_level'.")
|
| 551 |
+
|
| 552 |
+
_raise_error_if_accelerate_model_or_sequential_hook_present(module)
|
| 553 |
+
|
| 554 |
+
config = GroupOffloadingConfig(
|
| 555 |
+
onload_device=onload_device,
|
| 556 |
+
offload_device=offload_device,
|
| 557 |
+
offload_type=offload_type,
|
| 558 |
+
num_blocks_per_group=num_blocks_per_group,
|
| 559 |
+
non_blocking=non_blocking,
|
| 560 |
+
stream=stream,
|
| 561 |
+
record_stream=record_stream,
|
| 562 |
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
| 563 |
+
offload_to_disk_path=offload_to_disk_path,
|
| 564 |
+
)
|
| 565 |
+
_apply_group_offloading(module, config)
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
def _apply_group_offloading(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
|
| 569 |
+
if config.offload_type == GroupOffloadingType.BLOCK_LEVEL:
|
| 570 |
+
_apply_group_offloading_block_level(module, config)
|
| 571 |
+
elif config.offload_type == GroupOffloadingType.LEAF_LEVEL:
|
| 572 |
+
_apply_group_offloading_leaf_level(module, config)
|
| 573 |
+
else:
|
| 574 |
+
assert False
|
| 575 |
+
|
| 576 |
+
|
| 577 |
+
def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
|
| 578 |
+
r"""
|
| 579 |
+
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
|
| 580 |
+
the "leaf_level" offloading, which is more fine-grained, this offloading is done at the top-level blocks.
|
| 581 |
+
"""
|
| 582 |
+
|
| 583 |
+
if config.stream is not None and config.num_blocks_per_group != 1:
|
| 584 |
+
logger.warning(
|
| 585 |
+
f"Using streams is only supported for num_blocks_per_group=1. Got {config.num_blocks_per_group=}. Setting it to 1."
|
| 586 |
+
)
|
| 587 |
+
config.num_blocks_per_group = 1
|
| 588 |
+
|
| 589 |
+
# Create module groups for ModuleList and Sequential blocks
|
| 590 |
+
modules_with_group_offloading = set()
|
| 591 |
+
unmatched_modules = []
|
| 592 |
+
matched_module_groups = []
|
| 593 |
+
for name, submodule in module.named_children():
|
| 594 |
+
if not isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
|
| 595 |
+
unmatched_modules.append((name, submodule))
|
| 596 |
+
modules_with_group_offloading.add(name)
|
| 597 |
+
continue
|
| 598 |
+
|
| 599 |
+
for i in range(0, len(submodule), config.num_blocks_per_group):
|
| 600 |
+
current_modules = submodule[i : i + config.num_blocks_per_group]
|
| 601 |
+
group_id = f"{name}_{i}_{i + len(current_modules) - 1}"
|
| 602 |
+
group = ModuleGroup(
|
| 603 |
+
modules=current_modules,
|
| 604 |
+
offload_device=config.offload_device,
|
| 605 |
+
onload_device=config.onload_device,
|
| 606 |
+
offload_to_disk_path=config.offload_to_disk_path,
|
| 607 |
+
offload_leader=current_modules[-1],
|
| 608 |
+
onload_leader=current_modules[0],
|
| 609 |
+
non_blocking=config.non_blocking,
|
| 610 |
+
stream=config.stream,
|
| 611 |
+
record_stream=config.record_stream,
|
| 612 |
+
low_cpu_mem_usage=config.low_cpu_mem_usage,
|
| 613 |
+
onload_self=True,
|
| 614 |
+
group_id=group_id,
|
| 615 |
+
)
|
| 616 |
+
matched_module_groups.append(group)
|
| 617 |
+
for j in range(i, i + len(current_modules)):
|
| 618 |
+
modules_with_group_offloading.add(f"{name}.{j}")
|
| 619 |
+
|
| 620 |
+
# Apply group offloading hooks to the module groups
|
| 621 |
+
for i, group in enumerate(matched_module_groups):
|
| 622 |
+
for group_module in group.modules:
|
| 623 |
+
_apply_group_offloading_hook(group_module, group, config=config)
|
| 624 |
+
|
| 625 |
+
# Parameters and Buffers of the top-level module need to be offloaded/onloaded separately
|
| 626 |
+
# when the forward pass of this module is called. This is because the top-level module is not
|
| 627 |
+
# part of any group (as doing so would lead to no VRAM savings).
|
| 628 |
+
parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading)
|
| 629 |
+
buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading)
|
| 630 |
+
parameters = [param for _, param in parameters]
|
| 631 |
+
buffers = [buffer for _, buffer in buffers]
|
| 632 |
+
|
| 633 |
+
# Create a group for the unmatched submodules of the top-level module so that they are on the correct
|
| 634 |
+
# device when the forward pass is called.
|
| 635 |
+
unmatched_modules = [unmatched_module for _, unmatched_module in unmatched_modules]
|
| 636 |
+
unmatched_group = ModuleGroup(
|
| 637 |
+
modules=unmatched_modules,
|
| 638 |
+
offload_device=config.offload_device,
|
| 639 |
+
onload_device=config.onload_device,
|
| 640 |
+
offload_to_disk_path=config.offload_to_disk_path,
|
| 641 |
+
offload_leader=module,
|
| 642 |
+
onload_leader=module,
|
| 643 |
+
parameters=parameters,
|
| 644 |
+
buffers=buffers,
|
| 645 |
+
non_blocking=False,
|
| 646 |
+
stream=None,
|
| 647 |
+
record_stream=False,
|
| 648 |
+
onload_self=True,
|
| 649 |
+
group_id=f"{module.__class__.__name__}_unmatched_group",
|
| 650 |
+
)
|
| 651 |
+
if config.stream is None:
|
| 652 |
+
_apply_group_offloading_hook(module, unmatched_group, config=config)
|
| 653 |
+
else:
|
| 654 |
+
_apply_lazy_group_offloading_hook(module, unmatched_group, config=config)
|
| 655 |
+
|
| 656 |
+
|
| 657 |
+
def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
|
| 658 |
+
r"""
|
| 659 |
+
This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
|
| 660 |
+
requirements. However, it can be slower compared to other offloading methods due to the excessive number of device
|
| 661 |
+
synchronizations. When using devices that support streams to overlap data transfer and computation, this method can
|
| 662 |
+
reduce memory usage without any performance degradation.
|
| 663 |
+
"""
|
| 664 |
+
# Create module groups for leaf modules and apply group offloading hooks
|
| 665 |
+
modules_with_group_offloading = set()
|
| 666 |
+
for name, submodule in module.named_modules():
|
| 667 |
+
if not isinstance(submodule, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
|
| 668 |
+
continue
|
| 669 |
+
group = ModuleGroup(
|
| 670 |
+
modules=[submodule],
|
| 671 |
+
offload_device=config.offload_device,
|
| 672 |
+
onload_device=config.onload_device,
|
| 673 |
+
offload_to_disk_path=config.offload_to_disk_path,
|
| 674 |
+
offload_leader=submodule,
|
| 675 |
+
onload_leader=submodule,
|
| 676 |
+
non_blocking=config.non_blocking,
|
| 677 |
+
stream=config.stream,
|
| 678 |
+
record_stream=config.record_stream,
|
| 679 |
+
low_cpu_mem_usage=config.low_cpu_mem_usage,
|
| 680 |
+
onload_self=True,
|
| 681 |
+
group_id=name,
|
| 682 |
+
)
|
| 683 |
+
_apply_group_offloading_hook(submodule, group, config=config)
|
| 684 |
+
modules_with_group_offloading.add(name)
|
| 685 |
+
|
| 686 |
+
# Parameters and Buffers at all non-leaf levels need to be offloaded/onloaded separately when the forward pass
|
| 687 |
+
# of the module is called
|
| 688 |
+
module_dict = dict(module.named_modules())
|
| 689 |
+
parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading)
|
| 690 |
+
buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading)
|
| 691 |
+
|
| 692 |
+
# Find closest module parent for each parameter and buffer, and attach group hooks
|
| 693 |
+
parent_to_parameters = {}
|
| 694 |
+
for name, param in parameters:
|
| 695 |
+
parent_name = _find_parent_module_in_module_dict(name, module_dict)
|
| 696 |
+
if parent_name in parent_to_parameters:
|
| 697 |
+
parent_to_parameters[parent_name].append(param)
|
| 698 |
+
else:
|
| 699 |
+
parent_to_parameters[parent_name] = [param]
|
| 700 |
+
|
| 701 |
+
parent_to_buffers = {}
|
| 702 |
+
for name, buffer in buffers:
|
| 703 |
+
parent_name = _find_parent_module_in_module_dict(name, module_dict)
|
| 704 |
+
if parent_name in parent_to_buffers:
|
| 705 |
+
parent_to_buffers[parent_name].append(buffer)
|
| 706 |
+
else:
|
| 707 |
+
parent_to_buffers[parent_name] = [buffer]
|
| 708 |
+
|
| 709 |
+
parent_names = set(parent_to_parameters.keys()) | set(parent_to_buffers.keys())
|
| 710 |
+
for name in parent_names:
|
| 711 |
+
parameters = parent_to_parameters.get(name, [])
|
| 712 |
+
buffers = parent_to_buffers.get(name, [])
|
| 713 |
+
parent_module = module_dict[name]
|
| 714 |
+
group = ModuleGroup(
|
| 715 |
+
modules=[],
|
| 716 |
+
offload_device=config.offload_device,
|
| 717 |
+
onload_device=config.onload_device,
|
| 718 |
+
offload_leader=parent_module,
|
| 719 |
+
onload_leader=parent_module,
|
| 720 |
+
offload_to_disk_path=config.offload_to_disk_path,
|
| 721 |
+
parameters=parameters,
|
| 722 |
+
buffers=buffers,
|
| 723 |
+
non_blocking=config.non_blocking,
|
| 724 |
+
stream=config.stream,
|
| 725 |
+
record_stream=config.record_stream,
|
| 726 |
+
low_cpu_mem_usage=config.low_cpu_mem_usage,
|
| 727 |
+
onload_self=True,
|
| 728 |
+
group_id=name,
|
| 729 |
+
)
|
| 730 |
+
_apply_group_offloading_hook(parent_module, group, config=config)
|
| 731 |
+
|
| 732 |
+
if config.stream is not None:
|
| 733 |
+
# When using streams, we need to know the layer execution order for applying prefetching (to overlap data transfer
|
| 734 |
+
# and computation). Since we don't know the order beforehand, we apply a lazy prefetching hook that will find the
|
| 735 |
+
# execution order and apply prefetching in the correct order.
|
| 736 |
+
unmatched_group = ModuleGroup(
|
| 737 |
+
modules=[],
|
| 738 |
+
offload_device=config.offload_device,
|
| 739 |
+
onload_device=config.onload_device,
|
| 740 |
+
offload_to_disk_path=config.offload_to_disk_path,
|
| 741 |
+
offload_leader=module,
|
| 742 |
+
onload_leader=module,
|
| 743 |
+
parameters=None,
|
| 744 |
+
buffers=None,
|
| 745 |
+
non_blocking=False,
|
| 746 |
+
stream=None,
|
| 747 |
+
record_stream=False,
|
| 748 |
+
low_cpu_mem_usage=config.low_cpu_mem_usage,
|
| 749 |
+
onload_self=True,
|
| 750 |
+
group_id=_GROUP_ID_LAZY_LEAF,
|
| 751 |
+
)
|
| 752 |
+
_apply_lazy_group_offloading_hook(module, unmatched_group, config=config)
|
| 753 |
+
|
| 754 |
+
|
| 755 |
+
def _apply_group_offloading_hook(
|
| 756 |
+
module: torch.nn.Module,
|
| 757 |
+
group: ModuleGroup,
|
| 758 |
+
*,
|
| 759 |
+
config: GroupOffloadingConfig,
|
| 760 |
+
) -> None:
|
| 761 |
+
registry = HookRegistry.check_if_exists_or_initialize(module)
|
| 762 |
+
|
| 763 |
+
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
|
| 764 |
+
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
|
| 765 |
+
if registry.get_hook(_GROUP_OFFLOADING) is None:
|
| 766 |
+
hook = GroupOffloadingHook(group, config=config)
|
| 767 |
+
registry.register_hook(hook, _GROUP_OFFLOADING)
|
| 768 |
+
|
| 769 |
+
|
| 770 |
+
def _apply_lazy_group_offloading_hook(
|
| 771 |
+
module: torch.nn.Module,
|
| 772 |
+
group: ModuleGroup,
|
| 773 |
+
*,
|
| 774 |
+
config: GroupOffloadingConfig,
|
| 775 |
+
) -> None:
|
| 776 |
+
registry = HookRegistry.check_if_exists_or_initialize(module)
|
| 777 |
+
|
| 778 |
+
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
|
| 779 |
+
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
|
| 780 |
+
if registry.get_hook(_GROUP_OFFLOADING) is None:
|
| 781 |
+
hook = GroupOffloadingHook(group, config=config)
|
| 782 |
+
registry.register_hook(hook, _GROUP_OFFLOADING)
|
| 783 |
+
|
| 784 |
+
lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook()
|
| 785 |
+
registry.register_hook(lazy_prefetch_hook, _LAZY_PREFETCH_GROUP_OFFLOADING)
|
| 786 |
+
|
| 787 |
+
|
| 788 |
+
def _gather_parameters_with_no_group_offloading_parent(
|
| 789 |
+
module: torch.nn.Module, modules_with_group_offloading: Set[str]
|
| 790 |
+
) -> List[torch.nn.Parameter]:
|
| 791 |
+
parameters = []
|
| 792 |
+
for name, parameter in module.named_parameters():
|
| 793 |
+
has_parent_with_group_offloading = False
|
| 794 |
+
atoms = name.split(".")
|
| 795 |
+
while len(atoms) > 0:
|
| 796 |
+
parent_name = ".".join(atoms)
|
| 797 |
+
if parent_name in modules_with_group_offloading:
|
| 798 |
+
has_parent_with_group_offloading = True
|
| 799 |
+
break
|
| 800 |
+
atoms.pop()
|
| 801 |
+
if not has_parent_with_group_offloading:
|
| 802 |
+
parameters.append((name, parameter))
|
| 803 |
+
return parameters
|
| 804 |
+
|
| 805 |
+
|
| 806 |
+
def _gather_buffers_with_no_group_offloading_parent(
|
| 807 |
+
module: torch.nn.Module, modules_with_group_offloading: Set[str]
|
| 808 |
+
) -> List[torch.Tensor]:
|
| 809 |
+
buffers = []
|
| 810 |
+
for name, buffer in module.named_buffers():
|
| 811 |
+
has_parent_with_group_offloading = False
|
| 812 |
+
atoms = name.split(".")
|
| 813 |
+
while len(atoms) > 0:
|
| 814 |
+
parent_name = ".".join(atoms)
|
| 815 |
+
if parent_name in modules_with_group_offloading:
|
| 816 |
+
has_parent_with_group_offloading = True
|
| 817 |
+
break
|
| 818 |
+
atoms.pop()
|
| 819 |
+
if not has_parent_with_group_offloading:
|
| 820 |
+
buffers.append((name, buffer))
|
| 821 |
+
return buffers
|
| 822 |
+
|
| 823 |
+
|
| 824 |
+
def _find_parent_module_in_module_dict(name: str, module_dict: Dict[str, torch.nn.Module]) -> str:
|
| 825 |
+
atoms = name.split(".")
|
| 826 |
+
while len(atoms) > 0:
|
| 827 |
+
parent_name = ".".join(atoms)
|
| 828 |
+
if parent_name in module_dict:
|
| 829 |
+
return parent_name
|
| 830 |
+
atoms.pop()
|
| 831 |
+
return ""
|
| 832 |
+
|
| 833 |
+
|
| 834 |
+
def _raise_error_if_accelerate_model_or_sequential_hook_present(module: torch.nn.Module) -> None:
|
| 835 |
+
if not is_accelerate_available():
|
| 836 |
+
return
|
| 837 |
+
for name, submodule in module.named_modules():
|
| 838 |
+
if not hasattr(submodule, "_hf_hook"):
|
| 839 |
+
continue
|
| 840 |
+
if isinstance(submodule._hf_hook, (AlignDevicesHook, CpuOffload)):
|
| 841 |
+
raise ValueError(
|
| 842 |
+
f"Cannot apply group offloading to a module that is already applying an alternative "
|
| 843 |
+
f"offloading strategy from Accelerate. If you want to apply group offloading, please "
|
| 844 |
+
f"disable the existing offloading strategy first. Offending module: {name} ({type(submodule)})"
|
| 845 |
+
)
|
| 846 |
+
|
| 847 |
+
|
| 848 |
+
def _get_top_level_group_offload_hook(module: torch.nn.Module) -> Optional[GroupOffloadingHook]:
|
| 849 |
+
for submodule in module.modules():
|
| 850 |
+
if hasattr(submodule, "_diffusers_hook"):
|
| 851 |
+
group_offloading_hook = submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING)
|
| 852 |
+
if group_offloading_hook is not None:
|
| 853 |
+
return group_offloading_hook
|
| 854 |
+
return None
|
| 855 |
+
|
| 856 |
+
|
| 857 |
+
def _is_group_offload_enabled(module: torch.nn.Module) -> bool:
|
| 858 |
+
top_level_group_offload_hook = _get_top_level_group_offload_hook(module)
|
| 859 |
+
return top_level_group_offload_hook is not None
|
| 860 |
+
|
| 861 |
+
|
| 862 |
+
def _get_group_onload_device(module: torch.nn.Module) -> torch.device:
|
| 863 |
+
top_level_group_offload_hook = _get_top_level_group_offload_hook(module)
|
| 864 |
+
if top_level_group_offload_hook is not None:
|
| 865 |
+
return top_level_group_offload_hook.config.onload_device
|
| 866 |
+
raise ValueError("Group offloading is not enabled for the provided module.")
|
| 867 |
+
|
| 868 |
+
|
| 869 |
+
def _compute_group_hash(group_id):
|
| 870 |
+
hashed_id = hashlib.sha256(group_id.encode("utf-8")).hexdigest()
|
| 871 |
+
# first 16 characters for a reasonably short but unique name
|
| 872 |
+
return hashed_id[:16]
|
| 873 |
+
|
| 874 |
+
|
| 875 |
+
def _maybe_remove_and_reapply_group_offloading(module: torch.nn.Module) -> None:
|
| 876 |
+
r"""
|
| 877 |
+
Removes the group offloading hook from the module and re-applies it. This is useful when the module has been
|
| 878 |
+
modified in-place and the group offloading hook references-to-tensors needs to be updated. The in-place
|
| 879 |
+
modification can happen in a number of ways, for example, fusing QKV or unloading/loading LoRAs on-the-fly.
|
| 880 |
+
|
| 881 |
+
In this implementation, we make an assumption that group offloading has only been applied at the top-level module,
|
| 882 |
+
and therefore all submodules have the same onload and offload devices. If this assumption is not true, say in the
|
| 883 |
+
case where user has applied group offloading at multiple levels, this function will not work as expected.
|
| 884 |
+
|
| 885 |
+
There is some performance penalty associated with doing this when non-default streams are used, because we need to
|
| 886 |
+
retrace the execution order of the layers with `LazyPrefetchGroupOffloadingHook`.
|
| 887 |
+
"""
|
| 888 |
+
top_level_group_offload_hook = _get_top_level_group_offload_hook(module)
|
| 889 |
+
|
| 890 |
+
if top_level_group_offload_hook is None:
|
| 891 |
+
return
|
| 892 |
+
|
| 893 |
+
registry = HookRegistry.check_if_exists_or_initialize(module)
|
| 894 |
+
registry.remove_hook(_GROUP_OFFLOADING, recurse=True)
|
| 895 |
+
registry.remove_hook(_LAYER_EXECUTION_TRACKER, recurse=True)
|
| 896 |
+
registry.remove_hook(_LAZY_PREFETCH_GROUP_OFFLOADING, recurse=True)
|
| 897 |
+
|
| 898 |
+
_apply_group_offloading(module, top_level_group_offload_hook.config)
|