xiaoanyu123 commited on
Commit
8f3b606
·
verified ·
1 Parent(s): f01352f

Add files using upload-large-folder tool

Browse files
Files changed (20) hide show
  1. pythonProject/.venv/Lib/site-packages/diffusers/guiders/__init__.py +41 -0
  2. pythonProject/.venv/Lib/site-packages/diffusers/guiders/__pycache__/__init__.cpython-310.pyc +0 -0
  3. pythonProject/.venv/Lib/site-packages/diffusers/guiders/__pycache__/adaptive_projected_guidance.cpython-310.pyc +0 -0
  4. pythonProject/.venv/Lib/site-packages/diffusers/guiders/__pycache__/auto_guidance.cpython-310.pyc +0 -0
  5. pythonProject/.venv/Lib/site-packages/diffusers/guiders/__pycache__/classifier_free_guidance.cpython-310.pyc +0 -0
  6. pythonProject/.venv/Lib/site-packages/diffusers/guiders/__pycache__/classifier_free_zero_star_guidance.cpython-310.pyc +0 -0
  7. pythonProject/.venv/Lib/site-packages/diffusers/guiders/__pycache__/frequency_decoupled_guidance.cpython-310.pyc +0 -0
  8. pythonProject/.venv/Lib/site-packages/diffusers/guiders/__pycache__/guider_utils.cpython-310.pyc +0 -0
  9. pythonProject/.venv/Lib/site-packages/diffusers/guiders/__pycache__/perturbed_attention_guidance.cpython-310.pyc +0 -0
  10. pythonProject/.venv/Lib/site-packages/diffusers/guiders/__pycache__/skip_layer_guidance.cpython-310.pyc +0 -0
  11. pythonProject/.venv/Lib/site-packages/diffusers/guiders/__pycache__/smoothed_energy_guidance.cpython-310.pyc +0 -0
  12. pythonProject/.venv/Lib/site-packages/diffusers/guiders/__pycache__/tangential_classifier_free_guidance.cpython-310.pyc +0 -0
  13. pythonProject/.venv/Lib/site-packages/diffusers/guiders/guider_utils.py +315 -0
  14. pythonProject/.venv/Lib/site-packages/diffusers/guiders/perturbed_attention_guidance.py +271 -0
  15. pythonProject/.venv/Lib/site-packages/diffusers/guiders/skip_layer_guidance.py +262 -0
  16. pythonProject/.venv/Lib/site-packages/diffusers/guiders/smoothed_energy_guidance.py +251 -0
  17. pythonProject/.venv/Lib/site-packages/diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
  18. pythonProject/.venv/Lib/site-packages/diffusers/hooks/faster_cache.py +654 -0
  19. pythonProject/.venv/Lib/site-packages/diffusers/hooks/first_block_cache.py +259 -0
  20. pythonProject/.venv/Lib/site-packages/diffusers/hooks/group_offloading.py +898 -0
pythonProject/.venv/Lib/site-packages/diffusers/guiders/__init__.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Union
16
+
17
+ from ..utils import is_torch_available
18
+
19
+
20
+ if is_torch_available():
21
+ from .adaptive_projected_guidance import AdaptiveProjectedGuidance
22
+ from .auto_guidance import AutoGuidance
23
+ from .classifier_free_guidance import ClassifierFreeGuidance
24
+ from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance
25
+ from .frequency_decoupled_guidance import FrequencyDecoupledGuidance
26
+ from .perturbed_attention_guidance import PerturbedAttentionGuidance
27
+ from .skip_layer_guidance import SkipLayerGuidance
28
+ from .smoothed_energy_guidance import SmoothedEnergyGuidance
29
+ from .tangential_classifier_free_guidance import TangentialClassifierFreeGuidance
30
+
31
+ GuiderType = Union[
32
+ AdaptiveProjectedGuidance,
33
+ AutoGuidance,
34
+ ClassifierFreeGuidance,
35
+ ClassifierFreeZeroStarGuidance,
36
+ FrequencyDecoupledGuidance,
37
+ PerturbedAttentionGuidance,
38
+ SkipLayerGuidance,
39
+ SmoothedEnergyGuidance,
40
+ TangentialClassifierFreeGuidance,
41
+ ]
pythonProject/.venv/Lib/site-packages/diffusers/guiders/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.01 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/guiders/__pycache__/adaptive_projected_guidance.cpython-310.pyc ADDED
Binary file (6.65 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/guiders/__pycache__/auto_guidance.cpython-310.pyc ADDED
Binary file (7.43 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/guiders/__pycache__/classifier_free_guidance.cpython-310.pyc ADDED
Binary file (5.93 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/guiders/__pycache__/classifier_free_zero_star_guidance.cpython-310.pyc ADDED
Binary file (5.81 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/guiders/__pycache__/frequency_decoupled_guidance.cpython-310.pyc ADDED
Binary file (12.3 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/guiders/__pycache__/guider_utils.cpython-310.pyc ADDED
Binary file (14.9 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/guiders/__pycache__/perturbed_attention_guidance.cpython-310.pyc ADDED
Binary file (9.77 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/guiders/__pycache__/skip_layer_guidance.cpython-310.pyc ADDED
Binary file (10.3 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/guiders/__pycache__/smoothed_energy_guidance.cpython-310.pyc ADDED
Binary file (9.95 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/guiders/__pycache__/tangential_classifier_free_guidance.cpython-310.pyc ADDED
Binary file (5.16 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/guiders/guider_utils.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ from huggingface_hub.utils import validate_hf_hub_args
20
+ from typing_extensions import Self
21
+
22
+ from ..configuration_utils import ConfigMixin
23
+ from ..utils import BaseOutput, PushToHubMixin, get_logger
24
+
25
+
26
+ if TYPE_CHECKING:
27
+ from ..modular_pipelines.modular_pipeline import BlockState
28
+
29
+
30
+ GUIDER_CONFIG_NAME = "guider_config.json"
31
+
32
+
33
+ logger = get_logger(__name__) # pylint: disable=invalid-name
34
+
35
+
36
+ class BaseGuidance(ConfigMixin, PushToHubMixin):
37
+ r"""Base class providing the skeleton for implementing guidance techniques."""
38
+
39
+ config_name = GUIDER_CONFIG_NAME
40
+ _input_predictions = None
41
+ _identifier_key = "__guidance_identifier__"
42
+
43
+ def __init__(self, start: float = 0.0, stop: float = 1.0):
44
+ self._start = start
45
+ self._stop = stop
46
+ self._step: int = None
47
+ self._num_inference_steps: int = None
48
+ self._timestep: torch.LongTensor = None
49
+ self._count_prepared = 0
50
+ self._input_fields: Dict[str, Union[str, Tuple[str, str]]] = None
51
+ self._enabled = True
52
+
53
+ if not (0.0 <= start < 1.0):
54
+ raise ValueError(f"Expected `start` to be between 0.0 and 1.0, but got {start}.")
55
+ if not (start <= stop <= 1.0):
56
+ raise ValueError(f"Expected `stop` to be between {start} and 1.0, but got {stop}.")
57
+
58
+ if self._input_predictions is None or not isinstance(self._input_predictions, list):
59
+ raise ValueError(
60
+ "`_input_predictions` must be a list of required prediction names for the guidance technique."
61
+ )
62
+
63
+ def disable(self):
64
+ self._enabled = False
65
+
66
+ def enable(self):
67
+ self._enabled = True
68
+
69
+ def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None:
70
+ self._step = step
71
+ self._num_inference_steps = num_inference_steps
72
+ self._timestep = timestep
73
+ self._count_prepared = 0
74
+
75
+ def set_input_fields(self, **kwargs: Dict[str, Union[str, Tuple[str, str]]]) -> None:
76
+ """
77
+ Set the input fields for the guidance technique. The input fields are used to specify the names of the returned
78
+ attributes containing the prepared data after `prepare_inputs` is called. The prepared data is obtained from
79
+ the values of the provided keyword arguments to this method.
80
+
81
+ Args:
82
+ **kwargs (`Dict[str, Union[str, Tuple[str, str]]]`):
83
+ A dictionary where the keys are the names of the fields that will be used to store the data once it is
84
+ prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, which is used
85
+ to look up the required data provided for preparation.
86
+
87
+ If a string is provided, it will be used as the conditional data (or unconditional if used with a
88
+ guidance method that requires it). If a tuple of length 2 is provided, the first element must be the
89
+ conditional data identifier and the second element must be the unconditional data identifier or None.
90
+
91
+ Example:
92
+ ```
93
+ data = {"prompt_embeds": <some tensor>, "negative_prompt_embeds": <some tensor>, "latents": <some tensor>}
94
+
95
+ BaseGuidance.set_input_fields(
96
+ latents="latents",
97
+ prompt_embeds=("prompt_embeds", "negative_prompt_embeds"),
98
+ )
99
+ ```
100
+ """
101
+ for key, value in kwargs.items():
102
+ is_string = isinstance(value, str)
103
+ is_tuple_of_str_with_len_2 = (
104
+ isinstance(value, tuple) and len(value) == 2 and all(isinstance(v, str) for v in value)
105
+ )
106
+ if not (is_string or is_tuple_of_str_with_len_2):
107
+ raise ValueError(
108
+ f"Expected `set_input_fields` to be called with a string or a tuple of string with length 2, but got {type(value)} for key {key}."
109
+ )
110
+ self._input_fields = kwargs
111
+
112
+ def prepare_models(self, denoiser: torch.nn.Module) -> None:
113
+ """
114
+ Prepares the models for the guidance technique on a given batch of data. This method should be overridden in
115
+ subclasses to implement specific model preparation logic.
116
+ """
117
+ self._count_prepared += 1
118
+
119
+ def cleanup_models(self, denoiser: torch.nn.Module) -> None:
120
+ """
121
+ Cleans up the models for the guidance technique after a given batch of data. This method should be overridden
122
+ in subclasses to implement specific model cleanup logic. It is useful for removing any hooks or other stateful
123
+ modifications made during `prepare_models`.
124
+ """
125
+ pass
126
+
127
+ def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
128
+ raise NotImplementedError("BaseGuidance::prepare_inputs must be implemented in subclasses.")
129
+
130
+ def __call__(self, data: List["BlockState"]) -> Any:
131
+ if not all(hasattr(d, "noise_pred") for d in data):
132
+ raise ValueError("Expected all data to have `noise_pred` attribute.")
133
+ if len(data) != self.num_conditions:
134
+ raise ValueError(
135
+ f"Expected {self.num_conditions} data items, but got {len(data)}. Please check the input data."
136
+ )
137
+ forward_inputs = {getattr(d, self._identifier_key): d.noise_pred for d in data}
138
+ return self.forward(**forward_inputs)
139
+
140
+ def forward(self, *args, **kwargs) -> Any:
141
+ raise NotImplementedError("BaseGuidance::forward must be implemented in subclasses.")
142
+
143
+ @property
144
+ def is_conditional(self) -> bool:
145
+ raise NotImplementedError("BaseGuidance::is_conditional must be implemented in subclasses.")
146
+
147
+ @property
148
+ def is_unconditional(self) -> bool:
149
+ return not self.is_conditional
150
+
151
+ @property
152
+ def num_conditions(self) -> int:
153
+ raise NotImplementedError("BaseGuidance::num_conditions must be implemented in subclasses.")
154
+
155
+ @classmethod
156
+ def _prepare_batch(
157
+ cls,
158
+ input_fields: Dict[str, Union[str, Tuple[str, str]]],
159
+ data: "BlockState",
160
+ tuple_index: int,
161
+ identifier: str,
162
+ ) -> "BlockState":
163
+ """
164
+ Prepares a batch of data for the guidance technique. This method is used in the `prepare_inputs` method of the
165
+ `BaseGuidance` class. It prepares the batch based on the provided tuple index.
166
+
167
+ Args:
168
+ input_fields (`Dict[str, Union[str, Tuple[str, str]]]`):
169
+ A dictionary where the keys are the names of the fields that will be used to store the data once it is
170
+ prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, which is used
171
+ to look up the required data provided for preparation. If a string is provided, it will be used as the
172
+ conditional data (or unconditional if used with a guidance method that requires it). If a tuple of
173
+ length 2 is provided, the first element must be the conditional data identifier and the second element
174
+ must be the unconditional data identifier or None.
175
+ data (`BlockState`):
176
+ The input data to be prepared.
177
+ tuple_index (`int`):
178
+ The index to use when accessing input fields that are tuples.
179
+
180
+ Returns:
181
+ `BlockState`: The prepared batch of data.
182
+ """
183
+ from ..modular_pipelines.modular_pipeline import BlockState
184
+
185
+ if input_fields is None:
186
+ raise ValueError(
187
+ "Input fields cannot be None. Please pass `input_fields` to `prepare_inputs` or call `set_input_fields` before preparing inputs."
188
+ )
189
+ data_batch = {}
190
+ for key, value in input_fields.items():
191
+ try:
192
+ if isinstance(value, str):
193
+ data_batch[key] = getattr(data, value)
194
+ elif isinstance(value, tuple):
195
+ data_batch[key] = getattr(data, value[tuple_index])
196
+ else:
197
+ # We've already checked that value is a string or a tuple of strings with length 2
198
+ pass
199
+ except AttributeError:
200
+ logger.debug(f"`data` does not have attribute(s) {value}, skipping.")
201
+ data_batch[cls._identifier_key] = identifier
202
+ return BlockState(**data_batch)
203
+
204
+ @classmethod
205
+ @validate_hf_hub_args
206
+ def from_pretrained(
207
+ cls,
208
+ pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
209
+ subfolder: Optional[str] = None,
210
+ return_unused_kwargs=False,
211
+ **kwargs,
212
+ ) -> Self:
213
+ r"""
214
+ Instantiate a guider from a pre-defined JSON configuration file in a local directory or Hub repository.
215
+
216
+ Parameters:
217
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
218
+ Can be either:
219
+
220
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
221
+ the Hub.
222
+ - A path to a *directory* (for example `./my_model_directory`) containing the guider configuration
223
+ saved with [`~BaseGuidance.save_pretrained`].
224
+ subfolder (`str`, *optional*):
225
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
226
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
227
+ Whether kwargs that are not consumed by the Python class should be returned or not.
228
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
229
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
230
+ is not used.
231
+ force_download (`bool`, *optional*, defaults to `False`):
232
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
233
+ cached versions if they exist.
234
+
235
+ proxies (`Dict[str, str]`, *optional*):
236
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
237
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
238
+ output_loading_info(`bool`, *optional*, defaults to `False`):
239
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
240
+ local_files_only(`bool`, *optional*, defaults to `False`):
241
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
242
+ won't be downloaded from the Hub.
243
+ token (`str` or *bool*, *optional*):
244
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
245
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
246
+ revision (`str`, *optional*, defaults to `"main"`):
247
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
248
+ allowed by Git.
249
+
250
+ <Tip>
251
+
252
+ To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with `hf
253
+ auth login`. You can also activate the special
254
+ ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
255
+ firewalled environment.
256
+
257
+ </Tip>
258
+
259
+ """
260
+ config, kwargs, commit_hash = cls.load_config(
261
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
262
+ subfolder=subfolder,
263
+ return_unused_kwargs=True,
264
+ return_commit_hash=True,
265
+ **kwargs,
266
+ )
267
+ return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs)
268
+
269
+ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
270
+ """
271
+ Save a guider configuration object to a directory so that it can be reloaded using the
272
+ [`~BaseGuidance.from_pretrained`] class method.
273
+
274
+ Args:
275
+ save_directory (`str` or `os.PathLike`):
276
+ Directory where the configuration JSON file will be saved (will be created if it does not exist).
277
+ push_to_hub (`bool`, *optional*, defaults to `False`):
278
+ Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
279
+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
280
+ namespace).
281
+ kwargs (`Dict[str, Any]`, *optional*):
282
+ Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
283
+ """
284
+ self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
285
+
286
+
287
+ class GuiderOutput(BaseOutput):
288
+ pred: torch.Tensor
289
+ pred_cond: Optional[torch.Tensor]
290
+ pred_uncond: Optional[torch.Tensor]
291
+
292
+
293
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
294
+ r"""
295
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
296
+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
297
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf).
298
+
299
+ Args:
300
+ noise_cfg (`torch.Tensor`):
301
+ The predicted noise tensor for the guided diffusion process.
302
+ noise_pred_text (`torch.Tensor`):
303
+ The predicted noise tensor for the text-guided diffusion process.
304
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
305
+ A rescale factor applied to the noise predictions.
306
+ Returns:
307
+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
308
+ """
309
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
310
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
311
+ # rescale the results from guidance (fixes overexposure)
312
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
313
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
314
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
315
+ return noise_cfg
pythonProject/.venv/Lib/site-packages/diffusers/guiders/perturbed_attention_guidance.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+
20
+ from ..configuration_utils import register_to_config
21
+ from ..hooks import HookRegistry, LayerSkipConfig
22
+ from ..hooks.layer_skip import _apply_layer_skip_hook
23
+ from ..utils import get_logger
24
+ from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
25
+
26
+
27
+ if TYPE_CHECKING:
28
+ from ..modular_pipelines.modular_pipeline import BlockState
29
+
30
+
31
+ logger = get_logger(__name__) # pylint: disable=invalid-name
32
+
33
+
34
+ class PerturbedAttentionGuidance(BaseGuidance):
35
+ """
36
+ Perturbed Attention Guidance (PAG): https://huggingface.co/papers/2403.17377
37
+
38
+ The intution behind PAG can be thought of as moving the CFG predicted distribution estimates further away from
39
+ worse versions of the conditional distribution estimates. PAG was one of the first techniques to introduce the idea
40
+ of using a worse version of the trained model for better guiding itself in the denoising process. It perturbs the
41
+ attention scores of the latent stream by replacing the score matrix with an identity matrix for selectively chosen
42
+ layers.
43
+
44
+ Additional reading:
45
+ - [Guiding a Diffusion Model with a Bad Version of Itself](https://huggingface.co/papers/2406.02507)
46
+
47
+ PAG is implemented with similar implementation to SkipLayerGuidance due to overlap in the configuration parameters
48
+ and implementation details.
49
+
50
+ Args:
51
+ guidance_scale (`float`, defaults to `7.5`):
52
+ The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
53
+ prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
54
+ deterioration of image quality.
55
+ perturbed_guidance_scale (`float`, defaults to `2.8`):
56
+ The scale parameter for perturbed attention guidance.
57
+ perturbed_guidance_start (`float`, defaults to `0.01`):
58
+ The fraction of the total number of denoising steps after which perturbed attention guidance starts.
59
+ perturbed_guidance_stop (`float`, defaults to `0.2`):
60
+ The fraction of the total number of denoising steps after which perturbed attention guidance stops.
61
+ perturbed_guidance_layers (`int` or `List[int]`, *optional*):
62
+ The layer indices to apply perturbed attention guidance to. Can be a single integer or a list of integers.
63
+ If not provided, `perturbed_guidance_config` must be provided.
64
+ perturbed_guidance_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*):
65
+ The configuration for the perturbed attention guidance. Can be a single `LayerSkipConfig` or a list of
66
+ `LayerSkipConfig`. If not provided, `perturbed_guidance_layers` must be provided.
67
+ guidance_rescale (`float`, defaults to `0.0`):
68
+ The rescale factor applied to the noise predictions. This is used to improve image quality and fix
69
+ overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
70
+ Flawed](https://huggingface.co/papers/2305.08891).
71
+ use_original_formulation (`bool`, defaults to `False`):
72
+ Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
73
+ we use the diffusers-native implementation that has been in the codebase for a long time. See
74
+ [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
75
+ start (`float`, defaults to `0.01`):
76
+ The fraction of the total number of denoising steps after which guidance starts.
77
+ stop (`float`, defaults to `0.2`):
78
+ The fraction of the total number of denoising steps after which guidance stops.
79
+ """
80
+
81
+ # NOTE: The current implementation does not account for joint latent conditioning (text + image/video tokens in
82
+ # the same latent stream). It assumes the entire latent is a single stream of visual tokens. It would be very
83
+ # complex to support joint latent conditioning in a model-agnostic manner without specializing the implementation
84
+ # for each model architecture.
85
+
86
+ _input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
87
+
88
+ @register_to_config
89
+ def __init__(
90
+ self,
91
+ guidance_scale: float = 7.5,
92
+ perturbed_guidance_scale: float = 2.8,
93
+ perturbed_guidance_start: float = 0.01,
94
+ perturbed_guidance_stop: float = 0.2,
95
+ perturbed_guidance_layers: Optional[Union[int, List[int]]] = None,
96
+ perturbed_guidance_config: Union[LayerSkipConfig, List[LayerSkipConfig], Dict[str, Any]] = None,
97
+ guidance_rescale: float = 0.0,
98
+ use_original_formulation: bool = False,
99
+ start: float = 0.0,
100
+ stop: float = 1.0,
101
+ ):
102
+ super().__init__(start, stop)
103
+
104
+ self.guidance_scale = guidance_scale
105
+ self.skip_layer_guidance_scale = perturbed_guidance_scale
106
+ self.skip_layer_guidance_start = perturbed_guidance_start
107
+ self.skip_layer_guidance_stop = perturbed_guidance_stop
108
+ self.guidance_rescale = guidance_rescale
109
+ self.use_original_formulation = use_original_formulation
110
+
111
+ if perturbed_guidance_config is None:
112
+ if perturbed_guidance_layers is None:
113
+ raise ValueError(
114
+ "`perturbed_guidance_layers` must be provided if `perturbed_guidance_config` is not specified."
115
+ )
116
+ perturbed_guidance_config = LayerSkipConfig(
117
+ indices=perturbed_guidance_layers,
118
+ fqn="auto",
119
+ skip_attention=False,
120
+ skip_attention_scores=True,
121
+ skip_ff=False,
122
+ )
123
+ else:
124
+ if perturbed_guidance_layers is not None:
125
+ raise ValueError(
126
+ "`perturbed_guidance_layers` should not be provided if `perturbed_guidance_config` is specified."
127
+ )
128
+
129
+ if isinstance(perturbed_guidance_config, dict):
130
+ perturbed_guidance_config = LayerSkipConfig.from_dict(perturbed_guidance_config)
131
+
132
+ if isinstance(perturbed_guidance_config, LayerSkipConfig):
133
+ perturbed_guidance_config = [perturbed_guidance_config]
134
+
135
+ if not isinstance(perturbed_guidance_config, list):
136
+ raise ValueError(
137
+ "`perturbed_guidance_config` must be a `LayerSkipConfig`, a list of `LayerSkipConfig`, or a dict that can be converted to a `LayerSkipConfig`."
138
+ )
139
+ elif isinstance(next(iter(perturbed_guidance_config), None), dict):
140
+ perturbed_guidance_config = [LayerSkipConfig.from_dict(config) for config in perturbed_guidance_config]
141
+
142
+ for config in perturbed_guidance_config:
143
+ if config.skip_attention or not config.skip_attention_scores or config.skip_ff:
144
+ logger.warning(
145
+ "Perturbed Attention Guidance is designed to perturb attention scores, so `skip_attention` should be False, `skip_attention_scores` should be True, and `skip_ff` should be False. "
146
+ "Please check your configuration. Modifying the config to match the expected values."
147
+ )
148
+ config.skip_attention = False
149
+ config.skip_attention_scores = True
150
+ config.skip_ff = False
151
+
152
+ self.skip_layer_config = perturbed_guidance_config
153
+ self._skip_layer_hook_names = [f"SkipLayerGuidance_{i}" for i in range(len(self.skip_layer_config))]
154
+
155
+ # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.prepare_models
156
+ def prepare_models(self, denoiser: torch.nn.Module) -> None:
157
+ self._count_prepared += 1
158
+ if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
159
+ for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config):
160
+ _apply_layer_skip_hook(denoiser, config, name=name)
161
+
162
+ # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.cleanup_models
163
+ def cleanup_models(self, denoiser: torch.nn.Module) -> None:
164
+ if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
165
+ registry = HookRegistry.check_if_exists_or_initialize(denoiser)
166
+ # Remove the hooks after inference
167
+ for hook_name in self._skip_layer_hook_names:
168
+ registry.remove_hook(hook_name, recurse=True)
169
+
170
+ # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.prepare_inputs
171
+ def prepare_inputs(
172
+ self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
173
+ ) -> List["BlockState"]:
174
+ if input_fields is None:
175
+ input_fields = self._input_fields
176
+
177
+ if self.num_conditions == 1:
178
+ tuple_indices = [0]
179
+ input_predictions = ["pred_cond"]
180
+ elif self.num_conditions == 2:
181
+ tuple_indices = [0, 1]
182
+ input_predictions = (
183
+ ["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"]
184
+ )
185
+ else:
186
+ tuple_indices = [0, 1, 0]
187
+ input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
188
+ data_batches = []
189
+ for i in range(self.num_conditions):
190
+ data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
191
+ data_batches.append(data_batch)
192
+ return data_batches
193
+
194
+ # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.forward
195
+ def forward(
196
+ self,
197
+ pred_cond: torch.Tensor,
198
+ pred_uncond: Optional[torch.Tensor] = None,
199
+ pred_cond_skip: Optional[torch.Tensor] = None,
200
+ ) -> GuiderOutput:
201
+ pred = None
202
+
203
+ if not self._is_cfg_enabled() and not self._is_slg_enabled():
204
+ pred = pred_cond
205
+ elif not self._is_cfg_enabled():
206
+ shift = pred_cond - pred_cond_skip
207
+ pred = pred_cond if self.use_original_formulation else pred_cond_skip
208
+ pred = pred + self.skip_layer_guidance_scale * shift
209
+ elif not self._is_slg_enabled():
210
+ shift = pred_cond - pred_uncond
211
+ pred = pred_cond if self.use_original_formulation else pred_uncond
212
+ pred = pred + self.guidance_scale * shift
213
+ else:
214
+ shift = pred_cond - pred_uncond
215
+ shift_skip = pred_cond - pred_cond_skip
216
+ pred = pred_cond if self.use_original_formulation else pred_uncond
217
+ pred = pred + self.guidance_scale * shift + self.skip_layer_guidance_scale * shift_skip
218
+
219
+ if self.guidance_rescale > 0.0:
220
+ pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
221
+
222
+ return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
223
+
224
+ @property
225
+ # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.is_conditional
226
+ def is_conditional(self) -> bool:
227
+ return self._count_prepared == 1 or self._count_prepared == 3
228
+
229
+ @property
230
+ # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.num_conditions
231
+ def num_conditions(self) -> int:
232
+ num_conditions = 1
233
+ if self._is_cfg_enabled():
234
+ num_conditions += 1
235
+ if self._is_slg_enabled():
236
+ num_conditions += 1
237
+ return num_conditions
238
+
239
+ # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance._is_cfg_enabled
240
+ def _is_cfg_enabled(self) -> bool:
241
+ if not self._enabled:
242
+ return False
243
+
244
+ is_within_range = True
245
+ if self._num_inference_steps is not None:
246
+ skip_start_step = int(self._start * self._num_inference_steps)
247
+ skip_stop_step = int(self._stop * self._num_inference_steps)
248
+ is_within_range = skip_start_step <= self._step < skip_stop_step
249
+
250
+ is_close = False
251
+ if self.use_original_formulation:
252
+ is_close = math.isclose(self.guidance_scale, 0.0)
253
+ else:
254
+ is_close = math.isclose(self.guidance_scale, 1.0)
255
+
256
+ return is_within_range and not is_close
257
+
258
+ # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance._is_slg_enabled
259
+ def _is_slg_enabled(self) -> bool:
260
+ if not self._enabled:
261
+ return False
262
+
263
+ is_within_range = True
264
+ if self._num_inference_steps is not None:
265
+ skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
266
+ skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
267
+ is_within_range = skip_start_step < self._step < skip_stop_step
268
+
269
+ is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0)
270
+
271
+ return is_within_range and not is_zero
pythonProject/.venv/Lib/site-packages/diffusers/guiders/skip_layer_guidance.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+
20
+ from ..configuration_utils import register_to_config
21
+ from ..hooks import HookRegistry, LayerSkipConfig
22
+ from ..hooks.layer_skip import _apply_layer_skip_hook
23
+ from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
24
+
25
+
26
+ if TYPE_CHECKING:
27
+ from ..modular_pipelines.modular_pipeline import BlockState
28
+
29
+
30
+ class SkipLayerGuidance(BaseGuidance):
31
+ """
32
+ Skip Layer Guidance (SLG): https://github.com/Stability-AI/sd3.5
33
+
34
+ Spatio-Temporal Guidance (STG): https://huggingface.co/papers/2411.18664
35
+
36
+ SLG was introduced by StabilityAI for improving structure and anotomy coherence in generated images. It works by
37
+ skipping the forward pass of specified transformer blocks during the denoising process on an additional conditional
38
+ batch of data, apart from the conditional and unconditional batches already used in CFG
39
+ ([~guiders.classifier_free_guidance.ClassifierFreeGuidance]), and then scaling and shifting the CFG predictions
40
+ based on the difference between conditional without skipping and conditional with skipping predictions.
41
+
42
+ The intution behind SLG can be thought of as moving the CFG predicted distribution estimates further away from
43
+ worse versions of the conditional distribution estimates (because skipping layers is equivalent to using a worse
44
+ version of the model for the conditional prediction).
45
+
46
+ STG is an improvement and follow-up work combining ideas from SLG, PAG and similar techniques for improving
47
+ generation quality in video diffusion models.
48
+
49
+ Additional reading:
50
+ - [Guiding a Diffusion Model with a Bad Version of Itself](https://huggingface.co/papers/2406.02507)
51
+
52
+ The values for `skip_layer_guidance_scale`, `skip_layer_guidance_start`, and `skip_layer_guidance_stop` are
53
+ defaulted to the recommendations by StabilityAI for Stable Diffusion 3.5 Medium.
54
+
55
+ Args:
56
+ guidance_scale (`float`, defaults to `7.5`):
57
+ The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
58
+ prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
59
+ deterioration of image quality.
60
+ skip_layer_guidance_scale (`float`, defaults to `2.8`):
61
+ The scale parameter for skip layer guidance. Anatomy and structure coherence may improve with higher
62
+ values, but it may also lead to overexposure and saturation.
63
+ skip_layer_guidance_start (`float`, defaults to `0.01`):
64
+ The fraction of the total number of denoising steps after which skip layer guidance starts.
65
+ skip_layer_guidance_stop (`float`, defaults to `0.2`):
66
+ The fraction of the total number of denoising steps after which skip layer guidance stops.
67
+ skip_layer_guidance_layers (`int` or `List[int]`, *optional*):
68
+ The layer indices to apply skip layer guidance to. Can be a single integer or a list of integers. If not
69
+ provided, `skip_layer_config` must be provided. The recommended values are `[7, 8, 9]` for Stable Diffusion
70
+ 3.5 Medium.
71
+ skip_layer_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*):
72
+ The configuration for the skip layer guidance. Can be a single `LayerSkipConfig` or a list of
73
+ `LayerSkipConfig`. If not provided, `skip_layer_guidance_layers` must be provided.
74
+ guidance_rescale (`float`, defaults to `0.0`):
75
+ The rescale factor applied to the noise predictions. This is used to improve image quality and fix
76
+ overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
77
+ Flawed](https://huggingface.co/papers/2305.08891).
78
+ use_original_formulation (`bool`, defaults to `False`):
79
+ Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
80
+ we use the diffusers-native implementation that has been in the codebase for a long time. See
81
+ [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
82
+ start (`float`, defaults to `0.01`):
83
+ The fraction of the total number of denoising steps after which guidance starts.
84
+ stop (`float`, defaults to `0.2`):
85
+ The fraction of the total number of denoising steps after which guidance stops.
86
+ """
87
+
88
+ _input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
89
+
90
+ @register_to_config
91
+ def __init__(
92
+ self,
93
+ guidance_scale: float = 7.5,
94
+ skip_layer_guidance_scale: float = 2.8,
95
+ skip_layer_guidance_start: float = 0.01,
96
+ skip_layer_guidance_stop: float = 0.2,
97
+ skip_layer_guidance_layers: Optional[Union[int, List[int]]] = None,
98
+ skip_layer_config: Union[LayerSkipConfig, List[LayerSkipConfig], Dict[str, Any]] = None,
99
+ guidance_rescale: float = 0.0,
100
+ use_original_formulation: bool = False,
101
+ start: float = 0.0,
102
+ stop: float = 1.0,
103
+ ):
104
+ super().__init__(start, stop)
105
+
106
+ self.guidance_scale = guidance_scale
107
+ self.skip_layer_guidance_scale = skip_layer_guidance_scale
108
+ self.skip_layer_guidance_start = skip_layer_guidance_start
109
+ self.skip_layer_guidance_stop = skip_layer_guidance_stop
110
+ self.guidance_rescale = guidance_rescale
111
+ self.use_original_formulation = use_original_formulation
112
+
113
+ if not (0.0 <= skip_layer_guidance_start < 1.0):
114
+ raise ValueError(
115
+ f"Expected `skip_layer_guidance_start` to be between 0.0 and 1.0, but got {skip_layer_guidance_start}."
116
+ )
117
+ if not (skip_layer_guidance_start <= skip_layer_guidance_stop <= 1.0):
118
+ raise ValueError(
119
+ f"Expected `skip_layer_guidance_stop` to be between 0.0 and 1.0, but got {skip_layer_guidance_stop}."
120
+ )
121
+
122
+ if skip_layer_guidance_layers is None and skip_layer_config is None:
123
+ raise ValueError(
124
+ "Either `skip_layer_guidance_layers` or `skip_layer_config` must be provided to enable Skip Layer Guidance."
125
+ )
126
+ if skip_layer_guidance_layers is not None and skip_layer_config is not None:
127
+ raise ValueError("Only one of `skip_layer_guidance_layers` or `skip_layer_config` can be provided.")
128
+
129
+ if skip_layer_guidance_layers is not None:
130
+ if isinstance(skip_layer_guidance_layers, int):
131
+ skip_layer_guidance_layers = [skip_layer_guidance_layers]
132
+ if not isinstance(skip_layer_guidance_layers, list):
133
+ raise ValueError(
134
+ f"Expected `skip_layer_guidance_layers` to be an int or a list of ints, but got {type(skip_layer_guidance_layers)}."
135
+ )
136
+ skip_layer_config = [LayerSkipConfig(layer, fqn="auto") for layer in skip_layer_guidance_layers]
137
+
138
+ if isinstance(skip_layer_config, dict):
139
+ skip_layer_config = LayerSkipConfig.from_dict(skip_layer_config)
140
+
141
+ if isinstance(skip_layer_config, LayerSkipConfig):
142
+ skip_layer_config = [skip_layer_config]
143
+
144
+ if not isinstance(skip_layer_config, list):
145
+ raise ValueError(
146
+ f"Expected `skip_layer_config` to be a LayerSkipConfig or a list of LayerSkipConfig, but got {type(skip_layer_config)}."
147
+ )
148
+ elif isinstance(next(iter(skip_layer_config), None), dict):
149
+ skip_layer_config = [LayerSkipConfig.from_dict(config) for config in skip_layer_config]
150
+
151
+ self.skip_layer_config = skip_layer_config
152
+ self._skip_layer_hook_names = [f"SkipLayerGuidance_{i}" for i in range(len(self.skip_layer_config))]
153
+
154
+ def prepare_models(self, denoiser: torch.nn.Module) -> None:
155
+ self._count_prepared += 1
156
+ if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
157
+ for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config):
158
+ _apply_layer_skip_hook(denoiser, config, name=name)
159
+
160
+ def cleanup_models(self, denoiser: torch.nn.Module) -> None:
161
+ if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
162
+ registry = HookRegistry.check_if_exists_or_initialize(denoiser)
163
+ # Remove the hooks after inference
164
+ for hook_name in self._skip_layer_hook_names:
165
+ registry.remove_hook(hook_name, recurse=True)
166
+
167
+ def prepare_inputs(
168
+ self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
169
+ ) -> List["BlockState"]:
170
+ if input_fields is None:
171
+ input_fields = self._input_fields
172
+
173
+ if self.num_conditions == 1:
174
+ tuple_indices = [0]
175
+ input_predictions = ["pred_cond"]
176
+ elif self.num_conditions == 2:
177
+ tuple_indices = [0, 1]
178
+ input_predictions = (
179
+ ["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"]
180
+ )
181
+ else:
182
+ tuple_indices = [0, 1, 0]
183
+ input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
184
+ data_batches = []
185
+ for i in range(self.num_conditions):
186
+ data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
187
+ data_batches.append(data_batch)
188
+ return data_batches
189
+
190
+ def forward(
191
+ self,
192
+ pred_cond: torch.Tensor,
193
+ pred_uncond: Optional[torch.Tensor] = None,
194
+ pred_cond_skip: Optional[torch.Tensor] = None,
195
+ ) -> GuiderOutput:
196
+ pred = None
197
+
198
+ if not self._is_cfg_enabled() and not self._is_slg_enabled():
199
+ pred = pred_cond
200
+ elif not self._is_cfg_enabled():
201
+ shift = pred_cond - pred_cond_skip
202
+ pred = pred_cond if self.use_original_formulation else pred_cond_skip
203
+ pred = pred + self.skip_layer_guidance_scale * shift
204
+ elif not self._is_slg_enabled():
205
+ shift = pred_cond - pred_uncond
206
+ pred = pred_cond if self.use_original_formulation else pred_uncond
207
+ pred = pred + self.guidance_scale * shift
208
+ else:
209
+ shift = pred_cond - pred_uncond
210
+ shift_skip = pred_cond - pred_cond_skip
211
+ pred = pred_cond if self.use_original_formulation else pred_uncond
212
+ pred = pred + self.guidance_scale * shift + self.skip_layer_guidance_scale * shift_skip
213
+
214
+ if self.guidance_rescale > 0.0:
215
+ pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
216
+
217
+ return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
218
+
219
+ @property
220
+ def is_conditional(self) -> bool:
221
+ return self._count_prepared == 1 or self._count_prepared == 3
222
+
223
+ @property
224
+ def num_conditions(self) -> int:
225
+ num_conditions = 1
226
+ if self._is_cfg_enabled():
227
+ num_conditions += 1
228
+ if self._is_slg_enabled():
229
+ num_conditions += 1
230
+ return num_conditions
231
+
232
+ def _is_cfg_enabled(self) -> bool:
233
+ if not self._enabled:
234
+ return False
235
+
236
+ is_within_range = True
237
+ if self._num_inference_steps is not None:
238
+ skip_start_step = int(self._start * self._num_inference_steps)
239
+ skip_stop_step = int(self._stop * self._num_inference_steps)
240
+ is_within_range = skip_start_step <= self._step < skip_stop_step
241
+
242
+ is_close = False
243
+ if self.use_original_formulation:
244
+ is_close = math.isclose(self.guidance_scale, 0.0)
245
+ else:
246
+ is_close = math.isclose(self.guidance_scale, 1.0)
247
+
248
+ return is_within_range and not is_close
249
+
250
+ def _is_slg_enabled(self) -> bool:
251
+ if not self._enabled:
252
+ return False
253
+
254
+ is_within_range = True
255
+ if self._num_inference_steps is not None:
256
+ skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
257
+ skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
258
+ is_within_range = skip_start_step < self._step < skip_stop_step
259
+
260
+ is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0)
261
+
262
+ return is_within_range and not is_zero
pythonProject/.venv/Lib/site-packages/diffusers/guiders/smoothed_energy_guidance.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+
20
+ from ..configuration_utils import register_to_config
21
+ from ..hooks import HookRegistry
22
+ from ..hooks.smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig, _apply_smoothed_energy_guidance_hook
23
+ from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
24
+
25
+
26
+ if TYPE_CHECKING:
27
+ from ..modular_pipelines.modular_pipeline import BlockState
28
+
29
+
30
+ class SmoothedEnergyGuidance(BaseGuidance):
31
+ """
32
+ Smoothed Energy Guidance (SEG): https://huggingface.co/papers/2408.00760
33
+
34
+ SEG is only supported as an experimental prototype feature for now, so the implementation may be modified in the
35
+ future without warning or guarantee of reproducibility. This implementation assumes:
36
+ - Generated images are square (height == width)
37
+ - The model does not combine different modalities together (e.g., text and image latent streams are not combined
38
+ together such as Flux)
39
+
40
+ Args:
41
+ guidance_scale (`float`, defaults to `7.5`):
42
+ The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
43
+ prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
44
+ deterioration of image quality.
45
+ seg_guidance_scale (`float`, defaults to `3.0`):
46
+ The scale parameter for smoothed energy guidance. Anatomy and structure coherence may improve with higher
47
+ values, but it may also lead to overexposure and saturation.
48
+ seg_blur_sigma (`float`, defaults to `9999999.0`):
49
+ The amount by which we blur the attention weights. Setting this value greater than 9999.0 results in
50
+ infinite blur, which means uniform queries. Controlling it exponentially is empirically effective.
51
+ seg_blur_threshold_inf (`float`, defaults to `9999.0`):
52
+ The threshold above which the blur is considered infinite.
53
+ seg_guidance_start (`float`, defaults to `0.0`):
54
+ The fraction of the total number of denoising steps after which smoothed energy guidance starts.
55
+ seg_guidance_stop (`float`, defaults to `1.0`):
56
+ The fraction of the total number of denoising steps after which smoothed energy guidance stops.
57
+ seg_guidance_layers (`int` or `List[int]`, *optional*):
58
+ The layer indices to apply smoothed energy guidance to. Can be a single integer or a list of integers. If
59
+ not provided, `seg_guidance_config` must be provided. The recommended values are `[7, 8, 9]` for Stable
60
+ Diffusion 3.5 Medium.
61
+ seg_guidance_config (`SmoothedEnergyGuidanceConfig` or `List[SmoothedEnergyGuidanceConfig]`, *optional*):
62
+ The configuration for the smoothed energy layer guidance. Can be a single `SmoothedEnergyGuidanceConfig` or
63
+ a list of `SmoothedEnergyGuidanceConfig`. If not provided, `seg_guidance_layers` must be provided.
64
+ guidance_rescale (`float`, defaults to `0.0`):
65
+ The rescale factor applied to the noise predictions. This is used to improve image quality and fix
66
+ overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
67
+ Flawed](https://huggingface.co/papers/2305.08891).
68
+ use_original_formulation (`bool`, defaults to `False`):
69
+ Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
70
+ we use the diffusers-native implementation that has been in the codebase for a long time. See
71
+ [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
72
+ start (`float`, defaults to `0.01`):
73
+ The fraction of the total number of denoising steps after which guidance starts.
74
+ stop (`float`, defaults to `0.2`):
75
+ The fraction of the total number of denoising steps after which guidance stops.
76
+ """
77
+
78
+ _input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"]
79
+
80
+ @register_to_config
81
+ def __init__(
82
+ self,
83
+ guidance_scale: float = 7.5,
84
+ seg_guidance_scale: float = 2.8,
85
+ seg_blur_sigma: float = 9999999.0,
86
+ seg_blur_threshold_inf: float = 9999.0,
87
+ seg_guidance_start: float = 0.0,
88
+ seg_guidance_stop: float = 1.0,
89
+ seg_guidance_layers: Optional[Union[int, List[int]]] = None,
90
+ seg_guidance_config: Union[SmoothedEnergyGuidanceConfig, List[SmoothedEnergyGuidanceConfig]] = None,
91
+ guidance_rescale: float = 0.0,
92
+ use_original_formulation: bool = False,
93
+ start: float = 0.0,
94
+ stop: float = 1.0,
95
+ ):
96
+ super().__init__(start, stop)
97
+
98
+ self.guidance_scale = guidance_scale
99
+ self.seg_guidance_scale = seg_guidance_scale
100
+ self.seg_blur_sigma = seg_blur_sigma
101
+ self.seg_blur_threshold_inf = seg_blur_threshold_inf
102
+ self.seg_guidance_start = seg_guidance_start
103
+ self.seg_guidance_stop = seg_guidance_stop
104
+ self.guidance_rescale = guidance_rescale
105
+ self.use_original_formulation = use_original_formulation
106
+
107
+ if not (0.0 <= seg_guidance_start < 1.0):
108
+ raise ValueError(f"Expected `seg_guidance_start` to be between 0.0 and 1.0, but got {seg_guidance_start}.")
109
+ if not (seg_guidance_start <= seg_guidance_stop <= 1.0):
110
+ raise ValueError(f"Expected `seg_guidance_stop` to be between 0.0 and 1.0, but got {seg_guidance_stop}.")
111
+
112
+ if seg_guidance_layers is None and seg_guidance_config is None:
113
+ raise ValueError(
114
+ "Either `seg_guidance_layers` or `seg_guidance_config` must be provided to enable Smoothed Energy Guidance."
115
+ )
116
+ if seg_guidance_layers is not None and seg_guidance_config is not None:
117
+ raise ValueError("Only one of `seg_guidance_layers` or `seg_guidance_config` can be provided.")
118
+
119
+ if seg_guidance_layers is not None:
120
+ if isinstance(seg_guidance_layers, int):
121
+ seg_guidance_layers = [seg_guidance_layers]
122
+ if not isinstance(seg_guidance_layers, list):
123
+ raise ValueError(
124
+ f"Expected `seg_guidance_layers` to be an int or a list of ints, but got {type(seg_guidance_layers)}."
125
+ )
126
+ seg_guidance_config = [SmoothedEnergyGuidanceConfig(layer, fqn="auto") for layer in seg_guidance_layers]
127
+
128
+ if isinstance(seg_guidance_config, dict):
129
+ seg_guidance_config = SmoothedEnergyGuidanceConfig.from_dict(seg_guidance_config)
130
+
131
+ if isinstance(seg_guidance_config, SmoothedEnergyGuidanceConfig):
132
+ seg_guidance_config = [seg_guidance_config]
133
+
134
+ if not isinstance(seg_guidance_config, list):
135
+ raise ValueError(
136
+ f"Expected `seg_guidance_config` to be a SmoothedEnergyGuidanceConfig or a list of SmoothedEnergyGuidanceConfig, but got {type(seg_guidance_config)}."
137
+ )
138
+ elif isinstance(next(iter(seg_guidance_config), None), dict):
139
+ seg_guidance_config = [SmoothedEnergyGuidanceConfig.from_dict(config) for config in seg_guidance_config]
140
+
141
+ self.seg_guidance_config = seg_guidance_config
142
+ self._seg_layer_hook_names = [f"SmoothedEnergyGuidance_{i}" for i in range(len(self.seg_guidance_config))]
143
+
144
+ def prepare_models(self, denoiser: torch.nn.Module) -> None:
145
+ if self._is_seg_enabled() and self.is_conditional and self._count_prepared > 1:
146
+ for name, config in zip(self._seg_layer_hook_names, self.seg_guidance_config):
147
+ _apply_smoothed_energy_guidance_hook(denoiser, config, self.seg_blur_sigma, name=name)
148
+
149
+ def cleanup_models(self, denoiser: torch.nn.Module):
150
+ if self._is_seg_enabled() and self.is_conditional and self._count_prepared > 1:
151
+ registry = HookRegistry.check_if_exists_or_initialize(denoiser)
152
+ # Remove the hooks after inference
153
+ for hook_name in self._seg_layer_hook_names:
154
+ registry.remove_hook(hook_name, recurse=True)
155
+
156
+ def prepare_inputs(
157
+ self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
158
+ ) -> List["BlockState"]:
159
+ if input_fields is None:
160
+ input_fields = self._input_fields
161
+
162
+ if self.num_conditions == 1:
163
+ tuple_indices = [0]
164
+ input_predictions = ["pred_cond"]
165
+ elif self.num_conditions == 2:
166
+ tuple_indices = [0, 1]
167
+ input_predictions = (
168
+ ["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_seg"]
169
+ )
170
+ else:
171
+ tuple_indices = [0, 1, 0]
172
+ input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"]
173
+ data_batches = []
174
+ for i in range(self.num_conditions):
175
+ data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
176
+ data_batches.append(data_batch)
177
+ return data_batches
178
+
179
+ def forward(
180
+ self,
181
+ pred_cond: torch.Tensor,
182
+ pred_uncond: Optional[torch.Tensor] = None,
183
+ pred_cond_seg: Optional[torch.Tensor] = None,
184
+ ) -> GuiderOutput:
185
+ pred = None
186
+
187
+ if not self._is_cfg_enabled() and not self._is_seg_enabled():
188
+ pred = pred_cond
189
+ elif not self._is_cfg_enabled():
190
+ shift = pred_cond - pred_cond_seg
191
+ pred = pred_cond if self.use_original_formulation else pred_cond_seg
192
+ pred = pred + self.seg_guidance_scale * shift
193
+ elif not self._is_seg_enabled():
194
+ shift = pred_cond - pred_uncond
195
+ pred = pred_cond if self.use_original_formulation else pred_uncond
196
+ pred = pred + self.guidance_scale * shift
197
+ else:
198
+ shift = pred_cond - pred_uncond
199
+ shift_seg = pred_cond - pred_cond_seg
200
+ pred = pred_cond if self.use_original_formulation else pred_uncond
201
+ pred = pred + self.guidance_scale * shift + self.seg_guidance_scale * shift_seg
202
+
203
+ if self.guidance_rescale > 0.0:
204
+ pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
205
+
206
+ return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
207
+
208
+ @property
209
+ def is_conditional(self) -> bool:
210
+ return self._count_prepared == 1 or self._count_prepared == 3
211
+
212
+ @property
213
+ def num_conditions(self) -> int:
214
+ num_conditions = 1
215
+ if self._is_cfg_enabled():
216
+ num_conditions += 1
217
+ if self._is_seg_enabled():
218
+ num_conditions += 1
219
+ return num_conditions
220
+
221
+ def _is_cfg_enabled(self) -> bool:
222
+ if not self._enabled:
223
+ return False
224
+
225
+ is_within_range = True
226
+ if self._num_inference_steps is not None:
227
+ skip_start_step = int(self._start * self._num_inference_steps)
228
+ skip_stop_step = int(self._stop * self._num_inference_steps)
229
+ is_within_range = skip_start_step <= self._step < skip_stop_step
230
+
231
+ is_close = False
232
+ if self.use_original_formulation:
233
+ is_close = math.isclose(self.guidance_scale, 0.0)
234
+ else:
235
+ is_close = math.isclose(self.guidance_scale, 1.0)
236
+
237
+ return is_within_range and not is_close
238
+
239
+ def _is_seg_enabled(self) -> bool:
240
+ if not self._enabled:
241
+ return False
242
+
243
+ is_within_range = True
244
+ if self._num_inference_steps is not None:
245
+ skip_start_step = int(self.seg_guidance_start * self._num_inference_steps)
246
+ skip_stop_step = int(self.seg_guidance_stop * self._num_inference_steps)
247
+ is_within_range = skip_start_step < self._step < skip_stop_step
248
+
249
+ is_zero = math.isclose(self.seg_guidance_scale, 0.0)
250
+
251
+ return is_within_range and not is_zero
pythonProject/.venv/Lib/site-packages/diffusers/guiders/tangential_classifier_free_guidance.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+
20
+ from ..configuration_utils import register_to_config
21
+ from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
22
+
23
+
24
+ if TYPE_CHECKING:
25
+ from ..modular_pipelines.modular_pipeline import BlockState
26
+
27
+
28
+ class TangentialClassifierFreeGuidance(BaseGuidance):
29
+ """
30
+ Tangential Classifier Free Guidance (TCFG): https://huggingface.co/papers/2503.18137
31
+
32
+ Args:
33
+ guidance_scale (`float`, defaults to `7.5`):
34
+ The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
35
+ prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
36
+ deterioration of image quality.
37
+ guidance_rescale (`float`, defaults to `0.0`):
38
+ The rescale factor applied to the noise predictions. This is used to improve image quality and fix
39
+ overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
40
+ Flawed](https://huggingface.co/papers/2305.08891).
41
+ use_original_formulation (`bool`, defaults to `False`):
42
+ Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
43
+ we use the diffusers-native implementation that has been in the codebase for a long time. See
44
+ [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
45
+ start (`float`, defaults to `0.0`):
46
+ The fraction of the total number of denoising steps after which guidance starts.
47
+ stop (`float`, defaults to `1.0`):
48
+ The fraction of the total number of denoising steps after which guidance stops.
49
+ """
50
+
51
+ _input_predictions = ["pred_cond", "pred_uncond"]
52
+
53
+ @register_to_config
54
+ def __init__(
55
+ self,
56
+ guidance_scale: float = 7.5,
57
+ guidance_rescale: float = 0.0,
58
+ use_original_formulation: bool = False,
59
+ start: float = 0.0,
60
+ stop: float = 1.0,
61
+ ):
62
+ super().__init__(start, stop)
63
+
64
+ self.guidance_scale = guidance_scale
65
+ self.guidance_rescale = guidance_rescale
66
+ self.use_original_formulation = use_original_formulation
67
+
68
+ def prepare_inputs(
69
+ self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
70
+ ) -> List["BlockState"]:
71
+ if input_fields is None:
72
+ input_fields = self._input_fields
73
+
74
+ tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
75
+ data_batches = []
76
+ for i in range(self.num_conditions):
77
+ data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
78
+ data_batches.append(data_batch)
79
+ return data_batches
80
+
81
+ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
82
+ pred = None
83
+
84
+ if not self._is_tcfg_enabled():
85
+ pred = pred_cond
86
+ else:
87
+ pred = normalized_guidance(pred_cond, pred_uncond, self.guidance_scale, self.use_original_formulation)
88
+
89
+ if self.guidance_rescale > 0.0:
90
+ pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
91
+
92
+ return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
93
+
94
+ @property
95
+ def is_conditional(self) -> bool:
96
+ return self._num_outputs_prepared == 1
97
+
98
+ @property
99
+ def num_conditions(self) -> int:
100
+ num_conditions = 1
101
+ if self._is_tcfg_enabled():
102
+ num_conditions += 1
103
+ return num_conditions
104
+
105
+ def _is_tcfg_enabled(self) -> bool:
106
+ if not self._enabled:
107
+ return False
108
+
109
+ is_within_range = True
110
+ if self._num_inference_steps is not None:
111
+ skip_start_step = int(self._start * self._num_inference_steps)
112
+ skip_stop_step = int(self._stop * self._num_inference_steps)
113
+ is_within_range = skip_start_step <= self._step < skip_stop_step
114
+
115
+ is_close = False
116
+ if self.use_original_formulation:
117
+ is_close = math.isclose(self.guidance_scale, 0.0)
118
+ else:
119
+ is_close = math.isclose(self.guidance_scale, 1.0)
120
+
121
+ return is_within_range and not is_close
122
+
123
+
124
+ def normalized_guidance(
125
+ pred_cond: torch.Tensor, pred_uncond: torch.Tensor, guidance_scale: float, use_original_formulation: bool = False
126
+ ) -> torch.Tensor:
127
+ cond_dtype = pred_cond.dtype
128
+ preds = torch.stack([pred_cond, pred_uncond], dim=1).float()
129
+ preds = preds.flatten(2)
130
+ U, S, Vh = torch.linalg.svd(preds, full_matrices=False)
131
+ Vh_modified = Vh.clone()
132
+ Vh_modified[:, 1] = 0
133
+
134
+ uncond_flat = pred_uncond.reshape(pred_uncond.size(0), 1, -1).float()
135
+ x_Vh = torch.matmul(uncond_flat, Vh.transpose(-2, -1))
136
+ x_Vh_V = torch.matmul(x_Vh, Vh_modified)
137
+ pred_uncond = x_Vh_V.reshape(pred_uncond.shape).to(cond_dtype)
138
+
139
+ pred = pred_cond if use_original_formulation else pred_uncond
140
+ shift = pred_cond - pred_uncond
141
+ pred = pred + guidance_scale * shift
142
+
143
+ return pred
pythonProject/.venv/Lib/site-packages/diffusers/hooks/faster_cache.py ADDED
@@ -0,0 +1,654 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import re
16
+ from dataclasses import dataclass
17
+ from typing import Any, Callable, List, Optional, Tuple
18
+
19
+ import torch
20
+
21
+ from ..models.attention import AttentionModuleMixin
22
+ from ..models.modeling_outputs import Transformer2DModelOutput
23
+ from ..utils import logging
24
+ from ._common import _ATTENTION_CLASSES
25
+ from .hooks import HookRegistry, ModelHook
26
+
27
+
28
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
29
+
30
+
31
+ _FASTER_CACHE_DENOISER_HOOK = "faster_cache_denoiser"
32
+ _FASTER_CACHE_BLOCK_HOOK = "faster_cache_block"
33
+ _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = (
34
+ "^blocks.*attn",
35
+ "^transformer_blocks.*attn",
36
+ "^single_transformer_blocks.*attn",
37
+ )
38
+ _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("^temporal_transformer_blocks.*attn",)
39
+ _TRANSFORMER_BLOCK_IDENTIFIERS = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS + _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
40
+ _UNCOND_COND_INPUT_KWARGS_IDENTIFIERS = (
41
+ "hidden_states",
42
+ "encoder_hidden_states",
43
+ "timestep",
44
+ "attention_mask",
45
+ "encoder_attention_mask",
46
+ )
47
+
48
+
49
+ @dataclass
50
+ class FasterCacheConfig:
51
+ r"""
52
+ Configuration for [FasterCache](https://huggingface.co/papers/2410.19355).
53
+
54
+ Attributes:
55
+ spatial_attention_block_skip_range (`int`, defaults to `2`):
56
+ Calculate the attention states every `N` iterations. If this is set to `N`, the attention computation will
57
+ be skipped `N - 1` times (i.e., cached attention states will be reused) before computing the new attention
58
+ states again.
59
+ temporal_attention_block_skip_range (`int`, *optional*, defaults to `None`):
60
+ Calculate the attention states every `N` iterations. If this is set to `N`, the attention computation will
61
+ be skipped `N - 1` times (i.e., cached attention states will be reused) before computing the new attention
62
+ states again.
63
+ spatial_attention_timestep_skip_range (`Tuple[float, float]`, defaults to `(-1, 681)`):
64
+ The timestep range within which the spatial attention computation can be skipped without a significant loss
65
+ in quality. This is to be determined by the user based on the underlying model. The first value in the
66
+ tuple is the lower bound and the second value is the upper bound. Typically, diffusion timesteps for
67
+ denoising are in the reversed range of 0 to 1000 (i.e. denoising starts at timestep 1000 and ends at
68
+ timestep 0). For the default values, this would mean that the spatial attention computation skipping will
69
+ be applicable only after denoising timestep 681 is reached, and continue until the end of the denoising
70
+ process.
71
+ temporal_attention_timestep_skip_range (`Tuple[float, float]`, *optional*, defaults to `None`):
72
+ The timestep range within which the temporal attention computation can be skipped without a significant
73
+ loss in quality. This is to be determined by the user based on the underlying model. The first value in the
74
+ tuple is the lower bound and the second value is the upper bound. Typically, diffusion timesteps for
75
+ denoising are in the reversed range of 0 to 1000 (i.e. denoising starts at timestep 1000 and ends at
76
+ timestep 0).
77
+ low_frequency_weight_update_timestep_range (`Tuple[int, int]`, defaults to `(99, 901)`):
78
+ The timestep range within which the low frequency weight scaling update is applied. The first value in the
79
+ tuple is the lower bound and the second value is the upper bound of the timestep range. The callback
80
+ function for the update is called only within this range.
81
+ high_frequency_weight_update_timestep_range (`Tuple[int, int]`, defaults to `(-1, 301)`):
82
+ The timestep range within which the high frequency weight scaling update is applied. The first value in the
83
+ tuple is the lower bound and the second value is the upper bound of the timestep range. The callback
84
+ function for the update is called only within this range.
85
+ alpha_low_frequency (`float`, defaults to `1.1`):
86
+ The weight to scale the low frequency updates by. This is used to approximate the unconditional branch from
87
+ the conditional branch outputs.
88
+ alpha_high_frequency (`float`, defaults to `1.1`):
89
+ The weight to scale the high frequency updates by. This is used to approximate the unconditional branch
90
+ from the conditional branch outputs.
91
+ unconditional_batch_skip_range (`int`, defaults to `5`):
92
+ Process the unconditional branch every `N` iterations. If this is set to `N`, the unconditional branch
93
+ computation will be skipped `N - 1` times (i.e., cached unconditional branch states will be reused) before
94
+ computing the new unconditional branch states again.
95
+ unconditional_batch_timestep_skip_range (`Tuple[float, float]`, defaults to `(-1, 641)`):
96
+ The timestep range within which the unconditional branch computation can be skipped without a significant
97
+ loss in quality. This is to be determined by the user based on the underlying model. The first value in the
98
+ tuple is the lower bound and the second value is the upper bound.
99
+ spatial_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks.*attn1", "transformer_blocks.*attn1", "single_transformer_blocks.*attn1")`):
100
+ The identifiers to match the spatial attention blocks in the model. If the name of the block contains any
101
+ of these identifiers, FasterCache will be applied to that block. This can either be the full layer names,
102
+ partial layer names, or regex patterns. Matching will always be done using a regex match.
103
+ temporal_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("temporal_transformer_blocks.*attn1",)`):
104
+ The identifiers to match the temporal attention blocks in the model. If the name of the block contains any
105
+ of these identifiers, FasterCache will be applied to that block. This can either be the full layer names,
106
+ partial layer names, or regex patterns. Matching will always be done using a regex match.
107
+ attention_weight_callback (`Callable[[torch.nn.Module], float]`, defaults to `None`):
108
+ The callback function to determine the weight to scale the attention outputs by. This function should take
109
+ the attention module as input and return a float value. This is used to approximate the unconditional
110
+ branch from the conditional branch outputs. If not provided, the default weight is 0.5 for all timesteps.
111
+ Typically, as described in the paper, this weight should gradually increase from 0 to 1 as the inference
112
+ progresses. Users are encouraged to experiment and provide custom weight schedules that take into account
113
+ the number of inference steps and underlying model behaviour as denoising progresses.
114
+ low_frequency_weight_callback (`Callable[[torch.nn.Module], float]`, defaults to `None`):
115
+ The callback function to determine the weight to scale the low frequency updates by. If not provided, the
116
+ default weight is 1.1 for timesteps within the range specified (as described in the paper).
117
+ high_frequency_weight_callback (`Callable[[torch.nn.Module], float]`, defaults to `None`):
118
+ The callback function to determine the weight to scale the high frequency updates by. If not provided, the
119
+ default weight is 1.1 for timesteps within the range specified (as described in the paper).
120
+ tensor_format (`str`, defaults to `"BCFHW"`):
121
+ The format of the input tensors. This should be one of `"BCFHW"`, `"BFCHW"`, or `"BCHW"`. The format is
122
+ used to split individual latent frames in order for low and high frequency components to be computed.
123
+ is_guidance_distilled (`bool`, defaults to `False`):
124
+ Whether the model is guidance distilled or not. If the model is guidance distilled, FasterCache will not be
125
+ applied at the denoiser-level to skip the unconditional branch computation (as there is none).
126
+ _unconditional_conditional_input_kwargs_identifiers (`List[str]`, defaults to `("hidden_states", "encoder_hidden_states", "timestep", "attention_mask", "encoder_attention_mask")`):
127
+ The identifiers to match the input kwargs that contain the batchwise-concatenated unconditional and
128
+ conditional inputs. If the name of the input kwargs contains any of these identifiers, FasterCache will
129
+ split the inputs into unconditional and conditional branches. This must be a list of exact input kwargs
130
+ names that contain the batchwise-concatenated unconditional and conditional inputs.
131
+ """
132
+
133
+ # In the paper and codebase, they hardcode these values to 2. However, it can be made configurable
134
+ # after some testing. We default to 2 if these parameters are not provided.
135
+ spatial_attention_block_skip_range: int = 2
136
+ temporal_attention_block_skip_range: Optional[int] = None
137
+
138
+ spatial_attention_timestep_skip_range: Tuple[int, int] = (-1, 681)
139
+ temporal_attention_timestep_skip_range: Tuple[int, int] = (-1, 681)
140
+
141
+ # Indicator functions for low/high frequency as mentioned in Equation 11 of the paper
142
+ low_frequency_weight_update_timestep_range: Tuple[int, int] = (99, 901)
143
+ high_frequency_weight_update_timestep_range: Tuple[int, int] = (-1, 301)
144
+
145
+ # ⍺1 and ⍺2 as mentioned in Equation 11 of the paper
146
+ alpha_low_frequency: float = 1.1
147
+ alpha_high_frequency: float = 1.1
148
+
149
+ # n as described in CFG-Cache explanation in the paper - dependent on the model
150
+ unconditional_batch_skip_range: int = 5
151
+ unconditional_batch_timestep_skip_range: Tuple[int, int] = (-1, 641)
152
+
153
+ spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS
154
+ temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
155
+
156
+ attention_weight_callback: Callable[[torch.nn.Module], float] = None
157
+ low_frequency_weight_callback: Callable[[torch.nn.Module], float] = None
158
+ high_frequency_weight_callback: Callable[[torch.nn.Module], float] = None
159
+
160
+ tensor_format: str = "BCFHW"
161
+ is_guidance_distilled: bool = False
162
+
163
+ current_timestep_callback: Callable[[], int] = None
164
+
165
+ _unconditional_conditional_input_kwargs_identifiers: List[str] = _UNCOND_COND_INPUT_KWARGS_IDENTIFIERS
166
+
167
+ def __repr__(self) -> str:
168
+ return (
169
+ f"FasterCacheConfig(\n"
170
+ f" spatial_attention_block_skip_range={self.spatial_attention_block_skip_range},\n"
171
+ f" temporal_attention_block_skip_range={self.temporal_attention_block_skip_range},\n"
172
+ f" spatial_attention_timestep_skip_range={self.spatial_attention_timestep_skip_range},\n"
173
+ f" temporal_attention_timestep_skip_range={self.temporal_attention_timestep_skip_range},\n"
174
+ f" low_frequency_weight_update_timestep_range={self.low_frequency_weight_update_timestep_range},\n"
175
+ f" high_frequency_weight_update_timestep_range={self.high_frequency_weight_update_timestep_range},\n"
176
+ f" alpha_low_frequency={self.alpha_low_frequency},\n"
177
+ f" alpha_high_frequency={self.alpha_high_frequency},\n"
178
+ f" unconditional_batch_skip_range={self.unconditional_batch_skip_range},\n"
179
+ f" unconditional_batch_timestep_skip_range={self.unconditional_batch_timestep_skip_range},\n"
180
+ f" spatial_attention_block_identifiers={self.spatial_attention_block_identifiers},\n"
181
+ f" temporal_attention_block_identifiers={self.temporal_attention_block_identifiers},\n"
182
+ f" tensor_format={self.tensor_format},\n"
183
+ f")"
184
+ )
185
+
186
+
187
+ class FasterCacheDenoiserState:
188
+ r"""
189
+ State for [FasterCache](https://huggingface.co/papers/2410.19355) top-level denoiser module.
190
+ """
191
+
192
+ def __init__(self) -> None:
193
+ self.iteration: int = 0
194
+ self.low_frequency_delta: torch.Tensor = None
195
+ self.high_frequency_delta: torch.Tensor = None
196
+
197
+ def reset(self):
198
+ self.iteration = 0
199
+ self.low_frequency_delta = None
200
+ self.high_frequency_delta = None
201
+
202
+
203
+ class FasterCacheBlockState:
204
+ r"""
205
+ State for [FasterCache](https://huggingface.co/papers/2410.19355). Every underlying block that FasterCache is
206
+ applied to will have an instance of this state.
207
+ """
208
+
209
+ def __init__(self) -> None:
210
+ self.iteration: int = 0
211
+ self.batch_size: int = None
212
+ self.cache: Tuple[torch.Tensor, torch.Tensor] = None
213
+
214
+ def reset(self):
215
+ self.iteration = 0
216
+ self.batch_size = None
217
+ self.cache = None
218
+
219
+
220
+ class FasterCacheDenoiserHook(ModelHook):
221
+ _is_stateful = True
222
+
223
+ def __init__(
224
+ self,
225
+ unconditional_batch_skip_range: int,
226
+ unconditional_batch_timestep_skip_range: Tuple[int, int],
227
+ tensor_format: str,
228
+ is_guidance_distilled: bool,
229
+ uncond_cond_input_kwargs_identifiers: List[str],
230
+ current_timestep_callback: Callable[[], int],
231
+ low_frequency_weight_callback: Callable[[torch.nn.Module], torch.Tensor],
232
+ high_frequency_weight_callback: Callable[[torch.nn.Module], torch.Tensor],
233
+ ) -> None:
234
+ super().__init__()
235
+
236
+ self.unconditional_batch_skip_range = unconditional_batch_skip_range
237
+ self.unconditional_batch_timestep_skip_range = unconditional_batch_timestep_skip_range
238
+ # We can't easily detect what args are to be split in unconditional and conditional branches. We
239
+ # can only do it for kwargs, hence they are the only ones we split. The args are passed as-is.
240
+ # If a model is to be made compatible with FasterCache, the user must ensure that the inputs that
241
+ # contain batchwise-concatenated unconditional and conditional inputs are passed as kwargs.
242
+ self.uncond_cond_input_kwargs_identifiers = uncond_cond_input_kwargs_identifiers
243
+ self.tensor_format = tensor_format
244
+ self.is_guidance_distilled = is_guidance_distilled
245
+
246
+ self.current_timestep_callback = current_timestep_callback
247
+ self.low_frequency_weight_callback = low_frequency_weight_callback
248
+ self.high_frequency_weight_callback = high_frequency_weight_callback
249
+
250
+ def initialize_hook(self, module):
251
+ self.state = FasterCacheDenoiserState()
252
+ return module
253
+
254
+ @staticmethod
255
+ def _get_cond_input(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
256
+ # Note: this method assumes that the input tensor is batchwise-concatenated with unconditional inputs
257
+ # followed by conditional inputs.
258
+ _, cond = input.chunk(2, dim=0)
259
+ return cond
260
+
261
+ def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any:
262
+ # Split the unconditional and conditional inputs. We only want to infer the conditional branch if the
263
+ # requirements for skipping the unconditional branch are met as described in the paper.
264
+ # We skip the unconditional branch only if the following conditions are met:
265
+ # 1. We have completed at least one iteration of the denoiser
266
+ # 2. The current timestep is within the range specified by the user. This is the optimal timestep range
267
+ # where approximating the unconditional branch from the computation of the conditional branch is possible
268
+ # without a significant loss in quality.
269
+ # 3. The current iteration is not a multiple of the unconditional batch skip range. This is done so that
270
+ # we compute the unconditional branch at least once every few iterations to ensure minimal quality loss.
271
+ is_within_timestep_range = (
272
+ self.unconditional_batch_timestep_skip_range[0]
273
+ < self.current_timestep_callback()
274
+ < self.unconditional_batch_timestep_skip_range[1]
275
+ )
276
+ should_skip_uncond = (
277
+ self.state.iteration > 0
278
+ and is_within_timestep_range
279
+ and self.state.iteration % self.unconditional_batch_skip_range != 0
280
+ and not self.is_guidance_distilled
281
+ )
282
+
283
+ if should_skip_uncond:
284
+ is_any_kwarg_uncond = any(k in self.uncond_cond_input_kwargs_identifiers for k in kwargs.keys())
285
+ if is_any_kwarg_uncond:
286
+ logger.debug("FasterCache - Skipping unconditional branch computation")
287
+ args = tuple([self._get_cond_input(arg) if torch.is_tensor(arg) else arg for arg in args])
288
+ kwargs = {
289
+ k: v if k not in self.uncond_cond_input_kwargs_identifiers else self._get_cond_input(v)
290
+ for k, v in kwargs.items()
291
+ }
292
+
293
+ output = self.fn_ref.original_forward(*args, **kwargs)
294
+
295
+ if self.is_guidance_distilled:
296
+ self.state.iteration += 1
297
+ return output
298
+
299
+ if torch.is_tensor(output):
300
+ hidden_states = output
301
+ elif isinstance(output, (tuple, Transformer2DModelOutput)):
302
+ hidden_states = output[0]
303
+
304
+ batch_size = hidden_states.size(0)
305
+
306
+ if should_skip_uncond:
307
+ self.state.low_frequency_delta = self.state.low_frequency_delta * self.low_frequency_weight_callback(
308
+ module
309
+ )
310
+ self.state.high_frequency_delta = self.state.high_frequency_delta * self.high_frequency_weight_callback(
311
+ module
312
+ )
313
+
314
+ if self.tensor_format == "BCFHW":
315
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
316
+ if self.tensor_format == "BCFHW" or self.tensor_format == "BFCHW":
317
+ hidden_states = hidden_states.flatten(0, 1)
318
+
319
+ low_freq_cond, high_freq_cond = _split_low_high_freq(hidden_states.float())
320
+
321
+ # Approximate/compute the unconditional branch outputs as described in Equation 9 and 10 of the paper
322
+ low_freq_uncond = self.state.low_frequency_delta + low_freq_cond
323
+ high_freq_uncond = self.state.high_frequency_delta + high_freq_cond
324
+ uncond_freq = low_freq_uncond + high_freq_uncond
325
+
326
+ uncond_states = torch.fft.ifftshift(uncond_freq)
327
+ uncond_states = torch.fft.ifft2(uncond_states).real
328
+
329
+ if self.tensor_format == "BCFHW" or self.tensor_format == "BFCHW":
330
+ uncond_states = uncond_states.unflatten(0, (batch_size, -1))
331
+ hidden_states = hidden_states.unflatten(0, (batch_size, -1))
332
+ if self.tensor_format == "BCFHW":
333
+ uncond_states = uncond_states.permute(0, 2, 1, 3, 4)
334
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
335
+
336
+ # Concatenate the approximated unconditional and predicted conditional branches
337
+ uncond_states = uncond_states.to(hidden_states.dtype)
338
+ hidden_states = torch.cat([uncond_states, hidden_states], dim=0)
339
+ else:
340
+ uncond_states, cond_states = hidden_states.chunk(2, dim=0)
341
+ if self.tensor_format == "BCFHW":
342
+ uncond_states = uncond_states.permute(0, 2, 1, 3, 4)
343
+ cond_states = cond_states.permute(0, 2, 1, 3, 4)
344
+ if self.tensor_format == "BCFHW" or self.tensor_format == "BFCHW":
345
+ uncond_states = uncond_states.flatten(0, 1)
346
+ cond_states = cond_states.flatten(0, 1)
347
+
348
+ low_freq_uncond, high_freq_uncond = _split_low_high_freq(uncond_states.float())
349
+ low_freq_cond, high_freq_cond = _split_low_high_freq(cond_states.float())
350
+ self.state.low_frequency_delta = low_freq_uncond - low_freq_cond
351
+ self.state.high_frequency_delta = high_freq_uncond - high_freq_cond
352
+
353
+ self.state.iteration += 1
354
+ if torch.is_tensor(output):
355
+ output = hidden_states
356
+ elif isinstance(output, tuple):
357
+ output = (hidden_states, *output[1:])
358
+ else:
359
+ output.sample = hidden_states
360
+
361
+ return output
362
+
363
+ def reset_state(self, module: torch.nn.Module) -> torch.nn.Module:
364
+ self.state.reset()
365
+ return module
366
+
367
+
368
+ class FasterCacheBlockHook(ModelHook):
369
+ _is_stateful = True
370
+
371
+ def __init__(
372
+ self,
373
+ block_skip_range: int,
374
+ timestep_skip_range: Tuple[int, int],
375
+ is_guidance_distilled: bool,
376
+ weight_callback: Callable[[torch.nn.Module], float],
377
+ current_timestep_callback: Callable[[], int],
378
+ ) -> None:
379
+ super().__init__()
380
+
381
+ self.block_skip_range = block_skip_range
382
+ self.timestep_skip_range = timestep_skip_range
383
+ self.is_guidance_distilled = is_guidance_distilled
384
+
385
+ self.weight_callback = weight_callback
386
+ self.current_timestep_callback = current_timestep_callback
387
+
388
+ def initialize_hook(self, module):
389
+ self.state = FasterCacheBlockState()
390
+ return module
391
+
392
+ def _compute_approximated_attention_output(
393
+ self, t_2_output: torch.Tensor, t_output: torch.Tensor, weight: float, batch_size: int
394
+ ) -> torch.Tensor:
395
+ if t_2_output.size(0) != batch_size:
396
+ # The cache t_2_output contains both batchwise-concatenated unconditional-conditional branch outputs. Just
397
+ # take the conditional branch outputs.
398
+ assert t_2_output.size(0) == 2 * batch_size
399
+ t_2_output = t_2_output[batch_size:]
400
+ if t_output.size(0) != batch_size:
401
+ # The cache t_output contains both batchwise-concatenated unconditional-conditional branch outputs. Just
402
+ # take the conditional branch outputs.
403
+ assert t_output.size(0) == 2 * batch_size
404
+ t_output = t_output[batch_size:]
405
+ return t_output + (t_output - t_2_output) * weight
406
+
407
+ def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any:
408
+ batch_size = [
409
+ *[arg.size(0) for arg in args if torch.is_tensor(arg)],
410
+ *[v.size(0) for v in kwargs.values() if torch.is_tensor(v)],
411
+ ][0]
412
+ if self.state.batch_size is None:
413
+ # Will be updated on first forward pass through the denoiser
414
+ self.state.batch_size = batch_size
415
+
416
+ # If we have to skip due to the skip conditions, then let's skip as expected.
417
+ # But, we can't skip if the denoiser wants to infer both unconditional and conditional branches. This
418
+ # is because the expected output shapes of attention layer will not match if we only return values from
419
+ # the cache (which only caches conditional branch outputs). So, if state.batch_size (which is the true
420
+ # unconditional-conditional batch size) is same as the current batch size, we don't perform the layer
421
+ # skip. Otherwise, we conditionally skip the layer based on what state.skip_callback returns.
422
+ is_within_timestep_range = (
423
+ self.timestep_skip_range[0] < self.current_timestep_callback() < self.timestep_skip_range[1]
424
+ )
425
+ if not is_within_timestep_range:
426
+ should_skip_attention = False
427
+ else:
428
+ should_compute_attention = self.state.iteration > 0 and self.state.iteration % self.block_skip_range == 0
429
+ should_skip_attention = not should_compute_attention
430
+ if should_skip_attention:
431
+ should_skip_attention = self.is_guidance_distilled or self.state.batch_size != batch_size
432
+
433
+ if should_skip_attention:
434
+ logger.debug("FasterCache - Skipping attention and using approximation")
435
+ if torch.is_tensor(self.state.cache[-1]):
436
+ t_2_output, t_output = self.state.cache
437
+ weight = self.weight_callback(module)
438
+ output = self._compute_approximated_attention_output(t_2_output, t_output, weight, batch_size)
439
+ else:
440
+ # The cache contains multiple tensors from past N iterations (N=2 for FasterCache). We need to handle all of them.
441
+ # Diffusers blocks can return multiple tensors - let's call them [A, B, C, ...] for simplicity.
442
+ # In our cache, we would have [[A_1, B_1, C_1, ...], [A_2, B_2, C_2, ...], ...] where each list is the output from
443
+ # a forward pass of the block. We need to compute the approximated output for each of these tensors.
444
+ # The zip(*state.cache) operation will give us [(A_1, A_2, ...), (B_1, B_2, ...), (C_1, C_2, ...), ...] which
445
+ # allows us to compute the approximated attention output for each tensor in the cache.
446
+ output = ()
447
+ for t_2_output, t_output in zip(*self.state.cache):
448
+ result = self._compute_approximated_attention_output(
449
+ t_2_output, t_output, self.weight_callback(module), batch_size
450
+ )
451
+ output += (result,)
452
+ else:
453
+ logger.debug("FasterCache - Computing attention")
454
+ output = self.fn_ref.original_forward(*args, **kwargs)
455
+
456
+ # Note that the following condition for getting hidden_states should suffice since Diffusers blocks either return
457
+ # a single hidden_states tensor, or a tuple of (hidden_states, encoder_hidden_states) tensors. We need to handle
458
+ # both cases.
459
+ if torch.is_tensor(output):
460
+ cache_output = output
461
+ if not self.is_guidance_distilled and cache_output.size(0) == self.state.batch_size:
462
+ # The output here can be both unconditional-conditional branch outputs or just conditional branch outputs.
463
+ # This is determined at the higher-level denoiser module. We only want to cache the conditional branch outputs.
464
+ cache_output = cache_output.chunk(2, dim=0)[1]
465
+ else:
466
+ # Cache all return values and perform the same operation as above
467
+ cache_output = ()
468
+ for out in output:
469
+ if not self.is_guidance_distilled and out.size(0) == self.state.batch_size:
470
+ out = out.chunk(2, dim=0)[1]
471
+ cache_output += (out,)
472
+
473
+ if self.state.cache is None:
474
+ self.state.cache = [cache_output, cache_output]
475
+ else:
476
+ self.state.cache = [self.state.cache[-1], cache_output]
477
+
478
+ self.state.iteration += 1
479
+ return output
480
+
481
+ def reset_state(self, module: torch.nn.Module) -> torch.nn.Module:
482
+ self.state.reset()
483
+ return module
484
+
485
+
486
+ def apply_faster_cache(module: torch.nn.Module, config: FasterCacheConfig) -> None:
487
+ r"""
488
+ Applies [FasterCache](https://huggingface.co/papers/2410.19355) to a given pipeline.
489
+
490
+ Args:
491
+ module (`torch.nn.Module`):
492
+ The pytorch module to apply FasterCache to. Typically, this should be a transformer architecture supported
493
+ in Diffusers, such as `CogVideoXTransformer3DModel`, but external implementations may also work.
494
+ config (`FasterCacheConfig`):
495
+ The configuration to use for FasterCache.
496
+
497
+ Example:
498
+ ```python
499
+ >>> import torch
500
+ >>> from diffusers import CogVideoXPipeline, FasterCacheConfig, apply_faster_cache
501
+
502
+ >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
503
+ >>> pipe.to("cuda")
504
+
505
+ >>> config = FasterCacheConfig(
506
+ ... spatial_attention_block_skip_range=2,
507
+ ... spatial_attention_timestep_skip_range=(-1, 681),
508
+ ... low_frequency_weight_update_timestep_range=(99, 641),
509
+ ... high_frequency_weight_update_timestep_range=(-1, 301),
510
+ ... spatial_attention_block_identifiers=["transformer_blocks"],
511
+ ... attention_weight_callback=lambda _: 0.3,
512
+ ... tensor_format="BFCHW",
513
+ ... )
514
+ >>> apply_faster_cache(pipe.transformer, config)
515
+ ```
516
+ """
517
+
518
+ logger.warning(
519
+ "FasterCache is a purely experimental feature and may not work as expected. Not all models support FasterCache. "
520
+ "The API is subject to change in future releases, with no guarantee of backward compatibility. Please report any issues at "
521
+ "https://github.com/huggingface/diffusers/issues."
522
+ )
523
+
524
+ if config.attention_weight_callback is None:
525
+ # If the user has not provided a weight callback, we default to 0.5 for all timesteps.
526
+ # In the paper, they recommend using a gradually increasing weight from 0 to 1 as the inference progresses, but
527
+ # this depends from model-to-model. It is required by the user to provide a weight callback if they want to
528
+ # use a different weight function. Defaulting to 0.5 works well in practice for most cases.
529
+ logger.warning(
530
+ "No `attention_weight_callback` provided when enabling FasterCache. Defaulting to using a weight of 0.5 for all timesteps."
531
+ )
532
+ config.attention_weight_callback = lambda _: 0.5
533
+
534
+ if config.low_frequency_weight_callback is None:
535
+ logger.debug(
536
+ "Low frequency weight callback not provided when enabling FasterCache. Defaulting to behaviour described in the paper."
537
+ )
538
+
539
+ def low_frequency_weight_callback(module: torch.nn.Module) -> float:
540
+ is_within_range = (
541
+ config.low_frequency_weight_update_timestep_range[0]
542
+ < config.current_timestep_callback()
543
+ < config.low_frequency_weight_update_timestep_range[1]
544
+ )
545
+ return config.alpha_low_frequency if is_within_range else 1.0
546
+
547
+ config.low_frequency_weight_callback = low_frequency_weight_callback
548
+
549
+ if config.high_frequency_weight_callback is None:
550
+ logger.debug(
551
+ "High frequency weight callback not provided when enabling FasterCache. Defaulting to behaviour described in the paper."
552
+ )
553
+
554
+ def high_frequency_weight_callback(module: torch.nn.Module) -> float:
555
+ is_within_range = (
556
+ config.high_frequency_weight_update_timestep_range[0]
557
+ < config.current_timestep_callback()
558
+ < config.high_frequency_weight_update_timestep_range[1]
559
+ )
560
+ return config.alpha_high_frequency if is_within_range else 1.0
561
+
562
+ config.high_frequency_weight_callback = high_frequency_weight_callback
563
+
564
+ supported_tensor_formats = ["BCFHW", "BFCHW", "BCHW"] # TODO(aryan): Support BSC for LTX Video
565
+ if config.tensor_format not in supported_tensor_formats:
566
+ raise ValueError(f"`tensor_format` must be one of {supported_tensor_formats}, but got {config.tensor_format}.")
567
+
568
+ _apply_faster_cache_on_denoiser(module, config)
569
+
570
+ for name, submodule in module.named_modules():
571
+ if not isinstance(submodule, _ATTENTION_CLASSES):
572
+ continue
573
+ if any(re.search(identifier, name) is not None for identifier in _TRANSFORMER_BLOCK_IDENTIFIERS):
574
+ _apply_faster_cache_on_attention_class(name, submodule, config)
575
+
576
+
577
+ def _apply_faster_cache_on_denoiser(module: torch.nn.Module, config: FasterCacheConfig) -> None:
578
+ hook = FasterCacheDenoiserHook(
579
+ config.unconditional_batch_skip_range,
580
+ config.unconditional_batch_timestep_skip_range,
581
+ config.tensor_format,
582
+ config.is_guidance_distilled,
583
+ config._unconditional_conditional_input_kwargs_identifiers,
584
+ config.current_timestep_callback,
585
+ config.low_frequency_weight_callback,
586
+ config.high_frequency_weight_callback,
587
+ )
588
+ registry = HookRegistry.check_if_exists_or_initialize(module)
589
+ registry.register_hook(hook, _FASTER_CACHE_DENOISER_HOOK)
590
+
591
+
592
+ def _apply_faster_cache_on_attention_class(name: str, module: AttentionModuleMixin, config: FasterCacheConfig) -> None:
593
+ is_spatial_self_attention = (
594
+ any(re.search(identifier, name) is not None for identifier in config.spatial_attention_block_identifiers)
595
+ and config.spatial_attention_block_skip_range is not None
596
+ and not getattr(module, "is_cross_attention", False)
597
+ )
598
+ is_temporal_self_attention = (
599
+ any(re.search(identifier, name) is not None for identifier in config.temporal_attention_block_identifiers)
600
+ and config.temporal_attention_block_skip_range is not None
601
+ and not module.is_cross_attention
602
+ )
603
+
604
+ block_skip_range, timestep_skip_range, block_type = None, None, None
605
+ if is_spatial_self_attention:
606
+ block_skip_range = config.spatial_attention_block_skip_range
607
+ timestep_skip_range = config.spatial_attention_timestep_skip_range
608
+ block_type = "spatial"
609
+ elif is_temporal_self_attention:
610
+ block_skip_range = config.temporal_attention_block_skip_range
611
+ timestep_skip_range = config.temporal_attention_timestep_skip_range
612
+ block_type = "temporal"
613
+
614
+ if block_skip_range is None or timestep_skip_range is None:
615
+ logger.debug(
616
+ f'Unable to apply FasterCache to the selected layer: "{name}" because it does '
617
+ f"not match any of the required criteria for spatial or temporal attention layers. Note, "
618
+ f"however, that this layer may still be valid for applying PAB. Please specify the correct "
619
+ f"block identifiers in the configuration or use the specialized `apply_faster_cache_on_module` "
620
+ f"function to apply FasterCache to this layer."
621
+ )
622
+ return
623
+
624
+ logger.debug(f"Enabling FasterCache ({block_type}) for layer: {name}")
625
+ hook = FasterCacheBlockHook(
626
+ block_skip_range,
627
+ timestep_skip_range,
628
+ config.is_guidance_distilled,
629
+ config.attention_weight_callback,
630
+ config.current_timestep_callback,
631
+ )
632
+ registry = HookRegistry.check_if_exists_or_initialize(module)
633
+ registry.register_hook(hook, _FASTER_CACHE_BLOCK_HOOK)
634
+
635
+
636
+ # Reference: https://github.com/Vchitect/FasterCache/blob/fab32c15014636dc854948319c0a9a8d92c7acb4/scripts/latte/faster_cache_sample_latte.py#L127C1-L143C39
637
+ @torch.no_grad()
638
+ def _split_low_high_freq(x):
639
+ fft = torch.fft.fft2(x)
640
+ fft_shifted = torch.fft.fftshift(fft)
641
+ height, width = x.shape[-2:]
642
+ radius = min(height, width) // 5
643
+
644
+ y_grid, x_grid = torch.meshgrid(torch.arange(height), torch.arange(width))
645
+ center_x, center_y = width // 2, height // 2
646
+ mask = (x_grid - center_x) ** 2 + (y_grid - center_y) ** 2 <= radius**2
647
+
648
+ low_freq_mask = mask.unsqueeze(0).unsqueeze(0).to(x.device)
649
+ high_freq_mask = ~low_freq_mask
650
+
651
+ low_freq_fft = fft_shifted * low_freq_mask
652
+ high_freq_fft = fft_shifted * high_freq_mask
653
+
654
+ return low_freq_fft, high_freq_fft
pythonProject/.venv/Lib/site-packages/diffusers/hooks/first_block_cache.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import Tuple, Union
17
+
18
+ import torch
19
+
20
+ from ..utils import get_logger
21
+ from ..utils.torch_utils import unwrap_module
22
+ from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS
23
+ from ._helpers import TransformerBlockRegistry
24
+ from .hooks import BaseState, HookRegistry, ModelHook, StateManager
25
+
26
+
27
+ logger = get_logger(__name__) # pylint: disable=invalid-name
28
+
29
+ _FBC_LEADER_BLOCK_HOOK = "fbc_leader_block_hook"
30
+ _FBC_BLOCK_HOOK = "fbc_block_hook"
31
+
32
+
33
+ @dataclass
34
+ class FirstBlockCacheConfig:
35
+ r"""
36
+ Configuration for [First Block
37
+ Cache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching).
38
+
39
+ Args:
40
+ threshold (`float`, defaults to `0.05`):
41
+ The threshold to determine whether or not a forward pass through all layers of the model is required. A
42
+ higher threshold usually results in a forward pass through a lower number of layers and faster inference,
43
+ but might lead to poorer generation quality. A lower threshold may not result in significant generation
44
+ speedup. The threshold is compared against the absmean difference of the residuals between the current and
45
+ cached outputs from the first transformer block. If the difference is below the threshold, the forward pass
46
+ is skipped.
47
+ """
48
+
49
+ threshold: float = 0.05
50
+
51
+
52
+ class FBCSharedBlockState(BaseState):
53
+ def __init__(self) -> None:
54
+ super().__init__()
55
+
56
+ self.head_block_output: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
57
+ self.head_block_residual: torch.Tensor = None
58
+ self.tail_block_residuals: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
59
+ self.should_compute: bool = True
60
+
61
+ def reset(self):
62
+ self.tail_block_residuals = None
63
+ self.should_compute = True
64
+
65
+
66
+ class FBCHeadBlockHook(ModelHook):
67
+ _is_stateful = True
68
+
69
+ def __init__(self, state_manager: StateManager, threshold: float):
70
+ self.state_manager = state_manager
71
+ self.threshold = threshold
72
+ self._metadata = None
73
+
74
+ def initialize_hook(self, module):
75
+ unwrapped_module = unwrap_module(module)
76
+ self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__)
77
+ return module
78
+
79
+ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
80
+ original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs)
81
+
82
+ output = self.fn_ref.original_forward(*args, **kwargs)
83
+ is_output_tuple = isinstance(output, tuple)
84
+
85
+ if is_output_tuple:
86
+ hidden_states_residual = output[self._metadata.return_hidden_states_index] - original_hidden_states
87
+ else:
88
+ hidden_states_residual = output - original_hidden_states
89
+
90
+ shared_state: FBCSharedBlockState = self.state_manager.get_state()
91
+ hidden_states = encoder_hidden_states = None
92
+ should_compute = self._should_compute_remaining_blocks(hidden_states_residual)
93
+ shared_state.should_compute = should_compute
94
+
95
+ if not should_compute:
96
+ # Apply caching
97
+ if is_output_tuple:
98
+ hidden_states = (
99
+ shared_state.tail_block_residuals[0] + output[self._metadata.return_hidden_states_index]
100
+ )
101
+ else:
102
+ hidden_states = shared_state.tail_block_residuals[0] + output
103
+
104
+ if self._metadata.return_encoder_hidden_states_index is not None:
105
+ assert is_output_tuple
106
+ encoder_hidden_states = (
107
+ shared_state.tail_block_residuals[1] + output[self._metadata.return_encoder_hidden_states_index]
108
+ )
109
+
110
+ if is_output_tuple:
111
+ return_output = [None] * len(output)
112
+ return_output[self._metadata.return_hidden_states_index] = hidden_states
113
+ return_output[self._metadata.return_encoder_hidden_states_index] = encoder_hidden_states
114
+ return_output = tuple(return_output)
115
+ else:
116
+ return_output = hidden_states
117
+ output = return_output
118
+ else:
119
+ if is_output_tuple:
120
+ head_block_output = [None] * len(output)
121
+ head_block_output[0] = output[self._metadata.return_hidden_states_index]
122
+ head_block_output[1] = output[self._metadata.return_encoder_hidden_states_index]
123
+ else:
124
+ head_block_output = output
125
+ shared_state.head_block_output = head_block_output
126
+ shared_state.head_block_residual = hidden_states_residual
127
+
128
+ return output
129
+
130
+ def reset_state(self, module):
131
+ self.state_manager.reset()
132
+ return module
133
+
134
+ @torch.compiler.disable
135
+ def _should_compute_remaining_blocks(self, hidden_states_residual: torch.Tensor) -> bool:
136
+ shared_state = self.state_manager.get_state()
137
+ if shared_state.head_block_residual is None:
138
+ return True
139
+ prev_hidden_states_residual = shared_state.head_block_residual
140
+ absmean = (hidden_states_residual - prev_hidden_states_residual).abs().mean()
141
+ prev_hidden_states_absmean = prev_hidden_states_residual.abs().mean()
142
+ diff = (absmean / prev_hidden_states_absmean).item()
143
+ return diff > self.threshold
144
+
145
+
146
+ class FBCBlockHook(ModelHook):
147
+ def __init__(self, state_manager: StateManager, is_tail: bool = False):
148
+ super().__init__()
149
+ self.state_manager = state_manager
150
+ self.is_tail = is_tail
151
+ self._metadata = None
152
+
153
+ def initialize_hook(self, module):
154
+ unwrapped_module = unwrap_module(module)
155
+ self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__)
156
+ return module
157
+
158
+ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
159
+ original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs)
160
+ original_encoder_hidden_states = None
161
+ if self._metadata.return_encoder_hidden_states_index is not None:
162
+ original_encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs(
163
+ "encoder_hidden_states", args, kwargs
164
+ )
165
+
166
+ shared_state = self.state_manager.get_state()
167
+
168
+ if shared_state.should_compute:
169
+ output = self.fn_ref.original_forward(*args, **kwargs)
170
+ if self.is_tail:
171
+ hidden_states_residual = encoder_hidden_states_residual = None
172
+ if isinstance(output, tuple):
173
+ hidden_states_residual = (
174
+ output[self._metadata.return_hidden_states_index] - shared_state.head_block_output[0]
175
+ )
176
+ encoder_hidden_states_residual = (
177
+ output[self._metadata.return_encoder_hidden_states_index] - shared_state.head_block_output[1]
178
+ )
179
+ else:
180
+ hidden_states_residual = output - shared_state.head_block_output
181
+ shared_state.tail_block_residuals = (hidden_states_residual, encoder_hidden_states_residual)
182
+ return output
183
+
184
+ if original_encoder_hidden_states is None:
185
+ return_output = original_hidden_states
186
+ else:
187
+ return_output = [None, None]
188
+ return_output[self._metadata.return_hidden_states_index] = original_hidden_states
189
+ return_output[self._metadata.return_encoder_hidden_states_index] = original_encoder_hidden_states
190
+ return_output = tuple(return_output)
191
+ return return_output
192
+
193
+
194
+ def apply_first_block_cache(module: torch.nn.Module, config: FirstBlockCacheConfig) -> None:
195
+ """
196
+ Applies [First Block
197
+ Cache](https://github.com/chengzeyi/ParaAttention/blob/4de137c5b96416489f06e43e19f2c14a772e28fd/README.md#first-block-cache-our-dynamic-caching)
198
+ to a given module.
199
+
200
+ First Block Cache builds on the ideas of [TeaCache](https://huggingface.co/papers/2411.19108). It is much simpler
201
+ to implement generically for a wide range of models and has been integrated first for experimental purposes.
202
+
203
+ Args:
204
+ module (`torch.nn.Module`):
205
+ The pytorch module to apply FBCache to. Typically, this should be a transformer architecture supported in
206
+ Diffusers, such as `CogVideoXTransformer3DModel`, but external implementations may also work.
207
+ config (`FirstBlockCacheConfig`):
208
+ The configuration to use for applying the FBCache method.
209
+
210
+ Example:
211
+ ```python
212
+ >>> import torch
213
+ >>> from diffusers import CogView4Pipeline
214
+ >>> from diffusers.hooks import apply_first_block_cache, FirstBlockCacheConfig
215
+
216
+ >>> pipe = CogView4Pipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16)
217
+ >>> pipe.to("cuda")
218
+
219
+ >>> apply_first_block_cache(pipe.transformer, FirstBlockCacheConfig(threshold=0.2))
220
+
221
+ >>> prompt = "A photo of an astronaut riding a horse on mars"
222
+ >>> image = pipe(prompt, generator=torch.Generator().manual_seed(42)).images[0]
223
+ >>> image.save("output.png")
224
+ ```
225
+ """
226
+
227
+ state_manager = StateManager(FBCSharedBlockState, (), {})
228
+ remaining_blocks = []
229
+
230
+ for name, submodule in module.named_children():
231
+ if name not in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS or not isinstance(submodule, torch.nn.ModuleList):
232
+ continue
233
+ for index, block in enumerate(submodule):
234
+ remaining_blocks.append((f"{name}.{index}", block))
235
+
236
+ head_block_name, head_block = remaining_blocks.pop(0)
237
+ tail_block_name, tail_block = remaining_blocks.pop(-1)
238
+
239
+ logger.debug(f"Applying FBCHeadBlockHook to '{head_block_name}'")
240
+ _apply_fbc_head_block_hook(head_block, state_manager, config.threshold)
241
+
242
+ for name, block in remaining_blocks:
243
+ logger.debug(f"Applying FBCBlockHook to '{name}'")
244
+ _apply_fbc_block_hook(block, state_manager)
245
+
246
+ logger.debug(f"Applying FBCBlockHook to tail block '{tail_block_name}'")
247
+ _apply_fbc_block_hook(tail_block, state_manager, is_tail=True)
248
+
249
+
250
+ def _apply_fbc_head_block_hook(block: torch.nn.Module, state_manager: StateManager, threshold: float) -> None:
251
+ registry = HookRegistry.check_if_exists_or_initialize(block)
252
+ hook = FBCHeadBlockHook(state_manager, threshold)
253
+ registry.register_hook(hook, _FBC_LEADER_BLOCK_HOOK)
254
+
255
+
256
+ def _apply_fbc_block_hook(block: torch.nn.Module, state_manager: StateManager, is_tail: bool = False) -> None:
257
+ registry = HookRegistry.check_if_exists_or_initialize(block)
258
+ hook = FBCBlockHook(state_manager, is_tail)
259
+ registry.register_hook(hook, _FBC_BLOCK_HOOK)
pythonProject/.venv/Lib/site-packages/diffusers/hooks/group_offloading.py ADDED
@@ -0,0 +1,898 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import hashlib
16
+ import os
17
+ from contextlib import contextmanager, nullcontext
18
+ from dataclasses import dataclass
19
+ from enum import Enum
20
+ from typing import Dict, List, Optional, Set, Tuple, Union
21
+
22
+ import safetensors.torch
23
+ import torch
24
+
25
+ from ..utils import get_logger, is_accelerate_available
26
+ from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
27
+ from .hooks import HookRegistry, ModelHook
28
+
29
+
30
+ if is_accelerate_available():
31
+ from accelerate.hooks import AlignDevicesHook, CpuOffload
32
+ from accelerate.utils import send_to_device
33
+
34
+
35
+ logger = get_logger(__name__) # pylint: disable=invalid-name
36
+
37
+
38
+ # fmt: off
39
+ _GROUP_OFFLOADING = "group_offloading"
40
+ _LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
41
+ _LAZY_PREFETCH_GROUP_OFFLOADING = "lazy_prefetch_group_offloading"
42
+ _GROUP_ID_LAZY_LEAF = "lazy_leafs"
43
+ # fmt: on
44
+
45
+
46
+ class GroupOffloadingType(str, Enum):
47
+ BLOCK_LEVEL = "block_level"
48
+ LEAF_LEVEL = "leaf_level"
49
+
50
+
51
+ @dataclass
52
+ class GroupOffloadingConfig:
53
+ onload_device: torch.device
54
+ offload_device: torch.device
55
+ offload_type: GroupOffloadingType
56
+ non_blocking: bool
57
+ record_stream: bool
58
+ low_cpu_mem_usage: bool
59
+ num_blocks_per_group: Optional[int] = None
60
+ offload_to_disk_path: Optional[str] = None
61
+ stream: Optional[Union[torch.cuda.Stream, torch.Stream]] = None
62
+
63
+
64
+ class ModuleGroup:
65
+ def __init__(
66
+ self,
67
+ modules: List[torch.nn.Module],
68
+ offload_device: torch.device,
69
+ onload_device: torch.device,
70
+ offload_leader: torch.nn.Module,
71
+ onload_leader: Optional[torch.nn.Module] = None,
72
+ parameters: Optional[List[torch.nn.Parameter]] = None,
73
+ buffers: Optional[List[torch.Tensor]] = None,
74
+ non_blocking: bool = False,
75
+ stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
76
+ record_stream: Optional[bool] = False,
77
+ low_cpu_mem_usage: bool = False,
78
+ onload_self: bool = True,
79
+ offload_to_disk_path: Optional[str] = None,
80
+ group_id: Optional[int] = None,
81
+ ) -> None:
82
+ self.modules = modules
83
+ self.offload_device = offload_device
84
+ self.onload_device = onload_device
85
+ self.offload_leader = offload_leader
86
+ self.onload_leader = onload_leader
87
+ self.parameters = parameters or []
88
+ self.buffers = buffers or []
89
+ self.non_blocking = non_blocking or stream is not None
90
+ self.stream = stream
91
+ self.record_stream = record_stream
92
+ self.onload_self = onload_self
93
+ self.low_cpu_mem_usage = low_cpu_mem_usage
94
+
95
+ self.offload_to_disk_path = offload_to_disk_path
96
+ self._is_offloaded_to_disk = False
97
+
98
+ if self.offload_to_disk_path is not None:
99
+ # Instead of `group_id or str(id(self))` we do this because `group_id` can be "" as well.
100
+ self.group_id = group_id if group_id is not None else str(id(self))
101
+ short_hash = _compute_group_hash(self.group_id)
102
+ self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{short_hash}.safetensors")
103
+
104
+ all_tensors = []
105
+ for module in self.modules:
106
+ all_tensors.extend(list(module.parameters()))
107
+ all_tensors.extend(list(module.buffers()))
108
+ all_tensors.extend(self.parameters)
109
+ all_tensors.extend(self.buffers)
110
+ all_tensors = list(dict.fromkeys(all_tensors)) # Remove duplicates
111
+
112
+ self.tensor_to_key = {tensor: f"tensor_{i}" for i, tensor in enumerate(all_tensors)}
113
+ self.key_to_tensor = {v: k for k, v in self.tensor_to_key.items()}
114
+ self.cpu_param_dict = {}
115
+ else:
116
+ self.cpu_param_dict = self._init_cpu_param_dict()
117
+
118
+ self._torch_accelerator_module = (
119
+ getattr(torch, torch.accelerator.current_accelerator().type)
120
+ if hasattr(torch, "accelerator")
121
+ else torch.cuda
122
+ )
123
+
124
+ def _init_cpu_param_dict(self):
125
+ cpu_param_dict = {}
126
+ if self.stream is None:
127
+ return cpu_param_dict
128
+
129
+ for module in self.modules:
130
+ for param in module.parameters():
131
+ cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory()
132
+ for buffer in module.buffers():
133
+ cpu_param_dict[buffer] = (
134
+ buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory()
135
+ )
136
+
137
+ for param in self.parameters:
138
+ cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory()
139
+
140
+ for buffer in self.buffers:
141
+ cpu_param_dict[buffer] = buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory()
142
+
143
+ return cpu_param_dict
144
+
145
+ @contextmanager
146
+ def _pinned_memory_tensors(self):
147
+ try:
148
+ pinned_dict = {
149
+ param: tensor.pin_memory() if not tensor.is_pinned() else tensor
150
+ for param, tensor in self.cpu_param_dict.items()
151
+ }
152
+ yield pinned_dict
153
+ finally:
154
+ pinned_dict = None
155
+
156
+ def _transfer_tensor_to_device(self, tensor, source_tensor):
157
+ tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
158
+ if self.record_stream:
159
+ tensor.data.record_stream(self._torch_accelerator_module.current_stream())
160
+
161
+ def _process_tensors_from_modules(self, pinned_memory=None):
162
+ for group_module in self.modules:
163
+ for param in group_module.parameters():
164
+ source = pinned_memory[param] if pinned_memory else param.data
165
+ self._transfer_tensor_to_device(param, source)
166
+ for buffer in group_module.buffers():
167
+ source = pinned_memory[buffer] if pinned_memory else buffer.data
168
+ self._transfer_tensor_to_device(buffer, source)
169
+
170
+ for param in self.parameters:
171
+ source = pinned_memory[param] if pinned_memory else param.data
172
+ self._transfer_tensor_to_device(param, source)
173
+
174
+ for buffer in self.buffers:
175
+ source = pinned_memory[buffer] if pinned_memory else buffer.data
176
+ self._transfer_tensor_to_device(buffer, source)
177
+
178
+ def _onload_from_disk(self):
179
+ if self.stream is not None:
180
+ # Wait for previous Host->Device transfer to complete
181
+ self.stream.synchronize()
182
+
183
+ context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream)
184
+ current_stream = self._torch_accelerator_module.current_stream() if self.record_stream else None
185
+
186
+ with context:
187
+ # Load to CPU (if using streams) or directly to target device, pin, and async copy to device
188
+ device = str(self.onload_device) if self.stream is None else "cpu"
189
+ loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=device)
190
+
191
+ if self.stream is not None:
192
+ for key, tensor_obj in self.key_to_tensor.items():
193
+ pinned_tensor = loaded_tensors[key].pin_memory()
194
+ tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking)
195
+ if self.record_stream:
196
+ tensor_obj.data.record_stream(current_stream)
197
+ else:
198
+ onload_device = (
199
+ self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
200
+ )
201
+ loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
202
+ for key, tensor_obj in self.key_to_tensor.items():
203
+ tensor_obj.data = loaded_tensors[key]
204
+
205
+ def _onload_from_memory(self):
206
+ if self.stream is not None:
207
+ # Wait for previous Host->Device transfer to complete
208
+ self.stream.synchronize()
209
+
210
+ context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream)
211
+ with context:
212
+ if self.stream is not None:
213
+ with self._pinned_memory_tensors() as pinned_memory:
214
+ self._process_tensors_from_modules(pinned_memory)
215
+ else:
216
+ self._process_tensors_from_modules(None)
217
+
218
+ def _offload_to_disk(self):
219
+ # TODO: we can potentially optimize this code path by checking if the _all_ the desired
220
+ # safetensor files exist on the disk and if so, skip this step entirely, reducing IO
221
+ # overhead. Currently, we just check if the given `safetensors_file_path` exists and if not
222
+ # we perform a write.
223
+ # Check if the file has been saved in this session or if it already exists on disk.
224
+ if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path):
225
+ os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True)
226
+ tensors_to_save = {key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()}
227
+ safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path)
228
+
229
+ # The group is now considered offloaded to disk for the rest of the session.
230
+ self._is_offloaded_to_disk = True
231
+
232
+ # We do this to free up the RAM which is still holding the up tensor data.
233
+ for tensor_obj in self.tensor_to_key.keys():
234
+ tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device)
235
+
236
+ def _offload_to_memory(self):
237
+ if self.stream is not None:
238
+ if not self.record_stream:
239
+ self._torch_accelerator_module.current_stream().synchronize()
240
+
241
+ for group_module in self.modules:
242
+ for param in group_module.parameters():
243
+ param.data = self.cpu_param_dict[param]
244
+ for param in self.parameters:
245
+ param.data = self.cpu_param_dict[param]
246
+ for buffer in self.buffers:
247
+ buffer.data = self.cpu_param_dict[buffer]
248
+ else:
249
+ for group_module in self.modules:
250
+ group_module.to(self.offload_device, non_blocking=False)
251
+ for param in self.parameters:
252
+ param.data = param.data.to(self.offload_device, non_blocking=False)
253
+ for buffer in self.buffers:
254
+ buffer.data = buffer.data.to(self.offload_device, non_blocking=False)
255
+
256
+ @torch.compiler.disable()
257
+ def onload_(self):
258
+ r"""Onloads the group of parameters to the onload_device."""
259
+ if self.offload_to_disk_path is not None:
260
+ self._onload_from_disk()
261
+ else:
262
+ self._onload_from_memory()
263
+
264
+ @torch.compiler.disable()
265
+ def offload_(self):
266
+ r"""Offloads the group of parameters to the offload_device."""
267
+ if self.offload_to_disk_path:
268
+ self._offload_to_disk()
269
+ else:
270
+ self._offload_to_memory()
271
+
272
+
273
+ class GroupOffloadingHook(ModelHook):
274
+ r"""
275
+ A hook that offloads groups of torch.nn.Module to the CPU for storage and onloads to accelerator device for
276
+ computation. Each group has one "onload leader" module that is responsible for onloading, and an "offload leader"
277
+ module that is responsible for offloading. If prefetching is enabled, the onload leader of the previous module
278
+ group is responsible for onloading the current module group.
279
+ """
280
+
281
+ _is_stateful = False
282
+
283
+ def __init__(self, group: ModuleGroup, *, config: GroupOffloadingConfig) -> None:
284
+ self.group = group
285
+ self.next_group: Optional[ModuleGroup] = None
286
+ self.config = config
287
+
288
+ def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
289
+ if self.group.offload_leader == module:
290
+ self.group.offload_()
291
+ return module
292
+
293
+ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
294
+ # If there wasn't an onload_leader assigned, we assume that the submodule that first called its forward
295
+ # method is the onload_leader of the group.
296
+ if self.group.onload_leader is None:
297
+ self.group.onload_leader = module
298
+
299
+ # If the current module is the onload_leader of the group, we onload the group if it is supposed
300
+ # to onload itself. In the case of using prefetching with streams, we onload the next group if
301
+ # it is not supposed to onload itself.
302
+ if self.group.onload_leader == module:
303
+ if self.group.onload_self:
304
+ self.group.onload_()
305
+
306
+ should_onload_next_group = self.next_group is not None and not self.next_group.onload_self
307
+ if should_onload_next_group:
308
+ self.next_group.onload_()
309
+
310
+ should_synchronize = (
311
+ not self.group.onload_self and self.group.stream is not None and not should_onload_next_group
312
+ )
313
+ if should_synchronize:
314
+ # If this group didn't onload itself, it means it was asynchronously onloaded by the
315
+ # previous group. We need to synchronize the side stream to ensure parameters
316
+ # are completely loaded to proceed with forward pass. Without this, uninitialized
317
+ # weights will be used in the computation, leading to incorrect results
318
+ # Also, we should only do this synchronization if we don't already do it from the sync call in
319
+ # self.next_group.onload_, hence the `not should_onload_next_group` check.
320
+ self.group.stream.synchronize()
321
+
322
+ args = send_to_device(args, self.group.onload_device, non_blocking=self.group.non_blocking)
323
+ kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking)
324
+ return args, kwargs
325
+
326
+ def post_forward(self, module: torch.nn.Module, output):
327
+ if self.group.offload_leader == module:
328
+ self.group.offload_()
329
+ return output
330
+
331
+
332
+ class LazyPrefetchGroupOffloadingHook(ModelHook):
333
+ r"""
334
+ A hook, used in conjunction with GroupOffloadingHook, that applies lazy prefetching to groups of torch.nn.Module.
335
+ This hook is used to determine the order in which the layers are executed during the forward pass. Once the layer
336
+ invocation order is known, assignments of the next_group attribute for prefetching can be made, which allows
337
+ prefetching groups in the correct order.
338
+ """
339
+
340
+ _is_stateful = False
341
+
342
+ def __init__(self):
343
+ self.execution_order: List[Tuple[str, torch.nn.Module]] = []
344
+ self._layer_execution_tracker_module_names = set()
345
+
346
+ def initialize_hook(self, module):
347
+ def make_execution_order_update_callback(current_name, current_submodule):
348
+ def callback():
349
+ if not torch.compiler.is_compiling():
350
+ logger.debug(f"Adding {current_name} to the execution order")
351
+ self.execution_order.append((current_name, current_submodule))
352
+
353
+ return callback
354
+
355
+ # To every submodule that contains a group offloading hook (at this point, no prefetching is enabled for any
356
+ # of the groups), we add a layer execution tracker hook that will be used to determine the order in which the
357
+ # layers are executed during the forward pass.
358
+ for name, submodule in module.named_modules():
359
+ if name == "" or not hasattr(submodule, "_diffusers_hook"):
360
+ continue
361
+
362
+ registry = HookRegistry.check_if_exists_or_initialize(submodule)
363
+ group_offloading_hook = registry.get_hook(_GROUP_OFFLOADING)
364
+
365
+ if group_offloading_hook is not None:
366
+ # For the first forward pass, we have to load in a blocking manner
367
+ group_offloading_hook.group.non_blocking = False
368
+ layer_tracker_hook = LayerExecutionTrackerHook(make_execution_order_update_callback(name, submodule))
369
+ registry.register_hook(layer_tracker_hook, _LAYER_EXECUTION_TRACKER)
370
+ self._layer_execution_tracker_module_names.add(name)
371
+
372
+ return module
373
+
374
+ def post_forward(self, module, output):
375
+ # At this point, for the current modules' submodules, we know the execution order of the layers. We can now
376
+ # remove the layer execution tracker hooks and apply prefetching by setting the next_group attribute for each
377
+ # group offloading hook.
378
+ num_executed = len(self.execution_order)
379
+ execution_order_module_names = {name for name, _ in self.execution_order}
380
+
381
+ # It may be possible that some layers were not executed during the forward pass. This can happen if the layer
382
+ # is not used in the forward pass, or if the layer is not executed due to some other reason. In such cases, we
383
+ # may not be able to apply prefetching in the correct order, which can lead to device-mismatch related errors
384
+ # if the missing layers end up being executed in the future.
385
+ if execution_order_module_names != self._layer_execution_tracker_module_names:
386
+ unexecuted_layers = list(self._layer_execution_tracker_module_names - execution_order_module_names)
387
+ if not torch.compiler.is_compiling():
388
+ logger.warning(
389
+ "It seems like some layers were not executed during the forward pass. This may lead to problems when "
390
+ "applying lazy prefetching with automatic tracing and lead to device-mismatch related errors. Please "
391
+ "make sure that all layers are executed during the forward pass. The following layers were not executed:\n"
392
+ f"{unexecuted_layers=}"
393
+ )
394
+
395
+ # Remove the layer execution tracker hooks from the submodules
396
+ base_module_registry = module._diffusers_hook
397
+ registries = [submodule._diffusers_hook for _, submodule in self.execution_order]
398
+ group_offloading_hooks = [registry.get_hook(_GROUP_OFFLOADING) for registry in registries]
399
+
400
+ for i in range(num_executed):
401
+ registries[i].remove_hook(_LAYER_EXECUTION_TRACKER, recurse=False)
402
+
403
+ # Remove the current lazy prefetch group offloading hook so that it doesn't interfere with the next forward pass
404
+ base_module_registry.remove_hook(_LAZY_PREFETCH_GROUP_OFFLOADING, recurse=False)
405
+
406
+ # LazyPrefetchGroupOffloadingHook is only used with streams, so we know that non_blocking should be True.
407
+ # We disable non_blocking for the first forward pass, but need to enable it for the subsequent passes to
408
+ # see the benefits of prefetching.
409
+ for hook in group_offloading_hooks:
410
+ hook.group.non_blocking = True
411
+
412
+ # Set required attributes for prefetching
413
+ if num_executed > 0:
414
+ base_module_group_offloading_hook = base_module_registry.get_hook(_GROUP_OFFLOADING)
415
+ base_module_group_offloading_hook.next_group = group_offloading_hooks[0].group
416
+ base_module_group_offloading_hook.next_group.onload_self = False
417
+
418
+ for i in range(num_executed - 1):
419
+ name1, _ = self.execution_order[i]
420
+ name2, _ = self.execution_order[i + 1]
421
+ if not torch.compiler.is_compiling():
422
+ logger.debug(f"Applying lazy prefetch group offloading from {name1} to {name2}")
423
+ group_offloading_hooks[i].next_group = group_offloading_hooks[i + 1].group
424
+ group_offloading_hooks[i].next_group.onload_self = False
425
+
426
+ return output
427
+
428
+
429
+ class LayerExecutionTrackerHook(ModelHook):
430
+ r"""
431
+ A hook that tracks the order in which the layers are executed during the forward pass by calling back to the
432
+ LazyPrefetchGroupOffloadingHook to update the execution order.
433
+ """
434
+
435
+ _is_stateful = False
436
+
437
+ def __init__(self, execution_order_update_callback):
438
+ self.execution_order_update_callback = execution_order_update_callback
439
+
440
+ def pre_forward(self, module, *args, **kwargs):
441
+ self.execution_order_update_callback()
442
+ return args, kwargs
443
+
444
+
445
+ def apply_group_offloading(
446
+ module: torch.nn.Module,
447
+ onload_device: Union[str, torch.device],
448
+ offload_device: Union[str, torch.device] = torch.device("cpu"),
449
+ offload_type: Union[str, GroupOffloadingType] = "block_level",
450
+ num_blocks_per_group: Optional[int] = None,
451
+ non_blocking: bool = False,
452
+ use_stream: bool = False,
453
+ record_stream: bool = False,
454
+ low_cpu_mem_usage: bool = False,
455
+ offload_to_disk_path: Optional[str] = None,
456
+ ) -> None:
457
+ r"""
458
+ Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
459
+ where it is beneficial, we need to first provide some context on how other supported offloading methods work.
460
+
461
+ Typically, offloading is done at two levels:
462
+ - Module-level: In Diffusers, this can be enabled using the `ModelMixin::enable_model_cpu_offload()` method. It
463
+ works by offloading each component of a pipeline to the CPU for storage, and onloading to the accelerator device
464
+ when needed for computation. This method is more memory-efficient than keeping all components on the accelerator,
465
+ but the memory requirements are still quite high. For this method to work, one needs memory equivalent to size of
466
+ the model in runtime dtype + size of largest intermediate activation tensors to be able to complete the forward
467
+ pass.
468
+ - Leaf-level: In Diffusers, this can be enabled using the `ModelMixin::enable_sequential_cpu_offload()` method. It
469
+ works by offloading the lowest leaf-level parameters of the computation graph to the CPU for storage, and
470
+ onloading only the leafs to the accelerator device for computation. This uses the lowest amount of accelerator
471
+ memory, but can be slower due to the excessive number of device synchronizations.
472
+
473
+ Group offloading is a middle ground between the two methods. It works by offloading groups of internal layers,
474
+ (either `torch.nn.ModuleList` or `torch.nn.Sequential`). This method uses lower memory than module-level
475
+ offloading. It is also faster than leaf-level/sequential offloading, as the number of device synchronizations is
476
+ reduced.
477
+
478
+ Another supported feature (for CUDA devices with support for asynchronous data transfer streams) is the ability to
479
+ overlap data transfer and computation to reduce the overall execution time compared to sequential offloading. This
480
+ is enabled using layer prefetching with streams, i.e., the layer that is to be executed next starts onloading to
481
+ the accelerator device while the current layer is being executed - this increases the memory requirements slightly.
482
+ Note that this implementation also supports leaf-level offloading but can be made much faster when using streams.
483
+
484
+ Args:
485
+ module (`torch.nn.Module`):
486
+ The module to which group offloading is applied.
487
+ onload_device (`torch.device`):
488
+ The device to which the group of modules are onloaded.
489
+ offload_device (`torch.device`, defaults to `torch.device("cpu")`):
490
+ The device to which the group of modules are offloaded. This should typically be the CPU. Default is CPU.
491
+ offload_type (`str` or `GroupOffloadingType`, defaults to "block_level"):
492
+ The type of offloading to be applied. Can be one of "block_level" or "leaf_level". Default is
493
+ "block_level".
494
+ offload_to_disk_path (`str`, *optional*, defaults to `None`):
495
+ The path to the directory where parameters will be offloaded. Setting this option can be useful in limited
496
+ RAM environment settings where a reasonable speed-memory trade-off is desired.
497
+ num_blocks_per_group (`int`, *optional*):
498
+ The number of blocks per group when using offload_type="block_level". This is required when using
499
+ offload_type="block_level".
500
+ non_blocking (`bool`, defaults to `False`):
501
+ If True, offloading and onloading is done with non-blocking data transfer.
502
+ use_stream (`bool`, defaults to `False`):
503
+ If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for
504
+ overlapping computation and data transfer.
505
+ record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
506
+ as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the
507
+ [PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more
508
+ details.
509
+ low_cpu_mem_usage (`bool`, defaults to `False`):
510
+ If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
511
+ option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
512
+ the CPU memory is a bottleneck but may counteract the benefits of using streams.
513
+
514
+ Example:
515
+ ```python
516
+ >>> from diffusers import CogVideoXTransformer3DModel
517
+ >>> from diffusers.hooks import apply_group_offloading
518
+
519
+ >>> transformer = CogVideoXTransformer3DModel.from_pretrained(
520
+ ... "THUDM/CogVideoX-5b", subfolder="transformer", torch_dtype=torch.bfloat16
521
+ ... )
522
+
523
+ >>> apply_group_offloading(
524
+ ... transformer,
525
+ ... onload_device=torch.device("cuda"),
526
+ ... offload_device=torch.device("cpu"),
527
+ ... offload_type="block_level",
528
+ ... num_blocks_per_group=2,
529
+ ... use_stream=True,
530
+ ... )
531
+ ```
532
+ """
533
+
534
+ onload_device = torch.device(onload_device) if isinstance(onload_device, str) else onload_device
535
+ offload_device = torch.device(offload_device) if isinstance(offload_device, str) else offload_device
536
+ offload_type = GroupOffloadingType(offload_type)
537
+
538
+ stream = None
539
+ if use_stream:
540
+ if torch.cuda.is_available():
541
+ stream = torch.cuda.Stream()
542
+ elif hasattr(torch, "xpu") and torch.xpu.is_available():
543
+ stream = torch.Stream()
544
+ else:
545
+ raise ValueError("Using streams for data transfer requires a CUDA device, or an Intel XPU device.")
546
+
547
+ if not use_stream and record_stream:
548
+ raise ValueError("`record_stream` cannot be True when `use_stream=False`.")
549
+ if offload_type == GroupOffloadingType.BLOCK_LEVEL and num_blocks_per_group is None:
550
+ raise ValueError("`num_blocks_per_group` must be provided when using `offload_type='block_level'.")
551
+
552
+ _raise_error_if_accelerate_model_or_sequential_hook_present(module)
553
+
554
+ config = GroupOffloadingConfig(
555
+ onload_device=onload_device,
556
+ offload_device=offload_device,
557
+ offload_type=offload_type,
558
+ num_blocks_per_group=num_blocks_per_group,
559
+ non_blocking=non_blocking,
560
+ stream=stream,
561
+ record_stream=record_stream,
562
+ low_cpu_mem_usage=low_cpu_mem_usage,
563
+ offload_to_disk_path=offload_to_disk_path,
564
+ )
565
+ _apply_group_offloading(module, config)
566
+
567
+
568
+ def _apply_group_offloading(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
569
+ if config.offload_type == GroupOffloadingType.BLOCK_LEVEL:
570
+ _apply_group_offloading_block_level(module, config)
571
+ elif config.offload_type == GroupOffloadingType.LEAF_LEVEL:
572
+ _apply_group_offloading_leaf_level(module, config)
573
+ else:
574
+ assert False
575
+
576
+
577
+ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
578
+ r"""
579
+ This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
580
+ the "leaf_level" offloading, which is more fine-grained, this offloading is done at the top-level blocks.
581
+ """
582
+
583
+ if config.stream is not None and config.num_blocks_per_group != 1:
584
+ logger.warning(
585
+ f"Using streams is only supported for num_blocks_per_group=1. Got {config.num_blocks_per_group=}. Setting it to 1."
586
+ )
587
+ config.num_blocks_per_group = 1
588
+
589
+ # Create module groups for ModuleList and Sequential blocks
590
+ modules_with_group_offloading = set()
591
+ unmatched_modules = []
592
+ matched_module_groups = []
593
+ for name, submodule in module.named_children():
594
+ if not isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
595
+ unmatched_modules.append((name, submodule))
596
+ modules_with_group_offloading.add(name)
597
+ continue
598
+
599
+ for i in range(0, len(submodule), config.num_blocks_per_group):
600
+ current_modules = submodule[i : i + config.num_blocks_per_group]
601
+ group_id = f"{name}_{i}_{i + len(current_modules) - 1}"
602
+ group = ModuleGroup(
603
+ modules=current_modules,
604
+ offload_device=config.offload_device,
605
+ onload_device=config.onload_device,
606
+ offload_to_disk_path=config.offload_to_disk_path,
607
+ offload_leader=current_modules[-1],
608
+ onload_leader=current_modules[0],
609
+ non_blocking=config.non_blocking,
610
+ stream=config.stream,
611
+ record_stream=config.record_stream,
612
+ low_cpu_mem_usage=config.low_cpu_mem_usage,
613
+ onload_self=True,
614
+ group_id=group_id,
615
+ )
616
+ matched_module_groups.append(group)
617
+ for j in range(i, i + len(current_modules)):
618
+ modules_with_group_offloading.add(f"{name}.{j}")
619
+
620
+ # Apply group offloading hooks to the module groups
621
+ for i, group in enumerate(matched_module_groups):
622
+ for group_module in group.modules:
623
+ _apply_group_offloading_hook(group_module, group, config=config)
624
+
625
+ # Parameters and Buffers of the top-level module need to be offloaded/onloaded separately
626
+ # when the forward pass of this module is called. This is because the top-level module is not
627
+ # part of any group (as doing so would lead to no VRAM savings).
628
+ parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading)
629
+ buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading)
630
+ parameters = [param for _, param in parameters]
631
+ buffers = [buffer for _, buffer in buffers]
632
+
633
+ # Create a group for the unmatched submodules of the top-level module so that they are on the correct
634
+ # device when the forward pass is called.
635
+ unmatched_modules = [unmatched_module for _, unmatched_module in unmatched_modules]
636
+ unmatched_group = ModuleGroup(
637
+ modules=unmatched_modules,
638
+ offload_device=config.offload_device,
639
+ onload_device=config.onload_device,
640
+ offload_to_disk_path=config.offload_to_disk_path,
641
+ offload_leader=module,
642
+ onload_leader=module,
643
+ parameters=parameters,
644
+ buffers=buffers,
645
+ non_blocking=False,
646
+ stream=None,
647
+ record_stream=False,
648
+ onload_self=True,
649
+ group_id=f"{module.__class__.__name__}_unmatched_group",
650
+ )
651
+ if config.stream is None:
652
+ _apply_group_offloading_hook(module, unmatched_group, config=config)
653
+ else:
654
+ _apply_lazy_group_offloading_hook(module, unmatched_group, config=config)
655
+
656
+
657
+ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
658
+ r"""
659
+ This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
660
+ requirements. However, it can be slower compared to other offloading methods due to the excessive number of device
661
+ synchronizations. When using devices that support streams to overlap data transfer and computation, this method can
662
+ reduce memory usage without any performance degradation.
663
+ """
664
+ # Create module groups for leaf modules and apply group offloading hooks
665
+ modules_with_group_offloading = set()
666
+ for name, submodule in module.named_modules():
667
+ if not isinstance(submodule, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
668
+ continue
669
+ group = ModuleGroup(
670
+ modules=[submodule],
671
+ offload_device=config.offload_device,
672
+ onload_device=config.onload_device,
673
+ offload_to_disk_path=config.offload_to_disk_path,
674
+ offload_leader=submodule,
675
+ onload_leader=submodule,
676
+ non_blocking=config.non_blocking,
677
+ stream=config.stream,
678
+ record_stream=config.record_stream,
679
+ low_cpu_mem_usage=config.low_cpu_mem_usage,
680
+ onload_self=True,
681
+ group_id=name,
682
+ )
683
+ _apply_group_offloading_hook(submodule, group, config=config)
684
+ modules_with_group_offloading.add(name)
685
+
686
+ # Parameters and Buffers at all non-leaf levels need to be offloaded/onloaded separately when the forward pass
687
+ # of the module is called
688
+ module_dict = dict(module.named_modules())
689
+ parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading)
690
+ buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading)
691
+
692
+ # Find closest module parent for each parameter and buffer, and attach group hooks
693
+ parent_to_parameters = {}
694
+ for name, param in parameters:
695
+ parent_name = _find_parent_module_in_module_dict(name, module_dict)
696
+ if parent_name in parent_to_parameters:
697
+ parent_to_parameters[parent_name].append(param)
698
+ else:
699
+ parent_to_parameters[parent_name] = [param]
700
+
701
+ parent_to_buffers = {}
702
+ for name, buffer in buffers:
703
+ parent_name = _find_parent_module_in_module_dict(name, module_dict)
704
+ if parent_name in parent_to_buffers:
705
+ parent_to_buffers[parent_name].append(buffer)
706
+ else:
707
+ parent_to_buffers[parent_name] = [buffer]
708
+
709
+ parent_names = set(parent_to_parameters.keys()) | set(parent_to_buffers.keys())
710
+ for name in parent_names:
711
+ parameters = parent_to_parameters.get(name, [])
712
+ buffers = parent_to_buffers.get(name, [])
713
+ parent_module = module_dict[name]
714
+ group = ModuleGroup(
715
+ modules=[],
716
+ offload_device=config.offload_device,
717
+ onload_device=config.onload_device,
718
+ offload_leader=parent_module,
719
+ onload_leader=parent_module,
720
+ offload_to_disk_path=config.offload_to_disk_path,
721
+ parameters=parameters,
722
+ buffers=buffers,
723
+ non_blocking=config.non_blocking,
724
+ stream=config.stream,
725
+ record_stream=config.record_stream,
726
+ low_cpu_mem_usage=config.low_cpu_mem_usage,
727
+ onload_self=True,
728
+ group_id=name,
729
+ )
730
+ _apply_group_offloading_hook(parent_module, group, config=config)
731
+
732
+ if config.stream is not None:
733
+ # When using streams, we need to know the layer execution order for applying prefetching (to overlap data transfer
734
+ # and computation). Since we don't know the order beforehand, we apply a lazy prefetching hook that will find the
735
+ # execution order and apply prefetching in the correct order.
736
+ unmatched_group = ModuleGroup(
737
+ modules=[],
738
+ offload_device=config.offload_device,
739
+ onload_device=config.onload_device,
740
+ offload_to_disk_path=config.offload_to_disk_path,
741
+ offload_leader=module,
742
+ onload_leader=module,
743
+ parameters=None,
744
+ buffers=None,
745
+ non_blocking=False,
746
+ stream=None,
747
+ record_stream=False,
748
+ low_cpu_mem_usage=config.low_cpu_mem_usage,
749
+ onload_self=True,
750
+ group_id=_GROUP_ID_LAZY_LEAF,
751
+ )
752
+ _apply_lazy_group_offloading_hook(module, unmatched_group, config=config)
753
+
754
+
755
+ def _apply_group_offloading_hook(
756
+ module: torch.nn.Module,
757
+ group: ModuleGroup,
758
+ *,
759
+ config: GroupOffloadingConfig,
760
+ ) -> None:
761
+ registry = HookRegistry.check_if_exists_or_initialize(module)
762
+
763
+ # We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
764
+ # is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
765
+ if registry.get_hook(_GROUP_OFFLOADING) is None:
766
+ hook = GroupOffloadingHook(group, config=config)
767
+ registry.register_hook(hook, _GROUP_OFFLOADING)
768
+
769
+
770
+ def _apply_lazy_group_offloading_hook(
771
+ module: torch.nn.Module,
772
+ group: ModuleGroup,
773
+ *,
774
+ config: GroupOffloadingConfig,
775
+ ) -> None:
776
+ registry = HookRegistry.check_if_exists_or_initialize(module)
777
+
778
+ # We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
779
+ # is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
780
+ if registry.get_hook(_GROUP_OFFLOADING) is None:
781
+ hook = GroupOffloadingHook(group, config=config)
782
+ registry.register_hook(hook, _GROUP_OFFLOADING)
783
+
784
+ lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook()
785
+ registry.register_hook(lazy_prefetch_hook, _LAZY_PREFETCH_GROUP_OFFLOADING)
786
+
787
+
788
+ def _gather_parameters_with_no_group_offloading_parent(
789
+ module: torch.nn.Module, modules_with_group_offloading: Set[str]
790
+ ) -> List[torch.nn.Parameter]:
791
+ parameters = []
792
+ for name, parameter in module.named_parameters():
793
+ has_parent_with_group_offloading = False
794
+ atoms = name.split(".")
795
+ while len(atoms) > 0:
796
+ parent_name = ".".join(atoms)
797
+ if parent_name in modules_with_group_offloading:
798
+ has_parent_with_group_offloading = True
799
+ break
800
+ atoms.pop()
801
+ if not has_parent_with_group_offloading:
802
+ parameters.append((name, parameter))
803
+ return parameters
804
+
805
+
806
+ def _gather_buffers_with_no_group_offloading_parent(
807
+ module: torch.nn.Module, modules_with_group_offloading: Set[str]
808
+ ) -> List[torch.Tensor]:
809
+ buffers = []
810
+ for name, buffer in module.named_buffers():
811
+ has_parent_with_group_offloading = False
812
+ atoms = name.split(".")
813
+ while len(atoms) > 0:
814
+ parent_name = ".".join(atoms)
815
+ if parent_name in modules_with_group_offloading:
816
+ has_parent_with_group_offloading = True
817
+ break
818
+ atoms.pop()
819
+ if not has_parent_with_group_offloading:
820
+ buffers.append((name, buffer))
821
+ return buffers
822
+
823
+
824
+ def _find_parent_module_in_module_dict(name: str, module_dict: Dict[str, torch.nn.Module]) -> str:
825
+ atoms = name.split(".")
826
+ while len(atoms) > 0:
827
+ parent_name = ".".join(atoms)
828
+ if parent_name in module_dict:
829
+ return parent_name
830
+ atoms.pop()
831
+ return ""
832
+
833
+
834
+ def _raise_error_if_accelerate_model_or_sequential_hook_present(module: torch.nn.Module) -> None:
835
+ if not is_accelerate_available():
836
+ return
837
+ for name, submodule in module.named_modules():
838
+ if not hasattr(submodule, "_hf_hook"):
839
+ continue
840
+ if isinstance(submodule._hf_hook, (AlignDevicesHook, CpuOffload)):
841
+ raise ValueError(
842
+ f"Cannot apply group offloading to a module that is already applying an alternative "
843
+ f"offloading strategy from Accelerate. If you want to apply group offloading, please "
844
+ f"disable the existing offloading strategy first. Offending module: {name} ({type(submodule)})"
845
+ )
846
+
847
+
848
+ def _get_top_level_group_offload_hook(module: torch.nn.Module) -> Optional[GroupOffloadingHook]:
849
+ for submodule in module.modules():
850
+ if hasattr(submodule, "_diffusers_hook"):
851
+ group_offloading_hook = submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING)
852
+ if group_offloading_hook is not None:
853
+ return group_offloading_hook
854
+ return None
855
+
856
+
857
+ def _is_group_offload_enabled(module: torch.nn.Module) -> bool:
858
+ top_level_group_offload_hook = _get_top_level_group_offload_hook(module)
859
+ return top_level_group_offload_hook is not None
860
+
861
+
862
+ def _get_group_onload_device(module: torch.nn.Module) -> torch.device:
863
+ top_level_group_offload_hook = _get_top_level_group_offload_hook(module)
864
+ if top_level_group_offload_hook is not None:
865
+ return top_level_group_offload_hook.config.onload_device
866
+ raise ValueError("Group offloading is not enabled for the provided module.")
867
+
868
+
869
+ def _compute_group_hash(group_id):
870
+ hashed_id = hashlib.sha256(group_id.encode("utf-8")).hexdigest()
871
+ # first 16 characters for a reasonably short but unique name
872
+ return hashed_id[:16]
873
+
874
+
875
+ def _maybe_remove_and_reapply_group_offloading(module: torch.nn.Module) -> None:
876
+ r"""
877
+ Removes the group offloading hook from the module and re-applies it. This is useful when the module has been
878
+ modified in-place and the group offloading hook references-to-tensors needs to be updated. The in-place
879
+ modification can happen in a number of ways, for example, fusing QKV or unloading/loading LoRAs on-the-fly.
880
+
881
+ In this implementation, we make an assumption that group offloading has only been applied at the top-level module,
882
+ and therefore all submodules have the same onload and offload devices. If this assumption is not true, say in the
883
+ case where user has applied group offloading at multiple levels, this function will not work as expected.
884
+
885
+ There is some performance penalty associated with doing this when non-default streams are used, because we need to
886
+ retrace the execution order of the layers with `LazyPrefetchGroupOffloadingHook`.
887
+ """
888
+ top_level_group_offload_hook = _get_top_level_group_offload_hook(module)
889
+
890
+ if top_level_group_offload_hook is None:
891
+ return
892
+
893
+ registry = HookRegistry.check_if_exists_or_initialize(module)
894
+ registry.remove_hook(_GROUP_OFFLOADING, recurse=True)
895
+ registry.remove_hook(_LAYER_EXECUTION_TRACKER, recurse=True)
896
+ registry.remove_hook(_LAZY_PREFETCH_GROUP_OFFLOADING, recurse=True)
897
+
898
+ _apply_group_offloading(module, top_level_group_offload_hook.config)