Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/__pycache__/__init__.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/__pycache__/callbacks.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/__pycache__/configuration_utils.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/__pycache__/dependency_versions_check.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/__pycache__/dependency_versions_table.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/__pycache__/image_processor.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/__pycache__/optimization.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/__pycache__/training_utils.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/__pycache__/video_processor.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/commands/__init__.py +27 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/commands/custom_blocks.py +132 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/commands/diffusers_cli.py +45 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/commands/env.py +180 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/commands/fp16_safetensors.py +132 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/experimental/__init__.py +1 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/__init__.py +31 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/adaptive_projected_guidance.py +237 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/adaptive_projected_guidance_mix.py +297 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/auto_guidance.py +198 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/classifier_free_guidance.py +156 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/classifier_free_zero_star_guidance.py +164 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/frequency_decoupled_guidance.py +335 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/guider_utils.py +396 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/magnitude_aware_guidance.py +159 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/perturbed_attention_guidance.py +289 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/skip_layer_guidance.py +280 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/smoothed_energy_guidance.py +269 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/tangential_classifier_free_guidance.py +151 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/__init__.py +29 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/__pycache__/__init__.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/__pycache__/_common.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/__pycache__/_helpers.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/__pycache__/context_parallel.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/__pycache__/faster_cache.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/__pycache__/first_block_cache.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/__pycache__/group_offloading.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/__pycache__/hooks.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/__pycache__/layer_skip.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/__pycache__/layerwise_casting.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/__pycache__/mag_cache.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/__pycache__/pyramid_attention_broadcast.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/__pycache__/smoothed_energy_guidance_utils.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/__pycache__/taylorseer_cache.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/__pycache__/utils.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/_common.py +61 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/_helpers.py +381 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/context_parallel.py +383 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/faster_cache.py +654 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/first_block_cache.py +258 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/group_offloading.py +966 -0
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (38.8 kB). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/__pycache__/callbacks.cpython-312.pyc
ADDED
|
Binary file (10.9 kB). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/__pycache__/configuration_utils.cpython-312.pyc
ADDED
|
Binary file (38.8 kB). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/__pycache__/dependency_versions_check.cpython-312.pyc
ADDED
|
Binary file (917 Bytes). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/__pycache__/dependency_versions_table.cpython-312.pyc
ADDED
|
Binary file (2.29 kB). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/__pycache__/image_processor.cpython-312.pyc
ADDED
|
Binary file (65.1 kB). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/__pycache__/optimization.cpython-312.pyc
ADDED
|
Binary file (15.9 kB). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/__pycache__/training_utils.cpython-312.pyc
ADDED
|
Binary file (40.2 kB). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/__pycache__/video_processor.cpython-312.pyc
ADDED
|
Binary file (9.41 kB). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/commands/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from abc import ABC, abstractmethod
|
| 16 |
+
from argparse import ArgumentParser
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class BaseDiffusersCLICommand(ABC):
|
| 20 |
+
@staticmethod
|
| 21 |
+
@abstractmethod
|
| 22 |
+
def register_subcommand(parser: ArgumentParser):
|
| 23 |
+
raise NotImplementedError()
|
| 24 |
+
|
| 25 |
+
@abstractmethod
|
| 26 |
+
def run(self):
|
| 27 |
+
raise NotImplementedError()
|
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/commands/custom_blocks.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
Usage example:
|
| 17 |
+
TODO
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import ast
|
| 21 |
+
import importlib.util
|
| 22 |
+
import os
|
| 23 |
+
from argparse import ArgumentParser, Namespace
|
| 24 |
+
from pathlib import Path
|
| 25 |
+
|
| 26 |
+
from ..utils import logging
|
| 27 |
+
from . import BaseDiffusersCLICommand
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
EXPECTED_PARENT_CLASSES = ["ModularPipelineBlocks"]
|
| 31 |
+
CONFIG = "config.json"
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def conversion_command_factory(args: Namespace):
|
| 35 |
+
return CustomBlocksCommand(args.block_module_name, args.block_class_name)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class CustomBlocksCommand(BaseDiffusersCLICommand):
|
| 39 |
+
@staticmethod
|
| 40 |
+
def register_subcommand(parser: ArgumentParser):
|
| 41 |
+
conversion_parser = parser.add_parser("custom_blocks")
|
| 42 |
+
conversion_parser.add_argument(
|
| 43 |
+
"--block_module_name",
|
| 44 |
+
type=str,
|
| 45 |
+
default="block.py",
|
| 46 |
+
help="Module filename in which the custom block will be implemented.",
|
| 47 |
+
)
|
| 48 |
+
conversion_parser.add_argument(
|
| 49 |
+
"--block_class_name",
|
| 50 |
+
type=str,
|
| 51 |
+
default=None,
|
| 52 |
+
help="Name of the custom block. If provided None, we will try to infer it.",
|
| 53 |
+
)
|
| 54 |
+
conversion_parser.set_defaults(func=conversion_command_factory)
|
| 55 |
+
|
| 56 |
+
def __init__(self, block_module_name: str = "block.py", block_class_name: str = None):
|
| 57 |
+
self.logger = logging.get_logger("diffusers-cli/custom_blocks")
|
| 58 |
+
self.block_module_name = Path(block_module_name)
|
| 59 |
+
self.block_class_name = block_class_name
|
| 60 |
+
|
| 61 |
+
def run(self):
|
| 62 |
+
# determine the block to be saved.
|
| 63 |
+
out = self._get_class_names(self.block_module_name)
|
| 64 |
+
classes_found = list({cls for cls, _ in out})
|
| 65 |
+
|
| 66 |
+
if self.block_class_name is not None:
|
| 67 |
+
child_class, parent_class = self._choose_block(out, self.block_class_name)
|
| 68 |
+
if child_class is None and parent_class is None:
|
| 69 |
+
raise ValueError(
|
| 70 |
+
"`block_class_name` could not be retrieved. Available classes from "
|
| 71 |
+
f"{self.block_module_name}:\n{classes_found}"
|
| 72 |
+
)
|
| 73 |
+
else:
|
| 74 |
+
self.logger.info(
|
| 75 |
+
f"Found classes: {classes_found} will be using {classes_found[0]}. "
|
| 76 |
+
"If this needs to be changed, re-run the command specifying `block_class_name`."
|
| 77 |
+
)
|
| 78 |
+
child_class, parent_class = out[0][0], out[0][1]
|
| 79 |
+
|
| 80 |
+
# dynamically get the custom block and initialize it to call `save_pretrained` in the current directory.
|
| 81 |
+
# the user is responsible for running it, so I guess that is safe?
|
| 82 |
+
module_name = f"__dynamic__{self.block_module_name.stem}"
|
| 83 |
+
spec = importlib.util.spec_from_file_location(module_name, str(self.block_module_name))
|
| 84 |
+
module = importlib.util.module_from_spec(spec)
|
| 85 |
+
spec.loader.exec_module(module)
|
| 86 |
+
getattr(module, child_class)().save_pretrained(os.getcwd())
|
| 87 |
+
|
| 88 |
+
# or, we could create it manually.
|
| 89 |
+
# automap = self._create_automap(parent_class=parent_class, child_class=child_class)
|
| 90 |
+
# with open(CONFIG, "w") as f:
|
| 91 |
+
# json.dump(automap, f)
|
| 92 |
+
|
| 93 |
+
def _choose_block(self, candidates, chosen=None):
|
| 94 |
+
for cls, base in candidates:
|
| 95 |
+
if cls == chosen:
|
| 96 |
+
return cls, base
|
| 97 |
+
return None, None
|
| 98 |
+
|
| 99 |
+
def _get_class_names(self, file_path):
|
| 100 |
+
source = file_path.read_text(encoding="utf-8")
|
| 101 |
+
try:
|
| 102 |
+
tree = ast.parse(source, filename=file_path)
|
| 103 |
+
except SyntaxError as e:
|
| 104 |
+
raise ValueError(f"Could not parse {file_path!r}: {e}") from e
|
| 105 |
+
|
| 106 |
+
results: list[tuple[str, str]] = []
|
| 107 |
+
for node in tree.body:
|
| 108 |
+
if not isinstance(node, ast.ClassDef):
|
| 109 |
+
continue
|
| 110 |
+
|
| 111 |
+
# extract all base names for this class
|
| 112 |
+
base_names = [bname for b in node.bases if (bname := self._get_base_name(b)) is not None]
|
| 113 |
+
|
| 114 |
+
# for each allowed base that appears in the class's bases, emit a tuple
|
| 115 |
+
for allowed in EXPECTED_PARENT_CLASSES:
|
| 116 |
+
if allowed in base_names:
|
| 117 |
+
results.append((node.name, allowed))
|
| 118 |
+
|
| 119 |
+
return results
|
| 120 |
+
|
| 121 |
+
def _get_base_name(self, node: ast.expr):
|
| 122 |
+
if isinstance(node, ast.Name):
|
| 123 |
+
return node.id
|
| 124 |
+
elif isinstance(node, ast.Attribute):
|
| 125 |
+
val = self._get_base_name(node.value)
|
| 126 |
+
return f"{val}.{node.attr}" if val else node.attr
|
| 127 |
+
return None
|
| 128 |
+
|
| 129 |
+
def _create_automap(self, parent_class, child_class):
|
| 130 |
+
module = str(self.block_module_name).replace(".py", "").rsplit(".", 1)[-1]
|
| 131 |
+
auto_map = {f"{parent_class}": f"{module}.{child_class}"}
|
| 132 |
+
return {"auto_map": auto_map}
|
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/commands/diffusers_cli.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from argparse import ArgumentParser
|
| 17 |
+
|
| 18 |
+
from .custom_blocks import CustomBlocksCommand
|
| 19 |
+
from .env import EnvironmentCommand
|
| 20 |
+
from .fp16_safetensors import FP16SafetensorsCommand
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def main():
|
| 24 |
+
parser = ArgumentParser("Diffusers CLI tool", usage="diffusers-cli <command> [<args>]")
|
| 25 |
+
commands_parser = parser.add_subparsers(help="diffusers-cli command helpers")
|
| 26 |
+
|
| 27 |
+
# Register commands
|
| 28 |
+
EnvironmentCommand.register_subcommand(commands_parser)
|
| 29 |
+
FP16SafetensorsCommand.register_subcommand(commands_parser)
|
| 30 |
+
CustomBlocksCommand.register_subcommand(commands_parser)
|
| 31 |
+
|
| 32 |
+
# Let's go
|
| 33 |
+
args = parser.parse_args()
|
| 34 |
+
|
| 35 |
+
if not hasattr(args, "func"):
|
| 36 |
+
parser.print_help()
|
| 37 |
+
exit(1)
|
| 38 |
+
|
| 39 |
+
# Run
|
| 40 |
+
service = args.func(args)
|
| 41 |
+
service.run()
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
if __name__ == "__main__":
|
| 45 |
+
main()
|
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/commands/env.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import platform
|
| 16 |
+
import subprocess
|
| 17 |
+
from argparse import ArgumentParser
|
| 18 |
+
|
| 19 |
+
import huggingface_hub
|
| 20 |
+
|
| 21 |
+
from .. import __version__ as version
|
| 22 |
+
from ..utils import (
|
| 23 |
+
is_accelerate_available,
|
| 24 |
+
is_bitsandbytes_available,
|
| 25 |
+
is_flax_available,
|
| 26 |
+
is_google_colab,
|
| 27 |
+
is_peft_available,
|
| 28 |
+
is_safetensors_available,
|
| 29 |
+
is_torch_available,
|
| 30 |
+
is_transformers_available,
|
| 31 |
+
is_xformers_available,
|
| 32 |
+
)
|
| 33 |
+
from . import BaseDiffusersCLICommand
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def info_command_factory(_):
|
| 37 |
+
return EnvironmentCommand()
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class EnvironmentCommand(BaseDiffusersCLICommand):
|
| 41 |
+
@staticmethod
|
| 42 |
+
def register_subcommand(parser: ArgumentParser) -> None:
|
| 43 |
+
download_parser = parser.add_parser("env")
|
| 44 |
+
download_parser.set_defaults(func=info_command_factory)
|
| 45 |
+
|
| 46 |
+
def run(self) -> dict:
|
| 47 |
+
hub_version = huggingface_hub.__version__
|
| 48 |
+
|
| 49 |
+
safetensors_version = "not installed"
|
| 50 |
+
if is_safetensors_available():
|
| 51 |
+
import safetensors
|
| 52 |
+
|
| 53 |
+
safetensors_version = safetensors.__version__
|
| 54 |
+
|
| 55 |
+
pt_version = "not installed"
|
| 56 |
+
pt_cuda_available = "NA"
|
| 57 |
+
if is_torch_available():
|
| 58 |
+
import torch
|
| 59 |
+
|
| 60 |
+
pt_version = torch.__version__
|
| 61 |
+
pt_cuda_available = torch.cuda.is_available()
|
| 62 |
+
|
| 63 |
+
flax_version = "not installed"
|
| 64 |
+
jax_version = "not installed"
|
| 65 |
+
jaxlib_version = "not installed"
|
| 66 |
+
jax_backend = "NA"
|
| 67 |
+
if is_flax_available():
|
| 68 |
+
import flax
|
| 69 |
+
import jax
|
| 70 |
+
import jaxlib
|
| 71 |
+
|
| 72 |
+
flax_version = flax.__version__
|
| 73 |
+
jax_version = jax.__version__
|
| 74 |
+
jaxlib_version = jaxlib.__version__
|
| 75 |
+
jax_backend = jax.lib.xla_bridge.get_backend().platform
|
| 76 |
+
|
| 77 |
+
transformers_version = "not installed"
|
| 78 |
+
if is_transformers_available():
|
| 79 |
+
import transformers
|
| 80 |
+
|
| 81 |
+
transformers_version = transformers.__version__
|
| 82 |
+
|
| 83 |
+
accelerate_version = "not installed"
|
| 84 |
+
if is_accelerate_available():
|
| 85 |
+
import accelerate
|
| 86 |
+
|
| 87 |
+
accelerate_version = accelerate.__version__
|
| 88 |
+
|
| 89 |
+
peft_version = "not installed"
|
| 90 |
+
if is_peft_available():
|
| 91 |
+
import peft
|
| 92 |
+
|
| 93 |
+
peft_version = peft.__version__
|
| 94 |
+
|
| 95 |
+
bitsandbytes_version = "not installed"
|
| 96 |
+
if is_bitsandbytes_available():
|
| 97 |
+
import bitsandbytes
|
| 98 |
+
|
| 99 |
+
bitsandbytes_version = bitsandbytes.__version__
|
| 100 |
+
|
| 101 |
+
xformers_version = "not installed"
|
| 102 |
+
if is_xformers_available():
|
| 103 |
+
import xformers
|
| 104 |
+
|
| 105 |
+
xformers_version = xformers.__version__
|
| 106 |
+
|
| 107 |
+
platform_info = platform.platform()
|
| 108 |
+
|
| 109 |
+
is_google_colab_str = "Yes" if is_google_colab() else "No"
|
| 110 |
+
|
| 111 |
+
accelerator = "NA"
|
| 112 |
+
if platform.system() in {"Linux", "Windows"}:
|
| 113 |
+
try:
|
| 114 |
+
sp = subprocess.Popen(
|
| 115 |
+
["nvidia-smi", "--query-gpu=gpu_name,memory.total", "--format=csv,noheader"],
|
| 116 |
+
stdout=subprocess.PIPE,
|
| 117 |
+
stderr=subprocess.PIPE,
|
| 118 |
+
)
|
| 119 |
+
out_str, _ = sp.communicate()
|
| 120 |
+
out_str = out_str.decode("utf-8")
|
| 121 |
+
|
| 122 |
+
if len(out_str) > 0:
|
| 123 |
+
accelerator = out_str.strip()
|
| 124 |
+
except FileNotFoundError:
|
| 125 |
+
pass
|
| 126 |
+
elif platform.system() == "Darwin": # Mac OS
|
| 127 |
+
try:
|
| 128 |
+
sp = subprocess.Popen(
|
| 129 |
+
["system_profiler", "SPDisplaysDataType"],
|
| 130 |
+
stdout=subprocess.PIPE,
|
| 131 |
+
stderr=subprocess.PIPE,
|
| 132 |
+
)
|
| 133 |
+
out_str, _ = sp.communicate()
|
| 134 |
+
out_str = out_str.decode("utf-8")
|
| 135 |
+
|
| 136 |
+
start = out_str.find("Chipset Model:")
|
| 137 |
+
if start != -1:
|
| 138 |
+
start += len("Chipset Model:")
|
| 139 |
+
end = out_str.find("\n", start)
|
| 140 |
+
accelerator = out_str[start:end].strip()
|
| 141 |
+
|
| 142 |
+
start = out_str.find("VRAM (Total):")
|
| 143 |
+
if start != -1:
|
| 144 |
+
start += len("VRAM (Total):")
|
| 145 |
+
end = out_str.find("\n", start)
|
| 146 |
+
accelerator += " VRAM: " + out_str[start:end].strip()
|
| 147 |
+
except FileNotFoundError:
|
| 148 |
+
pass
|
| 149 |
+
else:
|
| 150 |
+
print("It seems you are running an unusual OS. Could you fill in the accelerator manually?")
|
| 151 |
+
|
| 152 |
+
info = {
|
| 153 |
+
"🤗 Diffusers version": version,
|
| 154 |
+
"Platform": platform_info,
|
| 155 |
+
"Running on Google Colab?": is_google_colab_str,
|
| 156 |
+
"Python version": platform.python_version(),
|
| 157 |
+
"PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})",
|
| 158 |
+
"Flax version (CPU?/GPU?/TPU?)": f"{flax_version} ({jax_backend})",
|
| 159 |
+
"Jax version": jax_version,
|
| 160 |
+
"JaxLib version": jaxlib_version,
|
| 161 |
+
"Huggingface_hub version": hub_version,
|
| 162 |
+
"Transformers version": transformers_version,
|
| 163 |
+
"Accelerate version": accelerate_version,
|
| 164 |
+
"PEFT version": peft_version,
|
| 165 |
+
"Bitsandbytes version": bitsandbytes_version,
|
| 166 |
+
"Safetensors version": safetensors_version,
|
| 167 |
+
"xFormers version": xformers_version,
|
| 168 |
+
"Accelerator": accelerator,
|
| 169 |
+
"Using GPU in script?": "<fill in>",
|
| 170 |
+
"Using distributed or parallel set-up in script?": "<fill in>",
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n")
|
| 174 |
+
print(self.format_dict(info))
|
| 175 |
+
|
| 176 |
+
return info
|
| 177 |
+
|
| 178 |
+
@staticmethod
|
| 179 |
+
def format_dict(d: dict) -> str:
|
| 180 |
+
return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n"
|
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/commands/fp16_safetensors.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
Usage example:
|
| 17 |
+
diffusers-cli fp16_safetensors --ckpt_id=openai/shap-e --fp16 --use_safetensors
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import glob
|
| 21 |
+
import json
|
| 22 |
+
import warnings
|
| 23 |
+
from argparse import ArgumentParser, Namespace
|
| 24 |
+
from importlib import import_module
|
| 25 |
+
|
| 26 |
+
import huggingface_hub
|
| 27 |
+
import torch
|
| 28 |
+
from huggingface_hub import hf_hub_download
|
| 29 |
+
from packaging import version
|
| 30 |
+
|
| 31 |
+
from ..utils import logging
|
| 32 |
+
from . import BaseDiffusersCLICommand
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def conversion_command_factory(args: Namespace):
|
| 36 |
+
if args.use_auth_token:
|
| 37 |
+
warnings.warn(
|
| 38 |
+
"The `--use_auth_token` flag is deprecated and will be removed in a future version."
|
| 39 |
+
"Authentication is now handled automatically if the user is logged in."
|
| 40 |
+
)
|
| 41 |
+
return FP16SafetensorsCommand(args.ckpt_id, args.fp16, args.use_safetensors)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class FP16SafetensorsCommand(BaseDiffusersCLICommand):
|
| 45 |
+
@staticmethod
|
| 46 |
+
def register_subcommand(parser: ArgumentParser):
|
| 47 |
+
conversion_parser = parser.add_parser("fp16_safetensors")
|
| 48 |
+
conversion_parser.add_argument(
|
| 49 |
+
"--ckpt_id",
|
| 50 |
+
type=str,
|
| 51 |
+
help="Repo id of the checkpoints on which to run the conversion. Example: 'openai/shap-e'.",
|
| 52 |
+
)
|
| 53 |
+
conversion_parser.add_argument(
|
| 54 |
+
"--fp16", action="store_true", help="If serializing the variables in FP16 precision."
|
| 55 |
+
)
|
| 56 |
+
conversion_parser.add_argument(
|
| 57 |
+
"--use_safetensors", action="store_true", help="If serializing in the safetensors format."
|
| 58 |
+
)
|
| 59 |
+
conversion_parser.add_argument(
|
| 60 |
+
"--use_auth_token",
|
| 61 |
+
action="store_true",
|
| 62 |
+
help="When working with checkpoints having private visibility. When used `hf auth login` needs to be run beforehand.",
|
| 63 |
+
)
|
| 64 |
+
conversion_parser.set_defaults(func=conversion_command_factory)
|
| 65 |
+
|
| 66 |
+
def __init__(self, ckpt_id: str, fp16: bool, use_safetensors: bool):
|
| 67 |
+
self.logger = logging.get_logger("diffusers-cli/fp16_safetensors")
|
| 68 |
+
self.ckpt_id = ckpt_id
|
| 69 |
+
self.local_ckpt_dir = f"/tmp/{ckpt_id}"
|
| 70 |
+
self.fp16 = fp16
|
| 71 |
+
|
| 72 |
+
self.use_safetensors = use_safetensors
|
| 73 |
+
|
| 74 |
+
if not self.use_safetensors and not self.fp16:
|
| 75 |
+
raise NotImplementedError(
|
| 76 |
+
"When `use_safetensors` and `fp16` both are False, then this command is of no use."
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
def run(self):
|
| 80 |
+
if version.parse(huggingface_hub.__version__) < version.parse("0.9.0"):
|
| 81 |
+
raise ImportError(
|
| 82 |
+
"The huggingface_hub version must be >= 0.9.0 to use this command. Please update your huggingface_hub"
|
| 83 |
+
" installation."
|
| 84 |
+
)
|
| 85 |
+
else:
|
| 86 |
+
from huggingface_hub import create_commit
|
| 87 |
+
from huggingface_hub._commit_api import CommitOperationAdd
|
| 88 |
+
|
| 89 |
+
model_index = hf_hub_download(repo_id=self.ckpt_id, filename="model_index.json")
|
| 90 |
+
with open(model_index, "r") as f:
|
| 91 |
+
pipeline_class_name = json.load(f)["_class_name"]
|
| 92 |
+
pipeline_class = getattr(import_module("diffusers"), pipeline_class_name)
|
| 93 |
+
self.logger.info(f"Pipeline class imported: {pipeline_class_name}.")
|
| 94 |
+
|
| 95 |
+
# Load the appropriate pipeline. We could have used `DiffusionPipeline`
|
| 96 |
+
# here, but just to avoid potential edge cases.
|
| 97 |
+
pipeline = pipeline_class.from_pretrained(
|
| 98 |
+
self.ckpt_id, torch_dtype=torch.float16 if self.fp16 else torch.float32
|
| 99 |
+
)
|
| 100 |
+
pipeline.save_pretrained(
|
| 101 |
+
self.local_ckpt_dir,
|
| 102 |
+
safe_serialization=True if self.use_safetensors else False,
|
| 103 |
+
variant="fp16" if self.fp16 else None,
|
| 104 |
+
)
|
| 105 |
+
self.logger.info(f"Pipeline locally saved to {self.local_ckpt_dir}.")
|
| 106 |
+
|
| 107 |
+
# Fetch all the paths.
|
| 108 |
+
if self.fp16:
|
| 109 |
+
modified_paths = glob.glob(f"{self.local_ckpt_dir}/*/*.fp16.*")
|
| 110 |
+
elif self.use_safetensors:
|
| 111 |
+
modified_paths = glob.glob(f"{self.local_ckpt_dir}/*/*.safetensors")
|
| 112 |
+
|
| 113 |
+
# Prepare for the PR.
|
| 114 |
+
commit_message = f"Serialize variables with FP16: {self.fp16} and safetensors: {self.use_safetensors}."
|
| 115 |
+
operations = []
|
| 116 |
+
for path in modified_paths:
|
| 117 |
+
operations.append(CommitOperationAdd(path_in_repo="/".join(path.split("/")[4:]), path_or_fileobj=path))
|
| 118 |
+
|
| 119 |
+
# Open the PR.
|
| 120 |
+
commit_description = (
|
| 121 |
+
"Variables converted by the [`diffusers`' `fp16_safetensors`"
|
| 122 |
+
" CLI](https://github.com/huggingface/diffusers/blob/main/src/diffusers/commands/fp16_safetensors.py)."
|
| 123 |
+
)
|
| 124 |
+
hub_pr_url = create_commit(
|
| 125 |
+
repo_id=self.ckpt_id,
|
| 126 |
+
operations=operations,
|
| 127 |
+
commit_message=commit_message,
|
| 128 |
+
commit_description=commit_description,
|
| 129 |
+
repo_type="model",
|
| 130 |
+
create_pr=True,
|
| 131 |
+
).pr_url
|
| 132 |
+
self.logger.info(f"PR created here: {hub_pr_url}.")
|
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/experimental/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .rl import ValueGuidedRLPipeline
|
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/__init__.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
from ..utils import is_torch_available, logging
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
if is_torch_available():
|
| 20 |
+
from .adaptive_projected_guidance import AdaptiveProjectedGuidance
|
| 21 |
+
from .adaptive_projected_guidance_mix import AdaptiveProjectedMixGuidance
|
| 22 |
+
from .auto_guidance import AutoGuidance
|
| 23 |
+
from .classifier_free_guidance import ClassifierFreeGuidance
|
| 24 |
+
from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance
|
| 25 |
+
from .frequency_decoupled_guidance import FrequencyDecoupledGuidance
|
| 26 |
+
from .guider_utils import BaseGuidance
|
| 27 |
+
from .magnitude_aware_guidance import MagnitudeAwareGuidance
|
| 28 |
+
from .perturbed_attention_guidance import PerturbedAttentionGuidance
|
| 29 |
+
from .skip_layer_guidance import SkipLayerGuidance
|
| 30 |
+
from .smoothed_energy_guidance import SmoothedEnergyGuidance
|
| 31 |
+
from .tangential_classifier_free_guidance import TangentialClassifierFreeGuidance
|
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/adaptive_projected_guidance.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
from typing import TYPE_CHECKING
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
|
| 22 |
+
from ..configuration_utils import register_to_config
|
| 23 |
+
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
if TYPE_CHECKING:
|
| 27 |
+
from ..modular_pipelines.modular_pipeline import BlockState
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class AdaptiveProjectedGuidance(BaseGuidance):
|
| 31 |
+
"""
|
| 32 |
+
Adaptive Projected Guidance (APG): https://huggingface.co/papers/2410.02416
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
guidance_scale (`float`, defaults to `7.5`):
|
| 36 |
+
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
| 37 |
+
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
| 38 |
+
deterioration of image quality.
|
| 39 |
+
adaptive_projected_guidance_momentum (`float`, defaults to `None`):
|
| 40 |
+
The momentum parameter for the adaptive projected guidance. Disabled if set to `None`.
|
| 41 |
+
adaptive_projected_guidance_rescale (`float`, defaults to `15.0`):
|
| 42 |
+
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
| 43 |
+
guidance_rescale (`float`, defaults to `0.0`):
|
| 44 |
+
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
| 45 |
+
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
| 46 |
+
Flawed](https://huggingface.co/papers/2305.08891).
|
| 47 |
+
use_original_formulation (`bool`, defaults to `False`):
|
| 48 |
+
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
| 49 |
+
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
| 50 |
+
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
| 51 |
+
start (`float`, defaults to `0.0`):
|
| 52 |
+
The fraction of the total number of denoising steps after which guidance starts.
|
| 53 |
+
stop (`float`, defaults to `1.0`):
|
| 54 |
+
The fraction of the total number of denoising steps after which guidance stops.
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
_input_predictions = ["pred_cond", "pred_uncond"]
|
| 58 |
+
|
| 59 |
+
@register_to_config
|
| 60 |
+
def __init__(
|
| 61 |
+
self,
|
| 62 |
+
guidance_scale: float = 7.5,
|
| 63 |
+
adaptive_projected_guidance_momentum: float | None = None,
|
| 64 |
+
adaptive_projected_guidance_rescale: float = 15.0,
|
| 65 |
+
eta: float = 1.0,
|
| 66 |
+
guidance_rescale: float = 0.0,
|
| 67 |
+
use_original_formulation: bool = False,
|
| 68 |
+
start: float = 0.0,
|
| 69 |
+
stop: float = 1.0,
|
| 70 |
+
enabled: bool = True,
|
| 71 |
+
):
|
| 72 |
+
super().__init__(start, stop, enabled)
|
| 73 |
+
|
| 74 |
+
self.guidance_scale = guidance_scale
|
| 75 |
+
self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
|
| 76 |
+
self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale
|
| 77 |
+
self.eta = eta
|
| 78 |
+
self.guidance_rescale = guidance_rescale
|
| 79 |
+
self.use_original_formulation = use_original_formulation
|
| 80 |
+
self.momentum_buffer = None
|
| 81 |
+
|
| 82 |
+
def prepare_inputs(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]) -> list["BlockState"]:
|
| 83 |
+
if self._step == 0:
|
| 84 |
+
if self.adaptive_projected_guidance_momentum is not None:
|
| 85 |
+
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
|
| 86 |
+
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
| 87 |
+
data_batches = []
|
| 88 |
+
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
| 89 |
+
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
|
| 90 |
+
data_batches.append(data_batch)
|
| 91 |
+
return data_batches
|
| 92 |
+
|
| 93 |
+
def prepare_inputs_from_block_state(
|
| 94 |
+
self, data: "BlockState", input_fields: dict[str, str | tuple[str, str]]
|
| 95 |
+
) -> list["BlockState"]:
|
| 96 |
+
if self._step == 0:
|
| 97 |
+
if self.adaptive_projected_guidance_momentum is not None:
|
| 98 |
+
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
|
| 99 |
+
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
| 100 |
+
data_batches = []
|
| 101 |
+
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
| 102 |
+
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
|
| 103 |
+
data_batches.append(data_batch)
|
| 104 |
+
return data_batches
|
| 105 |
+
|
| 106 |
+
def forward(self, pred_cond: torch.Tensor, pred_uncond: torch.Tensor | None = None) -> GuiderOutput:
|
| 107 |
+
pred = None
|
| 108 |
+
|
| 109 |
+
if not self._is_apg_enabled():
|
| 110 |
+
pred = pred_cond
|
| 111 |
+
else:
|
| 112 |
+
pred = normalized_guidance(
|
| 113 |
+
pred_cond,
|
| 114 |
+
pred_uncond,
|
| 115 |
+
self.guidance_scale,
|
| 116 |
+
self.momentum_buffer,
|
| 117 |
+
self.eta,
|
| 118 |
+
self.adaptive_projected_guidance_rescale,
|
| 119 |
+
self.use_original_formulation,
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
if self.guidance_rescale > 0.0:
|
| 123 |
+
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
| 124 |
+
|
| 125 |
+
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
|
| 126 |
+
|
| 127 |
+
@property
|
| 128 |
+
def is_conditional(self) -> bool:
|
| 129 |
+
return self._count_prepared == 1
|
| 130 |
+
|
| 131 |
+
@property
|
| 132 |
+
def num_conditions(self) -> int:
|
| 133 |
+
num_conditions = 1
|
| 134 |
+
if self._is_apg_enabled():
|
| 135 |
+
num_conditions += 1
|
| 136 |
+
return num_conditions
|
| 137 |
+
|
| 138 |
+
def _is_apg_enabled(self) -> bool:
|
| 139 |
+
if not self._enabled:
|
| 140 |
+
return False
|
| 141 |
+
|
| 142 |
+
is_within_range = True
|
| 143 |
+
if self._num_inference_steps is not None:
|
| 144 |
+
skip_start_step = int(self._start * self._num_inference_steps)
|
| 145 |
+
skip_stop_step = int(self._stop * self._num_inference_steps)
|
| 146 |
+
is_within_range = skip_start_step <= self._step < skip_stop_step
|
| 147 |
+
|
| 148 |
+
is_close = False
|
| 149 |
+
if self.use_original_formulation:
|
| 150 |
+
is_close = math.isclose(self.guidance_scale, 0.0)
|
| 151 |
+
else:
|
| 152 |
+
is_close = math.isclose(self.guidance_scale, 1.0)
|
| 153 |
+
|
| 154 |
+
return is_within_range and not is_close
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
class MomentumBuffer:
|
| 158 |
+
def __init__(self, momentum: float):
|
| 159 |
+
self.momentum = momentum
|
| 160 |
+
self.running_average = 0
|
| 161 |
+
|
| 162 |
+
def update(self, update_value: torch.Tensor):
|
| 163 |
+
new_average = self.momentum * self.running_average
|
| 164 |
+
self.running_average = update_value + new_average
|
| 165 |
+
|
| 166 |
+
def __repr__(self) -> str:
|
| 167 |
+
"""
|
| 168 |
+
Returns a string representation showing momentum, shape, statistics, and a slice of the running_average.
|
| 169 |
+
"""
|
| 170 |
+
if isinstance(self.running_average, torch.Tensor):
|
| 171 |
+
shape = tuple(self.running_average.shape)
|
| 172 |
+
|
| 173 |
+
# Calculate statistics
|
| 174 |
+
with torch.no_grad():
|
| 175 |
+
stats = {
|
| 176 |
+
"mean": self.running_average.mean().item(),
|
| 177 |
+
"std": self.running_average.std().item(),
|
| 178 |
+
"min": self.running_average.min().item(),
|
| 179 |
+
"max": self.running_average.max().item(),
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
# Get a slice (max 3 elements per dimension)
|
| 183 |
+
slice_indices = tuple(slice(None, min(3, dim)) for dim in shape)
|
| 184 |
+
sliced_data = self.running_average[slice_indices]
|
| 185 |
+
|
| 186 |
+
# Format the slice for display (convert to float32 for numpy compatibility with bfloat16)
|
| 187 |
+
slice_str = str(sliced_data.detach().float().cpu().numpy())
|
| 188 |
+
if len(slice_str) > 200: # Truncate if too long
|
| 189 |
+
slice_str = slice_str[:200] + "..."
|
| 190 |
+
|
| 191 |
+
stats_str = ", ".join([f"{k}={v:.4f}" for k, v in stats.items()])
|
| 192 |
+
|
| 193 |
+
return (
|
| 194 |
+
f"MomentumBuffer(\n"
|
| 195 |
+
f" momentum={self.momentum},\n"
|
| 196 |
+
f" shape={shape},\n"
|
| 197 |
+
f" stats=[{stats_str}],\n"
|
| 198 |
+
f" slice={slice_str}\n"
|
| 199 |
+
f")"
|
| 200 |
+
)
|
| 201 |
+
else:
|
| 202 |
+
return f"MomentumBuffer(momentum={self.momentum}, running_average={self.running_average})"
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def normalized_guidance(
|
| 206 |
+
pred_cond: torch.Tensor,
|
| 207 |
+
pred_uncond: torch.Tensor,
|
| 208 |
+
guidance_scale: float,
|
| 209 |
+
momentum_buffer: MomentumBuffer | None = None,
|
| 210 |
+
eta: float = 1.0,
|
| 211 |
+
norm_threshold: float = 0.0,
|
| 212 |
+
use_original_formulation: bool = False,
|
| 213 |
+
):
|
| 214 |
+
diff = pred_cond - pred_uncond
|
| 215 |
+
dim = [-i for i in range(1, len(diff.shape))]
|
| 216 |
+
|
| 217 |
+
if momentum_buffer is not None:
|
| 218 |
+
momentum_buffer.update(diff)
|
| 219 |
+
diff = momentum_buffer.running_average
|
| 220 |
+
|
| 221 |
+
if norm_threshold > 0:
|
| 222 |
+
ones = torch.ones_like(diff)
|
| 223 |
+
diff_norm = diff.norm(p=2, dim=dim, keepdim=True)
|
| 224 |
+
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
|
| 225 |
+
diff = diff * scale_factor
|
| 226 |
+
|
| 227 |
+
v0, v1 = diff.double(), pred_cond.double()
|
| 228 |
+
v1 = torch.nn.functional.normalize(v1, dim=dim)
|
| 229 |
+
v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
|
| 230 |
+
v0_orthogonal = v0 - v0_parallel
|
| 231 |
+
diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff)
|
| 232 |
+
normalized_update = diff_orthogonal + eta * diff_parallel
|
| 233 |
+
|
| 234 |
+
pred = pred_cond if use_original_formulation else pred_uncond
|
| 235 |
+
pred = pred + guidance_scale * normalized_update
|
| 236 |
+
|
| 237 |
+
return pred
|
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/adaptive_projected_guidance_mix.py
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import math
|
| 16 |
+
from typing import TYPE_CHECKING
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
from ..configuration_utils import register_to_config
|
| 21 |
+
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
if TYPE_CHECKING:
|
| 25 |
+
from ..modular_pipelines.modular_pipeline import BlockState
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class AdaptiveProjectedMixGuidance(BaseGuidance):
|
| 29 |
+
"""
|
| 30 |
+
Adaptive Projected Guidance (APG) https://huggingface.co/papers/2410.02416 combined with Classifier-Free Guidance
|
| 31 |
+
(CFG). This guider is used in HunyuanImage2.1 https://github.com/Tencent-Hunyuan/HunyuanImage-2.1
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
guidance_scale (`float`, defaults to `7.5`):
|
| 35 |
+
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
| 36 |
+
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
| 37 |
+
deterioration of image quality.
|
| 38 |
+
adaptive_projected_guidance_momentum (`float`, defaults to `None`):
|
| 39 |
+
The momentum parameter for the adaptive projected guidance. Disabled if set to `None`.
|
| 40 |
+
adaptive_projected_guidance_rescale (`float`, defaults to `15.0`):
|
| 41 |
+
The rescale factor applied to the noise predictions for adaptive projected guidance. This is used to
|
| 42 |
+
improve image quality and fix
|
| 43 |
+
guidance_rescale (`float`, defaults to `0.0`):
|
| 44 |
+
The rescale factor applied to the noise predictions for classifier-free guidance. This is used to improve
|
| 45 |
+
image quality and fix overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample
|
| 46 |
+
Steps are Flawed](https://huggingface.co/papers/2305.08891).
|
| 47 |
+
use_original_formulation (`bool`, defaults to `False`):
|
| 48 |
+
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
| 49 |
+
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
| 50 |
+
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
| 51 |
+
start (`float`, defaults to `0.0`):
|
| 52 |
+
The fraction of the total number of denoising steps after which the classifier-free guidance starts.
|
| 53 |
+
stop (`float`, defaults to `1.0`):
|
| 54 |
+
The fraction of the total number of denoising steps after which the classifier-free guidance stops.
|
| 55 |
+
adaptive_projected_guidance_start_step (`int`, defaults to `5`):
|
| 56 |
+
The step at which the adaptive projected guidance starts (before this step, classifier-free guidance is
|
| 57 |
+
used, and momentum buffer is updated).
|
| 58 |
+
enabled (`bool`, defaults to `True`):
|
| 59 |
+
Whether this guidance is enabled.
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
_input_predictions = ["pred_cond", "pred_uncond"]
|
| 63 |
+
|
| 64 |
+
@register_to_config
|
| 65 |
+
def __init__(
|
| 66 |
+
self,
|
| 67 |
+
guidance_scale: float = 3.5,
|
| 68 |
+
guidance_rescale: float = 0.0,
|
| 69 |
+
adaptive_projected_guidance_scale: float = 10.0,
|
| 70 |
+
adaptive_projected_guidance_momentum: float = -0.5,
|
| 71 |
+
adaptive_projected_guidance_rescale: float = 10.0,
|
| 72 |
+
eta: float = 0.0,
|
| 73 |
+
use_original_formulation: bool = False,
|
| 74 |
+
start: float = 0.0,
|
| 75 |
+
stop: float = 1.0,
|
| 76 |
+
adaptive_projected_guidance_start_step: int = 5,
|
| 77 |
+
enabled: bool = True,
|
| 78 |
+
):
|
| 79 |
+
super().__init__(start, stop, enabled)
|
| 80 |
+
|
| 81 |
+
self.guidance_scale = guidance_scale
|
| 82 |
+
self.guidance_rescale = guidance_rescale
|
| 83 |
+
self.adaptive_projected_guidance_scale = adaptive_projected_guidance_scale
|
| 84 |
+
self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
|
| 85 |
+
self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale
|
| 86 |
+
self.eta = eta
|
| 87 |
+
self.adaptive_projected_guidance_start_step = adaptive_projected_guidance_start_step
|
| 88 |
+
self.use_original_formulation = use_original_formulation
|
| 89 |
+
self.momentum_buffer = None
|
| 90 |
+
|
| 91 |
+
def prepare_inputs(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]) -> list["BlockState"]:
|
| 92 |
+
if self._step == 0:
|
| 93 |
+
if self.adaptive_projected_guidance_momentum is not None:
|
| 94 |
+
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
|
| 95 |
+
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
| 96 |
+
data_batches = []
|
| 97 |
+
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
| 98 |
+
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
|
| 99 |
+
data_batches.append(data_batch)
|
| 100 |
+
return data_batches
|
| 101 |
+
|
| 102 |
+
def prepare_inputs_from_block_state(
|
| 103 |
+
self, data: "BlockState", input_fields: dict[str, str | tuple[str, str]]
|
| 104 |
+
) -> list["BlockState"]:
|
| 105 |
+
if self._step == 0:
|
| 106 |
+
if self.adaptive_projected_guidance_momentum is not None:
|
| 107 |
+
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
|
| 108 |
+
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
| 109 |
+
data_batches = []
|
| 110 |
+
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
| 111 |
+
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
|
| 112 |
+
data_batches.append(data_batch)
|
| 113 |
+
return data_batches
|
| 114 |
+
|
| 115 |
+
def forward(self, pred_cond: torch.Tensor, pred_uncond: torch.Tensor | None = None) -> GuiderOutput:
|
| 116 |
+
pred = None
|
| 117 |
+
|
| 118 |
+
# no guidance
|
| 119 |
+
if not self._is_cfg_enabled():
|
| 120 |
+
pred = pred_cond
|
| 121 |
+
|
| 122 |
+
# CFG + update momentum buffer
|
| 123 |
+
elif not self._is_apg_enabled():
|
| 124 |
+
if self.momentum_buffer is not None:
|
| 125 |
+
update_momentum_buffer(pred_cond, pred_uncond, self.momentum_buffer)
|
| 126 |
+
# CFG + update momentum buffer
|
| 127 |
+
shift = pred_cond - pred_uncond
|
| 128 |
+
pred = pred_cond if self.use_original_formulation else pred_uncond
|
| 129 |
+
pred = pred + self.guidance_scale * shift
|
| 130 |
+
|
| 131 |
+
# APG
|
| 132 |
+
elif self._is_apg_enabled():
|
| 133 |
+
pred = normalized_guidance(
|
| 134 |
+
pred_cond,
|
| 135 |
+
pred_uncond,
|
| 136 |
+
self.adaptive_projected_guidance_scale,
|
| 137 |
+
self.momentum_buffer,
|
| 138 |
+
self.eta,
|
| 139 |
+
self.adaptive_projected_guidance_rescale,
|
| 140 |
+
self.use_original_formulation,
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
if self.guidance_rescale > 0.0:
|
| 144 |
+
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
| 145 |
+
|
| 146 |
+
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
|
| 147 |
+
|
| 148 |
+
@property
|
| 149 |
+
def is_conditional(self) -> bool:
|
| 150 |
+
return self._count_prepared == 1
|
| 151 |
+
|
| 152 |
+
@property
|
| 153 |
+
def num_conditions(self) -> int:
|
| 154 |
+
num_conditions = 1
|
| 155 |
+
if self._is_apg_enabled() or self._is_cfg_enabled():
|
| 156 |
+
num_conditions += 1
|
| 157 |
+
return num_conditions
|
| 158 |
+
|
| 159 |
+
# Copied from diffusers.guiders.classifier_free_guidance.ClassifierFreeGuidance._is_cfg_enabled
|
| 160 |
+
def _is_cfg_enabled(self) -> bool:
|
| 161 |
+
if not self._enabled:
|
| 162 |
+
return False
|
| 163 |
+
|
| 164 |
+
is_within_range = True
|
| 165 |
+
if self._num_inference_steps is not None:
|
| 166 |
+
skip_start_step = int(self._start * self._num_inference_steps)
|
| 167 |
+
skip_stop_step = int(self._stop * self._num_inference_steps)
|
| 168 |
+
is_within_range = skip_start_step <= self._step < skip_stop_step
|
| 169 |
+
|
| 170 |
+
is_close = False
|
| 171 |
+
if self.use_original_formulation:
|
| 172 |
+
is_close = math.isclose(self.guidance_scale, 0.0)
|
| 173 |
+
else:
|
| 174 |
+
is_close = math.isclose(self.guidance_scale, 1.0)
|
| 175 |
+
|
| 176 |
+
return is_within_range and not is_close
|
| 177 |
+
|
| 178 |
+
def _is_apg_enabled(self) -> bool:
|
| 179 |
+
if not self._enabled:
|
| 180 |
+
return False
|
| 181 |
+
|
| 182 |
+
if not self._is_cfg_enabled():
|
| 183 |
+
return False
|
| 184 |
+
|
| 185 |
+
is_within_range = False
|
| 186 |
+
if self._step is not None:
|
| 187 |
+
is_within_range = self._step > self.adaptive_projected_guidance_start_step
|
| 188 |
+
|
| 189 |
+
is_close = False
|
| 190 |
+
if self.use_original_formulation:
|
| 191 |
+
is_close = math.isclose(self.adaptive_projected_guidance_scale, 0.0)
|
| 192 |
+
else:
|
| 193 |
+
is_close = math.isclose(self.adaptive_projected_guidance_scale, 1.0)
|
| 194 |
+
|
| 195 |
+
return is_within_range and not is_close
|
| 196 |
+
|
| 197 |
+
def get_state(self):
|
| 198 |
+
state = super().get_state()
|
| 199 |
+
state["momentum_buffer"] = self.momentum_buffer
|
| 200 |
+
state["is_apg_enabled"] = self._is_apg_enabled()
|
| 201 |
+
state["is_cfg_enabled"] = self._is_cfg_enabled()
|
| 202 |
+
return state
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
# Copied from diffusers.guiders.adaptive_projected_guidance.MomentumBuffer
|
| 206 |
+
class MomentumBuffer:
|
| 207 |
+
def __init__(self, momentum: float):
|
| 208 |
+
self.momentum = momentum
|
| 209 |
+
self.running_average = 0
|
| 210 |
+
|
| 211 |
+
def update(self, update_value: torch.Tensor):
|
| 212 |
+
new_average = self.momentum * self.running_average
|
| 213 |
+
self.running_average = update_value + new_average
|
| 214 |
+
|
| 215 |
+
def __repr__(self) -> str:
|
| 216 |
+
"""
|
| 217 |
+
Returns a string representation showing momentum, shape, statistics, and a slice of the running_average.
|
| 218 |
+
"""
|
| 219 |
+
if isinstance(self.running_average, torch.Tensor):
|
| 220 |
+
shape = tuple(self.running_average.shape)
|
| 221 |
+
|
| 222 |
+
# Calculate statistics
|
| 223 |
+
with torch.no_grad():
|
| 224 |
+
stats = {
|
| 225 |
+
"mean": self.running_average.mean().item(),
|
| 226 |
+
"std": self.running_average.std().item(),
|
| 227 |
+
"min": self.running_average.min().item(),
|
| 228 |
+
"max": self.running_average.max().item(),
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
# Get a slice (max 3 elements per dimension)
|
| 232 |
+
slice_indices = tuple(slice(None, min(3, dim)) for dim in shape)
|
| 233 |
+
sliced_data = self.running_average[slice_indices]
|
| 234 |
+
|
| 235 |
+
# Format the slice for display (convert to float32 for numpy compatibility with bfloat16)
|
| 236 |
+
slice_str = str(sliced_data.detach().float().cpu().numpy())
|
| 237 |
+
if len(slice_str) > 200: # Truncate if too long
|
| 238 |
+
slice_str = slice_str[:200] + "..."
|
| 239 |
+
|
| 240 |
+
stats_str = ", ".join([f"{k}={v:.4f}" for k, v in stats.items()])
|
| 241 |
+
|
| 242 |
+
return (
|
| 243 |
+
f"MomentumBuffer(\n"
|
| 244 |
+
f" momentum={self.momentum},\n"
|
| 245 |
+
f" shape={shape},\n"
|
| 246 |
+
f" stats=[{stats_str}],\n"
|
| 247 |
+
f" slice={slice_str}\n"
|
| 248 |
+
f")"
|
| 249 |
+
)
|
| 250 |
+
else:
|
| 251 |
+
return f"MomentumBuffer(momentum={self.momentum}, running_average={self.running_average})"
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def update_momentum_buffer(
|
| 255 |
+
pred_cond: torch.Tensor,
|
| 256 |
+
pred_uncond: torch.Tensor,
|
| 257 |
+
momentum_buffer: MomentumBuffer | None = None,
|
| 258 |
+
):
|
| 259 |
+
diff = pred_cond - pred_uncond
|
| 260 |
+
if momentum_buffer is not None:
|
| 261 |
+
momentum_buffer.update(diff)
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def normalized_guidance(
|
| 265 |
+
pred_cond: torch.Tensor,
|
| 266 |
+
pred_uncond: torch.Tensor,
|
| 267 |
+
guidance_scale: float,
|
| 268 |
+
momentum_buffer: MomentumBuffer | None = None,
|
| 269 |
+
eta: float = 1.0,
|
| 270 |
+
norm_threshold: float = 0.0,
|
| 271 |
+
use_original_formulation: bool = False,
|
| 272 |
+
):
|
| 273 |
+
if momentum_buffer is not None:
|
| 274 |
+
update_momentum_buffer(pred_cond, pred_uncond, momentum_buffer)
|
| 275 |
+
diff = momentum_buffer.running_average
|
| 276 |
+
else:
|
| 277 |
+
diff = pred_cond - pred_uncond
|
| 278 |
+
|
| 279 |
+
dim = [-i for i in range(1, len(diff.shape))]
|
| 280 |
+
|
| 281 |
+
if norm_threshold > 0:
|
| 282 |
+
ones = torch.ones_like(diff)
|
| 283 |
+
diff_norm = diff.norm(p=2, dim=dim, keepdim=True)
|
| 284 |
+
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
|
| 285 |
+
diff = diff * scale_factor
|
| 286 |
+
|
| 287 |
+
v0, v1 = diff.double(), pred_cond.double()
|
| 288 |
+
v1 = torch.nn.functional.normalize(v1, dim=dim)
|
| 289 |
+
v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
|
| 290 |
+
v0_orthogonal = v0 - v0_parallel
|
| 291 |
+
diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff)
|
| 292 |
+
normalized_update = diff_orthogonal + eta * diff_parallel
|
| 293 |
+
|
| 294 |
+
pred = pred_cond if use_original_formulation else pred_uncond
|
| 295 |
+
pred = pred + guidance_scale * normalized_update
|
| 296 |
+
|
| 297 |
+
return pred
|
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/auto_guidance.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
from typing import TYPE_CHECKING, Any
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
|
| 22 |
+
from ..configuration_utils import register_to_config
|
| 23 |
+
from ..hooks import HookRegistry, LayerSkipConfig
|
| 24 |
+
from ..hooks.layer_skip import _apply_layer_skip_hook
|
| 25 |
+
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
if TYPE_CHECKING:
|
| 29 |
+
from ..modular_pipelines.modular_pipeline import BlockState
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class AutoGuidance(BaseGuidance):
|
| 33 |
+
"""
|
| 34 |
+
AutoGuidance: https://huggingface.co/papers/2406.02507
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
guidance_scale (`float`, defaults to `7.5`):
|
| 38 |
+
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
| 39 |
+
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
| 40 |
+
deterioration of image quality.
|
| 41 |
+
auto_guidance_layers (`int` or `list[int]`, *optional*):
|
| 42 |
+
The layer indices to apply skip layer guidance to. Can be a single integer or a list of integers. If not
|
| 43 |
+
provided, `skip_layer_config` must be provided.
|
| 44 |
+
auto_guidance_config (`LayerSkipConfig` or `list[LayerSkipConfig]`, *optional*):
|
| 45 |
+
The configuration for the skip layer guidance. Can be a single `LayerSkipConfig` or a list of
|
| 46 |
+
`LayerSkipConfig`. If not provided, `skip_layer_guidance_layers` must be provided.
|
| 47 |
+
dropout (`float`, *optional*):
|
| 48 |
+
The dropout probability for autoguidance on the enabled skip layers (either with `auto_guidance_layers` or
|
| 49 |
+
`auto_guidance_config`). If not provided, the dropout probability will be set to 1.0.
|
| 50 |
+
guidance_rescale (`float`, defaults to `0.0`):
|
| 51 |
+
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
| 52 |
+
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
| 53 |
+
Flawed](https://huggingface.co/papers/2305.08891).
|
| 54 |
+
use_original_formulation (`bool`, defaults to `False`):
|
| 55 |
+
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
| 56 |
+
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
| 57 |
+
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
| 58 |
+
start (`float`, defaults to `0.0`):
|
| 59 |
+
The fraction of the total number of denoising steps after which guidance starts.
|
| 60 |
+
stop (`float`, defaults to `1.0`):
|
| 61 |
+
The fraction of the total number of denoising steps after which guidance stops.
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
_input_predictions = ["pred_cond", "pred_uncond"]
|
| 65 |
+
|
| 66 |
+
@register_to_config
|
| 67 |
+
def __init__(
|
| 68 |
+
self,
|
| 69 |
+
guidance_scale: float = 7.5,
|
| 70 |
+
auto_guidance_layers: int | list[int] | None = None,
|
| 71 |
+
auto_guidance_config: LayerSkipConfig | list[LayerSkipConfig] | dict[str, Any] = None,
|
| 72 |
+
dropout: float | None = None,
|
| 73 |
+
guidance_rescale: float = 0.0,
|
| 74 |
+
use_original_formulation: bool = False,
|
| 75 |
+
start: float = 0.0,
|
| 76 |
+
stop: float = 1.0,
|
| 77 |
+
enabled: bool = True,
|
| 78 |
+
):
|
| 79 |
+
super().__init__(start, stop, enabled)
|
| 80 |
+
|
| 81 |
+
self.guidance_scale = guidance_scale
|
| 82 |
+
self.auto_guidance_layers = auto_guidance_layers
|
| 83 |
+
self.auto_guidance_config = auto_guidance_config
|
| 84 |
+
self.dropout = dropout
|
| 85 |
+
self.guidance_rescale = guidance_rescale
|
| 86 |
+
self.use_original_formulation = use_original_formulation
|
| 87 |
+
|
| 88 |
+
is_layer_or_config_provided = auto_guidance_layers is not None or auto_guidance_config is not None
|
| 89 |
+
is_layer_and_config_provided = auto_guidance_layers is not None and auto_guidance_config is not None
|
| 90 |
+
if not is_layer_or_config_provided:
|
| 91 |
+
raise ValueError(
|
| 92 |
+
"Either `auto_guidance_layers` or `auto_guidance_config` must be provided to enable AutoGuidance."
|
| 93 |
+
)
|
| 94 |
+
if is_layer_and_config_provided:
|
| 95 |
+
raise ValueError("Only one of `auto_guidance_layers` or `auto_guidance_config` can be provided.")
|
| 96 |
+
if auto_guidance_config is None and dropout is None:
|
| 97 |
+
raise ValueError("`dropout` must be provided if `auto_guidance_layers` is provided.")
|
| 98 |
+
|
| 99 |
+
if auto_guidance_layers is not None:
|
| 100 |
+
if isinstance(auto_guidance_layers, int):
|
| 101 |
+
auto_guidance_layers = [auto_guidance_layers]
|
| 102 |
+
if not isinstance(auto_guidance_layers, list):
|
| 103 |
+
raise ValueError(
|
| 104 |
+
f"Expected `auto_guidance_layers` to be an int or a list of ints, but got {type(auto_guidance_layers)}."
|
| 105 |
+
)
|
| 106 |
+
auto_guidance_config = [
|
| 107 |
+
LayerSkipConfig(layer, fqn="auto", dropout=dropout) for layer in auto_guidance_layers
|
| 108 |
+
]
|
| 109 |
+
|
| 110 |
+
if isinstance(auto_guidance_config, dict):
|
| 111 |
+
auto_guidance_config = LayerSkipConfig.from_dict(auto_guidance_config)
|
| 112 |
+
|
| 113 |
+
if isinstance(auto_guidance_config, LayerSkipConfig):
|
| 114 |
+
auto_guidance_config = [auto_guidance_config]
|
| 115 |
+
|
| 116 |
+
if not isinstance(auto_guidance_config, list):
|
| 117 |
+
raise ValueError(
|
| 118 |
+
f"Expected `auto_guidance_config` to be a LayerSkipConfig or a list of LayerSkipConfig, but got {type(auto_guidance_config)}."
|
| 119 |
+
)
|
| 120 |
+
elif isinstance(next(iter(auto_guidance_config), None), dict):
|
| 121 |
+
auto_guidance_config = [LayerSkipConfig.from_dict(config) for config in auto_guidance_config]
|
| 122 |
+
|
| 123 |
+
self.auto_guidance_config = auto_guidance_config
|
| 124 |
+
self._auto_guidance_hook_names = [f"AutoGuidance_{i}" for i in range(len(self.auto_guidance_config))]
|
| 125 |
+
|
| 126 |
+
def prepare_models(self, denoiser: torch.nn.Module) -> None:
|
| 127 |
+
self._count_prepared += 1
|
| 128 |
+
if self._is_ag_enabled() and self.is_unconditional:
|
| 129 |
+
for name, config in zip(self._auto_guidance_hook_names, self.auto_guidance_config):
|
| 130 |
+
_apply_layer_skip_hook(denoiser, config, name=name)
|
| 131 |
+
|
| 132 |
+
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
|
| 133 |
+
if self._is_ag_enabled() and self.is_unconditional:
|
| 134 |
+
for name in self._auto_guidance_hook_names:
|
| 135 |
+
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
|
| 136 |
+
registry.remove_hook(name, recurse=True)
|
| 137 |
+
|
| 138 |
+
def prepare_inputs(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]) -> list["BlockState"]:
|
| 139 |
+
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
| 140 |
+
data_batches = []
|
| 141 |
+
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
| 142 |
+
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
|
| 143 |
+
data_batches.append(data_batch)
|
| 144 |
+
return data_batches
|
| 145 |
+
|
| 146 |
+
def prepare_inputs_from_block_state(
|
| 147 |
+
self, data: "BlockState", input_fields: dict[str, str | tuple[str, str]]
|
| 148 |
+
) -> list["BlockState"]:
|
| 149 |
+
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
| 150 |
+
data_batches = []
|
| 151 |
+
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
| 152 |
+
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
|
| 153 |
+
data_batches.append(data_batch)
|
| 154 |
+
return data_batches
|
| 155 |
+
|
| 156 |
+
def forward(self, pred_cond: torch.Tensor, pred_uncond: torch.Tensor | None = None) -> GuiderOutput:
|
| 157 |
+
pred = None
|
| 158 |
+
|
| 159 |
+
if not self._is_ag_enabled():
|
| 160 |
+
pred = pred_cond
|
| 161 |
+
else:
|
| 162 |
+
shift = pred_cond - pred_uncond
|
| 163 |
+
pred = pred_cond if self.use_original_formulation else pred_uncond
|
| 164 |
+
pred = pred + self.guidance_scale * shift
|
| 165 |
+
|
| 166 |
+
if self.guidance_rescale > 0.0:
|
| 167 |
+
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
| 168 |
+
|
| 169 |
+
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
|
| 170 |
+
|
| 171 |
+
@property
|
| 172 |
+
def is_conditional(self) -> bool:
|
| 173 |
+
return self._count_prepared == 1
|
| 174 |
+
|
| 175 |
+
@property
|
| 176 |
+
def num_conditions(self) -> int:
|
| 177 |
+
num_conditions = 1
|
| 178 |
+
if self._is_ag_enabled():
|
| 179 |
+
num_conditions += 1
|
| 180 |
+
return num_conditions
|
| 181 |
+
|
| 182 |
+
def _is_ag_enabled(self) -> bool:
|
| 183 |
+
if not self._enabled:
|
| 184 |
+
return False
|
| 185 |
+
|
| 186 |
+
is_within_range = True
|
| 187 |
+
if self._num_inference_steps is not None:
|
| 188 |
+
skip_start_step = int(self._start * self._num_inference_steps)
|
| 189 |
+
skip_stop_step = int(self._stop * self._num_inference_steps)
|
| 190 |
+
is_within_range = skip_start_step <= self._step < skip_stop_step
|
| 191 |
+
|
| 192 |
+
is_close = False
|
| 193 |
+
if self.use_original_formulation:
|
| 194 |
+
is_close = math.isclose(self.guidance_scale, 0.0)
|
| 195 |
+
else:
|
| 196 |
+
is_close = math.isclose(self.guidance_scale, 1.0)
|
| 197 |
+
|
| 198 |
+
return is_within_range and not is_close
|
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/classifier_free_guidance.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
from typing import TYPE_CHECKING
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
|
| 22 |
+
from ..configuration_utils import register_to_config
|
| 23 |
+
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
if TYPE_CHECKING:
|
| 27 |
+
from ..modular_pipelines.modular_pipeline import BlockState
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class ClassifierFreeGuidance(BaseGuidance):
|
| 31 |
+
"""
|
| 32 |
+
Implements Classifier-Free Guidance (CFG) for diffusion models.
|
| 33 |
+
|
| 34 |
+
Reference: https://huggingface.co/papers/2207.12598
|
| 35 |
+
|
| 36 |
+
CFG improves generation quality and prompt adherence by jointly training models on both conditional and
|
| 37 |
+
unconditional data, then combining predictions during inference. This allows trading off between quality (high
|
| 38 |
+
guidance) and diversity (low guidance).
|
| 39 |
+
|
| 40 |
+
**Two CFG Formulations:**
|
| 41 |
+
|
| 42 |
+
1. **Original formulation** (from paper):
|
| 43 |
+
```
|
| 44 |
+
x_pred = x_cond + guidance_scale * (x_cond - x_uncond)
|
| 45 |
+
```
|
| 46 |
+
Moves conditional predictions further from unconditional ones.
|
| 47 |
+
|
| 48 |
+
2. **Diffusers-native formulation** (default, from Imagen paper):
|
| 49 |
+
```
|
| 50 |
+
x_pred = x_uncond + guidance_scale * (x_cond - x_uncond)
|
| 51 |
+
```
|
| 52 |
+
Moves unconditional predictions toward conditional ones, effectively suppressing negative features (e.g., "bad
|
| 53 |
+
quality", "watermarks"). Equivalent in theory but more intuitive.
|
| 54 |
+
|
| 55 |
+
Use `use_original_formulation=True` to switch to the original formulation.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
guidance_scale (`float`, defaults to `7.5`):
|
| 59 |
+
CFG scale applied by this guider during post-processing. Higher values = stronger prompt conditioning but
|
| 60 |
+
may reduce quality. Typical range: 1.0-20.0.
|
| 61 |
+
guidance_rescale (`float`, defaults to `0.0`):
|
| 62 |
+
Rescaling factor to prevent overexposure from high guidance scales. Based on [Common Diffusion Noise
|
| 63 |
+
Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). Range: 0.0 (no rescaling)
|
| 64 |
+
to 1.0 (full rescaling).
|
| 65 |
+
use_original_formulation (`bool`, defaults to `False`):
|
| 66 |
+
If `True`, uses the original CFG formulation from the paper. If `False` (default), uses the
|
| 67 |
+
diffusers-native formulation from the Imagen paper.
|
| 68 |
+
start (`float`, defaults to `0.0`):
|
| 69 |
+
Fraction of denoising steps (0.0-1.0) after which CFG starts. Use > 0.0 to disable CFG in early denoising
|
| 70 |
+
steps.
|
| 71 |
+
stop (`float`, defaults to `1.0`):
|
| 72 |
+
Fraction of denoising steps (0.0-1.0) after which CFG stops. Use < 1.0 to disable CFG in late denoising
|
| 73 |
+
steps.
|
| 74 |
+
enabled (`bool`, defaults to `True`):
|
| 75 |
+
Whether CFG is enabled. Set to `False` to disable CFG entirely (uses only conditional predictions).
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
_input_predictions = ["pred_cond", "pred_uncond"]
|
| 79 |
+
|
| 80 |
+
@register_to_config
|
| 81 |
+
def __init__(
|
| 82 |
+
self,
|
| 83 |
+
guidance_scale: float = 7.5,
|
| 84 |
+
guidance_rescale: float = 0.0,
|
| 85 |
+
use_original_formulation: bool = False,
|
| 86 |
+
start: float = 0.0,
|
| 87 |
+
stop: float = 1.0,
|
| 88 |
+
enabled: bool = True,
|
| 89 |
+
):
|
| 90 |
+
super().__init__(start, stop, enabled)
|
| 91 |
+
|
| 92 |
+
self.guidance_scale = guidance_scale
|
| 93 |
+
self.guidance_rescale = guidance_rescale
|
| 94 |
+
self.use_original_formulation = use_original_formulation
|
| 95 |
+
|
| 96 |
+
def prepare_inputs(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]) -> list["BlockState"]:
|
| 97 |
+
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
| 98 |
+
data_batches = []
|
| 99 |
+
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
| 100 |
+
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
|
| 101 |
+
data_batches.append(data_batch)
|
| 102 |
+
return data_batches
|
| 103 |
+
|
| 104 |
+
def prepare_inputs_from_block_state(
|
| 105 |
+
self, data: "BlockState", input_fields: dict[str, str | tuple[str, str]]
|
| 106 |
+
) -> list["BlockState"]:
|
| 107 |
+
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
| 108 |
+
data_batches = []
|
| 109 |
+
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
| 110 |
+
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
|
| 111 |
+
data_batches.append(data_batch)
|
| 112 |
+
return data_batches
|
| 113 |
+
|
| 114 |
+
def forward(self, pred_cond: torch.Tensor, pred_uncond: torch.Tensor | None = None) -> GuiderOutput:
|
| 115 |
+
pred = None
|
| 116 |
+
|
| 117 |
+
if not self._is_cfg_enabled():
|
| 118 |
+
pred = pred_cond
|
| 119 |
+
else:
|
| 120 |
+
shift = pred_cond - pred_uncond
|
| 121 |
+
pred = pred_cond if self.use_original_formulation else pred_uncond
|
| 122 |
+
pred = pred + self.guidance_scale * shift
|
| 123 |
+
|
| 124 |
+
if self.guidance_rescale > 0.0:
|
| 125 |
+
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
| 126 |
+
|
| 127 |
+
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
|
| 128 |
+
|
| 129 |
+
@property
|
| 130 |
+
def is_conditional(self) -> bool:
|
| 131 |
+
return self._count_prepared == 1
|
| 132 |
+
|
| 133 |
+
@property
|
| 134 |
+
def num_conditions(self) -> int:
|
| 135 |
+
num_conditions = 1
|
| 136 |
+
if self._is_cfg_enabled():
|
| 137 |
+
num_conditions += 1
|
| 138 |
+
return num_conditions
|
| 139 |
+
|
| 140 |
+
def _is_cfg_enabled(self) -> bool:
|
| 141 |
+
if not self._enabled:
|
| 142 |
+
return False
|
| 143 |
+
|
| 144 |
+
is_within_range = True
|
| 145 |
+
if self._num_inference_steps is not None:
|
| 146 |
+
skip_start_step = int(self._start * self._num_inference_steps)
|
| 147 |
+
skip_stop_step = int(self._stop * self._num_inference_steps)
|
| 148 |
+
is_within_range = skip_start_step <= self._step < skip_stop_step
|
| 149 |
+
|
| 150 |
+
is_close = False
|
| 151 |
+
if self.use_original_formulation:
|
| 152 |
+
is_close = math.isclose(self.guidance_scale, 0.0)
|
| 153 |
+
else:
|
| 154 |
+
is_close = math.isclose(self.guidance_scale, 1.0)
|
| 155 |
+
|
| 156 |
+
return is_within_range and not is_close
|
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/classifier_free_zero_star_guidance.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
from typing import TYPE_CHECKING
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
|
| 22 |
+
from ..configuration_utils import register_to_config
|
| 23 |
+
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
if TYPE_CHECKING:
|
| 27 |
+
from ..modular_pipelines.modular_pipeline import BlockState
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class ClassifierFreeZeroStarGuidance(BaseGuidance):
|
| 31 |
+
"""
|
| 32 |
+
Classifier-free Zero* (CFG-Zero*): https://huggingface.co/papers/2503.18886
|
| 33 |
+
|
| 34 |
+
This is an implementation of the Classifier-Free Zero* guidance technique, which is a variant of classifier-free
|
| 35 |
+
guidance. It proposes zero initialization of the noise predictions for the first few steps of the diffusion
|
| 36 |
+
process, and also introduces an optimal rescaling factor for the noise predictions, which can help in improving the
|
| 37 |
+
quality of generated images.
|
| 38 |
+
|
| 39 |
+
The authors of the paper suggest setting zero initialization in the first 4% of the inference steps.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
guidance_scale (`float`, defaults to `7.5`):
|
| 43 |
+
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
| 44 |
+
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
| 45 |
+
deterioration of image quality.
|
| 46 |
+
zero_init_steps (`int`, defaults to `1`):
|
| 47 |
+
The number of inference steps for which the noise predictions are zeroed out (see Section 4.2).
|
| 48 |
+
guidance_rescale (`float`, defaults to `0.0`):
|
| 49 |
+
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
| 50 |
+
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
| 51 |
+
Flawed](https://huggingface.co/papers/2305.08891).
|
| 52 |
+
use_original_formulation (`bool`, defaults to `False`):
|
| 53 |
+
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
| 54 |
+
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
| 55 |
+
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
| 56 |
+
start (`float`, defaults to `0.01`):
|
| 57 |
+
The fraction of the total number of denoising steps after which guidance starts.
|
| 58 |
+
stop (`float`, defaults to `0.2`):
|
| 59 |
+
The fraction of the total number of denoising steps after which guidance stops.
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
_input_predictions = ["pred_cond", "pred_uncond"]
|
| 63 |
+
|
| 64 |
+
@register_to_config
|
| 65 |
+
def __init__(
|
| 66 |
+
self,
|
| 67 |
+
guidance_scale: float = 7.5,
|
| 68 |
+
zero_init_steps: int = 1,
|
| 69 |
+
guidance_rescale: float = 0.0,
|
| 70 |
+
use_original_formulation: bool = False,
|
| 71 |
+
start: float = 0.0,
|
| 72 |
+
stop: float = 1.0,
|
| 73 |
+
enabled: bool = True,
|
| 74 |
+
):
|
| 75 |
+
super().__init__(start, stop, enabled)
|
| 76 |
+
|
| 77 |
+
self.guidance_scale = guidance_scale
|
| 78 |
+
self.zero_init_steps = zero_init_steps
|
| 79 |
+
self.guidance_rescale = guidance_rescale
|
| 80 |
+
self.use_original_formulation = use_original_formulation
|
| 81 |
+
|
| 82 |
+
def prepare_inputs(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]) -> list["BlockState"]:
|
| 83 |
+
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
| 84 |
+
data_batches = []
|
| 85 |
+
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
| 86 |
+
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
|
| 87 |
+
data_batches.append(data_batch)
|
| 88 |
+
return data_batches
|
| 89 |
+
|
| 90 |
+
def prepare_inputs_from_block_state(
|
| 91 |
+
self, data: "BlockState", input_fields: dict[str, str | tuple[str, str]]
|
| 92 |
+
) -> list["BlockState"]:
|
| 93 |
+
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
| 94 |
+
data_batches = []
|
| 95 |
+
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
| 96 |
+
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
|
| 97 |
+
data_batches.append(data_batch)
|
| 98 |
+
return data_batches
|
| 99 |
+
|
| 100 |
+
def forward(self, pred_cond: torch.Tensor, pred_uncond: torch.Tensor | None = None) -> GuiderOutput:
|
| 101 |
+
pred = None
|
| 102 |
+
|
| 103 |
+
# YiYi Notes: add default behavior for self._enabled == False
|
| 104 |
+
if not self._enabled:
|
| 105 |
+
pred = pred_cond
|
| 106 |
+
|
| 107 |
+
elif self._step < self.zero_init_steps:
|
| 108 |
+
pred = torch.zeros_like(pred_cond)
|
| 109 |
+
elif not self._is_cfg_enabled():
|
| 110 |
+
pred = pred_cond
|
| 111 |
+
else:
|
| 112 |
+
pred_cond_flat = pred_cond.flatten(1)
|
| 113 |
+
pred_uncond_flat = pred_uncond.flatten(1)
|
| 114 |
+
alpha = cfg_zero_star_scale(pred_cond_flat, pred_uncond_flat)
|
| 115 |
+
alpha = alpha.view(-1, *(1,) * (len(pred_cond.shape) - 1))
|
| 116 |
+
pred_uncond = pred_uncond * alpha
|
| 117 |
+
shift = pred_cond - pred_uncond
|
| 118 |
+
pred = pred_cond if self.use_original_formulation else pred_uncond
|
| 119 |
+
pred = pred + self.guidance_scale * shift
|
| 120 |
+
|
| 121 |
+
if self.guidance_rescale > 0.0:
|
| 122 |
+
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
| 123 |
+
|
| 124 |
+
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
|
| 125 |
+
|
| 126 |
+
@property
|
| 127 |
+
def is_conditional(self) -> bool:
|
| 128 |
+
return self._count_prepared == 1
|
| 129 |
+
|
| 130 |
+
@property
|
| 131 |
+
def num_conditions(self) -> int:
|
| 132 |
+
num_conditions = 1
|
| 133 |
+
if self._is_cfg_enabled():
|
| 134 |
+
num_conditions += 1
|
| 135 |
+
return num_conditions
|
| 136 |
+
|
| 137 |
+
def _is_cfg_enabled(self) -> bool:
|
| 138 |
+
if not self._enabled:
|
| 139 |
+
return False
|
| 140 |
+
|
| 141 |
+
is_within_range = True
|
| 142 |
+
if self._num_inference_steps is not None:
|
| 143 |
+
skip_start_step = int(self._start * self._num_inference_steps)
|
| 144 |
+
skip_stop_step = int(self._stop * self._num_inference_steps)
|
| 145 |
+
is_within_range = skip_start_step <= self._step < skip_stop_step
|
| 146 |
+
|
| 147 |
+
is_close = False
|
| 148 |
+
if self.use_original_formulation:
|
| 149 |
+
is_close = math.isclose(self.guidance_scale, 0.0)
|
| 150 |
+
else:
|
| 151 |
+
is_close = math.isclose(self.guidance_scale, 1.0)
|
| 152 |
+
|
| 153 |
+
return is_within_range and not is_close
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def cfg_zero_star_scale(cond: torch.Tensor, uncond: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
|
| 157 |
+
cond_dtype = cond.dtype
|
| 158 |
+
cond = cond.float()
|
| 159 |
+
uncond = uncond.float()
|
| 160 |
+
dot_product = torch.sum(cond * uncond, dim=1, keepdim=True)
|
| 161 |
+
squared_norm = torch.sum(uncond**2, dim=1, keepdim=True) + eps
|
| 162 |
+
# st_star = v_cond^T * v_uncond / ||v_uncond||^2
|
| 163 |
+
scale = dot_product / squared_norm
|
| 164 |
+
return scale.to(dtype=cond_dtype)
|
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/frequency_decoupled_guidance.py
ADDED
|
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
from typing import TYPE_CHECKING
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
|
| 22 |
+
from ..configuration_utils import register_to_config
|
| 23 |
+
from ..utils import is_kornia_available
|
| 24 |
+
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
if TYPE_CHECKING:
|
| 28 |
+
from ..modular_pipelines.modular_pipeline import BlockState
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
_CAN_USE_KORNIA = is_kornia_available()
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
if _CAN_USE_KORNIA:
|
| 35 |
+
from kornia.geometry import pyrup as upsample_and_blur_func
|
| 36 |
+
from kornia.geometry.transform import build_laplacian_pyramid as build_laplacian_pyramid_func
|
| 37 |
+
else:
|
| 38 |
+
upsample_and_blur_func = None
|
| 39 |
+
build_laplacian_pyramid_func = None
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def project(v0: torch.Tensor, v1: torch.Tensor, upcast_to_double: bool = True) -> tuple[torch.Tensor, torch.Tensor]:
|
| 43 |
+
"""
|
| 44 |
+
Project vector v0 onto vector v1, returning the parallel and orthogonal components of v0. Implementation from paper
|
| 45 |
+
(Algorithm 2).
|
| 46 |
+
"""
|
| 47 |
+
# v0 shape: [B, ...]
|
| 48 |
+
# v1 shape: [B, ...]
|
| 49 |
+
# Assume first dim is a batch dim and all other dims are channel or "spatial" dims
|
| 50 |
+
all_dims_but_first = list(range(1, len(v0.shape)))
|
| 51 |
+
if upcast_to_double:
|
| 52 |
+
dtype = v0.dtype
|
| 53 |
+
v0, v1 = v0.double(), v1.double()
|
| 54 |
+
v1 = torch.nn.functional.normalize(v1, dim=all_dims_but_first)
|
| 55 |
+
v0_parallel = (v0 * v1).sum(dim=all_dims_but_first, keepdim=True) * v1
|
| 56 |
+
v0_orthogonal = v0 - v0_parallel
|
| 57 |
+
if upcast_to_double:
|
| 58 |
+
v0_parallel = v0_parallel.to(dtype)
|
| 59 |
+
v0_orthogonal = v0_orthogonal.to(dtype)
|
| 60 |
+
return v0_parallel, v0_orthogonal
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def build_image_from_pyramid(pyramid: list[torch.Tensor]) -> torch.Tensor:
|
| 64 |
+
"""
|
| 65 |
+
Recovers the data space latents from the Laplacian pyramid frequency space. Implementation from the paper
|
| 66 |
+
(Algorithm 2).
|
| 67 |
+
"""
|
| 68 |
+
# pyramid shapes: [[B, C, H, W], [B, C, H/2, W/2], ...]
|
| 69 |
+
img = pyramid[-1]
|
| 70 |
+
for i in range(len(pyramid) - 2, -1, -1):
|
| 71 |
+
img = upsample_and_blur_func(img) + pyramid[i]
|
| 72 |
+
return img
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class FrequencyDecoupledGuidance(BaseGuidance):
|
| 76 |
+
"""
|
| 77 |
+
Frequency-Decoupled Guidance (FDG): https://huggingface.co/papers/2506.19713
|
| 78 |
+
|
| 79 |
+
FDG is a technique similar to (and based on) classifier-free guidance (CFG) which is used to improve generation
|
| 80 |
+
quality and condition-following in diffusion models. Like CFG, during training we jointly train the model on both
|
| 81 |
+
conditional and unconditional data, and use a combination of the two during inference. (If you want more details on
|
| 82 |
+
how CFG works, you can check out the CFG guider.)
|
| 83 |
+
|
| 84 |
+
FDG differs from CFG in that the normal CFG prediction is instead decoupled into low- and high-frequency components
|
| 85 |
+
using a frequency transform (such as a Laplacian pyramid). The CFG update is then performed in frequency space
|
| 86 |
+
separately for the low- and high-frequency components with different guidance scales. Finally, the inverse
|
| 87 |
+
frequency transform is used to map the CFG frequency predictions back to data space (e.g. pixel space for images)
|
| 88 |
+
to form the final FDG prediction.
|
| 89 |
+
|
| 90 |
+
For images, the FDG authors found that using low guidance scales for the low-frequency components retains sample
|
| 91 |
+
diversity and realistic color composition, while using high guidance scales for high-frequency components enhances
|
| 92 |
+
sample quality (such as better visual details). Therefore, they recommend using low guidance scales (low w_low) for
|
| 93 |
+
the low-frequency components and high guidance scales (high w_high) for the high-frequency components. As an
|
| 94 |
+
example, they suggest w_low = 5.0 and w_high = 10.0 for Stable Diffusion XL (see Table 8 in the paper).
|
| 95 |
+
|
| 96 |
+
As with CFG, Diffusers implements the scaling and shifting on the unconditional prediction based on the [Imagen
|
| 97 |
+
paper](https://huggingface.co/papers/2205.11487), which is equivalent to what the original CFG paper proposed in
|
| 98 |
+
theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)]
|
| 99 |
+
|
| 100 |
+
The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the
|
| 101 |
+
paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
guidance_scales (`list[float]`, defaults to `[10.0, 5.0]`):
|
| 105 |
+
The scale parameter for frequency-decoupled guidance for each frequency component, listed from highest
|
| 106 |
+
frequency level to lowest. Higher values result in stronger conditioning on the text prompt, while lower
|
| 107 |
+
values allow for more freedom in generation. Higher values may lead to saturation and deterioration of
|
| 108 |
+
image quality. The FDG authors recommend using higher guidance scales for higher frequency components and
|
| 109 |
+
lower guidance scales for lower frequency components (so `guidance_scales` should typically be sorted in
|
| 110 |
+
descending order).
|
| 111 |
+
guidance_rescale (`float` or `list[float]`, defaults to `0.0`):
|
| 112 |
+
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
| 113 |
+
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
| 114 |
+
Flawed](https://huggingface.co/papers/2305.08891). If a list is supplied, it should be the same length as
|
| 115 |
+
`guidance_scales`.
|
| 116 |
+
parallel_weights (`float` or `list[float]`, *optional*):
|
| 117 |
+
Optional weights for the parallel component of each frequency component of the projected CFG shift. If not
|
| 118 |
+
set, the weights will default to `1.0` for all components, which corresponds to using the normal CFG shift
|
| 119 |
+
(that is, equal weights for the parallel and orthogonal components). If set, a value in `[0, 1]` is
|
| 120 |
+
recommended. If a list is supplied, it should be the same length as `guidance_scales`.
|
| 121 |
+
use_original_formulation (`bool`, defaults to `False`):
|
| 122 |
+
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
| 123 |
+
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
| 124 |
+
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
| 125 |
+
start (`float` or `list[float]`, defaults to `0.0`):
|
| 126 |
+
The fraction of the total number of denoising steps after which guidance starts. If a list is supplied, it
|
| 127 |
+
should be the same length as `guidance_scales`.
|
| 128 |
+
stop (`float` or `list[float]`, defaults to `1.0`):
|
| 129 |
+
The fraction of the total number of denoising steps after which guidance stops. If a list is supplied, it
|
| 130 |
+
should be the same length as `guidance_scales`.
|
| 131 |
+
guidance_rescale_space (`str`, defaults to `"data"`):
|
| 132 |
+
Whether to performance guidance rescaling in `"data"` space (after the full FDG update in data space) or in
|
| 133 |
+
`"freq"` space (right after the CFG update, for each freq level). Note that frequency space rescaling is
|
| 134 |
+
speculative and may not produce expected results. If `"data"` is set, the first `guidance_rescale` value
|
| 135 |
+
will be used; otherwise, per-frequency-level guidance rescale values will be used if available.
|
| 136 |
+
upcast_to_double (`bool`, defaults to `True`):
|
| 137 |
+
Whether to upcast certain operations, such as the projection operation when using `parallel_weights`, to
|
| 138 |
+
float64 when performing guidance. This may result in better performance at the cost of increased runtime.
|
| 139 |
+
"""
|
| 140 |
+
|
| 141 |
+
_input_predictions = ["pred_cond", "pred_uncond"]
|
| 142 |
+
|
| 143 |
+
@register_to_config
|
| 144 |
+
def __init__(
|
| 145 |
+
self,
|
| 146 |
+
guidance_scales: list[float] | tuple[float] = [10.0, 5.0],
|
| 147 |
+
guidance_rescale: float | list[float] | tuple[float] = 0.0,
|
| 148 |
+
parallel_weights: float | list[float] | tuple[float] | None = None,
|
| 149 |
+
use_original_formulation: bool = False,
|
| 150 |
+
start: float | list[float] | tuple[float] = 0.0,
|
| 151 |
+
stop: float | list[float] | tuple[float] = 1.0,
|
| 152 |
+
guidance_rescale_space: str = "data",
|
| 153 |
+
upcast_to_double: bool = True,
|
| 154 |
+
enabled: bool = True,
|
| 155 |
+
):
|
| 156 |
+
if not _CAN_USE_KORNIA:
|
| 157 |
+
raise ImportError(
|
| 158 |
+
"The `FrequencyDecoupledGuidance` guider cannot be instantiated because the `kornia` library on which "
|
| 159 |
+
"it depends is not available in the current environment. You can install `kornia` with `pip install "
|
| 160 |
+
"kornia`."
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
# Set start to earliest start for any freq component and stop to latest stop for any freq component
|
| 164 |
+
min_start = start if isinstance(start, float) else min(start)
|
| 165 |
+
max_stop = stop if isinstance(stop, float) else max(stop)
|
| 166 |
+
super().__init__(min_start, max_stop, enabled)
|
| 167 |
+
|
| 168 |
+
self.guidance_scales = guidance_scales
|
| 169 |
+
self.levels = len(guidance_scales)
|
| 170 |
+
|
| 171 |
+
if isinstance(guidance_rescale, float):
|
| 172 |
+
self.guidance_rescale = [guidance_rescale] * self.levels
|
| 173 |
+
elif len(guidance_rescale) == self.levels:
|
| 174 |
+
self.guidance_rescale = guidance_rescale
|
| 175 |
+
else:
|
| 176 |
+
raise ValueError(
|
| 177 |
+
f"`guidance_rescale` has length {len(guidance_rescale)} but should have the same length as "
|
| 178 |
+
f"`guidance_scales` ({len(self.guidance_scales)})"
|
| 179 |
+
)
|
| 180 |
+
# Whether to perform guidance rescaling in frequency space (right after the CFG update) or data space (after
|
| 181 |
+
# transforming from frequency space back to data space)
|
| 182 |
+
if guidance_rescale_space not in ["data", "freq"]:
|
| 183 |
+
raise ValueError(
|
| 184 |
+
f"Guidance rescale space is {guidance_rescale_space} but must be one of `data` or `freq`."
|
| 185 |
+
)
|
| 186 |
+
self.guidance_rescale_space = guidance_rescale_space
|
| 187 |
+
|
| 188 |
+
if parallel_weights is None:
|
| 189 |
+
# Use normal CFG shift (equal weights for parallel and orthogonal components)
|
| 190 |
+
self.parallel_weights = [1.0] * self.levels
|
| 191 |
+
elif isinstance(parallel_weights, float):
|
| 192 |
+
self.parallel_weights = [parallel_weights] * self.levels
|
| 193 |
+
elif len(parallel_weights) == self.levels:
|
| 194 |
+
self.parallel_weights = parallel_weights
|
| 195 |
+
else:
|
| 196 |
+
raise ValueError(
|
| 197 |
+
f"`parallel_weights` has length {len(parallel_weights)} but should have the same length as "
|
| 198 |
+
f"`guidance_scales` ({len(self.guidance_scales)})"
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
self.use_original_formulation = use_original_formulation
|
| 202 |
+
self.upcast_to_double = upcast_to_double
|
| 203 |
+
|
| 204 |
+
if isinstance(start, float):
|
| 205 |
+
self.guidance_start = [start] * self.levels
|
| 206 |
+
elif len(start) == self.levels:
|
| 207 |
+
self.guidance_start = start
|
| 208 |
+
else:
|
| 209 |
+
raise ValueError(
|
| 210 |
+
f"`start` has length {len(start)} but should have the same length as `guidance_scales` "
|
| 211 |
+
f"({len(self.guidance_scales)})"
|
| 212 |
+
)
|
| 213 |
+
if isinstance(stop, float):
|
| 214 |
+
self.guidance_stop = [stop] * self.levels
|
| 215 |
+
elif len(stop) == self.levels:
|
| 216 |
+
self.guidance_stop = stop
|
| 217 |
+
else:
|
| 218 |
+
raise ValueError(
|
| 219 |
+
f"`stop` has length {len(stop)} but should have the same length as `guidance_scales` "
|
| 220 |
+
f"({len(self.guidance_scales)})"
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
def prepare_inputs(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]) -> list["BlockState"]:
|
| 224 |
+
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
| 225 |
+
data_batches = []
|
| 226 |
+
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
| 227 |
+
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
|
| 228 |
+
data_batches.append(data_batch)
|
| 229 |
+
return data_batches
|
| 230 |
+
|
| 231 |
+
def prepare_inputs_from_block_state(
|
| 232 |
+
self, data: "BlockState", input_fields: dict[str, str | tuple[str, str]]
|
| 233 |
+
) -> list["BlockState"]:
|
| 234 |
+
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
| 235 |
+
data_batches = []
|
| 236 |
+
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
| 237 |
+
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
|
| 238 |
+
data_batches.append(data_batch)
|
| 239 |
+
return data_batches
|
| 240 |
+
|
| 241 |
+
def forward(self, pred_cond: torch.Tensor, pred_uncond: torch.Tensor | None = None) -> GuiderOutput:
|
| 242 |
+
pred = None
|
| 243 |
+
|
| 244 |
+
if not self._is_fdg_enabled():
|
| 245 |
+
pred = pred_cond
|
| 246 |
+
else:
|
| 247 |
+
# Apply the frequency transform (e.g. Laplacian pyramid) to the conditional and unconditional predictions.
|
| 248 |
+
pred_cond_pyramid = build_laplacian_pyramid_func(pred_cond, self.levels)
|
| 249 |
+
pred_uncond_pyramid = build_laplacian_pyramid_func(pred_uncond, self.levels)
|
| 250 |
+
|
| 251 |
+
# From high frequencies to low frequencies, following the paper implementation
|
| 252 |
+
pred_guided_pyramid = []
|
| 253 |
+
parameters = zip(self.guidance_scales, self.parallel_weights, self.guidance_rescale)
|
| 254 |
+
for level, (guidance_scale, parallel_weight, guidance_rescale) in enumerate(parameters):
|
| 255 |
+
if self._is_fdg_enabled_for_level(level):
|
| 256 |
+
# Get the cond/uncond preds (in freq space) at the current frequency level
|
| 257 |
+
pred_cond_freq = pred_cond_pyramid[level]
|
| 258 |
+
pred_uncond_freq = pred_uncond_pyramid[level]
|
| 259 |
+
|
| 260 |
+
shift = pred_cond_freq - pred_uncond_freq
|
| 261 |
+
|
| 262 |
+
# Apply parallel weights, if used (1.0 corresponds to using the normal CFG shift)
|
| 263 |
+
if not math.isclose(parallel_weight, 1.0):
|
| 264 |
+
shift_parallel, shift_orthogonal = project(shift, pred_cond_freq, self.upcast_to_double)
|
| 265 |
+
shift = parallel_weight * shift_parallel + shift_orthogonal
|
| 266 |
+
|
| 267 |
+
# Apply CFG update for the current frequency level
|
| 268 |
+
pred = pred_cond_freq if self.use_original_formulation else pred_uncond_freq
|
| 269 |
+
pred = pred + guidance_scale * shift
|
| 270 |
+
|
| 271 |
+
if self.guidance_rescale_space == "freq" and guidance_rescale > 0.0:
|
| 272 |
+
pred = rescale_noise_cfg(pred, pred_cond_freq, guidance_rescale)
|
| 273 |
+
|
| 274 |
+
# Add the current FDG guided level to the FDG prediction pyramid
|
| 275 |
+
pred_guided_pyramid.append(pred)
|
| 276 |
+
else:
|
| 277 |
+
# Add the current pred_cond_pyramid level as the "non-FDG" prediction
|
| 278 |
+
pred_guided_pyramid.append(pred_cond_freq)
|
| 279 |
+
|
| 280 |
+
# Convert from frequency space back to data (e.g. pixel) space by applying inverse freq transform
|
| 281 |
+
pred = build_image_from_pyramid(pred_guided_pyramid)
|
| 282 |
+
|
| 283 |
+
# If rescaling in data space, use the first elem of self.guidance_rescale as the "global" rescale value
|
| 284 |
+
# across all freq levels
|
| 285 |
+
if self.guidance_rescale_space == "data" and self.guidance_rescale[0] > 0.0:
|
| 286 |
+
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale[0])
|
| 287 |
+
|
| 288 |
+
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
|
| 289 |
+
|
| 290 |
+
@property
|
| 291 |
+
def is_conditional(self) -> bool:
|
| 292 |
+
return self._count_prepared == 1
|
| 293 |
+
|
| 294 |
+
@property
|
| 295 |
+
def num_conditions(self) -> int:
|
| 296 |
+
num_conditions = 1
|
| 297 |
+
if self._is_fdg_enabled():
|
| 298 |
+
num_conditions += 1
|
| 299 |
+
return num_conditions
|
| 300 |
+
|
| 301 |
+
def _is_fdg_enabled(self) -> bool:
|
| 302 |
+
if not self._enabled:
|
| 303 |
+
return False
|
| 304 |
+
|
| 305 |
+
is_within_range = True
|
| 306 |
+
if self._num_inference_steps is not None:
|
| 307 |
+
skip_start_step = int(self._start * self._num_inference_steps)
|
| 308 |
+
skip_stop_step = int(self._stop * self._num_inference_steps)
|
| 309 |
+
is_within_range = skip_start_step <= self._step < skip_stop_step
|
| 310 |
+
|
| 311 |
+
is_close = False
|
| 312 |
+
if self.use_original_formulation:
|
| 313 |
+
is_close = all(math.isclose(guidance_scale, 0.0) for guidance_scale in self.guidance_scales)
|
| 314 |
+
else:
|
| 315 |
+
is_close = all(math.isclose(guidance_scale, 1.0) for guidance_scale in self.guidance_scales)
|
| 316 |
+
|
| 317 |
+
return is_within_range and not is_close
|
| 318 |
+
|
| 319 |
+
def _is_fdg_enabled_for_level(self, level: int) -> bool:
|
| 320 |
+
if not self._enabled:
|
| 321 |
+
return False
|
| 322 |
+
|
| 323 |
+
is_within_range = True
|
| 324 |
+
if self._num_inference_steps is not None:
|
| 325 |
+
skip_start_step = int(self.guidance_start[level] * self._num_inference_steps)
|
| 326 |
+
skip_stop_step = int(self.guidance_stop[level] * self._num_inference_steps)
|
| 327 |
+
is_within_range = skip_start_step <= self._step < skip_stop_step
|
| 328 |
+
|
| 329 |
+
is_close = False
|
| 330 |
+
if self.use_original_formulation:
|
| 331 |
+
is_close = math.isclose(self.guidance_scales[level], 0.0)
|
| 332 |
+
else:
|
| 333 |
+
is_close = math.isclose(self.guidance_scales[level], 1.0)
|
| 334 |
+
|
| 335 |
+
return is_within_range and not is_close
|
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/guider_utils.py
ADDED
|
@@ -0,0 +1,396 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import os
|
| 18 |
+
from typing import TYPE_CHECKING, Any
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
from huggingface_hub.utils import validate_hf_hub_args
|
| 22 |
+
from typing_extensions import Self
|
| 23 |
+
|
| 24 |
+
from ..configuration_utils import ConfigMixin
|
| 25 |
+
from ..utils import BaseOutput, PushToHubMixin, get_logger
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
if TYPE_CHECKING:
|
| 29 |
+
from ..modular_pipelines.modular_pipeline import BlockState
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
GUIDER_CONFIG_NAME = "guider_config.json"
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
logger = get_logger(__name__) # pylint: disable=invalid-name
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class BaseGuidance(ConfigMixin, PushToHubMixin):
|
| 39 |
+
r"""Base class providing the skeleton for implementing guidance techniques."""
|
| 40 |
+
|
| 41 |
+
config_name = GUIDER_CONFIG_NAME
|
| 42 |
+
_input_predictions = None
|
| 43 |
+
_identifier_key = "__guidance_identifier__"
|
| 44 |
+
|
| 45 |
+
def __init__(self, start: float = 0.0, stop: float = 1.0, enabled: bool = True):
|
| 46 |
+
logger.warning(
|
| 47 |
+
"Guiders are currently an experimental feature under active development. The API is subject to breaking changes in future releases."
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
self._start = start
|
| 51 |
+
self._stop = stop
|
| 52 |
+
self._step: int = None
|
| 53 |
+
self._num_inference_steps: int = None
|
| 54 |
+
self._timestep: torch.LongTensor = None
|
| 55 |
+
self._count_prepared = 0
|
| 56 |
+
self._input_fields: dict[str, str | tuple[str, str]] = None
|
| 57 |
+
self._enabled = enabled
|
| 58 |
+
|
| 59 |
+
if not (0.0 <= start < 1.0):
|
| 60 |
+
raise ValueError(f"Expected `start` to be between 0.0 and 1.0, but got {start}.")
|
| 61 |
+
if not (start <= stop <= 1.0):
|
| 62 |
+
raise ValueError(f"Expected `stop` to be between {start} and 1.0, but got {stop}.")
|
| 63 |
+
|
| 64 |
+
if self._input_predictions is None or not isinstance(self._input_predictions, list):
|
| 65 |
+
raise ValueError(
|
| 66 |
+
"`_input_predictions` must be a list of required prediction names for the guidance technique."
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
def new(self, **kwargs):
|
| 70 |
+
"""
|
| 71 |
+
Creates a copy of this guider instance, optionally with modified configuration parameters.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
**kwargs: Configuration parameters to override in the new instance. If no kwargs are provided,
|
| 75 |
+
returns an exact copy with the same configuration.
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
A new guider instance with the same (or updated) configuration.
|
| 79 |
+
|
| 80 |
+
Example:
|
| 81 |
+
```python
|
| 82 |
+
# Create a CFG guider
|
| 83 |
+
guider = ClassifierFreeGuidance(guidance_scale=3.5)
|
| 84 |
+
|
| 85 |
+
# Create an exact copy
|
| 86 |
+
same_guider = guider.new()
|
| 87 |
+
|
| 88 |
+
# Create a copy with different start step, keeping other config the same
|
| 89 |
+
new_guider = guider.new(guidance_scale=5)
|
| 90 |
+
```
|
| 91 |
+
"""
|
| 92 |
+
return self.__class__.from_config(self.config, **kwargs)
|
| 93 |
+
|
| 94 |
+
def disable(self):
|
| 95 |
+
self._enabled = False
|
| 96 |
+
|
| 97 |
+
def enable(self):
|
| 98 |
+
self._enabled = True
|
| 99 |
+
|
| 100 |
+
def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None:
|
| 101 |
+
self._step = step
|
| 102 |
+
self._num_inference_steps = num_inference_steps
|
| 103 |
+
self._timestep = timestep
|
| 104 |
+
self._count_prepared = 0
|
| 105 |
+
|
| 106 |
+
def get_state(self) -> dict[str, Any]:
|
| 107 |
+
"""
|
| 108 |
+
Returns the current state of the guidance technique as a dictionary. The state variables will be included in
|
| 109 |
+
the __repr__ method. Returns:
|
| 110 |
+
`dict[str, Any]`: A dictionary containing the current state variables including:
|
| 111 |
+
- step: Current inference step
|
| 112 |
+
- num_inference_steps: Total number of inference steps
|
| 113 |
+
- timestep: Current timestep tensor
|
| 114 |
+
- count_prepared: Number of times prepare_models has been called
|
| 115 |
+
- enabled: Whether the guidance is enabled
|
| 116 |
+
- num_conditions: Number of conditions
|
| 117 |
+
"""
|
| 118 |
+
state = {
|
| 119 |
+
"step": self._step,
|
| 120 |
+
"num_inference_steps": self._num_inference_steps,
|
| 121 |
+
"timestep": self._timestep,
|
| 122 |
+
"count_prepared": self._count_prepared,
|
| 123 |
+
"enabled": self._enabled,
|
| 124 |
+
"num_conditions": self.num_conditions,
|
| 125 |
+
}
|
| 126 |
+
return state
|
| 127 |
+
|
| 128 |
+
def __repr__(self) -> str:
|
| 129 |
+
"""
|
| 130 |
+
Returns a string representation of the guidance object including both config and current state.
|
| 131 |
+
"""
|
| 132 |
+
# Get ConfigMixin's __repr__
|
| 133 |
+
str_repr = super().__repr__()
|
| 134 |
+
|
| 135 |
+
# Get current state
|
| 136 |
+
state = self.get_state()
|
| 137 |
+
|
| 138 |
+
# Format each state variable on its own line with indentation
|
| 139 |
+
state_lines = []
|
| 140 |
+
for k, v in state.items():
|
| 141 |
+
# Convert value to string and handle multi-line values
|
| 142 |
+
v_str = str(v)
|
| 143 |
+
if "\n" in v_str:
|
| 144 |
+
# For multi-line values (like MomentumBuffer), indent subsequent lines
|
| 145 |
+
v_lines = v_str.split("\n")
|
| 146 |
+
v_str = v_lines[0] + "\n" + "\n".join([" " + line for line in v_lines[1:]])
|
| 147 |
+
state_lines.append(f" {k}: {v_str}")
|
| 148 |
+
|
| 149 |
+
state_str = "\n".join(state_lines)
|
| 150 |
+
|
| 151 |
+
return f"{str_repr}\nState:\n{state_str}"
|
| 152 |
+
|
| 153 |
+
def prepare_models(self, denoiser: torch.nn.Module) -> None:
|
| 154 |
+
"""
|
| 155 |
+
Prepares the models for the guidance technique on a given batch of data. This method should be overridden in
|
| 156 |
+
subclasses to implement specific model preparation logic.
|
| 157 |
+
"""
|
| 158 |
+
self._count_prepared += 1
|
| 159 |
+
|
| 160 |
+
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
|
| 161 |
+
"""
|
| 162 |
+
Cleans up the models for the guidance technique after a given batch of data. This method should be overridden
|
| 163 |
+
in subclasses to implement specific model cleanup logic. It is useful for removing any hooks or other stateful
|
| 164 |
+
modifications made during `prepare_models`.
|
| 165 |
+
"""
|
| 166 |
+
pass
|
| 167 |
+
|
| 168 |
+
def prepare_inputs(self, data: "BlockState") -> list["BlockState"]:
|
| 169 |
+
raise NotImplementedError("BaseGuidance::prepare_inputs must be implemented in subclasses.")
|
| 170 |
+
|
| 171 |
+
def prepare_inputs_from_block_state(
|
| 172 |
+
self, data: "BlockState", input_fields: dict[str, str | tuple[str, str]]
|
| 173 |
+
) -> list["BlockState"]:
|
| 174 |
+
raise NotImplementedError("BaseGuidance::prepare_inputs_from_block_state must be implemented in subclasses.")
|
| 175 |
+
|
| 176 |
+
def __call__(self, data: list["BlockState"]) -> Any:
|
| 177 |
+
if not all(hasattr(d, "noise_pred") for d in data):
|
| 178 |
+
raise ValueError("Expected all data to have `noise_pred` attribute.")
|
| 179 |
+
if len(data) != self.num_conditions:
|
| 180 |
+
raise ValueError(
|
| 181 |
+
f"Expected {self.num_conditions} data items, but got {len(data)}. Please check the input data."
|
| 182 |
+
)
|
| 183 |
+
forward_inputs = {getattr(d, self._identifier_key): d.noise_pred for d in data}
|
| 184 |
+
return self.forward(**forward_inputs)
|
| 185 |
+
|
| 186 |
+
def forward(self, *args, **kwargs) -> Any:
|
| 187 |
+
raise NotImplementedError("BaseGuidance::forward must be implemented in subclasses.")
|
| 188 |
+
|
| 189 |
+
@property
|
| 190 |
+
def is_conditional(self) -> bool:
|
| 191 |
+
raise NotImplementedError("BaseGuidance::is_conditional must be implemented in subclasses.")
|
| 192 |
+
|
| 193 |
+
@property
|
| 194 |
+
def is_unconditional(self) -> bool:
|
| 195 |
+
return not self.is_conditional
|
| 196 |
+
|
| 197 |
+
@property
|
| 198 |
+
def num_conditions(self) -> int:
|
| 199 |
+
raise NotImplementedError("BaseGuidance::num_conditions must be implemented in subclasses.")
|
| 200 |
+
|
| 201 |
+
@classmethod
|
| 202 |
+
def _prepare_batch(
|
| 203 |
+
cls,
|
| 204 |
+
data: dict[str, tuple[torch.Tensor, torch.Tensor]],
|
| 205 |
+
tuple_index: int,
|
| 206 |
+
identifier: str,
|
| 207 |
+
) -> "BlockState":
|
| 208 |
+
"""
|
| 209 |
+
Prepares a batch of data for the guidance technique. This method is used in the `prepare_inputs` method of the
|
| 210 |
+
`BaseGuidance` class. It prepares the batch based on the provided tuple index.
|
| 211 |
+
|
| 212 |
+
Args:
|
| 213 |
+
input_fields (`dict[str, str | tuple[str, str]]`):
|
| 214 |
+
A dictionary where the keys are the names of the fields that will be used to store the data once it is
|
| 215 |
+
prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, which is used
|
| 216 |
+
to look up the required data provided for preparation. If a string is provided, it will be used as the
|
| 217 |
+
conditional data (or unconditional if used with a guidance method that requires it). If a tuple of
|
| 218 |
+
length 2 is provided, the first element must be the conditional data identifier and the second element
|
| 219 |
+
must be the unconditional data identifier or None.
|
| 220 |
+
data (`BlockState`):
|
| 221 |
+
The input data to be prepared.
|
| 222 |
+
tuple_index (`int`):
|
| 223 |
+
The index to use when accessing input fields that are tuples.
|
| 224 |
+
|
| 225 |
+
Returns:
|
| 226 |
+
`BlockState`: The prepared batch of data.
|
| 227 |
+
"""
|
| 228 |
+
from ..modular_pipelines.modular_pipeline import BlockState
|
| 229 |
+
|
| 230 |
+
data_batch = {}
|
| 231 |
+
for key, value in data.items():
|
| 232 |
+
try:
|
| 233 |
+
if isinstance(value, torch.Tensor):
|
| 234 |
+
data_batch[key] = value
|
| 235 |
+
elif isinstance(value, tuple):
|
| 236 |
+
data_batch[key] = value[tuple_index]
|
| 237 |
+
else:
|
| 238 |
+
raise ValueError(f"Invalid value type: {type(value)}")
|
| 239 |
+
except ValueError:
|
| 240 |
+
logger.debug(f"`data` does not have attribute(s) {value}, skipping.")
|
| 241 |
+
data_batch[cls._identifier_key] = identifier
|
| 242 |
+
return BlockState(**data_batch)
|
| 243 |
+
|
| 244 |
+
@classmethod
|
| 245 |
+
def _prepare_batch_from_block_state(
|
| 246 |
+
cls,
|
| 247 |
+
input_fields: dict[str, str | tuple[str, str]],
|
| 248 |
+
data: "BlockState",
|
| 249 |
+
tuple_index: int,
|
| 250 |
+
identifier: str,
|
| 251 |
+
) -> "BlockState":
|
| 252 |
+
"""
|
| 253 |
+
Prepares a batch of data for the guidance technique. This method is used in the `prepare_inputs` method of the
|
| 254 |
+
`BaseGuidance` class. It prepares the batch based on the provided tuple index.
|
| 255 |
+
|
| 256 |
+
Args:
|
| 257 |
+
input_fields (`dict[str, str | tuple[str, str]]`):
|
| 258 |
+
A dictionary where the keys are the names of the fields that will be used to store the data once it is
|
| 259 |
+
prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, which is used
|
| 260 |
+
to look up the required data provided for preparation. If a string is provided, it will be used as the
|
| 261 |
+
conditional data (or unconditional if used with a guidance method that requires it). If a tuple of
|
| 262 |
+
length 2 is provided, the first element must be the conditional data identifier and the second element
|
| 263 |
+
must be the unconditional data identifier or None.
|
| 264 |
+
data (`BlockState`):
|
| 265 |
+
The input data to be prepared.
|
| 266 |
+
tuple_index (`int`):
|
| 267 |
+
The index to use when accessing input fields that are tuples.
|
| 268 |
+
|
| 269 |
+
Returns:
|
| 270 |
+
`BlockState`: The prepared batch of data.
|
| 271 |
+
"""
|
| 272 |
+
from ..modular_pipelines.modular_pipeline import BlockState
|
| 273 |
+
|
| 274 |
+
data_batch = {}
|
| 275 |
+
for key, value in input_fields.items():
|
| 276 |
+
try:
|
| 277 |
+
if isinstance(value, str):
|
| 278 |
+
data_batch[key] = getattr(data, value)
|
| 279 |
+
elif isinstance(value, tuple):
|
| 280 |
+
data_batch[key] = getattr(data, value[tuple_index])
|
| 281 |
+
else:
|
| 282 |
+
# We've already checked that value is a string or a tuple of strings with length 2
|
| 283 |
+
pass
|
| 284 |
+
except AttributeError:
|
| 285 |
+
logger.debug(f"`data` does not have attribute(s) {value}, skipping.")
|
| 286 |
+
data_batch[cls._identifier_key] = identifier
|
| 287 |
+
return BlockState(**data_batch)
|
| 288 |
+
|
| 289 |
+
@classmethod
|
| 290 |
+
@validate_hf_hub_args
|
| 291 |
+
def from_pretrained(
|
| 292 |
+
cls,
|
| 293 |
+
pretrained_model_name_or_path: str | os.PathLike | None = None,
|
| 294 |
+
subfolder: str | None = None,
|
| 295 |
+
return_unused_kwargs=False,
|
| 296 |
+
**kwargs,
|
| 297 |
+
) -> Self:
|
| 298 |
+
r"""
|
| 299 |
+
Instantiate a guider from a pre-defined JSON configuration file in a local directory or Hub repository.
|
| 300 |
+
|
| 301 |
+
Parameters:
|
| 302 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
|
| 303 |
+
Can be either:
|
| 304 |
+
|
| 305 |
+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
| 306 |
+
the Hub.
|
| 307 |
+
- A path to a *directory* (for example `./my_model_directory`) containing the guider configuration
|
| 308 |
+
saved with [`~BaseGuidance.save_pretrained`].
|
| 309 |
+
subfolder (`str`, *optional*):
|
| 310 |
+
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
| 311 |
+
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
|
| 312 |
+
Whether kwargs that are not consumed by the Python class should be returned or not.
|
| 313 |
+
cache_dir (`str | os.PathLike`, *optional*):
|
| 314 |
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
| 315 |
+
is not used.
|
| 316 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
| 317 |
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
| 318 |
+
cached versions if they exist.
|
| 319 |
+
|
| 320 |
+
proxies (`dict[str, str]`, *optional*):
|
| 321 |
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
| 322 |
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
| 323 |
+
output_loading_info(`bool`, *optional*, defaults to `False`):
|
| 324 |
+
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
| 325 |
+
local_files_only(`bool`, *optional*, defaults to `False`):
|
| 326 |
+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
| 327 |
+
won't be downloaded from the Hub.
|
| 328 |
+
token (`str` or *bool*, *optional*):
|
| 329 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
| 330 |
+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
| 331 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
| 332 |
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
| 333 |
+
allowed by Git.
|
| 334 |
+
|
| 335 |
+
> [!TIP] > To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in
|
| 336 |
+
with `hf > auth login`. You can also activate the special >
|
| 337 |
+
["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a >
|
| 338 |
+
firewalled environment.
|
| 339 |
+
|
| 340 |
+
"""
|
| 341 |
+
config, kwargs, commit_hash = cls.load_config(
|
| 342 |
+
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
| 343 |
+
subfolder=subfolder,
|
| 344 |
+
return_unused_kwargs=True,
|
| 345 |
+
return_commit_hash=True,
|
| 346 |
+
**kwargs,
|
| 347 |
+
)
|
| 348 |
+
return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs)
|
| 349 |
+
|
| 350 |
+
def save_pretrained(self, save_directory: str | os.PathLike, push_to_hub: bool = False, **kwargs):
|
| 351 |
+
"""
|
| 352 |
+
Save a guider configuration object to a directory so that it can be reloaded using the
|
| 353 |
+
[`~BaseGuidance.from_pretrained`] class method.
|
| 354 |
+
|
| 355 |
+
Args:
|
| 356 |
+
save_directory (`str` or `os.PathLike`):
|
| 357 |
+
Directory where the configuration JSON file will be saved (will be created if it does not exist).
|
| 358 |
+
push_to_hub (`bool`, *optional*, defaults to `False`):
|
| 359 |
+
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
|
| 360 |
+
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
| 361 |
+
namespace).
|
| 362 |
+
kwargs (`dict[str, Any]`, *optional*):
|
| 363 |
+
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
| 364 |
+
"""
|
| 365 |
+
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
class GuiderOutput(BaseOutput):
|
| 369 |
+
pred: torch.Tensor
|
| 370 |
+
pred_cond: torch.Tensor | None
|
| 371 |
+
pred_uncond: torch.Tensor | None
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
| 375 |
+
r"""
|
| 376 |
+
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
|
| 377 |
+
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
| 378 |
+
Flawed](https://huggingface.co/papers/2305.08891).
|
| 379 |
+
|
| 380 |
+
Args:
|
| 381 |
+
noise_cfg (`torch.Tensor`):
|
| 382 |
+
The predicted noise tensor for the guided diffusion process.
|
| 383 |
+
noise_pred_text (`torch.Tensor`):
|
| 384 |
+
The predicted noise tensor for the text-guided diffusion process.
|
| 385 |
+
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
| 386 |
+
A rescale factor applied to the noise predictions.
|
| 387 |
+
Returns:
|
| 388 |
+
noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
|
| 389 |
+
"""
|
| 390 |
+
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
| 391 |
+
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
| 392 |
+
# rescale the results from guidance (fixes overexposure)
|
| 393 |
+
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
| 394 |
+
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
| 395 |
+
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
| 396 |
+
return noise_cfg
|
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/magnitude_aware_guidance.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import math
|
| 16 |
+
from typing import TYPE_CHECKING
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
from ..configuration_utils import register_to_config
|
| 21 |
+
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
if TYPE_CHECKING:
|
| 25 |
+
from ..modular_pipelines.modular_pipeline import BlockState
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class MagnitudeAwareGuidance(BaseGuidance):
|
| 29 |
+
"""
|
| 30 |
+
Magnitude-Aware Mitigation for Boosted Guidance (MAMBO-G): https://huggingface.co/papers/2508.03442
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
guidance_scale (`float`, defaults to `10.0`):
|
| 34 |
+
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
| 35 |
+
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
| 36 |
+
deterioration of image quality.
|
| 37 |
+
alpha (`float`, defaults to `8.0`):
|
| 38 |
+
The alpha parameter for the magnitude-aware guidance. Higher values cause more aggressive supression of
|
| 39 |
+
guidance scale when the magnitude of the guidance update is large.
|
| 40 |
+
guidance_rescale (`float`, defaults to `0.0`):
|
| 41 |
+
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
| 42 |
+
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
| 43 |
+
Flawed](https://huggingface.co/papers/2305.08891).
|
| 44 |
+
use_original_formulation (`bool`, defaults to `False`):
|
| 45 |
+
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
| 46 |
+
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
| 47 |
+
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
| 48 |
+
start (`float`, defaults to `0.0`):
|
| 49 |
+
The fraction of the total number of denoising steps after which guidance starts.
|
| 50 |
+
stop (`float`, defaults to `1.0`):
|
| 51 |
+
The fraction of the total number of denoising steps after which guidance stops.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
_input_predictions = ["pred_cond", "pred_uncond"]
|
| 55 |
+
|
| 56 |
+
@register_to_config
|
| 57 |
+
def __init__(
|
| 58 |
+
self,
|
| 59 |
+
guidance_scale: float = 10.0,
|
| 60 |
+
alpha: float = 8.0,
|
| 61 |
+
guidance_rescale: float = 0.0,
|
| 62 |
+
use_original_formulation: bool = False,
|
| 63 |
+
start: float = 0.0,
|
| 64 |
+
stop: float = 1.0,
|
| 65 |
+
enabled: bool = True,
|
| 66 |
+
):
|
| 67 |
+
super().__init__(start, stop, enabled)
|
| 68 |
+
|
| 69 |
+
self.guidance_scale = guidance_scale
|
| 70 |
+
self.alpha = alpha
|
| 71 |
+
self.guidance_rescale = guidance_rescale
|
| 72 |
+
self.use_original_formulation = use_original_formulation
|
| 73 |
+
|
| 74 |
+
def prepare_inputs(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]) -> list["BlockState"]:
|
| 75 |
+
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
| 76 |
+
data_batches = []
|
| 77 |
+
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
| 78 |
+
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
|
| 79 |
+
data_batches.append(data_batch)
|
| 80 |
+
return data_batches
|
| 81 |
+
|
| 82 |
+
def prepare_inputs_from_block_state(
|
| 83 |
+
self, data: "BlockState", input_fields: dict[str, str | tuple[str, str]]
|
| 84 |
+
) -> list["BlockState"]:
|
| 85 |
+
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
| 86 |
+
data_batches = []
|
| 87 |
+
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
| 88 |
+
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
|
| 89 |
+
data_batches.append(data_batch)
|
| 90 |
+
return data_batches
|
| 91 |
+
|
| 92 |
+
def forward(self, pred_cond: torch.Tensor, pred_uncond: torch.Tensor | None = None) -> GuiderOutput:
|
| 93 |
+
pred = None
|
| 94 |
+
|
| 95 |
+
if not self._is_mambo_g_enabled():
|
| 96 |
+
pred = pred_cond
|
| 97 |
+
else:
|
| 98 |
+
pred = mambo_guidance(
|
| 99 |
+
pred_cond,
|
| 100 |
+
pred_uncond,
|
| 101 |
+
self.guidance_scale,
|
| 102 |
+
self.alpha,
|
| 103 |
+
self.use_original_formulation,
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
if self.guidance_rescale > 0.0:
|
| 107 |
+
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
| 108 |
+
|
| 109 |
+
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
|
| 110 |
+
|
| 111 |
+
@property
|
| 112 |
+
def is_conditional(self) -> bool:
|
| 113 |
+
return self._count_prepared == 1
|
| 114 |
+
|
| 115 |
+
@property
|
| 116 |
+
def num_conditions(self) -> int:
|
| 117 |
+
num_conditions = 1
|
| 118 |
+
if self._is_mambo_g_enabled():
|
| 119 |
+
num_conditions += 1
|
| 120 |
+
return num_conditions
|
| 121 |
+
|
| 122 |
+
def _is_mambo_g_enabled(self) -> bool:
|
| 123 |
+
if not self._enabled:
|
| 124 |
+
return False
|
| 125 |
+
|
| 126 |
+
is_within_range = True
|
| 127 |
+
if self._num_inference_steps is not None:
|
| 128 |
+
skip_start_step = int(self._start * self._num_inference_steps)
|
| 129 |
+
skip_stop_step = int(self._stop * self._num_inference_steps)
|
| 130 |
+
is_within_range = skip_start_step <= self._step < skip_stop_step
|
| 131 |
+
|
| 132 |
+
is_close = False
|
| 133 |
+
if self.use_original_formulation:
|
| 134 |
+
is_close = math.isclose(self.guidance_scale, 0.0)
|
| 135 |
+
else:
|
| 136 |
+
is_close = math.isclose(self.guidance_scale, 1.0)
|
| 137 |
+
|
| 138 |
+
return is_within_range and not is_close
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def mambo_guidance(
|
| 142 |
+
pred_cond: torch.Tensor,
|
| 143 |
+
pred_uncond: torch.Tensor,
|
| 144 |
+
guidance_scale: float,
|
| 145 |
+
alpha: float = 8.0,
|
| 146 |
+
use_original_formulation: bool = False,
|
| 147 |
+
):
|
| 148 |
+
dim = list(range(1, len(pred_cond.shape)))
|
| 149 |
+
diff = pred_cond - pred_uncond
|
| 150 |
+
ratio = torch.norm(diff, dim=dim, keepdim=True) / torch.norm(pred_uncond, dim=dim, keepdim=True)
|
| 151 |
+
guidance_scale_final = (
|
| 152 |
+
guidance_scale * torch.exp(-alpha * ratio)
|
| 153 |
+
if use_original_formulation
|
| 154 |
+
else 1.0 + (guidance_scale - 1.0) * torch.exp(-alpha * ratio)
|
| 155 |
+
)
|
| 156 |
+
pred = pred_cond if use_original_formulation else pred_uncond
|
| 157 |
+
pred = pred + guidance_scale_final * diff
|
| 158 |
+
|
| 159 |
+
return pred
|
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/perturbed_attention_guidance.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
from typing import TYPE_CHECKING, Any
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
|
| 22 |
+
from ..configuration_utils import register_to_config
|
| 23 |
+
from ..hooks import HookRegistry, LayerSkipConfig
|
| 24 |
+
from ..hooks.layer_skip import _apply_layer_skip_hook
|
| 25 |
+
from ..utils import get_logger
|
| 26 |
+
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
if TYPE_CHECKING:
|
| 30 |
+
from ..modular_pipelines.modular_pipeline import BlockState
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
logger = get_logger(__name__) # pylint: disable=invalid-name
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class PerturbedAttentionGuidance(BaseGuidance):
|
| 37 |
+
"""
|
| 38 |
+
Perturbed Attention Guidance (PAG): https://huggingface.co/papers/2403.17377
|
| 39 |
+
|
| 40 |
+
The intution behind PAG can be thought of as moving the CFG predicted distribution estimates further away from
|
| 41 |
+
worse versions of the conditional distribution estimates. PAG was one of the first techniques to introduce the idea
|
| 42 |
+
of using a worse version of the trained model for better guiding itself in the denoising process. It perturbs the
|
| 43 |
+
attention scores of the latent stream by replacing the score matrix with an identity matrix for selectively chosen
|
| 44 |
+
layers.
|
| 45 |
+
|
| 46 |
+
Additional reading:
|
| 47 |
+
- [Guiding a Diffusion Model with a Bad Version of Itself](https://huggingface.co/papers/2406.02507)
|
| 48 |
+
|
| 49 |
+
PAG is implemented with similar implementation to SkipLayerGuidance due to overlap in the configuration parameters
|
| 50 |
+
and implementation details.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
guidance_scale (`float`, defaults to `7.5`):
|
| 54 |
+
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
| 55 |
+
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
| 56 |
+
deterioration of image quality.
|
| 57 |
+
perturbed_guidance_scale (`float`, defaults to `2.8`):
|
| 58 |
+
The scale parameter for perturbed attention guidance.
|
| 59 |
+
perturbed_guidance_start (`float`, defaults to `0.01`):
|
| 60 |
+
The fraction of the total number of denoising steps after which perturbed attention guidance starts.
|
| 61 |
+
perturbed_guidance_stop (`float`, defaults to `0.2`):
|
| 62 |
+
The fraction of the total number of denoising steps after which perturbed attention guidance stops.
|
| 63 |
+
perturbed_guidance_layers (`int` or `list[int]`, *optional*):
|
| 64 |
+
The layer indices to apply perturbed attention guidance to. Can be a single integer or a list of integers.
|
| 65 |
+
If not provided, `perturbed_guidance_config` must be provided.
|
| 66 |
+
perturbed_guidance_config (`LayerSkipConfig` or `list[LayerSkipConfig]`, *optional*):
|
| 67 |
+
The configuration for the perturbed attention guidance. Can be a single `LayerSkipConfig` or a list of
|
| 68 |
+
`LayerSkipConfig`. If not provided, `perturbed_guidance_layers` must be provided.
|
| 69 |
+
guidance_rescale (`float`, defaults to `0.0`):
|
| 70 |
+
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
| 71 |
+
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
| 72 |
+
Flawed](https://huggingface.co/papers/2305.08891).
|
| 73 |
+
use_original_formulation (`bool`, defaults to `False`):
|
| 74 |
+
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
| 75 |
+
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
| 76 |
+
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
| 77 |
+
start (`float`, defaults to `0.01`):
|
| 78 |
+
The fraction of the total number of denoising steps after which guidance starts.
|
| 79 |
+
stop (`float`, defaults to `0.2`):
|
| 80 |
+
The fraction of the total number of denoising steps after which guidance stops.
|
| 81 |
+
"""
|
| 82 |
+
|
| 83 |
+
# NOTE: The current implementation does not account for joint latent conditioning (text + image/video tokens in
|
| 84 |
+
# the same latent stream). It assumes the entire latent is a single stream of visual tokens. It would be very
|
| 85 |
+
# complex to support joint latent conditioning in a model-agnostic manner without specializing the implementation
|
| 86 |
+
# for each model architecture.
|
| 87 |
+
|
| 88 |
+
_input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
|
| 89 |
+
|
| 90 |
+
@register_to_config
|
| 91 |
+
def __init__(
|
| 92 |
+
self,
|
| 93 |
+
guidance_scale: float = 7.5,
|
| 94 |
+
perturbed_guidance_scale: float = 2.8,
|
| 95 |
+
perturbed_guidance_start: float = 0.01,
|
| 96 |
+
perturbed_guidance_stop: float = 0.2,
|
| 97 |
+
perturbed_guidance_layers: int | list[int] | None = None,
|
| 98 |
+
perturbed_guidance_config: LayerSkipConfig | list[LayerSkipConfig] | dict[str, Any] = None,
|
| 99 |
+
guidance_rescale: float = 0.0,
|
| 100 |
+
use_original_formulation: bool = False,
|
| 101 |
+
start: float = 0.0,
|
| 102 |
+
stop: float = 1.0,
|
| 103 |
+
enabled: bool = True,
|
| 104 |
+
):
|
| 105 |
+
super().__init__(start, stop, enabled)
|
| 106 |
+
|
| 107 |
+
self.guidance_scale = guidance_scale
|
| 108 |
+
self.skip_layer_guidance_scale = perturbed_guidance_scale
|
| 109 |
+
self.skip_layer_guidance_start = perturbed_guidance_start
|
| 110 |
+
self.skip_layer_guidance_stop = perturbed_guidance_stop
|
| 111 |
+
self.guidance_rescale = guidance_rescale
|
| 112 |
+
self.use_original_formulation = use_original_formulation
|
| 113 |
+
|
| 114 |
+
if perturbed_guidance_config is None:
|
| 115 |
+
if perturbed_guidance_layers is None:
|
| 116 |
+
raise ValueError(
|
| 117 |
+
"`perturbed_guidance_layers` must be provided if `perturbed_guidance_config` is not specified."
|
| 118 |
+
)
|
| 119 |
+
perturbed_guidance_config = LayerSkipConfig(
|
| 120 |
+
indices=perturbed_guidance_layers,
|
| 121 |
+
fqn="auto",
|
| 122 |
+
skip_attention=False,
|
| 123 |
+
skip_attention_scores=True,
|
| 124 |
+
skip_ff=False,
|
| 125 |
+
)
|
| 126 |
+
else:
|
| 127 |
+
if perturbed_guidance_layers is not None:
|
| 128 |
+
raise ValueError(
|
| 129 |
+
"`perturbed_guidance_layers` should not be provided if `perturbed_guidance_config` is specified."
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
if isinstance(perturbed_guidance_config, dict):
|
| 133 |
+
perturbed_guidance_config = LayerSkipConfig.from_dict(perturbed_guidance_config)
|
| 134 |
+
|
| 135 |
+
if isinstance(perturbed_guidance_config, LayerSkipConfig):
|
| 136 |
+
perturbed_guidance_config = [perturbed_guidance_config]
|
| 137 |
+
|
| 138 |
+
if not isinstance(perturbed_guidance_config, list):
|
| 139 |
+
raise ValueError(
|
| 140 |
+
"`perturbed_guidance_config` must be a `LayerSkipConfig`, a list of `LayerSkipConfig`, or a dict that can be converted to a `LayerSkipConfig`."
|
| 141 |
+
)
|
| 142 |
+
elif isinstance(next(iter(perturbed_guidance_config), None), dict):
|
| 143 |
+
perturbed_guidance_config = [LayerSkipConfig.from_dict(config) for config in perturbed_guidance_config]
|
| 144 |
+
|
| 145 |
+
for config in perturbed_guidance_config:
|
| 146 |
+
if config.skip_attention or not config.skip_attention_scores or config.skip_ff:
|
| 147 |
+
logger.warning(
|
| 148 |
+
"Perturbed Attention Guidance is designed to perturb attention scores, so `skip_attention` should be False, `skip_attention_scores` should be True, and `skip_ff` should be False. "
|
| 149 |
+
"Please check your configuration. Modifying the config to match the expected values."
|
| 150 |
+
)
|
| 151 |
+
config.skip_attention = False
|
| 152 |
+
config.skip_attention_scores = True
|
| 153 |
+
config.skip_ff = False
|
| 154 |
+
|
| 155 |
+
self.skip_layer_config = perturbed_guidance_config
|
| 156 |
+
self._skip_layer_hook_names = [f"SkipLayerGuidance_{i}" for i in range(len(self.skip_layer_config))]
|
| 157 |
+
|
| 158 |
+
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.prepare_models
|
| 159 |
+
def prepare_models(self, denoiser: torch.nn.Module) -> None:
|
| 160 |
+
self._count_prepared += 1
|
| 161 |
+
if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
|
| 162 |
+
for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config):
|
| 163 |
+
_apply_layer_skip_hook(denoiser, config, name=name)
|
| 164 |
+
|
| 165 |
+
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.cleanup_models
|
| 166 |
+
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
|
| 167 |
+
if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
|
| 168 |
+
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
|
| 169 |
+
# Remove the hooks after inference
|
| 170 |
+
for hook_name in self._skip_layer_hook_names:
|
| 171 |
+
registry.remove_hook(hook_name, recurse=True)
|
| 172 |
+
|
| 173 |
+
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.prepare_inputs
|
| 174 |
+
def prepare_inputs(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]) -> list["BlockState"]:
|
| 175 |
+
if self.num_conditions == 1:
|
| 176 |
+
tuple_indices = [0]
|
| 177 |
+
input_predictions = ["pred_cond"]
|
| 178 |
+
elif self.num_conditions == 2:
|
| 179 |
+
tuple_indices = [0, 1]
|
| 180 |
+
input_predictions = (
|
| 181 |
+
["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"]
|
| 182 |
+
)
|
| 183 |
+
else:
|
| 184 |
+
tuple_indices = [0, 1, 0]
|
| 185 |
+
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
|
| 186 |
+
data_batches = []
|
| 187 |
+
for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
|
| 188 |
+
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
|
| 189 |
+
data_batches.append(data_batch)
|
| 190 |
+
return data_batches
|
| 191 |
+
|
| 192 |
+
def prepare_inputs_from_block_state(
|
| 193 |
+
self, data: "BlockState", input_fields: dict[str, str | tuple[str, str]]
|
| 194 |
+
) -> list["BlockState"]:
|
| 195 |
+
if self.num_conditions == 1:
|
| 196 |
+
tuple_indices = [0]
|
| 197 |
+
input_predictions = ["pred_cond"]
|
| 198 |
+
elif self.num_conditions == 2:
|
| 199 |
+
tuple_indices = [0, 1]
|
| 200 |
+
input_predictions = (
|
| 201 |
+
["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"]
|
| 202 |
+
)
|
| 203 |
+
else:
|
| 204 |
+
tuple_indices = [0, 1, 0]
|
| 205 |
+
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
|
| 206 |
+
data_batches = []
|
| 207 |
+
for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
|
| 208 |
+
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
|
| 209 |
+
data_batches.append(data_batch)
|
| 210 |
+
return data_batches
|
| 211 |
+
|
| 212 |
+
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.forward
|
| 213 |
+
def forward(
|
| 214 |
+
self,
|
| 215 |
+
pred_cond: torch.Tensor,
|
| 216 |
+
pred_uncond: torch.Tensor | None = None,
|
| 217 |
+
pred_cond_skip: torch.Tensor | None = None,
|
| 218 |
+
) -> GuiderOutput:
|
| 219 |
+
pred = None
|
| 220 |
+
|
| 221 |
+
if not self._is_cfg_enabled() and not self._is_slg_enabled():
|
| 222 |
+
pred = pred_cond
|
| 223 |
+
elif not self._is_cfg_enabled():
|
| 224 |
+
shift = pred_cond - pred_cond_skip
|
| 225 |
+
pred = pred_cond if self.use_original_formulation else pred_cond_skip
|
| 226 |
+
pred = pred + self.skip_layer_guidance_scale * shift
|
| 227 |
+
elif not self._is_slg_enabled():
|
| 228 |
+
shift = pred_cond - pred_uncond
|
| 229 |
+
pred = pred_cond if self.use_original_formulation else pred_uncond
|
| 230 |
+
pred = pred + self.guidance_scale * shift
|
| 231 |
+
else:
|
| 232 |
+
shift = pred_cond - pred_uncond
|
| 233 |
+
shift_skip = pred_cond - pred_cond_skip
|
| 234 |
+
pred = pred_cond if self.use_original_formulation else pred_uncond
|
| 235 |
+
pred = pred + self.guidance_scale * shift + self.skip_layer_guidance_scale * shift_skip
|
| 236 |
+
|
| 237 |
+
if self.guidance_rescale > 0.0:
|
| 238 |
+
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
| 239 |
+
|
| 240 |
+
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
|
| 241 |
+
|
| 242 |
+
@property
|
| 243 |
+
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.is_conditional
|
| 244 |
+
def is_conditional(self) -> bool:
|
| 245 |
+
return self._count_prepared == 1 or self._count_prepared == 3
|
| 246 |
+
|
| 247 |
+
@property
|
| 248 |
+
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.num_conditions
|
| 249 |
+
def num_conditions(self) -> int:
|
| 250 |
+
num_conditions = 1
|
| 251 |
+
if self._is_cfg_enabled():
|
| 252 |
+
num_conditions += 1
|
| 253 |
+
if self._is_slg_enabled():
|
| 254 |
+
num_conditions += 1
|
| 255 |
+
return num_conditions
|
| 256 |
+
|
| 257 |
+
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance._is_cfg_enabled
|
| 258 |
+
def _is_cfg_enabled(self) -> bool:
|
| 259 |
+
if not self._enabled:
|
| 260 |
+
return False
|
| 261 |
+
|
| 262 |
+
is_within_range = True
|
| 263 |
+
if self._num_inference_steps is not None:
|
| 264 |
+
skip_start_step = int(self._start * self._num_inference_steps)
|
| 265 |
+
skip_stop_step = int(self._stop * self._num_inference_steps)
|
| 266 |
+
is_within_range = skip_start_step <= self._step < skip_stop_step
|
| 267 |
+
|
| 268 |
+
is_close = False
|
| 269 |
+
if self.use_original_formulation:
|
| 270 |
+
is_close = math.isclose(self.guidance_scale, 0.0)
|
| 271 |
+
else:
|
| 272 |
+
is_close = math.isclose(self.guidance_scale, 1.0)
|
| 273 |
+
|
| 274 |
+
return is_within_range and not is_close
|
| 275 |
+
|
| 276 |
+
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance._is_slg_enabled
|
| 277 |
+
def _is_slg_enabled(self) -> bool:
|
| 278 |
+
if not self._enabled:
|
| 279 |
+
return False
|
| 280 |
+
|
| 281 |
+
is_within_range = True
|
| 282 |
+
if self._num_inference_steps is not None:
|
| 283 |
+
skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
|
| 284 |
+
skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
|
| 285 |
+
is_within_range = skip_start_step < self._step < skip_stop_step
|
| 286 |
+
|
| 287 |
+
is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0)
|
| 288 |
+
|
| 289 |
+
return is_within_range and not is_zero
|
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/skip_layer_guidance.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
from typing import TYPE_CHECKING, Any
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
|
| 22 |
+
from ..configuration_utils import register_to_config
|
| 23 |
+
from ..hooks import HookRegistry, LayerSkipConfig
|
| 24 |
+
from ..hooks.layer_skip import _apply_layer_skip_hook
|
| 25 |
+
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
if TYPE_CHECKING:
|
| 29 |
+
from ..modular_pipelines.modular_pipeline import BlockState
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class SkipLayerGuidance(BaseGuidance):
|
| 33 |
+
"""
|
| 34 |
+
Skip Layer Guidance (SLG): https://github.com/Stability-AI/sd3.5
|
| 35 |
+
|
| 36 |
+
Spatio-Temporal Guidance (STG): https://huggingface.co/papers/2411.18664
|
| 37 |
+
|
| 38 |
+
SLG was introduced by StabilityAI for improving structure and anotomy coherence in generated images. It works by
|
| 39 |
+
skipping the forward pass of specified transformer blocks during the denoising process on an additional conditional
|
| 40 |
+
batch of data, apart from the conditional and unconditional batches already used in CFG
|
| 41 |
+
([~guiders.classifier_free_guidance.ClassifierFreeGuidance]), and then scaling and shifting the CFG predictions
|
| 42 |
+
based on the difference between conditional without skipping and conditional with skipping predictions.
|
| 43 |
+
|
| 44 |
+
The intution behind SLG can be thought of as moving the CFG predicted distribution estimates further away from
|
| 45 |
+
worse versions of the conditional distribution estimates (because skipping layers is equivalent to using a worse
|
| 46 |
+
version of the model for the conditional prediction).
|
| 47 |
+
|
| 48 |
+
STG is an improvement and follow-up work combining ideas from SLG, PAG and similar techniques for improving
|
| 49 |
+
generation quality in video diffusion models.
|
| 50 |
+
|
| 51 |
+
Additional reading:
|
| 52 |
+
- [Guiding a Diffusion Model with a Bad Version of Itself](https://huggingface.co/papers/2406.02507)
|
| 53 |
+
|
| 54 |
+
The values for `skip_layer_guidance_scale`, `skip_layer_guidance_start`, and `skip_layer_guidance_stop` are
|
| 55 |
+
defaulted to the recommendations by StabilityAI for Stable Diffusion 3.5 Medium.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
guidance_scale (`float`, defaults to `7.5`):
|
| 59 |
+
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
| 60 |
+
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
| 61 |
+
deterioration of image quality.
|
| 62 |
+
skip_layer_guidance_scale (`float`, defaults to `2.8`):
|
| 63 |
+
The scale parameter for skip layer guidance. Anatomy and structure coherence may improve with higher
|
| 64 |
+
values, but it may also lead to overexposure and saturation.
|
| 65 |
+
skip_layer_guidance_start (`float`, defaults to `0.01`):
|
| 66 |
+
The fraction of the total number of denoising steps after which skip layer guidance starts.
|
| 67 |
+
skip_layer_guidance_stop (`float`, defaults to `0.2`):
|
| 68 |
+
The fraction of the total number of denoising steps after which skip layer guidance stops.
|
| 69 |
+
skip_layer_guidance_layers (`int` or `list[int]`, *optional*):
|
| 70 |
+
The layer indices to apply skip layer guidance to. Can be a single integer or a list of integers. If not
|
| 71 |
+
provided, `skip_layer_config` must be provided. The recommended values are `[7, 8, 9]` for Stable Diffusion
|
| 72 |
+
3.5 Medium.
|
| 73 |
+
skip_layer_config (`LayerSkipConfig` or `list[LayerSkipConfig]`, *optional*):
|
| 74 |
+
The configuration for the skip layer guidance. Can be a single `LayerSkipConfig` or a list of
|
| 75 |
+
`LayerSkipConfig`. If not provided, `skip_layer_guidance_layers` must be provided.
|
| 76 |
+
guidance_rescale (`float`, defaults to `0.0`):
|
| 77 |
+
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
| 78 |
+
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
| 79 |
+
Flawed](https://huggingface.co/papers/2305.08891).
|
| 80 |
+
use_original_formulation (`bool`, defaults to `False`):
|
| 81 |
+
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
| 82 |
+
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
| 83 |
+
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
| 84 |
+
start (`float`, defaults to `0.01`):
|
| 85 |
+
The fraction of the total number of denoising steps after which guidance starts.
|
| 86 |
+
stop (`float`, defaults to `0.2`):
|
| 87 |
+
The fraction of the total number of denoising steps after which guidance stops.
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
_input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
|
| 91 |
+
|
| 92 |
+
@register_to_config
|
| 93 |
+
def __init__(
|
| 94 |
+
self,
|
| 95 |
+
guidance_scale: float = 7.5,
|
| 96 |
+
skip_layer_guidance_scale: float = 2.8,
|
| 97 |
+
skip_layer_guidance_start: float = 0.01,
|
| 98 |
+
skip_layer_guidance_stop: float = 0.2,
|
| 99 |
+
skip_layer_guidance_layers: int | list[int] | None = None,
|
| 100 |
+
skip_layer_config: LayerSkipConfig | list[LayerSkipConfig] | dict[str, Any] = None,
|
| 101 |
+
guidance_rescale: float = 0.0,
|
| 102 |
+
use_original_formulation: bool = False,
|
| 103 |
+
start: float = 0.0,
|
| 104 |
+
stop: float = 1.0,
|
| 105 |
+
enabled: bool = True,
|
| 106 |
+
):
|
| 107 |
+
super().__init__(start, stop, enabled)
|
| 108 |
+
|
| 109 |
+
self.guidance_scale = guidance_scale
|
| 110 |
+
self.skip_layer_guidance_scale = skip_layer_guidance_scale
|
| 111 |
+
self.skip_layer_guidance_start = skip_layer_guidance_start
|
| 112 |
+
self.skip_layer_guidance_stop = skip_layer_guidance_stop
|
| 113 |
+
self.guidance_rescale = guidance_rescale
|
| 114 |
+
self.use_original_formulation = use_original_formulation
|
| 115 |
+
|
| 116 |
+
if not (0.0 <= skip_layer_guidance_start < 1.0):
|
| 117 |
+
raise ValueError(
|
| 118 |
+
f"Expected `skip_layer_guidance_start` to be between 0.0 and 1.0, but got {skip_layer_guidance_start}."
|
| 119 |
+
)
|
| 120 |
+
if not (skip_layer_guidance_start <= skip_layer_guidance_stop <= 1.0):
|
| 121 |
+
raise ValueError(
|
| 122 |
+
f"Expected `skip_layer_guidance_stop` to be between 0.0 and 1.0, but got {skip_layer_guidance_stop}."
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
if skip_layer_guidance_layers is None and skip_layer_config is None:
|
| 126 |
+
raise ValueError(
|
| 127 |
+
"Either `skip_layer_guidance_layers` or `skip_layer_config` must be provided to enable Skip Layer Guidance."
|
| 128 |
+
)
|
| 129 |
+
if skip_layer_guidance_layers is not None and skip_layer_config is not None:
|
| 130 |
+
raise ValueError("Only one of `skip_layer_guidance_layers` or `skip_layer_config` can be provided.")
|
| 131 |
+
|
| 132 |
+
if skip_layer_guidance_layers is not None:
|
| 133 |
+
if isinstance(skip_layer_guidance_layers, int):
|
| 134 |
+
skip_layer_guidance_layers = [skip_layer_guidance_layers]
|
| 135 |
+
if not isinstance(skip_layer_guidance_layers, list):
|
| 136 |
+
raise ValueError(
|
| 137 |
+
f"Expected `skip_layer_guidance_layers` to be an int or a list of ints, but got {type(skip_layer_guidance_layers)}."
|
| 138 |
+
)
|
| 139 |
+
skip_layer_config = [LayerSkipConfig(layer, fqn="auto") for layer in skip_layer_guidance_layers]
|
| 140 |
+
|
| 141 |
+
if isinstance(skip_layer_config, dict):
|
| 142 |
+
skip_layer_config = LayerSkipConfig.from_dict(skip_layer_config)
|
| 143 |
+
|
| 144 |
+
if isinstance(skip_layer_config, LayerSkipConfig):
|
| 145 |
+
skip_layer_config = [skip_layer_config]
|
| 146 |
+
|
| 147 |
+
if not isinstance(skip_layer_config, list):
|
| 148 |
+
raise ValueError(
|
| 149 |
+
f"Expected `skip_layer_config` to be a LayerSkipConfig or a list of LayerSkipConfig, but got {type(skip_layer_config)}."
|
| 150 |
+
)
|
| 151 |
+
elif isinstance(next(iter(skip_layer_config), None), dict):
|
| 152 |
+
skip_layer_config = [LayerSkipConfig.from_dict(config) for config in skip_layer_config]
|
| 153 |
+
|
| 154 |
+
self.skip_layer_config = skip_layer_config
|
| 155 |
+
self._skip_layer_hook_names = [f"SkipLayerGuidance_{i}" for i in range(len(self.skip_layer_config))]
|
| 156 |
+
|
| 157 |
+
def prepare_models(self, denoiser: torch.nn.Module) -> None:
|
| 158 |
+
self._count_prepared += 1
|
| 159 |
+
if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
|
| 160 |
+
for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config):
|
| 161 |
+
_apply_layer_skip_hook(denoiser, config, name=name)
|
| 162 |
+
|
| 163 |
+
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
|
| 164 |
+
if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
|
| 165 |
+
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
|
| 166 |
+
# Remove the hooks after inference
|
| 167 |
+
for hook_name in self._skip_layer_hook_names:
|
| 168 |
+
registry.remove_hook(hook_name, recurse=True)
|
| 169 |
+
|
| 170 |
+
def prepare_inputs(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]) -> list["BlockState"]:
|
| 171 |
+
if self.num_conditions == 1:
|
| 172 |
+
tuple_indices = [0]
|
| 173 |
+
input_predictions = ["pred_cond"]
|
| 174 |
+
elif self.num_conditions == 2:
|
| 175 |
+
tuple_indices = [0, 1]
|
| 176 |
+
input_predictions = (
|
| 177 |
+
["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"]
|
| 178 |
+
)
|
| 179 |
+
else:
|
| 180 |
+
tuple_indices = [0, 1, 0]
|
| 181 |
+
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
|
| 182 |
+
data_batches = []
|
| 183 |
+
for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
|
| 184 |
+
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
|
| 185 |
+
data_batches.append(data_batch)
|
| 186 |
+
return data_batches
|
| 187 |
+
|
| 188 |
+
def prepare_inputs_from_block_state(
|
| 189 |
+
self, data: "BlockState", input_fields: dict[str, str | tuple[str, str]]
|
| 190 |
+
) -> list["BlockState"]:
|
| 191 |
+
if self.num_conditions == 1:
|
| 192 |
+
tuple_indices = [0]
|
| 193 |
+
input_predictions = ["pred_cond"]
|
| 194 |
+
elif self.num_conditions == 2:
|
| 195 |
+
tuple_indices = [0, 1]
|
| 196 |
+
input_predictions = (
|
| 197 |
+
["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"]
|
| 198 |
+
)
|
| 199 |
+
else:
|
| 200 |
+
tuple_indices = [0, 1, 0]
|
| 201 |
+
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
|
| 202 |
+
data_batches = []
|
| 203 |
+
for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
|
| 204 |
+
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
|
| 205 |
+
data_batches.append(data_batch)
|
| 206 |
+
return data_batches
|
| 207 |
+
|
| 208 |
+
def forward(
|
| 209 |
+
self,
|
| 210 |
+
pred_cond: torch.Tensor,
|
| 211 |
+
pred_uncond: torch.Tensor | None = None,
|
| 212 |
+
pred_cond_skip: torch.Tensor | None = None,
|
| 213 |
+
) -> GuiderOutput:
|
| 214 |
+
pred = None
|
| 215 |
+
|
| 216 |
+
if not self._is_cfg_enabled() and not self._is_slg_enabled():
|
| 217 |
+
pred = pred_cond
|
| 218 |
+
elif not self._is_cfg_enabled():
|
| 219 |
+
shift = pred_cond - pred_cond_skip
|
| 220 |
+
pred = pred_cond if self.use_original_formulation else pred_cond_skip
|
| 221 |
+
pred = pred + self.skip_layer_guidance_scale * shift
|
| 222 |
+
elif not self._is_slg_enabled():
|
| 223 |
+
shift = pred_cond - pred_uncond
|
| 224 |
+
pred = pred_cond if self.use_original_formulation else pred_uncond
|
| 225 |
+
pred = pred + self.guidance_scale * shift
|
| 226 |
+
else:
|
| 227 |
+
shift = pred_cond - pred_uncond
|
| 228 |
+
shift_skip = pred_cond - pred_cond_skip
|
| 229 |
+
pred = pred_cond if self.use_original_formulation else pred_uncond
|
| 230 |
+
pred = pred + self.guidance_scale * shift + self.skip_layer_guidance_scale * shift_skip
|
| 231 |
+
|
| 232 |
+
if self.guidance_rescale > 0.0:
|
| 233 |
+
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
| 234 |
+
|
| 235 |
+
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
|
| 236 |
+
|
| 237 |
+
@property
|
| 238 |
+
def is_conditional(self) -> bool:
|
| 239 |
+
return self._count_prepared == 1 or self._count_prepared == 3
|
| 240 |
+
|
| 241 |
+
@property
|
| 242 |
+
def num_conditions(self) -> int:
|
| 243 |
+
num_conditions = 1
|
| 244 |
+
if self._is_cfg_enabled():
|
| 245 |
+
num_conditions += 1
|
| 246 |
+
if self._is_slg_enabled():
|
| 247 |
+
num_conditions += 1
|
| 248 |
+
return num_conditions
|
| 249 |
+
|
| 250 |
+
def _is_cfg_enabled(self) -> bool:
|
| 251 |
+
if not self._enabled:
|
| 252 |
+
return False
|
| 253 |
+
|
| 254 |
+
is_within_range = True
|
| 255 |
+
if self._num_inference_steps is not None:
|
| 256 |
+
skip_start_step = int(self._start * self._num_inference_steps)
|
| 257 |
+
skip_stop_step = int(self._stop * self._num_inference_steps)
|
| 258 |
+
is_within_range = skip_start_step <= self._step < skip_stop_step
|
| 259 |
+
|
| 260 |
+
is_close = False
|
| 261 |
+
if self.use_original_formulation:
|
| 262 |
+
is_close = math.isclose(self.guidance_scale, 0.0)
|
| 263 |
+
else:
|
| 264 |
+
is_close = math.isclose(self.guidance_scale, 1.0)
|
| 265 |
+
|
| 266 |
+
return is_within_range and not is_close
|
| 267 |
+
|
| 268 |
+
def _is_slg_enabled(self) -> bool:
|
| 269 |
+
if not self._enabled:
|
| 270 |
+
return False
|
| 271 |
+
|
| 272 |
+
is_within_range = True
|
| 273 |
+
if self._num_inference_steps is not None:
|
| 274 |
+
skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
|
| 275 |
+
skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
|
| 276 |
+
is_within_range = skip_start_step < self._step < skip_stop_step
|
| 277 |
+
|
| 278 |
+
is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0)
|
| 279 |
+
|
| 280 |
+
return is_within_range and not is_zero
|
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/smoothed_energy_guidance.py
ADDED
|
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
from typing import TYPE_CHECKING
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
|
| 22 |
+
from ..configuration_utils import register_to_config
|
| 23 |
+
from ..hooks import HookRegistry
|
| 24 |
+
from ..hooks.smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig, _apply_smoothed_energy_guidance_hook
|
| 25 |
+
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
if TYPE_CHECKING:
|
| 29 |
+
from ..modular_pipelines.modular_pipeline import BlockState
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class SmoothedEnergyGuidance(BaseGuidance):
|
| 33 |
+
"""
|
| 34 |
+
Smoothed Energy Guidance (SEG): https://huggingface.co/papers/2408.00760
|
| 35 |
+
|
| 36 |
+
SEG is only supported as an experimental prototype feature for now, so the implementation may be modified in the
|
| 37 |
+
future without warning or guarantee of reproducibility. This implementation assumes:
|
| 38 |
+
- Generated images are square (height == width)
|
| 39 |
+
- The model does not combine different modalities together (e.g., text and image latent streams are not combined
|
| 40 |
+
together such as Flux)
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
guidance_scale (`float`, defaults to `7.5`):
|
| 44 |
+
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
| 45 |
+
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
| 46 |
+
deterioration of image quality.
|
| 47 |
+
seg_guidance_scale (`float`, defaults to `3.0`):
|
| 48 |
+
The scale parameter for smoothed energy guidance. Anatomy and structure coherence may improve with higher
|
| 49 |
+
values, but it may also lead to overexposure and saturation.
|
| 50 |
+
seg_blur_sigma (`float`, defaults to `9999999.0`):
|
| 51 |
+
The amount by which we blur the attention weights. Setting this value greater than 9999.0 results in
|
| 52 |
+
infinite blur, which means uniform queries. Controlling it exponentially is empirically effective.
|
| 53 |
+
seg_blur_threshold_inf (`float`, defaults to `9999.0`):
|
| 54 |
+
The threshold above which the blur is considered infinite.
|
| 55 |
+
seg_guidance_start (`float`, defaults to `0.0`):
|
| 56 |
+
The fraction of the total number of denoising steps after which smoothed energy guidance starts.
|
| 57 |
+
seg_guidance_stop (`float`, defaults to `1.0`):
|
| 58 |
+
The fraction of the total number of denoising steps after which smoothed energy guidance stops.
|
| 59 |
+
seg_guidance_layers (`int` or `list[int]`, *optional*):
|
| 60 |
+
The layer indices to apply smoothed energy guidance to. Can be a single integer or a list of integers. If
|
| 61 |
+
not provided, `seg_guidance_config` must be provided. The recommended values are `[7, 8, 9]` for Stable
|
| 62 |
+
Diffusion 3.5 Medium.
|
| 63 |
+
seg_guidance_config (`SmoothedEnergyGuidanceConfig` or `list[SmoothedEnergyGuidanceConfig]`, *optional*):
|
| 64 |
+
The configuration for the smoothed energy layer guidance. Can be a single `SmoothedEnergyGuidanceConfig` or
|
| 65 |
+
a list of `SmoothedEnergyGuidanceConfig`. If not provided, `seg_guidance_layers` must be provided.
|
| 66 |
+
guidance_rescale (`float`, defaults to `0.0`):
|
| 67 |
+
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
| 68 |
+
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
| 69 |
+
Flawed](https://huggingface.co/papers/2305.08891).
|
| 70 |
+
use_original_formulation (`bool`, defaults to `False`):
|
| 71 |
+
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
| 72 |
+
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
| 73 |
+
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
| 74 |
+
start (`float`, defaults to `0.01`):
|
| 75 |
+
The fraction of the total number of denoising steps after which guidance starts.
|
| 76 |
+
stop (`float`, defaults to `0.2`):
|
| 77 |
+
The fraction of the total number of denoising steps after which guidance stops.
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
_input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"]
|
| 81 |
+
|
| 82 |
+
@register_to_config
|
| 83 |
+
def __init__(
|
| 84 |
+
self,
|
| 85 |
+
guidance_scale: float = 7.5,
|
| 86 |
+
seg_guidance_scale: float = 2.8,
|
| 87 |
+
seg_blur_sigma: float = 9999999.0,
|
| 88 |
+
seg_blur_threshold_inf: float = 9999.0,
|
| 89 |
+
seg_guidance_start: float = 0.0,
|
| 90 |
+
seg_guidance_stop: float = 1.0,
|
| 91 |
+
seg_guidance_layers: int | list[int] | None = None,
|
| 92 |
+
seg_guidance_config: SmoothedEnergyGuidanceConfig | list[SmoothedEnergyGuidanceConfig] = None,
|
| 93 |
+
guidance_rescale: float = 0.0,
|
| 94 |
+
use_original_formulation: bool = False,
|
| 95 |
+
start: float = 0.0,
|
| 96 |
+
stop: float = 1.0,
|
| 97 |
+
enabled: bool = True,
|
| 98 |
+
):
|
| 99 |
+
super().__init__(start, stop, enabled)
|
| 100 |
+
|
| 101 |
+
self.guidance_scale = guidance_scale
|
| 102 |
+
self.seg_guidance_scale = seg_guidance_scale
|
| 103 |
+
self.seg_blur_sigma = seg_blur_sigma
|
| 104 |
+
self.seg_blur_threshold_inf = seg_blur_threshold_inf
|
| 105 |
+
self.seg_guidance_start = seg_guidance_start
|
| 106 |
+
self.seg_guidance_stop = seg_guidance_stop
|
| 107 |
+
self.guidance_rescale = guidance_rescale
|
| 108 |
+
self.use_original_formulation = use_original_formulation
|
| 109 |
+
|
| 110 |
+
if not (0.0 <= seg_guidance_start < 1.0):
|
| 111 |
+
raise ValueError(f"Expected `seg_guidance_start` to be between 0.0 and 1.0, but got {seg_guidance_start}.")
|
| 112 |
+
if not (seg_guidance_start <= seg_guidance_stop <= 1.0):
|
| 113 |
+
raise ValueError(f"Expected `seg_guidance_stop` to be between 0.0 and 1.0, but got {seg_guidance_stop}.")
|
| 114 |
+
|
| 115 |
+
if seg_guidance_layers is None and seg_guidance_config is None:
|
| 116 |
+
raise ValueError(
|
| 117 |
+
"Either `seg_guidance_layers` or `seg_guidance_config` must be provided to enable Smoothed Energy Guidance."
|
| 118 |
+
)
|
| 119 |
+
if seg_guidance_layers is not None and seg_guidance_config is not None:
|
| 120 |
+
raise ValueError("Only one of `seg_guidance_layers` or `seg_guidance_config` can be provided.")
|
| 121 |
+
|
| 122 |
+
if seg_guidance_layers is not None:
|
| 123 |
+
if isinstance(seg_guidance_layers, int):
|
| 124 |
+
seg_guidance_layers = [seg_guidance_layers]
|
| 125 |
+
if not isinstance(seg_guidance_layers, list):
|
| 126 |
+
raise ValueError(
|
| 127 |
+
f"Expected `seg_guidance_layers` to be an int or a list of ints, but got {type(seg_guidance_layers)}."
|
| 128 |
+
)
|
| 129 |
+
seg_guidance_config = [SmoothedEnergyGuidanceConfig(layer, fqn="auto") for layer in seg_guidance_layers]
|
| 130 |
+
|
| 131 |
+
if isinstance(seg_guidance_config, dict):
|
| 132 |
+
seg_guidance_config = SmoothedEnergyGuidanceConfig.from_dict(seg_guidance_config)
|
| 133 |
+
|
| 134 |
+
if isinstance(seg_guidance_config, SmoothedEnergyGuidanceConfig):
|
| 135 |
+
seg_guidance_config = [seg_guidance_config]
|
| 136 |
+
|
| 137 |
+
if not isinstance(seg_guidance_config, list):
|
| 138 |
+
raise ValueError(
|
| 139 |
+
f"Expected `seg_guidance_config` to be a SmoothedEnergyGuidanceConfig or a list of SmoothedEnergyGuidanceConfig, but got {type(seg_guidance_config)}."
|
| 140 |
+
)
|
| 141 |
+
elif isinstance(next(iter(seg_guidance_config), None), dict):
|
| 142 |
+
seg_guidance_config = [SmoothedEnergyGuidanceConfig.from_dict(config) for config in seg_guidance_config]
|
| 143 |
+
|
| 144 |
+
self.seg_guidance_config = seg_guidance_config
|
| 145 |
+
self._seg_layer_hook_names = [f"SmoothedEnergyGuidance_{i}" for i in range(len(self.seg_guidance_config))]
|
| 146 |
+
|
| 147 |
+
def prepare_models(self, denoiser: torch.nn.Module) -> None:
|
| 148 |
+
if self._is_seg_enabled() and self.is_conditional and self._count_prepared > 1:
|
| 149 |
+
for name, config in zip(self._seg_layer_hook_names, self.seg_guidance_config):
|
| 150 |
+
_apply_smoothed_energy_guidance_hook(denoiser, config, self.seg_blur_sigma, name=name)
|
| 151 |
+
|
| 152 |
+
def cleanup_models(self, denoiser: torch.nn.Module):
|
| 153 |
+
if self._is_seg_enabled() and self.is_conditional and self._count_prepared > 1:
|
| 154 |
+
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
|
| 155 |
+
# Remove the hooks after inference
|
| 156 |
+
for hook_name in self._seg_layer_hook_names:
|
| 157 |
+
registry.remove_hook(hook_name, recurse=True)
|
| 158 |
+
|
| 159 |
+
def prepare_inputs(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]) -> list["BlockState"]:
|
| 160 |
+
if self.num_conditions == 1:
|
| 161 |
+
tuple_indices = [0]
|
| 162 |
+
input_predictions = ["pred_cond"]
|
| 163 |
+
elif self.num_conditions == 2:
|
| 164 |
+
tuple_indices = [0, 1]
|
| 165 |
+
input_predictions = (
|
| 166 |
+
["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_seg"]
|
| 167 |
+
)
|
| 168 |
+
else:
|
| 169 |
+
tuple_indices = [0, 1, 0]
|
| 170 |
+
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"]
|
| 171 |
+
data_batches = []
|
| 172 |
+
for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
|
| 173 |
+
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
|
| 174 |
+
data_batches.append(data_batch)
|
| 175 |
+
return data_batches
|
| 176 |
+
|
| 177 |
+
def prepare_inputs_from_block_state(
|
| 178 |
+
self, data: "BlockState", input_fields: dict[str, str | tuple[str, str]]
|
| 179 |
+
) -> list["BlockState"]:
|
| 180 |
+
if self.num_conditions == 1:
|
| 181 |
+
tuple_indices = [0]
|
| 182 |
+
input_predictions = ["pred_cond"]
|
| 183 |
+
elif self.num_conditions == 2:
|
| 184 |
+
tuple_indices = [0, 1]
|
| 185 |
+
input_predictions = (
|
| 186 |
+
["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_seg"]
|
| 187 |
+
)
|
| 188 |
+
else:
|
| 189 |
+
tuple_indices = [0, 1, 0]
|
| 190 |
+
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"]
|
| 191 |
+
data_batches = []
|
| 192 |
+
for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
|
| 193 |
+
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
|
| 194 |
+
data_batches.append(data_batch)
|
| 195 |
+
return data_batches
|
| 196 |
+
|
| 197 |
+
def forward(
|
| 198 |
+
self,
|
| 199 |
+
pred_cond: torch.Tensor,
|
| 200 |
+
pred_uncond: torch.Tensor | None = None,
|
| 201 |
+
pred_cond_seg: torch.Tensor | None = None,
|
| 202 |
+
) -> GuiderOutput:
|
| 203 |
+
pred = None
|
| 204 |
+
|
| 205 |
+
if not self._is_cfg_enabled() and not self._is_seg_enabled():
|
| 206 |
+
pred = pred_cond
|
| 207 |
+
elif not self._is_cfg_enabled():
|
| 208 |
+
shift = pred_cond - pred_cond_seg
|
| 209 |
+
pred = pred_cond if self.use_original_formulation else pred_cond_seg
|
| 210 |
+
pred = pred + self.seg_guidance_scale * shift
|
| 211 |
+
elif not self._is_seg_enabled():
|
| 212 |
+
shift = pred_cond - pred_uncond
|
| 213 |
+
pred = pred_cond if self.use_original_formulation else pred_uncond
|
| 214 |
+
pred = pred + self.guidance_scale * shift
|
| 215 |
+
else:
|
| 216 |
+
shift = pred_cond - pred_uncond
|
| 217 |
+
shift_seg = pred_cond - pred_cond_seg
|
| 218 |
+
pred = pred_cond if self.use_original_formulation else pred_uncond
|
| 219 |
+
pred = pred + self.guidance_scale * shift + self.seg_guidance_scale * shift_seg
|
| 220 |
+
|
| 221 |
+
if self.guidance_rescale > 0.0:
|
| 222 |
+
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
| 223 |
+
|
| 224 |
+
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
|
| 225 |
+
|
| 226 |
+
@property
|
| 227 |
+
def is_conditional(self) -> bool:
|
| 228 |
+
return self._count_prepared == 1 or self._count_prepared == 3
|
| 229 |
+
|
| 230 |
+
@property
|
| 231 |
+
def num_conditions(self) -> int:
|
| 232 |
+
num_conditions = 1
|
| 233 |
+
if self._is_cfg_enabled():
|
| 234 |
+
num_conditions += 1
|
| 235 |
+
if self._is_seg_enabled():
|
| 236 |
+
num_conditions += 1
|
| 237 |
+
return num_conditions
|
| 238 |
+
|
| 239 |
+
def _is_cfg_enabled(self) -> bool:
|
| 240 |
+
if not self._enabled:
|
| 241 |
+
return False
|
| 242 |
+
|
| 243 |
+
is_within_range = True
|
| 244 |
+
if self._num_inference_steps is not None:
|
| 245 |
+
skip_start_step = int(self._start * self._num_inference_steps)
|
| 246 |
+
skip_stop_step = int(self._stop * self._num_inference_steps)
|
| 247 |
+
is_within_range = skip_start_step <= self._step < skip_stop_step
|
| 248 |
+
|
| 249 |
+
is_close = False
|
| 250 |
+
if self.use_original_formulation:
|
| 251 |
+
is_close = math.isclose(self.guidance_scale, 0.0)
|
| 252 |
+
else:
|
| 253 |
+
is_close = math.isclose(self.guidance_scale, 1.0)
|
| 254 |
+
|
| 255 |
+
return is_within_range and not is_close
|
| 256 |
+
|
| 257 |
+
def _is_seg_enabled(self) -> bool:
|
| 258 |
+
if not self._enabled:
|
| 259 |
+
return False
|
| 260 |
+
|
| 261 |
+
is_within_range = True
|
| 262 |
+
if self._num_inference_steps is not None:
|
| 263 |
+
skip_start_step = int(self.seg_guidance_start * self._num_inference_steps)
|
| 264 |
+
skip_stop_step = int(self.seg_guidance_stop * self._num_inference_steps)
|
| 265 |
+
is_within_range = skip_start_step < self._step < skip_stop_step
|
| 266 |
+
|
| 267 |
+
is_zero = math.isclose(self.seg_guidance_scale, 0.0)
|
| 268 |
+
|
| 269 |
+
return is_within_range and not is_zero
|
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/tangential_classifier_free_guidance.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
from typing import TYPE_CHECKING
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
|
| 22 |
+
from ..configuration_utils import register_to_config
|
| 23 |
+
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
if TYPE_CHECKING:
|
| 27 |
+
from ..modular_pipelines.modular_pipeline import BlockState
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class TangentialClassifierFreeGuidance(BaseGuidance):
|
| 31 |
+
"""
|
| 32 |
+
Tangential Classifier Free Guidance (TCFG): https://huggingface.co/papers/2503.18137
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
guidance_scale (`float`, defaults to `7.5`):
|
| 36 |
+
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
| 37 |
+
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
| 38 |
+
deterioration of image quality.
|
| 39 |
+
guidance_rescale (`float`, defaults to `0.0`):
|
| 40 |
+
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
| 41 |
+
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
| 42 |
+
Flawed](https://huggingface.co/papers/2305.08891).
|
| 43 |
+
use_original_formulation (`bool`, defaults to `False`):
|
| 44 |
+
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
| 45 |
+
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
| 46 |
+
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
| 47 |
+
start (`float`, defaults to `0.0`):
|
| 48 |
+
The fraction of the total number of denoising steps after which guidance starts.
|
| 49 |
+
stop (`float`, defaults to `1.0`):
|
| 50 |
+
The fraction of the total number of denoising steps after which guidance stops.
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
_input_predictions = ["pred_cond", "pred_uncond"]
|
| 54 |
+
|
| 55 |
+
@register_to_config
|
| 56 |
+
def __init__(
|
| 57 |
+
self,
|
| 58 |
+
guidance_scale: float = 7.5,
|
| 59 |
+
guidance_rescale: float = 0.0,
|
| 60 |
+
use_original_formulation: bool = False,
|
| 61 |
+
start: float = 0.0,
|
| 62 |
+
stop: float = 1.0,
|
| 63 |
+
enabled: bool = True,
|
| 64 |
+
):
|
| 65 |
+
super().__init__(start, stop, enabled)
|
| 66 |
+
|
| 67 |
+
self.guidance_scale = guidance_scale
|
| 68 |
+
self.guidance_rescale = guidance_rescale
|
| 69 |
+
self.use_original_formulation = use_original_formulation
|
| 70 |
+
|
| 71 |
+
def prepare_inputs(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]) -> list["BlockState"]:
|
| 72 |
+
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
| 73 |
+
data_batches = []
|
| 74 |
+
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
| 75 |
+
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
|
| 76 |
+
data_batches.append(data_batch)
|
| 77 |
+
return data_batches
|
| 78 |
+
|
| 79 |
+
def prepare_inputs_from_block_state(
|
| 80 |
+
self, data: "BlockState", input_fields: dict[str, str | tuple[str, str]]
|
| 81 |
+
) -> list["BlockState"]:
|
| 82 |
+
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
| 83 |
+
data_batches = []
|
| 84 |
+
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
| 85 |
+
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
|
| 86 |
+
data_batches.append(data_batch)
|
| 87 |
+
return data_batches
|
| 88 |
+
|
| 89 |
+
def forward(self, pred_cond: torch.Tensor, pred_uncond: torch.Tensor | None = None) -> GuiderOutput:
|
| 90 |
+
pred = None
|
| 91 |
+
|
| 92 |
+
if not self._is_tcfg_enabled():
|
| 93 |
+
pred = pred_cond
|
| 94 |
+
else:
|
| 95 |
+
pred = normalized_guidance(pred_cond, pred_uncond, self.guidance_scale, self.use_original_formulation)
|
| 96 |
+
|
| 97 |
+
if self.guidance_rescale > 0.0:
|
| 98 |
+
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
| 99 |
+
|
| 100 |
+
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
|
| 101 |
+
|
| 102 |
+
@property
|
| 103 |
+
def is_conditional(self) -> bool:
|
| 104 |
+
return self._num_outputs_prepared == 1
|
| 105 |
+
|
| 106 |
+
@property
|
| 107 |
+
def num_conditions(self) -> int:
|
| 108 |
+
num_conditions = 1
|
| 109 |
+
if self._is_tcfg_enabled():
|
| 110 |
+
num_conditions += 1
|
| 111 |
+
return num_conditions
|
| 112 |
+
|
| 113 |
+
def _is_tcfg_enabled(self) -> bool:
|
| 114 |
+
if not self._enabled:
|
| 115 |
+
return False
|
| 116 |
+
|
| 117 |
+
is_within_range = True
|
| 118 |
+
if self._num_inference_steps is not None:
|
| 119 |
+
skip_start_step = int(self._start * self._num_inference_steps)
|
| 120 |
+
skip_stop_step = int(self._stop * self._num_inference_steps)
|
| 121 |
+
is_within_range = skip_start_step <= self._step < skip_stop_step
|
| 122 |
+
|
| 123 |
+
is_close = False
|
| 124 |
+
if self.use_original_formulation:
|
| 125 |
+
is_close = math.isclose(self.guidance_scale, 0.0)
|
| 126 |
+
else:
|
| 127 |
+
is_close = math.isclose(self.guidance_scale, 1.0)
|
| 128 |
+
|
| 129 |
+
return is_within_range and not is_close
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def normalized_guidance(
|
| 133 |
+
pred_cond: torch.Tensor, pred_uncond: torch.Tensor, guidance_scale: float, use_original_formulation: bool = False
|
| 134 |
+
) -> torch.Tensor:
|
| 135 |
+
cond_dtype = pred_cond.dtype
|
| 136 |
+
preds = torch.stack([pred_cond, pred_uncond], dim=1).float()
|
| 137 |
+
preds = preds.flatten(2)
|
| 138 |
+
U, S, Vh = torch.linalg.svd(preds, full_matrices=False)
|
| 139 |
+
Vh_modified = Vh.clone()
|
| 140 |
+
Vh_modified[:, 1] = 0
|
| 141 |
+
|
| 142 |
+
uncond_flat = pred_uncond.reshape(pred_uncond.size(0), 1, -1).float()
|
| 143 |
+
x_Vh = torch.matmul(uncond_flat, Vh.transpose(-2, -1))
|
| 144 |
+
x_Vh_V = torch.matmul(x_Vh, Vh_modified)
|
| 145 |
+
pred_uncond = x_Vh_V.reshape(pred_uncond.shape).to(cond_dtype)
|
| 146 |
+
|
| 147 |
+
pred = pred_cond if use_original_formulation else pred_uncond
|
| 148 |
+
shift = pred_cond - pred_uncond
|
| 149 |
+
pred = pred + guidance_scale * shift
|
| 150 |
+
|
| 151 |
+
return pred
|
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/__init__.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from ..utils import is_torch_available
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
if is_torch_available():
|
| 19 |
+
from .context_parallel import apply_context_parallel
|
| 20 |
+
from .faster_cache import FasterCacheConfig, apply_faster_cache
|
| 21 |
+
from .first_block_cache import FirstBlockCacheConfig, apply_first_block_cache
|
| 22 |
+
from .group_offloading import apply_group_offloading
|
| 23 |
+
from .hooks import HookRegistry, ModelHook
|
| 24 |
+
from .layer_skip import LayerSkipConfig, apply_layer_skip
|
| 25 |
+
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
|
| 26 |
+
from .mag_cache import MagCacheConfig, apply_mag_cache
|
| 27 |
+
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
|
| 28 |
+
from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig
|
| 29 |
+
from .taylorseer_cache import TaylorSeerCacheConfig, apply_taylorseer_cache
|
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (1.24 kB). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/__pycache__/_common.cpython-312.pyc
ADDED
|
Binary file (1.96 kB). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/__pycache__/_helpers.cpython-312.pyc
ADDED
|
Binary file (11.1 kB). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/__pycache__/context_parallel.cpython-312.pyc
ADDED
|
Binary file (22.8 kB). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/__pycache__/faster_cache.cpython-312.pyc
ADDED
|
Binary file (35.3 kB). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/__pycache__/first_block_cache.cpython-312.pyc
ADDED
|
Binary file (12.7 kB). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/__pycache__/group_offloading.cpython-312.pyc
ADDED
|
Binary file (48.8 kB). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/__pycache__/hooks.cpython-312.pyc
ADDED
|
Binary file (15.6 kB). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/__pycache__/layer_skip.cpython-312.pyc
ADDED
|
Binary file (14.6 kB). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/__pycache__/layerwise_casting.cpython-312.pyc
ADDED
|
Binary file (12.2 kB). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/__pycache__/mag_cache.cpython-312.pyc
ADDED
|
Binary file (20.2 kB). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/__pycache__/pyramid_attention_broadcast.cpython-312.pyc
ADDED
|
Binary file (17.2 kB). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/__pycache__/smoothed_energy_guidance_utils.cpython-312.pyc
ADDED
|
Binary file (8.94 kB). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/__pycache__/taylorseer_cache.cpython-312.pyc
ADDED
|
Binary file (17.8 kB). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/__pycache__/utils.cpython-312.pyc
ADDED
|
Binary file (2.25 kB). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/_common.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
|
| 17 |
+
from ..models.attention import AttentionModuleMixin, FeedForward, LuminaFeedForward
|
| 18 |
+
from ..models.attention_processor import Attention, MochiAttention
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
_ATTENTION_CLASSES = (Attention, MochiAttention, AttentionModuleMixin)
|
| 22 |
+
_FEEDFORWARD_CLASSES = (FeedForward, LuminaFeedForward)
|
| 23 |
+
|
| 24 |
+
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = (
|
| 25 |
+
"blocks",
|
| 26 |
+
"transformer_blocks",
|
| 27 |
+
"single_transformer_blocks",
|
| 28 |
+
"layers",
|
| 29 |
+
"visual_transformer_blocks",
|
| 30 |
+
)
|
| 31 |
+
_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
|
| 32 |
+
_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "layers")
|
| 33 |
+
|
| 34 |
+
_ALL_TRANSFORMER_BLOCK_IDENTIFIERS = tuple(
|
| 35 |
+
{
|
| 36 |
+
*_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS,
|
| 37 |
+
*_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS,
|
| 38 |
+
*_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS,
|
| 39 |
+
}
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# Layers supported for group offloading and layerwise casting
|
| 43 |
+
_GO_LC_SUPPORTED_PYTORCH_LAYERS = (
|
| 44 |
+
torch.nn.Conv1d,
|
| 45 |
+
torch.nn.Conv2d,
|
| 46 |
+
torch.nn.Conv3d,
|
| 47 |
+
torch.nn.ConvTranspose1d,
|
| 48 |
+
torch.nn.ConvTranspose2d,
|
| 49 |
+
torch.nn.ConvTranspose3d,
|
| 50 |
+
torch.nn.Linear,
|
| 51 |
+
torch.nn.Embedding,
|
| 52 |
+
# TODO(aryan): look into torch.nn.LayerNorm, torch.nn.GroupNorm later, seems to be causing some issues with CogVideoX
|
| 53 |
+
# because of double invocation of the same norm layer in CogVideoXLayerNorm
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def _get_submodule_from_fqn(module: torch.nn.Module, fqn: str) -> torch.nn.Module | None:
|
| 58 |
+
for submodule_name, submodule in module.named_modules():
|
| 59 |
+
if submodule_name == fqn:
|
| 60 |
+
return submodule
|
| 61 |
+
return None
|
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/_helpers.py
ADDED
|
@@ -0,0 +1,381 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import inspect
|
| 16 |
+
from dataclasses import dataclass
|
| 17 |
+
from typing import Any, Callable, Type
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class AttentionProcessorMetadata:
|
| 22 |
+
skip_processor_output_fn: Callable[[Any], Any]
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
class TransformerBlockMetadata:
|
| 27 |
+
return_hidden_states_index: int = None
|
| 28 |
+
return_encoder_hidden_states_index: int = None
|
| 29 |
+
hidden_states_argument_name: str = "hidden_states"
|
| 30 |
+
|
| 31 |
+
_cls: Type = None
|
| 32 |
+
_cached_parameter_indices: dict[str, int] = None
|
| 33 |
+
|
| 34 |
+
def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None):
|
| 35 |
+
kwargs = kwargs or {}
|
| 36 |
+
if identifier in kwargs:
|
| 37 |
+
return kwargs[identifier]
|
| 38 |
+
if self._cached_parameter_indices is not None:
|
| 39 |
+
return args[self._cached_parameter_indices[identifier]]
|
| 40 |
+
if self._cls is None:
|
| 41 |
+
raise ValueError("Model class is not set for metadata.")
|
| 42 |
+
parameters = list(inspect.signature(self._cls.forward).parameters.keys())
|
| 43 |
+
parameters = parameters[1:] # skip `self`
|
| 44 |
+
self._cached_parameter_indices = {param: i for i, param in enumerate(parameters)}
|
| 45 |
+
if identifier not in self._cached_parameter_indices:
|
| 46 |
+
raise ValueError(f"Parameter '{identifier}' not found in function signature but was requested.")
|
| 47 |
+
index = self._cached_parameter_indices[identifier]
|
| 48 |
+
if index >= len(args):
|
| 49 |
+
raise ValueError(f"Expected {index} arguments but got {len(args)}.")
|
| 50 |
+
return args[index]
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class AttentionProcessorRegistry:
|
| 54 |
+
_registry = {}
|
| 55 |
+
# TODO(aryan): this is only required for the time being because we need to do the registrations
|
| 56 |
+
# for classes. If we do it eagerly, i.e. call the functions in global scope, we will get circular
|
| 57 |
+
# import errors because of the models imported in this file.
|
| 58 |
+
_is_registered = False
|
| 59 |
+
|
| 60 |
+
@classmethod
|
| 61 |
+
def register(cls, model_class: Type, metadata: AttentionProcessorMetadata):
|
| 62 |
+
cls._register()
|
| 63 |
+
cls._registry[model_class] = metadata
|
| 64 |
+
|
| 65 |
+
@classmethod
|
| 66 |
+
def get(cls, model_class: Type) -> AttentionProcessorMetadata:
|
| 67 |
+
cls._register()
|
| 68 |
+
if model_class not in cls._registry:
|
| 69 |
+
raise ValueError(f"Model class {model_class} not registered.")
|
| 70 |
+
return cls._registry[model_class]
|
| 71 |
+
|
| 72 |
+
@classmethod
|
| 73 |
+
def _register(cls):
|
| 74 |
+
if cls._is_registered:
|
| 75 |
+
return
|
| 76 |
+
cls._is_registered = True
|
| 77 |
+
_register_attention_processors_metadata()
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class TransformerBlockRegistry:
|
| 81 |
+
_registry = {}
|
| 82 |
+
# TODO(aryan): this is only required for the time being because we need to do the registrations
|
| 83 |
+
# for classes. If we do it eagerly, i.e. call the functions in global scope, we will get circular
|
| 84 |
+
# import errors because of the models imported in this file.
|
| 85 |
+
_is_registered = False
|
| 86 |
+
|
| 87 |
+
@classmethod
|
| 88 |
+
def register(cls, model_class: Type, metadata: TransformerBlockMetadata):
|
| 89 |
+
cls._register()
|
| 90 |
+
metadata._cls = model_class
|
| 91 |
+
cls._registry[model_class] = metadata
|
| 92 |
+
|
| 93 |
+
@classmethod
|
| 94 |
+
def get(cls, model_class: Type) -> TransformerBlockMetadata:
|
| 95 |
+
cls._register()
|
| 96 |
+
if model_class not in cls._registry:
|
| 97 |
+
raise ValueError(f"Model class {model_class} not registered.")
|
| 98 |
+
return cls._registry[model_class]
|
| 99 |
+
|
| 100 |
+
@classmethod
|
| 101 |
+
def _register(cls):
|
| 102 |
+
if cls._is_registered:
|
| 103 |
+
return
|
| 104 |
+
cls._is_registered = True
|
| 105 |
+
_register_transformer_blocks_metadata()
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def _register_attention_processors_metadata():
|
| 109 |
+
from ..models.attention_processor import AttnProcessor2_0
|
| 110 |
+
from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor
|
| 111 |
+
from ..models.transformers.transformer_flux import FluxAttnProcessor
|
| 112 |
+
from ..models.transformers.transformer_hunyuanimage import HunyuanImageAttnProcessor
|
| 113 |
+
from ..models.transformers.transformer_qwenimage import QwenDoubleStreamAttnProcessor2_0
|
| 114 |
+
from ..models.transformers.transformer_wan import WanAttnProcessor2_0
|
| 115 |
+
from ..models.transformers.transformer_z_image import ZSingleStreamAttnProcessor
|
| 116 |
+
|
| 117 |
+
# AttnProcessor2_0
|
| 118 |
+
AttentionProcessorRegistry.register(
|
| 119 |
+
model_class=AttnProcessor2_0,
|
| 120 |
+
metadata=AttentionProcessorMetadata(
|
| 121 |
+
skip_processor_output_fn=_skip_proc_output_fn_Attention_AttnProcessor2_0,
|
| 122 |
+
),
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
# CogView4AttnProcessor
|
| 126 |
+
AttentionProcessorRegistry.register(
|
| 127 |
+
model_class=CogView4AttnProcessor,
|
| 128 |
+
metadata=AttentionProcessorMetadata(
|
| 129 |
+
skip_processor_output_fn=_skip_proc_output_fn_Attention_CogView4AttnProcessor,
|
| 130 |
+
),
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
# WanAttnProcessor2_0
|
| 134 |
+
AttentionProcessorRegistry.register(
|
| 135 |
+
model_class=WanAttnProcessor2_0,
|
| 136 |
+
metadata=AttentionProcessorMetadata(
|
| 137 |
+
skip_processor_output_fn=_skip_proc_output_fn_Attention_WanAttnProcessor2_0,
|
| 138 |
+
),
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
# FluxAttnProcessor
|
| 142 |
+
AttentionProcessorRegistry.register(
|
| 143 |
+
model_class=FluxAttnProcessor,
|
| 144 |
+
metadata=AttentionProcessorMetadata(skip_processor_output_fn=_skip_proc_output_fn_Attention_FluxAttnProcessor),
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
# QwenDoubleStreamAttnProcessor2
|
| 148 |
+
AttentionProcessorRegistry.register(
|
| 149 |
+
model_class=QwenDoubleStreamAttnProcessor2_0,
|
| 150 |
+
metadata=AttentionProcessorMetadata(
|
| 151 |
+
skip_processor_output_fn=_skip_proc_output_fn_Attention_QwenDoubleStreamAttnProcessor2_0
|
| 152 |
+
),
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
# HunyuanImageAttnProcessor
|
| 156 |
+
AttentionProcessorRegistry.register(
|
| 157 |
+
model_class=HunyuanImageAttnProcessor,
|
| 158 |
+
metadata=AttentionProcessorMetadata(
|
| 159 |
+
skip_processor_output_fn=_skip_proc_output_fn_Attention_HunyuanImageAttnProcessor,
|
| 160 |
+
),
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
# ZSingleStreamAttnProcessor
|
| 164 |
+
AttentionProcessorRegistry.register(
|
| 165 |
+
model_class=ZSingleStreamAttnProcessor,
|
| 166 |
+
metadata=AttentionProcessorMetadata(
|
| 167 |
+
skip_processor_output_fn=_skip_proc_output_fn_Attention_ZSingleStreamAttnProcessor,
|
| 168 |
+
),
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def _register_transformer_blocks_metadata():
|
| 173 |
+
from ..models.attention import BasicTransformerBlock, JointTransformerBlock
|
| 174 |
+
from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock
|
| 175 |
+
from ..models.transformers.transformer_bria import BriaTransformerBlock
|
| 176 |
+
from ..models.transformers.transformer_cogview4 import CogView4TransformerBlock
|
| 177 |
+
from ..models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
|
| 178 |
+
from ..models.transformers.transformer_hunyuan_video import (
|
| 179 |
+
HunyuanVideoSingleTransformerBlock,
|
| 180 |
+
HunyuanVideoTokenReplaceSingleTransformerBlock,
|
| 181 |
+
HunyuanVideoTokenReplaceTransformerBlock,
|
| 182 |
+
HunyuanVideoTransformerBlock,
|
| 183 |
+
)
|
| 184 |
+
from ..models.transformers.transformer_hunyuanimage import (
|
| 185 |
+
HunyuanImageSingleTransformerBlock,
|
| 186 |
+
HunyuanImageTransformerBlock,
|
| 187 |
+
)
|
| 188 |
+
from ..models.transformers.transformer_kandinsky import Kandinsky5TransformerDecoderBlock
|
| 189 |
+
from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock
|
| 190 |
+
from ..models.transformers.transformer_mochi import MochiTransformerBlock
|
| 191 |
+
from ..models.transformers.transformer_qwenimage import QwenImageTransformerBlock
|
| 192 |
+
from ..models.transformers.transformer_wan import WanTransformerBlock
|
| 193 |
+
from ..models.transformers.transformer_z_image import ZImageTransformerBlock
|
| 194 |
+
|
| 195 |
+
# BasicTransformerBlock
|
| 196 |
+
TransformerBlockRegistry.register(
|
| 197 |
+
model_class=BasicTransformerBlock,
|
| 198 |
+
metadata=TransformerBlockMetadata(
|
| 199 |
+
return_hidden_states_index=0,
|
| 200 |
+
return_encoder_hidden_states_index=None,
|
| 201 |
+
),
|
| 202 |
+
)
|
| 203 |
+
TransformerBlockRegistry.register(
|
| 204 |
+
model_class=BriaTransformerBlock,
|
| 205 |
+
metadata=TransformerBlockMetadata(
|
| 206 |
+
return_hidden_states_index=0,
|
| 207 |
+
return_encoder_hidden_states_index=None,
|
| 208 |
+
),
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
# CogVideoX
|
| 212 |
+
TransformerBlockRegistry.register(
|
| 213 |
+
model_class=CogVideoXBlock,
|
| 214 |
+
metadata=TransformerBlockMetadata(
|
| 215 |
+
return_hidden_states_index=0,
|
| 216 |
+
return_encoder_hidden_states_index=1,
|
| 217 |
+
),
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
# CogView4
|
| 221 |
+
TransformerBlockRegistry.register(
|
| 222 |
+
model_class=CogView4TransformerBlock,
|
| 223 |
+
metadata=TransformerBlockMetadata(
|
| 224 |
+
return_hidden_states_index=0,
|
| 225 |
+
return_encoder_hidden_states_index=1,
|
| 226 |
+
),
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
# Flux
|
| 230 |
+
TransformerBlockRegistry.register(
|
| 231 |
+
model_class=FluxTransformerBlock,
|
| 232 |
+
metadata=TransformerBlockMetadata(
|
| 233 |
+
return_hidden_states_index=1,
|
| 234 |
+
return_encoder_hidden_states_index=0,
|
| 235 |
+
),
|
| 236 |
+
)
|
| 237 |
+
TransformerBlockRegistry.register(
|
| 238 |
+
model_class=FluxSingleTransformerBlock,
|
| 239 |
+
metadata=TransformerBlockMetadata(
|
| 240 |
+
return_hidden_states_index=1,
|
| 241 |
+
return_encoder_hidden_states_index=0,
|
| 242 |
+
),
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
# HunyuanVideo
|
| 246 |
+
TransformerBlockRegistry.register(
|
| 247 |
+
model_class=HunyuanVideoTransformerBlock,
|
| 248 |
+
metadata=TransformerBlockMetadata(
|
| 249 |
+
return_hidden_states_index=0,
|
| 250 |
+
return_encoder_hidden_states_index=1,
|
| 251 |
+
),
|
| 252 |
+
)
|
| 253 |
+
TransformerBlockRegistry.register(
|
| 254 |
+
model_class=HunyuanVideoSingleTransformerBlock,
|
| 255 |
+
metadata=TransformerBlockMetadata(
|
| 256 |
+
return_hidden_states_index=0,
|
| 257 |
+
return_encoder_hidden_states_index=1,
|
| 258 |
+
),
|
| 259 |
+
)
|
| 260 |
+
TransformerBlockRegistry.register(
|
| 261 |
+
model_class=HunyuanVideoTokenReplaceTransformerBlock,
|
| 262 |
+
metadata=TransformerBlockMetadata(
|
| 263 |
+
return_hidden_states_index=0,
|
| 264 |
+
return_encoder_hidden_states_index=1,
|
| 265 |
+
),
|
| 266 |
+
)
|
| 267 |
+
TransformerBlockRegistry.register(
|
| 268 |
+
model_class=HunyuanVideoTokenReplaceSingleTransformerBlock,
|
| 269 |
+
metadata=TransformerBlockMetadata(
|
| 270 |
+
return_hidden_states_index=0,
|
| 271 |
+
return_encoder_hidden_states_index=1,
|
| 272 |
+
),
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
# LTXVideo
|
| 276 |
+
TransformerBlockRegistry.register(
|
| 277 |
+
model_class=LTXVideoTransformerBlock,
|
| 278 |
+
metadata=TransformerBlockMetadata(
|
| 279 |
+
return_hidden_states_index=0,
|
| 280 |
+
return_encoder_hidden_states_index=None,
|
| 281 |
+
),
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
# Mochi
|
| 285 |
+
TransformerBlockRegistry.register(
|
| 286 |
+
model_class=MochiTransformerBlock,
|
| 287 |
+
metadata=TransformerBlockMetadata(
|
| 288 |
+
return_hidden_states_index=0,
|
| 289 |
+
return_encoder_hidden_states_index=1,
|
| 290 |
+
),
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
# Wan
|
| 294 |
+
TransformerBlockRegistry.register(
|
| 295 |
+
model_class=WanTransformerBlock,
|
| 296 |
+
metadata=TransformerBlockMetadata(
|
| 297 |
+
return_hidden_states_index=0,
|
| 298 |
+
return_encoder_hidden_states_index=None,
|
| 299 |
+
),
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
# QwenImage
|
| 303 |
+
TransformerBlockRegistry.register(
|
| 304 |
+
model_class=QwenImageTransformerBlock,
|
| 305 |
+
metadata=TransformerBlockMetadata(
|
| 306 |
+
return_hidden_states_index=1,
|
| 307 |
+
return_encoder_hidden_states_index=0,
|
| 308 |
+
),
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
# HunyuanImage2.1
|
| 312 |
+
TransformerBlockRegistry.register(
|
| 313 |
+
model_class=HunyuanImageTransformerBlock,
|
| 314 |
+
metadata=TransformerBlockMetadata(
|
| 315 |
+
return_hidden_states_index=0,
|
| 316 |
+
return_encoder_hidden_states_index=1,
|
| 317 |
+
),
|
| 318 |
+
)
|
| 319 |
+
TransformerBlockRegistry.register(
|
| 320 |
+
model_class=HunyuanImageSingleTransformerBlock,
|
| 321 |
+
metadata=TransformerBlockMetadata(
|
| 322 |
+
return_hidden_states_index=0,
|
| 323 |
+
return_encoder_hidden_states_index=1,
|
| 324 |
+
),
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
# ZImage
|
| 328 |
+
TransformerBlockRegistry.register(
|
| 329 |
+
model_class=ZImageTransformerBlock,
|
| 330 |
+
metadata=TransformerBlockMetadata(
|
| 331 |
+
return_hidden_states_index=0,
|
| 332 |
+
return_encoder_hidden_states_index=None,
|
| 333 |
+
),
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
TransformerBlockRegistry.register(
|
| 337 |
+
model_class=JointTransformerBlock,
|
| 338 |
+
metadata=TransformerBlockMetadata(
|
| 339 |
+
return_hidden_states_index=1,
|
| 340 |
+
return_encoder_hidden_states_index=0,
|
| 341 |
+
),
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
# Kandinsky 5.0 (Kandinsky5TransformerDecoderBlock)
|
| 345 |
+
TransformerBlockRegistry.register(
|
| 346 |
+
model_class=Kandinsky5TransformerDecoderBlock,
|
| 347 |
+
metadata=TransformerBlockMetadata(
|
| 348 |
+
return_hidden_states_index=0,
|
| 349 |
+
return_encoder_hidden_states_index=None,
|
| 350 |
+
hidden_states_argument_name="visual_embed",
|
| 351 |
+
),
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
# fmt: off
|
| 356 |
+
def _skip_attention___ret___hidden_states(self, *args, **kwargs):
|
| 357 |
+
hidden_states = kwargs.get("hidden_states", None)
|
| 358 |
+
if hidden_states is None and len(args) > 0:
|
| 359 |
+
hidden_states = args[0]
|
| 360 |
+
return hidden_states
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs):
|
| 364 |
+
hidden_states = kwargs.get("hidden_states", None)
|
| 365 |
+
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
|
| 366 |
+
if hidden_states is None and len(args) > 0:
|
| 367 |
+
hidden_states = args[0]
|
| 368 |
+
if encoder_hidden_states is None and len(args) > 1:
|
| 369 |
+
encoder_hidden_states = args[1]
|
| 370 |
+
return hidden_states, encoder_hidden_states
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
_skip_proc_output_fn_Attention_AttnProcessor2_0 = _skip_attention___ret___hidden_states
|
| 374 |
+
_skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states
|
| 375 |
+
_skip_proc_output_fn_Attention_WanAttnProcessor2_0 = _skip_attention___ret___hidden_states
|
| 376 |
+
# not sure what this is yet.
|
| 377 |
+
_skip_proc_output_fn_Attention_FluxAttnProcessor = _skip_attention___ret___hidden_states
|
| 378 |
+
_skip_proc_output_fn_Attention_QwenDoubleStreamAttnProcessor2_0 = _skip_attention___ret___hidden_states
|
| 379 |
+
_skip_proc_output_fn_Attention_HunyuanImageAttnProcessor = _skip_attention___ret___hidden_states
|
| 380 |
+
_skip_proc_output_fn_Attention_ZSingleStreamAttnProcessor = _skip_attention___ret___hidden_states
|
| 381 |
+
# fmt: on
|
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/context_parallel.py
ADDED
|
@@ -0,0 +1,383 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import copy
|
| 15 |
+
import functools
|
| 16 |
+
import inspect
|
| 17 |
+
from dataclasses import dataclass
|
| 18 |
+
from typing import Type
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import torch.distributed as dist
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
if torch.distributed.is_available():
|
| 25 |
+
import torch.distributed._functional_collectives as funcol
|
| 26 |
+
|
| 27 |
+
from ..models._modeling_parallel import (
|
| 28 |
+
ContextParallelConfig,
|
| 29 |
+
ContextParallelInput,
|
| 30 |
+
ContextParallelModelPlan,
|
| 31 |
+
ContextParallelOutput,
|
| 32 |
+
gather_size_by_comm,
|
| 33 |
+
)
|
| 34 |
+
from ..utils import get_logger
|
| 35 |
+
from ..utils.torch_utils import maybe_allow_in_graph, unwrap_module
|
| 36 |
+
from .hooks import HookRegistry, ModelHook
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
logger = get_logger(__name__) # pylint: disable=invalid-name
|
| 40 |
+
|
| 41 |
+
_CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE = "cp_input---{}"
|
| 42 |
+
_CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE = "cp_output---{}"
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# TODO(aryan): consolidate with ._helpers.TransformerBlockMetadata
|
| 46 |
+
@dataclass
|
| 47 |
+
class ModuleForwardMetadata:
|
| 48 |
+
cached_parameter_indices: dict[str, int] = None
|
| 49 |
+
_cls: Type = None
|
| 50 |
+
|
| 51 |
+
def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None):
|
| 52 |
+
kwargs = kwargs or {}
|
| 53 |
+
|
| 54 |
+
if identifier in kwargs:
|
| 55 |
+
return kwargs[identifier], True, None
|
| 56 |
+
|
| 57 |
+
if self.cached_parameter_indices is not None:
|
| 58 |
+
index = self.cached_parameter_indices.get(identifier, None)
|
| 59 |
+
if index is None:
|
| 60 |
+
raise ValueError(f"Parameter '{identifier}' not found in cached indices.")
|
| 61 |
+
return args[index], False, index
|
| 62 |
+
|
| 63 |
+
if self._cls is None:
|
| 64 |
+
raise ValueError("Model class is not set for metadata.")
|
| 65 |
+
|
| 66 |
+
parameters = list(inspect.signature(self._cls.forward).parameters.keys())
|
| 67 |
+
parameters = parameters[1:] # skip `self`
|
| 68 |
+
self.cached_parameter_indices = {param: i for i, param in enumerate(parameters)}
|
| 69 |
+
|
| 70 |
+
if identifier not in self.cached_parameter_indices:
|
| 71 |
+
raise ValueError(f"Parameter '{identifier}' not found in function signature but was requested.")
|
| 72 |
+
|
| 73 |
+
index = self.cached_parameter_indices[identifier]
|
| 74 |
+
|
| 75 |
+
if index >= len(args):
|
| 76 |
+
raise ValueError(f"Expected {index} arguments but got {len(args)}.")
|
| 77 |
+
|
| 78 |
+
return args[index], False, index
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def apply_context_parallel(
|
| 82 |
+
module: torch.nn.Module,
|
| 83 |
+
parallel_config: ContextParallelConfig,
|
| 84 |
+
plan: dict[str, ContextParallelModelPlan],
|
| 85 |
+
) -> None:
|
| 86 |
+
"""Apply context parallel on a model."""
|
| 87 |
+
logger.debug(f"Applying context parallel with CP mesh: {parallel_config._mesh} and plan: {plan}")
|
| 88 |
+
|
| 89 |
+
for module_id, cp_model_plan in plan.items():
|
| 90 |
+
submodule = _get_submodule_by_name(module, module_id)
|
| 91 |
+
if not isinstance(submodule, list):
|
| 92 |
+
submodule = [submodule]
|
| 93 |
+
|
| 94 |
+
logger.debug(f"Applying ContextParallelHook to {module_id=} identifying a total of {len(submodule)} modules")
|
| 95 |
+
|
| 96 |
+
for m in submodule:
|
| 97 |
+
if isinstance(cp_model_plan, dict):
|
| 98 |
+
hook = ContextParallelSplitHook(cp_model_plan, parallel_config)
|
| 99 |
+
hook_name = _CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE.format(module_id)
|
| 100 |
+
elif isinstance(cp_model_plan, (ContextParallelOutput, list, tuple)):
|
| 101 |
+
if isinstance(cp_model_plan, ContextParallelOutput):
|
| 102 |
+
cp_model_plan = [cp_model_plan]
|
| 103 |
+
if not all(isinstance(x, ContextParallelOutput) for x in cp_model_plan):
|
| 104 |
+
raise ValueError(f"Expected all elements of cp_model_plan to be CPOutput, but got {cp_model_plan}")
|
| 105 |
+
hook = ContextParallelGatherHook(cp_model_plan, parallel_config)
|
| 106 |
+
hook_name = _CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE.format(module_id)
|
| 107 |
+
else:
|
| 108 |
+
raise ValueError(f"Unsupported context parallel model plan type: {type(cp_model_plan)}")
|
| 109 |
+
registry = HookRegistry.check_if_exists_or_initialize(m)
|
| 110 |
+
registry.register_hook(hook, hook_name)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def remove_context_parallel(module: torch.nn.Module, plan: dict[str, ContextParallelModelPlan]) -> None:
|
| 114 |
+
for module_id, cp_model_plan in plan.items():
|
| 115 |
+
submodule = _get_submodule_by_name(module, module_id)
|
| 116 |
+
if not isinstance(submodule, list):
|
| 117 |
+
submodule = [submodule]
|
| 118 |
+
|
| 119 |
+
for m in submodule:
|
| 120 |
+
registry = HookRegistry.check_if_exists_or_initialize(m)
|
| 121 |
+
if isinstance(cp_model_plan, dict):
|
| 122 |
+
hook_name = _CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE.format(module_id)
|
| 123 |
+
elif isinstance(cp_model_plan, (ContextParallelOutput, list, tuple)):
|
| 124 |
+
hook_name = _CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE.format(module_id)
|
| 125 |
+
else:
|
| 126 |
+
raise ValueError(f"Unsupported context parallel model plan type: {type(cp_model_plan)}")
|
| 127 |
+
registry.remove_hook(hook_name)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class ContextParallelSplitHook(ModelHook):
|
| 131 |
+
def __init__(self, metadata: ContextParallelModelPlan, parallel_config: ContextParallelConfig) -> None:
|
| 132 |
+
super().__init__()
|
| 133 |
+
self.metadata = metadata
|
| 134 |
+
self.parallel_config = parallel_config
|
| 135 |
+
self.module_forward_metadata = None
|
| 136 |
+
|
| 137 |
+
def initialize_hook(self, module):
|
| 138 |
+
cls = unwrap_module(module).__class__
|
| 139 |
+
self.module_forward_metadata = ModuleForwardMetadata(_cls=cls)
|
| 140 |
+
return module
|
| 141 |
+
|
| 142 |
+
def pre_forward(self, module, *args, **kwargs):
|
| 143 |
+
args_list = list(args)
|
| 144 |
+
|
| 145 |
+
for name, cpm in self.metadata.items():
|
| 146 |
+
if isinstance(cpm, ContextParallelInput) and cpm.split_output:
|
| 147 |
+
continue
|
| 148 |
+
|
| 149 |
+
# Maybe the parameter was passed as a keyword argument
|
| 150 |
+
input_val, is_kwarg, index = self.module_forward_metadata._get_parameter_from_args_kwargs(
|
| 151 |
+
name, args_list, kwargs
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
if input_val is None:
|
| 155 |
+
continue
|
| 156 |
+
|
| 157 |
+
# The input_val may be a tensor or list/tuple of tensors. In certain cases, user may specify to shard
|
| 158 |
+
# the output instead of input for a particular layer by setting split_output=True
|
| 159 |
+
if isinstance(input_val, torch.Tensor):
|
| 160 |
+
input_val = self._prepare_cp_input(input_val, cpm)
|
| 161 |
+
elif isinstance(input_val, (list, tuple)):
|
| 162 |
+
if len(input_val) != len(cpm):
|
| 163 |
+
raise ValueError(
|
| 164 |
+
f"Expected input model plan to have {len(input_val)} elements, but got {len(cpm)}."
|
| 165 |
+
)
|
| 166 |
+
sharded_input_val = []
|
| 167 |
+
for i, x in enumerate(input_val):
|
| 168 |
+
if torch.is_tensor(x) and not cpm[i].split_output:
|
| 169 |
+
x = self._prepare_cp_input(x, cpm[i])
|
| 170 |
+
sharded_input_val.append(x)
|
| 171 |
+
input_val = sharded_input_val
|
| 172 |
+
else:
|
| 173 |
+
raise ValueError(f"Unsupported input type: {type(input_val)}")
|
| 174 |
+
|
| 175 |
+
if is_kwarg:
|
| 176 |
+
kwargs[name] = input_val
|
| 177 |
+
elif index is not None and index < len(args_list):
|
| 178 |
+
args_list[index] = input_val
|
| 179 |
+
else:
|
| 180 |
+
raise ValueError(
|
| 181 |
+
f"An unexpected error occurred while processing the input '{name}'. Please open an "
|
| 182 |
+
f"issue at https://github.com/huggingface/diffusers/issues and provide a minimal reproducible "
|
| 183 |
+
f"example along with the full stack trace."
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
return tuple(args_list), kwargs
|
| 187 |
+
|
| 188 |
+
def post_forward(self, module, output):
|
| 189 |
+
is_tensor = isinstance(output, torch.Tensor)
|
| 190 |
+
is_tensor_list = isinstance(output, (list, tuple)) and all(isinstance(x, torch.Tensor) for x in output)
|
| 191 |
+
|
| 192 |
+
if not is_tensor and not is_tensor_list:
|
| 193 |
+
raise ValueError(f"Expected output to be a tensor or a list/tuple of tensors, but got {type(output)}.")
|
| 194 |
+
|
| 195 |
+
output = [output] if is_tensor else list(output)
|
| 196 |
+
for index, cpm in self.metadata.items():
|
| 197 |
+
if not isinstance(cpm, ContextParallelInput) or not cpm.split_output:
|
| 198 |
+
continue
|
| 199 |
+
if index >= len(output):
|
| 200 |
+
raise ValueError(f"Index {index} out of bounds for output of length {len(output)}.")
|
| 201 |
+
current_output = output[index]
|
| 202 |
+
current_output = self._prepare_cp_input(current_output, cpm)
|
| 203 |
+
output[index] = current_output
|
| 204 |
+
|
| 205 |
+
return output[0] if is_tensor else tuple(output)
|
| 206 |
+
|
| 207 |
+
def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput) -> torch.Tensor:
|
| 208 |
+
if cp_input.expected_dims is not None and x.dim() != cp_input.expected_dims:
|
| 209 |
+
logger.warning_once(
|
| 210 |
+
f"Expected input tensor to have {cp_input.expected_dims} dimensions, but got {x.dim()} dimensions, split will not be applied."
|
| 211 |
+
)
|
| 212 |
+
return x
|
| 213 |
+
else:
|
| 214 |
+
if self.parallel_config.ulysses_anything:
|
| 215 |
+
return PartitionAnythingSharder.shard_anything(
|
| 216 |
+
x, cp_input.split_dim, self.parallel_config._flattened_mesh
|
| 217 |
+
)
|
| 218 |
+
return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
class ContextParallelGatherHook(ModelHook):
|
| 222 |
+
def __init__(self, metadata: ContextParallelModelPlan, parallel_config: ContextParallelConfig) -> None:
|
| 223 |
+
super().__init__()
|
| 224 |
+
self.metadata = metadata
|
| 225 |
+
self.parallel_config = parallel_config
|
| 226 |
+
|
| 227 |
+
def post_forward(self, module, output):
|
| 228 |
+
is_tensor = isinstance(output, torch.Tensor)
|
| 229 |
+
|
| 230 |
+
if is_tensor:
|
| 231 |
+
output = [output]
|
| 232 |
+
elif not (isinstance(output, (list, tuple)) and all(isinstance(x, torch.Tensor) for x in output)):
|
| 233 |
+
raise ValueError(f"Expected output to be a tensor or a list/tuple of tensors, but got {type(output)}.")
|
| 234 |
+
|
| 235 |
+
output = list(output)
|
| 236 |
+
|
| 237 |
+
if len(output) != len(self.metadata):
|
| 238 |
+
raise ValueError(f"Expected output to have {len(self.metadata)} elements, but got {len(output)}.")
|
| 239 |
+
|
| 240 |
+
for i, cpm in enumerate(self.metadata):
|
| 241 |
+
if cpm is None:
|
| 242 |
+
continue
|
| 243 |
+
if self.parallel_config.ulysses_anything:
|
| 244 |
+
output[i] = PartitionAnythingSharder.unshard_anything(
|
| 245 |
+
output[i], cpm.gather_dim, self.parallel_config._flattened_mesh
|
| 246 |
+
)
|
| 247 |
+
else:
|
| 248 |
+
output[i] = EquipartitionSharder.unshard(
|
| 249 |
+
output[i], cpm.gather_dim, self.parallel_config._flattened_mesh
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
return output[0] if is_tensor else tuple(output)
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
class AllGatherFunction(torch.autograd.Function):
|
| 256 |
+
@staticmethod
|
| 257 |
+
def forward(ctx, tensor, dim, group):
|
| 258 |
+
ctx.dim = dim
|
| 259 |
+
ctx.group = group
|
| 260 |
+
ctx.world_size = torch.distributed.get_world_size(group)
|
| 261 |
+
ctx.rank = torch.distributed.get_rank(group)
|
| 262 |
+
return funcol.all_gather_tensor(tensor, dim, group=group)
|
| 263 |
+
|
| 264 |
+
@staticmethod
|
| 265 |
+
def backward(ctx, grad_output):
|
| 266 |
+
grad_chunks = torch.chunk(grad_output, ctx.world_size, dim=ctx.dim)
|
| 267 |
+
return grad_chunks[ctx.rank], None, None
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
class EquipartitionSharder:
|
| 271 |
+
@classmethod
|
| 272 |
+
def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
|
| 273 |
+
# NOTE: the following assertion does not have to be true in general. We simply enforce it for now
|
| 274 |
+
# because the alternate case has not yet been tested/required for any model.
|
| 275 |
+
assert tensor.size()[dim] % mesh.size() == 0, (
|
| 276 |
+
"Tensor size along dimension to be sharded must be divisible by mesh size"
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
# The following is not fullgraph compatible with Dynamo (fails in DeviceMesh.get_rank)
|
| 280 |
+
# return tensor.chunk(mesh.size(), dim=dim)[mesh.get_rank()]
|
| 281 |
+
|
| 282 |
+
return tensor.chunk(mesh.size(), dim=dim)[torch.distributed.get_rank(mesh.get_group())]
|
| 283 |
+
|
| 284 |
+
@classmethod
|
| 285 |
+
def unshard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
|
| 286 |
+
tensor = tensor.contiguous()
|
| 287 |
+
tensor = AllGatherFunction.apply(tensor, dim, mesh.get_group())
|
| 288 |
+
return tensor
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
class AllGatherAnythingFunction(torch.autograd.Function):
|
| 292 |
+
@staticmethod
|
| 293 |
+
def forward(ctx, tensor: torch.Tensor, dim: int, group: dist.device_mesh.DeviceMesh):
|
| 294 |
+
ctx.dim = dim
|
| 295 |
+
ctx.group = group
|
| 296 |
+
ctx.world_size = dist.get_world_size(group)
|
| 297 |
+
ctx.rank = dist.get_rank(group)
|
| 298 |
+
gathered_tensor = _all_gather_anything(tensor, dim, group)
|
| 299 |
+
return gathered_tensor
|
| 300 |
+
|
| 301 |
+
@staticmethod
|
| 302 |
+
def backward(ctx, grad_output):
|
| 303 |
+
# NOTE: We use `tensor_split` instead of chunk, because the `chunk`
|
| 304 |
+
# function may return fewer than the specified number of chunks!
|
| 305 |
+
grad_splits = torch.tensor_split(grad_output, ctx.world_size, dim=ctx.dim)
|
| 306 |
+
return grad_splits[ctx.rank], None, None
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
class PartitionAnythingSharder:
|
| 310 |
+
@classmethod
|
| 311 |
+
def shard_anything(
|
| 312 |
+
cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh
|
| 313 |
+
) -> torch.Tensor:
|
| 314 |
+
assert tensor.size()[dim] >= mesh.size(), (
|
| 315 |
+
f"Cannot shard tensor of size {tensor.size()} along dim {dim} across mesh of size {mesh.size()}."
|
| 316 |
+
)
|
| 317 |
+
# NOTE: We use `tensor_split` instead of chunk, because the `chunk`
|
| 318 |
+
# function may return fewer than the specified number of chunks!
|
| 319 |
+
return tensor.tensor_split(mesh.size(), dim=dim)[dist.get_rank(mesh.get_group())]
|
| 320 |
+
|
| 321 |
+
@classmethod
|
| 322 |
+
def unshard_anything(
|
| 323 |
+
cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh
|
| 324 |
+
) -> torch.Tensor:
|
| 325 |
+
tensor = tensor.contiguous()
|
| 326 |
+
tensor = AllGatherAnythingFunction.apply(tensor, dim, mesh.get_group())
|
| 327 |
+
return tensor
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
@functools.lru_cache(maxsize=64)
|
| 331 |
+
def _fill_gather_shapes(shape: tuple[int], gather_dims: tuple[int], dim: int, world_size: int) -> list[list[int]]:
|
| 332 |
+
gather_shapes = []
|
| 333 |
+
for i in range(world_size):
|
| 334 |
+
rank_shape = list(copy.deepcopy(shape))
|
| 335 |
+
rank_shape[dim] = gather_dims[i]
|
| 336 |
+
gather_shapes.append(rank_shape)
|
| 337 |
+
return gather_shapes
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
@maybe_allow_in_graph
|
| 341 |
+
def _all_gather_anything(tensor: torch.Tensor, dim: int, group: dist.device_mesh.DeviceMesh) -> torch.Tensor:
|
| 342 |
+
world_size = dist.get_world_size(group=group)
|
| 343 |
+
|
| 344 |
+
tensor = tensor.contiguous()
|
| 345 |
+
shape = tensor.shape
|
| 346 |
+
rank_dim = shape[dim]
|
| 347 |
+
gather_dims = gather_size_by_comm(rank_dim, group)
|
| 348 |
+
|
| 349 |
+
gather_shapes = _fill_gather_shapes(tuple(shape), tuple(gather_dims), dim, world_size)
|
| 350 |
+
|
| 351 |
+
gathered_tensors = [torch.empty(shape, device=tensor.device, dtype=tensor.dtype) for shape in gather_shapes]
|
| 352 |
+
|
| 353 |
+
dist.all_gather(gathered_tensors, tensor, group=group)
|
| 354 |
+
gathered_tensor = torch.cat(gathered_tensors, dim=dim)
|
| 355 |
+
return gathered_tensor
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
def _get_submodule_by_name(model: torch.nn.Module, name: str) -> torch.nn.Module | list[torch.nn.Module]:
|
| 359 |
+
if name.count("*") > 1:
|
| 360 |
+
raise ValueError("Wildcard '*' can only be used once in the name")
|
| 361 |
+
return _find_submodule_by_name(model, name)
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
def _find_submodule_by_name(model: torch.nn.Module, name: str) -> torch.nn.Module | list[torch.nn.Module]:
|
| 365 |
+
if name == "":
|
| 366 |
+
return model
|
| 367 |
+
first_atom, remaining_name = name.split(".", 1) if "." in name else (name, "")
|
| 368 |
+
if first_atom == "*":
|
| 369 |
+
if not isinstance(model, torch.nn.ModuleList):
|
| 370 |
+
raise ValueError("Wildcard '*' can only be used with ModuleList")
|
| 371 |
+
submodules = []
|
| 372 |
+
for submodule in model:
|
| 373 |
+
subsubmodules = _find_submodule_by_name(submodule, remaining_name)
|
| 374 |
+
if not isinstance(subsubmodules, list):
|
| 375 |
+
subsubmodules = [subsubmodules]
|
| 376 |
+
submodules.extend(subsubmodules)
|
| 377 |
+
return submodules
|
| 378 |
+
else:
|
| 379 |
+
if hasattr(model, first_atom):
|
| 380 |
+
submodule = getattr(model, first_atom)
|
| 381 |
+
return _find_submodule_by_name(submodule, remaining_name)
|
| 382 |
+
else:
|
| 383 |
+
raise ValueError(f"'{first_atom}' is not a submodule of '{model.__class__.__name__}'")
|
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/faster_cache.py
ADDED
|
@@ -0,0 +1,654 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import re
|
| 16 |
+
from dataclasses import dataclass
|
| 17 |
+
from typing import Any, Callable
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
|
| 21 |
+
from ..models.attention import AttentionModuleMixin
|
| 22 |
+
from ..models.modeling_outputs import Transformer2DModelOutput
|
| 23 |
+
from ..utils import logging
|
| 24 |
+
from ._common import _ATTENTION_CLASSES
|
| 25 |
+
from .hooks import HookRegistry, ModelHook
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
_FASTER_CACHE_DENOISER_HOOK = "faster_cache_denoiser"
|
| 32 |
+
_FASTER_CACHE_BLOCK_HOOK = "faster_cache_block"
|
| 33 |
+
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = (
|
| 34 |
+
"^blocks.*attn",
|
| 35 |
+
"^transformer_blocks.*attn",
|
| 36 |
+
"^single_transformer_blocks.*attn",
|
| 37 |
+
)
|
| 38 |
+
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("^temporal_transformer_blocks.*attn",)
|
| 39 |
+
_TRANSFORMER_BLOCK_IDENTIFIERS = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS + _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
|
| 40 |
+
_UNCOND_COND_INPUT_KWARGS_IDENTIFIERS = (
|
| 41 |
+
"hidden_states",
|
| 42 |
+
"encoder_hidden_states",
|
| 43 |
+
"timestep",
|
| 44 |
+
"attention_mask",
|
| 45 |
+
"encoder_attention_mask",
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@dataclass
|
| 50 |
+
class FasterCacheConfig:
|
| 51 |
+
r"""
|
| 52 |
+
Configuration for [FasterCache](https://huggingface.co/papers/2410.19355).
|
| 53 |
+
|
| 54 |
+
Attributes:
|
| 55 |
+
spatial_attention_block_skip_range (`int`, defaults to `2`):
|
| 56 |
+
Calculate the attention states every `N` iterations. If this is set to `N`, the attention computation will
|
| 57 |
+
be skipped `N - 1` times (i.e., cached attention states will be reused) before computing the new attention
|
| 58 |
+
states again.
|
| 59 |
+
temporal_attention_block_skip_range (`int`, *optional*, defaults to `None`):
|
| 60 |
+
Calculate the attention states every `N` iterations. If this is set to `N`, the attention computation will
|
| 61 |
+
be skipped `N - 1` times (i.e., cached attention states will be reused) before computing the new attention
|
| 62 |
+
states again.
|
| 63 |
+
spatial_attention_timestep_skip_range (`tuple[float, float]`, defaults to `(-1, 681)`):
|
| 64 |
+
The timestep range within which the spatial attention computation can be skipped without a significant loss
|
| 65 |
+
in quality. This is to be determined by the user based on the underlying model. The first value in the
|
| 66 |
+
tuple is the lower bound and the second value is the upper bound. Typically, diffusion timesteps for
|
| 67 |
+
denoising are in the reversed range of 0 to 1000 (i.e. denoising starts at timestep 1000 and ends at
|
| 68 |
+
timestep 0). For the default values, this would mean that the spatial attention computation skipping will
|
| 69 |
+
be applicable only after denoising timestep 681 is reached, and continue until the end of the denoising
|
| 70 |
+
process.
|
| 71 |
+
temporal_attention_timestep_skip_range (`tuple[float, float]`, *optional*, defaults to `None`):
|
| 72 |
+
The timestep range within which the temporal attention computation can be skipped without a significant
|
| 73 |
+
loss in quality. This is to be determined by the user based on the underlying model. The first value in the
|
| 74 |
+
tuple is the lower bound and the second value is the upper bound. Typically, diffusion timesteps for
|
| 75 |
+
denoising are in the reversed range of 0 to 1000 (i.e. denoising starts at timestep 1000 and ends at
|
| 76 |
+
timestep 0).
|
| 77 |
+
low_frequency_weight_update_timestep_range (`tuple[int, int]`, defaults to `(99, 901)`):
|
| 78 |
+
The timestep range within which the low frequency weight scaling update is applied. The first value in the
|
| 79 |
+
tuple is the lower bound and the second value is the upper bound of the timestep range. The callback
|
| 80 |
+
function for the update is called only within this range.
|
| 81 |
+
high_frequency_weight_update_timestep_range (`tuple[int, int]`, defaults to `(-1, 301)`):
|
| 82 |
+
The timestep range within which the high frequency weight scaling update is applied. The first value in the
|
| 83 |
+
tuple is the lower bound and the second value is the upper bound of the timestep range. The callback
|
| 84 |
+
function for the update is called only within this range.
|
| 85 |
+
alpha_low_frequency (`float`, defaults to `1.1`):
|
| 86 |
+
The weight to scale the low frequency updates by. This is used to approximate the unconditional branch from
|
| 87 |
+
the conditional branch outputs.
|
| 88 |
+
alpha_high_frequency (`float`, defaults to `1.1`):
|
| 89 |
+
The weight to scale the high frequency updates by. This is used to approximate the unconditional branch
|
| 90 |
+
from the conditional branch outputs.
|
| 91 |
+
unconditional_batch_skip_range (`int`, defaults to `5`):
|
| 92 |
+
Process the unconditional branch every `N` iterations. If this is set to `N`, the unconditional branch
|
| 93 |
+
computation will be skipped `N - 1` times (i.e., cached unconditional branch states will be reused) before
|
| 94 |
+
computing the new unconditional branch states again.
|
| 95 |
+
unconditional_batch_timestep_skip_range (`tuple[float, float]`, defaults to `(-1, 641)`):
|
| 96 |
+
The timestep range within which the unconditional branch computation can be skipped without a significant
|
| 97 |
+
loss in quality. This is to be determined by the user based on the underlying model. The first value in the
|
| 98 |
+
tuple is the lower bound and the second value is the upper bound.
|
| 99 |
+
spatial_attention_block_identifiers (`tuple[str, ...]`, defaults to `("blocks.*attn1", "transformer_blocks.*attn1", "single_transformer_blocks.*attn1")`):
|
| 100 |
+
The identifiers to match the spatial attention blocks in the model. If the name of the block contains any
|
| 101 |
+
of these identifiers, FasterCache will be applied to that block. This can either be the full layer names,
|
| 102 |
+
partial layer names, or regex patterns. Matching will always be done using a regex match.
|
| 103 |
+
temporal_attention_block_identifiers (`tuple[str, ...]`, defaults to `("temporal_transformer_blocks.*attn1",)`):
|
| 104 |
+
The identifiers to match the temporal attention blocks in the model. If the name of the block contains any
|
| 105 |
+
of these identifiers, FasterCache will be applied to that block. This can either be the full layer names,
|
| 106 |
+
partial layer names, or regex patterns. Matching will always be done using a regex match.
|
| 107 |
+
attention_weight_callback (`Callable[[torch.nn.Module], float]`, defaults to `None`):
|
| 108 |
+
The callback function to determine the weight to scale the attention outputs by. This function should take
|
| 109 |
+
the attention module as input and return a float value. This is used to approximate the unconditional
|
| 110 |
+
branch from the conditional branch outputs. If not provided, the default weight is 0.5 for all timesteps.
|
| 111 |
+
Typically, as described in the paper, this weight should gradually increase from 0 to 1 as the inference
|
| 112 |
+
progresses. Users are encouraged to experiment and provide custom weight schedules that take into account
|
| 113 |
+
the number of inference steps and underlying model behaviour as denoising progresses.
|
| 114 |
+
low_frequency_weight_callback (`Callable[[torch.nn.Module], float]`, defaults to `None`):
|
| 115 |
+
The callback function to determine the weight to scale the low frequency updates by. If not provided, the
|
| 116 |
+
default weight is 1.1 for timesteps within the range specified (as described in the paper).
|
| 117 |
+
high_frequency_weight_callback (`Callable[[torch.nn.Module], float]`, defaults to `None`):
|
| 118 |
+
The callback function to determine the weight to scale the high frequency updates by. If not provided, the
|
| 119 |
+
default weight is 1.1 for timesteps within the range specified (as described in the paper).
|
| 120 |
+
tensor_format (`str`, defaults to `"BCFHW"`):
|
| 121 |
+
The format of the input tensors. This should be one of `"BCFHW"`, `"BFCHW"`, or `"BCHW"`. The format is
|
| 122 |
+
used to split individual latent frames in order for low and high frequency components to be computed.
|
| 123 |
+
is_guidance_distilled (`bool`, defaults to `False`):
|
| 124 |
+
Whether the model is guidance distilled or not. If the model is guidance distilled, FasterCache will not be
|
| 125 |
+
applied at the denoiser-level to skip the unconditional branch computation (as there is none).
|
| 126 |
+
_unconditional_conditional_input_kwargs_identifiers (`list[str]`, defaults to `("hidden_states", "encoder_hidden_states", "timestep", "attention_mask", "encoder_attention_mask")`):
|
| 127 |
+
The identifiers to match the input kwargs that contain the batchwise-concatenated unconditional and
|
| 128 |
+
conditional inputs. If the name of the input kwargs contains any of these identifiers, FasterCache will
|
| 129 |
+
split the inputs into unconditional and conditional branches. This must be a list of exact input kwargs
|
| 130 |
+
names that contain the batchwise-concatenated unconditional and conditional inputs.
|
| 131 |
+
"""
|
| 132 |
+
|
| 133 |
+
# In the paper and codebase, they hardcode these values to 2. However, it can be made configurable
|
| 134 |
+
# after some testing. We default to 2 if these parameters are not provided.
|
| 135 |
+
spatial_attention_block_skip_range: int = 2
|
| 136 |
+
temporal_attention_block_skip_range: int | None = None
|
| 137 |
+
|
| 138 |
+
spatial_attention_timestep_skip_range: tuple[int, int] = (-1, 681)
|
| 139 |
+
temporal_attention_timestep_skip_range: tuple[int, int] = (-1, 681)
|
| 140 |
+
|
| 141 |
+
# Indicator functions for low/high frequency as mentioned in Equation 11 of the paper
|
| 142 |
+
low_frequency_weight_update_timestep_range: tuple[int, int] = (99, 901)
|
| 143 |
+
high_frequency_weight_update_timestep_range: tuple[int, int] = (-1, 301)
|
| 144 |
+
|
| 145 |
+
# ⍺1 and ⍺2 as mentioned in Equation 11 of the paper
|
| 146 |
+
alpha_low_frequency: float = 1.1
|
| 147 |
+
alpha_high_frequency: float = 1.1
|
| 148 |
+
|
| 149 |
+
# n as described in CFG-Cache explanation in the paper - dependent on the model
|
| 150 |
+
unconditional_batch_skip_range: int = 5
|
| 151 |
+
unconditional_batch_timestep_skip_range: tuple[int, int] = (-1, 641)
|
| 152 |
+
|
| 153 |
+
spatial_attention_block_identifiers: tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS
|
| 154 |
+
temporal_attention_block_identifiers: tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
|
| 155 |
+
|
| 156 |
+
attention_weight_callback: Callable[[torch.nn.Module], float] = None
|
| 157 |
+
low_frequency_weight_callback: Callable[[torch.nn.Module], float] = None
|
| 158 |
+
high_frequency_weight_callback: Callable[[torch.nn.Module], float] = None
|
| 159 |
+
|
| 160 |
+
tensor_format: str = "BCFHW"
|
| 161 |
+
is_guidance_distilled: bool = False
|
| 162 |
+
|
| 163 |
+
current_timestep_callback: Callable[[], int] = None
|
| 164 |
+
|
| 165 |
+
_unconditional_conditional_input_kwargs_identifiers: list[str] = _UNCOND_COND_INPUT_KWARGS_IDENTIFIERS
|
| 166 |
+
|
| 167 |
+
def __repr__(self) -> str:
|
| 168 |
+
return (
|
| 169 |
+
f"FasterCacheConfig(\n"
|
| 170 |
+
f" spatial_attention_block_skip_range={self.spatial_attention_block_skip_range},\n"
|
| 171 |
+
f" temporal_attention_block_skip_range={self.temporal_attention_block_skip_range},\n"
|
| 172 |
+
f" spatial_attention_timestep_skip_range={self.spatial_attention_timestep_skip_range},\n"
|
| 173 |
+
f" temporal_attention_timestep_skip_range={self.temporal_attention_timestep_skip_range},\n"
|
| 174 |
+
f" low_frequency_weight_update_timestep_range={self.low_frequency_weight_update_timestep_range},\n"
|
| 175 |
+
f" high_frequency_weight_update_timestep_range={self.high_frequency_weight_update_timestep_range},\n"
|
| 176 |
+
f" alpha_low_frequency={self.alpha_low_frequency},\n"
|
| 177 |
+
f" alpha_high_frequency={self.alpha_high_frequency},\n"
|
| 178 |
+
f" unconditional_batch_skip_range={self.unconditional_batch_skip_range},\n"
|
| 179 |
+
f" unconditional_batch_timestep_skip_range={self.unconditional_batch_timestep_skip_range},\n"
|
| 180 |
+
f" spatial_attention_block_identifiers={self.spatial_attention_block_identifiers},\n"
|
| 181 |
+
f" temporal_attention_block_identifiers={self.temporal_attention_block_identifiers},\n"
|
| 182 |
+
f" tensor_format={self.tensor_format},\n"
|
| 183 |
+
f")"
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
class FasterCacheDenoiserState:
|
| 188 |
+
r"""
|
| 189 |
+
State for [FasterCache](https://huggingface.co/papers/2410.19355) top-level denoiser module.
|
| 190 |
+
"""
|
| 191 |
+
|
| 192 |
+
def __init__(self) -> None:
|
| 193 |
+
self.iteration: int = 0
|
| 194 |
+
self.low_frequency_delta: torch.Tensor = None
|
| 195 |
+
self.high_frequency_delta: torch.Tensor = None
|
| 196 |
+
|
| 197 |
+
def reset(self):
|
| 198 |
+
self.iteration = 0
|
| 199 |
+
self.low_frequency_delta = None
|
| 200 |
+
self.high_frequency_delta = None
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
class FasterCacheBlockState:
|
| 204 |
+
r"""
|
| 205 |
+
State for [FasterCache](https://huggingface.co/papers/2410.19355). Every underlying block that FasterCache is
|
| 206 |
+
applied to will have an instance of this state.
|
| 207 |
+
"""
|
| 208 |
+
|
| 209 |
+
def __init__(self) -> None:
|
| 210 |
+
self.iteration: int = 0
|
| 211 |
+
self.batch_size: int = None
|
| 212 |
+
self.cache: tuple[torch.Tensor, torch.Tensor] = None
|
| 213 |
+
|
| 214 |
+
def reset(self):
|
| 215 |
+
self.iteration = 0
|
| 216 |
+
self.batch_size = None
|
| 217 |
+
self.cache = None
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
class FasterCacheDenoiserHook(ModelHook):
|
| 221 |
+
_is_stateful = True
|
| 222 |
+
|
| 223 |
+
def __init__(
|
| 224 |
+
self,
|
| 225 |
+
unconditional_batch_skip_range: int,
|
| 226 |
+
unconditional_batch_timestep_skip_range: tuple[int, int],
|
| 227 |
+
tensor_format: str,
|
| 228 |
+
is_guidance_distilled: bool,
|
| 229 |
+
uncond_cond_input_kwargs_identifiers: list[str],
|
| 230 |
+
current_timestep_callback: Callable[[], int],
|
| 231 |
+
low_frequency_weight_callback: Callable[[torch.nn.Module], torch.Tensor],
|
| 232 |
+
high_frequency_weight_callback: Callable[[torch.nn.Module], torch.Tensor],
|
| 233 |
+
) -> None:
|
| 234 |
+
super().__init__()
|
| 235 |
+
|
| 236 |
+
self.unconditional_batch_skip_range = unconditional_batch_skip_range
|
| 237 |
+
self.unconditional_batch_timestep_skip_range = unconditional_batch_timestep_skip_range
|
| 238 |
+
# We can't easily detect what args are to be split in unconditional and conditional branches. We
|
| 239 |
+
# can only do it for kwargs, hence they are the only ones we split. The args are passed as-is.
|
| 240 |
+
# If a model is to be made compatible with FasterCache, the user must ensure that the inputs that
|
| 241 |
+
# contain batchwise-concatenated unconditional and conditional inputs are passed as kwargs.
|
| 242 |
+
self.uncond_cond_input_kwargs_identifiers = uncond_cond_input_kwargs_identifiers
|
| 243 |
+
self.tensor_format = tensor_format
|
| 244 |
+
self.is_guidance_distilled = is_guidance_distilled
|
| 245 |
+
|
| 246 |
+
self.current_timestep_callback = current_timestep_callback
|
| 247 |
+
self.low_frequency_weight_callback = low_frequency_weight_callback
|
| 248 |
+
self.high_frequency_weight_callback = high_frequency_weight_callback
|
| 249 |
+
|
| 250 |
+
def initialize_hook(self, module):
|
| 251 |
+
self.state = FasterCacheDenoiserState()
|
| 252 |
+
return module
|
| 253 |
+
|
| 254 |
+
@staticmethod
|
| 255 |
+
def _get_cond_input(input: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| 256 |
+
# Note: this method assumes that the input tensor is batchwise-concatenated with unconditional inputs
|
| 257 |
+
# followed by conditional inputs.
|
| 258 |
+
_, cond = input.chunk(2, dim=0)
|
| 259 |
+
return cond
|
| 260 |
+
|
| 261 |
+
def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any:
|
| 262 |
+
# Split the unconditional and conditional inputs. We only want to infer the conditional branch if the
|
| 263 |
+
# requirements for skipping the unconditional branch are met as described in the paper.
|
| 264 |
+
# We skip the unconditional branch only if the following conditions are met:
|
| 265 |
+
# 1. We have completed at least one iteration of the denoiser
|
| 266 |
+
# 2. The current timestep is within the range specified by the user. This is the optimal timestep range
|
| 267 |
+
# where approximating the unconditional branch from the computation of the conditional branch is possible
|
| 268 |
+
# without a significant loss in quality.
|
| 269 |
+
# 3. The current iteration is not a multiple of the unconditional batch skip range. This is done so that
|
| 270 |
+
# we compute the unconditional branch at least once every few iterations to ensure minimal quality loss.
|
| 271 |
+
is_within_timestep_range = (
|
| 272 |
+
self.unconditional_batch_timestep_skip_range[0]
|
| 273 |
+
< self.current_timestep_callback()
|
| 274 |
+
< self.unconditional_batch_timestep_skip_range[1]
|
| 275 |
+
)
|
| 276 |
+
should_skip_uncond = (
|
| 277 |
+
self.state.iteration > 0
|
| 278 |
+
and is_within_timestep_range
|
| 279 |
+
and self.state.iteration % self.unconditional_batch_skip_range != 0
|
| 280 |
+
and not self.is_guidance_distilled
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
if should_skip_uncond:
|
| 284 |
+
is_any_kwarg_uncond = any(k in self.uncond_cond_input_kwargs_identifiers for k in kwargs.keys())
|
| 285 |
+
if is_any_kwarg_uncond:
|
| 286 |
+
logger.debug("FasterCache - Skipping unconditional branch computation")
|
| 287 |
+
args = tuple([self._get_cond_input(arg) if torch.is_tensor(arg) else arg for arg in args])
|
| 288 |
+
kwargs = {
|
| 289 |
+
k: v if k not in self.uncond_cond_input_kwargs_identifiers else self._get_cond_input(v)
|
| 290 |
+
for k, v in kwargs.items()
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
output = self.fn_ref.original_forward(*args, **kwargs)
|
| 294 |
+
|
| 295 |
+
if self.is_guidance_distilled:
|
| 296 |
+
self.state.iteration += 1
|
| 297 |
+
return output
|
| 298 |
+
|
| 299 |
+
if torch.is_tensor(output):
|
| 300 |
+
hidden_states = output
|
| 301 |
+
elif isinstance(output, (tuple, Transformer2DModelOutput)):
|
| 302 |
+
hidden_states = output[0]
|
| 303 |
+
|
| 304 |
+
batch_size = hidden_states.size(0)
|
| 305 |
+
|
| 306 |
+
if should_skip_uncond:
|
| 307 |
+
self.state.low_frequency_delta = self.state.low_frequency_delta * self.low_frequency_weight_callback(
|
| 308 |
+
module
|
| 309 |
+
)
|
| 310 |
+
self.state.high_frequency_delta = self.state.high_frequency_delta * self.high_frequency_weight_callback(
|
| 311 |
+
module
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
if self.tensor_format == "BCFHW":
|
| 315 |
+
hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
|
| 316 |
+
if self.tensor_format == "BCFHW" or self.tensor_format == "BFCHW":
|
| 317 |
+
hidden_states = hidden_states.flatten(0, 1)
|
| 318 |
+
|
| 319 |
+
low_freq_cond, high_freq_cond = _split_low_high_freq(hidden_states.float())
|
| 320 |
+
|
| 321 |
+
# Approximate/compute the unconditional branch outputs as described in Equation 9 and 10 of the paper
|
| 322 |
+
low_freq_uncond = self.state.low_frequency_delta + low_freq_cond
|
| 323 |
+
high_freq_uncond = self.state.high_frequency_delta + high_freq_cond
|
| 324 |
+
uncond_freq = low_freq_uncond + high_freq_uncond
|
| 325 |
+
|
| 326 |
+
uncond_states = torch.fft.ifftshift(uncond_freq)
|
| 327 |
+
uncond_states = torch.fft.ifft2(uncond_states).real
|
| 328 |
+
|
| 329 |
+
if self.tensor_format == "BCFHW" or self.tensor_format == "BFCHW":
|
| 330 |
+
uncond_states = uncond_states.unflatten(0, (batch_size, -1))
|
| 331 |
+
hidden_states = hidden_states.unflatten(0, (batch_size, -1))
|
| 332 |
+
if self.tensor_format == "BCFHW":
|
| 333 |
+
uncond_states = uncond_states.permute(0, 2, 1, 3, 4)
|
| 334 |
+
hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
|
| 335 |
+
|
| 336 |
+
# Concatenate the approximated unconditional and predicted conditional branches
|
| 337 |
+
uncond_states = uncond_states.to(hidden_states.dtype)
|
| 338 |
+
hidden_states = torch.cat([uncond_states, hidden_states], dim=0)
|
| 339 |
+
else:
|
| 340 |
+
uncond_states, cond_states = hidden_states.chunk(2, dim=0)
|
| 341 |
+
if self.tensor_format == "BCFHW":
|
| 342 |
+
uncond_states = uncond_states.permute(0, 2, 1, 3, 4)
|
| 343 |
+
cond_states = cond_states.permute(0, 2, 1, 3, 4)
|
| 344 |
+
if self.tensor_format == "BCFHW" or self.tensor_format == "BFCHW":
|
| 345 |
+
uncond_states = uncond_states.flatten(0, 1)
|
| 346 |
+
cond_states = cond_states.flatten(0, 1)
|
| 347 |
+
|
| 348 |
+
low_freq_uncond, high_freq_uncond = _split_low_high_freq(uncond_states.float())
|
| 349 |
+
low_freq_cond, high_freq_cond = _split_low_high_freq(cond_states.float())
|
| 350 |
+
self.state.low_frequency_delta = low_freq_uncond - low_freq_cond
|
| 351 |
+
self.state.high_frequency_delta = high_freq_uncond - high_freq_cond
|
| 352 |
+
|
| 353 |
+
self.state.iteration += 1
|
| 354 |
+
if torch.is_tensor(output):
|
| 355 |
+
output = hidden_states
|
| 356 |
+
elif isinstance(output, tuple):
|
| 357 |
+
output = (hidden_states, *output[1:])
|
| 358 |
+
else:
|
| 359 |
+
output.sample = hidden_states
|
| 360 |
+
|
| 361 |
+
return output
|
| 362 |
+
|
| 363 |
+
def reset_state(self, module: torch.nn.Module) -> torch.nn.Module:
|
| 364 |
+
self.state.reset()
|
| 365 |
+
return module
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
class FasterCacheBlockHook(ModelHook):
|
| 369 |
+
_is_stateful = True
|
| 370 |
+
|
| 371 |
+
def __init__(
|
| 372 |
+
self,
|
| 373 |
+
block_skip_range: int,
|
| 374 |
+
timestep_skip_range: tuple[int, int],
|
| 375 |
+
is_guidance_distilled: bool,
|
| 376 |
+
weight_callback: Callable[[torch.nn.Module], float],
|
| 377 |
+
current_timestep_callback: Callable[[], int],
|
| 378 |
+
) -> None:
|
| 379 |
+
super().__init__()
|
| 380 |
+
|
| 381 |
+
self.block_skip_range = block_skip_range
|
| 382 |
+
self.timestep_skip_range = timestep_skip_range
|
| 383 |
+
self.is_guidance_distilled = is_guidance_distilled
|
| 384 |
+
|
| 385 |
+
self.weight_callback = weight_callback
|
| 386 |
+
self.current_timestep_callback = current_timestep_callback
|
| 387 |
+
|
| 388 |
+
def initialize_hook(self, module):
|
| 389 |
+
self.state = FasterCacheBlockState()
|
| 390 |
+
return module
|
| 391 |
+
|
| 392 |
+
def _compute_approximated_attention_output(
|
| 393 |
+
self, t_2_output: torch.Tensor, t_output: torch.Tensor, weight: float, batch_size: int
|
| 394 |
+
) -> torch.Tensor:
|
| 395 |
+
if t_2_output.size(0) != batch_size:
|
| 396 |
+
# The cache t_2_output contains both batchwise-concatenated unconditional-conditional branch outputs. Just
|
| 397 |
+
# take the conditional branch outputs.
|
| 398 |
+
assert t_2_output.size(0) == 2 * batch_size
|
| 399 |
+
t_2_output = t_2_output[batch_size:]
|
| 400 |
+
if t_output.size(0) != batch_size:
|
| 401 |
+
# The cache t_output contains both batchwise-concatenated unconditional-conditional branch outputs. Just
|
| 402 |
+
# take the conditional branch outputs.
|
| 403 |
+
assert t_output.size(0) == 2 * batch_size
|
| 404 |
+
t_output = t_output[batch_size:]
|
| 405 |
+
return t_output + (t_output - t_2_output) * weight
|
| 406 |
+
|
| 407 |
+
def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any:
|
| 408 |
+
batch_size = [
|
| 409 |
+
*[arg.size(0) for arg in args if torch.is_tensor(arg)],
|
| 410 |
+
*[v.size(0) for v in kwargs.values() if torch.is_tensor(v)],
|
| 411 |
+
][0]
|
| 412 |
+
if self.state.batch_size is None:
|
| 413 |
+
# Will be updated on first forward pass through the denoiser
|
| 414 |
+
self.state.batch_size = batch_size
|
| 415 |
+
|
| 416 |
+
# If we have to skip due to the skip conditions, then let's skip as expected.
|
| 417 |
+
# But, we can't skip if the denoiser wants to infer both unconditional and conditional branches. This
|
| 418 |
+
# is because the expected output shapes of attention layer will not match if we only return values from
|
| 419 |
+
# the cache (which only caches conditional branch outputs). So, if state.batch_size (which is the true
|
| 420 |
+
# unconditional-conditional batch size) is same as the current batch size, we don't perform the layer
|
| 421 |
+
# skip. Otherwise, we conditionally skip the layer based on what state.skip_callback returns.
|
| 422 |
+
is_within_timestep_range = (
|
| 423 |
+
self.timestep_skip_range[0] < self.current_timestep_callback() < self.timestep_skip_range[1]
|
| 424 |
+
)
|
| 425 |
+
if not is_within_timestep_range:
|
| 426 |
+
should_skip_attention = False
|
| 427 |
+
else:
|
| 428 |
+
should_compute_attention = self.state.iteration > 0 and self.state.iteration % self.block_skip_range == 0
|
| 429 |
+
should_skip_attention = not should_compute_attention
|
| 430 |
+
if should_skip_attention:
|
| 431 |
+
should_skip_attention = self.is_guidance_distilled or self.state.batch_size != batch_size
|
| 432 |
+
|
| 433 |
+
if should_skip_attention:
|
| 434 |
+
logger.debug("FasterCache - Skipping attention and using approximation")
|
| 435 |
+
if torch.is_tensor(self.state.cache[-1]):
|
| 436 |
+
t_2_output, t_output = self.state.cache
|
| 437 |
+
weight = self.weight_callback(module)
|
| 438 |
+
output = self._compute_approximated_attention_output(t_2_output, t_output, weight, batch_size)
|
| 439 |
+
else:
|
| 440 |
+
# The cache contains multiple tensors from past N iterations (N=2 for FasterCache). We need to handle all of them.
|
| 441 |
+
# Diffusers blocks can return multiple tensors - let's call them [A, B, C, ...] for simplicity.
|
| 442 |
+
# In our cache, we would have [[A_1, B_1, C_1, ...], [A_2, B_2, C_2, ...], ...] where each list is the output from
|
| 443 |
+
# a forward pass of the block. We need to compute the approximated output for each of these tensors.
|
| 444 |
+
# The zip(*state.cache) operation will give us [(A_1, A_2, ...), (B_1, B_2, ...), (C_1, C_2, ...), ...] which
|
| 445 |
+
# allows us to compute the approximated attention output for each tensor in the cache.
|
| 446 |
+
output = ()
|
| 447 |
+
for t_2_output, t_output in zip(*self.state.cache):
|
| 448 |
+
result = self._compute_approximated_attention_output(
|
| 449 |
+
t_2_output, t_output, self.weight_callback(module), batch_size
|
| 450 |
+
)
|
| 451 |
+
output += (result,)
|
| 452 |
+
else:
|
| 453 |
+
logger.debug("FasterCache - Computing attention")
|
| 454 |
+
output = self.fn_ref.original_forward(*args, **kwargs)
|
| 455 |
+
|
| 456 |
+
# Note that the following condition for getting hidden_states should suffice since Diffusers blocks either return
|
| 457 |
+
# a single hidden_states tensor, or a tuple of (hidden_states, encoder_hidden_states) tensors. We need to handle
|
| 458 |
+
# both cases.
|
| 459 |
+
if torch.is_tensor(output):
|
| 460 |
+
cache_output = output
|
| 461 |
+
if not self.is_guidance_distilled and cache_output.size(0) == self.state.batch_size:
|
| 462 |
+
# The output here can be both unconditional-conditional branch outputs or just conditional branch outputs.
|
| 463 |
+
# This is determined at the higher-level denoiser module. We only want to cache the conditional branch outputs.
|
| 464 |
+
cache_output = cache_output.chunk(2, dim=0)[1]
|
| 465 |
+
else:
|
| 466 |
+
# Cache all return values and perform the same operation as above
|
| 467 |
+
cache_output = ()
|
| 468 |
+
for out in output:
|
| 469 |
+
if not self.is_guidance_distilled and out.size(0) == self.state.batch_size:
|
| 470 |
+
out = out.chunk(2, dim=0)[1]
|
| 471 |
+
cache_output += (out,)
|
| 472 |
+
|
| 473 |
+
if self.state.cache is None:
|
| 474 |
+
self.state.cache = [cache_output, cache_output]
|
| 475 |
+
else:
|
| 476 |
+
self.state.cache = [self.state.cache[-1], cache_output]
|
| 477 |
+
|
| 478 |
+
self.state.iteration += 1
|
| 479 |
+
return output
|
| 480 |
+
|
| 481 |
+
def reset_state(self, module: torch.nn.Module) -> torch.nn.Module:
|
| 482 |
+
self.state.reset()
|
| 483 |
+
return module
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
def apply_faster_cache(module: torch.nn.Module, config: FasterCacheConfig) -> None:
|
| 487 |
+
r"""
|
| 488 |
+
Applies [FasterCache](https://huggingface.co/papers/2410.19355) to a given pipeline.
|
| 489 |
+
|
| 490 |
+
Args:
|
| 491 |
+
module (`torch.nn.Module`):
|
| 492 |
+
The pytorch module to apply FasterCache to. Typically, this should be a transformer architecture supported
|
| 493 |
+
in Diffusers, such as `CogVideoXTransformer3DModel`, but external implementations may also work.
|
| 494 |
+
config (`FasterCacheConfig`):
|
| 495 |
+
The configuration to use for FasterCache.
|
| 496 |
+
|
| 497 |
+
Example:
|
| 498 |
+
```python
|
| 499 |
+
>>> import torch
|
| 500 |
+
>>> from diffusers import CogVideoXPipeline, FasterCacheConfig, apply_faster_cache
|
| 501 |
+
|
| 502 |
+
>>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
|
| 503 |
+
>>> pipe.to("cuda")
|
| 504 |
+
|
| 505 |
+
>>> config = FasterCacheConfig(
|
| 506 |
+
... spatial_attention_block_skip_range=2,
|
| 507 |
+
... spatial_attention_timestep_skip_range=(-1, 681),
|
| 508 |
+
... low_frequency_weight_update_timestep_range=(99, 641),
|
| 509 |
+
... high_frequency_weight_update_timestep_range=(-1, 301),
|
| 510 |
+
... spatial_attention_block_identifiers=["transformer_blocks"],
|
| 511 |
+
... attention_weight_callback=lambda _: 0.3,
|
| 512 |
+
... tensor_format="BFCHW",
|
| 513 |
+
... )
|
| 514 |
+
>>> apply_faster_cache(pipe.transformer, config)
|
| 515 |
+
```
|
| 516 |
+
"""
|
| 517 |
+
|
| 518 |
+
logger.warning(
|
| 519 |
+
"FasterCache is a purely experimental feature and may not work as expected. Not all models support FasterCache. "
|
| 520 |
+
"The API is subject to change in future releases, with no guarantee of backward compatibility. Please report any issues at "
|
| 521 |
+
"https://github.com/huggingface/diffusers/issues."
|
| 522 |
+
)
|
| 523 |
+
|
| 524 |
+
if config.attention_weight_callback is None:
|
| 525 |
+
# If the user has not provided a weight callback, we default to 0.5 for all timesteps.
|
| 526 |
+
# In the paper, they recommend using a gradually increasing weight from 0 to 1 as the inference progresses, but
|
| 527 |
+
# this depends from model-to-model. It is required by the user to provide a weight callback if they want to
|
| 528 |
+
# use a different weight function. Defaulting to 0.5 works well in practice for most cases.
|
| 529 |
+
logger.warning(
|
| 530 |
+
"No `attention_weight_callback` provided when enabling FasterCache. Defaulting to using a weight of 0.5 for all timesteps."
|
| 531 |
+
)
|
| 532 |
+
config.attention_weight_callback = lambda _: 0.5
|
| 533 |
+
|
| 534 |
+
if config.low_frequency_weight_callback is None:
|
| 535 |
+
logger.debug(
|
| 536 |
+
"Low frequency weight callback not provided when enabling FasterCache. Defaulting to behaviour described in the paper."
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
def low_frequency_weight_callback(module: torch.nn.Module) -> float:
|
| 540 |
+
is_within_range = (
|
| 541 |
+
config.low_frequency_weight_update_timestep_range[0]
|
| 542 |
+
< config.current_timestep_callback()
|
| 543 |
+
< config.low_frequency_weight_update_timestep_range[1]
|
| 544 |
+
)
|
| 545 |
+
return config.alpha_low_frequency if is_within_range else 1.0
|
| 546 |
+
|
| 547 |
+
config.low_frequency_weight_callback = low_frequency_weight_callback
|
| 548 |
+
|
| 549 |
+
if config.high_frequency_weight_callback is None:
|
| 550 |
+
logger.debug(
|
| 551 |
+
"High frequency weight callback not provided when enabling FasterCache. Defaulting to behaviour described in the paper."
|
| 552 |
+
)
|
| 553 |
+
|
| 554 |
+
def high_frequency_weight_callback(module: torch.nn.Module) -> float:
|
| 555 |
+
is_within_range = (
|
| 556 |
+
config.high_frequency_weight_update_timestep_range[0]
|
| 557 |
+
< config.current_timestep_callback()
|
| 558 |
+
< config.high_frequency_weight_update_timestep_range[1]
|
| 559 |
+
)
|
| 560 |
+
return config.alpha_high_frequency if is_within_range else 1.0
|
| 561 |
+
|
| 562 |
+
config.high_frequency_weight_callback = high_frequency_weight_callback
|
| 563 |
+
|
| 564 |
+
supported_tensor_formats = ["BCFHW", "BFCHW", "BCHW"] # TODO(aryan): Support BSC for LTX Video
|
| 565 |
+
if config.tensor_format not in supported_tensor_formats:
|
| 566 |
+
raise ValueError(f"`tensor_format` must be one of {supported_tensor_formats}, but got {config.tensor_format}.")
|
| 567 |
+
|
| 568 |
+
_apply_faster_cache_on_denoiser(module, config)
|
| 569 |
+
|
| 570 |
+
for name, submodule in module.named_modules():
|
| 571 |
+
if not isinstance(submodule, _ATTENTION_CLASSES):
|
| 572 |
+
continue
|
| 573 |
+
if any(re.search(identifier, name) is not None for identifier in _TRANSFORMER_BLOCK_IDENTIFIERS):
|
| 574 |
+
_apply_faster_cache_on_attention_class(name, submodule, config)
|
| 575 |
+
|
| 576 |
+
|
| 577 |
+
def _apply_faster_cache_on_denoiser(module: torch.nn.Module, config: FasterCacheConfig) -> None:
|
| 578 |
+
hook = FasterCacheDenoiserHook(
|
| 579 |
+
config.unconditional_batch_skip_range,
|
| 580 |
+
config.unconditional_batch_timestep_skip_range,
|
| 581 |
+
config.tensor_format,
|
| 582 |
+
config.is_guidance_distilled,
|
| 583 |
+
config._unconditional_conditional_input_kwargs_identifiers,
|
| 584 |
+
config.current_timestep_callback,
|
| 585 |
+
config.low_frequency_weight_callback,
|
| 586 |
+
config.high_frequency_weight_callback,
|
| 587 |
+
)
|
| 588 |
+
registry = HookRegistry.check_if_exists_or_initialize(module)
|
| 589 |
+
registry.register_hook(hook, _FASTER_CACHE_DENOISER_HOOK)
|
| 590 |
+
|
| 591 |
+
|
| 592 |
+
def _apply_faster_cache_on_attention_class(name: str, module: AttentionModuleMixin, config: FasterCacheConfig) -> None:
|
| 593 |
+
is_spatial_self_attention = (
|
| 594 |
+
any(re.search(identifier, name) is not None for identifier in config.spatial_attention_block_identifiers)
|
| 595 |
+
and config.spatial_attention_block_skip_range is not None
|
| 596 |
+
and not getattr(module, "is_cross_attention", False)
|
| 597 |
+
)
|
| 598 |
+
is_temporal_self_attention = (
|
| 599 |
+
any(re.search(identifier, name) is not None for identifier in config.temporal_attention_block_identifiers)
|
| 600 |
+
and config.temporal_attention_block_skip_range is not None
|
| 601 |
+
and not module.is_cross_attention
|
| 602 |
+
)
|
| 603 |
+
|
| 604 |
+
block_skip_range, timestep_skip_range, block_type = None, None, None
|
| 605 |
+
if is_spatial_self_attention:
|
| 606 |
+
block_skip_range = config.spatial_attention_block_skip_range
|
| 607 |
+
timestep_skip_range = config.spatial_attention_timestep_skip_range
|
| 608 |
+
block_type = "spatial"
|
| 609 |
+
elif is_temporal_self_attention:
|
| 610 |
+
block_skip_range = config.temporal_attention_block_skip_range
|
| 611 |
+
timestep_skip_range = config.temporal_attention_timestep_skip_range
|
| 612 |
+
block_type = "temporal"
|
| 613 |
+
|
| 614 |
+
if block_skip_range is None or timestep_skip_range is None:
|
| 615 |
+
logger.debug(
|
| 616 |
+
f'Unable to apply FasterCache to the selected layer: "{name}" because it does '
|
| 617 |
+
f"not match any of the required criteria for spatial or temporal attention layers. Note, "
|
| 618 |
+
f"however, that this layer may still be valid for applying PAB. Please specify the correct "
|
| 619 |
+
f"block identifiers in the configuration or use the specialized `apply_faster_cache_on_module` "
|
| 620 |
+
f"function to apply FasterCache to this layer."
|
| 621 |
+
)
|
| 622 |
+
return
|
| 623 |
+
|
| 624 |
+
logger.debug(f"Enabling FasterCache ({block_type}) for layer: {name}")
|
| 625 |
+
hook = FasterCacheBlockHook(
|
| 626 |
+
block_skip_range,
|
| 627 |
+
timestep_skip_range,
|
| 628 |
+
config.is_guidance_distilled,
|
| 629 |
+
config.attention_weight_callback,
|
| 630 |
+
config.current_timestep_callback,
|
| 631 |
+
)
|
| 632 |
+
registry = HookRegistry.check_if_exists_or_initialize(module)
|
| 633 |
+
registry.register_hook(hook, _FASTER_CACHE_BLOCK_HOOK)
|
| 634 |
+
|
| 635 |
+
|
| 636 |
+
# Reference: https://github.com/Vchitect/FasterCache/blob/fab32c15014636dc854948319c0a9a8d92c7acb4/scripts/latte/faster_cache_sample_latte.py#L127C1-L143C39
|
| 637 |
+
@torch.no_grad()
|
| 638 |
+
def _split_low_high_freq(x):
|
| 639 |
+
fft = torch.fft.fft2(x)
|
| 640 |
+
fft_shifted = torch.fft.fftshift(fft)
|
| 641 |
+
height, width = x.shape[-2:]
|
| 642 |
+
radius = min(height, width) // 5
|
| 643 |
+
|
| 644 |
+
y_grid, x_grid = torch.meshgrid(torch.arange(height), torch.arange(width))
|
| 645 |
+
center_x, center_y = width // 2, height // 2
|
| 646 |
+
mask = (x_grid - center_x) ** 2 + (y_grid - center_y) ** 2 <= radius**2
|
| 647 |
+
|
| 648 |
+
low_freq_mask = mask.unsqueeze(0).unsqueeze(0).to(x.device)
|
| 649 |
+
high_freq_mask = ~low_freq_mask
|
| 650 |
+
|
| 651 |
+
low_freq_fft = fft_shifted * low_freq_mask
|
| 652 |
+
high_freq_fft = fft_shifted * high_freq_mask
|
| 653 |
+
|
| 654 |
+
return low_freq_fft, high_freq_fft
|
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/first_block_cache.py
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from dataclasses import dataclass
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
|
| 19 |
+
from ..utils import get_logger
|
| 20 |
+
from ..utils.torch_utils import unwrap_module
|
| 21 |
+
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS
|
| 22 |
+
from ._helpers import TransformerBlockRegistry
|
| 23 |
+
from .hooks import BaseState, HookRegistry, ModelHook, StateManager
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
logger = get_logger(__name__) # pylint: disable=invalid-name
|
| 27 |
+
|
| 28 |
+
_FBC_LEADER_BLOCK_HOOK = "fbc_leader_block_hook"
|
| 29 |
+
_FBC_BLOCK_HOOK = "fbc_block_hook"
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@dataclass
|
| 33 |
+
class FirstBlockCacheConfig:
|
| 34 |
+
r"""
|
| 35 |
+
Configuration for [First Block
|
| 36 |
+
Cache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching).
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
threshold (`float`, defaults to `0.05`):
|
| 40 |
+
The threshold to determine whether or not a forward pass through all layers of the model is required. A
|
| 41 |
+
higher threshold usually results in a forward pass through a lower number of layers and faster inference,
|
| 42 |
+
but might lead to poorer generation quality. A lower threshold may not result in significant generation
|
| 43 |
+
speedup. The threshold is compared against the absmean difference of the residuals between the current and
|
| 44 |
+
cached outputs from the first transformer block. If the difference is below the threshold, the forward pass
|
| 45 |
+
is skipped.
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
threshold: float = 0.05
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class FBCSharedBlockState(BaseState):
|
| 52 |
+
def __init__(self) -> None:
|
| 53 |
+
super().__init__()
|
| 54 |
+
|
| 55 |
+
self.head_block_output: torch.Tensor | tuple[torch.Tensor, ...] = None
|
| 56 |
+
self.head_block_residual: torch.Tensor = None
|
| 57 |
+
self.tail_block_residuals: torch.Tensor | tuple[torch.Tensor, ...] = None
|
| 58 |
+
self.should_compute: bool = True
|
| 59 |
+
|
| 60 |
+
def reset(self):
|
| 61 |
+
self.tail_block_residuals = None
|
| 62 |
+
self.should_compute = True
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class FBCHeadBlockHook(ModelHook):
|
| 66 |
+
_is_stateful = True
|
| 67 |
+
|
| 68 |
+
def __init__(self, state_manager: StateManager, threshold: float):
|
| 69 |
+
self.state_manager = state_manager
|
| 70 |
+
self.threshold = threshold
|
| 71 |
+
self._metadata = None
|
| 72 |
+
|
| 73 |
+
def initialize_hook(self, module):
|
| 74 |
+
unwrapped_module = unwrap_module(module)
|
| 75 |
+
self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__)
|
| 76 |
+
return module
|
| 77 |
+
|
| 78 |
+
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
| 79 |
+
original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs)
|
| 80 |
+
|
| 81 |
+
output = self.fn_ref.original_forward(*args, **kwargs)
|
| 82 |
+
is_output_tuple = isinstance(output, tuple)
|
| 83 |
+
|
| 84 |
+
if is_output_tuple:
|
| 85 |
+
hidden_states_residual = output[self._metadata.return_hidden_states_index] - original_hidden_states
|
| 86 |
+
else:
|
| 87 |
+
hidden_states_residual = output - original_hidden_states
|
| 88 |
+
|
| 89 |
+
shared_state: FBCSharedBlockState = self.state_manager.get_state()
|
| 90 |
+
hidden_states = encoder_hidden_states = None
|
| 91 |
+
should_compute = self._should_compute_remaining_blocks(hidden_states_residual)
|
| 92 |
+
shared_state.should_compute = should_compute
|
| 93 |
+
|
| 94 |
+
if not should_compute:
|
| 95 |
+
# Apply caching
|
| 96 |
+
if is_output_tuple:
|
| 97 |
+
hidden_states = (
|
| 98 |
+
shared_state.tail_block_residuals[0] + output[self._metadata.return_hidden_states_index]
|
| 99 |
+
)
|
| 100 |
+
else:
|
| 101 |
+
hidden_states = shared_state.tail_block_residuals[0] + output
|
| 102 |
+
|
| 103 |
+
if self._metadata.return_encoder_hidden_states_index is not None:
|
| 104 |
+
assert is_output_tuple
|
| 105 |
+
encoder_hidden_states = (
|
| 106 |
+
shared_state.tail_block_residuals[1] + output[self._metadata.return_encoder_hidden_states_index]
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
if is_output_tuple:
|
| 110 |
+
return_output = [None] * len(output)
|
| 111 |
+
return_output[self._metadata.return_hidden_states_index] = hidden_states
|
| 112 |
+
return_output[self._metadata.return_encoder_hidden_states_index] = encoder_hidden_states
|
| 113 |
+
return_output = tuple(return_output)
|
| 114 |
+
else:
|
| 115 |
+
return_output = hidden_states
|
| 116 |
+
output = return_output
|
| 117 |
+
else:
|
| 118 |
+
if is_output_tuple:
|
| 119 |
+
head_block_output = [None] * len(output)
|
| 120 |
+
head_block_output[0] = output[self._metadata.return_hidden_states_index]
|
| 121 |
+
head_block_output[1] = output[self._metadata.return_encoder_hidden_states_index]
|
| 122 |
+
else:
|
| 123 |
+
head_block_output = output
|
| 124 |
+
shared_state.head_block_output = head_block_output
|
| 125 |
+
shared_state.head_block_residual = hidden_states_residual
|
| 126 |
+
|
| 127 |
+
return output
|
| 128 |
+
|
| 129 |
+
def reset_state(self, module):
|
| 130 |
+
self.state_manager.reset()
|
| 131 |
+
return module
|
| 132 |
+
|
| 133 |
+
@torch.compiler.disable
|
| 134 |
+
def _should_compute_remaining_blocks(self, hidden_states_residual: torch.Tensor) -> bool:
|
| 135 |
+
shared_state = self.state_manager.get_state()
|
| 136 |
+
if shared_state.head_block_residual is None:
|
| 137 |
+
return True
|
| 138 |
+
prev_hidden_states_residual = shared_state.head_block_residual
|
| 139 |
+
absmean = (hidden_states_residual - prev_hidden_states_residual).abs().mean()
|
| 140 |
+
prev_hidden_states_absmean = prev_hidden_states_residual.abs().mean()
|
| 141 |
+
diff = (absmean / prev_hidden_states_absmean).item()
|
| 142 |
+
return diff > self.threshold
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class FBCBlockHook(ModelHook):
|
| 146 |
+
def __init__(self, state_manager: StateManager, is_tail: bool = False):
|
| 147 |
+
super().__init__()
|
| 148 |
+
self.state_manager = state_manager
|
| 149 |
+
self.is_tail = is_tail
|
| 150 |
+
self._metadata = None
|
| 151 |
+
|
| 152 |
+
def initialize_hook(self, module):
|
| 153 |
+
unwrapped_module = unwrap_module(module)
|
| 154 |
+
self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__)
|
| 155 |
+
return module
|
| 156 |
+
|
| 157 |
+
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
| 158 |
+
original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs)
|
| 159 |
+
original_encoder_hidden_states = None
|
| 160 |
+
if self._metadata.return_encoder_hidden_states_index is not None:
|
| 161 |
+
original_encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs(
|
| 162 |
+
"encoder_hidden_states", args, kwargs
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
shared_state = self.state_manager.get_state()
|
| 166 |
+
|
| 167 |
+
if shared_state.should_compute:
|
| 168 |
+
output = self.fn_ref.original_forward(*args, **kwargs)
|
| 169 |
+
if self.is_tail:
|
| 170 |
+
hidden_states_residual = encoder_hidden_states_residual = None
|
| 171 |
+
if isinstance(output, tuple):
|
| 172 |
+
hidden_states_residual = (
|
| 173 |
+
output[self._metadata.return_hidden_states_index] - shared_state.head_block_output[0]
|
| 174 |
+
)
|
| 175 |
+
encoder_hidden_states_residual = (
|
| 176 |
+
output[self._metadata.return_encoder_hidden_states_index] - shared_state.head_block_output[1]
|
| 177 |
+
)
|
| 178 |
+
else:
|
| 179 |
+
hidden_states_residual = output - shared_state.head_block_output
|
| 180 |
+
shared_state.tail_block_residuals = (hidden_states_residual, encoder_hidden_states_residual)
|
| 181 |
+
return output
|
| 182 |
+
|
| 183 |
+
if original_encoder_hidden_states is None:
|
| 184 |
+
return_output = original_hidden_states
|
| 185 |
+
else:
|
| 186 |
+
return_output = [None, None]
|
| 187 |
+
return_output[self._metadata.return_hidden_states_index] = original_hidden_states
|
| 188 |
+
return_output[self._metadata.return_encoder_hidden_states_index] = original_encoder_hidden_states
|
| 189 |
+
return_output = tuple(return_output)
|
| 190 |
+
return return_output
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def apply_first_block_cache(module: torch.nn.Module, config: FirstBlockCacheConfig) -> None:
|
| 194 |
+
"""
|
| 195 |
+
Applies [First Block
|
| 196 |
+
Cache](https://github.com/chengzeyi/ParaAttention/blob/4de137c5b96416489f06e43e19f2c14a772e28fd/README.md#first-block-cache-our-dynamic-caching)
|
| 197 |
+
to a given module.
|
| 198 |
+
|
| 199 |
+
First Block Cache builds on the ideas of [TeaCache](https://huggingface.co/papers/2411.19108). It is much simpler
|
| 200 |
+
to implement generically for a wide range of models and has been integrated first for experimental purposes.
|
| 201 |
+
|
| 202 |
+
Args:
|
| 203 |
+
module (`torch.nn.Module`):
|
| 204 |
+
The pytorch module to apply FBCache to. Typically, this should be a transformer architecture supported in
|
| 205 |
+
Diffusers, such as `CogVideoXTransformer3DModel`, but external implementations may also work.
|
| 206 |
+
config (`FirstBlockCacheConfig`):
|
| 207 |
+
The configuration to use for applying the FBCache method.
|
| 208 |
+
|
| 209 |
+
Example:
|
| 210 |
+
```python
|
| 211 |
+
>>> import torch
|
| 212 |
+
>>> from diffusers import CogView4Pipeline
|
| 213 |
+
>>> from diffusers.hooks import apply_first_block_cache, FirstBlockCacheConfig
|
| 214 |
+
|
| 215 |
+
>>> pipe = CogView4Pipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16)
|
| 216 |
+
>>> pipe.to("cuda")
|
| 217 |
+
|
| 218 |
+
>>> apply_first_block_cache(pipe.transformer, FirstBlockCacheConfig(threshold=0.2))
|
| 219 |
+
|
| 220 |
+
>>> prompt = "A photo of an astronaut riding a horse on mars"
|
| 221 |
+
>>> image = pipe(prompt, generator=torch.Generator().manual_seed(42)).images[0]
|
| 222 |
+
>>> image.save("output.png")
|
| 223 |
+
```
|
| 224 |
+
"""
|
| 225 |
+
|
| 226 |
+
state_manager = StateManager(FBCSharedBlockState, (), {})
|
| 227 |
+
remaining_blocks = []
|
| 228 |
+
|
| 229 |
+
for name, submodule in module.named_children():
|
| 230 |
+
if name not in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS or not isinstance(submodule, torch.nn.ModuleList):
|
| 231 |
+
continue
|
| 232 |
+
for index, block in enumerate(submodule):
|
| 233 |
+
remaining_blocks.append((f"{name}.{index}", block))
|
| 234 |
+
|
| 235 |
+
head_block_name, head_block = remaining_blocks.pop(0)
|
| 236 |
+
tail_block_name, tail_block = remaining_blocks.pop(-1)
|
| 237 |
+
|
| 238 |
+
logger.debug(f"Applying FBCHeadBlockHook to '{head_block_name}'")
|
| 239 |
+
_apply_fbc_head_block_hook(head_block, state_manager, config.threshold)
|
| 240 |
+
|
| 241 |
+
for name, block in remaining_blocks:
|
| 242 |
+
logger.debug(f"Applying FBCBlockHook to '{name}'")
|
| 243 |
+
_apply_fbc_block_hook(block, state_manager)
|
| 244 |
+
|
| 245 |
+
logger.debug(f"Applying FBCBlockHook to tail block '{tail_block_name}'")
|
| 246 |
+
_apply_fbc_block_hook(tail_block, state_manager, is_tail=True)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def _apply_fbc_head_block_hook(block: torch.nn.Module, state_manager: StateManager, threshold: float) -> None:
|
| 250 |
+
registry = HookRegistry.check_if_exists_or_initialize(block)
|
| 251 |
+
hook = FBCHeadBlockHook(state_manager, threshold)
|
| 252 |
+
registry.register_hook(hook, _FBC_LEADER_BLOCK_HOOK)
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def _apply_fbc_block_hook(block: torch.nn.Module, state_manager: StateManager, is_tail: bool = False) -> None:
|
| 256 |
+
registry = HookRegistry.check_if_exists_or_initialize(block)
|
| 257 |
+
hook = FBCBlockHook(state_manager, is_tail)
|
| 258 |
+
registry.register_hook(hook, _FBC_BLOCK_HOOK)
|
URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/hooks/group_offloading.py
ADDED
|
@@ -0,0 +1,966 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import hashlib
|
| 16 |
+
import os
|
| 17 |
+
from contextlib import contextmanager, nullcontext
|
| 18 |
+
from dataclasses import dataclass, replace
|
| 19 |
+
from enum import Enum
|
| 20 |
+
from typing import Set
|
| 21 |
+
|
| 22 |
+
import safetensors.torch
|
| 23 |
+
import torch
|
| 24 |
+
|
| 25 |
+
from ..utils import get_logger, is_accelerate_available
|
| 26 |
+
from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
|
| 27 |
+
from .hooks import HookRegistry, ModelHook
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
if is_accelerate_available():
|
| 31 |
+
from accelerate.hooks import AlignDevicesHook, CpuOffload
|
| 32 |
+
from accelerate.utils import send_to_device
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
logger = get_logger(__name__) # pylint: disable=invalid-name
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# fmt: off
|
| 39 |
+
_GROUP_OFFLOADING = "group_offloading"
|
| 40 |
+
_LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
|
| 41 |
+
_LAZY_PREFETCH_GROUP_OFFLOADING = "lazy_prefetch_group_offloading"
|
| 42 |
+
_GROUP_ID_LAZY_LEAF = "lazy_leafs"
|
| 43 |
+
# fmt: on
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class GroupOffloadingType(str, Enum):
|
| 47 |
+
BLOCK_LEVEL = "block_level"
|
| 48 |
+
LEAF_LEVEL = "leaf_level"
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@dataclass
|
| 52 |
+
class GroupOffloadingConfig:
|
| 53 |
+
onload_device: torch.device
|
| 54 |
+
offload_device: torch.device
|
| 55 |
+
offload_type: GroupOffloadingType
|
| 56 |
+
non_blocking: bool
|
| 57 |
+
record_stream: bool
|
| 58 |
+
low_cpu_mem_usage: bool
|
| 59 |
+
num_blocks_per_group: int | None = None
|
| 60 |
+
offload_to_disk_path: str | None = None
|
| 61 |
+
stream: torch.cuda.Stream | torch.Stream | None = None
|
| 62 |
+
block_modules: list[str] | None = None
|
| 63 |
+
exclude_kwargs: list[str] | None = None
|
| 64 |
+
module_prefix: str = ""
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class ModuleGroup:
|
| 68 |
+
def __init__(
|
| 69 |
+
self,
|
| 70 |
+
modules: list[torch.nn.Module],
|
| 71 |
+
offload_device: torch.device,
|
| 72 |
+
onload_device: torch.device,
|
| 73 |
+
offload_leader: torch.nn.Module,
|
| 74 |
+
onload_leader: torch.nn.Module | None = None,
|
| 75 |
+
parameters: list[torch.nn.Parameter] | None = None,
|
| 76 |
+
buffers: list[torch.Tensor] | None = None,
|
| 77 |
+
non_blocking: bool = False,
|
| 78 |
+
stream: torch.cuda.Stream | torch.Stream | None = None,
|
| 79 |
+
record_stream: bool | None = False,
|
| 80 |
+
low_cpu_mem_usage: bool = False,
|
| 81 |
+
onload_self: bool = True,
|
| 82 |
+
offload_to_disk_path: str | None = None,
|
| 83 |
+
group_id: int | str | None = None,
|
| 84 |
+
) -> None:
|
| 85 |
+
self.modules = modules
|
| 86 |
+
self.offload_device = offload_device
|
| 87 |
+
self.onload_device = onload_device
|
| 88 |
+
self.offload_leader = offload_leader
|
| 89 |
+
self.onload_leader = onload_leader
|
| 90 |
+
self.parameters = parameters or []
|
| 91 |
+
self.buffers = buffers or []
|
| 92 |
+
self.non_blocking = non_blocking or stream is not None
|
| 93 |
+
self.stream = stream
|
| 94 |
+
self.record_stream = record_stream
|
| 95 |
+
self.onload_self = onload_self
|
| 96 |
+
self.low_cpu_mem_usage = low_cpu_mem_usage
|
| 97 |
+
|
| 98 |
+
self.offload_to_disk_path = offload_to_disk_path
|
| 99 |
+
self._is_offloaded_to_disk = False
|
| 100 |
+
|
| 101 |
+
if self.offload_to_disk_path is not None:
|
| 102 |
+
# Instead of `group_id or str(id(self))` we do this because `group_id` can be "" as well.
|
| 103 |
+
self.group_id = group_id if group_id is not None else str(id(self))
|
| 104 |
+
short_hash = _compute_group_hash(self.group_id)
|
| 105 |
+
self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{short_hash}.safetensors")
|
| 106 |
+
|
| 107 |
+
all_tensors = []
|
| 108 |
+
for module in self.modules:
|
| 109 |
+
all_tensors.extend(list(module.parameters()))
|
| 110 |
+
all_tensors.extend(list(module.buffers()))
|
| 111 |
+
all_tensors.extend(self.parameters)
|
| 112 |
+
all_tensors.extend(self.buffers)
|
| 113 |
+
all_tensors = list(dict.fromkeys(all_tensors)) # Remove duplicates
|
| 114 |
+
|
| 115 |
+
self.tensor_to_key = {tensor: f"tensor_{i}" for i, tensor in enumerate(all_tensors)}
|
| 116 |
+
self.key_to_tensor = {v: k for k, v in self.tensor_to_key.items()}
|
| 117 |
+
self.cpu_param_dict = {}
|
| 118 |
+
else:
|
| 119 |
+
self.cpu_param_dict = self._init_cpu_param_dict()
|
| 120 |
+
|
| 121 |
+
self._torch_accelerator_module = (
|
| 122 |
+
getattr(torch, torch.accelerator.current_accelerator().type)
|
| 123 |
+
if hasattr(torch, "accelerator")
|
| 124 |
+
else torch.cuda
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
def _init_cpu_param_dict(self):
|
| 128 |
+
cpu_param_dict = {}
|
| 129 |
+
if self.stream is None:
|
| 130 |
+
return cpu_param_dict
|
| 131 |
+
|
| 132 |
+
for module in self.modules:
|
| 133 |
+
for param in module.parameters():
|
| 134 |
+
cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory()
|
| 135 |
+
for buffer in module.buffers():
|
| 136 |
+
cpu_param_dict[buffer] = (
|
| 137 |
+
buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory()
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
for param in self.parameters:
|
| 141 |
+
cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory()
|
| 142 |
+
|
| 143 |
+
for buffer in self.buffers:
|
| 144 |
+
cpu_param_dict[buffer] = buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory()
|
| 145 |
+
|
| 146 |
+
return cpu_param_dict
|
| 147 |
+
|
| 148 |
+
@contextmanager
|
| 149 |
+
def _pinned_memory_tensors(self):
|
| 150 |
+
try:
|
| 151 |
+
pinned_dict = {
|
| 152 |
+
param: tensor.pin_memory() if not tensor.is_pinned() else tensor
|
| 153 |
+
for param, tensor in self.cpu_param_dict.items()
|
| 154 |
+
}
|
| 155 |
+
yield pinned_dict
|
| 156 |
+
finally:
|
| 157 |
+
pinned_dict = None
|
| 158 |
+
|
| 159 |
+
def _transfer_tensor_to_device(self, tensor, source_tensor, default_stream):
|
| 160 |
+
tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
|
| 161 |
+
if self.record_stream:
|
| 162 |
+
tensor.data.record_stream(default_stream)
|
| 163 |
+
|
| 164 |
+
def _process_tensors_from_modules(self, pinned_memory=None, default_stream=None):
|
| 165 |
+
for group_module in self.modules:
|
| 166 |
+
for param in group_module.parameters():
|
| 167 |
+
source = pinned_memory[param] if pinned_memory else param.data
|
| 168 |
+
self._transfer_tensor_to_device(param, source, default_stream)
|
| 169 |
+
for buffer in group_module.buffers():
|
| 170 |
+
source = pinned_memory[buffer] if pinned_memory else buffer.data
|
| 171 |
+
self._transfer_tensor_to_device(buffer, source, default_stream)
|
| 172 |
+
|
| 173 |
+
for param in self.parameters:
|
| 174 |
+
source = pinned_memory[param] if pinned_memory else param.data
|
| 175 |
+
self._transfer_tensor_to_device(param, source, default_stream)
|
| 176 |
+
|
| 177 |
+
for buffer in self.buffers:
|
| 178 |
+
source = pinned_memory[buffer] if pinned_memory else buffer.data
|
| 179 |
+
self._transfer_tensor_to_device(buffer, source, default_stream)
|
| 180 |
+
|
| 181 |
+
def _onload_from_disk(self):
|
| 182 |
+
if self.stream is not None:
|
| 183 |
+
# Wait for previous Host->Device transfer to complete
|
| 184 |
+
self.stream.synchronize()
|
| 185 |
+
|
| 186 |
+
context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream)
|
| 187 |
+
current_stream = self._torch_accelerator_module.current_stream() if self.record_stream else None
|
| 188 |
+
|
| 189 |
+
with context:
|
| 190 |
+
# Load to CPU (if using streams) or directly to target device, pin, and async copy to device
|
| 191 |
+
device = str(self.onload_device) if self.stream is None else "cpu"
|
| 192 |
+
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=device)
|
| 193 |
+
|
| 194 |
+
if self.stream is not None:
|
| 195 |
+
for key, tensor_obj in self.key_to_tensor.items():
|
| 196 |
+
pinned_tensor = loaded_tensors[key].pin_memory()
|
| 197 |
+
tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking)
|
| 198 |
+
if self.record_stream:
|
| 199 |
+
tensor_obj.data.record_stream(current_stream)
|
| 200 |
+
else:
|
| 201 |
+
onload_device = (
|
| 202 |
+
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
|
| 203 |
+
)
|
| 204 |
+
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
|
| 205 |
+
for key, tensor_obj in self.key_to_tensor.items():
|
| 206 |
+
tensor_obj.data = loaded_tensors[key]
|
| 207 |
+
|
| 208 |
+
def _onload_from_memory(self):
|
| 209 |
+
if self.stream is not None:
|
| 210 |
+
# Wait for previous Host->Device transfer to complete
|
| 211 |
+
self.stream.synchronize()
|
| 212 |
+
|
| 213 |
+
context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream)
|
| 214 |
+
default_stream = self._torch_accelerator_module.current_stream() if self.stream is not None else None
|
| 215 |
+
|
| 216 |
+
with context:
|
| 217 |
+
if self.stream is not None:
|
| 218 |
+
with self._pinned_memory_tensors() as pinned_memory:
|
| 219 |
+
self._process_tensors_from_modules(pinned_memory, default_stream=default_stream)
|
| 220 |
+
else:
|
| 221 |
+
self._process_tensors_from_modules(None)
|
| 222 |
+
|
| 223 |
+
def _offload_to_disk(self):
|
| 224 |
+
# TODO: we can potentially optimize this code path by checking if the _all_ the desired
|
| 225 |
+
# safetensor files exist on the disk and if so, skip this step entirely, reducing IO
|
| 226 |
+
# overhead. Currently, we just check if the given `safetensors_file_path` exists and if not
|
| 227 |
+
# we perform a write.
|
| 228 |
+
# Check if the file has been saved in this session or if it already exists on disk.
|
| 229 |
+
if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path):
|
| 230 |
+
os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True)
|
| 231 |
+
tensors_to_save = {key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()}
|
| 232 |
+
safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path)
|
| 233 |
+
|
| 234 |
+
# The group is now considered offloaded to disk for the rest of the session.
|
| 235 |
+
self._is_offloaded_to_disk = True
|
| 236 |
+
|
| 237 |
+
# We do this to free up the RAM which is still holding the up tensor data.
|
| 238 |
+
for tensor_obj in self.tensor_to_key.keys():
|
| 239 |
+
tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device)
|
| 240 |
+
|
| 241 |
+
def _offload_to_memory(self):
|
| 242 |
+
if self.stream is not None:
|
| 243 |
+
if not self.record_stream:
|
| 244 |
+
self._torch_accelerator_module.current_stream().synchronize()
|
| 245 |
+
|
| 246 |
+
for group_module in self.modules:
|
| 247 |
+
for param in group_module.parameters():
|
| 248 |
+
param.data = self.cpu_param_dict[param]
|
| 249 |
+
for param in self.parameters:
|
| 250 |
+
param.data = self.cpu_param_dict[param]
|
| 251 |
+
for buffer in self.buffers:
|
| 252 |
+
buffer.data = self.cpu_param_dict[buffer]
|
| 253 |
+
else:
|
| 254 |
+
for group_module in self.modules:
|
| 255 |
+
group_module.to(self.offload_device, non_blocking=False)
|
| 256 |
+
for param in self.parameters:
|
| 257 |
+
param.data = param.data.to(self.offload_device, non_blocking=False)
|
| 258 |
+
for buffer in self.buffers:
|
| 259 |
+
buffer.data = buffer.data.to(self.offload_device, non_blocking=False)
|
| 260 |
+
|
| 261 |
+
@torch.compiler.disable()
|
| 262 |
+
def onload_(self):
|
| 263 |
+
r"""Onloads the group of parameters to the onload_device."""
|
| 264 |
+
if self.offload_to_disk_path is not None:
|
| 265 |
+
self._onload_from_disk()
|
| 266 |
+
else:
|
| 267 |
+
self._onload_from_memory()
|
| 268 |
+
|
| 269 |
+
@torch.compiler.disable()
|
| 270 |
+
def offload_(self):
|
| 271 |
+
r"""Offloads the group of parameters to the offload_device."""
|
| 272 |
+
if self.offload_to_disk_path:
|
| 273 |
+
self._offload_to_disk()
|
| 274 |
+
else:
|
| 275 |
+
self._offload_to_memory()
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
class GroupOffloadingHook(ModelHook):
|
| 279 |
+
r"""
|
| 280 |
+
A hook that offloads groups of torch.nn.Module to the CPU for storage and onloads to accelerator device for
|
| 281 |
+
computation. Each group has one "onload leader" module that is responsible for onloading, and an "offload leader"
|
| 282 |
+
module that is responsible for offloading. If prefetching is enabled, the onload leader of the previous module
|
| 283 |
+
group is responsible for onloading the current module group.
|
| 284 |
+
"""
|
| 285 |
+
|
| 286 |
+
_is_stateful = False
|
| 287 |
+
|
| 288 |
+
def __init__(self, group: ModuleGroup, *, config: GroupOffloadingConfig) -> None:
|
| 289 |
+
self.group = group
|
| 290 |
+
self.next_group: ModuleGroup | None = None
|
| 291 |
+
self.config = config
|
| 292 |
+
|
| 293 |
+
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
|
| 294 |
+
if self.group.offload_leader == module:
|
| 295 |
+
self.group.offload_()
|
| 296 |
+
return module
|
| 297 |
+
|
| 298 |
+
def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
|
| 299 |
+
# If there wasn't an onload_leader assigned, we assume that the submodule that first called its forward
|
| 300 |
+
# method is the onload_leader of the group.
|
| 301 |
+
if self.group.onload_leader is None:
|
| 302 |
+
self.group.onload_leader = module
|
| 303 |
+
|
| 304 |
+
# If the current module is the onload_leader of the group, we onload the group if it is supposed
|
| 305 |
+
# to onload itself. In the case of using prefetching with streams, we onload the next group if
|
| 306 |
+
# it is not supposed to onload itself.
|
| 307 |
+
if self.group.onload_leader == module:
|
| 308 |
+
if self.group.onload_self:
|
| 309 |
+
self.group.onload_()
|
| 310 |
+
else:
|
| 311 |
+
# onload_self=False means this group relies on prefetching from a previous group.
|
| 312 |
+
# However, for conditionally-executed modules (e.g. patch_short/patch_mid/patch_long in Helios),
|
| 313 |
+
# the prefetch chain may not cover them if they were absent during the first forward pass
|
| 314 |
+
# when the execution order was traced. In that case, their weights remain on offload_device,
|
| 315 |
+
# so we fall back to a synchronous onload here.
|
| 316 |
+
params = [p for m in self.group.modules for p in m.parameters()] + list(self.group.parameters)
|
| 317 |
+
if params and params[0].device == self.group.offload_device:
|
| 318 |
+
self.group.onload_()
|
| 319 |
+
if self.group.stream is not None:
|
| 320 |
+
self.group.stream.synchronize()
|
| 321 |
+
|
| 322 |
+
should_onload_next_group = self.next_group is not None and not self.next_group.onload_self
|
| 323 |
+
if should_onload_next_group:
|
| 324 |
+
self.next_group.onload_()
|
| 325 |
+
|
| 326 |
+
should_synchronize = (
|
| 327 |
+
not self.group.onload_self and self.group.stream is not None and not should_onload_next_group
|
| 328 |
+
)
|
| 329 |
+
if should_synchronize:
|
| 330 |
+
# If this group didn't onload itself, it means it was asynchronously onloaded by the
|
| 331 |
+
# previous group. We need to synchronize the side stream to ensure parameters
|
| 332 |
+
# are completely loaded to proceed with forward pass. Without this, uninitialized
|
| 333 |
+
# weights will be used in the computation, leading to incorrect results
|
| 334 |
+
# Also, we should only do this synchronization if we don't already do it from the sync call in
|
| 335 |
+
# self.next_group.onload_, hence the `not should_onload_next_group` check.
|
| 336 |
+
self.group.stream.synchronize()
|
| 337 |
+
|
| 338 |
+
args = send_to_device(args, self.group.onload_device, non_blocking=self.group.non_blocking)
|
| 339 |
+
|
| 340 |
+
# Some Autoencoder models use a feature cache that is passed through submodules
|
| 341 |
+
# and modified in place. The `send_to_device` call returns a copy of this feature cache object
|
| 342 |
+
# which breaks the inplace updates. Use `exclude_kwargs` to mark these cache features
|
| 343 |
+
exclude_kwargs = self.config.exclude_kwargs or []
|
| 344 |
+
if exclude_kwargs:
|
| 345 |
+
moved_kwargs = send_to_device(
|
| 346 |
+
{k: v for k, v in kwargs.items() if k not in exclude_kwargs},
|
| 347 |
+
self.group.onload_device,
|
| 348 |
+
non_blocking=self.group.non_blocking,
|
| 349 |
+
)
|
| 350 |
+
kwargs.update(moved_kwargs)
|
| 351 |
+
else:
|
| 352 |
+
kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking)
|
| 353 |
+
|
| 354 |
+
return args, kwargs
|
| 355 |
+
|
| 356 |
+
def post_forward(self, module: torch.nn.Module, output):
|
| 357 |
+
if self.group.offload_leader == module:
|
| 358 |
+
self.group.offload_()
|
| 359 |
+
return output
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
class LazyPrefetchGroupOffloadingHook(ModelHook):
|
| 363 |
+
r"""
|
| 364 |
+
A hook, used in conjunction with GroupOffloadingHook, that applies lazy prefetching to groups of torch.nn.Module.
|
| 365 |
+
This hook is used to determine the order in which the layers are executed during the forward pass. Once the layer
|
| 366 |
+
invocation order is known, assignments of the next_group attribute for prefetching can be made, which allows
|
| 367 |
+
prefetching groups in the correct order.
|
| 368 |
+
"""
|
| 369 |
+
|
| 370 |
+
_is_stateful = False
|
| 371 |
+
|
| 372 |
+
def __init__(self):
|
| 373 |
+
self.execution_order: list[tuple[str, torch.nn.Module]] = []
|
| 374 |
+
self._layer_execution_tracker_module_names = set()
|
| 375 |
+
|
| 376 |
+
def initialize_hook(self, module):
|
| 377 |
+
def make_execution_order_update_callback(current_name, current_submodule):
|
| 378 |
+
def callback():
|
| 379 |
+
if not torch.compiler.is_compiling():
|
| 380 |
+
logger.debug(f"Adding {current_name} to the execution order")
|
| 381 |
+
self.execution_order.append((current_name, current_submodule))
|
| 382 |
+
|
| 383 |
+
return callback
|
| 384 |
+
|
| 385 |
+
# To every submodule that contains a group offloading hook (at this point, no prefetching is enabled for any
|
| 386 |
+
# of the groups), we add a layer execution tracker hook that will be used to determine the order in which the
|
| 387 |
+
# layers are executed during the forward pass.
|
| 388 |
+
for name, submodule in module.named_modules():
|
| 389 |
+
if name == "" or not hasattr(submodule, "_diffusers_hook"):
|
| 390 |
+
continue
|
| 391 |
+
|
| 392 |
+
registry = HookRegistry.check_if_exists_or_initialize(submodule)
|
| 393 |
+
group_offloading_hook = registry.get_hook(_GROUP_OFFLOADING)
|
| 394 |
+
|
| 395 |
+
if group_offloading_hook is not None:
|
| 396 |
+
# For the first forward pass, we have to load in a blocking manner
|
| 397 |
+
group_offloading_hook.group.non_blocking = False
|
| 398 |
+
layer_tracker_hook = LayerExecutionTrackerHook(make_execution_order_update_callback(name, submodule))
|
| 399 |
+
registry.register_hook(layer_tracker_hook, _LAYER_EXECUTION_TRACKER)
|
| 400 |
+
self._layer_execution_tracker_module_names.add(name)
|
| 401 |
+
|
| 402 |
+
return module
|
| 403 |
+
|
| 404 |
+
def post_forward(self, module, output):
|
| 405 |
+
# At this point, for the current modules' submodules, we know the execution order of the layers. We can now
|
| 406 |
+
# remove the layer execution tracker hooks and apply prefetching by setting the next_group attribute for each
|
| 407 |
+
# group offloading hook.
|
| 408 |
+
num_executed = len(self.execution_order)
|
| 409 |
+
execution_order_module_names = {name for name, _ in self.execution_order}
|
| 410 |
+
|
| 411 |
+
# It may be possible that some layers were not executed during the forward pass. This can happen if the layer
|
| 412 |
+
# is not used in the forward pass, or if the layer is not executed due to some other reason. In such cases, we
|
| 413 |
+
# may not be able to apply prefetching in the correct order, which can lead to device-mismatch related errors
|
| 414 |
+
# if the missing layers end up being executed in the future.
|
| 415 |
+
if execution_order_module_names != self._layer_execution_tracker_module_names:
|
| 416 |
+
unexecuted_layers = list(self._layer_execution_tracker_module_names - execution_order_module_names)
|
| 417 |
+
if not torch.compiler.is_compiling():
|
| 418 |
+
logger.warning(
|
| 419 |
+
"It seems like some layers were not executed during the forward pass. This may lead to problems when "
|
| 420 |
+
"applying lazy prefetching with automatic tracing and lead to device-mismatch related errors. Please "
|
| 421 |
+
"make sure that all layers are executed during the forward pass. The following layers were not executed:\n"
|
| 422 |
+
f"{unexecuted_layers=}"
|
| 423 |
+
)
|
| 424 |
+
|
| 425 |
+
# Remove the layer execution tracker hooks from the submodules
|
| 426 |
+
base_module_registry = module._diffusers_hook
|
| 427 |
+
registries = [submodule._diffusers_hook for _, submodule in self.execution_order]
|
| 428 |
+
group_offloading_hooks = [registry.get_hook(_GROUP_OFFLOADING) for registry in registries]
|
| 429 |
+
|
| 430 |
+
for i in range(num_executed):
|
| 431 |
+
registries[i].remove_hook(_LAYER_EXECUTION_TRACKER, recurse=False)
|
| 432 |
+
|
| 433 |
+
# Remove the current lazy prefetch group offloading hook so that it doesn't interfere with the next forward pass
|
| 434 |
+
base_module_registry.remove_hook(_LAZY_PREFETCH_GROUP_OFFLOADING, recurse=False)
|
| 435 |
+
|
| 436 |
+
# LazyPrefetchGroupOffloadingHook is only used with streams, so we know that non_blocking should be True.
|
| 437 |
+
# We disable non_blocking for the first forward pass, but need to enable it for the subsequent passes to
|
| 438 |
+
# see the benefits of prefetching.
|
| 439 |
+
for hook in group_offloading_hooks:
|
| 440 |
+
hook.group.non_blocking = True
|
| 441 |
+
|
| 442 |
+
# Set required attributes for prefetching
|
| 443 |
+
if num_executed > 0:
|
| 444 |
+
base_module_group_offloading_hook = base_module_registry.get_hook(_GROUP_OFFLOADING)
|
| 445 |
+
base_module_group_offloading_hook.next_group = group_offloading_hooks[0].group
|
| 446 |
+
base_module_group_offloading_hook.next_group.onload_self = False
|
| 447 |
+
|
| 448 |
+
for i in range(num_executed - 1):
|
| 449 |
+
name1, _ = self.execution_order[i]
|
| 450 |
+
name2, _ = self.execution_order[i + 1]
|
| 451 |
+
if not torch.compiler.is_compiling():
|
| 452 |
+
logger.debug(f"Applying lazy prefetch group offloading from {name1} to {name2}")
|
| 453 |
+
group_offloading_hooks[i].next_group = group_offloading_hooks[i + 1].group
|
| 454 |
+
group_offloading_hooks[i].next_group.onload_self = False
|
| 455 |
+
|
| 456 |
+
return output
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
class LayerExecutionTrackerHook(ModelHook):
|
| 460 |
+
r"""
|
| 461 |
+
A hook that tracks the order in which the layers are executed during the forward pass by calling back to the
|
| 462 |
+
LazyPrefetchGroupOffloadingHook to update the execution order.
|
| 463 |
+
"""
|
| 464 |
+
|
| 465 |
+
_is_stateful = False
|
| 466 |
+
|
| 467 |
+
def __init__(self, execution_order_update_callback):
|
| 468 |
+
self.execution_order_update_callback = execution_order_update_callback
|
| 469 |
+
|
| 470 |
+
def pre_forward(self, module, *args, **kwargs):
|
| 471 |
+
self.execution_order_update_callback()
|
| 472 |
+
return args, kwargs
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
def apply_group_offloading(
|
| 476 |
+
module: torch.nn.Module,
|
| 477 |
+
onload_device: str | torch.device,
|
| 478 |
+
offload_device: str | torch.device = torch.device("cpu"),
|
| 479 |
+
offload_type: str | GroupOffloadingType = "block_level",
|
| 480 |
+
num_blocks_per_group: int | None = None,
|
| 481 |
+
non_blocking: bool = False,
|
| 482 |
+
use_stream: bool = False,
|
| 483 |
+
record_stream: bool = False,
|
| 484 |
+
low_cpu_mem_usage: bool = False,
|
| 485 |
+
offload_to_disk_path: str | None = None,
|
| 486 |
+
block_modules: list[str] | None = None,
|
| 487 |
+
exclude_kwargs: list[str] | None = None,
|
| 488 |
+
) -> None:
|
| 489 |
+
r"""
|
| 490 |
+
Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
|
| 491 |
+
where it is beneficial, we need to first provide some context on how other supported offloading methods work.
|
| 492 |
+
|
| 493 |
+
Typically, offloading is done at two levels:
|
| 494 |
+
- Module-level: In Diffusers, this can be enabled using the `ModelMixin::enable_model_cpu_offload()` method. It
|
| 495 |
+
works by offloading each component of a pipeline to the CPU for storage, and onloading to the accelerator device
|
| 496 |
+
when needed for computation. This method is more memory-efficient than keeping all components on the accelerator,
|
| 497 |
+
but the memory requirements are still quite high. For this method to work, one needs memory equivalent to size of
|
| 498 |
+
the model in runtime dtype + size of largest intermediate activation tensors to be able to complete the forward
|
| 499 |
+
pass.
|
| 500 |
+
- Leaf-level: In Diffusers, this can be enabled using the `ModelMixin::enable_sequential_cpu_offload()` method. It
|
| 501 |
+
works by offloading the lowest leaf-level parameters of the computation graph to the CPU for storage, and
|
| 502 |
+
onloading only the leafs to the accelerator device for computation. This uses the lowest amount of accelerator
|
| 503 |
+
memory, but can be slower due to the excessive number of device synchronizations.
|
| 504 |
+
|
| 505 |
+
Group offloading is a middle ground between the two methods. It works by offloading groups of internal layers,
|
| 506 |
+
(either `torch.nn.ModuleList` or `torch.nn.Sequential`). This method uses lower memory than module-level
|
| 507 |
+
offloading. It is also faster than leaf-level/sequential offloading, as the number of device synchronizations is
|
| 508 |
+
reduced.
|
| 509 |
+
|
| 510 |
+
Another supported feature (for CUDA devices with support for asynchronous data transfer streams) is the ability to
|
| 511 |
+
overlap data transfer and computation to reduce the overall execution time compared to sequential offloading. This
|
| 512 |
+
is enabled using layer prefetching with streams, i.e., the layer that is to be executed next starts onloading to
|
| 513 |
+
the accelerator device while the current layer is being executed - this increases the memory requirements slightly.
|
| 514 |
+
Note that this implementation also supports leaf-level offloading but can be made much faster when using streams.
|
| 515 |
+
|
| 516 |
+
Args:
|
| 517 |
+
module (`torch.nn.Module`):
|
| 518 |
+
The module to which group offloading is applied.
|
| 519 |
+
onload_device (`torch.device`):
|
| 520 |
+
The device to which the group of modules are onloaded.
|
| 521 |
+
offload_device (`torch.device`, defaults to `torch.device("cpu")`):
|
| 522 |
+
The device to which the group of modules are offloaded. This should typically be the CPU. Default is CPU.
|
| 523 |
+
offload_type (`str` or `GroupOffloadingType`, defaults to "block_level"):
|
| 524 |
+
The type of offloading to be applied. Can be one of "block_level" or "leaf_level". Default is
|
| 525 |
+
"block_level".
|
| 526 |
+
offload_to_disk_path (`str`, *optional*, defaults to `None`):
|
| 527 |
+
The path to the directory where parameters will be offloaded. Setting this option can be useful in limited
|
| 528 |
+
RAM environment settings where a reasonable speed-memory trade-off is desired.
|
| 529 |
+
num_blocks_per_group (`int`, *optional*):
|
| 530 |
+
The number of blocks per group when using offload_type="block_level". This is required when using
|
| 531 |
+
offload_type="block_level".
|
| 532 |
+
non_blocking (`bool`, defaults to `False`):
|
| 533 |
+
If True, offloading and onloading is done with non-blocking data transfer.
|
| 534 |
+
use_stream (`bool`, defaults to `False`):
|
| 535 |
+
If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for
|
| 536 |
+
overlapping computation and data transfer.
|
| 537 |
+
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
|
| 538 |
+
as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the
|
| 539 |
+
[PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more
|
| 540 |
+
details.
|
| 541 |
+
low_cpu_mem_usage (`bool`, defaults to `False`):
|
| 542 |
+
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
|
| 543 |
+
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
|
| 544 |
+
the CPU memory is a bottleneck but may counteract the benefits of using streams.
|
| 545 |
+
block_modules (`list[str]`, *optional*):
|
| 546 |
+
List of module names that should be treated as blocks for offloading. If provided, only these modules will
|
| 547 |
+
be considered for block-level offloading. If not provided, the default block detection logic will be used.
|
| 548 |
+
exclude_kwargs (`list[str]`, *optional*):
|
| 549 |
+
List of kwarg keys that should not be processed by send_to_device. This is useful for mutable state like
|
| 550 |
+
caching lists that need to maintain their object identity across forward passes. If not provided, will be
|
| 551 |
+
inferred from the module's `_skip_keys` attribute if it exists.
|
| 552 |
+
|
| 553 |
+
Example:
|
| 554 |
+
```python
|
| 555 |
+
>>> from diffusers import CogVideoXTransformer3DModel
|
| 556 |
+
>>> from diffusers.hooks import apply_group_offloading
|
| 557 |
+
|
| 558 |
+
>>> transformer = CogVideoXTransformer3DModel.from_pretrained(
|
| 559 |
+
... "THUDM/CogVideoX-5b", subfolder="transformer", torch_dtype=torch.bfloat16
|
| 560 |
+
... )
|
| 561 |
+
|
| 562 |
+
>>> apply_group_offloading(
|
| 563 |
+
... transformer,
|
| 564 |
+
... onload_device=torch.device("cuda"),
|
| 565 |
+
... offload_device=torch.device("cpu"),
|
| 566 |
+
... offload_type="block_level",
|
| 567 |
+
... num_blocks_per_group=2,
|
| 568 |
+
... use_stream=True,
|
| 569 |
+
... )
|
| 570 |
+
```
|
| 571 |
+
"""
|
| 572 |
+
|
| 573 |
+
onload_device = torch.device(onload_device) if isinstance(onload_device, str) else onload_device
|
| 574 |
+
offload_device = torch.device(offload_device) if isinstance(offload_device, str) else offload_device
|
| 575 |
+
offload_type = GroupOffloadingType(offload_type)
|
| 576 |
+
|
| 577 |
+
stream = None
|
| 578 |
+
if use_stream:
|
| 579 |
+
if torch.cuda.is_available():
|
| 580 |
+
stream = torch.cuda.Stream()
|
| 581 |
+
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
| 582 |
+
stream = torch.Stream()
|
| 583 |
+
else:
|
| 584 |
+
raise ValueError("Using streams for data transfer requires a CUDA device, or an Intel XPU device.")
|
| 585 |
+
|
| 586 |
+
if not use_stream and record_stream:
|
| 587 |
+
raise ValueError("`record_stream` cannot be True when `use_stream=False`.")
|
| 588 |
+
if offload_type == GroupOffloadingType.BLOCK_LEVEL and num_blocks_per_group is None:
|
| 589 |
+
raise ValueError("`num_blocks_per_group` must be provided when using `offload_type='block_level'.")
|
| 590 |
+
|
| 591 |
+
_raise_error_if_accelerate_model_or_sequential_hook_present(module)
|
| 592 |
+
|
| 593 |
+
if block_modules is None:
|
| 594 |
+
block_modules = getattr(module, "_group_offload_block_modules", None)
|
| 595 |
+
|
| 596 |
+
if exclude_kwargs is None:
|
| 597 |
+
exclude_kwargs = getattr(module, "_skip_keys", None)
|
| 598 |
+
|
| 599 |
+
config = GroupOffloadingConfig(
|
| 600 |
+
onload_device=onload_device,
|
| 601 |
+
offload_device=offload_device,
|
| 602 |
+
offload_type=offload_type,
|
| 603 |
+
num_blocks_per_group=num_blocks_per_group,
|
| 604 |
+
non_blocking=non_blocking,
|
| 605 |
+
stream=stream,
|
| 606 |
+
record_stream=record_stream,
|
| 607 |
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
| 608 |
+
offload_to_disk_path=offload_to_disk_path,
|
| 609 |
+
block_modules=block_modules,
|
| 610 |
+
exclude_kwargs=exclude_kwargs,
|
| 611 |
+
)
|
| 612 |
+
_apply_group_offloading(module, config)
|
| 613 |
+
|
| 614 |
+
|
| 615 |
+
def _apply_group_offloading(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
|
| 616 |
+
if config.offload_type == GroupOffloadingType.BLOCK_LEVEL:
|
| 617 |
+
_apply_group_offloading_block_level(module, config)
|
| 618 |
+
elif config.offload_type == GroupOffloadingType.LEAF_LEVEL:
|
| 619 |
+
_apply_group_offloading_leaf_level(module, config)
|
| 620 |
+
else:
|
| 621 |
+
assert False
|
| 622 |
+
|
| 623 |
+
|
| 624 |
+
def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
|
| 625 |
+
r"""
|
| 626 |
+
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks, and explicitly
|
| 627 |
+
defined block modules. In comparison to the "leaf_level" offloading, which is more fine-grained, this offloading is
|
| 628 |
+
done at the top-level blocks and modules specified in block_modules.
|
| 629 |
+
|
| 630 |
+
When block_modules is provided, only those modules will be treated as blocks for offloading. For each specified
|
| 631 |
+
module, recursively apply block offloading to it.
|
| 632 |
+
"""
|
| 633 |
+
if config.stream is not None and config.num_blocks_per_group != 1:
|
| 634 |
+
logger.warning(
|
| 635 |
+
f"Using streams is only supported for num_blocks_per_group=1. Got {config.num_blocks_per_group=}. Setting it to 1."
|
| 636 |
+
)
|
| 637 |
+
config.num_blocks_per_group = 1
|
| 638 |
+
|
| 639 |
+
block_modules = set(config.block_modules) if config.block_modules is not None else set()
|
| 640 |
+
|
| 641 |
+
# Create module groups for ModuleList and Sequential blocks, and explicitly defined block modules
|
| 642 |
+
modules_with_group_offloading = set()
|
| 643 |
+
unmatched_modules = []
|
| 644 |
+
matched_module_groups = []
|
| 645 |
+
|
| 646 |
+
for name, submodule in module.named_children():
|
| 647 |
+
# Check if this is an explicitly defined block module
|
| 648 |
+
if name in block_modules:
|
| 649 |
+
# Track submodule using a prefix to avoid filename collisions during disk offload.
|
| 650 |
+
# Without this, submodules sharing the same model class would be assigned identical
|
| 651 |
+
# filenames (derived from the class name).
|
| 652 |
+
prefix = f"{config.module_prefix}{name}." if config.module_prefix else f"{name}."
|
| 653 |
+
submodule_config = replace(config, module_prefix=prefix)
|
| 654 |
+
|
| 655 |
+
_apply_group_offloading_block_level(submodule, submodule_config)
|
| 656 |
+
modules_with_group_offloading.add(name)
|
| 657 |
+
|
| 658 |
+
elif isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
|
| 659 |
+
# Handle ModuleList and Sequential blocks as before
|
| 660 |
+
for i in range(0, len(submodule), config.num_blocks_per_group):
|
| 661 |
+
current_modules = list(submodule[i : i + config.num_blocks_per_group])
|
| 662 |
+
if len(current_modules) == 0:
|
| 663 |
+
continue
|
| 664 |
+
|
| 665 |
+
group_id = f"{config.module_prefix}{name}_{i}_{i + len(current_modules) - 1}"
|
| 666 |
+
group = ModuleGroup(
|
| 667 |
+
modules=current_modules,
|
| 668 |
+
offload_device=config.offload_device,
|
| 669 |
+
onload_device=config.onload_device,
|
| 670 |
+
offload_to_disk_path=config.offload_to_disk_path,
|
| 671 |
+
offload_leader=current_modules[-1],
|
| 672 |
+
onload_leader=current_modules[0],
|
| 673 |
+
non_blocking=config.non_blocking,
|
| 674 |
+
stream=config.stream,
|
| 675 |
+
record_stream=config.record_stream,
|
| 676 |
+
low_cpu_mem_usage=config.low_cpu_mem_usage,
|
| 677 |
+
onload_self=True,
|
| 678 |
+
group_id=group_id,
|
| 679 |
+
)
|
| 680 |
+
matched_module_groups.append(group)
|
| 681 |
+
for j in range(i, i + len(current_modules)):
|
| 682 |
+
modules_with_group_offloading.add(f"{name}.{j}")
|
| 683 |
+
else:
|
| 684 |
+
# This is an unmatched module
|
| 685 |
+
unmatched_modules.append((name, submodule))
|
| 686 |
+
|
| 687 |
+
# Apply group offloading hooks to the module groups
|
| 688 |
+
for i, group in enumerate(matched_module_groups):
|
| 689 |
+
for group_module in group.modules:
|
| 690 |
+
_apply_group_offloading_hook(group_module, group, config=config)
|
| 691 |
+
|
| 692 |
+
# Parameters and Buffers of the top-level module need to be offloaded/onloaded separately
|
| 693 |
+
# when the forward pass of this module is called. This is because the top-level module is not
|
| 694 |
+
# part of any group (as doing so would lead to no VRAM savings).
|
| 695 |
+
parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading)
|
| 696 |
+
buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading)
|
| 697 |
+
parameters = [param for _, param in parameters]
|
| 698 |
+
buffers = [buffer for _, buffer in buffers]
|
| 699 |
+
|
| 700 |
+
# Create a group for the remaining unmatched submodules of the top-level
|
| 701 |
+
# module so that they are on the correct device when the forward pass is called.
|
| 702 |
+
unmatched_modules = [unmatched_module for _, unmatched_module in unmatched_modules]
|
| 703 |
+
if len(unmatched_modules) > 0 or len(parameters) > 0 or len(buffers) > 0:
|
| 704 |
+
unmatched_group = ModuleGroup(
|
| 705 |
+
modules=unmatched_modules,
|
| 706 |
+
offload_device=config.offload_device,
|
| 707 |
+
onload_device=config.onload_device,
|
| 708 |
+
offload_to_disk_path=config.offload_to_disk_path,
|
| 709 |
+
offload_leader=module,
|
| 710 |
+
onload_leader=module,
|
| 711 |
+
parameters=parameters,
|
| 712 |
+
buffers=buffers,
|
| 713 |
+
non_blocking=False,
|
| 714 |
+
stream=None,
|
| 715 |
+
record_stream=False,
|
| 716 |
+
onload_self=True,
|
| 717 |
+
group_id=f"{config.module_prefix}{module.__class__.__name__}_unmatched_group",
|
| 718 |
+
)
|
| 719 |
+
if config.stream is None:
|
| 720 |
+
_apply_group_offloading_hook(module, unmatched_group, config=config)
|
| 721 |
+
else:
|
| 722 |
+
_apply_lazy_group_offloading_hook(module, unmatched_group, config=config)
|
| 723 |
+
|
| 724 |
+
|
| 725 |
+
def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
|
| 726 |
+
r"""
|
| 727 |
+
This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
|
| 728 |
+
requirements. However, it can be slower compared to other offloading methods due to the excessive number of device
|
| 729 |
+
synchronizations. When using devices that support streams to overlap data transfer and computation, this method can
|
| 730 |
+
reduce memory usage without any performance degradation.
|
| 731 |
+
"""
|
| 732 |
+
# Create module groups for leaf modules and apply group offloading hooks
|
| 733 |
+
modules_with_group_offloading = set()
|
| 734 |
+
for name, submodule in module.named_modules():
|
| 735 |
+
if not isinstance(submodule, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
|
| 736 |
+
continue
|
| 737 |
+
group = ModuleGroup(
|
| 738 |
+
modules=[submodule],
|
| 739 |
+
offload_device=config.offload_device,
|
| 740 |
+
onload_device=config.onload_device,
|
| 741 |
+
offload_to_disk_path=config.offload_to_disk_path,
|
| 742 |
+
offload_leader=submodule,
|
| 743 |
+
onload_leader=submodule,
|
| 744 |
+
non_blocking=config.non_blocking,
|
| 745 |
+
stream=config.stream,
|
| 746 |
+
record_stream=config.record_stream,
|
| 747 |
+
low_cpu_mem_usage=config.low_cpu_mem_usage,
|
| 748 |
+
onload_self=True,
|
| 749 |
+
group_id=name,
|
| 750 |
+
)
|
| 751 |
+
_apply_group_offloading_hook(submodule, group, config=config)
|
| 752 |
+
modules_with_group_offloading.add(name)
|
| 753 |
+
|
| 754 |
+
# Parameters and Buffers at all non-leaf levels need to be offloaded/onloaded separately when the forward pass
|
| 755 |
+
# of the module is called
|
| 756 |
+
module_dict = dict(module.named_modules())
|
| 757 |
+
parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading)
|
| 758 |
+
buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading)
|
| 759 |
+
|
| 760 |
+
# Find closest module parent for each parameter and buffer, and attach group hooks
|
| 761 |
+
parent_to_parameters = {}
|
| 762 |
+
for name, param in parameters:
|
| 763 |
+
parent_name = _find_parent_module_in_module_dict(name, module_dict)
|
| 764 |
+
if parent_name in parent_to_parameters:
|
| 765 |
+
parent_to_parameters[parent_name].append(param)
|
| 766 |
+
else:
|
| 767 |
+
parent_to_parameters[parent_name] = [param]
|
| 768 |
+
|
| 769 |
+
parent_to_buffers = {}
|
| 770 |
+
for name, buffer in buffers:
|
| 771 |
+
parent_name = _find_parent_module_in_module_dict(name, module_dict)
|
| 772 |
+
if parent_name in parent_to_buffers:
|
| 773 |
+
parent_to_buffers[parent_name].append(buffer)
|
| 774 |
+
else:
|
| 775 |
+
parent_to_buffers[parent_name] = [buffer]
|
| 776 |
+
|
| 777 |
+
parent_names = set(parent_to_parameters.keys()) | set(parent_to_buffers.keys())
|
| 778 |
+
for name in parent_names:
|
| 779 |
+
parameters = parent_to_parameters.get(name, [])
|
| 780 |
+
buffers = parent_to_buffers.get(name, [])
|
| 781 |
+
parent_module = module_dict[name]
|
| 782 |
+
group = ModuleGroup(
|
| 783 |
+
modules=[],
|
| 784 |
+
offload_device=config.offload_device,
|
| 785 |
+
onload_device=config.onload_device,
|
| 786 |
+
offload_leader=parent_module,
|
| 787 |
+
onload_leader=parent_module,
|
| 788 |
+
offload_to_disk_path=config.offload_to_disk_path,
|
| 789 |
+
parameters=parameters,
|
| 790 |
+
buffers=buffers,
|
| 791 |
+
non_blocking=config.non_blocking,
|
| 792 |
+
stream=config.stream,
|
| 793 |
+
record_stream=config.record_stream,
|
| 794 |
+
low_cpu_mem_usage=config.low_cpu_mem_usage,
|
| 795 |
+
onload_self=True,
|
| 796 |
+
group_id=name,
|
| 797 |
+
)
|
| 798 |
+
_apply_group_offloading_hook(parent_module, group, config=config)
|
| 799 |
+
|
| 800 |
+
if config.stream is not None:
|
| 801 |
+
# When using streams, we need to know the layer execution order for applying prefetching (to overlap data transfer
|
| 802 |
+
# and computation). Since we don't know the order beforehand, we apply a lazy prefetching hook that will find the
|
| 803 |
+
# execution order and apply prefetching in the correct order.
|
| 804 |
+
unmatched_group = ModuleGroup(
|
| 805 |
+
modules=[],
|
| 806 |
+
offload_device=config.offload_device,
|
| 807 |
+
onload_device=config.onload_device,
|
| 808 |
+
offload_to_disk_path=config.offload_to_disk_path,
|
| 809 |
+
offload_leader=module,
|
| 810 |
+
onload_leader=module,
|
| 811 |
+
parameters=None,
|
| 812 |
+
buffers=None,
|
| 813 |
+
non_blocking=False,
|
| 814 |
+
stream=None,
|
| 815 |
+
record_stream=False,
|
| 816 |
+
low_cpu_mem_usage=config.low_cpu_mem_usage,
|
| 817 |
+
onload_self=True,
|
| 818 |
+
group_id=_GROUP_ID_LAZY_LEAF,
|
| 819 |
+
)
|
| 820 |
+
_apply_lazy_group_offloading_hook(module, unmatched_group, config=config)
|
| 821 |
+
|
| 822 |
+
|
| 823 |
+
def _apply_group_offloading_hook(
|
| 824 |
+
module: torch.nn.Module,
|
| 825 |
+
group: ModuleGroup,
|
| 826 |
+
*,
|
| 827 |
+
config: GroupOffloadingConfig,
|
| 828 |
+
) -> None:
|
| 829 |
+
registry = HookRegistry.check_if_exists_or_initialize(module)
|
| 830 |
+
|
| 831 |
+
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
|
| 832 |
+
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
|
| 833 |
+
if registry.get_hook(_GROUP_OFFLOADING) is None:
|
| 834 |
+
hook = GroupOffloadingHook(group, config=config)
|
| 835 |
+
registry.register_hook(hook, _GROUP_OFFLOADING)
|
| 836 |
+
|
| 837 |
+
|
| 838 |
+
def _apply_lazy_group_offloading_hook(
|
| 839 |
+
module: torch.nn.Module,
|
| 840 |
+
group: ModuleGroup,
|
| 841 |
+
*,
|
| 842 |
+
config: GroupOffloadingConfig,
|
| 843 |
+
) -> None:
|
| 844 |
+
registry = HookRegistry.check_if_exists_or_initialize(module)
|
| 845 |
+
|
| 846 |
+
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
|
| 847 |
+
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
|
| 848 |
+
if registry.get_hook(_GROUP_OFFLOADING) is None:
|
| 849 |
+
hook = GroupOffloadingHook(group, config=config)
|
| 850 |
+
registry.register_hook(hook, _GROUP_OFFLOADING)
|
| 851 |
+
|
| 852 |
+
lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook()
|
| 853 |
+
registry.register_hook(lazy_prefetch_hook, _LAZY_PREFETCH_GROUP_OFFLOADING)
|
| 854 |
+
|
| 855 |
+
|
| 856 |
+
def _gather_parameters_with_no_group_offloading_parent(
|
| 857 |
+
module: torch.nn.Module, modules_with_group_offloading: Set[str]
|
| 858 |
+
) -> list[torch.nn.Parameter]:
|
| 859 |
+
parameters = []
|
| 860 |
+
for name, parameter in module.named_parameters():
|
| 861 |
+
has_parent_with_group_offloading = False
|
| 862 |
+
atoms = name.split(".")
|
| 863 |
+
while len(atoms) > 0:
|
| 864 |
+
parent_name = ".".join(atoms)
|
| 865 |
+
if parent_name in modules_with_group_offloading:
|
| 866 |
+
has_parent_with_group_offloading = True
|
| 867 |
+
break
|
| 868 |
+
atoms.pop()
|
| 869 |
+
if not has_parent_with_group_offloading:
|
| 870 |
+
parameters.append((name, parameter))
|
| 871 |
+
return parameters
|
| 872 |
+
|
| 873 |
+
|
| 874 |
+
def _gather_buffers_with_no_group_offloading_parent(
|
| 875 |
+
module: torch.nn.Module, modules_with_group_offloading: Set[str]
|
| 876 |
+
) -> list[torch.Tensor]:
|
| 877 |
+
buffers = []
|
| 878 |
+
for name, buffer in module.named_buffers():
|
| 879 |
+
has_parent_with_group_offloading = False
|
| 880 |
+
atoms = name.split(".")
|
| 881 |
+
while len(atoms) > 0:
|
| 882 |
+
parent_name = ".".join(atoms)
|
| 883 |
+
if parent_name in modules_with_group_offloading:
|
| 884 |
+
has_parent_with_group_offloading = True
|
| 885 |
+
break
|
| 886 |
+
atoms.pop()
|
| 887 |
+
if not has_parent_with_group_offloading:
|
| 888 |
+
buffers.append((name, buffer))
|
| 889 |
+
return buffers
|
| 890 |
+
|
| 891 |
+
|
| 892 |
+
def _find_parent_module_in_module_dict(name: str, module_dict: dict[str, torch.nn.Module]) -> str:
|
| 893 |
+
atoms = name.split(".")
|
| 894 |
+
while len(atoms) > 0:
|
| 895 |
+
parent_name = ".".join(atoms)
|
| 896 |
+
if parent_name in module_dict:
|
| 897 |
+
return parent_name
|
| 898 |
+
atoms.pop()
|
| 899 |
+
return ""
|
| 900 |
+
|
| 901 |
+
|
| 902 |
+
def _raise_error_if_accelerate_model_or_sequential_hook_present(module: torch.nn.Module) -> None:
|
| 903 |
+
if not is_accelerate_available():
|
| 904 |
+
return
|
| 905 |
+
for name, submodule in module.named_modules():
|
| 906 |
+
if not hasattr(submodule, "_hf_hook"):
|
| 907 |
+
continue
|
| 908 |
+
if isinstance(submodule._hf_hook, (AlignDevicesHook, CpuOffload)):
|
| 909 |
+
raise ValueError(
|
| 910 |
+
f"Cannot apply group offloading to a module that is already applying an alternative "
|
| 911 |
+
f"offloading strategy from Accelerate. If you want to apply group offloading, please "
|
| 912 |
+
f"disable the existing offloading strategy first. Offending module: {name} ({type(submodule)})"
|
| 913 |
+
)
|
| 914 |
+
|
| 915 |
+
|
| 916 |
+
def _get_top_level_group_offload_hook(module: torch.nn.Module) -> GroupOffloadingHook | None:
|
| 917 |
+
for submodule in module.modules():
|
| 918 |
+
if hasattr(submodule, "_diffusers_hook"):
|
| 919 |
+
group_offloading_hook = submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING)
|
| 920 |
+
if group_offloading_hook is not None:
|
| 921 |
+
return group_offloading_hook
|
| 922 |
+
return None
|
| 923 |
+
|
| 924 |
+
|
| 925 |
+
def _is_group_offload_enabled(module: torch.nn.Module) -> bool:
|
| 926 |
+
top_level_group_offload_hook = _get_top_level_group_offload_hook(module)
|
| 927 |
+
return top_level_group_offload_hook is not None
|
| 928 |
+
|
| 929 |
+
|
| 930 |
+
def _get_group_onload_device(module: torch.nn.Module) -> torch.device:
|
| 931 |
+
top_level_group_offload_hook = _get_top_level_group_offload_hook(module)
|
| 932 |
+
if top_level_group_offload_hook is not None:
|
| 933 |
+
return top_level_group_offload_hook.config.onload_device
|
| 934 |
+
raise ValueError("Group offloading is not enabled for the provided module.")
|
| 935 |
+
|
| 936 |
+
|
| 937 |
+
def _compute_group_hash(group_id):
|
| 938 |
+
hashed_id = hashlib.sha256(group_id.encode("utf-8")).hexdigest()
|
| 939 |
+
# first 16 characters for a reasonably short but unique name
|
| 940 |
+
return hashed_id[:16]
|
| 941 |
+
|
| 942 |
+
|
| 943 |
+
def _maybe_remove_and_reapply_group_offloading(module: torch.nn.Module) -> None:
|
| 944 |
+
r"""
|
| 945 |
+
Removes the group offloading hook from the module and re-applies it. This is useful when the module has been
|
| 946 |
+
modified in-place and the group offloading hook references-to-tensors needs to be updated. The in-place
|
| 947 |
+
modification can happen in a number of ways, for example, fusing QKV or unloading/loading LoRAs on-the-fly.
|
| 948 |
+
|
| 949 |
+
In this implementation, we make an assumption that group offloading has only been applied at the top-level module,
|
| 950 |
+
and therefore all submodules have the same onload and offload devices. If this assumption is not true, say in the
|
| 951 |
+
case where user has applied group offloading at multiple levels, this function will not work as expected.
|
| 952 |
+
|
| 953 |
+
There is some performance penalty associated with doing this when non-default streams are used, because we need to
|
| 954 |
+
retrace the execution order of the layers with `LazyPrefetchGroupOffloadingHook`.
|
| 955 |
+
"""
|
| 956 |
+
top_level_group_offload_hook = _get_top_level_group_offload_hook(module)
|
| 957 |
+
|
| 958 |
+
if top_level_group_offload_hook is None:
|
| 959 |
+
return
|
| 960 |
+
|
| 961 |
+
registry = HookRegistry.check_if_exists_or_initialize(module)
|
| 962 |
+
registry.remove_hook(_GROUP_OFFLOADING, recurse=True)
|
| 963 |
+
registry.remove_hook(_LAYER_EXECUTION_TRACKER, recurse=True)
|
| 964 |
+
registry.remove_hook(_LAZY_PREFETCH_GROUP_OFFLOADING, recurse=True)
|
| 965 |
+
|
| 966 |
+
_apply_group_offloading(module, top_level_group_offload_hook.config)
|