BryanW commited on
Commit
204b9ca
·
verified ·
1 Parent(s): 1611dcd

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/__pycache__/__init__.cpython-312.pyc +0 -0
  2. URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/__pycache__/callbacks.cpython-312.pyc +0 -0
  3. URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/__pycache__/configuration_utils.cpython-312.pyc +0 -0
  4. URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/__pycache__/dependency_versions_check.cpython-312.pyc +0 -0
  5. URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/__pycache__/dependency_versions_table.cpython-312.pyc +0 -0
  6. URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/__pycache__/image_processor.cpython-312.pyc +0 -0
  7. URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/__pycache__/optimization.cpython-312.pyc +0 -0
  8. URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/__pycache__/training_utils.cpython-312.pyc +0 -0
  9. URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/__pycache__/video_processor.cpython-312.pyc +0 -0
  10. URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/commands/__init__.py +27 -0
  11. URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/commands/custom_blocks.py +132 -0
  12. URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/commands/diffusers_cli.py +45 -0
  13. URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/commands/env.py +180 -0
  14. URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/commands/fp16_safetensors.py +132 -0
  15. URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/experimental/__init__.py +1 -0
  16. URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/__init__.py +31 -0
  17. URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/adaptive_projected_guidance.py +237 -0
  18. URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/adaptive_projected_guidance_mix.py +297 -0
  19. URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/auto_guidance.py +198 -0
  20. URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/classifier_free_guidance.py +156 -0
  21. URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/classifier_free_zero_star_guidance.py +164 -0
  22. URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/frequency_decoupled_guidance.py +335 -0
  23. URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/guider_utils.py +396 -0
  24. URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/magnitude_aware_guidance.py +159 -0
  25. URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/perturbed_attention_guidance.py +289 -0
  26. URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/skip_layer_guidance.py +280 -0
  27. URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/smoothed_energy_guidance.py +269 -0
  28. URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/tangential_classifier_free_guidance.py +151 -0
  29. URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/__init__.py +29 -0
  30. URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/__pycache__/__init__.cpython-312.pyc +0 -0
  31. URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/__pycache__/_common.cpython-312.pyc +0 -0
  32. URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/__pycache__/_helpers.cpython-312.pyc +0 -0
  33. URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/__pycache__/context_parallel.cpython-312.pyc +0 -0
  34. URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/__pycache__/faster_cache.cpython-312.pyc +0 -0
  35. URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/__pycache__/first_block_cache.cpython-312.pyc +0 -0
  36. URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/__pycache__/group_offloading.cpython-312.pyc +0 -0
  37. URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/__pycache__/hooks.cpython-312.pyc +0 -0
  38. URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/__pycache__/layer_skip.cpython-312.pyc +0 -0
  39. URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/__pycache__/layerwise_casting.cpython-312.pyc +0 -0
  40. URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/__pycache__/mag_cache.cpython-312.pyc +0 -0
  41. URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/__pycache__/pyramid_attention_broadcast.cpython-312.pyc +0 -0
  42. URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/__pycache__/smoothed_energy_guidance_utils.cpython-312.pyc +0 -0
  43. URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/__pycache__/taylorseer_cache.cpython-312.pyc +0 -0
  44. URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/__pycache__/utils.cpython-312.pyc +0 -0
  45. URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/_common.py +61 -0
  46. URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/_helpers.py +381 -0
  47. URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/context_parallel.py +383 -0
  48. URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/faster_cache.py +654 -0
  49. URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/first_block_cache.py +258 -0
  50. URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/group_offloading.py +966 -0
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (38.8 kB). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/__pycache__/callbacks.cpython-312.pyc ADDED
Binary file (10.9 kB). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/__pycache__/configuration_utils.cpython-312.pyc ADDED
Binary file (38.8 kB). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/__pycache__/dependency_versions_check.cpython-312.pyc ADDED
Binary file (917 Bytes). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/__pycache__/dependency_versions_table.cpython-312.pyc ADDED
Binary file (2.29 kB). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/__pycache__/image_processor.cpython-312.pyc ADDED
Binary file (65.1 kB). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/__pycache__/optimization.cpython-312.pyc ADDED
Binary file (15.9 kB). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/__pycache__/training_utils.cpython-312.pyc ADDED
Binary file (40.2 kB). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/__pycache__/video_processor.cpython-312.pyc ADDED
Binary file (9.41 kB). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/commands/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from abc import ABC, abstractmethod
16
+ from argparse import ArgumentParser
17
+
18
+
19
+ class BaseDiffusersCLICommand(ABC):
20
+ @staticmethod
21
+ @abstractmethod
22
+ def register_subcommand(parser: ArgumentParser):
23
+ raise NotImplementedError()
24
+
25
+ @abstractmethod
26
+ def run(self):
27
+ raise NotImplementedError()
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/commands/custom_blocks.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Usage example:
17
+ TODO
18
+ """
19
+
20
+ import ast
21
+ import importlib.util
22
+ import os
23
+ from argparse import ArgumentParser, Namespace
24
+ from pathlib import Path
25
+
26
+ from ..utils import logging
27
+ from . import BaseDiffusersCLICommand
28
+
29
+
30
+ EXPECTED_PARENT_CLASSES = ["ModularPipelineBlocks"]
31
+ CONFIG = "config.json"
32
+
33
+
34
+ def conversion_command_factory(args: Namespace):
35
+ return CustomBlocksCommand(args.block_module_name, args.block_class_name)
36
+
37
+
38
+ class CustomBlocksCommand(BaseDiffusersCLICommand):
39
+ @staticmethod
40
+ def register_subcommand(parser: ArgumentParser):
41
+ conversion_parser = parser.add_parser("custom_blocks")
42
+ conversion_parser.add_argument(
43
+ "--block_module_name",
44
+ type=str,
45
+ default="block.py",
46
+ help="Module filename in which the custom block will be implemented.",
47
+ )
48
+ conversion_parser.add_argument(
49
+ "--block_class_name",
50
+ type=str,
51
+ default=None,
52
+ help="Name of the custom block. If provided None, we will try to infer it.",
53
+ )
54
+ conversion_parser.set_defaults(func=conversion_command_factory)
55
+
56
+ def __init__(self, block_module_name: str = "block.py", block_class_name: str = None):
57
+ self.logger = logging.get_logger("diffusers-cli/custom_blocks")
58
+ self.block_module_name = Path(block_module_name)
59
+ self.block_class_name = block_class_name
60
+
61
+ def run(self):
62
+ # determine the block to be saved.
63
+ out = self._get_class_names(self.block_module_name)
64
+ classes_found = list({cls for cls, _ in out})
65
+
66
+ if self.block_class_name is not None:
67
+ child_class, parent_class = self._choose_block(out, self.block_class_name)
68
+ if child_class is None and parent_class is None:
69
+ raise ValueError(
70
+ "`block_class_name` could not be retrieved. Available classes from "
71
+ f"{self.block_module_name}:\n{classes_found}"
72
+ )
73
+ else:
74
+ self.logger.info(
75
+ f"Found classes: {classes_found} will be using {classes_found[0]}. "
76
+ "If this needs to be changed, re-run the command specifying `block_class_name`."
77
+ )
78
+ child_class, parent_class = out[0][0], out[0][1]
79
+
80
+ # dynamically get the custom block and initialize it to call `save_pretrained` in the current directory.
81
+ # the user is responsible for running it, so I guess that is safe?
82
+ module_name = f"__dynamic__{self.block_module_name.stem}"
83
+ spec = importlib.util.spec_from_file_location(module_name, str(self.block_module_name))
84
+ module = importlib.util.module_from_spec(spec)
85
+ spec.loader.exec_module(module)
86
+ getattr(module, child_class)().save_pretrained(os.getcwd())
87
+
88
+ # or, we could create it manually.
89
+ # automap = self._create_automap(parent_class=parent_class, child_class=child_class)
90
+ # with open(CONFIG, "w") as f:
91
+ # json.dump(automap, f)
92
+
93
+ def _choose_block(self, candidates, chosen=None):
94
+ for cls, base in candidates:
95
+ if cls == chosen:
96
+ return cls, base
97
+ return None, None
98
+
99
+ def _get_class_names(self, file_path):
100
+ source = file_path.read_text(encoding="utf-8")
101
+ try:
102
+ tree = ast.parse(source, filename=file_path)
103
+ except SyntaxError as e:
104
+ raise ValueError(f"Could not parse {file_path!r}: {e}") from e
105
+
106
+ results: list[tuple[str, str]] = []
107
+ for node in tree.body:
108
+ if not isinstance(node, ast.ClassDef):
109
+ continue
110
+
111
+ # extract all base names for this class
112
+ base_names = [bname for b in node.bases if (bname := self._get_base_name(b)) is not None]
113
+
114
+ # for each allowed base that appears in the class's bases, emit a tuple
115
+ for allowed in EXPECTED_PARENT_CLASSES:
116
+ if allowed in base_names:
117
+ results.append((node.name, allowed))
118
+
119
+ return results
120
+
121
+ def _get_base_name(self, node: ast.expr):
122
+ if isinstance(node, ast.Name):
123
+ return node.id
124
+ elif isinstance(node, ast.Attribute):
125
+ val = self._get_base_name(node.value)
126
+ return f"{val}.{node.attr}" if val else node.attr
127
+ return None
128
+
129
+ def _create_automap(self, parent_class, child_class):
130
+ module = str(self.block_module_name).replace(".py", "").rsplit(".", 1)[-1]
131
+ auto_map = {f"{parent_class}": f"{module}.{child_class}"}
132
+ return {"auto_map": auto_map}
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/commands/diffusers_cli.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from argparse import ArgumentParser
17
+
18
+ from .custom_blocks import CustomBlocksCommand
19
+ from .env import EnvironmentCommand
20
+ from .fp16_safetensors import FP16SafetensorsCommand
21
+
22
+
23
+ def main():
24
+ parser = ArgumentParser("Diffusers CLI tool", usage="diffusers-cli <command> [<args>]")
25
+ commands_parser = parser.add_subparsers(help="diffusers-cli command helpers")
26
+
27
+ # Register commands
28
+ EnvironmentCommand.register_subcommand(commands_parser)
29
+ FP16SafetensorsCommand.register_subcommand(commands_parser)
30
+ CustomBlocksCommand.register_subcommand(commands_parser)
31
+
32
+ # Let's go
33
+ args = parser.parse_args()
34
+
35
+ if not hasattr(args, "func"):
36
+ parser.print_help()
37
+ exit(1)
38
+
39
+ # Run
40
+ service = args.func(args)
41
+ service.run()
42
+
43
+
44
+ if __name__ == "__main__":
45
+ main()
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/commands/env.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import platform
16
+ import subprocess
17
+ from argparse import ArgumentParser
18
+
19
+ import huggingface_hub
20
+
21
+ from .. import __version__ as version
22
+ from ..utils import (
23
+ is_accelerate_available,
24
+ is_bitsandbytes_available,
25
+ is_flax_available,
26
+ is_google_colab,
27
+ is_peft_available,
28
+ is_safetensors_available,
29
+ is_torch_available,
30
+ is_transformers_available,
31
+ is_xformers_available,
32
+ )
33
+ from . import BaseDiffusersCLICommand
34
+
35
+
36
+ def info_command_factory(_):
37
+ return EnvironmentCommand()
38
+
39
+
40
+ class EnvironmentCommand(BaseDiffusersCLICommand):
41
+ @staticmethod
42
+ def register_subcommand(parser: ArgumentParser) -> None:
43
+ download_parser = parser.add_parser("env")
44
+ download_parser.set_defaults(func=info_command_factory)
45
+
46
+ def run(self) -> dict:
47
+ hub_version = huggingface_hub.__version__
48
+
49
+ safetensors_version = "not installed"
50
+ if is_safetensors_available():
51
+ import safetensors
52
+
53
+ safetensors_version = safetensors.__version__
54
+
55
+ pt_version = "not installed"
56
+ pt_cuda_available = "NA"
57
+ if is_torch_available():
58
+ import torch
59
+
60
+ pt_version = torch.__version__
61
+ pt_cuda_available = torch.cuda.is_available()
62
+
63
+ flax_version = "not installed"
64
+ jax_version = "not installed"
65
+ jaxlib_version = "not installed"
66
+ jax_backend = "NA"
67
+ if is_flax_available():
68
+ import flax
69
+ import jax
70
+ import jaxlib
71
+
72
+ flax_version = flax.__version__
73
+ jax_version = jax.__version__
74
+ jaxlib_version = jaxlib.__version__
75
+ jax_backend = jax.lib.xla_bridge.get_backend().platform
76
+
77
+ transformers_version = "not installed"
78
+ if is_transformers_available():
79
+ import transformers
80
+
81
+ transformers_version = transformers.__version__
82
+
83
+ accelerate_version = "not installed"
84
+ if is_accelerate_available():
85
+ import accelerate
86
+
87
+ accelerate_version = accelerate.__version__
88
+
89
+ peft_version = "not installed"
90
+ if is_peft_available():
91
+ import peft
92
+
93
+ peft_version = peft.__version__
94
+
95
+ bitsandbytes_version = "not installed"
96
+ if is_bitsandbytes_available():
97
+ import bitsandbytes
98
+
99
+ bitsandbytes_version = bitsandbytes.__version__
100
+
101
+ xformers_version = "not installed"
102
+ if is_xformers_available():
103
+ import xformers
104
+
105
+ xformers_version = xformers.__version__
106
+
107
+ platform_info = platform.platform()
108
+
109
+ is_google_colab_str = "Yes" if is_google_colab() else "No"
110
+
111
+ accelerator = "NA"
112
+ if platform.system() in {"Linux", "Windows"}:
113
+ try:
114
+ sp = subprocess.Popen(
115
+ ["nvidia-smi", "--query-gpu=gpu_name,memory.total", "--format=csv,noheader"],
116
+ stdout=subprocess.PIPE,
117
+ stderr=subprocess.PIPE,
118
+ )
119
+ out_str, _ = sp.communicate()
120
+ out_str = out_str.decode("utf-8")
121
+
122
+ if len(out_str) > 0:
123
+ accelerator = out_str.strip()
124
+ except FileNotFoundError:
125
+ pass
126
+ elif platform.system() == "Darwin": # Mac OS
127
+ try:
128
+ sp = subprocess.Popen(
129
+ ["system_profiler", "SPDisplaysDataType"],
130
+ stdout=subprocess.PIPE,
131
+ stderr=subprocess.PIPE,
132
+ )
133
+ out_str, _ = sp.communicate()
134
+ out_str = out_str.decode("utf-8")
135
+
136
+ start = out_str.find("Chipset Model:")
137
+ if start != -1:
138
+ start += len("Chipset Model:")
139
+ end = out_str.find("\n", start)
140
+ accelerator = out_str[start:end].strip()
141
+
142
+ start = out_str.find("VRAM (Total):")
143
+ if start != -1:
144
+ start += len("VRAM (Total):")
145
+ end = out_str.find("\n", start)
146
+ accelerator += " VRAM: " + out_str[start:end].strip()
147
+ except FileNotFoundError:
148
+ pass
149
+ else:
150
+ print("It seems you are running an unusual OS. Could you fill in the accelerator manually?")
151
+
152
+ info = {
153
+ "🤗 Diffusers version": version,
154
+ "Platform": platform_info,
155
+ "Running on Google Colab?": is_google_colab_str,
156
+ "Python version": platform.python_version(),
157
+ "PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})",
158
+ "Flax version (CPU?/GPU?/TPU?)": f"{flax_version} ({jax_backend})",
159
+ "Jax version": jax_version,
160
+ "JaxLib version": jaxlib_version,
161
+ "Huggingface_hub version": hub_version,
162
+ "Transformers version": transformers_version,
163
+ "Accelerate version": accelerate_version,
164
+ "PEFT version": peft_version,
165
+ "Bitsandbytes version": bitsandbytes_version,
166
+ "Safetensors version": safetensors_version,
167
+ "xFormers version": xformers_version,
168
+ "Accelerator": accelerator,
169
+ "Using GPU in script?": "<fill in>",
170
+ "Using distributed or parallel set-up in script?": "<fill in>",
171
+ }
172
+
173
+ print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n")
174
+ print(self.format_dict(info))
175
+
176
+ return info
177
+
178
+ @staticmethod
179
+ def format_dict(d: dict) -> str:
180
+ return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n"
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/commands/fp16_safetensors.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Usage example:
17
+ diffusers-cli fp16_safetensors --ckpt_id=openai/shap-e --fp16 --use_safetensors
18
+ """
19
+
20
+ import glob
21
+ import json
22
+ import warnings
23
+ from argparse import ArgumentParser, Namespace
24
+ from importlib import import_module
25
+
26
+ import huggingface_hub
27
+ import torch
28
+ from huggingface_hub import hf_hub_download
29
+ from packaging import version
30
+
31
+ from ..utils import logging
32
+ from . import BaseDiffusersCLICommand
33
+
34
+
35
+ def conversion_command_factory(args: Namespace):
36
+ if args.use_auth_token:
37
+ warnings.warn(
38
+ "The `--use_auth_token` flag is deprecated and will be removed in a future version."
39
+ "Authentication is now handled automatically if the user is logged in."
40
+ )
41
+ return FP16SafetensorsCommand(args.ckpt_id, args.fp16, args.use_safetensors)
42
+
43
+
44
+ class FP16SafetensorsCommand(BaseDiffusersCLICommand):
45
+ @staticmethod
46
+ def register_subcommand(parser: ArgumentParser):
47
+ conversion_parser = parser.add_parser("fp16_safetensors")
48
+ conversion_parser.add_argument(
49
+ "--ckpt_id",
50
+ type=str,
51
+ help="Repo id of the checkpoints on which to run the conversion. Example: 'openai/shap-e'.",
52
+ )
53
+ conversion_parser.add_argument(
54
+ "--fp16", action="store_true", help="If serializing the variables in FP16 precision."
55
+ )
56
+ conversion_parser.add_argument(
57
+ "--use_safetensors", action="store_true", help="If serializing in the safetensors format."
58
+ )
59
+ conversion_parser.add_argument(
60
+ "--use_auth_token",
61
+ action="store_true",
62
+ help="When working with checkpoints having private visibility. When used `hf auth login` needs to be run beforehand.",
63
+ )
64
+ conversion_parser.set_defaults(func=conversion_command_factory)
65
+
66
+ def __init__(self, ckpt_id: str, fp16: bool, use_safetensors: bool):
67
+ self.logger = logging.get_logger("diffusers-cli/fp16_safetensors")
68
+ self.ckpt_id = ckpt_id
69
+ self.local_ckpt_dir = f"/tmp/{ckpt_id}"
70
+ self.fp16 = fp16
71
+
72
+ self.use_safetensors = use_safetensors
73
+
74
+ if not self.use_safetensors and not self.fp16:
75
+ raise NotImplementedError(
76
+ "When `use_safetensors` and `fp16` both are False, then this command is of no use."
77
+ )
78
+
79
+ def run(self):
80
+ if version.parse(huggingface_hub.__version__) < version.parse("0.9.0"):
81
+ raise ImportError(
82
+ "The huggingface_hub version must be >= 0.9.0 to use this command. Please update your huggingface_hub"
83
+ " installation."
84
+ )
85
+ else:
86
+ from huggingface_hub import create_commit
87
+ from huggingface_hub._commit_api import CommitOperationAdd
88
+
89
+ model_index = hf_hub_download(repo_id=self.ckpt_id, filename="model_index.json")
90
+ with open(model_index, "r") as f:
91
+ pipeline_class_name = json.load(f)["_class_name"]
92
+ pipeline_class = getattr(import_module("diffusers"), pipeline_class_name)
93
+ self.logger.info(f"Pipeline class imported: {pipeline_class_name}.")
94
+
95
+ # Load the appropriate pipeline. We could have used `DiffusionPipeline`
96
+ # here, but just to avoid potential edge cases.
97
+ pipeline = pipeline_class.from_pretrained(
98
+ self.ckpt_id, torch_dtype=torch.float16 if self.fp16 else torch.float32
99
+ )
100
+ pipeline.save_pretrained(
101
+ self.local_ckpt_dir,
102
+ safe_serialization=True if self.use_safetensors else False,
103
+ variant="fp16" if self.fp16 else None,
104
+ )
105
+ self.logger.info(f"Pipeline locally saved to {self.local_ckpt_dir}.")
106
+
107
+ # Fetch all the paths.
108
+ if self.fp16:
109
+ modified_paths = glob.glob(f"{self.local_ckpt_dir}/*/*.fp16.*")
110
+ elif self.use_safetensors:
111
+ modified_paths = glob.glob(f"{self.local_ckpt_dir}/*/*.safetensors")
112
+
113
+ # Prepare for the PR.
114
+ commit_message = f"Serialize variables with FP16: {self.fp16} and safetensors: {self.use_safetensors}."
115
+ operations = []
116
+ for path in modified_paths:
117
+ operations.append(CommitOperationAdd(path_in_repo="/".join(path.split("/")[4:]), path_or_fileobj=path))
118
+
119
+ # Open the PR.
120
+ commit_description = (
121
+ "Variables converted by the [`diffusers`' `fp16_safetensors`"
122
+ " CLI](https://github.com/huggingface/diffusers/blob/main/src/diffusers/commands/fp16_safetensors.py)."
123
+ )
124
+ hub_pr_url = create_commit(
125
+ repo_id=self.ckpt_id,
126
+ operations=operations,
127
+ commit_message=commit_message,
128
+ commit_description=commit_description,
129
+ repo_type="model",
130
+ create_pr=True,
131
+ ).pr_url
132
+ self.logger.info(f"PR created here: {hub_pr_url}.")
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/experimental/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .rl import ValueGuidedRLPipeline
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from ..utils import is_torch_available, logging
17
+
18
+
19
+ if is_torch_available():
20
+ from .adaptive_projected_guidance import AdaptiveProjectedGuidance
21
+ from .adaptive_projected_guidance_mix import AdaptiveProjectedMixGuidance
22
+ from .auto_guidance import AutoGuidance
23
+ from .classifier_free_guidance import ClassifierFreeGuidance
24
+ from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance
25
+ from .frequency_decoupled_guidance import FrequencyDecoupledGuidance
26
+ from .guider_utils import BaseGuidance
27
+ from .magnitude_aware_guidance import MagnitudeAwareGuidance
28
+ from .perturbed_attention_guidance import PerturbedAttentionGuidance
29
+ from .skip_layer_guidance import SkipLayerGuidance
30
+ from .smoothed_energy_guidance import SmoothedEnergyGuidance
31
+ from .tangential_classifier_free_guidance import TangentialClassifierFreeGuidance
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/adaptive_projected_guidance.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import math
18
+ from typing import TYPE_CHECKING
19
+
20
+ import torch
21
+
22
+ from ..configuration_utils import register_to_config
23
+ from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
24
+
25
+
26
+ if TYPE_CHECKING:
27
+ from ..modular_pipelines.modular_pipeline import BlockState
28
+
29
+
30
+ class AdaptiveProjectedGuidance(BaseGuidance):
31
+ """
32
+ Adaptive Projected Guidance (APG): https://huggingface.co/papers/2410.02416
33
+
34
+ Args:
35
+ guidance_scale (`float`, defaults to `7.5`):
36
+ The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
37
+ prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
38
+ deterioration of image quality.
39
+ adaptive_projected_guidance_momentum (`float`, defaults to `None`):
40
+ The momentum parameter for the adaptive projected guidance. Disabled if set to `None`.
41
+ adaptive_projected_guidance_rescale (`float`, defaults to `15.0`):
42
+ The rescale factor applied to the noise predictions. This is used to improve image quality and fix
43
+ guidance_rescale (`float`, defaults to `0.0`):
44
+ The rescale factor applied to the noise predictions. This is used to improve image quality and fix
45
+ overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
46
+ Flawed](https://huggingface.co/papers/2305.08891).
47
+ use_original_formulation (`bool`, defaults to `False`):
48
+ Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
49
+ we use the diffusers-native implementation that has been in the codebase for a long time. See
50
+ [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
51
+ start (`float`, defaults to `0.0`):
52
+ The fraction of the total number of denoising steps after which guidance starts.
53
+ stop (`float`, defaults to `1.0`):
54
+ The fraction of the total number of denoising steps after which guidance stops.
55
+ """
56
+
57
+ _input_predictions = ["pred_cond", "pred_uncond"]
58
+
59
+ @register_to_config
60
+ def __init__(
61
+ self,
62
+ guidance_scale: float = 7.5,
63
+ adaptive_projected_guidance_momentum: float | None = None,
64
+ adaptive_projected_guidance_rescale: float = 15.0,
65
+ eta: float = 1.0,
66
+ guidance_rescale: float = 0.0,
67
+ use_original_formulation: bool = False,
68
+ start: float = 0.0,
69
+ stop: float = 1.0,
70
+ enabled: bool = True,
71
+ ):
72
+ super().__init__(start, stop, enabled)
73
+
74
+ self.guidance_scale = guidance_scale
75
+ self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
76
+ self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale
77
+ self.eta = eta
78
+ self.guidance_rescale = guidance_rescale
79
+ self.use_original_formulation = use_original_formulation
80
+ self.momentum_buffer = None
81
+
82
+ def prepare_inputs(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]) -> list["BlockState"]:
83
+ if self._step == 0:
84
+ if self.adaptive_projected_guidance_momentum is not None:
85
+ self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
86
+ tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
87
+ data_batches = []
88
+ for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
89
+ data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
90
+ data_batches.append(data_batch)
91
+ return data_batches
92
+
93
+ def prepare_inputs_from_block_state(
94
+ self, data: "BlockState", input_fields: dict[str, str | tuple[str, str]]
95
+ ) -> list["BlockState"]:
96
+ if self._step == 0:
97
+ if self.adaptive_projected_guidance_momentum is not None:
98
+ self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
99
+ tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
100
+ data_batches = []
101
+ for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
102
+ data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
103
+ data_batches.append(data_batch)
104
+ return data_batches
105
+
106
+ def forward(self, pred_cond: torch.Tensor, pred_uncond: torch.Tensor | None = None) -> GuiderOutput:
107
+ pred = None
108
+
109
+ if not self._is_apg_enabled():
110
+ pred = pred_cond
111
+ else:
112
+ pred = normalized_guidance(
113
+ pred_cond,
114
+ pred_uncond,
115
+ self.guidance_scale,
116
+ self.momentum_buffer,
117
+ self.eta,
118
+ self.adaptive_projected_guidance_rescale,
119
+ self.use_original_formulation,
120
+ )
121
+
122
+ if self.guidance_rescale > 0.0:
123
+ pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
124
+
125
+ return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
126
+
127
+ @property
128
+ def is_conditional(self) -> bool:
129
+ return self._count_prepared == 1
130
+
131
+ @property
132
+ def num_conditions(self) -> int:
133
+ num_conditions = 1
134
+ if self._is_apg_enabled():
135
+ num_conditions += 1
136
+ return num_conditions
137
+
138
+ def _is_apg_enabled(self) -> bool:
139
+ if not self._enabled:
140
+ return False
141
+
142
+ is_within_range = True
143
+ if self._num_inference_steps is not None:
144
+ skip_start_step = int(self._start * self._num_inference_steps)
145
+ skip_stop_step = int(self._stop * self._num_inference_steps)
146
+ is_within_range = skip_start_step <= self._step < skip_stop_step
147
+
148
+ is_close = False
149
+ if self.use_original_formulation:
150
+ is_close = math.isclose(self.guidance_scale, 0.0)
151
+ else:
152
+ is_close = math.isclose(self.guidance_scale, 1.0)
153
+
154
+ return is_within_range and not is_close
155
+
156
+
157
+ class MomentumBuffer:
158
+ def __init__(self, momentum: float):
159
+ self.momentum = momentum
160
+ self.running_average = 0
161
+
162
+ def update(self, update_value: torch.Tensor):
163
+ new_average = self.momentum * self.running_average
164
+ self.running_average = update_value + new_average
165
+
166
+ def __repr__(self) -> str:
167
+ """
168
+ Returns a string representation showing momentum, shape, statistics, and a slice of the running_average.
169
+ """
170
+ if isinstance(self.running_average, torch.Tensor):
171
+ shape = tuple(self.running_average.shape)
172
+
173
+ # Calculate statistics
174
+ with torch.no_grad():
175
+ stats = {
176
+ "mean": self.running_average.mean().item(),
177
+ "std": self.running_average.std().item(),
178
+ "min": self.running_average.min().item(),
179
+ "max": self.running_average.max().item(),
180
+ }
181
+
182
+ # Get a slice (max 3 elements per dimension)
183
+ slice_indices = tuple(slice(None, min(3, dim)) for dim in shape)
184
+ sliced_data = self.running_average[slice_indices]
185
+
186
+ # Format the slice for display (convert to float32 for numpy compatibility with bfloat16)
187
+ slice_str = str(sliced_data.detach().float().cpu().numpy())
188
+ if len(slice_str) > 200: # Truncate if too long
189
+ slice_str = slice_str[:200] + "..."
190
+
191
+ stats_str = ", ".join([f"{k}={v:.4f}" for k, v in stats.items()])
192
+
193
+ return (
194
+ f"MomentumBuffer(\n"
195
+ f" momentum={self.momentum},\n"
196
+ f" shape={shape},\n"
197
+ f" stats=[{stats_str}],\n"
198
+ f" slice={slice_str}\n"
199
+ f")"
200
+ )
201
+ else:
202
+ return f"MomentumBuffer(momentum={self.momentum}, running_average={self.running_average})"
203
+
204
+
205
+ def normalized_guidance(
206
+ pred_cond: torch.Tensor,
207
+ pred_uncond: torch.Tensor,
208
+ guidance_scale: float,
209
+ momentum_buffer: MomentumBuffer | None = None,
210
+ eta: float = 1.0,
211
+ norm_threshold: float = 0.0,
212
+ use_original_formulation: bool = False,
213
+ ):
214
+ diff = pred_cond - pred_uncond
215
+ dim = [-i for i in range(1, len(diff.shape))]
216
+
217
+ if momentum_buffer is not None:
218
+ momentum_buffer.update(diff)
219
+ diff = momentum_buffer.running_average
220
+
221
+ if norm_threshold > 0:
222
+ ones = torch.ones_like(diff)
223
+ diff_norm = diff.norm(p=2, dim=dim, keepdim=True)
224
+ scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
225
+ diff = diff * scale_factor
226
+
227
+ v0, v1 = diff.double(), pred_cond.double()
228
+ v1 = torch.nn.functional.normalize(v1, dim=dim)
229
+ v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
230
+ v0_orthogonal = v0 - v0_parallel
231
+ diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff)
232
+ normalized_update = diff_orthogonal + eta * diff_parallel
233
+
234
+ pred = pred_cond if use_original_formulation else pred_uncond
235
+ pred = pred + guidance_scale * normalized_update
236
+
237
+ return pred
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/adaptive_projected_guidance_mix.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from typing import TYPE_CHECKING
17
+
18
+ import torch
19
+
20
+ from ..configuration_utils import register_to_config
21
+ from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
22
+
23
+
24
+ if TYPE_CHECKING:
25
+ from ..modular_pipelines.modular_pipeline import BlockState
26
+
27
+
28
+ class AdaptiveProjectedMixGuidance(BaseGuidance):
29
+ """
30
+ Adaptive Projected Guidance (APG) https://huggingface.co/papers/2410.02416 combined with Classifier-Free Guidance
31
+ (CFG). This guider is used in HunyuanImage2.1 https://github.com/Tencent-Hunyuan/HunyuanImage-2.1
32
+
33
+ Args:
34
+ guidance_scale (`float`, defaults to `7.5`):
35
+ The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
36
+ prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
37
+ deterioration of image quality.
38
+ adaptive_projected_guidance_momentum (`float`, defaults to `None`):
39
+ The momentum parameter for the adaptive projected guidance. Disabled if set to `None`.
40
+ adaptive_projected_guidance_rescale (`float`, defaults to `15.0`):
41
+ The rescale factor applied to the noise predictions for adaptive projected guidance. This is used to
42
+ improve image quality and fix
43
+ guidance_rescale (`float`, defaults to `0.0`):
44
+ The rescale factor applied to the noise predictions for classifier-free guidance. This is used to improve
45
+ image quality and fix overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample
46
+ Steps are Flawed](https://huggingface.co/papers/2305.08891).
47
+ use_original_formulation (`bool`, defaults to `False`):
48
+ Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
49
+ we use the diffusers-native implementation that has been in the codebase for a long time. See
50
+ [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
51
+ start (`float`, defaults to `0.0`):
52
+ The fraction of the total number of denoising steps after which the classifier-free guidance starts.
53
+ stop (`float`, defaults to `1.0`):
54
+ The fraction of the total number of denoising steps after which the classifier-free guidance stops.
55
+ adaptive_projected_guidance_start_step (`int`, defaults to `5`):
56
+ The step at which the adaptive projected guidance starts (before this step, classifier-free guidance is
57
+ used, and momentum buffer is updated).
58
+ enabled (`bool`, defaults to `True`):
59
+ Whether this guidance is enabled.
60
+ """
61
+
62
+ _input_predictions = ["pred_cond", "pred_uncond"]
63
+
64
+ @register_to_config
65
+ def __init__(
66
+ self,
67
+ guidance_scale: float = 3.5,
68
+ guidance_rescale: float = 0.0,
69
+ adaptive_projected_guidance_scale: float = 10.0,
70
+ adaptive_projected_guidance_momentum: float = -0.5,
71
+ adaptive_projected_guidance_rescale: float = 10.0,
72
+ eta: float = 0.0,
73
+ use_original_formulation: bool = False,
74
+ start: float = 0.0,
75
+ stop: float = 1.0,
76
+ adaptive_projected_guidance_start_step: int = 5,
77
+ enabled: bool = True,
78
+ ):
79
+ super().__init__(start, stop, enabled)
80
+
81
+ self.guidance_scale = guidance_scale
82
+ self.guidance_rescale = guidance_rescale
83
+ self.adaptive_projected_guidance_scale = adaptive_projected_guidance_scale
84
+ self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
85
+ self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale
86
+ self.eta = eta
87
+ self.adaptive_projected_guidance_start_step = adaptive_projected_guidance_start_step
88
+ self.use_original_formulation = use_original_formulation
89
+ self.momentum_buffer = None
90
+
91
+ def prepare_inputs(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]) -> list["BlockState"]:
92
+ if self._step == 0:
93
+ if self.adaptive_projected_guidance_momentum is not None:
94
+ self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
95
+ tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
96
+ data_batches = []
97
+ for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
98
+ data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
99
+ data_batches.append(data_batch)
100
+ return data_batches
101
+
102
+ def prepare_inputs_from_block_state(
103
+ self, data: "BlockState", input_fields: dict[str, str | tuple[str, str]]
104
+ ) -> list["BlockState"]:
105
+ if self._step == 0:
106
+ if self.adaptive_projected_guidance_momentum is not None:
107
+ self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
108
+ tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
109
+ data_batches = []
110
+ for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
111
+ data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
112
+ data_batches.append(data_batch)
113
+ return data_batches
114
+
115
+ def forward(self, pred_cond: torch.Tensor, pred_uncond: torch.Tensor | None = None) -> GuiderOutput:
116
+ pred = None
117
+
118
+ # no guidance
119
+ if not self._is_cfg_enabled():
120
+ pred = pred_cond
121
+
122
+ # CFG + update momentum buffer
123
+ elif not self._is_apg_enabled():
124
+ if self.momentum_buffer is not None:
125
+ update_momentum_buffer(pred_cond, pred_uncond, self.momentum_buffer)
126
+ # CFG + update momentum buffer
127
+ shift = pred_cond - pred_uncond
128
+ pred = pred_cond if self.use_original_formulation else pred_uncond
129
+ pred = pred + self.guidance_scale * shift
130
+
131
+ # APG
132
+ elif self._is_apg_enabled():
133
+ pred = normalized_guidance(
134
+ pred_cond,
135
+ pred_uncond,
136
+ self.adaptive_projected_guidance_scale,
137
+ self.momentum_buffer,
138
+ self.eta,
139
+ self.adaptive_projected_guidance_rescale,
140
+ self.use_original_formulation,
141
+ )
142
+
143
+ if self.guidance_rescale > 0.0:
144
+ pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
145
+
146
+ return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
147
+
148
+ @property
149
+ def is_conditional(self) -> bool:
150
+ return self._count_prepared == 1
151
+
152
+ @property
153
+ def num_conditions(self) -> int:
154
+ num_conditions = 1
155
+ if self._is_apg_enabled() or self._is_cfg_enabled():
156
+ num_conditions += 1
157
+ return num_conditions
158
+
159
+ # Copied from diffusers.guiders.classifier_free_guidance.ClassifierFreeGuidance._is_cfg_enabled
160
+ def _is_cfg_enabled(self) -> bool:
161
+ if not self._enabled:
162
+ return False
163
+
164
+ is_within_range = True
165
+ if self._num_inference_steps is not None:
166
+ skip_start_step = int(self._start * self._num_inference_steps)
167
+ skip_stop_step = int(self._stop * self._num_inference_steps)
168
+ is_within_range = skip_start_step <= self._step < skip_stop_step
169
+
170
+ is_close = False
171
+ if self.use_original_formulation:
172
+ is_close = math.isclose(self.guidance_scale, 0.0)
173
+ else:
174
+ is_close = math.isclose(self.guidance_scale, 1.0)
175
+
176
+ return is_within_range and not is_close
177
+
178
+ def _is_apg_enabled(self) -> bool:
179
+ if not self._enabled:
180
+ return False
181
+
182
+ if not self._is_cfg_enabled():
183
+ return False
184
+
185
+ is_within_range = False
186
+ if self._step is not None:
187
+ is_within_range = self._step > self.adaptive_projected_guidance_start_step
188
+
189
+ is_close = False
190
+ if self.use_original_formulation:
191
+ is_close = math.isclose(self.adaptive_projected_guidance_scale, 0.0)
192
+ else:
193
+ is_close = math.isclose(self.adaptive_projected_guidance_scale, 1.0)
194
+
195
+ return is_within_range and not is_close
196
+
197
+ def get_state(self):
198
+ state = super().get_state()
199
+ state["momentum_buffer"] = self.momentum_buffer
200
+ state["is_apg_enabled"] = self._is_apg_enabled()
201
+ state["is_cfg_enabled"] = self._is_cfg_enabled()
202
+ return state
203
+
204
+
205
+ # Copied from diffusers.guiders.adaptive_projected_guidance.MomentumBuffer
206
+ class MomentumBuffer:
207
+ def __init__(self, momentum: float):
208
+ self.momentum = momentum
209
+ self.running_average = 0
210
+
211
+ def update(self, update_value: torch.Tensor):
212
+ new_average = self.momentum * self.running_average
213
+ self.running_average = update_value + new_average
214
+
215
+ def __repr__(self) -> str:
216
+ """
217
+ Returns a string representation showing momentum, shape, statistics, and a slice of the running_average.
218
+ """
219
+ if isinstance(self.running_average, torch.Tensor):
220
+ shape = tuple(self.running_average.shape)
221
+
222
+ # Calculate statistics
223
+ with torch.no_grad():
224
+ stats = {
225
+ "mean": self.running_average.mean().item(),
226
+ "std": self.running_average.std().item(),
227
+ "min": self.running_average.min().item(),
228
+ "max": self.running_average.max().item(),
229
+ }
230
+
231
+ # Get a slice (max 3 elements per dimension)
232
+ slice_indices = tuple(slice(None, min(3, dim)) for dim in shape)
233
+ sliced_data = self.running_average[slice_indices]
234
+
235
+ # Format the slice for display (convert to float32 for numpy compatibility with bfloat16)
236
+ slice_str = str(sliced_data.detach().float().cpu().numpy())
237
+ if len(slice_str) > 200: # Truncate if too long
238
+ slice_str = slice_str[:200] + "..."
239
+
240
+ stats_str = ", ".join([f"{k}={v:.4f}" for k, v in stats.items()])
241
+
242
+ return (
243
+ f"MomentumBuffer(\n"
244
+ f" momentum={self.momentum},\n"
245
+ f" shape={shape},\n"
246
+ f" stats=[{stats_str}],\n"
247
+ f" slice={slice_str}\n"
248
+ f")"
249
+ )
250
+ else:
251
+ return f"MomentumBuffer(momentum={self.momentum}, running_average={self.running_average})"
252
+
253
+
254
+ def update_momentum_buffer(
255
+ pred_cond: torch.Tensor,
256
+ pred_uncond: torch.Tensor,
257
+ momentum_buffer: MomentumBuffer | None = None,
258
+ ):
259
+ diff = pred_cond - pred_uncond
260
+ if momentum_buffer is not None:
261
+ momentum_buffer.update(diff)
262
+
263
+
264
+ def normalized_guidance(
265
+ pred_cond: torch.Tensor,
266
+ pred_uncond: torch.Tensor,
267
+ guidance_scale: float,
268
+ momentum_buffer: MomentumBuffer | None = None,
269
+ eta: float = 1.0,
270
+ norm_threshold: float = 0.0,
271
+ use_original_formulation: bool = False,
272
+ ):
273
+ if momentum_buffer is not None:
274
+ update_momentum_buffer(pred_cond, pred_uncond, momentum_buffer)
275
+ diff = momentum_buffer.running_average
276
+ else:
277
+ diff = pred_cond - pred_uncond
278
+
279
+ dim = [-i for i in range(1, len(diff.shape))]
280
+
281
+ if norm_threshold > 0:
282
+ ones = torch.ones_like(diff)
283
+ diff_norm = diff.norm(p=2, dim=dim, keepdim=True)
284
+ scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
285
+ diff = diff * scale_factor
286
+
287
+ v0, v1 = diff.double(), pred_cond.double()
288
+ v1 = torch.nn.functional.normalize(v1, dim=dim)
289
+ v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
290
+ v0_orthogonal = v0 - v0_parallel
291
+ diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff)
292
+ normalized_update = diff_orthogonal + eta * diff_parallel
293
+
294
+ pred = pred_cond if use_original_formulation else pred_uncond
295
+ pred = pred + guidance_scale * normalized_update
296
+
297
+ return pred
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/auto_guidance.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import math
18
+ from typing import TYPE_CHECKING, Any
19
+
20
+ import torch
21
+
22
+ from ..configuration_utils import register_to_config
23
+ from ..hooks import HookRegistry, LayerSkipConfig
24
+ from ..hooks.layer_skip import _apply_layer_skip_hook
25
+ from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
26
+
27
+
28
+ if TYPE_CHECKING:
29
+ from ..modular_pipelines.modular_pipeline import BlockState
30
+
31
+
32
+ class AutoGuidance(BaseGuidance):
33
+ """
34
+ AutoGuidance: https://huggingface.co/papers/2406.02507
35
+
36
+ Args:
37
+ guidance_scale (`float`, defaults to `7.5`):
38
+ The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
39
+ prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
40
+ deterioration of image quality.
41
+ auto_guidance_layers (`int` or `list[int]`, *optional*):
42
+ The layer indices to apply skip layer guidance to. Can be a single integer or a list of integers. If not
43
+ provided, `skip_layer_config` must be provided.
44
+ auto_guidance_config (`LayerSkipConfig` or `list[LayerSkipConfig]`, *optional*):
45
+ The configuration for the skip layer guidance. Can be a single `LayerSkipConfig` or a list of
46
+ `LayerSkipConfig`. If not provided, `skip_layer_guidance_layers` must be provided.
47
+ dropout (`float`, *optional*):
48
+ The dropout probability for autoguidance on the enabled skip layers (either with `auto_guidance_layers` or
49
+ `auto_guidance_config`). If not provided, the dropout probability will be set to 1.0.
50
+ guidance_rescale (`float`, defaults to `0.0`):
51
+ The rescale factor applied to the noise predictions. This is used to improve image quality and fix
52
+ overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
53
+ Flawed](https://huggingface.co/papers/2305.08891).
54
+ use_original_formulation (`bool`, defaults to `False`):
55
+ Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
56
+ we use the diffusers-native implementation that has been in the codebase for a long time. See
57
+ [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
58
+ start (`float`, defaults to `0.0`):
59
+ The fraction of the total number of denoising steps after which guidance starts.
60
+ stop (`float`, defaults to `1.0`):
61
+ The fraction of the total number of denoising steps after which guidance stops.
62
+ """
63
+
64
+ _input_predictions = ["pred_cond", "pred_uncond"]
65
+
66
+ @register_to_config
67
+ def __init__(
68
+ self,
69
+ guidance_scale: float = 7.5,
70
+ auto_guidance_layers: int | list[int] | None = None,
71
+ auto_guidance_config: LayerSkipConfig | list[LayerSkipConfig] | dict[str, Any] = None,
72
+ dropout: float | None = None,
73
+ guidance_rescale: float = 0.0,
74
+ use_original_formulation: bool = False,
75
+ start: float = 0.0,
76
+ stop: float = 1.0,
77
+ enabled: bool = True,
78
+ ):
79
+ super().__init__(start, stop, enabled)
80
+
81
+ self.guidance_scale = guidance_scale
82
+ self.auto_guidance_layers = auto_guidance_layers
83
+ self.auto_guidance_config = auto_guidance_config
84
+ self.dropout = dropout
85
+ self.guidance_rescale = guidance_rescale
86
+ self.use_original_formulation = use_original_formulation
87
+
88
+ is_layer_or_config_provided = auto_guidance_layers is not None or auto_guidance_config is not None
89
+ is_layer_and_config_provided = auto_guidance_layers is not None and auto_guidance_config is not None
90
+ if not is_layer_or_config_provided:
91
+ raise ValueError(
92
+ "Either `auto_guidance_layers` or `auto_guidance_config` must be provided to enable AutoGuidance."
93
+ )
94
+ if is_layer_and_config_provided:
95
+ raise ValueError("Only one of `auto_guidance_layers` or `auto_guidance_config` can be provided.")
96
+ if auto_guidance_config is None and dropout is None:
97
+ raise ValueError("`dropout` must be provided if `auto_guidance_layers` is provided.")
98
+
99
+ if auto_guidance_layers is not None:
100
+ if isinstance(auto_guidance_layers, int):
101
+ auto_guidance_layers = [auto_guidance_layers]
102
+ if not isinstance(auto_guidance_layers, list):
103
+ raise ValueError(
104
+ f"Expected `auto_guidance_layers` to be an int or a list of ints, but got {type(auto_guidance_layers)}."
105
+ )
106
+ auto_guidance_config = [
107
+ LayerSkipConfig(layer, fqn="auto", dropout=dropout) for layer in auto_guidance_layers
108
+ ]
109
+
110
+ if isinstance(auto_guidance_config, dict):
111
+ auto_guidance_config = LayerSkipConfig.from_dict(auto_guidance_config)
112
+
113
+ if isinstance(auto_guidance_config, LayerSkipConfig):
114
+ auto_guidance_config = [auto_guidance_config]
115
+
116
+ if not isinstance(auto_guidance_config, list):
117
+ raise ValueError(
118
+ f"Expected `auto_guidance_config` to be a LayerSkipConfig or a list of LayerSkipConfig, but got {type(auto_guidance_config)}."
119
+ )
120
+ elif isinstance(next(iter(auto_guidance_config), None), dict):
121
+ auto_guidance_config = [LayerSkipConfig.from_dict(config) for config in auto_guidance_config]
122
+
123
+ self.auto_guidance_config = auto_guidance_config
124
+ self._auto_guidance_hook_names = [f"AutoGuidance_{i}" for i in range(len(self.auto_guidance_config))]
125
+
126
+ def prepare_models(self, denoiser: torch.nn.Module) -> None:
127
+ self._count_prepared += 1
128
+ if self._is_ag_enabled() and self.is_unconditional:
129
+ for name, config in zip(self._auto_guidance_hook_names, self.auto_guidance_config):
130
+ _apply_layer_skip_hook(denoiser, config, name=name)
131
+
132
+ def cleanup_models(self, denoiser: torch.nn.Module) -> None:
133
+ if self._is_ag_enabled() and self.is_unconditional:
134
+ for name in self._auto_guidance_hook_names:
135
+ registry = HookRegistry.check_if_exists_or_initialize(denoiser)
136
+ registry.remove_hook(name, recurse=True)
137
+
138
+ def prepare_inputs(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]) -> list["BlockState"]:
139
+ tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
140
+ data_batches = []
141
+ for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
142
+ data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
143
+ data_batches.append(data_batch)
144
+ return data_batches
145
+
146
+ def prepare_inputs_from_block_state(
147
+ self, data: "BlockState", input_fields: dict[str, str | tuple[str, str]]
148
+ ) -> list["BlockState"]:
149
+ tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
150
+ data_batches = []
151
+ for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
152
+ data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
153
+ data_batches.append(data_batch)
154
+ return data_batches
155
+
156
+ def forward(self, pred_cond: torch.Tensor, pred_uncond: torch.Tensor | None = None) -> GuiderOutput:
157
+ pred = None
158
+
159
+ if not self._is_ag_enabled():
160
+ pred = pred_cond
161
+ else:
162
+ shift = pred_cond - pred_uncond
163
+ pred = pred_cond if self.use_original_formulation else pred_uncond
164
+ pred = pred + self.guidance_scale * shift
165
+
166
+ if self.guidance_rescale > 0.0:
167
+ pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
168
+
169
+ return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
170
+
171
+ @property
172
+ def is_conditional(self) -> bool:
173
+ return self._count_prepared == 1
174
+
175
+ @property
176
+ def num_conditions(self) -> int:
177
+ num_conditions = 1
178
+ if self._is_ag_enabled():
179
+ num_conditions += 1
180
+ return num_conditions
181
+
182
+ def _is_ag_enabled(self) -> bool:
183
+ if not self._enabled:
184
+ return False
185
+
186
+ is_within_range = True
187
+ if self._num_inference_steps is not None:
188
+ skip_start_step = int(self._start * self._num_inference_steps)
189
+ skip_stop_step = int(self._stop * self._num_inference_steps)
190
+ is_within_range = skip_start_step <= self._step < skip_stop_step
191
+
192
+ is_close = False
193
+ if self.use_original_formulation:
194
+ is_close = math.isclose(self.guidance_scale, 0.0)
195
+ else:
196
+ is_close = math.isclose(self.guidance_scale, 1.0)
197
+
198
+ return is_within_range and not is_close
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/classifier_free_guidance.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import math
18
+ from typing import TYPE_CHECKING
19
+
20
+ import torch
21
+
22
+ from ..configuration_utils import register_to_config
23
+ from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
24
+
25
+
26
+ if TYPE_CHECKING:
27
+ from ..modular_pipelines.modular_pipeline import BlockState
28
+
29
+
30
+ class ClassifierFreeGuidance(BaseGuidance):
31
+ """
32
+ Implements Classifier-Free Guidance (CFG) for diffusion models.
33
+
34
+ Reference: https://huggingface.co/papers/2207.12598
35
+
36
+ CFG improves generation quality and prompt adherence by jointly training models on both conditional and
37
+ unconditional data, then combining predictions during inference. This allows trading off between quality (high
38
+ guidance) and diversity (low guidance).
39
+
40
+ **Two CFG Formulations:**
41
+
42
+ 1. **Original formulation** (from paper):
43
+ ```
44
+ x_pred = x_cond + guidance_scale * (x_cond - x_uncond)
45
+ ```
46
+ Moves conditional predictions further from unconditional ones.
47
+
48
+ 2. **Diffusers-native formulation** (default, from Imagen paper):
49
+ ```
50
+ x_pred = x_uncond + guidance_scale * (x_cond - x_uncond)
51
+ ```
52
+ Moves unconditional predictions toward conditional ones, effectively suppressing negative features (e.g., "bad
53
+ quality", "watermarks"). Equivalent in theory but more intuitive.
54
+
55
+ Use `use_original_formulation=True` to switch to the original formulation.
56
+
57
+ Args:
58
+ guidance_scale (`float`, defaults to `7.5`):
59
+ CFG scale applied by this guider during post-processing. Higher values = stronger prompt conditioning but
60
+ may reduce quality. Typical range: 1.0-20.0.
61
+ guidance_rescale (`float`, defaults to `0.0`):
62
+ Rescaling factor to prevent overexposure from high guidance scales. Based on [Common Diffusion Noise
63
+ Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). Range: 0.0 (no rescaling)
64
+ to 1.0 (full rescaling).
65
+ use_original_formulation (`bool`, defaults to `False`):
66
+ If `True`, uses the original CFG formulation from the paper. If `False` (default), uses the
67
+ diffusers-native formulation from the Imagen paper.
68
+ start (`float`, defaults to `0.0`):
69
+ Fraction of denoising steps (0.0-1.0) after which CFG starts. Use > 0.0 to disable CFG in early denoising
70
+ steps.
71
+ stop (`float`, defaults to `1.0`):
72
+ Fraction of denoising steps (0.0-1.0) after which CFG stops. Use < 1.0 to disable CFG in late denoising
73
+ steps.
74
+ enabled (`bool`, defaults to `True`):
75
+ Whether CFG is enabled. Set to `False` to disable CFG entirely (uses only conditional predictions).
76
+ """
77
+
78
+ _input_predictions = ["pred_cond", "pred_uncond"]
79
+
80
+ @register_to_config
81
+ def __init__(
82
+ self,
83
+ guidance_scale: float = 7.5,
84
+ guidance_rescale: float = 0.0,
85
+ use_original_formulation: bool = False,
86
+ start: float = 0.0,
87
+ stop: float = 1.0,
88
+ enabled: bool = True,
89
+ ):
90
+ super().__init__(start, stop, enabled)
91
+
92
+ self.guidance_scale = guidance_scale
93
+ self.guidance_rescale = guidance_rescale
94
+ self.use_original_formulation = use_original_formulation
95
+
96
+ def prepare_inputs(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]) -> list["BlockState"]:
97
+ tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
98
+ data_batches = []
99
+ for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
100
+ data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
101
+ data_batches.append(data_batch)
102
+ return data_batches
103
+
104
+ def prepare_inputs_from_block_state(
105
+ self, data: "BlockState", input_fields: dict[str, str | tuple[str, str]]
106
+ ) -> list["BlockState"]:
107
+ tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
108
+ data_batches = []
109
+ for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
110
+ data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
111
+ data_batches.append(data_batch)
112
+ return data_batches
113
+
114
+ def forward(self, pred_cond: torch.Tensor, pred_uncond: torch.Tensor | None = None) -> GuiderOutput:
115
+ pred = None
116
+
117
+ if not self._is_cfg_enabled():
118
+ pred = pred_cond
119
+ else:
120
+ shift = pred_cond - pred_uncond
121
+ pred = pred_cond if self.use_original_formulation else pred_uncond
122
+ pred = pred + self.guidance_scale * shift
123
+
124
+ if self.guidance_rescale > 0.0:
125
+ pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
126
+
127
+ return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
128
+
129
+ @property
130
+ def is_conditional(self) -> bool:
131
+ return self._count_prepared == 1
132
+
133
+ @property
134
+ def num_conditions(self) -> int:
135
+ num_conditions = 1
136
+ if self._is_cfg_enabled():
137
+ num_conditions += 1
138
+ return num_conditions
139
+
140
+ def _is_cfg_enabled(self) -> bool:
141
+ if not self._enabled:
142
+ return False
143
+
144
+ is_within_range = True
145
+ if self._num_inference_steps is not None:
146
+ skip_start_step = int(self._start * self._num_inference_steps)
147
+ skip_stop_step = int(self._stop * self._num_inference_steps)
148
+ is_within_range = skip_start_step <= self._step < skip_stop_step
149
+
150
+ is_close = False
151
+ if self.use_original_formulation:
152
+ is_close = math.isclose(self.guidance_scale, 0.0)
153
+ else:
154
+ is_close = math.isclose(self.guidance_scale, 1.0)
155
+
156
+ return is_within_range and not is_close
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/classifier_free_zero_star_guidance.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import math
18
+ from typing import TYPE_CHECKING
19
+
20
+ import torch
21
+
22
+ from ..configuration_utils import register_to_config
23
+ from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
24
+
25
+
26
+ if TYPE_CHECKING:
27
+ from ..modular_pipelines.modular_pipeline import BlockState
28
+
29
+
30
+ class ClassifierFreeZeroStarGuidance(BaseGuidance):
31
+ """
32
+ Classifier-free Zero* (CFG-Zero*): https://huggingface.co/papers/2503.18886
33
+
34
+ This is an implementation of the Classifier-Free Zero* guidance technique, which is a variant of classifier-free
35
+ guidance. It proposes zero initialization of the noise predictions for the first few steps of the diffusion
36
+ process, and also introduces an optimal rescaling factor for the noise predictions, which can help in improving the
37
+ quality of generated images.
38
+
39
+ The authors of the paper suggest setting zero initialization in the first 4% of the inference steps.
40
+
41
+ Args:
42
+ guidance_scale (`float`, defaults to `7.5`):
43
+ The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
44
+ prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
45
+ deterioration of image quality.
46
+ zero_init_steps (`int`, defaults to `1`):
47
+ The number of inference steps for which the noise predictions are zeroed out (see Section 4.2).
48
+ guidance_rescale (`float`, defaults to `0.0`):
49
+ The rescale factor applied to the noise predictions. This is used to improve image quality and fix
50
+ overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
51
+ Flawed](https://huggingface.co/papers/2305.08891).
52
+ use_original_formulation (`bool`, defaults to `False`):
53
+ Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
54
+ we use the diffusers-native implementation that has been in the codebase for a long time. See
55
+ [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
56
+ start (`float`, defaults to `0.01`):
57
+ The fraction of the total number of denoising steps after which guidance starts.
58
+ stop (`float`, defaults to `0.2`):
59
+ The fraction of the total number of denoising steps after which guidance stops.
60
+ """
61
+
62
+ _input_predictions = ["pred_cond", "pred_uncond"]
63
+
64
+ @register_to_config
65
+ def __init__(
66
+ self,
67
+ guidance_scale: float = 7.5,
68
+ zero_init_steps: int = 1,
69
+ guidance_rescale: float = 0.0,
70
+ use_original_formulation: bool = False,
71
+ start: float = 0.0,
72
+ stop: float = 1.0,
73
+ enabled: bool = True,
74
+ ):
75
+ super().__init__(start, stop, enabled)
76
+
77
+ self.guidance_scale = guidance_scale
78
+ self.zero_init_steps = zero_init_steps
79
+ self.guidance_rescale = guidance_rescale
80
+ self.use_original_formulation = use_original_formulation
81
+
82
+ def prepare_inputs(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]) -> list["BlockState"]:
83
+ tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
84
+ data_batches = []
85
+ for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
86
+ data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
87
+ data_batches.append(data_batch)
88
+ return data_batches
89
+
90
+ def prepare_inputs_from_block_state(
91
+ self, data: "BlockState", input_fields: dict[str, str | tuple[str, str]]
92
+ ) -> list["BlockState"]:
93
+ tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
94
+ data_batches = []
95
+ for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
96
+ data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
97
+ data_batches.append(data_batch)
98
+ return data_batches
99
+
100
+ def forward(self, pred_cond: torch.Tensor, pred_uncond: torch.Tensor | None = None) -> GuiderOutput:
101
+ pred = None
102
+
103
+ # YiYi Notes: add default behavior for self._enabled == False
104
+ if not self._enabled:
105
+ pred = pred_cond
106
+
107
+ elif self._step < self.zero_init_steps:
108
+ pred = torch.zeros_like(pred_cond)
109
+ elif not self._is_cfg_enabled():
110
+ pred = pred_cond
111
+ else:
112
+ pred_cond_flat = pred_cond.flatten(1)
113
+ pred_uncond_flat = pred_uncond.flatten(1)
114
+ alpha = cfg_zero_star_scale(pred_cond_flat, pred_uncond_flat)
115
+ alpha = alpha.view(-1, *(1,) * (len(pred_cond.shape) - 1))
116
+ pred_uncond = pred_uncond * alpha
117
+ shift = pred_cond - pred_uncond
118
+ pred = pred_cond if self.use_original_formulation else pred_uncond
119
+ pred = pred + self.guidance_scale * shift
120
+
121
+ if self.guidance_rescale > 0.0:
122
+ pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
123
+
124
+ return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
125
+
126
+ @property
127
+ def is_conditional(self) -> bool:
128
+ return self._count_prepared == 1
129
+
130
+ @property
131
+ def num_conditions(self) -> int:
132
+ num_conditions = 1
133
+ if self._is_cfg_enabled():
134
+ num_conditions += 1
135
+ return num_conditions
136
+
137
+ def _is_cfg_enabled(self) -> bool:
138
+ if not self._enabled:
139
+ return False
140
+
141
+ is_within_range = True
142
+ if self._num_inference_steps is not None:
143
+ skip_start_step = int(self._start * self._num_inference_steps)
144
+ skip_stop_step = int(self._stop * self._num_inference_steps)
145
+ is_within_range = skip_start_step <= self._step < skip_stop_step
146
+
147
+ is_close = False
148
+ if self.use_original_formulation:
149
+ is_close = math.isclose(self.guidance_scale, 0.0)
150
+ else:
151
+ is_close = math.isclose(self.guidance_scale, 1.0)
152
+
153
+ return is_within_range and not is_close
154
+
155
+
156
+ def cfg_zero_star_scale(cond: torch.Tensor, uncond: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
157
+ cond_dtype = cond.dtype
158
+ cond = cond.float()
159
+ uncond = uncond.float()
160
+ dot_product = torch.sum(cond * uncond, dim=1, keepdim=True)
161
+ squared_norm = torch.sum(uncond**2, dim=1, keepdim=True) + eps
162
+ # st_star = v_cond^T * v_uncond / ||v_uncond||^2
163
+ scale = dot_product / squared_norm
164
+ return scale.to(dtype=cond_dtype)
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/frequency_decoupled_guidance.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import math
18
+ from typing import TYPE_CHECKING
19
+
20
+ import torch
21
+
22
+ from ..configuration_utils import register_to_config
23
+ from ..utils import is_kornia_available
24
+ from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
25
+
26
+
27
+ if TYPE_CHECKING:
28
+ from ..modular_pipelines.modular_pipeline import BlockState
29
+
30
+
31
+ _CAN_USE_KORNIA = is_kornia_available()
32
+
33
+
34
+ if _CAN_USE_KORNIA:
35
+ from kornia.geometry import pyrup as upsample_and_blur_func
36
+ from kornia.geometry.transform import build_laplacian_pyramid as build_laplacian_pyramid_func
37
+ else:
38
+ upsample_and_blur_func = None
39
+ build_laplacian_pyramid_func = None
40
+
41
+
42
+ def project(v0: torch.Tensor, v1: torch.Tensor, upcast_to_double: bool = True) -> tuple[torch.Tensor, torch.Tensor]:
43
+ """
44
+ Project vector v0 onto vector v1, returning the parallel and orthogonal components of v0. Implementation from paper
45
+ (Algorithm 2).
46
+ """
47
+ # v0 shape: [B, ...]
48
+ # v1 shape: [B, ...]
49
+ # Assume first dim is a batch dim and all other dims are channel or "spatial" dims
50
+ all_dims_but_first = list(range(1, len(v0.shape)))
51
+ if upcast_to_double:
52
+ dtype = v0.dtype
53
+ v0, v1 = v0.double(), v1.double()
54
+ v1 = torch.nn.functional.normalize(v1, dim=all_dims_but_first)
55
+ v0_parallel = (v0 * v1).sum(dim=all_dims_but_first, keepdim=True) * v1
56
+ v0_orthogonal = v0 - v0_parallel
57
+ if upcast_to_double:
58
+ v0_parallel = v0_parallel.to(dtype)
59
+ v0_orthogonal = v0_orthogonal.to(dtype)
60
+ return v0_parallel, v0_orthogonal
61
+
62
+
63
+ def build_image_from_pyramid(pyramid: list[torch.Tensor]) -> torch.Tensor:
64
+ """
65
+ Recovers the data space latents from the Laplacian pyramid frequency space. Implementation from the paper
66
+ (Algorithm 2).
67
+ """
68
+ # pyramid shapes: [[B, C, H, W], [B, C, H/2, W/2], ...]
69
+ img = pyramid[-1]
70
+ for i in range(len(pyramid) - 2, -1, -1):
71
+ img = upsample_and_blur_func(img) + pyramid[i]
72
+ return img
73
+
74
+
75
+ class FrequencyDecoupledGuidance(BaseGuidance):
76
+ """
77
+ Frequency-Decoupled Guidance (FDG): https://huggingface.co/papers/2506.19713
78
+
79
+ FDG is a technique similar to (and based on) classifier-free guidance (CFG) which is used to improve generation
80
+ quality and condition-following in diffusion models. Like CFG, during training we jointly train the model on both
81
+ conditional and unconditional data, and use a combination of the two during inference. (If you want more details on
82
+ how CFG works, you can check out the CFG guider.)
83
+
84
+ FDG differs from CFG in that the normal CFG prediction is instead decoupled into low- and high-frequency components
85
+ using a frequency transform (such as a Laplacian pyramid). The CFG update is then performed in frequency space
86
+ separately for the low- and high-frequency components with different guidance scales. Finally, the inverse
87
+ frequency transform is used to map the CFG frequency predictions back to data space (e.g. pixel space for images)
88
+ to form the final FDG prediction.
89
+
90
+ For images, the FDG authors found that using low guidance scales for the low-frequency components retains sample
91
+ diversity and realistic color composition, while using high guidance scales for high-frequency components enhances
92
+ sample quality (such as better visual details). Therefore, they recommend using low guidance scales (low w_low) for
93
+ the low-frequency components and high guidance scales (high w_high) for the high-frequency components. As an
94
+ example, they suggest w_low = 5.0 and w_high = 10.0 for Stable Diffusion XL (see Table 8 in the paper).
95
+
96
+ As with CFG, Diffusers implements the scaling and shifting on the unconditional prediction based on the [Imagen
97
+ paper](https://huggingface.co/papers/2205.11487), which is equivalent to what the original CFG paper proposed in
98
+ theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)]
99
+
100
+ The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the
101
+ paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time.
102
+
103
+ Args:
104
+ guidance_scales (`list[float]`, defaults to `[10.0, 5.0]`):
105
+ The scale parameter for frequency-decoupled guidance for each frequency component, listed from highest
106
+ frequency level to lowest. Higher values result in stronger conditioning on the text prompt, while lower
107
+ values allow for more freedom in generation. Higher values may lead to saturation and deterioration of
108
+ image quality. The FDG authors recommend using higher guidance scales for higher frequency components and
109
+ lower guidance scales for lower frequency components (so `guidance_scales` should typically be sorted in
110
+ descending order).
111
+ guidance_rescale (`float` or `list[float]`, defaults to `0.0`):
112
+ The rescale factor applied to the noise predictions. This is used to improve image quality and fix
113
+ overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
114
+ Flawed](https://huggingface.co/papers/2305.08891). If a list is supplied, it should be the same length as
115
+ `guidance_scales`.
116
+ parallel_weights (`float` or `list[float]`, *optional*):
117
+ Optional weights for the parallel component of each frequency component of the projected CFG shift. If not
118
+ set, the weights will default to `1.0` for all components, which corresponds to using the normal CFG shift
119
+ (that is, equal weights for the parallel and orthogonal components). If set, a value in `[0, 1]` is
120
+ recommended. If a list is supplied, it should be the same length as `guidance_scales`.
121
+ use_original_formulation (`bool`, defaults to `False`):
122
+ Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
123
+ we use the diffusers-native implementation that has been in the codebase for a long time. See
124
+ [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
125
+ start (`float` or `list[float]`, defaults to `0.0`):
126
+ The fraction of the total number of denoising steps after which guidance starts. If a list is supplied, it
127
+ should be the same length as `guidance_scales`.
128
+ stop (`float` or `list[float]`, defaults to `1.0`):
129
+ The fraction of the total number of denoising steps after which guidance stops. If a list is supplied, it
130
+ should be the same length as `guidance_scales`.
131
+ guidance_rescale_space (`str`, defaults to `"data"`):
132
+ Whether to performance guidance rescaling in `"data"` space (after the full FDG update in data space) or in
133
+ `"freq"` space (right after the CFG update, for each freq level). Note that frequency space rescaling is
134
+ speculative and may not produce expected results. If `"data"` is set, the first `guidance_rescale` value
135
+ will be used; otherwise, per-frequency-level guidance rescale values will be used if available.
136
+ upcast_to_double (`bool`, defaults to `True`):
137
+ Whether to upcast certain operations, such as the projection operation when using `parallel_weights`, to
138
+ float64 when performing guidance. This may result in better performance at the cost of increased runtime.
139
+ """
140
+
141
+ _input_predictions = ["pred_cond", "pred_uncond"]
142
+
143
+ @register_to_config
144
+ def __init__(
145
+ self,
146
+ guidance_scales: list[float] | tuple[float] = [10.0, 5.0],
147
+ guidance_rescale: float | list[float] | tuple[float] = 0.0,
148
+ parallel_weights: float | list[float] | tuple[float] | None = None,
149
+ use_original_formulation: bool = False,
150
+ start: float | list[float] | tuple[float] = 0.0,
151
+ stop: float | list[float] | tuple[float] = 1.0,
152
+ guidance_rescale_space: str = "data",
153
+ upcast_to_double: bool = True,
154
+ enabled: bool = True,
155
+ ):
156
+ if not _CAN_USE_KORNIA:
157
+ raise ImportError(
158
+ "The `FrequencyDecoupledGuidance` guider cannot be instantiated because the `kornia` library on which "
159
+ "it depends is not available in the current environment. You can install `kornia` with `pip install "
160
+ "kornia`."
161
+ )
162
+
163
+ # Set start to earliest start for any freq component and stop to latest stop for any freq component
164
+ min_start = start if isinstance(start, float) else min(start)
165
+ max_stop = stop if isinstance(stop, float) else max(stop)
166
+ super().__init__(min_start, max_stop, enabled)
167
+
168
+ self.guidance_scales = guidance_scales
169
+ self.levels = len(guidance_scales)
170
+
171
+ if isinstance(guidance_rescale, float):
172
+ self.guidance_rescale = [guidance_rescale] * self.levels
173
+ elif len(guidance_rescale) == self.levels:
174
+ self.guidance_rescale = guidance_rescale
175
+ else:
176
+ raise ValueError(
177
+ f"`guidance_rescale` has length {len(guidance_rescale)} but should have the same length as "
178
+ f"`guidance_scales` ({len(self.guidance_scales)})"
179
+ )
180
+ # Whether to perform guidance rescaling in frequency space (right after the CFG update) or data space (after
181
+ # transforming from frequency space back to data space)
182
+ if guidance_rescale_space not in ["data", "freq"]:
183
+ raise ValueError(
184
+ f"Guidance rescale space is {guidance_rescale_space} but must be one of `data` or `freq`."
185
+ )
186
+ self.guidance_rescale_space = guidance_rescale_space
187
+
188
+ if parallel_weights is None:
189
+ # Use normal CFG shift (equal weights for parallel and orthogonal components)
190
+ self.parallel_weights = [1.0] * self.levels
191
+ elif isinstance(parallel_weights, float):
192
+ self.parallel_weights = [parallel_weights] * self.levels
193
+ elif len(parallel_weights) == self.levels:
194
+ self.parallel_weights = parallel_weights
195
+ else:
196
+ raise ValueError(
197
+ f"`parallel_weights` has length {len(parallel_weights)} but should have the same length as "
198
+ f"`guidance_scales` ({len(self.guidance_scales)})"
199
+ )
200
+
201
+ self.use_original_formulation = use_original_formulation
202
+ self.upcast_to_double = upcast_to_double
203
+
204
+ if isinstance(start, float):
205
+ self.guidance_start = [start] * self.levels
206
+ elif len(start) == self.levels:
207
+ self.guidance_start = start
208
+ else:
209
+ raise ValueError(
210
+ f"`start` has length {len(start)} but should have the same length as `guidance_scales` "
211
+ f"({len(self.guidance_scales)})"
212
+ )
213
+ if isinstance(stop, float):
214
+ self.guidance_stop = [stop] * self.levels
215
+ elif len(stop) == self.levels:
216
+ self.guidance_stop = stop
217
+ else:
218
+ raise ValueError(
219
+ f"`stop` has length {len(stop)} but should have the same length as `guidance_scales` "
220
+ f"({len(self.guidance_scales)})"
221
+ )
222
+
223
+ def prepare_inputs(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]) -> list["BlockState"]:
224
+ tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
225
+ data_batches = []
226
+ for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
227
+ data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
228
+ data_batches.append(data_batch)
229
+ return data_batches
230
+
231
+ def prepare_inputs_from_block_state(
232
+ self, data: "BlockState", input_fields: dict[str, str | tuple[str, str]]
233
+ ) -> list["BlockState"]:
234
+ tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
235
+ data_batches = []
236
+ for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
237
+ data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
238
+ data_batches.append(data_batch)
239
+ return data_batches
240
+
241
+ def forward(self, pred_cond: torch.Tensor, pred_uncond: torch.Tensor | None = None) -> GuiderOutput:
242
+ pred = None
243
+
244
+ if not self._is_fdg_enabled():
245
+ pred = pred_cond
246
+ else:
247
+ # Apply the frequency transform (e.g. Laplacian pyramid) to the conditional and unconditional predictions.
248
+ pred_cond_pyramid = build_laplacian_pyramid_func(pred_cond, self.levels)
249
+ pred_uncond_pyramid = build_laplacian_pyramid_func(pred_uncond, self.levels)
250
+
251
+ # From high frequencies to low frequencies, following the paper implementation
252
+ pred_guided_pyramid = []
253
+ parameters = zip(self.guidance_scales, self.parallel_weights, self.guidance_rescale)
254
+ for level, (guidance_scale, parallel_weight, guidance_rescale) in enumerate(parameters):
255
+ if self._is_fdg_enabled_for_level(level):
256
+ # Get the cond/uncond preds (in freq space) at the current frequency level
257
+ pred_cond_freq = pred_cond_pyramid[level]
258
+ pred_uncond_freq = pred_uncond_pyramid[level]
259
+
260
+ shift = pred_cond_freq - pred_uncond_freq
261
+
262
+ # Apply parallel weights, if used (1.0 corresponds to using the normal CFG shift)
263
+ if not math.isclose(parallel_weight, 1.0):
264
+ shift_parallel, shift_orthogonal = project(shift, pred_cond_freq, self.upcast_to_double)
265
+ shift = parallel_weight * shift_parallel + shift_orthogonal
266
+
267
+ # Apply CFG update for the current frequency level
268
+ pred = pred_cond_freq if self.use_original_formulation else pred_uncond_freq
269
+ pred = pred + guidance_scale * shift
270
+
271
+ if self.guidance_rescale_space == "freq" and guidance_rescale > 0.0:
272
+ pred = rescale_noise_cfg(pred, pred_cond_freq, guidance_rescale)
273
+
274
+ # Add the current FDG guided level to the FDG prediction pyramid
275
+ pred_guided_pyramid.append(pred)
276
+ else:
277
+ # Add the current pred_cond_pyramid level as the "non-FDG" prediction
278
+ pred_guided_pyramid.append(pred_cond_freq)
279
+
280
+ # Convert from frequency space back to data (e.g. pixel) space by applying inverse freq transform
281
+ pred = build_image_from_pyramid(pred_guided_pyramid)
282
+
283
+ # If rescaling in data space, use the first elem of self.guidance_rescale as the "global" rescale value
284
+ # across all freq levels
285
+ if self.guidance_rescale_space == "data" and self.guidance_rescale[0] > 0.0:
286
+ pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale[0])
287
+
288
+ return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
289
+
290
+ @property
291
+ def is_conditional(self) -> bool:
292
+ return self._count_prepared == 1
293
+
294
+ @property
295
+ def num_conditions(self) -> int:
296
+ num_conditions = 1
297
+ if self._is_fdg_enabled():
298
+ num_conditions += 1
299
+ return num_conditions
300
+
301
+ def _is_fdg_enabled(self) -> bool:
302
+ if not self._enabled:
303
+ return False
304
+
305
+ is_within_range = True
306
+ if self._num_inference_steps is not None:
307
+ skip_start_step = int(self._start * self._num_inference_steps)
308
+ skip_stop_step = int(self._stop * self._num_inference_steps)
309
+ is_within_range = skip_start_step <= self._step < skip_stop_step
310
+
311
+ is_close = False
312
+ if self.use_original_formulation:
313
+ is_close = all(math.isclose(guidance_scale, 0.0) for guidance_scale in self.guidance_scales)
314
+ else:
315
+ is_close = all(math.isclose(guidance_scale, 1.0) for guidance_scale in self.guidance_scales)
316
+
317
+ return is_within_range and not is_close
318
+
319
+ def _is_fdg_enabled_for_level(self, level: int) -> bool:
320
+ if not self._enabled:
321
+ return False
322
+
323
+ is_within_range = True
324
+ if self._num_inference_steps is not None:
325
+ skip_start_step = int(self.guidance_start[level] * self._num_inference_steps)
326
+ skip_stop_step = int(self.guidance_stop[level] * self._num_inference_steps)
327
+ is_within_range = skip_start_step <= self._step < skip_stop_step
328
+
329
+ is_close = False
330
+ if self.use_original_formulation:
331
+ is_close = math.isclose(self.guidance_scales[level], 0.0)
332
+ else:
333
+ is_close = math.isclose(self.guidance_scales[level], 1.0)
334
+
335
+ return is_within_range and not is_close
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/guider_utils.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import os
18
+ from typing import TYPE_CHECKING, Any
19
+
20
+ import torch
21
+ from huggingface_hub.utils import validate_hf_hub_args
22
+ from typing_extensions import Self
23
+
24
+ from ..configuration_utils import ConfigMixin
25
+ from ..utils import BaseOutput, PushToHubMixin, get_logger
26
+
27
+
28
+ if TYPE_CHECKING:
29
+ from ..modular_pipelines.modular_pipeline import BlockState
30
+
31
+
32
+ GUIDER_CONFIG_NAME = "guider_config.json"
33
+
34
+
35
+ logger = get_logger(__name__) # pylint: disable=invalid-name
36
+
37
+
38
+ class BaseGuidance(ConfigMixin, PushToHubMixin):
39
+ r"""Base class providing the skeleton for implementing guidance techniques."""
40
+
41
+ config_name = GUIDER_CONFIG_NAME
42
+ _input_predictions = None
43
+ _identifier_key = "__guidance_identifier__"
44
+
45
+ def __init__(self, start: float = 0.0, stop: float = 1.0, enabled: bool = True):
46
+ logger.warning(
47
+ "Guiders are currently an experimental feature under active development. The API is subject to breaking changes in future releases."
48
+ )
49
+
50
+ self._start = start
51
+ self._stop = stop
52
+ self._step: int = None
53
+ self._num_inference_steps: int = None
54
+ self._timestep: torch.LongTensor = None
55
+ self._count_prepared = 0
56
+ self._input_fields: dict[str, str | tuple[str, str]] = None
57
+ self._enabled = enabled
58
+
59
+ if not (0.0 <= start < 1.0):
60
+ raise ValueError(f"Expected `start` to be between 0.0 and 1.0, but got {start}.")
61
+ if not (start <= stop <= 1.0):
62
+ raise ValueError(f"Expected `stop` to be between {start} and 1.0, but got {stop}.")
63
+
64
+ if self._input_predictions is None or not isinstance(self._input_predictions, list):
65
+ raise ValueError(
66
+ "`_input_predictions` must be a list of required prediction names for the guidance technique."
67
+ )
68
+
69
+ def new(self, **kwargs):
70
+ """
71
+ Creates a copy of this guider instance, optionally with modified configuration parameters.
72
+
73
+ Args:
74
+ **kwargs: Configuration parameters to override in the new instance. If no kwargs are provided,
75
+ returns an exact copy with the same configuration.
76
+
77
+ Returns:
78
+ A new guider instance with the same (or updated) configuration.
79
+
80
+ Example:
81
+ ```python
82
+ # Create a CFG guider
83
+ guider = ClassifierFreeGuidance(guidance_scale=3.5)
84
+
85
+ # Create an exact copy
86
+ same_guider = guider.new()
87
+
88
+ # Create a copy with different start step, keeping other config the same
89
+ new_guider = guider.new(guidance_scale=5)
90
+ ```
91
+ """
92
+ return self.__class__.from_config(self.config, **kwargs)
93
+
94
+ def disable(self):
95
+ self._enabled = False
96
+
97
+ def enable(self):
98
+ self._enabled = True
99
+
100
+ def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None:
101
+ self._step = step
102
+ self._num_inference_steps = num_inference_steps
103
+ self._timestep = timestep
104
+ self._count_prepared = 0
105
+
106
+ def get_state(self) -> dict[str, Any]:
107
+ """
108
+ Returns the current state of the guidance technique as a dictionary. The state variables will be included in
109
+ the __repr__ method. Returns:
110
+ `dict[str, Any]`: A dictionary containing the current state variables including:
111
+ - step: Current inference step
112
+ - num_inference_steps: Total number of inference steps
113
+ - timestep: Current timestep tensor
114
+ - count_prepared: Number of times prepare_models has been called
115
+ - enabled: Whether the guidance is enabled
116
+ - num_conditions: Number of conditions
117
+ """
118
+ state = {
119
+ "step": self._step,
120
+ "num_inference_steps": self._num_inference_steps,
121
+ "timestep": self._timestep,
122
+ "count_prepared": self._count_prepared,
123
+ "enabled": self._enabled,
124
+ "num_conditions": self.num_conditions,
125
+ }
126
+ return state
127
+
128
+ def __repr__(self) -> str:
129
+ """
130
+ Returns a string representation of the guidance object including both config and current state.
131
+ """
132
+ # Get ConfigMixin's __repr__
133
+ str_repr = super().__repr__()
134
+
135
+ # Get current state
136
+ state = self.get_state()
137
+
138
+ # Format each state variable on its own line with indentation
139
+ state_lines = []
140
+ for k, v in state.items():
141
+ # Convert value to string and handle multi-line values
142
+ v_str = str(v)
143
+ if "\n" in v_str:
144
+ # For multi-line values (like MomentumBuffer), indent subsequent lines
145
+ v_lines = v_str.split("\n")
146
+ v_str = v_lines[0] + "\n" + "\n".join([" " + line for line in v_lines[1:]])
147
+ state_lines.append(f" {k}: {v_str}")
148
+
149
+ state_str = "\n".join(state_lines)
150
+
151
+ return f"{str_repr}\nState:\n{state_str}"
152
+
153
+ def prepare_models(self, denoiser: torch.nn.Module) -> None:
154
+ """
155
+ Prepares the models for the guidance technique on a given batch of data. This method should be overridden in
156
+ subclasses to implement specific model preparation logic.
157
+ """
158
+ self._count_prepared += 1
159
+
160
+ def cleanup_models(self, denoiser: torch.nn.Module) -> None:
161
+ """
162
+ Cleans up the models for the guidance technique after a given batch of data. This method should be overridden
163
+ in subclasses to implement specific model cleanup logic. It is useful for removing any hooks or other stateful
164
+ modifications made during `prepare_models`.
165
+ """
166
+ pass
167
+
168
+ def prepare_inputs(self, data: "BlockState") -> list["BlockState"]:
169
+ raise NotImplementedError("BaseGuidance::prepare_inputs must be implemented in subclasses.")
170
+
171
+ def prepare_inputs_from_block_state(
172
+ self, data: "BlockState", input_fields: dict[str, str | tuple[str, str]]
173
+ ) -> list["BlockState"]:
174
+ raise NotImplementedError("BaseGuidance::prepare_inputs_from_block_state must be implemented in subclasses.")
175
+
176
+ def __call__(self, data: list["BlockState"]) -> Any:
177
+ if not all(hasattr(d, "noise_pred") for d in data):
178
+ raise ValueError("Expected all data to have `noise_pred` attribute.")
179
+ if len(data) != self.num_conditions:
180
+ raise ValueError(
181
+ f"Expected {self.num_conditions} data items, but got {len(data)}. Please check the input data."
182
+ )
183
+ forward_inputs = {getattr(d, self._identifier_key): d.noise_pred for d in data}
184
+ return self.forward(**forward_inputs)
185
+
186
+ def forward(self, *args, **kwargs) -> Any:
187
+ raise NotImplementedError("BaseGuidance::forward must be implemented in subclasses.")
188
+
189
+ @property
190
+ def is_conditional(self) -> bool:
191
+ raise NotImplementedError("BaseGuidance::is_conditional must be implemented in subclasses.")
192
+
193
+ @property
194
+ def is_unconditional(self) -> bool:
195
+ return not self.is_conditional
196
+
197
+ @property
198
+ def num_conditions(self) -> int:
199
+ raise NotImplementedError("BaseGuidance::num_conditions must be implemented in subclasses.")
200
+
201
+ @classmethod
202
+ def _prepare_batch(
203
+ cls,
204
+ data: dict[str, tuple[torch.Tensor, torch.Tensor]],
205
+ tuple_index: int,
206
+ identifier: str,
207
+ ) -> "BlockState":
208
+ """
209
+ Prepares a batch of data for the guidance technique. This method is used in the `prepare_inputs` method of the
210
+ `BaseGuidance` class. It prepares the batch based on the provided tuple index.
211
+
212
+ Args:
213
+ input_fields (`dict[str, str | tuple[str, str]]`):
214
+ A dictionary where the keys are the names of the fields that will be used to store the data once it is
215
+ prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, which is used
216
+ to look up the required data provided for preparation. If a string is provided, it will be used as the
217
+ conditional data (or unconditional if used with a guidance method that requires it). If a tuple of
218
+ length 2 is provided, the first element must be the conditional data identifier and the second element
219
+ must be the unconditional data identifier or None.
220
+ data (`BlockState`):
221
+ The input data to be prepared.
222
+ tuple_index (`int`):
223
+ The index to use when accessing input fields that are tuples.
224
+
225
+ Returns:
226
+ `BlockState`: The prepared batch of data.
227
+ """
228
+ from ..modular_pipelines.modular_pipeline import BlockState
229
+
230
+ data_batch = {}
231
+ for key, value in data.items():
232
+ try:
233
+ if isinstance(value, torch.Tensor):
234
+ data_batch[key] = value
235
+ elif isinstance(value, tuple):
236
+ data_batch[key] = value[tuple_index]
237
+ else:
238
+ raise ValueError(f"Invalid value type: {type(value)}")
239
+ except ValueError:
240
+ logger.debug(f"`data` does not have attribute(s) {value}, skipping.")
241
+ data_batch[cls._identifier_key] = identifier
242
+ return BlockState(**data_batch)
243
+
244
+ @classmethod
245
+ def _prepare_batch_from_block_state(
246
+ cls,
247
+ input_fields: dict[str, str | tuple[str, str]],
248
+ data: "BlockState",
249
+ tuple_index: int,
250
+ identifier: str,
251
+ ) -> "BlockState":
252
+ """
253
+ Prepares a batch of data for the guidance technique. This method is used in the `prepare_inputs` method of the
254
+ `BaseGuidance` class. It prepares the batch based on the provided tuple index.
255
+
256
+ Args:
257
+ input_fields (`dict[str, str | tuple[str, str]]`):
258
+ A dictionary where the keys are the names of the fields that will be used to store the data once it is
259
+ prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, which is used
260
+ to look up the required data provided for preparation. If a string is provided, it will be used as the
261
+ conditional data (or unconditional if used with a guidance method that requires it). If a tuple of
262
+ length 2 is provided, the first element must be the conditional data identifier and the second element
263
+ must be the unconditional data identifier or None.
264
+ data (`BlockState`):
265
+ The input data to be prepared.
266
+ tuple_index (`int`):
267
+ The index to use when accessing input fields that are tuples.
268
+
269
+ Returns:
270
+ `BlockState`: The prepared batch of data.
271
+ """
272
+ from ..modular_pipelines.modular_pipeline import BlockState
273
+
274
+ data_batch = {}
275
+ for key, value in input_fields.items():
276
+ try:
277
+ if isinstance(value, str):
278
+ data_batch[key] = getattr(data, value)
279
+ elif isinstance(value, tuple):
280
+ data_batch[key] = getattr(data, value[tuple_index])
281
+ else:
282
+ # We've already checked that value is a string or a tuple of strings with length 2
283
+ pass
284
+ except AttributeError:
285
+ logger.debug(f"`data` does not have attribute(s) {value}, skipping.")
286
+ data_batch[cls._identifier_key] = identifier
287
+ return BlockState(**data_batch)
288
+
289
+ @classmethod
290
+ @validate_hf_hub_args
291
+ def from_pretrained(
292
+ cls,
293
+ pretrained_model_name_or_path: str | os.PathLike | None = None,
294
+ subfolder: str | None = None,
295
+ return_unused_kwargs=False,
296
+ **kwargs,
297
+ ) -> Self:
298
+ r"""
299
+ Instantiate a guider from a pre-defined JSON configuration file in a local directory or Hub repository.
300
+
301
+ Parameters:
302
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
303
+ Can be either:
304
+
305
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
306
+ the Hub.
307
+ - A path to a *directory* (for example `./my_model_directory`) containing the guider configuration
308
+ saved with [`~BaseGuidance.save_pretrained`].
309
+ subfolder (`str`, *optional*):
310
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
311
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
312
+ Whether kwargs that are not consumed by the Python class should be returned or not.
313
+ cache_dir (`str | os.PathLike`, *optional*):
314
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
315
+ is not used.
316
+ force_download (`bool`, *optional*, defaults to `False`):
317
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
318
+ cached versions if they exist.
319
+
320
+ proxies (`dict[str, str]`, *optional*):
321
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
322
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
323
+ output_loading_info(`bool`, *optional*, defaults to `False`):
324
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
325
+ local_files_only(`bool`, *optional*, defaults to `False`):
326
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
327
+ won't be downloaded from the Hub.
328
+ token (`str` or *bool*, *optional*):
329
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
330
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
331
+ revision (`str`, *optional*, defaults to `"main"`):
332
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
333
+ allowed by Git.
334
+
335
+ > [!TIP] > To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in
336
+ with `hf > auth login`. You can also activate the special >
337
+ ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a >
338
+ firewalled environment.
339
+
340
+ """
341
+ config, kwargs, commit_hash = cls.load_config(
342
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
343
+ subfolder=subfolder,
344
+ return_unused_kwargs=True,
345
+ return_commit_hash=True,
346
+ **kwargs,
347
+ )
348
+ return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs)
349
+
350
+ def save_pretrained(self, save_directory: str | os.PathLike, push_to_hub: bool = False, **kwargs):
351
+ """
352
+ Save a guider configuration object to a directory so that it can be reloaded using the
353
+ [`~BaseGuidance.from_pretrained`] class method.
354
+
355
+ Args:
356
+ save_directory (`str` or `os.PathLike`):
357
+ Directory where the configuration JSON file will be saved (will be created if it does not exist).
358
+ push_to_hub (`bool`, *optional*, defaults to `False`):
359
+ Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
360
+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
361
+ namespace).
362
+ kwargs (`dict[str, Any]`, *optional*):
363
+ Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
364
+ """
365
+ self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
366
+
367
+
368
+ class GuiderOutput(BaseOutput):
369
+ pred: torch.Tensor
370
+ pred_cond: torch.Tensor | None
371
+ pred_uncond: torch.Tensor | None
372
+
373
+
374
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
375
+ r"""
376
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
377
+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
378
+ Flawed](https://huggingface.co/papers/2305.08891).
379
+
380
+ Args:
381
+ noise_cfg (`torch.Tensor`):
382
+ The predicted noise tensor for the guided diffusion process.
383
+ noise_pred_text (`torch.Tensor`):
384
+ The predicted noise tensor for the text-guided diffusion process.
385
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
386
+ A rescale factor applied to the noise predictions.
387
+ Returns:
388
+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
389
+ """
390
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
391
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
392
+ # rescale the results from guidance (fixes overexposure)
393
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
394
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
395
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
396
+ return noise_cfg
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/magnitude_aware_guidance.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from typing import TYPE_CHECKING
17
+
18
+ import torch
19
+
20
+ from ..configuration_utils import register_to_config
21
+ from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
22
+
23
+
24
+ if TYPE_CHECKING:
25
+ from ..modular_pipelines.modular_pipeline import BlockState
26
+
27
+
28
+ class MagnitudeAwareGuidance(BaseGuidance):
29
+ """
30
+ Magnitude-Aware Mitigation for Boosted Guidance (MAMBO-G): https://huggingface.co/papers/2508.03442
31
+
32
+ Args:
33
+ guidance_scale (`float`, defaults to `10.0`):
34
+ The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
35
+ prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
36
+ deterioration of image quality.
37
+ alpha (`float`, defaults to `8.0`):
38
+ The alpha parameter for the magnitude-aware guidance. Higher values cause more aggressive supression of
39
+ guidance scale when the magnitude of the guidance update is large.
40
+ guidance_rescale (`float`, defaults to `0.0`):
41
+ The rescale factor applied to the noise predictions. This is used to improve image quality and fix
42
+ overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
43
+ Flawed](https://huggingface.co/papers/2305.08891).
44
+ use_original_formulation (`bool`, defaults to `False`):
45
+ Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
46
+ we use the diffusers-native implementation that has been in the codebase for a long time. See
47
+ [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
48
+ start (`float`, defaults to `0.0`):
49
+ The fraction of the total number of denoising steps after which guidance starts.
50
+ stop (`float`, defaults to `1.0`):
51
+ The fraction of the total number of denoising steps after which guidance stops.
52
+ """
53
+
54
+ _input_predictions = ["pred_cond", "pred_uncond"]
55
+
56
+ @register_to_config
57
+ def __init__(
58
+ self,
59
+ guidance_scale: float = 10.0,
60
+ alpha: float = 8.0,
61
+ guidance_rescale: float = 0.0,
62
+ use_original_formulation: bool = False,
63
+ start: float = 0.0,
64
+ stop: float = 1.0,
65
+ enabled: bool = True,
66
+ ):
67
+ super().__init__(start, stop, enabled)
68
+
69
+ self.guidance_scale = guidance_scale
70
+ self.alpha = alpha
71
+ self.guidance_rescale = guidance_rescale
72
+ self.use_original_formulation = use_original_formulation
73
+
74
+ def prepare_inputs(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]) -> list["BlockState"]:
75
+ tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
76
+ data_batches = []
77
+ for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
78
+ data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
79
+ data_batches.append(data_batch)
80
+ return data_batches
81
+
82
+ def prepare_inputs_from_block_state(
83
+ self, data: "BlockState", input_fields: dict[str, str | tuple[str, str]]
84
+ ) -> list["BlockState"]:
85
+ tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
86
+ data_batches = []
87
+ for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
88
+ data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
89
+ data_batches.append(data_batch)
90
+ return data_batches
91
+
92
+ def forward(self, pred_cond: torch.Tensor, pred_uncond: torch.Tensor | None = None) -> GuiderOutput:
93
+ pred = None
94
+
95
+ if not self._is_mambo_g_enabled():
96
+ pred = pred_cond
97
+ else:
98
+ pred = mambo_guidance(
99
+ pred_cond,
100
+ pred_uncond,
101
+ self.guidance_scale,
102
+ self.alpha,
103
+ self.use_original_formulation,
104
+ )
105
+
106
+ if self.guidance_rescale > 0.0:
107
+ pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
108
+
109
+ return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
110
+
111
+ @property
112
+ def is_conditional(self) -> bool:
113
+ return self._count_prepared == 1
114
+
115
+ @property
116
+ def num_conditions(self) -> int:
117
+ num_conditions = 1
118
+ if self._is_mambo_g_enabled():
119
+ num_conditions += 1
120
+ return num_conditions
121
+
122
+ def _is_mambo_g_enabled(self) -> bool:
123
+ if not self._enabled:
124
+ return False
125
+
126
+ is_within_range = True
127
+ if self._num_inference_steps is not None:
128
+ skip_start_step = int(self._start * self._num_inference_steps)
129
+ skip_stop_step = int(self._stop * self._num_inference_steps)
130
+ is_within_range = skip_start_step <= self._step < skip_stop_step
131
+
132
+ is_close = False
133
+ if self.use_original_formulation:
134
+ is_close = math.isclose(self.guidance_scale, 0.0)
135
+ else:
136
+ is_close = math.isclose(self.guidance_scale, 1.0)
137
+
138
+ return is_within_range and not is_close
139
+
140
+
141
+ def mambo_guidance(
142
+ pred_cond: torch.Tensor,
143
+ pred_uncond: torch.Tensor,
144
+ guidance_scale: float,
145
+ alpha: float = 8.0,
146
+ use_original_formulation: bool = False,
147
+ ):
148
+ dim = list(range(1, len(pred_cond.shape)))
149
+ diff = pred_cond - pred_uncond
150
+ ratio = torch.norm(diff, dim=dim, keepdim=True) / torch.norm(pred_uncond, dim=dim, keepdim=True)
151
+ guidance_scale_final = (
152
+ guidance_scale * torch.exp(-alpha * ratio)
153
+ if use_original_formulation
154
+ else 1.0 + (guidance_scale - 1.0) * torch.exp(-alpha * ratio)
155
+ )
156
+ pred = pred_cond if use_original_formulation else pred_uncond
157
+ pred = pred + guidance_scale_final * diff
158
+
159
+ return pred
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/perturbed_attention_guidance.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import math
18
+ from typing import TYPE_CHECKING, Any
19
+
20
+ import torch
21
+
22
+ from ..configuration_utils import register_to_config
23
+ from ..hooks import HookRegistry, LayerSkipConfig
24
+ from ..hooks.layer_skip import _apply_layer_skip_hook
25
+ from ..utils import get_logger
26
+ from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
27
+
28
+
29
+ if TYPE_CHECKING:
30
+ from ..modular_pipelines.modular_pipeline import BlockState
31
+
32
+
33
+ logger = get_logger(__name__) # pylint: disable=invalid-name
34
+
35
+
36
+ class PerturbedAttentionGuidance(BaseGuidance):
37
+ """
38
+ Perturbed Attention Guidance (PAG): https://huggingface.co/papers/2403.17377
39
+
40
+ The intution behind PAG can be thought of as moving the CFG predicted distribution estimates further away from
41
+ worse versions of the conditional distribution estimates. PAG was one of the first techniques to introduce the idea
42
+ of using a worse version of the trained model for better guiding itself in the denoising process. It perturbs the
43
+ attention scores of the latent stream by replacing the score matrix with an identity matrix for selectively chosen
44
+ layers.
45
+
46
+ Additional reading:
47
+ - [Guiding a Diffusion Model with a Bad Version of Itself](https://huggingface.co/papers/2406.02507)
48
+
49
+ PAG is implemented with similar implementation to SkipLayerGuidance due to overlap in the configuration parameters
50
+ and implementation details.
51
+
52
+ Args:
53
+ guidance_scale (`float`, defaults to `7.5`):
54
+ The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
55
+ prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
56
+ deterioration of image quality.
57
+ perturbed_guidance_scale (`float`, defaults to `2.8`):
58
+ The scale parameter for perturbed attention guidance.
59
+ perturbed_guidance_start (`float`, defaults to `0.01`):
60
+ The fraction of the total number of denoising steps after which perturbed attention guidance starts.
61
+ perturbed_guidance_stop (`float`, defaults to `0.2`):
62
+ The fraction of the total number of denoising steps after which perturbed attention guidance stops.
63
+ perturbed_guidance_layers (`int` or `list[int]`, *optional*):
64
+ The layer indices to apply perturbed attention guidance to. Can be a single integer or a list of integers.
65
+ If not provided, `perturbed_guidance_config` must be provided.
66
+ perturbed_guidance_config (`LayerSkipConfig` or `list[LayerSkipConfig]`, *optional*):
67
+ The configuration for the perturbed attention guidance. Can be a single `LayerSkipConfig` or a list of
68
+ `LayerSkipConfig`. If not provided, `perturbed_guidance_layers` must be provided.
69
+ guidance_rescale (`float`, defaults to `0.0`):
70
+ The rescale factor applied to the noise predictions. This is used to improve image quality and fix
71
+ overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
72
+ Flawed](https://huggingface.co/papers/2305.08891).
73
+ use_original_formulation (`bool`, defaults to `False`):
74
+ Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
75
+ we use the diffusers-native implementation that has been in the codebase for a long time. See
76
+ [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
77
+ start (`float`, defaults to `0.01`):
78
+ The fraction of the total number of denoising steps after which guidance starts.
79
+ stop (`float`, defaults to `0.2`):
80
+ The fraction of the total number of denoising steps after which guidance stops.
81
+ """
82
+
83
+ # NOTE: The current implementation does not account for joint latent conditioning (text + image/video tokens in
84
+ # the same latent stream). It assumes the entire latent is a single stream of visual tokens. It would be very
85
+ # complex to support joint latent conditioning in a model-agnostic manner without specializing the implementation
86
+ # for each model architecture.
87
+
88
+ _input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
89
+
90
+ @register_to_config
91
+ def __init__(
92
+ self,
93
+ guidance_scale: float = 7.5,
94
+ perturbed_guidance_scale: float = 2.8,
95
+ perturbed_guidance_start: float = 0.01,
96
+ perturbed_guidance_stop: float = 0.2,
97
+ perturbed_guidance_layers: int | list[int] | None = None,
98
+ perturbed_guidance_config: LayerSkipConfig | list[LayerSkipConfig] | dict[str, Any] = None,
99
+ guidance_rescale: float = 0.0,
100
+ use_original_formulation: bool = False,
101
+ start: float = 0.0,
102
+ stop: float = 1.0,
103
+ enabled: bool = True,
104
+ ):
105
+ super().__init__(start, stop, enabled)
106
+
107
+ self.guidance_scale = guidance_scale
108
+ self.skip_layer_guidance_scale = perturbed_guidance_scale
109
+ self.skip_layer_guidance_start = perturbed_guidance_start
110
+ self.skip_layer_guidance_stop = perturbed_guidance_stop
111
+ self.guidance_rescale = guidance_rescale
112
+ self.use_original_formulation = use_original_formulation
113
+
114
+ if perturbed_guidance_config is None:
115
+ if perturbed_guidance_layers is None:
116
+ raise ValueError(
117
+ "`perturbed_guidance_layers` must be provided if `perturbed_guidance_config` is not specified."
118
+ )
119
+ perturbed_guidance_config = LayerSkipConfig(
120
+ indices=perturbed_guidance_layers,
121
+ fqn="auto",
122
+ skip_attention=False,
123
+ skip_attention_scores=True,
124
+ skip_ff=False,
125
+ )
126
+ else:
127
+ if perturbed_guidance_layers is not None:
128
+ raise ValueError(
129
+ "`perturbed_guidance_layers` should not be provided if `perturbed_guidance_config` is specified."
130
+ )
131
+
132
+ if isinstance(perturbed_guidance_config, dict):
133
+ perturbed_guidance_config = LayerSkipConfig.from_dict(perturbed_guidance_config)
134
+
135
+ if isinstance(perturbed_guidance_config, LayerSkipConfig):
136
+ perturbed_guidance_config = [perturbed_guidance_config]
137
+
138
+ if not isinstance(perturbed_guidance_config, list):
139
+ raise ValueError(
140
+ "`perturbed_guidance_config` must be a `LayerSkipConfig`, a list of `LayerSkipConfig`, or a dict that can be converted to a `LayerSkipConfig`."
141
+ )
142
+ elif isinstance(next(iter(perturbed_guidance_config), None), dict):
143
+ perturbed_guidance_config = [LayerSkipConfig.from_dict(config) for config in perturbed_guidance_config]
144
+
145
+ for config in perturbed_guidance_config:
146
+ if config.skip_attention or not config.skip_attention_scores or config.skip_ff:
147
+ logger.warning(
148
+ "Perturbed Attention Guidance is designed to perturb attention scores, so `skip_attention` should be False, `skip_attention_scores` should be True, and `skip_ff` should be False. "
149
+ "Please check your configuration. Modifying the config to match the expected values."
150
+ )
151
+ config.skip_attention = False
152
+ config.skip_attention_scores = True
153
+ config.skip_ff = False
154
+
155
+ self.skip_layer_config = perturbed_guidance_config
156
+ self._skip_layer_hook_names = [f"SkipLayerGuidance_{i}" for i in range(len(self.skip_layer_config))]
157
+
158
+ # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.prepare_models
159
+ def prepare_models(self, denoiser: torch.nn.Module) -> None:
160
+ self._count_prepared += 1
161
+ if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
162
+ for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config):
163
+ _apply_layer_skip_hook(denoiser, config, name=name)
164
+
165
+ # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.cleanup_models
166
+ def cleanup_models(self, denoiser: torch.nn.Module) -> None:
167
+ if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
168
+ registry = HookRegistry.check_if_exists_or_initialize(denoiser)
169
+ # Remove the hooks after inference
170
+ for hook_name in self._skip_layer_hook_names:
171
+ registry.remove_hook(hook_name, recurse=True)
172
+
173
+ # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.prepare_inputs
174
+ def prepare_inputs(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]) -> list["BlockState"]:
175
+ if self.num_conditions == 1:
176
+ tuple_indices = [0]
177
+ input_predictions = ["pred_cond"]
178
+ elif self.num_conditions == 2:
179
+ tuple_indices = [0, 1]
180
+ input_predictions = (
181
+ ["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"]
182
+ )
183
+ else:
184
+ tuple_indices = [0, 1, 0]
185
+ input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
186
+ data_batches = []
187
+ for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
188
+ data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
189
+ data_batches.append(data_batch)
190
+ return data_batches
191
+
192
+ def prepare_inputs_from_block_state(
193
+ self, data: "BlockState", input_fields: dict[str, str | tuple[str, str]]
194
+ ) -> list["BlockState"]:
195
+ if self.num_conditions == 1:
196
+ tuple_indices = [0]
197
+ input_predictions = ["pred_cond"]
198
+ elif self.num_conditions == 2:
199
+ tuple_indices = [0, 1]
200
+ input_predictions = (
201
+ ["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"]
202
+ )
203
+ else:
204
+ tuple_indices = [0, 1, 0]
205
+ input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
206
+ data_batches = []
207
+ for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
208
+ data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
209
+ data_batches.append(data_batch)
210
+ return data_batches
211
+
212
+ # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.forward
213
+ def forward(
214
+ self,
215
+ pred_cond: torch.Tensor,
216
+ pred_uncond: torch.Tensor | None = None,
217
+ pred_cond_skip: torch.Tensor | None = None,
218
+ ) -> GuiderOutput:
219
+ pred = None
220
+
221
+ if not self._is_cfg_enabled() and not self._is_slg_enabled():
222
+ pred = pred_cond
223
+ elif not self._is_cfg_enabled():
224
+ shift = pred_cond - pred_cond_skip
225
+ pred = pred_cond if self.use_original_formulation else pred_cond_skip
226
+ pred = pred + self.skip_layer_guidance_scale * shift
227
+ elif not self._is_slg_enabled():
228
+ shift = pred_cond - pred_uncond
229
+ pred = pred_cond if self.use_original_formulation else pred_uncond
230
+ pred = pred + self.guidance_scale * shift
231
+ else:
232
+ shift = pred_cond - pred_uncond
233
+ shift_skip = pred_cond - pred_cond_skip
234
+ pred = pred_cond if self.use_original_formulation else pred_uncond
235
+ pred = pred + self.guidance_scale * shift + self.skip_layer_guidance_scale * shift_skip
236
+
237
+ if self.guidance_rescale > 0.0:
238
+ pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
239
+
240
+ return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
241
+
242
+ @property
243
+ # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.is_conditional
244
+ def is_conditional(self) -> bool:
245
+ return self._count_prepared == 1 or self._count_prepared == 3
246
+
247
+ @property
248
+ # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.num_conditions
249
+ def num_conditions(self) -> int:
250
+ num_conditions = 1
251
+ if self._is_cfg_enabled():
252
+ num_conditions += 1
253
+ if self._is_slg_enabled():
254
+ num_conditions += 1
255
+ return num_conditions
256
+
257
+ # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance._is_cfg_enabled
258
+ def _is_cfg_enabled(self) -> bool:
259
+ if not self._enabled:
260
+ return False
261
+
262
+ is_within_range = True
263
+ if self._num_inference_steps is not None:
264
+ skip_start_step = int(self._start * self._num_inference_steps)
265
+ skip_stop_step = int(self._stop * self._num_inference_steps)
266
+ is_within_range = skip_start_step <= self._step < skip_stop_step
267
+
268
+ is_close = False
269
+ if self.use_original_formulation:
270
+ is_close = math.isclose(self.guidance_scale, 0.0)
271
+ else:
272
+ is_close = math.isclose(self.guidance_scale, 1.0)
273
+
274
+ return is_within_range and not is_close
275
+
276
+ # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance._is_slg_enabled
277
+ def _is_slg_enabled(self) -> bool:
278
+ if not self._enabled:
279
+ return False
280
+
281
+ is_within_range = True
282
+ if self._num_inference_steps is not None:
283
+ skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
284
+ skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
285
+ is_within_range = skip_start_step < self._step < skip_stop_step
286
+
287
+ is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0)
288
+
289
+ return is_within_range and not is_zero
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/skip_layer_guidance.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import math
18
+ from typing import TYPE_CHECKING, Any
19
+
20
+ import torch
21
+
22
+ from ..configuration_utils import register_to_config
23
+ from ..hooks import HookRegistry, LayerSkipConfig
24
+ from ..hooks.layer_skip import _apply_layer_skip_hook
25
+ from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
26
+
27
+
28
+ if TYPE_CHECKING:
29
+ from ..modular_pipelines.modular_pipeline import BlockState
30
+
31
+
32
+ class SkipLayerGuidance(BaseGuidance):
33
+ """
34
+ Skip Layer Guidance (SLG): https://github.com/Stability-AI/sd3.5
35
+
36
+ Spatio-Temporal Guidance (STG): https://huggingface.co/papers/2411.18664
37
+
38
+ SLG was introduced by StabilityAI for improving structure and anotomy coherence in generated images. It works by
39
+ skipping the forward pass of specified transformer blocks during the denoising process on an additional conditional
40
+ batch of data, apart from the conditional and unconditional batches already used in CFG
41
+ ([~guiders.classifier_free_guidance.ClassifierFreeGuidance]), and then scaling and shifting the CFG predictions
42
+ based on the difference between conditional without skipping and conditional with skipping predictions.
43
+
44
+ The intution behind SLG can be thought of as moving the CFG predicted distribution estimates further away from
45
+ worse versions of the conditional distribution estimates (because skipping layers is equivalent to using a worse
46
+ version of the model for the conditional prediction).
47
+
48
+ STG is an improvement and follow-up work combining ideas from SLG, PAG and similar techniques for improving
49
+ generation quality in video diffusion models.
50
+
51
+ Additional reading:
52
+ - [Guiding a Diffusion Model with a Bad Version of Itself](https://huggingface.co/papers/2406.02507)
53
+
54
+ The values for `skip_layer_guidance_scale`, `skip_layer_guidance_start`, and `skip_layer_guidance_stop` are
55
+ defaulted to the recommendations by StabilityAI for Stable Diffusion 3.5 Medium.
56
+
57
+ Args:
58
+ guidance_scale (`float`, defaults to `7.5`):
59
+ The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
60
+ prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
61
+ deterioration of image quality.
62
+ skip_layer_guidance_scale (`float`, defaults to `2.8`):
63
+ The scale parameter for skip layer guidance. Anatomy and structure coherence may improve with higher
64
+ values, but it may also lead to overexposure and saturation.
65
+ skip_layer_guidance_start (`float`, defaults to `0.01`):
66
+ The fraction of the total number of denoising steps after which skip layer guidance starts.
67
+ skip_layer_guidance_stop (`float`, defaults to `0.2`):
68
+ The fraction of the total number of denoising steps after which skip layer guidance stops.
69
+ skip_layer_guidance_layers (`int` or `list[int]`, *optional*):
70
+ The layer indices to apply skip layer guidance to. Can be a single integer or a list of integers. If not
71
+ provided, `skip_layer_config` must be provided. The recommended values are `[7, 8, 9]` for Stable Diffusion
72
+ 3.5 Medium.
73
+ skip_layer_config (`LayerSkipConfig` or `list[LayerSkipConfig]`, *optional*):
74
+ The configuration for the skip layer guidance. Can be a single `LayerSkipConfig` or a list of
75
+ `LayerSkipConfig`. If not provided, `skip_layer_guidance_layers` must be provided.
76
+ guidance_rescale (`float`, defaults to `0.0`):
77
+ The rescale factor applied to the noise predictions. This is used to improve image quality and fix
78
+ overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
79
+ Flawed](https://huggingface.co/papers/2305.08891).
80
+ use_original_formulation (`bool`, defaults to `False`):
81
+ Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
82
+ we use the diffusers-native implementation that has been in the codebase for a long time. See
83
+ [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
84
+ start (`float`, defaults to `0.01`):
85
+ The fraction of the total number of denoising steps after which guidance starts.
86
+ stop (`float`, defaults to `0.2`):
87
+ The fraction of the total number of denoising steps after which guidance stops.
88
+ """
89
+
90
+ _input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
91
+
92
+ @register_to_config
93
+ def __init__(
94
+ self,
95
+ guidance_scale: float = 7.5,
96
+ skip_layer_guidance_scale: float = 2.8,
97
+ skip_layer_guidance_start: float = 0.01,
98
+ skip_layer_guidance_stop: float = 0.2,
99
+ skip_layer_guidance_layers: int | list[int] | None = None,
100
+ skip_layer_config: LayerSkipConfig | list[LayerSkipConfig] | dict[str, Any] = None,
101
+ guidance_rescale: float = 0.0,
102
+ use_original_formulation: bool = False,
103
+ start: float = 0.0,
104
+ stop: float = 1.0,
105
+ enabled: bool = True,
106
+ ):
107
+ super().__init__(start, stop, enabled)
108
+
109
+ self.guidance_scale = guidance_scale
110
+ self.skip_layer_guidance_scale = skip_layer_guidance_scale
111
+ self.skip_layer_guidance_start = skip_layer_guidance_start
112
+ self.skip_layer_guidance_stop = skip_layer_guidance_stop
113
+ self.guidance_rescale = guidance_rescale
114
+ self.use_original_formulation = use_original_formulation
115
+
116
+ if not (0.0 <= skip_layer_guidance_start < 1.0):
117
+ raise ValueError(
118
+ f"Expected `skip_layer_guidance_start` to be between 0.0 and 1.0, but got {skip_layer_guidance_start}."
119
+ )
120
+ if not (skip_layer_guidance_start <= skip_layer_guidance_stop <= 1.0):
121
+ raise ValueError(
122
+ f"Expected `skip_layer_guidance_stop` to be between 0.0 and 1.0, but got {skip_layer_guidance_stop}."
123
+ )
124
+
125
+ if skip_layer_guidance_layers is None and skip_layer_config is None:
126
+ raise ValueError(
127
+ "Either `skip_layer_guidance_layers` or `skip_layer_config` must be provided to enable Skip Layer Guidance."
128
+ )
129
+ if skip_layer_guidance_layers is not None and skip_layer_config is not None:
130
+ raise ValueError("Only one of `skip_layer_guidance_layers` or `skip_layer_config` can be provided.")
131
+
132
+ if skip_layer_guidance_layers is not None:
133
+ if isinstance(skip_layer_guidance_layers, int):
134
+ skip_layer_guidance_layers = [skip_layer_guidance_layers]
135
+ if not isinstance(skip_layer_guidance_layers, list):
136
+ raise ValueError(
137
+ f"Expected `skip_layer_guidance_layers` to be an int or a list of ints, but got {type(skip_layer_guidance_layers)}."
138
+ )
139
+ skip_layer_config = [LayerSkipConfig(layer, fqn="auto") for layer in skip_layer_guidance_layers]
140
+
141
+ if isinstance(skip_layer_config, dict):
142
+ skip_layer_config = LayerSkipConfig.from_dict(skip_layer_config)
143
+
144
+ if isinstance(skip_layer_config, LayerSkipConfig):
145
+ skip_layer_config = [skip_layer_config]
146
+
147
+ if not isinstance(skip_layer_config, list):
148
+ raise ValueError(
149
+ f"Expected `skip_layer_config` to be a LayerSkipConfig or a list of LayerSkipConfig, but got {type(skip_layer_config)}."
150
+ )
151
+ elif isinstance(next(iter(skip_layer_config), None), dict):
152
+ skip_layer_config = [LayerSkipConfig.from_dict(config) for config in skip_layer_config]
153
+
154
+ self.skip_layer_config = skip_layer_config
155
+ self._skip_layer_hook_names = [f"SkipLayerGuidance_{i}" for i in range(len(self.skip_layer_config))]
156
+
157
+ def prepare_models(self, denoiser: torch.nn.Module) -> None:
158
+ self._count_prepared += 1
159
+ if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
160
+ for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config):
161
+ _apply_layer_skip_hook(denoiser, config, name=name)
162
+
163
+ def cleanup_models(self, denoiser: torch.nn.Module) -> None:
164
+ if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
165
+ registry = HookRegistry.check_if_exists_or_initialize(denoiser)
166
+ # Remove the hooks after inference
167
+ for hook_name in self._skip_layer_hook_names:
168
+ registry.remove_hook(hook_name, recurse=True)
169
+
170
+ def prepare_inputs(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]) -> list["BlockState"]:
171
+ if self.num_conditions == 1:
172
+ tuple_indices = [0]
173
+ input_predictions = ["pred_cond"]
174
+ elif self.num_conditions == 2:
175
+ tuple_indices = [0, 1]
176
+ input_predictions = (
177
+ ["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"]
178
+ )
179
+ else:
180
+ tuple_indices = [0, 1, 0]
181
+ input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
182
+ data_batches = []
183
+ for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
184
+ data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
185
+ data_batches.append(data_batch)
186
+ return data_batches
187
+
188
+ def prepare_inputs_from_block_state(
189
+ self, data: "BlockState", input_fields: dict[str, str | tuple[str, str]]
190
+ ) -> list["BlockState"]:
191
+ if self.num_conditions == 1:
192
+ tuple_indices = [0]
193
+ input_predictions = ["pred_cond"]
194
+ elif self.num_conditions == 2:
195
+ tuple_indices = [0, 1]
196
+ input_predictions = (
197
+ ["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"]
198
+ )
199
+ else:
200
+ tuple_indices = [0, 1, 0]
201
+ input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
202
+ data_batches = []
203
+ for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
204
+ data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
205
+ data_batches.append(data_batch)
206
+ return data_batches
207
+
208
+ def forward(
209
+ self,
210
+ pred_cond: torch.Tensor,
211
+ pred_uncond: torch.Tensor | None = None,
212
+ pred_cond_skip: torch.Tensor | None = None,
213
+ ) -> GuiderOutput:
214
+ pred = None
215
+
216
+ if not self._is_cfg_enabled() and not self._is_slg_enabled():
217
+ pred = pred_cond
218
+ elif not self._is_cfg_enabled():
219
+ shift = pred_cond - pred_cond_skip
220
+ pred = pred_cond if self.use_original_formulation else pred_cond_skip
221
+ pred = pred + self.skip_layer_guidance_scale * shift
222
+ elif not self._is_slg_enabled():
223
+ shift = pred_cond - pred_uncond
224
+ pred = pred_cond if self.use_original_formulation else pred_uncond
225
+ pred = pred + self.guidance_scale * shift
226
+ else:
227
+ shift = pred_cond - pred_uncond
228
+ shift_skip = pred_cond - pred_cond_skip
229
+ pred = pred_cond if self.use_original_formulation else pred_uncond
230
+ pred = pred + self.guidance_scale * shift + self.skip_layer_guidance_scale * shift_skip
231
+
232
+ if self.guidance_rescale > 0.0:
233
+ pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
234
+
235
+ return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
236
+
237
+ @property
238
+ def is_conditional(self) -> bool:
239
+ return self._count_prepared == 1 or self._count_prepared == 3
240
+
241
+ @property
242
+ def num_conditions(self) -> int:
243
+ num_conditions = 1
244
+ if self._is_cfg_enabled():
245
+ num_conditions += 1
246
+ if self._is_slg_enabled():
247
+ num_conditions += 1
248
+ return num_conditions
249
+
250
+ def _is_cfg_enabled(self) -> bool:
251
+ if not self._enabled:
252
+ return False
253
+
254
+ is_within_range = True
255
+ if self._num_inference_steps is not None:
256
+ skip_start_step = int(self._start * self._num_inference_steps)
257
+ skip_stop_step = int(self._stop * self._num_inference_steps)
258
+ is_within_range = skip_start_step <= self._step < skip_stop_step
259
+
260
+ is_close = False
261
+ if self.use_original_formulation:
262
+ is_close = math.isclose(self.guidance_scale, 0.0)
263
+ else:
264
+ is_close = math.isclose(self.guidance_scale, 1.0)
265
+
266
+ return is_within_range and not is_close
267
+
268
+ def _is_slg_enabled(self) -> bool:
269
+ if not self._enabled:
270
+ return False
271
+
272
+ is_within_range = True
273
+ if self._num_inference_steps is not None:
274
+ skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
275
+ skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
276
+ is_within_range = skip_start_step < self._step < skip_stop_step
277
+
278
+ is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0)
279
+
280
+ return is_within_range and not is_zero
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/smoothed_energy_guidance.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import math
18
+ from typing import TYPE_CHECKING
19
+
20
+ import torch
21
+
22
+ from ..configuration_utils import register_to_config
23
+ from ..hooks import HookRegistry
24
+ from ..hooks.smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig, _apply_smoothed_energy_guidance_hook
25
+ from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
26
+
27
+
28
+ if TYPE_CHECKING:
29
+ from ..modular_pipelines.modular_pipeline import BlockState
30
+
31
+
32
+ class SmoothedEnergyGuidance(BaseGuidance):
33
+ """
34
+ Smoothed Energy Guidance (SEG): https://huggingface.co/papers/2408.00760
35
+
36
+ SEG is only supported as an experimental prototype feature for now, so the implementation may be modified in the
37
+ future without warning or guarantee of reproducibility. This implementation assumes:
38
+ - Generated images are square (height == width)
39
+ - The model does not combine different modalities together (e.g., text and image latent streams are not combined
40
+ together such as Flux)
41
+
42
+ Args:
43
+ guidance_scale (`float`, defaults to `7.5`):
44
+ The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
45
+ prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
46
+ deterioration of image quality.
47
+ seg_guidance_scale (`float`, defaults to `3.0`):
48
+ The scale parameter for smoothed energy guidance. Anatomy and structure coherence may improve with higher
49
+ values, but it may also lead to overexposure and saturation.
50
+ seg_blur_sigma (`float`, defaults to `9999999.0`):
51
+ The amount by which we blur the attention weights. Setting this value greater than 9999.0 results in
52
+ infinite blur, which means uniform queries. Controlling it exponentially is empirically effective.
53
+ seg_blur_threshold_inf (`float`, defaults to `9999.0`):
54
+ The threshold above which the blur is considered infinite.
55
+ seg_guidance_start (`float`, defaults to `0.0`):
56
+ The fraction of the total number of denoising steps after which smoothed energy guidance starts.
57
+ seg_guidance_stop (`float`, defaults to `1.0`):
58
+ The fraction of the total number of denoising steps after which smoothed energy guidance stops.
59
+ seg_guidance_layers (`int` or `list[int]`, *optional*):
60
+ The layer indices to apply smoothed energy guidance to. Can be a single integer or a list of integers. If
61
+ not provided, `seg_guidance_config` must be provided. The recommended values are `[7, 8, 9]` for Stable
62
+ Diffusion 3.5 Medium.
63
+ seg_guidance_config (`SmoothedEnergyGuidanceConfig` or `list[SmoothedEnergyGuidanceConfig]`, *optional*):
64
+ The configuration for the smoothed energy layer guidance. Can be a single `SmoothedEnergyGuidanceConfig` or
65
+ a list of `SmoothedEnergyGuidanceConfig`. If not provided, `seg_guidance_layers` must be provided.
66
+ guidance_rescale (`float`, defaults to `0.0`):
67
+ The rescale factor applied to the noise predictions. This is used to improve image quality and fix
68
+ overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
69
+ Flawed](https://huggingface.co/papers/2305.08891).
70
+ use_original_formulation (`bool`, defaults to `False`):
71
+ Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
72
+ we use the diffusers-native implementation that has been in the codebase for a long time. See
73
+ [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
74
+ start (`float`, defaults to `0.01`):
75
+ The fraction of the total number of denoising steps after which guidance starts.
76
+ stop (`float`, defaults to `0.2`):
77
+ The fraction of the total number of denoising steps after which guidance stops.
78
+ """
79
+
80
+ _input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"]
81
+
82
+ @register_to_config
83
+ def __init__(
84
+ self,
85
+ guidance_scale: float = 7.5,
86
+ seg_guidance_scale: float = 2.8,
87
+ seg_blur_sigma: float = 9999999.0,
88
+ seg_blur_threshold_inf: float = 9999.0,
89
+ seg_guidance_start: float = 0.0,
90
+ seg_guidance_stop: float = 1.0,
91
+ seg_guidance_layers: int | list[int] | None = None,
92
+ seg_guidance_config: SmoothedEnergyGuidanceConfig | list[SmoothedEnergyGuidanceConfig] = None,
93
+ guidance_rescale: float = 0.0,
94
+ use_original_formulation: bool = False,
95
+ start: float = 0.0,
96
+ stop: float = 1.0,
97
+ enabled: bool = True,
98
+ ):
99
+ super().__init__(start, stop, enabled)
100
+
101
+ self.guidance_scale = guidance_scale
102
+ self.seg_guidance_scale = seg_guidance_scale
103
+ self.seg_blur_sigma = seg_blur_sigma
104
+ self.seg_blur_threshold_inf = seg_blur_threshold_inf
105
+ self.seg_guidance_start = seg_guidance_start
106
+ self.seg_guidance_stop = seg_guidance_stop
107
+ self.guidance_rescale = guidance_rescale
108
+ self.use_original_formulation = use_original_formulation
109
+
110
+ if not (0.0 <= seg_guidance_start < 1.0):
111
+ raise ValueError(f"Expected `seg_guidance_start` to be between 0.0 and 1.0, but got {seg_guidance_start}.")
112
+ if not (seg_guidance_start <= seg_guidance_stop <= 1.0):
113
+ raise ValueError(f"Expected `seg_guidance_stop` to be between 0.0 and 1.0, but got {seg_guidance_stop}.")
114
+
115
+ if seg_guidance_layers is None and seg_guidance_config is None:
116
+ raise ValueError(
117
+ "Either `seg_guidance_layers` or `seg_guidance_config` must be provided to enable Smoothed Energy Guidance."
118
+ )
119
+ if seg_guidance_layers is not None and seg_guidance_config is not None:
120
+ raise ValueError("Only one of `seg_guidance_layers` or `seg_guidance_config` can be provided.")
121
+
122
+ if seg_guidance_layers is not None:
123
+ if isinstance(seg_guidance_layers, int):
124
+ seg_guidance_layers = [seg_guidance_layers]
125
+ if not isinstance(seg_guidance_layers, list):
126
+ raise ValueError(
127
+ f"Expected `seg_guidance_layers` to be an int or a list of ints, but got {type(seg_guidance_layers)}."
128
+ )
129
+ seg_guidance_config = [SmoothedEnergyGuidanceConfig(layer, fqn="auto") for layer in seg_guidance_layers]
130
+
131
+ if isinstance(seg_guidance_config, dict):
132
+ seg_guidance_config = SmoothedEnergyGuidanceConfig.from_dict(seg_guidance_config)
133
+
134
+ if isinstance(seg_guidance_config, SmoothedEnergyGuidanceConfig):
135
+ seg_guidance_config = [seg_guidance_config]
136
+
137
+ if not isinstance(seg_guidance_config, list):
138
+ raise ValueError(
139
+ f"Expected `seg_guidance_config` to be a SmoothedEnergyGuidanceConfig or a list of SmoothedEnergyGuidanceConfig, but got {type(seg_guidance_config)}."
140
+ )
141
+ elif isinstance(next(iter(seg_guidance_config), None), dict):
142
+ seg_guidance_config = [SmoothedEnergyGuidanceConfig.from_dict(config) for config in seg_guidance_config]
143
+
144
+ self.seg_guidance_config = seg_guidance_config
145
+ self._seg_layer_hook_names = [f"SmoothedEnergyGuidance_{i}" for i in range(len(self.seg_guidance_config))]
146
+
147
+ def prepare_models(self, denoiser: torch.nn.Module) -> None:
148
+ if self._is_seg_enabled() and self.is_conditional and self._count_prepared > 1:
149
+ for name, config in zip(self._seg_layer_hook_names, self.seg_guidance_config):
150
+ _apply_smoothed_energy_guidance_hook(denoiser, config, self.seg_blur_sigma, name=name)
151
+
152
+ def cleanup_models(self, denoiser: torch.nn.Module):
153
+ if self._is_seg_enabled() and self.is_conditional and self._count_prepared > 1:
154
+ registry = HookRegistry.check_if_exists_or_initialize(denoiser)
155
+ # Remove the hooks after inference
156
+ for hook_name in self._seg_layer_hook_names:
157
+ registry.remove_hook(hook_name, recurse=True)
158
+
159
+ def prepare_inputs(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]) -> list["BlockState"]:
160
+ if self.num_conditions == 1:
161
+ tuple_indices = [0]
162
+ input_predictions = ["pred_cond"]
163
+ elif self.num_conditions == 2:
164
+ tuple_indices = [0, 1]
165
+ input_predictions = (
166
+ ["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_seg"]
167
+ )
168
+ else:
169
+ tuple_indices = [0, 1, 0]
170
+ input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"]
171
+ data_batches = []
172
+ for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
173
+ data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
174
+ data_batches.append(data_batch)
175
+ return data_batches
176
+
177
+ def prepare_inputs_from_block_state(
178
+ self, data: "BlockState", input_fields: dict[str, str | tuple[str, str]]
179
+ ) -> list["BlockState"]:
180
+ if self.num_conditions == 1:
181
+ tuple_indices = [0]
182
+ input_predictions = ["pred_cond"]
183
+ elif self.num_conditions == 2:
184
+ tuple_indices = [0, 1]
185
+ input_predictions = (
186
+ ["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_seg"]
187
+ )
188
+ else:
189
+ tuple_indices = [0, 1, 0]
190
+ input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"]
191
+ data_batches = []
192
+ for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
193
+ data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
194
+ data_batches.append(data_batch)
195
+ return data_batches
196
+
197
+ def forward(
198
+ self,
199
+ pred_cond: torch.Tensor,
200
+ pred_uncond: torch.Tensor | None = None,
201
+ pred_cond_seg: torch.Tensor | None = None,
202
+ ) -> GuiderOutput:
203
+ pred = None
204
+
205
+ if not self._is_cfg_enabled() and not self._is_seg_enabled():
206
+ pred = pred_cond
207
+ elif not self._is_cfg_enabled():
208
+ shift = pred_cond - pred_cond_seg
209
+ pred = pred_cond if self.use_original_formulation else pred_cond_seg
210
+ pred = pred + self.seg_guidance_scale * shift
211
+ elif not self._is_seg_enabled():
212
+ shift = pred_cond - pred_uncond
213
+ pred = pred_cond if self.use_original_formulation else pred_uncond
214
+ pred = pred + self.guidance_scale * shift
215
+ else:
216
+ shift = pred_cond - pred_uncond
217
+ shift_seg = pred_cond - pred_cond_seg
218
+ pred = pred_cond if self.use_original_formulation else pred_uncond
219
+ pred = pred + self.guidance_scale * shift + self.seg_guidance_scale * shift_seg
220
+
221
+ if self.guidance_rescale > 0.0:
222
+ pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
223
+
224
+ return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
225
+
226
+ @property
227
+ def is_conditional(self) -> bool:
228
+ return self._count_prepared == 1 or self._count_prepared == 3
229
+
230
+ @property
231
+ def num_conditions(self) -> int:
232
+ num_conditions = 1
233
+ if self._is_cfg_enabled():
234
+ num_conditions += 1
235
+ if self._is_seg_enabled():
236
+ num_conditions += 1
237
+ return num_conditions
238
+
239
+ def _is_cfg_enabled(self) -> bool:
240
+ if not self._enabled:
241
+ return False
242
+
243
+ is_within_range = True
244
+ if self._num_inference_steps is not None:
245
+ skip_start_step = int(self._start * self._num_inference_steps)
246
+ skip_stop_step = int(self._stop * self._num_inference_steps)
247
+ is_within_range = skip_start_step <= self._step < skip_stop_step
248
+
249
+ is_close = False
250
+ if self.use_original_formulation:
251
+ is_close = math.isclose(self.guidance_scale, 0.0)
252
+ else:
253
+ is_close = math.isclose(self.guidance_scale, 1.0)
254
+
255
+ return is_within_range and not is_close
256
+
257
+ def _is_seg_enabled(self) -> bool:
258
+ if not self._enabled:
259
+ return False
260
+
261
+ is_within_range = True
262
+ if self._num_inference_steps is not None:
263
+ skip_start_step = int(self.seg_guidance_start * self._num_inference_steps)
264
+ skip_stop_step = int(self.seg_guidance_stop * self._num_inference_steps)
265
+ is_within_range = skip_start_step < self._step < skip_stop_step
266
+
267
+ is_zero = math.isclose(self.seg_guidance_scale, 0.0)
268
+
269
+ return is_within_range and not is_zero
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/tangential_classifier_free_guidance.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import math
18
+ from typing import TYPE_CHECKING
19
+
20
+ import torch
21
+
22
+ from ..configuration_utils import register_to_config
23
+ from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
24
+
25
+
26
+ if TYPE_CHECKING:
27
+ from ..modular_pipelines.modular_pipeline import BlockState
28
+
29
+
30
+ class TangentialClassifierFreeGuidance(BaseGuidance):
31
+ """
32
+ Tangential Classifier Free Guidance (TCFG): https://huggingface.co/papers/2503.18137
33
+
34
+ Args:
35
+ guidance_scale (`float`, defaults to `7.5`):
36
+ The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
37
+ prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
38
+ deterioration of image quality.
39
+ guidance_rescale (`float`, defaults to `0.0`):
40
+ The rescale factor applied to the noise predictions. This is used to improve image quality and fix
41
+ overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
42
+ Flawed](https://huggingface.co/papers/2305.08891).
43
+ use_original_formulation (`bool`, defaults to `False`):
44
+ Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
45
+ we use the diffusers-native implementation that has been in the codebase for a long time. See
46
+ [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
47
+ start (`float`, defaults to `0.0`):
48
+ The fraction of the total number of denoising steps after which guidance starts.
49
+ stop (`float`, defaults to `1.0`):
50
+ The fraction of the total number of denoising steps after which guidance stops.
51
+ """
52
+
53
+ _input_predictions = ["pred_cond", "pred_uncond"]
54
+
55
+ @register_to_config
56
+ def __init__(
57
+ self,
58
+ guidance_scale: float = 7.5,
59
+ guidance_rescale: float = 0.0,
60
+ use_original_formulation: bool = False,
61
+ start: float = 0.0,
62
+ stop: float = 1.0,
63
+ enabled: bool = True,
64
+ ):
65
+ super().__init__(start, stop, enabled)
66
+
67
+ self.guidance_scale = guidance_scale
68
+ self.guidance_rescale = guidance_rescale
69
+ self.use_original_formulation = use_original_formulation
70
+
71
+ def prepare_inputs(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]) -> list["BlockState"]:
72
+ tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
73
+ data_batches = []
74
+ for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
75
+ data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
76
+ data_batches.append(data_batch)
77
+ return data_batches
78
+
79
+ def prepare_inputs_from_block_state(
80
+ self, data: "BlockState", input_fields: dict[str, str | tuple[str, str]]
81
+ ) -> list["BlockState"]:
82
+ tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
83
+ data_batches = []
84
+ for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
85
+ data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
86
+ data_batches.append(data_batch)
87
+ return data_batches
88
+
89
+ def forward(self, pred_cond: torch.Tensor, pred_uncond: torch.Tensor | None = None) -> GuiderOutput:
90
+ pred = None
91
+
92
+ if not self._is_tcfg_enabled():
93
+ pred = pred_cond
94
+ else:
95
+ pred = normalized_guidance(pred_cond, pred_uncond, self.guidance_scale, self.use_original_formulation)
96
+
97
+ if self.guidance_rescale > 0.0:
98
+ pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
99
+
100
+ return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
101
+
102
+ @property
103
+ def is_conditional(self) -> bool:
104
+ return self._num_outputs_prepared == 1
105
+
106
+ @property
107
+ def num_conditions(self) -> int:
108
+ num_conditions = 1
109
+ if self._is_tcfg_enabled():
110
+ num_conditions += 1
111
+ return num_conditions
112
+
113
+ def _is_tcfg_enabled(self) -> bool:
114
+ if not self._enabled:
115
+ return False
116
+
117
+ is_within_range = True
118
+ if self._num_inference_steps is not None:
119
+ skip_start_step = int(self._start * self._num_inference_steps)
120
+ skip_stop_step = int(self._stop * self._num_inference_steps)
121
+ is_within_range = skip_start_step <= self._step < skip_stop_step
122
+
123
+ is_close = False
124
+ if self.use_original_formulation:
125
+ is_close = math.isclose(self.guidance_scale, 0.0)
126
+ else:
127
+ is_close = math.isclose(self.guidance_scale, 1.0)
128
+
129
+ return is_within_range and not is_close
130
+
131
+
132
+ def normalized_guidance(
133
+ pred_cond: torch.Tensor, pred_uncond: torch.Tensor, guidance_scale: float, use_original_formulation: bool = False
134
+ ) -> torch.Tensor:
135
+ cond_dtype = pred_cond.dtype
136
+ preds = torch.stack([pred_cond, pred_uncond], dim=1).float()
137
+ preds = preds.flatten(2)
138
+ U, S, Vh = torch.linalg.svd(preds, full_matrices=False)
139
+ Vh_modified = Vh.clone()
140
+ Vh_modified[:, 1] = 0
141
+
142
+ uncond_flat = pred_uncond.reshape(pred_uncond.size(0), 1, -1).float()
143
+ x_Vh = torch.matmul(uncond_flat, Vh.transpose(-2, -1))
144
+ x_Vh_V = torch.matmul(x_Vh, Vh_modified)
145
+ pred_uncond = x_Vh_V.reshape(pred_uncond.shape).to(cond_dtype)
146
+
147
+ pred = pred_cond if use_original_formulation else pred_uncond
148
+ shift = pred_cond - pred_uncond
149
+ pred = pred + guidance_scale * shift
150
+
151
+ return pred
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/__init__.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ..utils import is_torch_available
16
+
17
+
18
+ if is_torch_available():
19
+ from .context_parallel import apply_context_parallel
20
+ from .faster_cache import FasterCacheConfig, apply_faster_cache
21
+ from .first_block_cache import FirstBlockCacheConfig, apply_first_block_cache
22
+ from .group_offloading import apply_group_offloading
23
+ from .hooks import HookRegistry, ModelHook
24
+ from .layer_skip import LayerSkipConfig, apply_layer_skip
25
+ from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
26
+ from .mag_cache import MagCacheConfig, apply_mag_cache
27
+ from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
28
+ from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig
29
+ from .taylorseer_cache import TaylorSeerCacheConfig, apply_taylorseer_cache
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (1.24 kB). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/__pycache__/_common.cpython-312.pyc ADDED
Binary file (1.96 kB). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/__pycache__/_helpers.cpython-312.pyc ADDED
Binary file (11.1 kB). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/__pycache__/context_parallel.cpython-312.pyc ADDED
Binary file (22.8 kB). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/__pycache__/faster_cache.cpython-312.pyc ADDED
Binary file (35.3 kB). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/__pycache__/first_block_cache.cpython-312.pyc ADDED
Binary file (12.7 kB). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/__pycache__/group_offloading.cpython-312.pyc ADDED
Binary file (48.8 kB). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/__pycache__/hooks.cpython-312.pyc ADDED
Binary file (15.6 kB). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/__pycache__/layer_skip.cpython-312.pyc ADDED
Binary file (14.6 kB). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/__pycache__/layerwise_casting.cpython-312.pyc ADDED
Binary file (12.2 kB). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/__pycache__/mag_cache.cpython-312.pyc ADDED
Binary file (20.2 kB). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/__pycache__/pyramid_attention_broadcast.cpython-312.pyc ADDED
Binary file (17.2 kB). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/__pycache__/smoothed_energy_guidance_utils.cpython-312.pyc ADDED
Binary file (8.94 kB). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/__pycache__/taylorseer_cache.cpython-312.pyc ADDED
Binary file (17.8 kB). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/__pycache__/utils.cpython-312.pyc ADDED
Binary file (2.25 kB). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/_common.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+
17
+ from ..models.attention import AttentionModuleMixin, FeedForward, LuminaFeedForward
18
+ from ..models.attention_processor import Attention, MochiAttention
19
+
20
+
21
+ _ATTENTION_CLASSES = (Attention, MochiAttention, AttentionModuleMixin)
22
+ _FEEDFORWARD_CLASSES = (FeedForward, LuminaFeedForward)
23
+
24
+ _SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = (
25
+ "blocks",
26
+ "transformer_blocks",
27
+ "single_transformer_blocks",
28
+ "layers",
29
+ "visual_transformer_blocks",
30
+ )
31
+ _TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
32
+ _CROSS_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "layers")
33
+
34
+ _ALL_TRANSFORMER_BLOCK_IDENTIFIERS = tuple(
35
+ {
36
+ *_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS,
37
+ *_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS,
38
+ *_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS,
39
+ }
40
+ )
41
+
42
+ # Layers supported for group offloading and layerwise casting
43
+ _GO_LC_SUPPORTED_PYTORCH_LAYERS = (
44
+ torch.nn.Conv1d,
45
+ torch.nn.Conv2d,
46
+ torch.nn.Conv3d,
47
+ torch.nn.ConvTranspose1d,
48
+ torch.nn.ConvTranspose2d,
49
+ torch.nn.ConvTranspose3d,
50
+ torch.nn.Linear,
51
+ torch.nn.Embedding,
52
+ # TODO(aryan): look into torch.nn.LayerNorm, torch.nn.GroupNorm later, seems to be causing some issues with CogVideoX
53
+ # because of double invocation of the same norm layer in CogVideoXLayerNorm
54
+ )
55
+
56
+
57
+ def _get_submodule_from_fqn(module: torch.nn.Module, fqn: str) -> torch.nn.Module | None:
58
+ for submodule_name, submodule in module.named_modules():
59
+ if submodule_name == fqn:
60
+ return submodule
61
+ return None
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/_helpers.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from dataclasses import dataclass
17
+ from typing import Any, Callable, Type
18
+
19
+
20
+ @dataclass
21
+ class AttentionProcessorMetadata:
22
+ skip_processor_output_fn: Callable[[Any], Any]
23
+
24
+
25
+ @dataclass
26
+ class TransformerBlockMetadata:
27
+ return_hidden_states_index: int = None
28
+ return_encoder_hidden_states_index: int = None
29
+ hidden_states_argument_name: str = "hidden_states"
30
+
31
+ _cls: Type = None
32
+ _cached_parameter_indices: dict[str, int] = None
33
+
34
+ def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None):
35
+ kwargs = kwargs or {}
36
+ if identifier in kwargs:
37
+ return kwargs[identifier]
38
+ if self._cached_parameter_indices is not None:
39
+ return args[self._cached_parameter_indices[identifier]]
40
+ if self._cls is None:
41
+ raise ValueError("Model class is not set for metadata.")
42
+ parameters = list(inspect.signature(self._cls.forward).parameters.keys())
43
+ parameters = parameters[1:] # skip `self`
44
+ self._cached_parameter_indices = {param: i for i, param in enumerate(parameters)}
45
+ if identifier not in self._cached_parameter_indices:
46
+ raise ValueError(f"Parameter '{identifier}' not found in function signature but was requested.")
47
+ index = self._cached_parameter_indices[identifier]
48
+ if index >= len(args):
49
+ raise ValueError(f"Expected {index} arguments but got {len(args)}.")
50
+ return args[index]
51
+
52
+
53
+ class AttentionProcessorRegistry:
54
+ _registry = {}
55
+ # TODO(aryan): this is only required for the time being because we need to do the registrations
56
+ # for classes. If we do it eagerly, i.e. call the functions in global scope, we will get circular
57
+ # import errors because of the models imported in this file.
58
+ _is_registered = False
59
+
60
+ @classmethod
61
+ def register(cls, model_class: Type, metadata: AttentionProcessorMetadata):
62
+ cls._register()
63
+ cls._registry[model_class] = metadata
64
+
65
+ @classmethod
66
+ def get(cls, model_class: Type) -> AttentionProcessorMetadata:
67
+ cls._register()
68
+ if model_class not in cls._registry:
69
+ raise ValueError(f"Model class {model_class} not registered.")
70
+ return cls._registry[model_class]
71
+
72
+ @classmethod
73
+ def _register(cls):
74
+ if cls._is_registered:
75
+ return
76
+ cls._is_registered = True
77
+ _register_attention_processors_metadata()
78
+
79
+
80
+ class TransformerBlockRegistry:
81
+ _registry = {}
82
+ # TODO(aryan): this is only required for the time being because we need to do the registrations
83
+ # for classes. If we do it eagerly, i.e. call the functions in global scope, we will get circular
84
+ # import errors because of the models imported in this file.
85
+ _is_registered = False
86
+
87
+ @classmethod
88
+ def register(cls, model_class: Type, metadata: TransformerBlockMetadata):
89
+ cls._register()
90
+ metadata._cls = model_class
91
+ cls._registry[model_class] = metadata
92
+
93
+ @classmethod
94
+ def get(cls, model_class: Type) -> TransformerBlockMetadata:
95
+ cls._register()
96
+ if model_class not in cls._registry:
97
+ raise ValueError(f"Model class {model_class} not registered.")
98
+ return cls._registry[model_class]
99
+
100
+ @classmethod
101
+ def _register(cls):
102
+ if cls._is_registered:
103
+ return
104
+ cls._is_registered = True
105
+ _register_transformer_blocks_metadata()
106
+
107
+
108
+ def _register_attention_processors_metadata():
109
+ from ..models.attention_processor import AttnProcessor2_0
110
+ from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor
111
+ from ..models.transformers.transformer_flux import FluxAttnProcessor
112
+ from ..models.transformers.transformer_hunyuanimage import HunyuanImageAttnProcessor
113
+ from ..models.transformers.transformer_qwenimage import QwenDoubleStreamAttnProcessor2_0
114
+ from ..models.transformers.transformer_wan import WanAttnProcessor2_0
115
+ from ..models.transformers.transformer_z_image import ZSingleStreamAttnProcessor
116
+
117
+ # AttnProcessor2_0
118
+ AttentionProcessorRegistry.register(
119
+ model_class=AttnProcessor2_0,
120
+ metadata=AttentionProcessorMetadata(
121
+ skip_processor_output_fn=_skip_proc_output_fn_Attention_AttnProcessor2_0,
122
+ ),
123
+ )
124
+
125
+ # CogView4AttnProcessor
126
+ AttentionProcessorRegistry.register(
127
+ model_class=CogView4AttnProcessor,
128
+ metadata=AttentionProcessorMetadata(
129
+ skip_processor_output_fn=_skip_proc_output_fn_Attention_CogView4AttnProcessor,
130
+ ),
131
+ )
132
+
133
+ # WanAttnProcessor2_0
134
+ AttentionProcessorRegistry.register(
135
+ model_class=WanAttnProcessor2_0,
136
+ metadata=AttentionProcessorMetadata(
137
+ skip_processor_output_fn=_skip_proc_output_fn_Attention_WanAttnProcessor2_0,
138
+ ),
139
+ )
140
+
141
+ # FluxAttnProcessor
142
+ AttentionProcessorRegistry.register(
143
+ model_class=FluxAttnProcessor,
144
+ metadata=AttentionProcessorMetadata(skip_processor_output_fn=_skip_proc_output_fn_Attention_FluxAttnProcessor),
145
+ )
146
+
147
+ # QwenDoubleStreamAttnProcessor2
148
+ AttentionProcessorRegistry.register(
149
+ model_class=QwenDoubleStreamAttnProcessor2_0,
150
+ metadata=AttentionProcessorMetadata(
151
+ skip_processor_output_fn=_skip_proc_output_fn_Attention_QwenDoubleStreamAttnProcessor2_0
152
+ ),
153
+ )
154
+
155
+ # HunyuanImageAttnProcessor
156
+ AttentionProcessorRegistry.register(
157
+ model_class=HunyuanImageAttnProcessor,
158
+ metadata=AttentionProcessorMetadata(
159
+ skip_processor_output_fn=_skip_proc_output_fn_Attention_HunyuanImageAttnProcessor,
160
+ ),
161
+ )
162
+
163
+ # ZSingleStreamAttnProcessor
164
+ AttentionProcessorRegistry.register(
165
+ model_class=ZSingleStreamAttnProcessor,
166
+ metadata=AttentionProcessorMetadata(
167
+ skip_processor_output_fn=_skip_proc_output_fn_Attention_ZSingleStreamAttnProcessor,
168
+ ),
169
+ )
170
+
171
+
172
+ def _register_transformer_blocks_metadata():
173
+ from ..models.attention import BasicTransformerBlock, JointTransformerBlock
174
+ from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock
175
+ from ..models.transformers.transformer_bria import BriaTransformerBlock
176
+ from ..models.transformers.transformer_cogview4 import CogView4TransformerBlock
177
+ from ..models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
178
+ from ..models.transformers.transformer_hunyuan_video import (
179
+ HunyuanVideoSingleTransformerBlock,
180
+ HunyuanVideoTokenReplaceSingleTransformerBlock,
181
+ HunyuanVideoTokenReplaceTransformerBlock,
182
+ HunyuanVideoTransformerBlock,
183
+ )
184
+ from ..models.transformers.transformer_hunyuanimage import (
185
+ HunyuanImageSingleTransformerBlock,
186
+ HunyuanImageTransformerBlock,
187
+ )
188
+ from ..models.transformers.transformer_kandinsky import Kandinsky5TransformerDecoderBlock
189
+ from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock
190
+ from ..models.transformers.transformer_mochi import MochiTransformerBlock
191
+ from ..models.transformers.transformer_qwenimage import QwenImageTransformerBlock
192
+ from ..models.transformers.transformer_wan import WanTransformerBlock
193
+ from ..models.transformers.transformer_z_image import ZImageTransformerBlock
194
+
195
+ # BasicTransformerBlock
196
+ TransformerBlockRegistry.register(
197
+ model_class=BasicTransformerBlock,
198
+ metadata=TransformerBlockMetadata(
199
+ return_hidden_states_index=0,
200
+ return_encoder_hidden_states_index=None,
201
+ ),
202
+ )
203
+ TransformerBlockRegistry.register(
204
+ model_class=BriaTransformerBlock,
205
+ metadata=TransformerBlockMetadata(
206
+ return_hidden_states_index=0,
207
+ return_encoder_hidden_states_index=None,
208
+ ),
209
+ )
210
+
211
+ # CogVideoX
212
+ TransformerBlockRegistry.register(
213
+ model_class=CogVideoXBlock,
214
+ metadata=TransformerBlockMetadata(
215
+ return_hidden_states_index=0,
216
+ return_encoder_hidden_states_index=1,
217
+ ),
218
+ )
219
+
220
+ # CogView4
221
+ TransformerBlockRegistry.register(
222
+ model_class=CogView4TransformerBlock,
223
+ metadata=TransformerBlockMetadata(
224
+ return_hidden_states_index=0,
225
+ return_encoder_hidden_states_index=1,
226
+ ),
227
+ )
228
+
229
+ # Flux
230
+ TransformerBlockRegistry.register(
231
+ model_class=FluxTransformerBlock,
232
+ metadata=TransformerBlockMetadata(
233
+ return_hidden_states_index=1,
234
+ return_encoder_hidden_states_index=0,
235
+ ),
236
+ )
237
+ TransformerBlockRegistry.register(
238
+ model_class=FluxSingleTransformerBlock,
239
+ metadata=TransformerBlockMetadata(
240
+ return_hidden_states_index=1,
241
+ return_encoder_hidden_states_index=0,
242
+ ),
243
+ )
244
+
245
+ # HunyuanVideo
246
+ TransformerBlockRegistry.register(
247
+ model_class=HunyuanVideoTransformerBlock,
248
+ metadata=TransformerBlockMetadata(
249
+ return_hidden_states_index=0,
250
+ return_encoder_hidden_states_index=1,
251
+ ),
252
+ )
253
+ TransformerBlockRegistry.register(
254
+ model_class=HunyuanVideoSingleTransformerBlock,
255
+ metadata=TransformerBlockMetadata(
256
+ return_hidden_states_index=0,
257
+ return_encoder_hidden_states_index=1,
258
+ ),
259
+ )
260
+ TransformerBlockRegistry.register(
261
+ model_class=HunyuanVideoTokenReplaceTransformerBlock,
262
+ metadata=TransformerBlockMetadata(
263
+ return_hidden_states_index=0,
264
+ return_encoder_hidden_states_index=1,
265
+ ),
266
+ )
267
+ TransformerBlockRegistry.register(
268
+ model_class=HunyuanVideoTokenReplaceSingleTransformerBlock,
269
+ metadata=TransformerBlockMetadata(
270
+ return_hidden_states_index=0,
271
+ return_encoder_hidden_states_index=1,
272
+ ),
273
+ )
274
+
275
+ # LTXVideo
276
+ TransformerBlockRegistry.register(
277
+ model_class=LTXVideoTransformerBlock,
278
+ metadata=TransformerBlockMetadata(
279
+ return_hidden_states_index=0,
280
+ return_encoder_hidden_states_index=None,
281
+ ),
282
+ )
283
+
284
+ # Mochi
285
+ TransformerBlockRegistry.register(
286
+ model_class=MochiTransformerBlock,
287
+ metadata=TransformerBlockMetadata(
288
+ return_hidden_states_index=0,
289
+ return_encoder_hidden_states_index=1,
290
+ ),
291
+ )
292
+
293
+ # Wan
294
+ TransformerBlockRegistry.register(
295
+ model_class=WanTransformerBlock,
296
+ metadata=TransformerBlockMetadata(
297
+ return_hidden_states_index=0,
298
+ return_encoder_hidden_states_index=None,
299
+ ),
300
+ )
301
+
302
+ # QwenImage
303
+ TransformerBlockRegistry.register(
304
+ model_class=QwenImageTransformerBlock,
305
+ metadata=TransformerBlockMetadata(
306
+ return_hidden_states_index=1,
307
+ return_encoder_hidden_states_index=0,
308
+ ),
309
+ )
310
+
311
+ # HunyuanImage2.1
312
+ TransformerBlockRegistry.register(
313
+ model_class=HunyuanImageTransformerBlock,
314
+ metadata=TransformerBlockMetadata(
315
+ return_hidden_states_index=0,
316
+ return_encoder_hidden_states_index=1,
317
+ ),
318
+ )
319
+ TransformerBlockRegistry.register(
320
+ model_class=HunyuanImageSingleTransformerBlock,
321
+ metadata=TransformerBlockMetadata(
322
+ return_hidden_states_index=0,
323
+ return_encoder_hidden_states_index=1,
324
+ ),
325
+ )
326
+
327
+ # ZImage
328
+ TransformerBlockRegistry.register(
329
+ model_class=ZImageTransformerBlock,
330
+ metadata=TransformerBlockMetadata(
331
+ return_hidden_states_index=0,
332
+ return_encoder_hidden_states_index=None,
333
+ ),
334
+ )
335
+
336
+ TransformerBlockRegistry.register(
337
+ model_class=JointTransformerBlock,
338
+ metadata=TransformerBlockMetadata(
339
+ return_hidden_states_index=1,
340
+ return_encoder_hidden_states_index=0,
341
+ ),
342
+ )
343
+
344
+ # Kandinsky 5.0 (Kandinsky5TransformerDecoderBlock)
345
+ TransformerBlockRegistry.register(
346
+ model_class=Kandinsky5TransformerDecoderBlock,
347
+ metadata=TransformerBlockMetadata(
348
+ return_hidden_states_index=0,
349
+ return_encoder_hidden_states_index=None,
350
+ hidden_states_argument_name="visual_embed",
351
+ ),
352
+ )
353
+
354
+
355
+ # fmt: off
356
+ def _skip_attention___ret___hidden_states(self, *args, **kwargs):
357
+ hidden_states = kwargs.get("hidden_states", None)
358
+ if hidden_states is None and len(args) > 0:
359
+ hidden_states = args[0]
360
+ return hidden_states
361
+
362
+
363
+ def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs):
364
+ hidden_states = kwargs.get("hidden_states", None)
365
+ encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
366
+ if hidden_states is None and len(args) > 0:
367
+ hidden_states = args[0]
368
+ if encoder_hidden_states is None and len(args) > 1:
369
+ encoder_hidden_states = args[1]
370
+ return hidden_states, encoder_hidden_states
371
+
372
+
373
+ _skip_proc_output_fn_Attention_AttnProcessor2_0 = _skip_attention___ret___hidden_states
374
+ _skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states
375
+ _skip_proc_output_fn_Attention_WanAttnProcessor2_0 = _skip_attention___ret___hidden_states
376
+ # not sure what this is yet.
377
+ _skip_proc_output_fn_Attention_FluxAttnProcessor = _skip_attention___ret___hidden_states
378
+ _skip_proc_output_fn_Attention_QwenDoubleStreamAttnProcessor2_0 = _skip_attention___ret___hidden_states
379
+ _skip_proc_output_fn_Attention_HunyuanImageAttnProcessor = _skip_attention___ret___hidden_states
380
+ _skip_proc_output_fn_Attention_ZSingleStreamAttnProcessor = _skip_attention___ret___hidden_states
381
+ # fmt: on
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/context_parallel.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import copy
15
+ import functools
16
+ import inspect
17
+ from dataclasses import dataclass
18
+ from typing import Type
19
+
20
+ import torch
21
+ import torch.distributed as dist
22
+
23
+
24
+ if torch.distributed.is_available():
25
+ import torch.distributed._functional_collectives as funcol
26
+
27
+ from ..models._modeling_parallel import (
28
+ ContextParallelConfig,
29
+ ContextParallelInput,
30
+ ContextParallelModelPlan,
31
+ ContextParallelOutput,
32
+ gather_size_by_comm,
33
+ )
34
+ from ..utils import get_logger
35
+ from ..utils.torch_utils import maybe_allow_in_graph, unwrap_module
36
+ from .hooks import HookRegistry, ModelHook
37
+
38
+
39
+ logger = get_logger(__name__) # pylint: disable=invalid-name
40
+
41
+ _CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE = "cp_input---{}"
42
+ _CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE = "cp_output---{}"
43
+
44
+
45
+ # TODO(aryan): consolidate with ._helpers.TransformerBlockMetadata
46
+ @dataclass
47
+ class ModuleForwardMetadata:
48
+ cached_parameter_indices: dict[str, int] = None
49
+ _cls: Type = None
50
+
51
+ def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None):
52
+ kwargs = kwargs or {}
53
+
54
+ if identifier in kwargs:
55
+ return kwargs[identifier], True, None
56
+
57
+ if self.cached_parameter_indices is not None:
58
+ index = self.cached_parameter_indices.get(identifier, None)
59
+ if index is None:
60
+ raise ValueError(f"Parameter '{identifier}' not found in cached indices.")
61
+ return args[index], False, index
62
+
63
+ if self._cls is None:
64
+ raise ValueError("Model class is not set for metadata.")
65
+
66
+ parameters = list(inspect.signature(self._cls.forward).parameters.keys())
67
+ parameters = parameters[1:] # skip `self`
68
+ self.cached_parameter_indices = {param: i for i, param in enumerate(parameters)}
69
+
70
+ if identifier not in self.cached_parameter_indices:
71
+ raise ValueError(f"Parameter '{identifier}' not found in function signature but was requested.")
72
+
73
+ index = self.cached_parameter_indices[identifier]
74
+
75
+ if index >= len(args):
76
+ raise ValueError(f"Expected {index} arguments but got {len(args)}.")
77
+
78
+ return args[index], False, index
79
+
80
+
81
+ def apply_context_parallel(
82
+ module: torch.nn.Module,
83
+ parallel_config: ContextParallelConfig,
84
+ plan: dict[str, ContextParallelModelPlan],
85
+ ) -> None:
86
+ """Apply context parallel on a model."""
87
+ logger.debug(f"Applying context parallel with CP mesh: {parallel_config._mesh} and plan: {plan}")
88
+
89
+ for module_id, cp_model_plan in plan.items():
90
+ submodule = _get_submodule_by_name(module, module_id)
91
+ if not isinstance(submodule, list):
92
+ submodule = [submodule]
93
+
94
+ logger.debug(f"Applying ContextParallelHook to {module_id=} identifying a total of {len(submodule)} modules")
95
+
96
+ for m in submodule:
97
+ if isinstance(cp_model_plan, dict):
98
+ hook = ContextParallelSplitHook(cp_model_plan, parallel_config)
99
+ hook_name = _CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE.format(module_id)
100
+ elif isinstance(cp_model_plan, (ContextParallelOutput, list, tuple)):
101
+ if isinstance(cp_model_plan, ContextParallelOutput):
102
+ cp_model_plan = [cp_model_plan]
103
+ if not all(isinstance(x, ContextParallelOutput) for x in cp_model_plan):
104
+ raise ValueError(f"Expected all elements of cp_model_plan to be CPOutput, but got {cp_model_plan}")
105
+ hook = ContextParallelGatherHook(cp_model_plan, parallel_config)
106
+ hook_name = _CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE.format(module_id)
107
+ else:
108
+ raise ValueError(f"Unsupported context parallel model plan type: {type(cp_model_plan)}")
109
+ registry = HookRegistry.check_if_exists_or_initialize(m)
110
+ registry.register_hook(hook, hook_name)
111
+
112
+
113
+ def remove_context_parallel(module: torch.nn.Module, plan: dict[str, ContextParallelModelPlan]) -> None:
114
+ for module_id, cp_model_plan in plan.items():
115
+ submodule = _get_submodule_by_name(module, module_id)
116
+ if not isinstance(submodule, list):
117
+ submodule = [submodule]
118
+
119
+ for m in submodule:
120
+ registry = HookRegistry.check_if_exists_or_initialize(m)
121
+ if isinstance(cp_model_plan, dict):
122
+ hook_name = _CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE.format(module_id)
123
+ elif isinstance(cp_model_plan, (ContextParallelOutput, list, tuple)):
124
+ hook_name = _CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE.format(module_id)
125
+ else:
126
+ raise ValueError(f"Unsupported context parallel model plan type: {type(cp_model_plan)}")
127
+ registry.remove_hook(hook_name)
128
+
129
+
130
+ class ContextParallelSplitHook(ModelHook):
131
+ def __init__(self, metadata: ContextParallelModelPlan, parallel_config: ContextParallelConfig) -> None:
132
+ super().__init__()
133
+ self.metadata = metadata
134
+ self.parallel_config = parallel_config
135
+ self.module_forward_metadata = None
136
+
137
+ def initialize_hook(self, module):
138
+ cls = unwrap_module(module).__class__
139
+ self.module_forward_metadata = ModuleForwardMetadata(_cls=cls)
140
+ return module
141
+
142
+ def pre_forward(self, module, *args, **kwargs):
143
+ args_list = list(args)
144
+
145
+ for name, cpm in self.metadata.items():
146
+ if isinstance(cpm, ContextParallelInput) and cpm.split_output:
147
+ continue
148
+
149
+ # Maybe the parameter was passed as a keyword argument
150
+ input_val, is_kwarg, index = self.module_forward_metadata._get_parameter_from_args_kwargs(
151
+ name, args_list, kwargs
152
+ )
153
+
154
+ if input_val is None:
155
+ continue
156
+
157
+ # The input_val may be a tensor or list/tuple of tensors. In certain cases, user may specify to shard
158
+ # the output instead of input for a particular layer by setting split_output=True
159
+ if isinstance(input_val, torch.Tensor):
160
+ input_val = self._prepare_cp_input(input_val, cpm)
161
+ elif isinstance(input_val, (list, tuple)):
162
+ if len(input_val) != len(cpm):
163
+ raise ValueError(
164
+ f"Expected input model plan to have {len(input_val)} elements, but got {len(cpm)}."
165
+ )
166
+ sharded_input_val = []
167
+ for i, x in enumerate(input_val):
168
+ if torch.is_tensor(x) and not cpm[i].split_output:
169
+ x = self._prepare_cp_input(x, cpm[i])
170
+ sharded_input_val.append(x)
171
+ input_val = sharded_input_val
172
+ else:
173
+ raise ValueError(f"Unsupported input type: {type(input_val)}")
174
+
175
+ if is_kwarg:
176
+ kwargs[name] = input_val
177
+ elif index is not None and index < len(args_list):
178
+ args_list[index] = input_val
179
+ else:
180
+ raise ValueError(
181
+ f"An unexpected error occurred while processing the input '{name}'. Please open an "
182
+ f"issue at https://github.com/huggingface/diffusers/issues and provide a minimal reproducible "
183
+ f"example along with the full stack trace."
184
+ )
185
+
186
+ return tuple(args_list), kwargs
187
+
188
+ def post_forward(self, module, output):
189
+ is_tensor = isinstance(output, torch.Tensor)
190
+ is_tensor_list = isinstance(output, (list, tuple)) and all(isinstance(x, torch.Tensor) for x in output)
191
+
192
+ if not is_tensor and not is_tensor_list:
193
+ raise ValueError(f"Expected output to be a tensor or a list/tuple of tensors, but got {type(output)}.")
194
+
195
+ output = [output] if is_tensor else list(output)
196
+ for index, cpm in self.metadata.items():
197
+ if not isinstance(cpm, ContextParallelInput) or not cpm.split_output:
198
+ continue
199
+ if index >= len(output):
200
+ raise ValueError(f"Index {index} out of bounds for output of length {len(output)}.")
201
+ current_output = output[index]
202
+ current_output = self._prepare_cp_input(current_output, cpm)
203
+ output[index] = current_output
204
+
205
+ return output[0] if is_tensor else tuple(output)
206
+
207
+ def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput) -> torch.Tensor:
208
+ if cp_input.expected_dims is not None and x.dim() != cp_input.expected_dims:
209
+ logger.warning_once(
210
+ f"Expected input tensor to have {cp_input.expected_dims} dimensions, but got {x.dim()} dimensions, split will not be applied."
211
+ )
212
+ return x
213
+ else:
214
+ if self.parallel_config.ulysses_anything:
215
+ return PartitionAnythingSharder.shard_anything(
216
+ x, cp_input.split_dim, self.parallel_config._flattened_mesh
217
+ )
218
+ return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)
219
+
220
+
221
+ class ContextParallelGatherHook(ModelHook):
222
+ def __init__(self, metadata: ContextParallelModelPlan, parallel_config: ContextParallelConfig) -> None:
223
+ super().__init__()
224
+ self.metadata = metadata
225
+ self.parallel_config = parallel_config
226
+
227
+ def post_forward(self, module, output):
228
+ is_tensor = isinstance(output, torch.Tensor)
229
+
230
+ if is_tensor:
231
+ output = [output]
232
+ elif not (isinstance(output, (list, tuple)) and all(isinstance(x, torch.Tensor) for x in output)):
233
+ raise ValueError(f"Expected output to be a tensor or a list/tuple of tensors, but got {type(output)}.")
234
+
235
+ output = list(output)
236
+
237
+ if len(output) != len(self.metadata):
238
+ raise ValueError(f"Expected output to have {len(self.metadata)} elements, but got {len(output)}.")
239
+
240
+ for i, cpm in enumerate(self.metadata):
241
+ if cpm is None:
242
+ continue
243
+ if self.parallel_config.ulysses_anything:
244
+ output[i] = PartitionAnythingSharder.unshard_anything(
245
+ output[i], cpm.gather_dim, self.parallel_config._flattened_mesh
246
+ )
247
+ else:
248
+ output[i] = EquipartitionSharder.unshard(
249
+ output[i], cpm.gather_dim, self.parallel_config._flattened_mesh
250
+ )
251
+
252
+ return output[0] if is_tensor else tuple(output)
253
+
254
+
255
+ class AllGatherFunction(torch.autograd.Function):
256
+ @staticmethod
257
+ def forward(ctx, tensor, dim, group):
258
+ ctx.dim = dim
259
+ ctx.group = group
260
+ ctx.world_size = torch.distributed.get_world_size(group)
261
+ ctx.rank = torch.distributed.get_rank(group)
262
+ return funcol.all_gather_tensor(tensor, dim, group=group)
263
+
264
+ @staticmethod
265
+ def backward(ctx, grad_output):
266
+ grad_chunks = torch.chunk(grad_output, ctx.world_size, dim=ctx.dim)
267
+ return grad_chunks[ctx.rank], None, None
268
+
269
+
270
+ class EquipartitionSharder:
271
+ @classmethod
272
+ def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
273
+ # NOTE: the following assertion does not have to be true in general. We simply enforce it for now
274
+ # because the alternate case has not yet been tested/required for any model.
275
+ assert tensor.size()[dim] % mesh.size() == 0, (
276
+ "Tensor size along dimension to be sharded must be divisible by mesh size"
277
+ )
278
+
279
+ # The following is not fullgraph compatible with Dynamo (fails in DeviceMesh.get_rank)
280
+ # return tensor.chunk(mesh.size(), dim=dim)[mesh.get_rank()]
281
+
282
+ return tensor.chunk(mesh.size(), dim=dim)[torch.distributed.get_rank(mesh.get_group())]
283
+
284
+ @classmethod
285
+ def unshard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
286
+ tensor = tensor.contiguous()
287
+ tensor = AllGatherFunction.apply(tensor, dim, mesh.get_group())
288
+ return tensor
289
+
290
+
291
+ class AllGatherAnythingFunction(torch.autograd.Function):
292
+ @staticmethod
293
+ def forward(ctx, tensor: torch.Tensor, dim: int, group: dist.device_mesh.DeviceMesh):
294
+ ctx.dim = dim
295
+ ctx.group = group
296
+ ctx.world_size = dist.get_world_size(group)
297
+ ctx.rank = dist.get_rank(group)
298
+ gathered_tensor = _all_gather_anything(tensor, dim, group)
299
+ return gathered_tensor
300
+
301
+ @staticmethod
302
+ def backward(ctx, grad_output):
303
+ # NOTE: We use `tensor_split` instead of chunk, because the `chunk`
304
+ # function may return fewer than the specified number of chunks!
305
+ grad_splits = torch.tensor_split(grad_output, ctx.world_size, dim=ctx.dim)
306
+ return grad_splits[ctx.rank], None, None
307
+
308
+
309
+ class PartitionAnythingSharder:
310
+ @classmethod
311
+ def shard_anything(
312
+ cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh
313
+ ) -> torch.Tensor:
314
+ assert tensor.size()[dim] >= mesh.size(), (
315
+ f"Cannot shard tensor of size {tensor.size()} along dim {dim} across mesh of size {mesh.size()}."
316
+ )
317
+ # NOTE: We use `tensor_split` instead of chunk, because the `chunk`
318
+ # function may return fewer than the specified number of chunks!
319
+ return tensor.tensor_split(mesh.size(), dim=dim)[dist.get_rank(mesh.get_group())]
320
+
321
+ @classmethod
322
+ def unshard_anything(
323
+ cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh
324
+ ) -> torch.Tensor:
325
+ tensor = tensor.contiguous()
326
+ tensor = AllGatherAnythingFunction.apply(tensor, dim, mesh.get_group())
327
+ return tensor
328
+
329
+
330
+ @functools.lru_cache(maxsize=64)
331
+ def _fill_gather_shapes(shape: tuple[int], gather_dims: tuple[int], dim: int, world_size: int) -> list[list[int]]:
332
+ gather_shapes = []
333
+ for i in range(world_size):
334
+ rank_shape = list(copy.deepcopy(shape))
335
+ rank_shape[dim] = gather_dims[i]
336
+ gather_shapes.append(rank_shape)
337
+ return gather_shapes
338
+
339
+
340
+ @maybe_allow_in_graph
341
+ def _all_gather_anything(tensor: torch.Tensor, dim: int, group: dist.device_mesh.DeviceMesh) -> torch.Tensor:
342
+ world_size = dist.get_world_size(group=group)
343
+
344
+ tensor = tensor.contiguous()
345
+ shape = tensor.shape
346
+ rank_dim = shape[dim]
347
+ gather_dims = gather_size_by_comm(rank_dim, group)
348
+
349
+ gather_shapes = _fill_gather_shapes(tuple(shape), tuple(gather_dims), dim, world_size)
350
+
351
+ gathered_tensors = [torch.empty(shape, device=tensor.device, dtype=tensor.dtype) for shape in gather_shapes]
352
+
353
+ dist.all_gather(gathered_tensors, tensor, group=group)
354
+ gathered_tensor = torch.cat(gathered_tensors, dim=dim)
355
+ return gathered_tensor
356
+
357
+
358
+ def _get_submodule_by_name(model: torch.nn.Module, name: str) -> torch.nn.Module | list[torch.nn.Module]:
359
+ if name.count("*") > 1:
360
+ raise ValueError("Wildcard '*' can only be used once in the name")
361
+ return _find_submodule_by_name(model, name)
362
+
363
+
364
+ def _find_submodule_by_name(model: torch.nn.Module, name: str) -> torch.nn.Module | list[torch.nn.Module]:
365
+ if name == "":
366
+ return model
367
+ first_atom, remaining_name = name.split(".", 1) if "." in name else (name, "")
368
+ if first_atom == "*":
369
+ if not isinstance(model, torch.nn.ModuleList):
370
+ raise ValueError("Wildcard '*' can only be used with ModuleList")
371
+ submodules = []
372
+ for submodule in model:
373
+ subsubmodules = _find_submodule_by_name(submodule, remaining_name)
374
+ if not isinstance(subsubmodules, list):
375
+ subsubmodules = [subsubmodules]
376
+ submodules.extend(subsubmodules)
377
+ return submodules
378
+ else:
379
+ if hasattr(model, first_atom):
380
+ submodule = getattr(model, first_atom)
381
+ return _find_submodule_by_name(submodule, remaining_name)
382
+ else:
383
+ raise ValueError(f"'{first_atom}' is not a submodule of '{model.__class__.__name__}'")
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/faster_cache.py ADDED
@@ -0,0 +1,654 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import re
16
+ from dataclasses import dataclass
17
+ from typing import Any, Callable
18
+
19
+ import torch
20
+
21
+ from ..models.attention import AttentionModuleMixin
22
+ from ..models.modeling_outputs import Transformer2DModelOutput
23
+ from ..utils import logging
24
+ from ._common import _ATTENTION_CLASSES
25
+ from .hooks import HookRegistry, ModelHook
26
+
27
+
28
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
29
+
30
+
31
+ _FASTER_CACHE_DENOISER_HOOK = "faster_cache_denoiser"
32
+ _FASTER_CACHE_BLOCK_HOOK = "faster_cache_block"
33
+ _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = (
34
+ "^blocks.*attn",
35
+ "^transformer_blocks.*attn",
36
+ "^single_transformer_blocks.*attn",
37
+ )
38
+ _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("^temporal_transformer_blocks.*attn",)
39
+ _TRANSFORMER_BLOCK_IDENTIFIERS = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS + _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
40
+ _UNCOND_COND_INPUT_KWARGS_IDENTIFIERS = (
41
+ "hidden_states",
42
+ "encoder_hidden_states",
43
+ "timestep",
44
+ "attention_mask",
45
+ "encoder_attention_mask",
46
+ )
47
+
48
+
49
+ @dataclass
50
+ class FasterCacheConfig:
51
+ r"""
52
+ Configuration for [FasterCache](https://huggingface.co/papers/2410.19355).
53
+
54
+ Attributes:
55
+ spatial_attention_block_skip_range (`int`, defaults to `2`):
56
+ Calculate the attention states every `N` iterations. If this is set to `N`, the attention computation will
57
+ be skipped `N - 1` times (i.e., cached attention states will be reused) before computing the new attention
58
+ states again.
59
+ temporal_attention_block_skip_range (`int`, *optional*, defaults to `None`):
60
+ Calculate the attention states every `N` iterations. If this is set to `N`, the attention computation will
61
+ be skipped `N - 1` times (i.e., cached attention states will be reused) before computing the new attention
62
+ states again.
63
+ spatial_attention_timestep_skip_range (`tuple[float, float]`, defaults to `(-1, 681)`):
64
+ The timestep range within which the spatial attention computation can be skipped without a significant loss
65
+ in quality. This is to be determined by the user based on the underlying model. The first value in the
66
+ tuple is the lower bound and the second value is the upper bound. Typically, diffusion timesteps for
67
+ denoising are in the reversed range of 0 to 1000 (i.e. denoising starts at timestep 1000 and ends at
68
+ timestep 0). For the default values, this would mean that the spatial attention computation skipping will
69
+ be applicable only after denoising timestep 681 is reached, and continue until the end of the denoising
70
+ process.
71
+ temporal_attention_timestep_skip_range (`tuple[float, float]`, *optional*, defaults to `None`):
72
+ The timestep range within which the temporal attention computation can be skipped without a significant
73
+ loss in quality. This is to be determined by the user based on the underlying model. The first value in the
74
+ tuple is the lower bound and the second value is the upper bound. Typically, diffusion timesteps for
75
+ denoising are in the reversed range of 0 to 1000 (i.e. denoising starts at timestep 1000 and ends at
76
+ timestep 0).
77
+ low_frequency_weight_update_timestep_range (`tuple[int, int]`, defaults to `(99, 901)`):
78
+ The timestep range within which the low frequency weight scaling update is applied. The first value in the
79
+ tuple is the lower bound and the second value is the upper bound of the timestep range. The callback
80
+ function for the update is called only within this range.
81
+ high_frequency_weight_update_timestep_range (`tuple[int, int]`, defaults to `(-1, 301)`):
82
+ The timestep range within which the high frequency weight scaling update is applied. The first value in the
83
+ tuple is the lower bound and the second value is the upper bound of the timestep range. The callback
84
+ function for the update is called only within this range.
85
+ alpha_low_frequency (`float`, defaults to `1.1`):
86
+ The weight to scale the low frequency updates by. This is used to approximate the unconditional branch from
87
+ the conditional branch outputs.
88
+ alpha_high_frequency (`float`, defaults to `1.1`):
89
+ The weight to scale the high frequency updates by. This is used to approximate the unconditional branch
90
+ from the conditional branch outputs.
91
+ unconditional_batch_skip_range (`int`, defaults to `5`):
92
+ Process the unconditional branch every `N` iterations. If this is set to `N`, the unconditional branch
93
+ computation will be skipped `N - 1` times (i.e., cached unconditional branch states will be reused) before
94
+ computing the new unconditional branch states again.
95
+ unconditional_batch_timestep_skip_range (`tuple[float, float]`, defaults to `(-1, 641)`):
96
+ The timestep range within which the unconditional branch computation can be skipped without a significant
97
+ loss in quality. This is to be determined by the user based on the underlying model. The first value in the
98
+ tuple is the lower bound and the second value is the upper bound.
99
+ spatial_attention_block_identifiers (`tuple[str, ...]`, defaults to `("blocks.*attn1", "transformer_blocks.*attn1", "single_transformer_blocks.*attn1")`):
100
+ The identifiers to match the spatial attention blocks in the model. If the name of the block contains any
101
+ of these identifiers, FasterCache will be applied to that block. This can either be the full layer names,
102
+ partial layer names, or regex patterns. Matching will always be done using a regex match.
103
+ temporal_attention_block_identifiers (`tuple[str, ...]`, defaults to `("temporal_transformer_blocks.*attn1",)`):
104
+ The identifiers to match the temporal attention blocks in the model. If the name of the block contains any
105
+ of these identifiers, FasterCache will be applied to that block. This can either be the full layer names,
106
+ partial layer names, or regex patterns. Matching will always be done using a regex match.
107
+ attention_weight_callback (`Callable[[torch.nn.Module], float]`, defaults to `None`):
108
+ The callback function to determine the weight to scale the attention outputs by. This function should take
109
+ the attention module as input and return a float value. This is used to approximate the unconditional
110
+ branch from the conditional branch outputs. If not provided, the default weight is 0.5 for all timesteps.
111
+ Typically, as described in the paper, this weight should gradually increase from 0 to 1 as the inference
112
+ progresses. Users are encouraged to experiment and provide custom weight schedules that take into account
113
+ the number of inference steps and underlying model behaviour as denoising progresses.
114
+ low_frequency_weight_callback (`Callable[[torch.nn.Module], float]`, defaults to `None`):
115
+ The callback function to determine the weight to scale the low frequency updates by. If not provided, the
116
+ default weight is 1.1 for timesteps within the range specified (as described in the paper).
117
+ high_frequency_weight_callback (`Callable[[torch.nn.Module], float]`, defaults to `None`):
118
+ The callback function to determine the weight to scale the high frequency updates by. If not provided, the
119
+ default weight is 1.1 for timesteps within the range specified (as described in the paper).
120
+ tensor_format (`str`, defaults to `"BCFHW"`):
121
+ The format of the input tensors. This should be one of `"BCFHW"`, `"BFCHW"`, or `"BCHW"`. The format is
122
+ used to split individual latent frames in order for low and high frequency components to be computed.
123
+ is_guidance_distilled (`bool`, defaults to `False`):
124
+ Whether the model is guidance distilled or not. If the model is guidance distilled, FasterCache will not be
125
+ applied at the denoiser-level to skip the unconditional branch computation (as there is none).
126
+ _unconditional_conditional_input_kwargs_identifiers (`list[str]`, defaults to `("hidden_states", "encoder_hidden_states", "timestep", "attention_mask", "encoder_attention_mask")`):
127
+ The identifiers to match the input kwargs that contain the batchwise-concatenated unconditional and
128
+ conditional inputs. If the name of the input kwargs contains any of these identifiers, FasterCache will
129
+ split the inputs into unconditional and conditional branches. This must be a list of exact input kwargs
130
+ names that contain the batchwise-concatenated unconditional and conditional inputs.
131
+ """
132
+
133
+ # In the paper and codebase, they hardcode these values to 2. However, it can be made configurable
134
+ # after some testing. We default to 2 if these parameters are not provided.
135
+ spatial_attention_block_skip_range: int = 2
136
+ temporal_attention_block_skip_range: int | None = None
137
+
138
+ spatial_attention_timestep_skip_range: tuple[int, int] = (-1, 681)
139
+ temporal_attention_timestep_skip_range: tuple[int, int] = (-1, 681)
140
+
141
+ # Indicator functions for low/high frequency as mentioned in Equation 11 of the paper
142
+ low_frequency_weight_update_timestep_range: tuple[int, int] = (99, 901)
143
+ high_frequency_weight_update_timestep_range: tuple[int, int] = (-1, 301)
144
+
145
+ # ⍺1 and ⍺2 as mentioned in Equation 11 of the paper
146
+ alpha_low_frequency: float = 1.1
147
+ alpha_high_frequency: float = 1.1
148
+
149
+ # n as described in CFG-Cache explanation in the paper - dependent on the model
150
+ unconditional_batch_skip_range: int = 5
151
+ unconditional_batch_timestep_skip_range: tuple[int, int] = (-1, 641)
152
+
153
+ spatial_attention_block_identifiers: tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS
154
+ temporal_attention_block_identifiers: tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
155
+
156
+ attention_weight_callback: Callable[[torch.nn.Module], float] = None
157
+ low_frequency_weight_callback: Callable[[torch.nn.Module], float] = None
158
+ high_frequency_weight_callback: Callable[[torch.nn.Module], float] = None
159
+
160
+ tensor_format: str = "BCFHW"
161
+ is_guidance_distilled: bool = False
162
+
163
+ current_timestep_callback: Callable[[], int] = None
164
+
165
+ _unconditional_conditional_input_kwargs_identifiers: list[str] = _UNCOND_COND_INPUT_KWARGS_IDENTIFIERS
166
+
167
+ def __repr__(self) -> str:
168
+ return (
169
+ f"FasterCacheConfig(\n"
170
+ f" spatial_attention_block_skip_range={self.spatial_attention_block_skip_range},\n"
171
+ f" temporal_attention_block_skip_range={self.temporal_attention_block_skip_range},\n"
172
+ f" spatial_attention_timestep_skip_range={self.spatial_attention_timestep_skip_range},\n"
173
+ f" temporal_attention_timestep_skip_range={self.temporal_attention_timestep_skip_range},\n"
174
+ f" low_frequency_weight_update_timestep_range={self.low_frequency_weight_update_timestep_range},\n"
175
+ f" high_frequency_weight_update_timestep_range={self.high_frequency_weight_update_timestep_range},\n"
176
+ f" alpha_low_frequency={self.alpha_low_frequency},\n"
177
+ f" alpha_high_frequency={self.alpha_high_frequency},\n"
178
+ f" unconditional_batch_skip_range={self.unconditional_batch_skip_range},\n"
179
+ f" unconditional_batch_timestep_skip_range={self.unconditional_batch_timestep_skip_range},\n"
180
+ f" spatial_attention_block_identifiers={self.spatial_attention_block_identifiers},\n"
181
+ f" temporal_attention_block_identifiers={self.temporal_attention_block_identifiers},\n"
182
+ f" tensor_format={self.tensor_format},\n"
183
+ f")"
184
+ )
185
+
186
+
187
+ class FasterCacheDenoiserState:
188
+ r"""
189
+ State for [FasterCache](https://huggingface.co/papers/2410.19355) top-level denoiser module.
190
+ """
191
+
192
+ def __init__(self) -> None:
193
+ self.iteration: int = 0
194
+ self.low_frequency_delta: torch.Tensor = None
195
+ self.high_frequency_delta: torch.Tensor = None
196
+
197
+ def reset(self):
198
+ self.iteration = 0
199
+ self.low_frequency_delta = None
200
+ self.high_frequency_delta = None
201
+
202
+
203
+ class FasterCacheBlockState:
204
+ r"""
205
+ State for [FasterCache](https://huggingface.co/papers/2410.19355). Every underlying block that FasterCache is
206
+ applied to will have an instance of this state.
207
+ """
208
+
209
+ def __init__(self) -> None:
210
+ self.iteration: int = 0
211
+ self.batch_size: int = None
212
+ self.cache: tuple[torch.Tensor, torch.Tensor] = None
213
+
214
+ def reset(self):
215
+ self.iteration = 0
216
+ self.batch_size = None
217
+ self.cache = None
218
+
219
+
220
+ class FasterCacheDenoiserHook(ModelHook):
221
+ _is_stateful = True
222
+
223
+ def __init__(
224
+ self,
225
+ unconditional_batch_skip_range: int,
226
+ unconditional_batch_timestep_skip_range: tuple[int, int],
227
+ tensor_format: str,
228
+ is_guidance_distilled: bool,
229
+ uncond_cond_input_kwargs_identifiers: list[str],
230
+ current_timestep_callback: Callable[[], int],
231
+ low_frequency_weight_callback: Callable[[torch.nn.Module], torch.Tensor],
232
+ high_frequency_weight_callback: Callable[[torch.nn.Module], torch.Tensor],
233
+ ) -> None:
234
+ super().__init__()
235
+
236
+ self.unconditional_batch_skip_range = unconditional_batch_skip_range
237
+ self.unconditional_batch_timestep_skip_range = unconditional_batch_timestep_skip_range
238
+ # We can't easily detect what args are to be split in unconditional and conditional branches. We
239
+ # can only do it for kwargs, hence they are the only ones we split. The args are passed as-is.
240
+ # If a model is to be made compatible with FasterCache, the user must ensure that the inputs that
241
+ # contain batchwise-concatenated unconditional and conditional inputs are passed as kwargs.
242
+ self.uncond_cond_input_kwargs_identifiers = uncond_cond_input_kwargs_identifiers
243
+ self.tensor_format = tensor_format
244
+ self.is_guidance_distilled = is_guidance_distilled
245
+
246
+ self.current_timestep_callback = current_timestep_callback
247
+ self.low_frequency_weight_callback = low_frequency_weight_callback
248
+ self.high_frequency_weight_callback = high_frequency_weight_callback
249
+
250
+ def initialize_hook(self, module):
251
+ self.state = FasterCacheDenoiserState()
252
+ return module
253
+
254
+ @staticmethod
255
+ def _get_cond_input(input: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
256
+ # Note: this method assumes that the input tensor is batchwise-concatenated with unconditional inputs
257
+ # followed by conditional inputs.
258
+ _, cond = input.chunk(2, dim=0)
259
+ return cond
260
+
261
+ def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any:
262
+ # Split the unconditional and conditional inputs. We only want to infer the conditional branch if the
263
+ # requirements for skipping the unconditional branch are met as described in the paper.
264
+ # We skip the unconditional branch only if the following conditions are met:
265
+ # 1. We have completed at least one iteration of the denoiser
266
+ # 2. The current timestep is within the range specified by the user. This is the optimal timestep range
267
+ # where approximating the unconditional branch from the computation of the conditional branch is possible
268
+ # without a significant loss in quality.
269
+ # 3. The current iteration is not a multiple of the unconditional batch skip range. This is done so that
270
+ # we compute the unconditional branch at least once every few iterations to ensure minimal quality loss.
271
+ is_within_timestep_range = (
272
+ self.unconditional_batch_timestep_skip_range[0]
273
+ < self.current_timestep_callback()
274
+ < self.unconditional_batch_timestep_skip_range[1]
275
+ )
276
+ should_skip_uncond = (
277
+ self.state.iteration > 0
278
+ and is_within_timestep_range
279
+ and self.state.iteration % self.unconditional_batch_skip_range != 0
280
+ and not self.is_guidance_distilled
281
+ )
282
+
283
+ if should_skip_uncond:
284
+ is_any_kwarg_uncond = any(k in self.uncond_cond_input_kwargs_identifiers for k in kwargs.keys())
285
+ if is_any_kwarg_uncond:
286
+ logger.debug("FasterCache - Skipping unconditional branch computation")
287
+ args = tuple([self._get_cond_input(arg) if torch.is_tensor(arg) else arg for arg in args])
288
+ kwargs = {
289
+ k: v if k not in self.uncond_cond_input_kwargs_identifiers else self._get_cond_input(v)
290
+ for k, v in kwargs.items()
291
+ }
292
+
293
+ output = self.fn_ref.original_forward(*args, **kwargs)
294
+
295
+ if self.is_guidance_distilled:
296
+ self.state.iteration += 1
297
+ return output
298
+
299
+ if torch.is_tensor(output):
300
+ hidden_states = output
301
+ elif isinstance(output, (tuple, Transformer2DModelOutput)):
302
+ hidden_states = output[0]
303
+
304
+ batch_size = hidden_states.size(0)
305
+
306
+ if should_skip_uncond:
307
+ self.state.low_frequency_delta = self.state.low_frequency_delta * self.low_frequency_weight_callback(
308
+ module
309
+ )
310
+ self.state.high_frequency_delta = self.state.high_frequency_delta * self.high_frequency_weight_callback(
311
+ module
312
+ )
313
+
314
+ if self.tensor_format == "BCFHW":
315
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
316
+ if self.tensor_format == "BCFHW" or self.tensor_format == "BFCHW":
317
+ hidden_states = hidden_states.flatten(0, 1)
318
+
319
+ low_freq_cond, high_freq_cond = _split_low_high_freq(hidden_states.float())
320
+
321
+ # Approximate/compute the unconditional branch outputs as described in Equation 9 and 10 of the paper
322
+ low_freq_uncond = self.state.low_frequency_delta + low_freq_cond
323
+ high_freq_uncond = self.state.high_frequency_delta + high_freq_cond
324
+ uncond_freq = low_freq_uncond + high_freq_uncond
325
+
326
+ uncond_states = torch.fft.ifftshift(uncond_freq)
327
+ uncond_states = torch.fft.ifft2(uncond_states).real
328
+
329
+ if self.tensor_format == "BCFHW" or self.tensor_format == "BFCHW":
330
+ uncond_states = uncond_states.unflatten(0, (batch_size, -1))
331
+ hidden_states = hidden_states.unflatten(0, (batch_size, -1))
332
+ if self.tensor_format == "BCFHW":
333
+ uncond_states = uncond_states.permute(0, 2, 1, 3, 4)
334
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
335
+
336
+ # Concatenate the approximated unconditional and predicted conditional branches
337
+ uncond_states = uncond_states.to(hidden_states.dtype)
338
+ hidden_states = torch.cat([uncond_states, hidden_states], dim=0)
339
+ else:
340
+ uncond_states, cond_states = hidden_states.chunk(2, dim=0)
341
+ if self.tensor_format == "BCFHW":
342
+ uncond_states = uncond_states.permute(0, 2, 1, 3, 4)
343
+ cond_states = cond_states.permute(0, 2, 1, 3, 4)
344
+ if self.tensor_format == "BCFHW" or self.tensor_format == "BFCHW":
345
+ uncond_states = uncond_states.flatten(0, 1)
346
+ cond_states = cond_states.flatten(0, 1)
347
+
348
+ low_freq_uncond, high_freq_uncond = _split_low_high_freq(uncond_states.float())
349
+ low_freq_cond, high_freq_cond = _split_low_high_freq(cond_states.float())
350
+ self.state.low_frequency_delta = low_freq_uncond - low_freq_cond
351
+ self.state.high_frequency_delta = high_freq_uncond - high_freq_cond
352
+
353
+ self.state.iteration += 1
354
+ if torch.is_tensor(output):
355
+ output = hidden_states
356
+ elif isinstance(output, tuple):
357
+ output = (hidden_states, *output[1:])
358
+ else:
359
+ output.sample = hidden_states
360
+
361
+ return output
362
+
363
+ def reset_state(self, module: torch.nn.Module) -> torch.nn.Module:
364
+ self.state.reset()
365
+ return module
366
+
367
+
368
+ class FasterCacheBlockHook(ModelHook):
369
+ _is_stateful = True
370
+
371
+ def __init__(
372
+ self,
373
+ block_skip_range: int,
374
+ timestep_skip_range: tuple[int, int],
375
+ is_guidance_distilled: bool,
376
+ weight_callback: Callable[[torch.nn.Module], float],
377
+ current_timestep_callback: Callable[[], int],
378
+ ) -> None:
379
+ super().__init__()
380
+
381
+ self.block_skip_range = block_skip_range
382
+ self.timestep_skip_range = timestep_skip_range
383
+ self.is_guidance_distilled = is_guidance_distilled
384
+
385
+ self.weight_callback = weight_callback
386
+ self.current_timestep_callback = current_timestep_callback
387
+
388
+ def initialize_hook(self, module):
389
+ self.state = FasterCacheBlockState()
390
+ return module
391
+
392
+ def _compute_approximated_attention_output(
393
+ self, t_2_output: torch.Tensor, t_output: torch.Tensor, weight: float, batch_size: int
394
+ ) -> torch.Tensor:
395
+ if t_2_output.size(0) != batch_size:
396
+ # The cache t_2_output contains both batchwise-concatenated unconditional-conditional branch outputs. Just
397
+ # take the conditional branch outputs.
398
+ assert t_2_output.size(0) == 2 * batch_size
399
+ t_2_output = t_2_output[batch_size:]
400
+ if t_output.size(0) != batch_size:
401
+ # The cache t_output contains both batchwise-concatenated unconditional-conditional branch outputs. Just
402
+ # take the conditional branch outputs.
403
+ assert t_output.size(0) == 2 * batch_size
404
+ t_output = t_output[batch_size:]
405
+ return t_output + (t_output - t_2_output) * weight
406
+
407
+ def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any:
408
+ batch_size = [
409
+ *[arg.size(0) for arg in args if torch.is_tensor(arg)],
410
+ *[v.size(0) for v in kwargs.values() if torch.is_tensor(v)],
411
+ ][0]
412
+ if self.state.batch_size is None:
413
+ # Will be updated on first forward pass through the denoiser
414
+ self.state.batch_size = batch_size
415
+
416
+ # If we have to skip due to the skip conditions, then let's skip as expected.
417
+ # But, we can't skip if the denoiser wants to infer both unconditional and conditional branches. This
418
+ # is because the expected output shapes of attention layer will not match if we only return values from
419
+ # the cache (which only caches conditional branch outputs). So, if state.batch_size (which is the true
420
+ # unconditional-conditional batch size) is same as the current batch size, we don't perform the layer
421
+ # skip. Otherwise, we conditionally skip the layer based on what state.skip_callback returns.
422
+ is_within_timestep_range = (
423
+ self.timestep_skip_range[0] < self.current_timestep_callback() < self.timestep_skip_range[1]
424
+ )
425
+ if not is_within_timestep_range:
426
+ should_skip_attention = False
427
+ else:
428
+ should_compute_attention = self.state.iteration > 0 and self.state.iteration % self.block_skip_range == 0
429
+ should_skip_attention = not should_compute_attention
430
+ if should_skip_attention:
431
+ should_skip_attention = self.is_guidance_distilled or self.state.batch_size != batch_size
432
+
433
+ if should_skip_attention:
434
+ logger.debug("FasterCache - Skipping attention and using approximation")
435
+ if torch.is_tensor(self.state.cache[-1]):
436
+ t_2_output, t_output = self.state.cache
437
+ weight = self.weight_callback(module)
438
+ output = self._compute_approximated_attention_output(t_2_output, t_output, weight, batch_size)
439
+ else:
440
+ # The cache contains multiple tensors from past N iterations (N=2 for FasterCache). We need to handle all of them.
441
+ # Diffusers blocks can return multiple tensors - let's call them [A, B, C, ...] for simplicity.
442
+ # In our cache, we would have [[A_1, B_1, C_1, ...], [A_2, B_2, C_2, ...], ...] where each list is the output from
443
+ # a forward pass of the block. We need to compute the approximated output for each of these tensors.
444
+ # The zip(*state.cache) operation will give us [(A_1, A_2, ...), (B_1, B_2, ...), (C_1, C_2, ...), ...] which
445
+ # allows us to compute the approximated attention output for each tensor in the cache.
446
+ output = ()
447
+ for t_2_output, t_output in zip(*self.state.cache):
448
+ result = self._compute_approximated_attention_output(
449
+ t_2_output, t_output, self.weight_callback(module), batch_size
450
+ )
451
+ output += (result,)
452
+ else:
453
+ logger.debug("FasterCache - Computing attention")
454
+ output = self.fn_ref.original_forward(*args, **kwargs)
455
+
456
+ # Note that the following condition for getting hidden_states should suffice since Diffusers blocks either return
457
+ # a single hidden_states tensor, or a tuple of (hidden_states, encoder_hidden_states) tensors. We need to handle
458
+ # both cases.
459
+ if torch.is_tensor(output):
460
+ cache_output = output
461
+ if not self.is_guidance_distilled and cache_output.size(0) == self.state.batch_size:
462
+ # The output here can be both unconditional-conditional branch outputs or just conditional branch outputs.
463
+ # This is determined at the higher-level denoiser module. We only want to cache the conditional branch outputs.
464
+ cache_output = cache_output.chunk(2, dim=0)[1]
465
+ else:
466
+ # Cache all return values and perform the same operation as above
467
+ cache_output = ()
468
+ for out in output:
469
+ if not self.is_guidance_distilled and out.size(0) == self.state.batch_size:
470
+ out = out.chunk(2, dim=0)[1]
471
+ cache_output += (out,)
472
+
473
+ if self.state.cache is None:
474
+ self.state.cache = [cache_output, cache_output]
475
+ else:
476
+ self.state.cache = [self.state.cache[-1], cache_output]
477
+
478
+ self.state.iteration += 1
479
+ return output
480
+
481
+ def reset_state(self, module: torch.nn.Module) -> torch.nn.Module:
482
+ self.state.reset()
483
+ return module
484
+
485
+
486
+ def apply_faster_cache(module: torch.nn.Module, config: FasterCacheConfig) -> None:
487
+ r"""
488
+ Applies [FasterCache](https://huggingface.co/papers/2410.19355) to a given pipeline.
489
+
490
+ Args:
491
+ module (`torch.nn.Module`):
492
+ The pytorch module to apply FasterCache to. Typically, this should be a transformer architecture supported
493
+ in Diffusers, such as `CogVideoXTransformer3DModel`, but external implementations may also work.
494
+ config (`FasterCacheConfig`):
495
+ The configuration to use for FasterCache.
496
+
497
+ Example:
498
+ ```python
499
+ >>> import torch
500
+ >>> from diffusers import CogVideoXPipeline, FasterCacheConfig, apply_faster_cache
501
+
502
+ >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
503
+ >>> pipe.to("cuda")
504
+
505
+ >>> config = FasterCacheConfig(
506
+ ... spatial_attention_block_skip_range=2,
507
+ ... spatial_attention_timestep_skip_range=(-1, 681),
508
+ ... low_frequency_weight_update_timestep_range=(99, 641),
509
+ ... high_frequency_weight_update_timestep_range=(-1, 301),
510
+ ... spatial_attention_block_identifiers=["transformer_blocks"],
511
+ ... attention_weight_callback=lambda _: 0.3,
512
+ ... tensor_format="BFCHW",
513
+ ... )
514
+ >>> apply_faster_cache(pipe.transformer, config)
515
+ ```
516
+ """
517
+
518
+ logger.warning(
519
+ "FasterCache is a purely experimental feature and may not work as expected. Not all models support FasterCache. "
520
+ "The API is subject to change in future releases, with no guarantee of backward compatibility. Please report any issues at "
521
+ "https://github.com/huggingface/diffusers/issues."
522
+ )
523
+
524
+ if config.attention_weight_callback is None:
525
+ # If the user has not provided a weight callback, we default to 0.5 for all timesteps.
526
+ # In the paper, they recommend using a gradually increasing weight from 0 to 1 as the inference progresses, but
527
+ # this depends from model-to-model. It is required by the user to provide a weight callback if they want to
528
+ # use a different weight function. Defaulting to 0.5 works well in practice for most cases.
529
+ logger.warning(
530
+ "No `attention_weight_callback` provided when enabling FasterCache. Defaulting to using a weight of 0.5 for all timesteps."
531
+ )
532
+ config.attention_weight_callback = lambda _: 0.5
533
+
534
+ if config.low_frequency_weight_callback is None:
535
+ logger.debug(
536
+ "Low frequency weight callback not provided when enabling FasterCache. Defaulting to behaviour described in the paper."
537
+ )
538
+
539
+ def low_frequency_weight_callback(module: torch.nn.Module) -> float:
540
+ is_within_range = (
541
+ config.low_frequency_weight_update_timestep_range[0]
542
+ < config.current_timestep_callback()
543
+ < config.low_frequency_weight_update_timestep_range[1]
544
+ )
545
+ return config.alpha_low_frequency if is_within_range else 1.0
546
+
547
+ config.low_frequency_weight_callback = low_frequency_weight_callback
548
+
549
+ if config.high_frequency_weight_callback is None:
550
+ logger.debug(
551
+ "High frequency weight callback not provided when enabling FasterCache. Defaulting to behaviour described in the paper."
552
+ )
553
+
554
+ def high_frequency_weight_callback(module: torch.nn.Module) -> float:
555
+ is_within_range = (
556
+ config.high_frequency_weight_update_timestep_range[0]
557
+ < config.current_timestep_callback()
558
+ < config.high_frequency_weight_update_timestep_range[1]
559
+ )
560
+ return config.alpha_high_frequency if is_within_range else 1.0
561
+
562
+ config.high_frequency_weight_callback = high_frequency_weight_callback
563
+
564
+ supported_tensor_formats = ["BCFHW", "BFCHW", "BCHW"] # TODO(aryan): Support BSC for LTX Video
565
+ if config.tensor_format not in supported_tensor_formats:
566
+ raise ValueError(f"`tensor_format` must be one of {supported_tensor_formats}, but got {config.tensor_format}.")
567
+
568
+ _apply_faster_cache_on_denoiser(module, config)
569
+
570
+ for name, submodule in module.named_modules():
571
+ if not isinstance(submodule, _ATTENTION_CLASSES):
572
+ continue
573
+ if any(re.search(identifier, name) is not None for identifier in _TRANSFORMER_BLOCK_IDENTIFIERS):
574
+ _apply_faster_cache_on_attention_class(name, submodule, config)
575
+
576
+
577
+ def _apply_faster_cache_on_denoiser(module: torch.nn.Module, config: FasterCacheConfig) -> None:
578
+ hook = FasterCacheDenoiserHook(
579
+ config.unconditional_batch_skip_range,
580
+ config.unconditional_batch_timestep_skip_range,
581
+ config.tensor_format,
582
+ config.is_guidance_distilled,
583
+ config._unconditional_conditional_input_kwargs_identifiers,
584
+ config.current_timestep_callback,
585
+ config.low_frequency_weight_callback,
586
+ config.high_frequency_weight_callback,
587
+ )
588
+ registry = HookRegistry.check_if_exists_or_initialize(module)
589
+ registry.register_hook(hook, _FASTER_CACHE_DENOISER_HOOK)
590
+
591
+
592
+ def _apply_faster_cache_on_attention_class(name: str, module: AttentionModuleMixin, config: FasterCacheConfig) -> None:
593
+ is_spatial_self_attention = (
594
+ any(re.search(identifier, name) is not None for identifier in config.spatial_attention_block_identifiers)
595
+ and config.spatial_attention_block_skip_range is not None
596
+ and not getattr(module, "is_cross_attention", False)
597
+ )
598
+ is_temporal_self_attention = (
599
+ any(re.search(identifier, name) is not None for identifier in config.temporal_attention_block_identifiers)
600
+ and config.temporal_attention_block_skip_range is not None
601
+ and not module.is_cross_attention
602
+ )
603
+
604
+ block_skip_range, timestep_skip_range, block_type = None, None, None
605
+ if is_spatial_self_attention:
606
+ block_skip_range = config.spatial_attention_block_skip_range
607
+ timestep_skip_range = config.spatial_attention_timestep_skip_range
608
+ block_type = "spatial"
609
+ elif is_temporal_self_attention:
610
+ block_skip_range = config.temporal_attention_block_skip_range
611
+ timestep_skip_range = config.temporal_attention_timestep_skip_range
612
+ block_type = "temporal"
613
+
614
+ if block_skip_range is None or timestep_skip_range is None:
615
+ logger.debug(
616
+ f'Unable to apply FasterCache to the selected layer: "{name}" because it does '
617
+ f"not match any of the required criteria for spatial or temporal attention layers. Note, "
618
+ f"however, that this layer may still be valid for applying PAB. Please specify the correct "
619
+ f"block identifiers in the configuration or use the specialized `apply_faster_cache_on_module` "
620
+ f"function to apply FasterCache to this layer."
621
+ )
622
+ return
623
+
624
+ logger.debug(f"Enabling FasterCache ({block_type}) for layer: {name}")
625
+ hook = FasterCacheBlockHook(
626
+ block_skip_range,
627
+ timestep_skip_range,
628
+ config.is_guidance_distilled,
629
+ config.attention_weight_callback,
630
+ config.current_timestep_callback,
631
+ )
632
+ registry = HookRegistry.check_if_exists_or_initialize(module)
633
+ registry.register_hook(hook, _FASTER_CACHE_BLOCK_HOOK)
634
+
635
+
636
+ # Reference: https://github.com/Vchitect/FasterCache/blob/fab32c15014636dc854948319c0a9a8d92c7acb4/scripts/latte/faster_cache_sample_latte.py#L127C1-L143C39
637
+ @torch.no_grad()
638
+ def _split_low_high_freq(x):
639
+ fft = torch.fft.fft2(x)
640
+ fft_shifted = torch.fft.fftshift(fft)
641
+ height, width = x.shape[-2:]
642
+ radius = min(height, width) // 5
643
+
644
+ y_grid, x_grid = torch.meshgrid(torch.arange(height), torch.arange(width))
645
+ center_x, center_y = width // 2, height // 2
646
+ mask = (x_grid - center_x) ** 2 + (y_grid - center_y) ** 2 <= radius**2
647
+
648
+ low_freq_mask = mask.unsqueeze(0).unsqueeze(0).to(x.device)
649
+ high_freq_mask = ~low_freq_mask
650
+
651
+ low_freq_fft = fft_shifted * low_freq_mask
652
+ high_freq_fft = fft_shifted * high_freq_mask
653
+
654
+ return low_freq_fft, high_freq_fft
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/first_block_cache.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+
17
+ import torch
18
+
19
+ from ..utils import get_logger
20
+ from ..utils.torch_utils import unwrap_module
21
+ from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS
22
+ from ._helpers import TransformerBlockRegistry
23
+ from .hooks import BaseState, HookRegistry, ModelHook, StateManager
24
+
25
+
26
+ logger = get_logger(__name__) # pylint: disable=invalid-name
27
+
28
+ _FBC_LEADER_BLOCK_HOOK = "fbc_leader_block_hook"
29
+ _FBC_BLOCK_HOOK = "fbc_block_hook"
30
+
31
+
32
+ @dataclass
33
+ class FirstBlockCacheConfig:
34
+ r"""
35
+ Configuration for [First Block
36
+ Cache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching).
37
+
38
+ Args:
39
+ threshold (`float`, defaults to `0.05`):
40
+ The threshold to determine whether or not a forward pass through all layers of the model is required. A
41
+ higher threshold usually results in a forward pass through a lower number of layers and faster inference,
42
+ but might lead to poorer generation quality. A lower threshold may not result in significant generation
43
+ speedup. The threshold is compared against the absmean difference of the residuals between the current and
44
+ cached outputs from the first transformer block. If the difference is below the threshold, the forward pass
45
+ is skipped.
46
+ """
47
+
48
+ threshold: float = 0.05
49
+
50
+
51
+ class FBCSharedBlockState(BaseState):
52
+ def __init__(self) -> None:
53
+ super().__init__()
54
+
55
+ self.head_block_output: torch.Tensor | tuple[torch.Tensor, ...] = None
56
+ self.head_block_residual: torch.Tensor = None
57
+ self.tail_block_residuals: torch.Tensor | tuple[torch.Tensor, ...] = None
58
+ self.should_compute: bool = True
59
+
60
+ def reset(self):
61
+ self.tail_block_residuals = None
62
+ self.should_compute = True
63
+
64
+
65
+ class FBCHeadBlockHook(ModelHook):
66
+ _is_stateful = True
67
+
68
+ def __init__(self, state_manager: StateManager, threshold: float):
69
+ self.state_manager = state_manager
70
+ self.threshold = threshold
71
+ self._metadata = None
72
+
73
+ def initialize_hook(self, module):
74
+ unwrapped_module = unwrap_module(module)
75
+ self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__)
76
+ return module
77
+
78
+ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
79
+ original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs)
80
+
81
+ output = self.fn_ref.original_forward(*args, **kwargs)
82
+ is_output_tuple = isinstance(output, tuple)
83
+
84
+ if is_output_tuple:
85
+ hidden_states_residual = output[self._metadata.return_hidden_states_index] - original_hidden_states
86
+ else:
87
+ hidden_states_residual = output - original_hidden_states
88
+
89
+ shared_state: FBCSharedBlockState = self.state_manager.get_state()
90
+ hidden_states = encoder_hidden_states = None
91
+ should_compute = self._should_compute_remaining_blocks(hidden_states_residual)
92
+ shared_state.should_compute = should_compute
93
+
94
+ if not should_compute:
95
+ # Apply caching
96
+ if is_output_tuple:
97
+ hidden_states = (
98
+ shared_state.tail_block_residuals[0] + output[self._metadata.return_hidden_states_index]
99
+ )
100
+ else:
101
+ hidden_states = shared_state.tail_block_residuals[0] + output
102
+
103
+ if self._metadata.return_encoder_hidden_states_index is not None:
104
+ assert is_output_tuple
105
+ encoder_hidden_states = (
106
+ shared_state.tail_block_residuals[1] + output[self._metadata.return_encoder_hidden_states_index]
107
+ )
108
+
109
+ if is_output_tuple:
110
+ return_output = [None] * len(output)
111
+ return_output[self._metadata.return_hidden_states_index] = hidden_states
112
+ return_output[self._metadata.return_encoder_hidden_states_index] = encoder_hidden_states
113
+ return_output = tuple(return_output)
114
+ else:
115
+ return_output = hidden_states
116
+ output = return_output
117
+ else:
118
+ if is_output_tuple:
119
+ head_block_output = [None] * len(output)
120
+ head_block_output[0] = output[self._metadata.return_hidden_states_index]
121
+ head_block_output[1] = output[self._metadata.return_encoder_hidden_states_index]
122
+ else:
123
+ head_block_output = output
124
+ shared_state.head_block_output = head_block_output
125
+ shared_state.head_block_residual = hidden_states_residual
126
+
127
+ return output
128
+
129
+ def reset_state(self, module):
130
+ self.state_manager.reset()
131
+ return module
132
+
133
+ @torch.compiler.disable
134
+ def _should_compute_remaining_blocks(self, hidden_states_residual: torch.Tensor) -> bool:
135
+ shared_state = self.state_manager.get_state()
136
+ if shared_state.head_block_residual is None:
137
+ return True
138
+ prev_hidden_states_residual = shared_state.head_block_residual
139
+ absmean = (hidden_states_residual - prev_hidden_states_residual).abs().mean()
140
+ prev_hidden_states_absmean = prev_hidden_states_residual.abs().mean()
141
+ diff = (absmean / prev_hidden_states_absmean).item()
142
+ return diff > self.threshold
143
+
144
+
145
+ class FBCBlockHook(ModelHook):
146
+ def __init__(self, state_manager: StateManager, is_tail: bool = False):
147
+ super().__init__()
148
+ self.state_manager = state_manager
149
+ self.is_tail = is_tail
150
+ self._metadata = None
151
+
152
+ def initialize_hook(self, module):
153
+ unwrapped_module = unwrap_module(module)
154
+ self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__)
155
+ return module
156
+
157
+ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
158
+ original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs)
159
+ original_encoder_hidden_states = None
160
+ if self._metadata.return_encoder_hidden_states_index is not None:
161
+ original_encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs(
162
+ "encoder_hidden_states", args, kwargs
163
+ )
164
+
165
+ shared_state = self.state_manager.get_state()
166
+
167
+ if shared_state.should_compute:
168
+ output = self.fn_ref.original_forward(*args, **kwargs)
169
+ if self.is_tail:
170
+ hidden_states_residual = encoder_hidden_states_residual = None
171
+ if isinstance(output, tuple):
172
+ hidden_states_residual = (
173
+ output[self._metadata.return_hidden_states_index] - shared_state.head_block_output[0]
174
+ )
175
+ encoder_hidden_states_residual = (
176
+ output[self._metadata.return_encoder_hidden_states_index] - shared_state.head_block_output[1]
177
+ )
178
+ else:
179
+ hidden_states_residual = output - shared_state.head_block_output
180
+ shared_state.tail_block_residuals = (hidden_states_residual, encoder_hidden_states_residual)
181
+ return output
182
+
183
+ if original_encoder_hidden_states is None:
184
+ return_output = original_hidden_states
185
+ else:
186
+ return_output = [None, None]
187
+ return_output[self._metadata.return_hidden_states_index] = original_hidden_states
188
+ return_output[self._metadata.return_encoder_hidden_states_index] = original_encoder_hidden_states
189
+ return_output = tuple(return_output)
190
+ return return_output
191
+
192
+
193
+ def apply_first_block_cache(module: torch.nn.Module, config: FirstBlockCacheConfig) -> None:
194
+ """
195
+ Applies [First Block
196
+ Cache](https://github.com/chengzeyi/ParaAttention/blob/4de137c5b96416489f06e43e19f2c14a772e28fd/README.md#first-block-cache-our-dynamic-caching)
197
+ to a given module.
198
+
199
+ First Block Cache builds on the ideas of [TeaCache](https://huggingface.co/papers/2411.19108). It is much simpler
200
+ to implement generically for a wide range of models and has been integrated first for experimental purposes.
201
+
202
+ Args:
203
+ module (`torch.nn.Module`):
204
+ The pytorch module to apply FBCache to. Typically, this should be a transformer architecture supported in
205
+ Diffusers, such as `CogVideoXTransformer3DModel`, but external implementations may also work.
206
+ config (`FirstBlockCacheConfig`):
207
+ The configuration to use for applying the FBCache method.
208
+
209
+ Example:
210
+ ```python
211
+ >>> import torch
212
+ >>> from diffusers import CogView4Pipeline
213
+ >>> from diffusers.hooks import apply_first_block_cache, FirstBlockCacheConfig
214
+
215
+ >>> pipe = CogView4Pipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16)
216
+ >>> pipe.to("cuda")
217
+
218
+ >>> apply_first_block_cache(pipe.transformer, FirstBlockCacheConfig(threshold=0.2))
219
+
220
+ >>> prompt = "A photo of an astronaut riding a horse on mars"
221
+ >>> image = pipe(prompt, generator=torch.Generator().manual_seed(42)).images[0]
222
+ >>> image.save("output.png")
223
+ ```
224
+ """
225
+
226
+ state_manager = StateManager(FBCSharedBlockState, (), {})
227
+ remaining_blocks = []
228
+
229
+ for name, submodule in module.named_children():
230
+ if name not in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS or not isinstance(submodule, torch.nn.ModuleList):
231
+ continue
232
+ for index, block in enumerate(submodule):
233
+ remaining_blocks.append((f"{name}.{index}", block))
234
+
235
+ head_block_name, head_block = remaining_blocks.pop(0)
236
+ tail_block_name, tail_block = remaining_blocks.pop(-1)
237
+
238
+ logger.debug(f"Applying FBCHeadBlockHook to '{head_block_name}'")
239
+ _apply_fbc_head_block_hook(head_block, state_manager, config.threshold)
240
+
241
+ for name, block in remaining_blocks:
242
+ logger.debug(f"Applying FBCBlockHook to '{name}'")
243
+ _apply_fbc_block_hook(block, state_manager)
244
+
245
+ logger.debug(f"Applying FBCBlockHook to tail block '{tail_block_name}'")
246
+ _apply_fbc_block_hook(tail_block, state_manager, is_tail=True)
247
+
248
+
249
+ def _apply_fbc_head_block_hook(block: torch.nn.Module, state_manager: StateManager, threshold: float) -> None:
250
+ registry = HookRegistry.check_if_exists_or_initialize(block)
251
+ hook = FBCHeadBlockHook(state_manager, threshold)
252
+ registry.register_hook(hook, _FBC_LEADER_BLOCK_HOOK)
253
+
254
+
255
+ def _apply_fbc_block_hook(block: torch.nn.Module, state_manager: StateManager, is_tail: bool = False) -> None:
256
+ registry = HookRegistry.check_if_exists_or_initialize(block)
257
+ hook = FBCBlockHook(state_manager, is_tail)
258
+ registry.register_hook(hook, _FBC_BLOCK_HOOK)
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/group_offloading.py ADDED
@@ -0,0 +1,966 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import hashlib
16
+ import os
17
+ from contextlib import contextmanager, nullcontext
18
+ from dataclasses import dataclass, replace
19
+ from enum import Enum
20
+ from typing import Set
21
+
22
+ import safetensors.torch
23
+ import torch
24
+
25
+ from ..utils import get_logger, is_accelerate_available
26
+ from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
27
+ from .hooks import HookRegistry, ModelHook
28
+
29
+
30
+ if is_accelerate_available():
31
+ from accelerate.hooks import AlignDevicesHook, CpuOffload
32
+ from accelerate.utils import send_to_device
33
+
34
+
35
+ logger = get_logger(__name__) # pylint: disable=invalid-name
36
+
37
+
38
+ # fmt: off
39
+ _GROUP_OFFLOADING = "group_offloading"
40
+ _LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
41
+ _LAZY_PREFETCH_GROUP_OFFLOADING = "lazy_prefetch_group_offloading"
42
+ _GROUP_ID_LAZY_LEAF = "lazy_leafs"
43
+ # fmt: on
44
+
45
+
46
+ class GroupOffloadingType(str, Enum):
47
+ BLOCK_LEVEL = "block_level"
48
+ LEAF_LEVEL = "leaf_level"
49
+
50
+
51
+ @dataclass
52
+ class GroupOffloadingConfig:
53
+ onload_device: torch.device
54
+ offload_device: torch.device
55
+ offload_type: GroupOffloadingType
56
+ non_blocking: bool
57
+ record_stream: bool
58
+ low_cpu_mem_usage: bool
59
+ num_blocks_per_group: int | None = None
60
+ offload_to_disk_path: str | None = None
61
+ stream: torch.cuda.Stream | torch.Stream | None = None
62
+ block_modules: list[str] | None = None
63
+ exclude_kwargs: list[str] | None = None
64
+ module_prefix: str = ""
65
+
66
+
67
+ class ModuleGroup:
68
+ def __init__(
69
+ self,
70
+ modules: list[torch.nn.Module],
71
+ offload_device: torch.device,
72
+ onload_device: torch.device,
73
+ offload_leader: torch.nn.Module,
74
+ onload_leader: torch.nn.Module | None = None,
75
+ parameters: list[torch.nn.Parameter] | None = None,
76
+ buffers: list[torch.Tensor] | None = None,
77
+ non_blocking: bool = False,
78
+ stream: torch.cuda.Stream | torch.Stream | None = None,
79
+ record_stream: bool | None = False,
80
+ low_cpu_mem_usage: bool = False,
81
+ onload_self: bool = True,
82
+ offload_to_disk_path: str | None = None,
83
+ group_id: int | str | None = None,
84
+ ) -> None:
85
+ self.modules = modules
86
+ self.offload_device = offload_device
87
+ self.onload_device = onload_device
88
+ self.offload_leader = offload_leader
89
+ self.onload_leader = onload_leader
90
+ self.parameters = parameters or []
91
+ self.buffers = buffers or []
92
+ self.non_blocking = non_blocking or stream is not None
93
+ self.stream = stream
94
+ self.record_stream = record_stream
95
+ self.onload_self = onload_self
96
+ self.low_cpu_mem_usage = low_cpu_mem_usage
97
+
98
+ self.offload_to_disk_path = offload_to_disk_path
99
+ self._is_offloaded_to_disk = False
100
+
101
+ if self.offload_to_disk_path is not None:
102
+ # Instead of `group_id or str(id(self))` we do this because `group_id` can be "" as well.
103
+ self.group_id = group_id if group_id is not None else str(id(self))
104
+ short_hash = _compute_group_hash(self.group_id)
105
+ self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{short_hash}.safetensors")
106
+
107
+ all_tensors = []
108
+ for module in self.modules:
109
+ all_tensors.extend(list(module.parameters()))
110
+ all_tensors.extend(list(module.buffers()))
111
+ all_tensors.extend(self.parameters)
112
+ all_tensors.extend(self.buffers)
113
+ all_tensors = list(dict.fromkeys(all_tensors)) # Remove duplicates
114
+
115
+ self.tensor_to_key = {tensor: f"tensor_{i}" for i, tensor in enumerate(all_tensors)}
116
+ self.key_to_tensor = {v: k for k, v in self.tensor_to_key.items()}
117
+ self.cpu_param_dict = {}
118
+ else:
119
+ self.cpu_param_dict = self._init_cpu_param_dict()
120
+
121
+ self._torch_accelerator_module = (
122
+ getattr(torch, torch.accelerator.current_accelerator().type)
123
+ if hasattr(torch, "accelerator")
124
+ else torch.cuda
125
+ )
126
+
127
+ def _init_cpu_param_dict(self):
128
+ cpu_param_dict = {}
129
+ if self.stream is None:
130
+ return cpu_param_dict
131
+
132
+ for module in self.modules:
133
+ for param in module.parameters():
134
+ cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory()
135
+ for buffer in module.buffers():
136
+ cpu_param_dict[buffer] = (
137
+ buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory()
138
+ )
139
+
140
+ for param in self.parameters:
141
+ cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory()
142
+
143
+ for buffer in self.buffers:
144
+ cpu_param_dict[buffer] = buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory()
145
+
146
+ return cpu_param_dict
147
+
148
+ @contextmanager
149
+ def _pinned_memory_tensors(self):
150
+ try:
151
+ pinned_dict = {
152
+ param: tensor.pin_memory() if not tensor.is_pinned() else tensor
153
+ for param, tensor in self.cpu_param_dict.items()
154
+ }
155
+ yield pinned_dict
156
+ finally:
157
+ pinned_dict = None
158
+
159
+ def _transfer_tensor_to_device(self, tensor, source_tensor, default_stream):
160
+ tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
161
+ if self.record_stream:
162
+ tensor.data.record_stream(default_stream)
163
+
164
+ def _process_tensors_from_modules(self, pinned_memory=None, default_stream=None):
165
+ for group_module in self.modules:
166
+ for param in group_module.parameters():
167
+ source = pinned_memory[param] if pinned_memory else param.data
168
+ self._transfer_tensor_to_device(param, source, default_stream)
169
+ for buffer in group_module.buffers():
170
+ source = pinned_memory[buffer] if pinned_memory else buffer.data
171
+ self._transfer_tensor_to_device(buffer, source, default_stream)
172
+
173
+ for param in self.parameters:
174
+ source = pinned_memory[param] if pinned_memory else param.data
175
+ self._transfer_tensor_to_device(param, source, default_stream)
176
+
177
+ for buffer in self.buffers:
178
+ source = pinned_memory[buffer] if pinned_memory else buffer.data
179
+ self._transfer_tensor_to_device(buffer, source, default_stream)
180
+
181
+ def _onload_from_disk(self):
182
+ if self.stream is not None:
183
+ # Wait for previous Host->Device transfer to complete
184
+ self.stream.synchronize()
185
+
186
+ context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream)
187
+ current_stream = self._torch_accelerator_module.current_stream() if self.record_stream else None
188
+
189
+ with context:
190
+ # Load to CPU (if using streams) or directly to target device, pin, and async copy to device
191
+ device = str(self.onload_device) if self.stream is None else "cpu"
192
+ loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=device)
193
+
194
+ if self.stream is not None:
195
+ for key, tensor_obj in self.key_to_tensor.items():
196
+ pinned_tensor = loaded_tensors[key].pin_memory()
197
+ tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking)
198
+ if self.record_stream:
199
+ tensor_obj.data.record_stream(current_stream)
200
+ else:
201
+ onload_device = (
202
+ self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
203
+ )
204
+ loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
205
+ for key, tensor_obj in self.key_to_tensor.items():
206
+ tensor_obj.data = loaded_tensors[key]
207
+
208
+ def _onload_from_memory(self):
209
+ if self.stream is not None:
210
+ # Wait for previous Host->Device transfer to complete
211
+ self.stream.synchronize()
212
+
213
+ context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream)
214
+ default_stream = self._torch_accelerator_module.current_stream() if self.stream is not None else None
215
+
216
+ with context:
217
+ if self.stream is not None:
218
+ with self._pinned_memory_tensors() as pinned_memory:
219
+ self._process_tensors_from_modules(pinned_memory, default_stream=default_stream)
220
+ else:
221
+ self._process_tensors_from_modules(None)
222
+
223
+ def _offload_to_disk(self):
224
+ # TODO: we can potentially optimize this code path by checking if the _all_ the desired
225
+ # safetensor files exist on the disk and if so, skip this step entirely, reducing IO
226
+ # overhead. Currently, we just check if the given `safetensors_file_path` exists and if not
227
+ # we perform a write.
228
+ # Check if the file has been saved in this session or if it already exists on disk.
229
+ if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path):
230
+ os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True)
231
+ tensors_to_save = {key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()}
232
+ safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path)
233
+
234
+ # The group is now considered offloaded to disk for the rest of the session.
235
+ self._is_offloaded_to_disk = True
236
+
237
+ # We do this to free up the RAM which is still holding the up tensor data.
238
+ for tensor_obj in self.tensor_to_key.keys():
239
+ tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device)
240
+
241
+ def _offload_to_memory(self):
242
+ if self.stream is not None:
243
+ if not self.record_stream:
244
+ self._torch_accelerator_module.current_stream().synchronize()
245
+
246
+ for group_module in self.modules:
247
+ for param in group_module.parameters():
248
+ param.data = self.cpu_param_dict[param]
249
+ for param in self.parameters:
250
+ param.data = self.cpu_param_dict[param]
251
+ for buffer in self.buffers:
252
+ buffer.data = self.cpu_param_dict[buffer]
253
+ else:
254
+ for group_module in self.modules:
255
+ group_module.to(self.offload_device, non_blocking=False)
256
+ for param in self.parameters:
257
+ param.data = param.data.to(self.offload_device, non_blocking=False)
258
+ for buffer in self.buffers:
259
+ buffer.data = buffer.data.to(self.offload_device, non_blocking=False)
260
+
261
+ @torch.compiler.disable()
262
+ def onload_(self):
263
+ r"""Onloads the group of parameters to the onload_device."""
264
+ if self.offload_to_disk_path is not None:
265
+ self._onload_from_disk()
266
+ else:
267
+ self._onload_from_memory()
268
+
269
+ @torch.compiler.disable()
270
+ def offload_(self):
271
+ r"""Offloads the group of parameters to the offload_device."""
272
+ if self.offload_to_disk_path:
273
+ self._offload_to_disk()
274
+ else:
275
+ self._offload_to_memory()
276
+
277
+
278
+ class GroupOffloadingHook(ModelHook):
279
+ r"""
280
+ A hook that offloads groups of torch.nn.Module to the CPU for storage and onloads to accelerator device for
281
+ computation. Each group has one "onload leader" module that is responsible for onloading, and an "offload leader"
282
+ module that is responsible for offloading. If prefetching is enabled, the onload leader of the previous module
283
+ group is responsible for onloading the current module group.
284
+ """
285
+
286
+ _is_stateful = False
287
+
288
+ def __init__(self, group: ModuleGroup, *, config: GroupOffloadingConfig) -> None:
289
+ self.group = group
290
+ self.next_group: ModuleGroup | None = None
291
+ self.config = config
292
+
293
+ def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
294
+ if self.group.offload_leader == module:
295
+ self.group.offload_()
296
+ return module
297
+
298
+ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
299
+ # If there wasn't an onload_leader assigned, we assume that the submodule that first called its forward
300
+ # method is the onload_leader of the group.
301
+ if self.group.onload_leader is None:
302
+ self.group.onload_leader = module
303
+
304
+ # If the current module is the onload_leader of the group, we onload the group if it is supposed
305
+ # to onload itself. In the case of using prefetching with streams, we onload the next group if
306
+ # it is not supposed to onload itself.
307
+ if self.group.onload_leader == module:
308
+ if self.group.onload_self:
309
+ self.group.onload_()
310
+ else:
311
+ # onload_self=False means this group relies on prefetching from a previous group.
312
+ # However, for conditionally-executed modules (e.g. patch_short/patch_mid/patch_long in Helios),
313
+ # the prefetch chain may not cover them if they were absent during the first forward pass
314
+ # when the execution order was traced. In that case, their weights remain on offload_device,
315
+ # so we fall back to a synchronous onload here.
316
+ params = [p for m in self.group.modules for p in m.parameters()] + list(self.group.parameters)
317
+ if params and params[0].device == self.group.offload_device:
318
+ self.group.onload_()
319
+ if self.group.stream is not None:
320
+ self.group.stream.synchronize()
321
+
322
+ should_onload_next_group = self.next_group is not None and not self.next_group.onload_self
323
+ if should_onload_next_group:
324
+ self.next_group.onload_()
325
+
326
+ should_synchronize = (
327
+ not self.group.onload_self and self.group.stream is not None and not should_onload_next_group
328
+ )
329
+ if should_synchronize:
330
+ # If this group didn't onload itself, it means it was asynchronously onloaded by the
331
+ # previous group. We need to synchronize the side stream to ensure parameters
332
+ # are completely loaded to proceed with forward pass. Without this, uninitialized
333
+ # weights will be used in the computation, leading to incorrect results
334
+ # Also, we should only do this synchronization if we don't already do it from the sync call in
335
+ # self.next_group.onload_, hence the `not should_onload_next_group` check.
336
+ self.group.stream.synchronize()
337
+
338
+ args = send_to_device(args, self.group.onload_device, non_blocking=self.group.non_blocking)
339
+
340
+ # Some Autoencoder models use a feature cache that is passed through submodules
341
+ # and modified in place. The `send_to_device` call returns a copy of this feature cache object
342
+ # which breaks the inplace updates. Use `exclude_kwargs` to mark these cache features
343
+ exclude_kwargs = self.config.exclude_kwargs or []
344
+ if exclude_kwargs:
345
+ moved_kwargs = send_to_device(
346
+ {k: v for k, v in kwargs.items() if k not in exclude_kwargs},
347
+ self.group.onload_device,
348
+ non_blocking=self.group.non_blocking,
349
+ )
350
+ kwargs.update(moved_kwargs)
351
+ else:
352
+ kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking)
353
+
354
+ return args, kwargs
355
+
356
+ def post_forward(self, module: torch.nn.Module, output):
357
+ if self.group.offload_leader == module:
358
+ self.group.offload_()
359
+ return output
360
+
361
+
362
+ class LazyPrefetchGroupOffloadingHook(ModelHook):
363
+ r"""
364
+ A hook, used in conjunction with GroupOffloadingHook, that applies lazy prefetching to groups of torch.nn.Module.
365
+ This hook is used to determine the order in which the layers are executed during the forward pass. Once the layer
366
+ invocation order is known, assignments of the next_group attribute for prefetching can be made, which allows
367
+ prefetching groups in the correct order.
368
+ """
369
+
370
+ _is_stateful = False
371
+
372
+ def __init__(self):
373
+ self.execution_order: list[tuple[str, torch.nn.Module]] = []
374
+ self._layer_execution_tracker_module_names = set()
375
+
376
+ def initialize_hook(self, module):
377
+ def make_execution_order_update_callback(current_name, current_submodule):
378
+ def callback():
379
+ if not torch.compiler.is_compiling():
380
+ logger.debug(f"Adding {current_name} to the execution order")
381
+ self.execution_order.append((current_name, current_submodule))
382
+
383
+ return callback
384
+
385
+ # To every submodule that contains a group offloading hook (at this point, no prefetching is enabled for any
386
+ # of the groups), we add a layer execution tracker hook that will be used to determine the order in which the
387
+ # layers are executed during the forward pass.
388
+ for name, submodule in module.named_modules():
389
+ if name == "" or not hasattr(submodule, "_diffusers_hook"):
390
+ continue
391
+
392
+ registry = HookRegistry.check_if_exists_or_initialize(submodule)
393
+ group_offloading_hook = registry.get_hook(_GROUP_OFFLOADING)
394
+
395
+ if group_offloading_hook is not None:
396
+ # For the first forward pass, we have to load in a blocking manner
397
+ group_offloading_hook.group.non_blocking = False
398
+ layer_tracker_hook = LayerExecutionTrackerHook(make_execution_order_update_callback(name, submodule))
399
+ registry.register_hook(layer_tracker_hook, _LAYER_EXECUTION_TRACKER)
400
+ self._layer_execution_tracker_module_names.add(name)
401
+
402
+ return module
403
+
404
+ def post_forward(self, module, output):
405
+ # At this point, for the current modules' submodules, we know the execution order of the layers. We can now
406
+ # remove the layer execution tracker hooks and apply prefetching by setting the next_group attribute for each
407
+ # group offloading hook.
408
+ num_executed = len(self.execution_order)
409
+ execution_order_module_names = {name for name, _ in self.execution_order}
410
+
411
+ # It may be possible that some layers were not executed during the forward pass. This can happen if the layer
412
+ # is not used in the forward pass, or if the layer is not executed due to some other reason. In such cases, we
413
+ # may not be able to apply prefetching in the correct order, which can lead to device-mismatch related errors
414
+ # if the missing layers end up being executed in the future.
415
+ if execution_order_module_names != self._layer_execution_tracker_module_names:
416
+ unexecuted_layers = list(self._layer_execution_tracker_module_names - execution_order_module_names)
417
+ if not torch.compiler.is_compiling():
418
+ logger.warning(
419
+ "It seems like some layers were not executed during the forward pass. This may lead to problems when "
420
+ "applying lazy prefetching with automatic tracing and lead to device-mismatch related errors. Please "
421
+ "make sure that all layers are executed during the forward pass. The following layers were not executed:\n"
422
+ f"{unexecuted_layers=}"
423
+ )
424
+
425
+ # Remove the layer execution tracker hooks from the submodules
426
+ base_module_registry = module._diffusers_hook
427
+ registries = [submodule._diffusers_hook for _, submodule in self.execution_order]
428
+ group_offloading_hooks = [registry.get_hook(_GROUP_OFFLOADING) for registry in registries]
429
+
430
+ for i in range(num_executed):
431
+ registries[i].remove_hook(_LAYER_EXECUTION_TRACKER, recurse=False)
432
+
433
+ # Remove the current lazy prefetch group offloading hook so that it doesn't interfere with the next forward pass
434
+ base_module_registry.remove_hook(_LAZY_PREFETCH_GROUP_OFFLOADING, recurse=False)
435
+
436
+ # LazyPrefetchGroupOffloadingHook is only used with streams, so we know that non_blocking should be True.
437
+ # We disable non_blocking for the first forward pass, but need to enable it for the subsequent passes to
438
+ # see the benefits of prefetching.
439
+ for hook in group_offloading_hooks:
440
+ hook.group.non_blocking = True
441
+
442
+ # Set required attributes for prefetching
443
+ if num_executed > 0:
444
+ base_module_group_offloading_hook = base_module_registry.get_hook(_GROUP_OFFLOADING)
445
+ base_module_group_offloading_hook.next_group = group_offloading_hooks[0].group
446
+ base_module_group_offloading_hook.next_group.onload_self = False
447
+
448
+ for i in range(num_executed - 1):
449
+ name1, _ = self.execution_order[i]
450
+ name2, _ = self.execution_order[i + 1]
451
+ if not torch.compiler.is_compiling():
452
+ logger.debug(f"Applying lazy prefetch group offloading from {name1} to {name2}")
453
+ group_offloading_hooks[i].next_group = group_offloading_hooks[i + 1].group
454
+ group_offloading_hooks[i].next_group.onload_self = False
455
+
456
+ return output
457
+
458
+
459
+ class LayerExecutionTrackerHook(ModelHook):
460
+ r"""
461
+ A hook that tracks the order in which the layers are executed during the forward pass by calling back to the
462
+ LazyPrefetchGroupOffloadingHook to update the execution order.
463
+ """
464
+
465
+ _is_stateful = False
466
+
467
+ def __init__(self, execution_order_update_callback):
468
+ self.execution_order_update_callback = execution_order_update_callback
469
+
470
+ def pre_forward(self, module, *args, **kwargs):
471
+ self.execution_order_update_callback()
472
+ return args, kwargs
473
+
474
+
475
+ def apply_group_offloading(
476
+ module: torch.nn.Module,
477
+ onload_device: str | torch.device,
478
+ offload_device: str | torch.device = torch.device("cpu"),
479
+ offload_type: str | GroupOffloadingType = "block_level",
480
+ num_blocks_per_group: int | None = None,
481
+ non_blocking: bool = False,
482
+ use_stream: bool = False,
483
+ record_stream: bool = False,
484
+ low_cpu_mem_usage: bool = False,
485
+ offload_to_disk_path: str | None = None,
486
+ block_modules: list[str] | None = None,
487
+ exclude_kwargs: list[str] | None = None,
488
+ ) -> None:
489
+ r"""
490
+ Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
491
+ where it is beneficial, we need to first provide some context on how other supported offloading methods work.
492
+
493
+ Typically, offloading is done at two levels:
494
+ - Module-level: In Diffusers, this can be enabled using the `ModelMixin::enable_model_cpu_offload()` method. It
495
+ works by offloading each component of a pipeline to the CPU for storage, and onloading to the accelerator device
496
+ when needed for computation. This method is more memory-efficient than keeping all components on the accelerator,
497
+ but the memory requirements are still quite high. For this method to work, one needs memory equivalent to size of
498
+ the model in runtime dtype + size of largest intermediate activation tensors to be able to complete the forward
499
+ pass.
500
+ - Leaf-level: In Diffusers, this can be enabled using the `ModelMixin::enable_sequential_cpu_offload()` method. It
501
+ works by offloading the lowest leaf-level parameters of the computation graph to the CPU for storage, and
502
+ onloading only the leafs to the accelerator device for computation. This uses the lowest amount of accelerator
503
+ memory, but can be slower due to the excessive number of device synchronizations.
504
+
505
+ Group offloading is a middle ground between the two methods. It works by offloading groups of internal layers,
506
+ (either `torch.nn.ModuleList` or `torch.nn.Sequential`). This method uses lower memory than module-level
507
+ offloading. It is also faster than leaf-level/sequential offloading, as the number of device synchronizations is
508
+ reduced.
509
+
510
+ Another supported feature (for CUDA devices with support for asynchronous data transfer streams) is the ability to
511
+ overlap data transfer and computation to reduce the overall execution time compared to sequential offloading. This
512
+ is enabled using layer prefetching with streams, i.e., the layer that is to be executed next starts onloading to
513
+ the accelerator device while the current layer is being executed - this increases the memory requirements slightly.
514
+ Note that this implementation also supports leaf-level offloading but can be made much faster when using streams.
515
+
516
+ Args:
517
+ module (`torch.nn.Module`):
518
+ The module to which group offloading is applied.
519
+ onload_device (`torch.device`):
520
+ The device to which the group of modules are onloaded.
521
+ offload_device (`torch.device`, defaults to `torch.device("cpu")`):
522
+ The device to which the group of modules are offloaded. This should typically be the CPU. Default is CPU.
523
+ offload_type (`str` or `GroupOffloadingType`, defaults to "block_level"):
524
+ The type of offloading to be applied. Can be one of "block_level" or "leaf_level". Default is
525
+ "block_level".
526
+ offload_to_disk_path (`str`, *optional*, defaults to `None`):
527
+ The path to the directory where parameters will be offloaded. Setting this option can be useful in limited
528
+ RAM environment settings where a reasonable speed-memory trade-off is desired.
529
+ num_blocks_per_group (`int`, *optional*):
530
+ The number of blocks per group when using offload_type="block_level". This is required when using
531
+ offload_type="block_level".
532
+ non_blocking (`bool`, defaults to `False`):
533
+ If True, offloading and onloading is done with non-blocking data transfer.
534
+ use_stream (`bool`, defaults to `False`):
535
+ If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for
536
+ overlapping computation and data transfer.
537
+ record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
538
+ as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the
539
+ [PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more
540
+ details.
541
+ low_cpu_mem_usage (`bool`, defaults to `False`):
542
+ If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
543
+ option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
544
+ the CPU memory is a bottleneck but may counteract the benefits of using streams.
545
+ block_modules (`list[str]`, *optional*):
546
+ List of module names that should be treated as blocks for offloading. If provided, only these modules will
547
+ be considered for block-level offloading. If not provided, the default block detection logic will be used.
548
+ exclude_kwargs (`list[str]`, *optional*):
549
+ List of kwarg keys that should not be processed by send_to_device. This is useful for mutable state like
550
+ caching lists that need to maintain their object identity across forward passes. If not provided, will be
551
+ inferred from the module's `_skip_keys` attribute if it exists.
552
+
553
+ Example:
554
+ ```python
555
+ >>> from diffusers import CogVideoXTransformer3DModel
556
+ >>> from diffusers.hooks import apply_group_offloading
557
+
558
+ >>> transformer = CogVideoXTransformer3DModel.from_pretrained(
559
+ ... "THUDM/CogVideoX-5b", subfolder="transformer", torch_dtype=torch.bfloat16
560
+ ... )
561
+
562
+ >>> apply_group_offloading(
563
+ ... transformer,
564
+ ... onload_device=torch.device("cuda"),
565
+ ... offload_device=torch.device("cpu"),
566
+ ... offload_type="block_level",
567
+ ... num_blocks_per_group=2,
568
+ ... use_stream=True,
569
+ ... )
570
+ ```
571
+ """
572
+
573
+ onload_device = torch.device(onload_device) if isinstance(onload_device, str) else onload_device
574
+ offload_device = torch.device(offload_device) if isinstance(offload_device, str) else offload_device
575
+ offload_type = GroupOffloadingType(offload_type)
576
+
577
+ stream = None
578
+ if use_stream:
579
+ if torch.cuda.is_available():
580
+ stream = torch.cuda.Stream()
581
+ elif hasattr(torch, "xpu") and torch.xpu.is_available():
582
+ stream = torch.Stream()
583
+ else:
584
+ raise ValueError("Using streams for data transfer requires a CUDA device, or an Intel XPU device.")
585
+
586
+ if not use_stream and record_stream:
587
+ raise ValueError("`record_stream` cannot be True when `use_stream=False`.")
588
+ if offload_type == GroupOffloadingType.BLOCK_LEVEL and num_blocks_per_group is None:
589
+ raise ValueError("`num_blocks_per_group` must be provided when using `offload_type='block_level'.")
590
+
591
+ _raise_error_if_accelerate_model_or_sequential_hook_present(module)
592
+
593
+ if block_modules is None:
594
+ block_modules = getattr(module, "_group_offload_block_modules", None)
595
+
596
+ if exclude_kwargs is None:
597
+ exclude_kwargs = getattr(module, "_skip_keys", None)
598
+
599
+ config = GroupOffloadingConfig(
600
+ onload_device=onload_device,
601
+ offload_device=offload_device,
602
+ offload_type=offload_type,
603
+ num_blocks_per_group=num_blocks_per_group,
604
+ non_blocking=non_blocking,
605
+ stream=stream,
606
+ record_stream=record_stream,
607
+ low_cpu_mem_usage=low_cpu_mem_usage,
608
+ offload_to_disk_path=offload_to_disk_path,
609
+ block_modules=block_modules,
610
+ exclude_kwargs=exclude_kwargs,
611
+ )
612
+ _apply_group_offloading(module, config)
613
+
614
+
615
+ def _apply_group_offloading(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
616
+ if config.offload_type == GroupOffloadingType.BLOCK_LEVEL:
617
+ _apply_group_offloading_block_level(module, config)
618
+ elif config.offload_type == GroupOffloadingType.LEAF_LEVEL:
619
+ _apply_group_offloading_leaf_level(module, config)
620
+ else:
621
+ assert False
622
+
623
+
624
+ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
625
+ r"""
626
+ This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks, and explicitly
627
+ defined block modules. In comparison to the "leaf_level" offloading, which is more fine-grained, this offloading is
628
+ done at the top-level blocks and modules specified in block_modules.
629
+
630
+ When block_modules is provided, only those modules will be treated as blocks for offloading. For each specified
631
+ module, recursively apply block offloading to it.
632
+ """
633
+ if config.stream is not None and config.num_blocks_per_group != 1:
634
+ logger.warning(
635
+ f"Using streams is only supported for num_blocks_per_group=1. Got {config.num_blocks_per_group=}. Setting it to 1."
636
+ )
637
+ config.num_blocks_per_group = 1
638
+
639
+ block_modules = set(config.block_modules) if config.block_modules is not None else set()
640
+
641
+ # Create module groups for ModuleList and Sequential blocks, and explicitly defined block modules
642
+ modules_with_group_offloading = set()
643
+ unmatched_modules = []
644
+ matched_module_groups = []
645
+
646
+ for name, submodule in module.named_children():
647
+ # Check if this is an explicitly defined block module
648
+ if name in block_modules:
649
+ # Track submodule using a prefix to avoid filename collisions during disk offload.
650
+ # Without this, submodules sharing the same model class would be assigned identical
651
+ # filenames (derived from the class name).
652
+ prefix = f"{config.module_prefix}{name}." if config.module_prefix else f"{name}."
653
+ submodule_config = replace(config, module_prefix=prefix)
654
+
655
+ _apply_group_offloading_block_level(submodule, submodule_config)
656
+ modules_with_group_offloading.add(name)
657
+
658
+ elif isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
659
+ # Handle ModuleList and Sequential blocks as before
660
+ for i in range(0, len(submodule), config.num_blocks_per_group):
661
+ current_modules = list(submodule[i : i + config.num_blocks_per_group])
662
+ if len(current_modules) == 0:
663
+ continue
664
+
665
+ group_id = f"{config.module_prefix}{name}_{i}_{i + len(current_modules) - 1}"
666
+ group = ModuleGroup(
667
+ modules=current_modules,
668
+ offload_device=config.offload_device,
669
+ onload_device=config.onload_device,
670
+ offload_to_disk_path=config.offload_to_disk_path,
671
+ offload_leader=current_modules[-1],
672
+ onload_leader=current_modules[0],
673
+ non_blocking=config.non_blocking,
674
+ stream=config.stream,
675
+ record_stream=config.record_stream,
676
+ low_cpu_mem_usage=config.low_cpu_mem_usage,
677
+ onload_self=True,
678
+ group_id=group_id,
679
+ )
680
+ matched_module_groups.append(group)
681
+ for j in range(i, i + len(current_modules)):
682
+ modules_with_group_offloading.add(f"{name}.{j}")
683
+ else:
684
+ # This is an unmatched module
685
+ unmatched_modules.append((name, submodule))
686
+
687
+ # Apply group offloading hooks to the module groups
688
+ for i, group in enumerate(matched_module_groups):
689
+ for group_module in group.modules:
690
+ _apply_group_offloading_hook(group_module, group, config=config)
691
+
692
+ # Parameters and Buffers of the top-level module need to be offloaded/onloaded separately
693
+ # when the forward pass of this module is called. This is because the top-level module is not
694
+ # part of any group (as doing so would lead to no VRAM savings).
695
+ parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading)
696
+ buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading)
697
+ parameters = [param for _, param in parameters]
698
+ buffers = [buffer for _, buffer in buffers]
699
+
700
+ # Create a group for the remaining unmatched submodules of the top-level
701
+ # module so that they are on the correct device when the forward pass is called.
702
+ unmatched_modules = [unmatched_module for _, unmatched_module in unmatched_modules]
703
+ if len(unmatched_modules) > 0 or len(parameters) > 0 or len(buffers) > 0:
704
+ unmatched_group = ModuleGroup(
705
+ modules=unmatched_modules,
706
+ offload_device=config.offload_device,
707
+ onload_device=config.onload_device,
708
+ offload_to_disk_path=config.offload_to_disk_path,
709
+ offload_leader=module,
710
+ onload_leader=module,
711
+ parameters=parameters,
712
+ buffers=buffers,
713
+ non_blocking=False,
714
+ stream=None,
715
+ record_stream=False,
716
+ onload_self=True,
717
+ group_id=f"{config.module_prefix}{module.__class__.__name__}_unmatched_group",
718
+ )
719
+ if config.stream is None:
720
+ _apply_group_offloading_hook(module, unmatched_group, config=config)
721
+ else:
722
+ _apply_lazy_group_offloading_hook(module, unmatched_group, config=config)
723
+
724
+
725
+ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
726
+ r"""
727
+ This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
728
+ requirements. However, it can be slower compared to other offloading methods due to the excessive number of device
729
+ synchronizations. When using devices that support streams to overlap data transfer and computation, this method can
730
+ reduce memory usage without any performance degradation.
731
+ """
732
+ # Create module groups for leaf modules and apply group offloading hooks
733
+ modules_with_group_offloading = set()
734
+ for name, submodule in module.named_modules():
735
+ if not isinstance(submodule, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
736
+ continue
737
+ group = ModuleGroup(
738
+ modules=[submodule],
739
+ offload_device=config.offload_device,
740
+ onload_device=config.onload_device,
741
+ offload_to_disk_path=config.offload_to_disk_path,
742
+ offload_leader=submodule,
743
+ onload_leader=submodule,
744
+ non_blocking=config.non_blocking,
745
+ stream=config.stream,
746
+ record_stream=config.record_stream,
747
+ low_cpu_mem_usage=config.low_cpu_mem_usage,
748
+ onload_self=True,
749
+ group_id=name,
750
+ )
751
+ _apply_group_offloading_hook(submodule, group, config=config)
752
+ modules_with_group_offloading.add(name)
753
+
754
+ # Parameters and Buffers at all non-leaf levels need to be offloaded/onloaded separately when the forward pass
755
+ # of the module is called
756
+ module_dict = dict(module.named_modules())
757
+ parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading)
758
+ buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading)
759
+
760
+ # Find closest module parent for each parameter and buffer, and attach group hooks
761
+ parent_to_parameters = {}
762
+ for name, param in parameters:
763
+ parent_name = _find_parent_module_in_module_dict(name, module_dict)
764
+ if parent_name in parent_to_parameters:
765
+ parent_to_parameters[parent_name].append(param)
766
+ else:
767
+ parent_to_parameters[parent_name] = [param]
768
+
769
+ parent_to_buffers = {}
770
+ for name, buffer in buffers:
771
+ parent_name = _find_parent_module_in_module_dict(name, module_dict)
772
+ if parent_name in parent_to_buffers:
773
+ parent_to_buffers[parent_name].append(buffer)
774
+ else:
775
+ parent_to_buffers[parent_name] = [buffer]
776
+
777
+ parent_names = set(parent_to_parameters.keys()) | set(parent_to_buffers.keys())
778
+ for name in parent_names:
779
+ parameters = parent_to_parameters.get(name, [])
780
+ buffers = parent_to_buffers.get(name, [])
781
+ parent_module = module_dict[name]
782
+ group = ModuleGroup(
783
+ modules=[],
784
+ offload_device=config.offload_device,
785
+ onload_device=config.onload_device,
786
+ offload_leader=parent_module,
787
+ onload_leader=parent_module,
788
+ offload_to_disk_path=config.offload_to_disk_path,
789
+ parameters=parameters,
790
+ buffers=buffers,
791
+ non_blocking=config.non_blocking,
792
+ stream=config.stream,
793
+ record_stream=config.record_stream,
794
+ low_cpu_mem_usage=config.low_cpu_mem_usage,
795
+ onload_self=True,
796
+ group_id=name,
797
+ )
798
+ _apply_group_offloading_hook(parent_module, group, config=config)
799
+
800
+ if config.stream is not None:
801
+ # When using streams, we need to know the layer execution order for applying prefetching (to overlap data transfer
802
+ # and computation). Since we don't know the order beforehand, we apply a lazy prefetching hook that will find the
803
+ # execution order and apply prefetching in the correct order.
804
+ unmatched_group = ModuleGroup(
805
+ modules=[],
806
+ offload_device=config.offload_device,
807
+ onload_device=config.onload_device,
808
+ offload_to_disk_path=config.offload_to_disk_path,
809
+ offload_leader=module,
810
+ onload_leader=module,
811
+ parameters=None,
812
+ buffers=None,
813
+ non_blocking=False,
814
+ stream=None,
815
+ record_stream=False,
816
+ low_cpu_mem_usage=config.low_cpu_mem_usage,
817
+ onload_self=True,
818
+ group_id=_GROUP_ID_LAZY_LEAF,
819
+ )
820
+ _apply_lazy_group_offloading_hook(module, unmatched_group, config=config)
821
+
822
+
823
+ def _apply_group_offloading_hook(
824
+ module: torch.nn.Module,
825
+ group: ModuleGroup,
826
+ *,
827
+ config: GroupOffloadingConfig,
828
+ ) -> None:
829
+ registry = HookRegistry.check_if_exists_or_initialize(module)
830
+
831
+ # We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
832
+ # is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
833
+ if registry.get_hook(_GROUP_OFFLOADING) is None:
834
+ hook = GroupOffloadingHook(group, config=config)
835
+ registry.register_hook(hook, _GROUP_OFFLOADING)
836
+
837
+
838
+ def _apply_lazy_group_offloading_hook(
839
+ module: torch.nn.Module,
840
+ group: ModuleGroup,
841
+ *,
842
+ config: GroupOffloadingConfig,
843
+ ) -> None:
844
+ registry = HookRegistry.check_if_exists_or_initialize(module)
845
+
846
+ # We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
847
+ # is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
848
+ if registry.get_hook(_GROUP_OFFLOADING) is None:
849
+ hook = GroupOffloadingHook(group, config=config)
850
+ registry.register_hook(hook, _GROUP_OFFLOADING)
851
+
852
+ lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook()
853
+ registry.register_hook(lazy_prefetch_hook, _LAZY_PREFETCH_GROUP_OFFLOADING)
854
+
855
+
856
+ def _gather_parameters_with_no_group_offloading_parent(
857
+ module: torch.nn.Module, modules_with_group_offloading: Set[str]
858
+ ) -> list[torch.nn.Parameter]:
859
+ parameters = []
860
+ for name, parameter in module.named_parameters():
861
+ has_parent_with_group_offloading = False
862
+ atoms = name.split(".")
863
+ while len(atoms) > 0:
864
+ parent_name = ".".join(atoms)
865
+ if parent_name in modules_with_group_offloading:
866
+ has_parent_with_group_offloading = True
867
+ break
868
+ atoms.pop()
869
+ if not has_parent_with_group_offloading:
870
+ parameters.append((name, parameter))
871
+ return parameters
872
+
873
+
874
+ def _gather_buffers_with_no_group_offloading_parent(
875
+ module: torch.nn.Module, modules_with_group_offloading: Set[str]
876
+ ) -> list[torch.Tensor]:
877
+ buffers = []
878
+ for name, buffer in module.named_buffers():
879
+ has_parent_with_group_offloading = False
880
+ atoms = name.split(".")
881
+ while len(atoms) > 0:
882
+ parent_name = ".".join(atoms)
883
+ if parent_name in modules_with_group_offloading:
884
+ has_parent_with_group_offloading = True
885
+ break
886
+ atoms.pop()
887
+ if not has_parent_with_group_offloading:
888
+ buffers.append((name, buffer))
889
+ return buffers
890
+
891
+
892
+ def _find_parent_module_in_module_dict(name: str, module_dict: dict[str, torch.nn.Module]) -> str:
893
+ atoms = name.split(".")
894
+ while len(atoms) > 0:
895
+ parent_name = ".".join(atoms)
896
+ if parent_name in module_dict:
897
+ return parent_name
898
+ atoms.pop()
899
+ return ""
900
+
901
+
902
+ def _raise_error_if_accelerate_model_or_sequential_hook_present(module: torch.nn.Module) -> None:
903
+ if not is_accelerate_available():
904
+ return
905
+ for name, submodule in module.named_modules():
906
+ if not hasattr(submodule, "_hf_hook"):
907
+ continue
908
+ if isinstance(submodule._hf_hook, (AlignDevicesHook, CpuOffload)):
909
+ raise ValueError(
910
+ f"Cannot apply group offloading to a module that is already applying an alternative "
911
+ f"offloading strategy from Accelerate. If you want to apply group offloading, please "
912
+ f"disable the existing offloading strategy first. Offending module: {name} ({type(submodule)})"
913
+ )
914
+
915
+
916
+ def _get_top_level_group_offload_hook(module: torch.nn.Module) -> GroupOffloadingHook | None:
917
+ for submodule in module.modules():
918
+ if hasattr(submodule, "_diffusers_hook"):
919
+ group_offloading_hook = submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING)
920
+ if group_offloading_hook is not None:
921
+ return group_offloading_hook
922
+ return None
923
+
924
+
925
+ def _is_group_offload_enabled(module: torch.nn.Module) -> bool:
926
+ top_level_group_offload_hook = _get_top_level_group_offload_hook(module)
927
+ return top_level_group_offload_hook is not None
928
+
929
+
930
+ def _get_group_onload_device(module: torch.nn.Module) -> torch.device:
931
+ top_level_group_offload_hook = _get_top_level_group_offload_hook(module)
932
+ if top_level_group_offload_hook is not None:
933
+ return top_level_group_offload_hook.config.onload_device
934
+ raise ValueError("Group offloading is not enabled for the provided module.")
935
+
936
+
937
+ def _compute_group_hash(group_id):
938
+ hashed_id = hashlib.sha256(group_id.encode("utf-8")).hexdigest()
939
+ # first 16 characters for a reasonably short but unique name
940
+ return hashed_id[:16]
941
+
942
+
943
+ def _maybe_remove_and_reapply_group_offloading(module: torch.nn.Module) -> None:
944
+ r"""
945
+ Removes the group offloading hook from the module and re-applies it. This is useful when the module has been
946
+ modified in-place and the group offloading hook references-to-tensors needs to be updated. The in-place
947
+ modification can happen in a number of ways, for example, fusing QKV or unloading/loading LoRAs on-the-fly.
948
+
949
+ In this implementation, we make an assumption that group offloading has only been applied at the top-level module,
950
+ and therefore all submodules have the same onload and offload devices. If this assumption is not true, say in the
951
+ case where user has applied group offloading at multiple levels, this function will not work as expected.
952
+
953
+ There is some performance penalty associated with doing this when non-default streams are used, because we need to
954
+ retrace the execution order of the layers with `LazyPrefetchGroupOffloadingHook`.
955
+ """
956
+ top_level_group_offload_hook = _get_top_level_group_offload_hook(module)
957
+
958
+ if top_level_group_offload_hook is None:
959
+ return
960
+
961
+ registry = HookRegistry.check_if_exists_or_initialize(module)
962
+ registry.remove_hook(_GROUP_OFFLOADING, recurse=True)
963
+ registry.remove_hook(_LAYER_EXECUTION_TRACKER, recurse=True)
964
+ registry.remove_hook(_LAZY_PREFETCH_GROUP_OFFLOADING, recurse=True)
965
+
966
+ _apply_group_offloading(module, top_level_group_offload_hook.config)